probly.layers.flax.DropConnectLinear¶
- class probly.layers.flax.DropConnectLinear(*args: Any, **kwargs: Any)[source]¶
Bases:
ModuleCustom Linear layer with DropConnect applied to weights during training based on [ASSR20].
- kernel¶
nnx.Param, weight matrix of the layer.
- bias¶
nnx.Param, bias of the layer.
- in_features¶
int, number of input features.
- out_features¶
int, number of output features.
- use_bias¶
bool, whether to add bias to the output.
- dtype¶
typing.Optional[flax.typing.Dtype], the dtype of the computation (default: infer from input and params).
- param_dtype¶
flax.typing.Dtype, the dtype passed to parameter initializers.
- precision¶
flax.typing.PrecisionLike, numerical precision of the computation see
jax.lax.Precisionfor details.
- kernel_init¶
flax.typing.Initializer, initializer function for the weight matrix.
- bias_init¶
flax.typing.Initializer, initializer function for the bias.
- dot_general¶
flax.typing.DotGeneralT, dot product function.
- promote_dtype¶
flax.typing.PromoteDtypeFn, function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(inputs, kernel, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.
- preferred_element_type¶
flax.typing.Dtype, Optional parameter controls the data type output by the dot product. This argument is passed to
dot_generalfunction. Seejax.lax.dotfor details.
- rate¶
float, probability of dropping individual weights.
- deterministic¶
bool, if false the inputs are scaled by
1/(1-rate)and masked, whereas if true, no mask is applied and the inputs are returned as is.
- rng_collection¶
str, the rng collection name to use when requesting a rng key.
- rngs¶
rnglib.Rngs or rnglib.RngStream or None, rng key.
Initialize a DropConnectLinear layer based on a given linear base layer.
- Parameters:
base_layer – nnx.Linear, The original linear layer to be wrapped.
rate – float, the dropconnect probability.
rng_collection – str, rng collection name to use when requesting a rng key.
rngs – nnx.Rngs or nn.RngStream or None, rng key.
- __call__(inputs: Array, *, deterministic: bool = False, rngs: Rngs | RngStream | Array | None = None) Array[source]¶
Forward pass of the DropConnectLinear layer.
- Parameters:
inputs – jax.Array, input data.
deterministic – bool, if false the inputs are masked, whereas if true, no mask is applied and the inputs are returned as is.
rngs – nnx.Rngs, nnx.RngStream or jax.Array, optional key used to generate the dropconnect mask.
- Returns:
jax.Array, layer output.
- eval(**attributes)[source]¶
Sets the Module to evaluation mode.
evalusesset_attributesto recursively set attributesdeterministic=Trueanduse_running_average=Trueof all nested Modules that have these attributes. Its primarily used to control the runtime behavior of theDropoutandBatchNormModules.Example:
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.eval() >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
- Parameters:
**attributes – additional attributes passed to
set_attributes.
- iter_children() Iterator[tuple[Key, Module]][source]¶
Warning: this method is method is deprecated; use
iter_children()instead.Iterates over all children
Module’s of the current Module. This method is similar toiter_modules(), except it only iterates over the immediate children, and does not recurse further down. Alias ofiter_children().
- iter_modules() Iterator[tuple[tuple[Key, ...], Module]][source]¶
Warning: this method is method is deprecated; use
iter_modules()instead.Recursively iterates over all nested
Module’s of the current Module, including the current Module. Alias ofiter_modules().
- perturb(name: str, value: Any, variable_type: str | type[Variable[Any]] = <class 'flax.nnx.variablelib.Perturbation'>)[source]¶
Extract gradients of intermediate values during training.
Used with
nnx.capture()to record intermediate values in the forward pass and their gradients in the backward pass. Returns the value plus whatever perturbation is stored undernamein the current capture context, allowing gradient computation viannx.grad.The workflow has four steps: 1. Initialize perturbations with
nnx.capture(model, nnx.Perturbation)2. Run model withnnx.capture(model, nnx.Intermediate, init=perturbations)3. Take gradients with respect to perturbations usingnnx.grad4. Combine results withnnx.merge_state(perturb_grads, intermediates)Note
This creates extra variables of the same size as
value, thus occupies more memory. Use it only to debug gradients in training.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __call__(self, x): ... x2 = self.perturb('grad_of_x', x) ... return 3 * x2 >>> model = Model() >>> x = 1.0 >>> # Step 1: Initialize perturbations >>> forward = nnx.capture(model, nnx.Perturbation) >>> _, perturbations = forward(x) >>> # Steps 2-4: Capture gradients >>> def train_step(model, perturbations, x): ... def loss(model, perturbations, x): ... return nnx.capture(model, nnx.Intermediate, init=perturbations)(x) ... (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x) ... return nnx.merge_state(perturb_grads, sowed) >>> metrics = train_step(model, perturbations, x) >>> # metrics contains gradients of intermediate values
- Parameters:
name – A string key for storing the perturbation value.
value – The intermediate value to capture gradients for. You must use the returned value (not the original) for gradient capturing to work.
variable_type – The
Variabletype for the stored perturbation. Default isnnx.Perturbation.
- set_attributes(*filters: filterlib.Filter, raise_if_not_found: bool = True, graph: bool | None = None, **attributes: tp.Any) None[source]¶
Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.
Example:
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.set_attributes(deterministic=True, use_running_average=True) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
Filter’s can be used to set the attributes of specific Modules:>>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.set_attributes(nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, False)
- Parameters:
*filters – Filters to select the Modules to set the attributes of.
raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.
**attributes – The attributes to set.
- set_mode(deterministic: bool | None = None, **kwargs) dict[source]¶
Class method used by
nnx.set_mode.- Parameters:
deterministic – if True, disables dropconnect masking.
- sow(variable_type: type[~flax.nnx.variablelib.Variable[~flax.nnx.module.B]] | str, name: str, value: ~flax.nnx.module.A, reduce_fn: ~typing.Callable[[~flax.nnx.module.B, ~flax.nnx.module.A], ~flax.nnx.module.B] = <function <lambda>>, init_fn: ~typing.Callable[[], ~flax.nnx.module.B] = <function <lambda>>) bool[source]¶
Store intermediate values during module execution for later extraction.
Used with
nnx.capture()decorator to collect intermediate values without explicitly passing containers through module calls. Values are stored under the specifiednamein a collection associated withvariable_type.By default, values are appended to a tuple, allowing multiple values to be tracked when the same module is called multiple times.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'features', x) ... x = self.linear2(x) ... return x >>> # With the capture decorator, sow returns intermediates >>> model = Model(rngs=nnx.Rngs(0)) >>> @nnx.capture(nnx.Intermediate) ... def forward(model, x): ... return model(x) >>> result, intermediates = forward(model, jnp.ones(2)) >>> assert 'features' in intermediates
Custom init/reduce functions can be passed to control accumulation:
>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... x = self.linear(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... return x
- Parameters:
variable_type – The
Variabletype for the stored value. TypicallyIntermediateor a subclass is used.name – A string key for storing the value in the collection.
value – The value to be stored.
reduce_fn – Function to combine existing and new values. Default appends to a tuple.
init_fn – Function providing initial value for first
reduce_fncall. Default is an empty tuple.
- train(**attributes)[source]¶
Sets the Module to training mode.
trainusesset_attributesto recursively set attributesdeterministic=Falseanduse_running_average=Falseof all nested Modules that have these attributes. Its primarily used to control the runtime behavior of theDropoutandBatchNormModules.Example:
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... # initialize Dropout and BatchNorm in eval mode ... self.dropout = nnx.Dropout(0.5, deterministic=True) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) >>> block.train() >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False)
- Parameters:
**attributes – additional attributes passed to
set_attributes.