probly.method.batchensemble

probly.method.batchensemble(base: Predictor, num_members: int = 1, use_base_weights: bool = False, init: InitMethod = 'normal', r_mean: float = 1.0, r_std: float = 0.5, s_mean: float = 1.0, s_std: float = 0.5, rngs: Rngs | int = 1) BatchEnsemblePredictor[source]

Create a BatchEnsemble predictor from a base predictor based on [WTB20].

Replaces all linear and convolutional layers with their BatchEnsemble counterparts and tags the result with num_members so predict() can tile inputs and wrap outputs as a Sample.

Parameters:
  • base – Predictor, The model in which the layers will be replaced by BatchEnsemble layers.

  • num_members – int, The number of members in the BatchEnsemble.

  • use_base_weights – bool, Whether to use the weights of the base layer as initial weights.

  • init – Initialization scheme for r and s - "normal" (Gaussian, imagenet baseline default) or "random_sign" ({-1, +1}, paper Appendix B).

  • r_mean – float, mean of the Gaussian initializer for r when init="normal".

  • r_std – float, standard deviation of the Gaussian initializer for r when init="normal".

  • s_mean – float, mean of the Gaussian initializer for s when init="normal".

  • s_std – float, standard deviation of the Gaussian initializer for s when init="normal".

  • rngs – nnx.Rngs | int, The rngs used for flax layer initialization.

Returns:

Predictor, The BatchEnsemble predictor.

Raises:
  • ValueError – If num_members is not a positive integer.

  • ValueError – If init is not "normal" or "random_sign".

  • ValueError – If r_std or s_std is not strictly positive when init="normal".