Source code for ripple.models.basic

from typing import List

import torch
import torch.nn as nn


[docs]class SimpleFF(nn.Module):
[docs] def __init__(self, input_dim: int, layer_config: List[int]) -> None: """Simple Feedforward network. Args: input_dim (int): Input 1 dimension shape for the model layer_config (List[int]): List of integers that add a \ Linear layer + ReLU to the model(except for the last \ layer, where only a Linear Layer is added) """ super().__init__() assert len(layer_config) > 0 self.layer_config = layer_config self.input_dim = input_dim self.layers = [] self.feature_extractor = [] self.output_layer = [] if len(layer_config) > 1: for i in range(len(layer_config) - 1): self.feature_extractor.append(nn.Linear(input_dim, layer_config[i])) self.feature_extractor.append(nn.ReLU()) input_dim = layer_config[i] self.output_layer.append(nn.Linear(layer_config[i], layer_config[-1])) # self.output_layer.append(nn.Softmax()) self.layers = self.feature_extractor + self.output_layer self.module = nn.Sequential(*self.layers) self.feature_extractor = nn.Sequential(*self.feature_extractor) self.output_layer = nn.Sequential(*self.output_layer)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.module(x)