probly.layers.flax

flax layer implementation.

Classes

BatchEnsembleConv(*args, **kwargs)

Implements a BatchEnsemble convolutional layer.

BatchEnsembleLinear(*args, **kwargs)

Implements a BatchEnsemble Linear layer.

DropConnectLinear(*args, **kwargs)

Custom Linear layer with DropConnect applied to weights during training based on [ASSR20].

class probly.layers.flax.BatchEnsembleConv(*args, **kwargs)[source]

Bases: Conv

Implements a BatchEnsemble convolutional layer.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

kernel_shape

Sequence[int], (in_features, out_features, kernel_size)

kernel

nnx.Param, weight matrix of the layer.

bias

nnx.Param, bias of the layer.

in_features

int, number of input features.

out_features

int, number of output features.

kernel_size

int or Sequence[int], size of the kernel.

strides

tp.Union[None, int, tp.Sequence[int]], representing the inter-window strides.

padding

flax.typing.PaddingLike, either the string 'SAME', the string 'VALID', the string 'CIRCULAR' (periodic boundary conditions), the string ‘REFLECT’ (reflection across the padding boundary), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and passing a single int in a sequence causes the same padding to be used on both sides. 'CAUSAL' padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.

input_dilation

tp.Union[None, int, tp.Sequence[int]], giving the dilation factor to apply in each spatial dimension of inputs (default: 1). Convolution with input dilation d is equivalent to transposed convolution with stride d.

kernel_dilation

tp.Union[None, int, tp.Sequence[int]], giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.

feature_group_count

int, If specified divides the input features into groups.

use_bias

bool, whether to add bias to the output.

mask

typing.Optional[Array], Optional .

dtype

typing.Optional[flax.typing.Dtype], the dtype of the computation (default: infer from input and params).

param_dtype

flax.typing.Dtype, the dtype passed to parameter initializers.

precision

flax.typing.PrecisionLike, numerical precision of the computation see jax.lax.Precision for details.

conv_general_dilated

flax.typing.DotGeneralT, dot product function.

promote_dtype

flax.typing.PromoteDtypeFn, function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of (inputs, kernel, bias) and a dtype keyword argument, and return a tuple of arrays with the promoted dtype.

preferred_element_type

flax.typing.Dtype, Optional parameter controls the data type output by the dot product. This argument is passed to dot_general function. See jax.lax.dot for details.

num_members

int, number of batch ensemble members.

s

nnx.Param, rank-one factor for input features.

r

nnx.Param, rank-one factor for output features.

__call__(inputs)[source]

Forward pass of the BatchEnsembleConv layer.

Parameters:

inputs (Array) – jax.Array, the input of shape [B, kernel_size(n-dimensional), in_features] or [E, B, kernel_size(n-dimensional), in_features], where B is the batch size and E is the ensemble_size.

Returns:

jax.Array, Output of shape [E, B, kernel_size(n-dimensional), out_features].

Return type:

Array

eval(**attributes)

Sets the Module to evaluation mode.

eval uses set_attributes to recursively set attributes deterministic=True and use_running_average=True of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Parameters:

**attributes – additional attributes passed to set_attributes.

iter_children()

Iterates over all children Module’s of the current Module. This method is similar to iter_modules(), except it only iterates over the immediate children, and does not recurse further down.

iter_children creates a generator that yields the key and the Module instance, where the key is a string representing the attribute name of the Module to access the corresponding child Module.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_children():
...  print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
Return type:

Iterator[tuple[Key, Module]]

iter_modules()

Recursively iterates over all nested Module’s of the current Module, including the current Module.

iter_modules creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
Return type:

Iterator[tuple[tuple[Key, …], Module]]

perturb(name, value, variable_type=<class 'flax.nnx.variablelib.Perturbation'>)

Add an zero-value variable (“perturbation”) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation variable.

Since the shape of the perturbation value depends on the shape of the input, a perturbation variable is only created after you run a sample input through the model once.

Note

This creates extra dummy variables of the same size as value, thus occupies more memory. Use it only to debug gradients in training.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = self.perturb('xgrad', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 4))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'xgrad')  # perturbation requires a sample input run
>>> _ = model(x)
>>> assert model.xgrad.value.shape == (1, 3)   # same as the intermediate value
>>> graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation)

>>> # Take gradients on the Param and Perturbation variables
>>> @nnx.grad(argnums=(0, 1))
... def grad_loss(params, perturbations, inputs, targets):
...   model = nnx.merge(graphdef, params, perturbations)
...   return jnp.mean((model(inputs) - targets) ** 2)

>>> (grads, perturbations) = grad_loss(params, perturbations, x, y)
>>> # `perturbations.xgrad.value` is the intermediate gradient
>>> assert not jnp.array_equal(perturbations.xgrad.value, jnp.zeros((1, 3)))
Parameters:
  • name (str) – A string denoting the Module attribute name for the perturbation value.

  • value (Any) – The value to take intermediate gradient.

  • variable_type (str | type[Variable[Any]]) – The Variable type for the stored perturbation. Defaulted at nnx.Perturbation.

set_attributes(*filters, raise_if_not_found=True, **attributes)

Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter’s can be used to set the attributes of specific Modules:

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Parameters:
  • *filters (filterlib.Filter) – Filters to select the Modules to set the attributes of.

  • raise_if_not_found (bool) – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • **attributes (tp.Any) – The attributes to set.

Return type:

None

sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)

sow() can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call. sow() stores a value in a new Module attribute, denoted by name. The value will be wrapped by a Variable of type variable_type, which can be useful to filter for in split(), state() and pop().

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i.value) == 1 # tuple of length 1
>>> assert model.i.value[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()

Alternatively, a custom init/reduce function can be passed:

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum.value == model.product.value).all()
>>> intermediate = model.sum.value

>>> y = model(x)
>>> assert (model.sum.value == intermediate*2).all()
>>> assert (model.product.value == intermediate**2).all()
Parameters:
  • variable_type (type[Variable[B]] | str) – The Variable type for the stored value. Typically Intermediate is used to indicate an intermediate value.

  • name (str) – A string denoting the Module attribute name, where the sowed value is stored.

  • value (A) – The value to be stored.

  • reduce_fn (Callable[[B, A], B]) – The function used to combine the existing value with the new value. The default is to append the value to a tuple.

  • init_fn (Callable[[], B]) – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

Return type:

bool

train(**attributes)

Sets the Module to training mode.

train uses set_attributes to recursively set attributes deterministic=False and use_running_average=False of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Parameters:

**attributes – additional attributes passed to set_attributes.

class probly.layers.flax.BatchEnsembleLinear(*args, **kwargs)[source]

Bases: Linear

Implements a BatchEnsemble Linear layer.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

kernel

nnx.Param, weight matrix of the layer.

bias

nnx.Param, bias of the layer.

in_features

int, number of input features.

out_features

int, number of output features.

use_bias

bool, whether to add bias to the output.

dtype

typing.Optional[flax.typing.Dtype], the dtype of the computation (default: infer from input and params).

param_dtype

flax.typing.Dtype, the dtype passed to parameter initializers.

precision

flax.typing.PrecisionLike, numerical precision of the computation see jax.lax.Precision for details.

dot_general

flax.typing.DotGeneralT, dot product function.

promote_dtype

flax.typing.PromoteDtypeFn, function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of (inputs, kernel, bias) and a dtype keyword argument, and return a tuple of arrays with the promoted dtype.

preferred_element_type

flax.typing.Dtype, Optional parameter controls the data type output by the dot product. This argument is passed to dot_general function. See jax.lax.dot for details.

num_members

int, number of batch ensemble members.

s

nnx.Param, rank-one factor for input features

r

nnx.Param, rank-one factor for output features

__call__(inputs)[source]

Forward pass of the BatchEnsembleLinear layer.

Parameters:

inputs (Array) – jax.Array, the input of shape [B, in_features] or [E, B, in_features]. where B is the batch size and E is the ensemble_size.

Returns:

jax.Array, Output of shape [E, B, out_features].

Return type:

Array

eval(**attributes)

Sets the Module to evaluation mode.

eval uses set_attributes to recursively set attributes deterministic=True and use_running_average=True of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Parameters:

**attributes – additional attributes passed to set_attributes.

iter_children()

Iterates over all children Module’s of the current Module. This method is similar to iter_modules(), except it only iterates over the immediate children, and does not recurse further down.

iter_children creates a generator that yields the key and the Module instance, where the key is a string representing the attribute name of the Module to access the corresponding child Module.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_children():
...  print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
Return type:

Iterator[tuple[Key, Module]]

iter_modules()

Recursively iterates over all nested Module’s of the current Module, including the current Module.

iter_modules creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
Return type:

Iterator[tuple[tuple[Key, …], Module]]

perturb(name, value, variable_type=<class 'flax.nnx.variablelib.Perturbation'>)

Add an zero-value variable (“perturbation”) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation variable.

Since the shape of the perturbation value depends on the shape of the input, a perturbation variable is only created after you run a sample input through the model once.

Note

This creates extra dummy variables of the same size as value, thus occupies more memory. Use it only to debug gradients in training.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = self.perturb('xgrad', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 4))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'xgrad')  # perturbation requires a sample input run
>>> _ = model(x)
>>> assert model.xgrad.value.shape == (1, 3)   # same as the intermediate value
>>> graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation)

>>> # Take gradients on the Param and Perturbation variables
>>> @nnx.grad(argnums=(0, 1))
... def grad_loss(params, perturbations, inputs, targets):
...   model = nnx.merge(graphdef, params, perturbations)
...   return jnp.mean((model(inputs) - targets) ** 2)

>>> (grads, perturbations) = grad_loss(params, perturbations, x, y)
>>> # `perturbations.xgrad.value` is the intermediate gradient
>>> assert not jnp.array_equal(perturbations.xgrad.value, jnp.zeros((1, 3)))
Parameters:
  • name (str) – A string denoting the Module attribute name for the perturbation value.

  • value (Any) – The value to take intermediate gradient.

  • variable_type (str | type[Variable[Any]]) – The Variable type for the stored perturbation. Defaulted at nnx.Perturbation.

set_attributes(*filters, raise_if_not_found=True, **attributes)

Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter’s can be used to set the attributes of specific Modules:

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Parameters:
  • *filters (filterlib.Filter) – Filters to select the Modules to set the attributes of.

  • raise_if_not_found (bool) – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • **attributes (tp.Any) – The attributes to set.

Return type:

None

sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)

sow() can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call. sow() stores a value in a new Module attribute, denoted by name. The value will be wrapped by a Variable of type variable_type, which can be useful to filter for in split(), state() and pop().

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i.value) == 1 # tuple of length 1
>>> assert model.i.value[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()

Alternatively, a custom init/reduce function can be passed:

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum.value == model.product.value).all()
>>> intermediate = model.sum.value

>>> y = model(x)
>>> assert (model.sum.value == intermediate*2).all()
>>> assert (model.product.value == intermediate**2).all()
Parameters:
  • variable_type (type[Variable[B]] | str) – The Variable type for the stored value. Typically Intermediate is used to indicate an intermediate value.

  • name (str) – A string denoting the Module attribute name, where the sowed value is stored.

  • value (A) – The value to be stored.

  • reduce_fn (Callable[[B, A], B]) – The function used to combine the existing value with the new value. The default is to append the value to a tuple.

  • init_fn (Callable[[], B]) – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

Return type:

bool

train(**attributes)

Sets the Module to training mode.

train uses set_attributes to recursively set attributes deterministic=False and use_running_average=False of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Parameters:

**attributes – additional attributes passed to set_attributes.

class probly.layers.flax.DropConnectLinear(*args, **kwargs)[source]

Bases: Module

Custom Linear layer with DropConnect applied to weights during training based on [ASSR20].

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

kernel

nnx.Param, weight matrix of the layer.

bias

nnx.Param, bias of the layer.

in_features

int, number of input features.

out_features

int, number of output features.

use_bias

bool, whether to add bias to the output.

dtype

typing.Optional[flax.typing.Dtype], the dtype of the computation (default: infer from input and params).

param_dtype

flax.typing.Dtype, the dtype passed to parameter initializers.

precision

flax.typing.PrecisionLike, numerical precision of the computation see jax.lax.Precision for details.

kernel_init

flax.typing.Initializer, initializer function for the weight matrix.

bias_init

flax.typing.Initializer, initializer function for the bias.

dot_general

flax.typing.DotGeneralT, dot product function.

promote_dtype

flax.typing.PromoteDtypeFn, function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of (inputs, kernel, bias) and a dtype keyword argument, and return a tuple of arrays with the promoted dtype.

preferred_element_type

flax.typing.Dtype, Optional parameter controls the data type output by the dot product. This argument is passed to dot_general function. See jax.lax.dot for details.

rate

float, probability of dropping individual weights.

deterministic

bool, if false the inputs are scaled by 1/(1-rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.

rng_collection

str, the rng collection name to use when requesting a rng key.

rngs

rnglib.Rngs or rnglib.RngStream or None, rng key.

__call__(inputs, *, deterministic=False, rngs=None)[source]

Forward pass of the DropConnectLinear layer.

Parameters:
  • inputs (Array) – jax.Array, input data.

  • deterministic (bool) – bool, if false the inputs are masked, whereas if true, no mask is applied and the inputs are returned as is.

  • rngs (Rngs | RngStream | Array | None) – nnx.Rngs, nnx.RngStream or jax.Array, optional key used to generate the dropconnect mask.

Returns:

jax.Array, layer output.

Return type:

Array

eval(**attributes)

Sets the Module to evaluation mode.

eval uses set_attributes to recursively set attributes deterministic=True and use_running_average=True of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Parameters:

**attributes – additional attributes passed to set_attributes.

iter_children()

Iterates over all children Module’s of the current Module. This method is similar to iter_modules(), except it only iterates over the immediate children, and does not recurse further down.

iter_children creates a generator that yields the key and the Module instance, where the key is a string representing the attribute name of the Module to access the corresponding child Module.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_children():
...  print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
Return type:

Iterator[tuple[Key, Module]]

iter_modules()

Recursively iterates over all nested Module’s of the current Module, including the current Module.

iter_modules creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
Return type:

Iterator[tuple[tuple[Key, …], Module]]

perturb(name, value, variable_type=<class 'flax.nnx.variablelib.Perturbation'>)

Add an zero-value variable (“perturbation”) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation variable.

Since the shape of the perturbation value depends on the shape of the input, a perturbation variable is only created after you run a sample input through the model once.

Note

This creates extra dummy variables of the same size as value, thus occupies more memory. Use it only to debug gradients in training.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = self.perturb('xgrad', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 4))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'xgrad')  # perturbation requires a sample input run
>>> _ = model(x)
>>> assert model.xgrad.value.shape == (1, 3)   # same as the intermediate value
>>> graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation)

>>> # Take gradients on the Param and Perturbation variables
>>> @nnx.grad(argnums=(0, 1))
... def grad_loss(params, perturbations, inputs, targets):
...   model = nnx.merge(graphdef, params, perturbations)
...   return jnp.mean((model(inputs) - targets) ** 2)

>>> (grads, perturbations) = grad_loss(params, perturbations, x, y)
>>> # `perturbations.xgrad.value` is the intermediate gradient
>>> assert not jnp.array_equal(perturbations.xgrad.value, jnp.zeros((1, 3)))
Parameters:
  • name (str) – A string denoting the Module attribute name for the perturbation value.

  • value (Any) – The value to take intermediate gradient.

  • variable_type (str | type[Variable[Any]]) – The Variable type for the stored perturbation. Defaulted at nnx.Perturbation.

set_attributes(*filters, raise_if_not_found=True, **attributes)

Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter’s can be used to set the attributes of specific Modules:

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Parameters:
  • *filters (filterlib.Filter) – Filters to select the Modules to set the attributes of.

  • raise_if_not_found (bool) – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • **attributes (tp.Any) – The attributes to set.

Return type:

None

set_mode(deterministic=None, **kwargs)[source]

Class method used by nnx.set_mode.

Parameters:

deterministic (bool | None) – if True, disables dropconnect masking.

Return type:

dict

sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)

sow() can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call. sow() stores a value in a new Module attribute, denoted by name. The value will be wrapped by a Variable of type variable_type, which can be useful to filter for in split(), state() and pop().

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i.value) == 1 # tuple of length 1
>>> assert model.i.value[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()

Alternatively, a custom init/reduce function can be passed:

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum.value == model.product.value).all()
>>> intermediate = model.sum.value

>>> y = model(x)
>>> assert (model.sum.value == intermediate*2).all()
>>> assert (model.product.value == intermediate**2).all()
Parameters:
  • variable_type (type[Variable[B]] | str) – The Variable type for the stored value. Typically Intermediate is used to indicate an intermediate value.

  • name (str) – A string denoting the Module attribute name, where the sowed value is stored.

  • value (A) – The value to be stored.

  • reduce_fn (Callable[[B, A], B]) – The function used to combine the existing value with the new value. The default is to append the value to a tuple.

  • init_fn (Callable[[], B]) – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

Return type:

bool

train(**attributes)

Sets the Module to training mode.

train uses set_attributes to recursively set attributes deterministic=False and use_running_average=False of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Parameters:

**attributes – additional attributes passed to set_attributes.