probly.conformal_prediction.scores.saps.flax

Flax/JAX implementation for SAPS scores.

Functions

saps_score_jax(probs, lambda_val, u)

Compute SAPS Nonconformity Score for JAX arrays.

probly.conformal_prediction.scores.saps.flax.saps_score_jax(probs, lambda_val, u)[source]

Compute SAPS Nonconformity Score for JAX arrays.

Parameters:
  • probs (Array) – 1D array with softmax probabilities.

  • lambda_val (float) – Lambda value for SAPS.

  • u (Array) – Optional random value in [0,1). If None, generated from key.

  • key – JAX random key for generating u if not provided.

Returns:

SAPS nonconformity score.

Return type:

float