probly.representation.distribution.array_dirichlet.ArrayDirichletDistribution¶
- class probly.representation.distribution.array_dirichlet.ArrayDirichletDistribution(alphas: ndarray = <property object>)[source]¶
Bases:
ArrayAxisProtected[ndarray],DirichletDistribution[ArrayCategoricalDistribution]A Dirichlet distribution stored as a numpy array.
Shape: (…, num_classes) The last axis represents the category dimension.
- alphas: np.ndarray¶
- astype(dtype: DTypeLike, order: Order = 'K', casting: Literal['no', 'equiv', 'safe', 'same_kind', 'unsafe'] = 'unsafe', subok: bool = True, copy: bool = True) Self[source]¶
Copy of the array, cast to a specified type.
- property flags: ArrayFlagsLike¶
The flags of the array.
- classmethod from_array(alphas: np.ndarray | list, dtype: DTypeLike | None = None) Self[source]¶
Create a Dirichlet distribution from an array or list.
- property mean: ArrayCategoricalDistribution¶
Return the mean of the Dirichlet distribution.
- permitted_functions: ClassVar[set[Callable]] = {<function average>, <function mean>, <function sum>}¶
- permitted_ufuncs: ClassVar[dict[np.ufunc, list[str]]] = {<ufunc 'add'>: ['__call__'], <ufunc 'subtract'>: ['__call__']}¶
- protected_values(func: Callable | None = None, method: str | None = None) dict[str, ArrayProtectedValue] | None[source]¶
Return all protected field values.
The values are preserved as-is and are not coerced to
np.ndarray. Optionally takes the function that triggered the call for context. This can be used to conditionally modify the returned values or prevent them from being accessed.
- reshape(*shape: int | tuple[int, ...], order: str = 'C', copy: bool | None = None) Self[source]¶
Return a copy with reshaped protected values.
- sample(num_samples: int = 1, rng: Generator | None = None) ArraySample[ArrayCategoricalDistribution][source]¶
Sample from the Dirichlet distribution (NumPy backend).
- transpose(*axes: int | None) Self[source]¶
Return a transposed version of the ArraySample.
This method implicitly also provides full axis tracking support for - np.moveaxis - np.rollaxis Those functions call out to transpose methods for custom array types.
- Parameters:
axes – The axes to transpose.
- Returns:
A transposed version of the ArraySample.
- type = 'dirichlet'¶