2023-03-31 20:05:14 +02:00
|
|
|
"""MLP feed forward stack in torch."""
|
|
|
|
|
|
|
|
from tml.projects.home.recap.model.config import MlpConfig
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from absl import logging
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(module):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""Initializes weights
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
```python
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
# Define a simple linear layer
|
|
|
|
linear_layer = nn.Linear(64, 32)
|
|
|
|
|
|
|
|
# Initialize the weights and biases using _init_weights
|
|
|
|
_init_weights(linear_layer)
|
|
|
|
```
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
if isinstance(module, torch.nn.Linear):
|
|
|
|
torch.nn.init.xavier_uniform_(module.weight)
|
|
|
|
torch.nn.init.constant_(module.bias, 0)
|
|
|
|
|
|
|
|
|
|
|
|
class Mlp(torch.nn.Module):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Multi-Layer Perceptron (MLP) feedforward neural network module in PyTorch.
|
|
|
|
|
|
|
|
This module defines an MLP with customizable layers and activation functions. It is suitable for various
|
|
|
|
applications such as deep learning for tabular data, feature extraction, and more.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
in_features (int): The number of input features or input dimensions.
|
|
|
|
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
To create an instance of the `Mlp` module and use it for forward passes, you can follow these steps:
|
|
|
|
|
|
|
|
```python
|
|
|
|
# Define the configuration for the MLP
|
|
|
|
mlp_config = MlpConfig(
|
|
|
|
layer_sizes=[128, 64], # Specify the sizes of hidden layers
|
|
|
|
batch_norm=True, # Enable batch normalization
|
|
|
|
dropout=0.2, # Apply dropout with a rate of 0.2
|
|
|
|
final_layer_activation=True # Apply ReLU activation to the final layer
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create an instance of the MLP module
|
|
|
|
mlp_model = Mlp(in_features=input_dim, mlp_config=mlp_config)
|
|
|
|
|
|
|
|
# Generate an input tensor
|
|
|
|
input_tensor = torch.randn(batch_size, input_dim)
|
|
|
|
|
|
|
|
# Perform a forward pass through the MLP
|
|
|
|
outputs = mlp_model(input_tensor)
|
|
|
|
|
|
|
|
# Access the output and shared layer
|
|
|
|
output = outputs["output"]
|
|
|
|
shared_layer = outputs["shared_layer"]
|
|
|
|
```
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The `Mlp` class allows you to create customizable MLP architectures by specifying the layer sizes,
|
|
|
|
enabling batch normalization and dropout, and choosing the activation function for the final layer.
|
|
|
|
|
|
|
|
Warning:
|
|
|
|
This class is intended for internal use within neural network architectures and should not be
|
|
|
|
directly accessed or modified by external code.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
def __init__(self, in_features: int, mlp_config: MlpConfig):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Initializes the Mlp module.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
in_features (int): The number of input features or input dimensions.
|
|
|
|
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
super().__init__()
|
|
|
|
self._mlp_config = mlp_config
|
|
|
|
input_size = in_features
|
|
|
|
layer_sizes = mlp_config.layer_sizes
|
|
|
|
modules = []
|
|
|
|
for layer_size in layer_sizes[:-1]:
|
|
|
|
modules.append(torch.nn.Linear(input_size, layer_size, bias=True))
|
|
|
|
|
|
|
|
if mlp_config.batch_norm:
|
|
|
|
modules.append(
|
|
|
|
torch.nn.BatchNorm1d(
|
|
|
|
layer_size, affine=mlp_config.batch_norm.affine, momentum=mlp_config.batch_norm.momentum
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
modules.append(torch.nn.ReLU())
|
|
|
|
|
|
|
|
if mlp_config.dropout:
|
|
|
|
modules.append(torch.nn.Dropout(mlp_config.dropout.rate))
|
|
|
|
|
|
|
|
input_size = layer_size
|
|
|
|
modules.append(torch.nn.Linear(input_size, layer_sizes[-1], bias=True))
|
|
|
|
if mlp_config.final_layer_activation:
|
|
|
|
modules.append(torch.nn.ReLU())
|
|
|
|
self.layers = torch.nn.ModuleList(modules)
|
|
|
|
self.layers.apply(_init_weights)
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Performs a forward pass through the MLP.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Output tensor of the MLP.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
net = x
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
|
net = layer(net)
|
|
|
|
if i == 1: # Share the first (widest?) set of activations for other applications.
|
|
|
|
shared_layer = net
|
|
|
|
return {"output": net, "shared_layer": shared_layer}
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shared_size(self):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Returns the size of the shared layer in the MLP.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: Size of the shared layer.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
return self._mlp_config.layer_sizes[-1]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def out_features(self):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Returns the number of output features from the MLP.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: Number of output features.
|
|
|
|
"""
|
|
|
|
|
2023-03-31 20:05:14 +02:00
|
|
|
return self._mlp_config.layer_sizes[-1]
|