probly.representation.sample.jax.JaxArraySample¶
- class probly.representation.sample.jax.JaxArraySample(array: Array, sample_axis: int, weights: Array | None = None)[source]¶
Bases:
Sample[Array]A sample implementation for JAX arrays.
- property T: Array¶
The transposed version of the underlying array.
- array: Array¶
- property dtype: DTypeLike¶
The data type of the underlying array.
- classmethod from_iterable(samples: Iterable[jax.Array], weights: Iterable[float] | None = None, sample_axis: SampleAxis = 'auto', dtype: DTypeLike | None = None) Self[source]¶
Create an JaxArraySample from a sequence of samples.
- Parameters:
samples – The predictions to create the sample from.
weights – Optional weights for the samples.
sample_axis – The dimension along which samples are organized.
dtype – Desired data type of the array.
- Returns:
The created JaxArraySample.
- classmethod from_sample(sample: Sample[jax.Array], sample_axis: SampleAxis = 'auto', dtype: DTypeLike | None = None) Self[source]¶
Create a new Sample from an existing Sample.
- Parameters:
sample – The sample to create the new sample from.
kwargs – Parameters for sample creation.
- Returns:
The created Sample.
- property mT: Array¶
The transposed version of the underlying array.
- move_sample_axis(new_sample_axis: int) JaxArraySample[source]¶
Return a new JaxArraySample with the sample dimension moved to new_sample_axis.
- Parameters:
new_sample_axis – The new sample dimension.
- Returns:
A new ArraySample with the sample dimension moved.
- property samples: Array¶
Return an iterator over the samples.
- to_device(device: Device | Sharding, *, stream: int | Any | None = None) Self[source]¶
Move the underlying array to the specified device.
- Parameters:
device – The target device.
stream – not implemented, passing a non-None value will lead to an error.
- Returns:
A new JaxArraySample on the specified device.