probly.representation.sample.jax.JaxArraySample

class probly.representation.sample.jax.JaxArraySample(array: Array, sample_axis: int, weights: Array | None = None)[source]

Bases: Sample[Array]

A sample implementation for JAX arrays.

property T: Array

The transposed version of the underlying array.

array: Array
concat(other: Sample[Array]) Self[source]

Append another sample to this sample.

copy() Self[source]

Create a copy of the JaxArraySample.

Returns:

A copy of the JaxArraySample.

property device: Any

The device of the underlying array.

property dtype: DTypeLike

The data type of the underlying array.

classmethod from_iterable(samples: Iterable[jax.Array], weights: Iterable[float] | None = None, sample_axis: SampleAxis = 'auto', dtype: DTypeLike | None = None) Self[source]

Create an JaxArraySample from a sequence of samples.

Parameters:
  • samples – The predictions to create the sample from.

  • weights – Optional weights for the samples.

  • sample_axis – The dimension along which samples are organized.

  • dtype – Desired data type of the array.

Returns:

The created JaxArraySample.

classmethod from_sample(sample: Sample[jax.Array], sample_axis: SampleAxis = 'auto', dtype: DTypeLike | None = None) Self[source]

Create a new Sample from an existing Sample.

Parameters:
  • sample – The sample to create the new sample from.

  • kwargs – Parameters for sample creation.

Returns:

The created Sample.

property is_weighted: bool

Return whether the samples are weighted.

property mT: Array

The transposed version of the underlying array.

move_sample_axis(new_sample_axis: int) JaxArraySample[source]

Return a new JaxArraySample with the sample dimension moved to new_sample_axis.

Parameters:

new_sample_axis – The new sample dimension.

Returns:

A new ArraySample with the sample dimension moved.

property ndim: int

The number of dimensions of the underlying array.

sample_axis: int
sample_mean() Array[source]

Compute the mean of the sample.

property sample_size: int

Return the number of samples.

sample_std(ddof: int = 0) Array[source]

Compute the standard deviation of the sample.

sample_var(ddof: int = 0) Array[source]

Compute the variance of the sample.

property samples: Array

Return an iterator over the samples.

property shape: tuple[int, ...]

The shape of the underlying array.

property size: int

The total number of elements in the underlying array.

to_device(device: Device | Sharding, *, stream: int | Any | None = None) Self[source]

Move the underlying array to the specified device.

Parameters:
  • device – The target device.

  • stream – not implemented, passing a non-None value will lead to an error.

Returns:

A new JaxArraySample on the specified device.

weights: Array | None