Multilib DemoΒΆ
%load_ext autoreload
%autoreload 2
import torch
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
"""Simple Neural Network class with two heads.
Attributes:
fc1: nn.Module, first fully connected layer
fc2: nn.Module, second fully connected layer
fc31: nn.Module, fully connected layer of first head
fc32: nn.Module, fully connected layer of second head
"""
def __init__(self) -> None:
"""Initialize an instance of the Net class."""
super().__init__()
self.fc1 = nn.Linear(1, 32)
self.fc2 = nn.Linear(32, 32)
self.fc31 = nn.Linear(32, 1)
self.fc32 = nn.Linear(32, 1)
self.act = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the neural network.
Args:
x: torch.Tensor, input data
Returns:
torch.Tensor, output data
"""
x = self.act(self.fc1(x))
x = self.act(self.fc2(x))
mu = self.fc31(x)
sigma2 = F.softplus(self.fc32(x))
x = torch.cat([mu, sigma2], dim=1)
return x
import jax
jax.Array.__qualname__
'Array'
import jax.numpy as jnp
isinstance(jnp.array([1, 2, 3]), jax.Array)
True
from probly.transformation import dropout
net = Net()
drop_net = dropout(net, p=0.1)
net
Net(
(fc1): Linear(in_features=1, out_features=32, bias=True)
(fc2): Linear(in_features=32, out_features=32, bias=True)
(fc31): Linear(in_features=32, out_features=1, bias=True)
(fc32): Linear(in_features=32, out_features=1, bias=True)
(act): ReLU()
)
drop_net
Net(
(fc1): Linear(in_features=1, out_features=32, bias=True)
(fc2): Sequential(
(0): Dropout(p=0.1, inplace=False)
(1): Linear(in_features=32, out_features=32, bias=True)
)
(fc31): Sequential(
(0): Dropout(p=0.1, inplace=False)
(1): Linear(in_features=32, out_features=1, bias=True)
)
(fc32): Sequential(
(0): Dropout(p=0.1, inplace=False)
(1): Linear(in_features=32, out_features=1, bias=True)
)
(act): ReLU()
)
drop_net.eval()
drop_net(torch.tensor([[1.0]]))
tensor([[-0.0048, 0.5399]], grad_fn=<CatBackward0>)
from probly.representation import Distribution
x = torch.rand((64, 1))
distribution = Distribution(drop_net)
outputs = distribution.predict(x, num_samples=20).tensor
print(outputs.shape)
torch.Size([64, 20, 2])
from probly.transformation import ensemble
ensemble = ensemble(Net(), n_members=5)
from probly.representation.sampling.sampler import EnsembleSampler
ensemble_sampler = EnsembleSampler(ensemble)
x = torch.rand((64, 1))
outputs = ensemble_sampler.sample(x).tensor
print(outputs.shape)
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[19], line 4
2 ensemble_sampler = EnsembleSampler(ensemble)
3 x = torch.rand((64, 1))
----> 4 outputs = ensemble_sampler.sample(x).tensor
5 print(outputs.shape)
File ~/Documents/PhD/UncertaintyPackage/probly/src/probly/representation/sampling/sampler.py:129, in EnsembleSampler.sample(self, *args, **kwargs)
126 def sample(self, *args: In, **kwargs: Unpack[KwIn]) -> Sample[Out]:
127 """Sample from the ensemble predictor for a given input."""
128 return self.sample_factory(
--> 129 self.predictor(*args, **kwargs),
130 )
File ~/Documents/PhD/UncertaintyPackage/probly/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py:1775, in Module._wrapped_call_impl(self, *args, **kwargs)
1773 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1774 else:
-> 1775 return self._call_impl(*args, **kwargs)
File ~/Documents/PhD/UncertaintyPackage/probly/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py:1786, in Module._call_impl(self, *args, **kwargs)
1781 # If we don't have any hooks, we want to skip the rest of the logic in
1782 # this function, and just call forward.
1783 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1784 or _global_backward_pre_hooks or _global_backward_hooks
1785 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1786 return forward_call(*args, **kwargs)
1788 result = None
1789 called_always_called_hooks = set()
File ~/Documents/PhD/UncertaintyPackage/probly/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py:399, in _forward_unimplemented(self, *input)
388 def _forward_unimplemented(self, *input: Any) -> None:
389 r"""Define the computation performed at every call.
390
391 Should be overridden by all subclasses.
(...)
397 registered hooks while the latter silently ignores them.
398 """
--> 399 raise NotImplementedError(
400 f'Module [{type(self).__name__}] is missing the required "forward" function'
401 )
NotImplementedError: Module [ModuleList] is missing the required "forward" function