probly.representation.sampling.jax_sample

JAX sample implementation.

Classes

JaxArraySample(array, sample_axis)

A sample implementation for JAX arrays.

class probly.representation.sampling.jax_sample.JaxArraySample(array, sample_axis)[source]

Bases: Sample[Array]

A sample implementation for JAX arrays.

Parameters:
  • array (Array)

  • sample_axis (int)

classmethod from_iterable(samples, sample_axis='auto', dtype=None)[source]

Create an JaxArraySample from a sequence of samples.

Parameters:
  • samples (Iterable[jax.Array]) – The predictions to create the sample from.

  • sample_axis (SampleAxis) – The dimension along which samples are organized.

  • dtype (DTypeLike | None) – Desired data type of the array.

Returns:

The created JaxArraySample.

Return type:

Self

classmethod from_sample(sample, sample_axis='auto', dtype=None)[source]
Parameters:
  • sample (Sample[jax.Array])

  • sample_axis (SampleAxis)

  • dtype (DTypeLike | None)

Return type:

Self

concat(other)[source]
Parameters:

other (Sample[Array])

Return type:

Self

copy()[source]

Create a copy of the JaxArraySample.

Returns:

A copy of the JaxArraySample.

Return type:

Self

move_sample_axis(new_sample_axis)[source]

Return a new JaxArraySample with the sample dimension moved to new_sample_axis.

Parameters:

new_sample_axis (int) – The new sample dimension.

Returns:

A new ArraySample with the sample dimension moved.

Return type:

JaxArraySample

sample_mean()[source]

Compute the mean of the sample.

Return type:

Array

sample_std(ddof=1)[source]

Compute the standard deviation of the sample.

Parameters:

ddof (int)

Return type:

Array

sample_var(ddof=1)[source]

Compute the variance of the sample.

Parameters:

ddof (int)

Return type:

Array

to_device(device, *, stream=None)[source]

Move the underlying array to the specified device.

Parameters:
  • device (Device | Sharding) – The target device.

  • stream (int | Any | None) – not implemented, passing a non-None value will lead to an error.

Returns:

A new JaxArraySample on the specified device.

Return type:

Self

property T: Array

The transposed version of the underlying array.

array: Array
property device: Any

The device of the underlying array.

property dtype: DTypeLike

The data type of the underlying array.

property mT: Array

The transposed version of the underlying array.

property ndim: int

The number of dimensions of the underlying array.

sample_axis: int
property sample_size: int

Return the number of samples.

property samples: Array

Return an iterator over the samples.

property shape: tuple[int, ...]

The shape of the underlying array.

property size: int

The total number of elements in the underlying array.