probly.representation.distribution.torch_dirichlet.TorchDirichletDistribution¶
- class probly.representation.distribution.torch_dirichlet.TorchDirichletDistribution(alphas: Tensor = <property object>)[source]¶
Bases:
TorchAxisProtected[Any],DirichletDistribution[TorchCategoricalDistribution]A Dirichlet distribution stored as a torch tensor.
Shape: (…, num_classes) The last axis represents the category dimension.
- alphas: torch.Tensor¶
- clone(*, memory_format: memory_format = torch.preserve_format) Self[source]¶
Return a copy of the array.
- cpu(memory_format: memory_format = torch.preserve_format) Self[source]¶
Move the array to the CPU.
- cuda(device: device | str | None = None, non_blocking: bool = False, memory_format: memory_format = torch.preserve_format) Self[source]¶
Move the array to the GPU.
- classmethod from_tensor(alphas: Tensor | list[float], dtype: dtype | None = None) TorchDirichletDistribution[source]¶
Create a Dirichlet distribution from a tensor or list.
- Parameters:
alphas – Dirichlet alpha parameters.
dtype – Desired tensor dtype.
- Returns:
The created torch Dirichlet distribution.
- property mean: TorchCategoricalDistribution¶
Return the expected categorical probabilities.
- numpy(*, force: bool = False) np.ndarray[source]¶
Convert the Dirichlet alpha parameters to a numpy array.
- permitted_functions: ClassVar[set[Callable]] = {<built-in method mean of type object>, <built-in method sum of type object>, <function torch_average>}¶
- permute(*dims: Size | int | tuple[int] | list[int]) Self[source]¶
Return a permuted version of the array.
- protected_values(func: Callable | None = None) dict[str, TorchProtectedValue] | None[source]¶
Return all protected field values as-is.
Optionally takes the torch function that triggered the call for context. This can be used to conditionally modify the returned values or prevent them from being accessed.
- sample(num_samples: int = 1) TorchSample[TorchCategoricalDistribution][source]¶
Sample categorical distributions from the Dirichlet distribution.
- size(dim: int | None = None) int | Size[source]¶
Return the size of the array along the given dimension.
- to_device(device: Literal['cpu'], /, *, stream: int | Any | None = None) Self[source]¶
Move the array to a device.
- type = 'dirichlet'¶