probly.conformal_prediction.scores.raps.flax

Flax/JAX implementation for RAPS scores.

Functions

raps_score_jax(probs[, lambda_reg, k_reg, ...])

Compute RAPS scores for JAX arrays.

probly.conformal_prediction.scores.raps.flax.raps_score_jax(probs, lambda_reg=0.1, k_reg=0, epsilon=0.01)[source]

Compute RAPS scores for JAX arrays.

For each sample, classes are sorted by descending probability. The score for each class is the cumulative sum up to that class in the sorted order, plus a rank-based regularization penalty and a small epsilon term.

Returned shape: (n_samples, n_classes)

Parameters:
Return type:

Array