probly.conformal_prediction.scores.cqr.flax

Flax/JAX implementation for CQR scores.

Functions

cqr_score_jax(y_true, y_pred)

Compute CQR nonconformity scores for JAX arrays.

probly.conformal_prediction.scores.cqr.flax.cqr_score_jax(y_true, y_pred)[source]

Compute CQR nonconformity scores for JAX arrays.

Parameters:
  • y_true (Array) – True targets as JAX array of shape (n_samples,).

  • y_pred (Array) – Predicted lower and upper quantiles as JAX array of shape (n_samples, 2).

  • Returns

  • -------

  • Array – One-dimensional array of nonconformity scores with shape (n_samples,).

Return type:

Array