probly.method.conformal.flax.FlaxConformalSetPredictor

class probly.method.conformal.flax.FlaxConformalSetPredictor(*args: Any, **kwargs: Any)[source]

Bases: _ConformalPredictorBase[In, Out], Module, Generic[In, Out]

Base flax conformal wrapper forwarding __call__.

Initialize the flax conformal wrapper.

__call__(*args: object, **kwargs: object) Any[source]

Forward to the wrapped model.

calibrate(alpha: float, y_calib: Out, *calib_args: In.args, **calib_kwargs: In.kwargs) Self[source]

Calibrate the predictor using calibration data.

property conformal_quantile: float | None

Return the calibrated conformal quantile if available.

eval(**attributes)[source]

Sets the Module to evaluation mode.

eval uses set_attributes to recursively set attributes deterministic=True and use_running_average=True of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Parameters:

**attributes – additional attributes passed to set_attributes.

iter_children() Iterator[tuple[Key, Module]][source]

Warning: this method is method is deprecated; use iter_children() instead.

Iterates over all children Module’s of the current Module. This method is similar to iter_modules(), except it only iterates over the immediate children, and does not recurse further down. Alias of iter_children().

iter_modules() Iterator[tuple[tuple[Key, ...], Module]][source]

Warning: this method is method is deprecated; use iter_modules() instead.

Recursively iterates over all nested Module’s of the current Module, including the current Module. Alias of iter_modules().

non_conformity_score
perturb(name: str, value: Any, variable_type: str | type[Variable[Any]] = <class 'flax.nnx.variablelib.Perturbation'>)[source]

Extract gradients of intermediate values during training.

Used with nnx.capture() to record intermediate values in the forward pass and their gradients in the backward pass. Returns the value plus whatever perturbation is stored under name in the current capture context, allowing gradient computation via nnx.grad.

The workflow has four steps: 1. Initialize perturbations with nnx.capture(model, nnx.Perturbation) 2. Run model with nnx.capture(model, nnx.Intermediate, init=perturbations) 3. Take gradients with respect to perturbations using nnx.grad 4. Combine results with nnx.merge_state(perturb_grads, intermediates)

Note

This creates extra variables of the same size as value, thus occupies more memory. Use it only to debug gradients in training.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __call__(self, x):
...     x2 = self.perturb('grad_of_x', x)
...     return 3 * x2

>>> model = Model()
>>> x = 1.0

>>> # Step 1: Initialize perturbations
>>> forward = nnx.capture(model, nnx.Perturbation)
>>> _, perturbations = forward(x)

>>> # Steps 2-4: Capture gradients
>>> def train_step(model, perturbations, x):
...   def loss(model, perturbations, x):
...     return nnx.capture(model, nnx.Intermediate, init=perturbations)(x)
...   (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x)
...   return nnx.merge_state(perturb_grads, sowed)

>>> metrics = train_step(model, perturbations, x)
>>> # metrics contains gradients of intermediate values
Parameters:
  • name – A string key for storing the perturbation value.

  • value – The intermediate value to capture gradients for. You must use the returned value (not the original) for gradient capturing to work.

  • variable_type – The Variable type for the stored perturbation. Default is nnx.Perturbation.

predictor: nnx.Module
set_attributes(*filters: filterlib.Filter, raise_if_not_found: bool = True, graph: bool | None = None, **attributes: tp.Any) None[source]

Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter’s can be used to set the attributes of specific Modules:

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Parameters:
  • *filters – Filters to select the Modules to set the attributes of.

  • raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • **attributes – The attributes to set.

sow(variable_type: type[~flax.nnx.variablelib.Variable[~flax.nnx.module.B]] | str, name: str, value: ~flax.nnx.module.A, reduce_fn: ~typing.Callable[[~flax.nnx.module.B, ~flax.nnx.module.A], ~flax.nnx.module.B] = <function <lambda>>, init_fn: ~typing.Callable[[], ~flax.nnx.module.B] = <function <lambda>>) bool[source]

Store intermediate values during module execution for later extraction.

Used with nnx.capture() decorator to collect intermediate values without explicitly passing containers through module calls. Values are stored under the specified name in a collection associated with variable_type.

By default, values are appended to a tuple, allowing multiple values to be tracked when the same module is called multiple times.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'features', x)
...     x = self.linear2(x)
...     return x

>>> # With the capture decorator, sow returns intermediates
>>> model = Model(rngs=nnx.Rngs(0))
>>> @nnx.capture(nnx.Intermediate)
... def forward(model, x):
...   return model(x)
>>> result, intermediates = forward(model, jnp.ones(2))
>>> assert 'features' in intermediates

Custom init/reduce functions can be passed to control accumulation:

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     return x
Parameters:
  • variable_type – The Variable type for the stored value. Typically Intermediate or a subclass is used.

  • name – A string key for storing the value in the collection.

  • value – The value to be stored.

  • reduce_fn – Function to combine existing and new values. Default appends to a tuple.

  • init_fn – Function providing initial value for first reduce_fn call. Default is an empty tuple.

train(**attributes)[source]

Sets the Module to training mode.

train uses set_attributes to recursively set attributes deterministic=False and use_running_average=False of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Parameters:

**attributes – additional attributes passed to set_attributes.