probly.conformal_prediction.scores.aps.flax

Flax/JAX implementation for APS scores.

Functions

aps_score_jax(probs)

Compute APS scores for JAX arrays (keeping data on GPU/TPU).

probly.conformal_prediction.scores.aps.flax.aps_score_jax(probs)[source]

Compute APS scores for JAX arrays (keeping data on GPU/TPU).

Parameters:

probs (Array)

Return type:

Array