probly.representation.sampling.flax_sampler

Sampling preparation for flax.

Functions

register_forced_train_mode(cls)

Register a class to be forced into train mode during sampling.

probly.representation.sampling.flax_sampler.register_forced_train_mode(cls)[source]

Register a class to be forced into train mode during sampling.

This enables Monte Carlo sampling techniques like MC Dropout [GG16b].

Parameters:

cls (LazyType)

Return type:

None