Source code for ripple.aleatoric

from typing import Tuple

import torch
import torch.nn as nn

from ripple.base import Base
from ripple.constants import OUTPUT_LAYER
from ripple.model import RippleModel


[docs]def sample(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: """Sample from mean and variance tensor. Args: mu (torch.Tensor): Input Mean. logvar (torch.Tensor): Input log(variance). Returns: torch.Tensor: Sampled tensor from a normal distribution \ based on the input mean and variance. """ assert mu.ndim == 2 assert logvar.ndim == 2 epsilon = torch.normal(mean=0, std=1.0, size=mu.shape).detach() return mu + torch.exp(0.5 * logvar) * epsilon
[docs]def neg_log_likelihood( y: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor ) -> torch.Tensor: """Negative log likelihood. Args: y (torch.Tensor): Label(integer) mu (torch.Tensor): Input mean tensor logvar (torch.Tensor): Input log(variance) tensor Returns: torch.Tensor: Negative log likelihood loss wrt label. """ var = torch.exp(logvar) loss = logvar + torch.pow(y - mu, 2) / var return loss
[docs]class MVE(Base): """Implementation of Mean and Variance Estimation: https://doi.org/10.1109/ICNN.1994.374138. This implementation utilizes cross entropy loss for classification models and negative log likelihood loss. Args: model (ripple.model.RippleModel): Base Ripple Model """ mu: nn.Module logvar: nn.Module
[docs] def __init__(self, model: RippleModel) -> None: super().__init__(model) self.mu = self.copy_layer(OUTPUT_LAYER) self.logvar = self.copy_layer(OUTPUT_LAYER)
[docs] def forward( self, x: torch.Tensor, training: bool = True, return_risk: bool = False ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: features = self.get_features(x, training) y_hat = self.get_output_from_features(features, training) mu = self.mu(features) logvar = self.logvar(features) if return_risk: var = torch.exp(logvar) y_hat.aleatoric = var return y_hat, (mu, logvar) return y_hat, (mu, logvar)
[docs] def train_forward( self, x: torch.Tensor, y: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Get mve loss for input data and label. Args: x (torch.Tensor): Input training data y (torch.Tensor): Ground truth label Returns: Tuple[torch.Tensor, torch.Tensor]: MVE Loss, prediction for input data. """ y_hat, (mu, logvar) = self.forward(x) loss = self.get_loss(y, mu, logvar) return loss, y_hat
[docs] def get_loss( self, y: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor ) -> torch.Tensor: """Implementation of MVE loss. Args: y (torch.Tensor): Ground truth label mu (torch.Tensor): Mean tensor logvar (torch.Tensor): Log(variance) tensor Returns: torch.Tensor: MVE Loss """ if self.is_classification: z = sample(mu, logvar) loss = nn.CrossEntropyLoss()(z, y.long()) else: loss = neg_log_likelihood(y, mu, logvar) return loss