Source code for ripple.model

from typing import Tuple

import torch
import torch.nn as nn


[docs]def get_activation(activation: str) -> nn.Module: """Return's the torch activation module. Args: activation (str): Choices: ["relu", "softmax", "identity"] Returns: nn.Module: The corresponding activation module """ activation = activation.lower() if activation == 'relu': return torch.nn.ReLU() elif activation == 'softmax': return torch.nn.Softmax() elif activation == 'identity': return torch.nn.Identity() else: raise AssertionError(f"Unsupported activation name: {activation}")
[docs]class Conv1dActivationDropout(nn.Module):
[docs] def __init__(self, conv: torch.nn.Conv1d, activation: str, dropout_rate: float) \ -> None: super().__init__() self.conv = conv self.activation = get_activation(activation) self.dropout = torch.nn.Dropout(dropout_rate)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.activation(self.conv(x)))
[docs]class Conv2dActivationDropout(nn.Module):
[docs] def __init__(self, conv: torch.nn.Conv1d, activation: str, dropout_rate: float) \ -> None: super().__init__() self.conv = conv self.activation = get_activation(activation) self.dropout = torch.nn.Dropout2d(dropout_rate)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.activation(self.conv(x)))
[docs]class Conv3dActivationDropout(nn.Module):
[docs] def __init__(self, conv: torch.nn.Conv1d, activation: str, dropout_rate: float) \ -> None: super().__init__() self.conv = conv self.activation = get_activation(activation) self.dropout = torch.nn.Dropout3d(dropout_rate)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.activation(self.conv(x)))
[docs]class LinearActivationDropout(nn.Module):
[docs] def __init__(self, linear: torch.nn.Linear, activation: str, dropout_rate: float) \ -> None: super().__init__() self.linear = linear self.activation = get_activation(activation) self.dropout = torch.nn.Dropout(dropout_rate)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.activation(self.linear(x)))
[docs]def replace_module(module: nn.Module, activation: str, dropout_rate: float) \ -> torch.nn.Module: """Recursively replaces every module with module + dropout inplace. Source: https://github.com/vballoli/nfnets-pytorch/blob/main/nfnets/utils.py#LL9C1-L28C38 Usage: replace_conv(model) #(In-line replacement) Args: module (nn.Module): target's model whose convolutions must be replaced. dropout_rate (float): rate of dropout for that module Returns: nn.Module: new module with added dropout layers """ if isinstance(module, torch.nn.Conv1d): new_module = Conv1dActivationDropout(module, activation, dropout_rate) return new_module elif isinstance(module, torch.nn.Conv2d): new_module = Conv2dActivationDropout(module, activation, dropout_rate) return new_module elif isinstance(module, torch.nn.Conv3d): new_module = Conv3dActivationDropout(module, activation, dropout_rate) return new_module elif isinstance(module, torch.nn.Linear): new_module = LinearActivationDropout(module, activation, dropout_rate) return new_module for name, mod in module.named_children(): if isinstance(mod, torch.nn.Dropout): new_module = torch.nn.Identity() if isinstance(mod, (torch.nn.ReLU, torch.nn.Sigmoid, torch.nn.Softmax)): new_module = torch.nn.Identity() named_children = list(module.named_children()) for name, mod in named_children: if isinstance(mod, torch.nn.Conv1d): new_module = Conv1dActivationDropout( mod, activation, dropout_rate ) elif isinstance(mod, torch.nn.Conv2d): new_module = Conv2dActivationDropout( mod, activation, dropout_rate ) elif isinstance(mod, torch.nn.Conv3d): new_module = Conv3dActivationDropout( mod, activation, dropout_rate ) elif isinstance(mod, torch.nn.Linear): new_module = LinearActivationDropout( mod, activation, dropout_rate ) elif isinstance(mod, torch.nn.Dropout): new_module = torch.nn.Identity() elif isinstance(mod, (Conv1dActivationDropout, Conv2dActivationDropout, Conv3dActivationDropout)): new_module = mod else: new_module = mod setattr(module, name, new_module) for name, mod in named_children: replace_module(mod, activation, dropout_rate) return module
[docs]class RippleModel(nn.Module): """Ripple Model wrapper. This class enables wrapping generic PyTorch models to be compatible \ with Ripple's functionalities. Args: feature_extractor (nn.Module): Feature extractor of the model output_layer (nn.Module): Final output module of the model input_shape (Tuple): Shape of the input tensor is_classification (bool): If this model is a classification model Raises: AssertionError: Multiple dimension outputs are not supported """ feature_extractor: nn.Module output_layer: nn.Module input_shape: Tuple feature_shape: Tuple output_dim: int is_classification: bool
[docs] def __init__( self, feature_extractor: nn.Module, output_layer: nn.Module, input_shape: Tuple, is_classification: bool, ) -> None: super().__init__() self.feature_extractor = feature_extractor self.output_layer = output_layer self.input_shape = input_shape self.is_classification = is_classification test_feature = self.feature_extractor(torch.randn((1, *self.input_shape))) self.feature_shape = test_feature.shape[1:] output_dim = self.output_layer(test_feature).shape[1:] if len(output_dim) > 1: raise AssertionError("More than one dim not supported") self.output_dim = output_dim
[docs] def forward(self, x: torch.Tensor, training: bool=True) -> torch.Tensor: self.feature_extractor.train(training) self.output_layer.train(training) return self.output_layer(self.feature_extractor(x))
def replace_modules(self, feature_dropout_rate: float, feature_activation: str, output_dropout_rate: float, output_activation: str) -> None: self.feature_extractor = replace_module( self.feature_extractor, feature_activation, feature_dropout_rate ) self.output_layer = replace_module( self.output_layer, output_activation, output_dropout_rate )