Source code for ripple.epistemic
from typing import Tuple
import torch
from ripple.base import Base
from ripple.model import RippleModel
[docs]class DropoutEpistemic(Base):
"""Uses multiple forward passes with dropout \\ enabled to get the mean and std of \
prediction.
Args:
model (RippleModel): RippleModel
T (int): Number of iterations
"""
T: int
[docs] def __init__(self, model: RippleModel, T: int) -> None:
super().__init__(model)
self.T = T
[docs] def forward(self, x: torch.Tensor, training: bool=True) -> \
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
outputs = []
for _ in range(self.T):
outputs.append(self.model(x, True))
stacked_outputs = torch.stack(outputs)
mean, std = stacked_outputs.mean(), stacked_outputs.std()
return mean, (mean, std)