probly.transformation.batchensemble¶
- probly.transformation.batchensemble(base, num_members=1, use_base_weights=False, s_mean=1.0, s_std=0.01, r_mean=1.0, r_std=0.01, rngs=1)[source]¶
Create a Batchensemble predictor from a base predictor.
It calls a traverser to replace all linear and convolutional layers by their BatchEnsemble counterparts.
- Parameters:
base (T) – Predictor, The model in which the layers will be replaced by BatchEnsemble layers.
num_members (int) – int, The number of members in the BatchEnsemble.
use_base_weights (bool) – bool, Whether to use the weights of the base layer as prior means.
s_mean (float) – float, The mean of the input modulation s, drawn from nn.init._normal(s_mean, s_std).
s_std (float) – float, The standard deviation of the input modulation s, drawn from nn.init._normal(s_mean, s_std).
r_mean (float) – float, The mean of the output modulation r, drawn from nn.init._normal(r_mean, r_std).
r_std (float) – float, The standard deviation of the output modulation r, drawn from nn.init._normal(r_mean, r_std).
rngs (Rngs | int) – nnx.Rngs | int, The rngs used for flax layer initialization.
- Returns:
Predictor, The BatchEnsemble predictor.
- Raises:
ValueError – If num_members is not a positive integer.
ValueError – If s_std is not greater than 0.
ValueError – If s_mean is not greater than 0.
ValueError – If r_std is not greater than 0.
ValueError – If r_mean is not greater than 0.
- Return type:
T