[docs]defsample(mu:torch.Tensor,logvar:torch.Tensor)->torch.Tensor:"""Sample from mean and variance tensor. Args: mu (torch.Tensor): Input Mean. logvar (torch.Tensor): Input log(variance). Returns: torch.Tensor: Sampled tensor from a normal distribution \ based on the input mean and variance. """assertmu.ndim==2assertlogvar.ndim==2epsilon=torch.normal(mean=0,std=1.0,size=mu.shape).detach()returnmu+torch.exp(0.5*logvar)*epsilon
[docs]defneg_log_likelihood(y:torch.Tensor,mu:torch.Tensor,logvar:torch.Tensor)->torch.Tensor:"""Negative log likelihood. Args: y (torch.Tensor): Label(integer) mu (torch.Tensor): Input mean tensor logvar (torch.Tensor): Input log(variance) tensor Returns: torch.Tensor: Negative log likelihood loss wrt label. """var=torch.exp(logvar)loss=logvar+torch.pow(y-mu,2)/varreturnloss
[docs]classMVE(Base):"""Implementation of Mean and Variance Estimation: https://doi.org/10.1109/ICNN.1994.374138. This implementation utilizes cross entropy loss for classification models and negative log likelihood loss. Args: model (ripple.model.RippleModel): Base Ripple Model """mu:nn.Modulelogvar:nn.Module
[docs]deftrain_forward(self,x:torch.Tensor,y:torch.Tensor)->Tuple[torch.Tensor,torch.Tensor]:"""Get mve loss for input data and label. Args: x (torch.Tensor): Input training data y (torch.Tensor): Ground truth label Returns: Tuple[torch.Tensor, torch.Tensor]: MVE Loss, prediction for input data. """y_hat,(mu,logvar)=self.forward(x)loss=self.get_loss(y,mu,logvar)returnloss,y_hat
[docs]defget_loss(self,y:torch.Tensor,mu:torch.Tensor,logvar:torch.Tensor)->torch.Tensor:"""Implementation of MVE loss. Args: y (torch.Tensor): Ground truth label mu (torch.Tensor): Mean tensor logvar (torch.Tensor): Log(variance) tensor Returns: torch.Tensor: MVE Loss """ifself.is_classification:z=sample(mu,logvar)loss=nn.CrossEntropyLoss()(z,y.long())else:loss=neg_log_likelihood(y,mu,logvar)returnloss