probly.method.batchensemble¶
- probly.method.batchensemble(base: Predictor, num_members: int = 1, use_base_weights: bool = False, init: InitMethod = 'normal', r_mean: float = 1.0, r_std: float = 0.5, s_mean: float = 1.0, s_std: float = 0.5, rngs: Rngs | int = 1) BatchEnsemblePredictor[source]¶
Create a BatchEnsemble predictor from a base predictor based on [WTB20].
Replaces all linear and convolutional layers with their BatchEnsemble counterparts and tags the result with
num_memberssopredict()can tile inputs and wrap outputs as a Sample.- Parameters:
base – Predictor, The model in which the layers will be replaced by BatchEnsemble layers.
num_members – int, The number of members in the BatchEnsemble.
use_base_weights – bool, Whether to use the weights of the base layer as initial weights.
init – Initialization scheme for
rands-"normal"(Gaussian, imagenet baseline default) or"random_sign"({-1, +1}, paper Appendix B).r_mean – float, mean of the Gaussian initializer for
rwheninit="normal".r_std – float, standard deviation of the Gaussian initializer for
rwheninit="normal".s_mean – float, mean of the Gaussian initializer for
swheninit="normal".s_std – float, standard deviation of the Gaussian initializer for
swheninit="normal".rngs – nnx.Rngs | int, The rngs used for flax layer initialization.
- Returns:
Predictor, The BatchEnsemble predictor.
- Raises:
ValueError – If num_members is not a positive integer.
ValueError – If init is not
"normal"or"random_sign".ValueError – If r_std or s_std is not strictly positive when
init="normal".