Source code for ripple.base
from copy import deepcopy
import torch
import torch.nn as nn
from .constants import FEATURE_EXTRACTOR, OUTPUT_LAYER
from .model import RippleModel
[docs]class Base(nn.Module):
[docs] def __init__(self, model: RippleModel) -> None:
"""Base Class for Ripple modules.
Args:
model (RippleModel): Ripple Model
"""
super().__init__()
self.model = model
[docs] def copy_layer(self, name: str) -> nn.Module:
"""Provides a copy of a particular part of the model.
Args:
name (Literal[FEATURE_EXTRACTOR, OUTPUT_LAYER]): Supported \
layers that can be copied.
Raises:
NotImplementedError: If the input argument is not \
in the supported choices.
Returns:
nn.Module: A copy of the requested module
"""
assert name in [FEATURE_EXTRACTOR, OUTPUT_LAYER]
if name == FEATURE_EXTRACTOR:
layer = self.model.feature_extractor
elif name == OUTPUT_LAYER:
layer = self.model.output_layer[0]
else:
raise NotImplementedError("Unsupported layer name")
assert isinstance(layer, nn.Module), (
f"Expected nn.Module, received {type(layer)} for Name: {name} Model:"
f" {self.model}"
)
state_dict = deepcopy(layer.state_dict())
new_layer = deepcopy(layer)
new_layer.load_state_dict(state_dict=state_dict)
return new_layer
@property
def is_classification(self):
return self.model.is_classification
def get_features(self, input_tensor: torch.Tensor, training: bool) -> torch.Tensor:
self.model.feature_extractor.train(training)
return self.model.feature_extractor(input_tensor)
def get_output(self, input_tensor: torch.Tensor, training: bool) -> torch.Tensor:
self.model.feature_extractor.train(training)
self.model.output_layer.train(training)
return self.model(input_tensor)
def get_output_from_features(
self, input_features: torch.Tensor, training
) -> torch.Tensor:
self.model.output_layer.train(training)
return self.model.output_layer(input_features)