Source code for probly.representation.distribution.mixture_gaussian

"""Mixture distribution for Gaussian components."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Self

import numpy as np

from probly.representation.sampling.array_sample import ArraySample

if TYPE_CHECKING:
    from collections.abc import Sequence

    from numpy.typing import DTypeLike

    from probly.representation.distribution.array_gaussian import ArrayGaussian


[docs] @dataclass(frozen=True, slots=True, weakref_slot=True) class ArrayGaussianMixture: """Gaussian mixture.""" components: Sequence[ArrayGaussian] weights: np.ndarray def __post_init__(self) -> None: """Validate and normalize the mixture weights.""" w = np.asarray(self.weights, dtype=float) if w.ndim != 1: msg = "weights must be 1D -> (K,)." raise ValueError(msg) if len(self.components) != w.shape[0]: msg = "for every components there must be just one weights." raise ValueError(msg) if np.any(w < 0): msg = "weights must be non-negative." raise ValueError(msg) s = w.sum() if not np.isclose(s, 1.0): w = w / s object.__setattr__(self, "weights", w) def __array_namespace__(self) -> Any: # noqa: ANN401 """Get the array namespace of the underlying array.""" return self.components[0].__array_namespace__() @property def dtype(self) -> DTypeLike: """The data type of the underlying array.""" return self.components[0].dtype @property def device(self) -> str: """The device of the underlying array.""" return self.components[0].device @property def ndim(self) -> int: """The number of dimensions of the underlying array.""" return self.components[0].ndim @property def shape(self) -> tuple[int, ...]: """The shape of the underlying array.""" return self.components[0].shape @property def size(self) -> int: """The total number of elements in the underlying array.""" return self.components[0].size @property def T(self) -> Self: # noqa: N802 """The transposed version of the mixture components.""" return type(self)( components=[c.T for c in self.components], weights=self.weights, )
[docs] def sample( self, num_samples: int, rng: np.random.Generator | None = None, ) -> ArraySample: """Draw samples from the Gaussian mixture. Returns an ArraySample.""" if rng is None: rng = np.random.default_rng() k = len(self.components) weights = self.weights comp_idx = rng.choice(k, size=num_samples, p=weights) reference_comp = self.components[0] reference_array = reference_comp.sample(1).array out_shape = (num_samples, *reference_array.shape[1:]) result = np.empty(out_shape, dtype=reference_array.dtype) for k, component in enumerate(self.components): indices_for_component = comp_idx == k num_samples_for_component = int(indices_for_component.sum()) if num_samples_for_component == 0: continue samples_for_component = component.sample(num_samples_for_component, rng=rng).array result[indices_for_component] = samples_for_component return ArraySample(array=result, sample_axis=0)