probly.representation.distribution.torch_categorical.TorchCategoricalDistribution¶
- class probly.representation.distribution.torch_categorical.TorchCategoricalDistribution(unnormalized_probabilities: Tensor = <property object>)[source]¶
Bases:
TorchAxisProtected[Any],CategoricalDistribution[Tensor]A categorical distribution stored as a torch tensor.
Shape: (…, num_classes) The last axis represents the category dimension.
- clone(*, memory_format: memory_format = torch.preserve_format) Self[source]¶
Return a copy of the array.
- cpu(memory_format: memory_format = torch.preserve_format) Self[source]¶
Move the array to the CPU.
- cuda(device: device | str | None = None, non_blocking: bool = False, memory_format: memory_format = torch.preserve_format) Self[source]¶
Move the array to the GPU.
- permitted_functions: ClassVar[set[Callable]] = {<built-in method mean of type object>, <function torch_average>}¶
- permute(*dims: Size | int | tuple[int] | list[int]) Self[source]¶
Return a permuted version of the array.
- protected_values(func: Callable | None = None) dict[str, TorchProtectedValue] | None[source]¶
Return all protected field values as-is.
Optionally takes the torch function that triggered the call for context. This can be used to conditionally modify the returned values or prevent them from being accessed.
- sample(num_samples: int = 1, rng: Generator | None = None) TorchSample[Tensor][source]¶
Sample from the categorical distribution (torch backend).
- size(dim: int | None = None) int | Size[source]¶
Return the size of the array along the given dimension.
- to_device(device: Literal['cpu'], /, *, stream: int | Any | None = None) Self[source]¶
Move the array to a device.
- type = 'categorical'¶
- unnormalized_probabilities: torch.Tensor¶