probly.representation.distribution.torch_dirichlet.TorchDirichletDistribution

class probly.representation.distribution.torch_dirichlet.TorchDirichletDistribution(alphas: Tensor = <property object>)[source]

Bases: TorchAxisProtected[Any], DirichletDistribution[TorchCategoricalDistribution]

A Dirichlet distribution stored as a torch tensor.

Shape: (…, num_classes) The last axis represents the category dimension.

property T: Self

Inverts the order of the dimensions of the underlying array.

alphas: torch.Tensor
clone(*, memory_format: memory_format = torch.preserve_format) Self[source]

Return a copy of the array.

cpu(memory_format: memory_format = torch.preserve_format) Self[source]

Move the array to the CPU.

cuda(device: device | str | None = None, non_blocking: bool = False, memory_format: memory_format = torch.preserve_format) Self[source]

Move the array to the GPU.

detach() Self[source]

Return a detached version of the array.

property device: device

Device of the array.

property dtype: dtype

Data type of the array.

entropy() ArrayLike[source]

Compute entropy.

classmethod from_tensor(alphas: Tensor | list[float], dtype: dtype | None = None) TorchDirichletDistribution[source]

Create a Dirichlet distribution from a tensor or list.

Parameters:
  • alphas – Dirichlet alpha parameters.

  • dtype – Desired tensor dtype.

Returns:

The created torch Dirichlet distribution.

property mH: Self

The adjoint (conjugate) transposed version of the underlying array.

property mT: Self

The transposed version of the underlying array.

property mean: TorchCategoricalDistribution

Return the expected categorical probabilities.

property ndim: int

Number of dimensions.

numpy(*, force: bool = False) np.ndarray[source]

Convert the Dirichlet alpha parameters to a numpy array.

permitted_functions: ClassVar[set[Callable]] = {<built-in method mean of type object>, <built-in method sum of type object>, <function torch_average>}
permute(*dims: Size | int | tuple[int] | list[int]) Self[source]

Return a permuted version of the array.

classmethod primary_protected_name() str[source]

Return the first protected field (dict order).

protected_axes: ClassVar[dict[str, int]] = {'alphas': 1}
property protected_shape: tuple[int, ...]

Protected trailing shape of the primary field.

protected_value() TorchProtectedValue[source]

Return the primary protected value.

protected_values(func: Callable | None = None) dict[str, TorchProtectedValue] | None[source]

Return all protected field values as-is.

Optionally takes the torch function that triggered the call for context. This can be used to conditionally modify the returned values or prevent them from being accessed.

reshape(*shape: int | tuple[int, ...]) Self[source]

Return a copy with reshaped protected values.

resolve_conj() Self[source]

Return a version of the array with any conjugate operations resolved.

resolve_neg() Self[source]

Return a version of the array with any negation operations resolved.

sample(num_samples: int = 1) TorchSample[TorchCategoricalDistribution][source]

Sample categorical distributions from the Dirichlet distribution.

property shape: tuple[int, ...]

Shape of the array.

size(dim: int | None = None) int | Size[source]

Return the size of the array along the given dimension.

to(*args: Any, **kwargs: Any) Self[source]

Move and/or cast the tensor, mirroring torch.Tensor.to.

to_device(device: Literal['cpu'], /, *, stream: int | Any | None = None) Self[source]

Move the array to a device.

transpose(dim0: int, dim1: int) Self[source]

Return a transposed version of the array.

type = 'dirichlet'
with_protected_value(value: TorchProtectedValue) Self[source]

Return a copy with a replaced primary protected value.

with_protected_values(values: dict[str, TorchProtectedValue]) Self[source]

Return a copy with updated protected field values.