probly.transformation.batchensemble

probly.transformation.batchensemble(base, num_members=1, use_base_weights=False, s_mean=1.0, s_std=0.01, r_mean=1.0, r_std=0.01, rngs=1)[source]

Create a Batchensemble predictor from a base predictor.

It calls a traverser to replace all linear and convolutional layers by their BatchEnsemble counterparts.

Parameters:
  • base (T) – Predictor, The model in which the layers will be replaced by BatchEnsemble layers.

  • num_members (int) – int, The number of members in the BatchEnsemble.

  • use_base_weights (bool) – bool, Whether to use the weights of the base layer as prior means.

  • s_mean (float) – float, The mean of the input modulation s, drawn from nn.init._normal(s_mean, s_std).

  • s_std (float) – float, The standard deviation of the input modulation s, drawn from nn.init._normal(s_mean, s_std).

  • r_mean (float) – float, The mean of the output modulation r, drawn from nn.init._normal(r_mean, r_std).

  • r_std (float) – float, The standard deviation of the output modulation r, drawn from nn.init._normal(r_mean, r_std).

  • rngs (Rngs | int) – nnx.Rngs | int, The rngs used for flax layer initialization.

Returns:

Predictor, The BatchEnsemble predictor.

Raises:
Return type:

T