Multilib DemoΒΆ

%load_ext autoreload
%autoreload 2
import torch
from torch import nn
import torch.nn.functional as F


class Net(nn.Module):
    """Simple Neural Network class with two heads.

    Attributes:
        fc1: nn.Module, first fully connected layer
        fc2: nn.Module, second fully connected layer
        fc31: nn.Module, fully connected layer of first head
        fc32: nn.Module, fully connected layer of second head
    """

    def __init__(self) -> None:
        """Initialize an instance of the Net class."""
        super().__init__()
        self.fc1 = nn.Linear(1, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc31 = nn.Linear(32, 1)

        self.fc32 = nn.Linear(32, 1)
        self.act = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the neural network.

        Args:
            x: torch.Tensor, input data
        Returns:
            torch.Tensor, output data
        """
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        mu = self.fc31(x)
        sigma2 = F.softplus(self.fc32(x))
        x = torch.cat([mu, sigma2], dim=1)
        return x
import jax

jax.Array.__qualname__
'Array'
import jax.numpy as jnp

isinstance(jnp.array([1, 2, 3]), jax.Array)
True
from probly.transformation import dropout

net = Net()
drop_net = dropout(net, p=0.1)
net
Net(
  (fc1): Linear(in_features=1, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=32, bias=True)
  (fc31): Linear(in_features=32, out_features=1, bias=True)
  (fc32): Linear(in_features=32, out_features=1, bias=True)
  (act): ReLU()
)
drop_net
Net(
  (fc1): Linear(in_features=1, out_features=32, bias=True)
  (fc2): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=32, out_features=32, bias=True)
  )
  (fc31): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=32, out_features=1, bias=True)
  )
  (fc32): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=32, out_features=1, bias=True)
  )
  (act): ReLU()
)
drop_net.eval()
drop_net(torch.tensor([[1.0]]))
tensor([[-0.0048,  0.5399]], grad_fn=<CatBackward0>)
from probly.representation import Distribution

x = torch.rand((64, 1))
distribution = Distribution(drop_net)
outputs = distribution.predict(x, num_samples=20).tensor
print(outputs.shape)
torch.Size([64, 20, 2])
from probly.transformation import ensemble

ensemble = ensemble(Net(), n_members=5)
from probly.representation.sampling.sampler import EnsembleSampler

ensemble_sampler = EnsembleSampler(ensemble)
x = torch.rand((64, 1))
outputs = ensemble_sampler.sample(x).tensor
print(outputs.shape)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[19], line 4
      2 ensemble_sampler = EnsembleSampler(ensemble)
      3 x = torch.rand((64, 1))
----> 4 outputs = ensemble_sampler.sample(x).tensor
      5 print(outputs.shape)

File ~/Documents/PhD/UncertaintyPackage/probly/src/probly/representation/sampling/sampler.py:129, in EnsembleSampler.sample(self, *args, **kwargs)
    126 def sample(self, *args: In, **kwargs: Unpack[KwIn]) -> Sample[Out]:
    127     """Sample from the ensemble predictor for a given input."""
    128     return self.sample_factory(
--> 129         self.predictor(*args, **kwargs),
    130     )

File ~/Documents/PhD/UncertaintyPackage/probly/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
   1773     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1774 else:
-> 1775     return self._call_impl(*args, **kwargs)

File ~/Documents/PhD/UncertaintyPackage/probly/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
   1781 # If we don't have any hooks, we want to skip the rest of the logic in
   1782 # this function, and just call forward.
   1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1784         or _global_backward_pre_hooks or _global_backward_hooks
   1785         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786     return forward_call(*args, **kwargs)
   1788 result = None
   1789 called_always_called_hooks = set()

File ~/Documents/PhD/UncertaintyPackage/probly/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py:399, in _forward_unimplemented(self, *input)
    388 def _forward_unimplemented(self, *input: Any) -> None:
    389     r"""Define the computation performed at every call.
    390 
    391     Should be overridden by all subclasses.
   (...)
    397         registered hooks while the latter silently ignores them.
    398     """
--> 399     raise NotImplementedError(
    400         f'Module [{type(self).__name__}] is missing the required "forward" function'
    401     )

NotImplementedError: Module [ModuleList] is missing the required "forward" function