2023-03-31 20:05:14 +02:00
|
|
|
"""MaskNet: Wang et al. (https://arxiv.org/abs/2102.07619)."""
|
|
|
|
|
|
|
|
from tml.projects.home.recap.model import config, mlp
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(module):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""Initializes weights
|
|
|
|
|
|
|
|
Example
|
|
|
|
|
|
|
|
```python
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
# Define a simple linear layer
|
|
|
|
linear_layer = nn.Linear(64, 32)
|
|
|
|
|
|
|
|
# Initialize the weights and biases using _init_weights
|
|
|
|
_init_weights(linear_layer)
|
|
|
|
```
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
if isinstance(module, torch.nn.Linear):
|
|
|
|
torch.nn.init.xavier_uniform_(module.weight)
|
|
|
|
torch.nn.init.constant_(module.bias, 0)
|
|
|
|
|
|
|
|
|
|
|
|
class MaskBlock(torch.nn.Module):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
MaskBlock module in a mask-based neural network.
|
|
|
|
|
|
|
|
This module represents a MaskBlock, which applies a masking operation to the input data and then
|
|
|
|
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
|
|
|
input_dim (int): Dimensionality of the input data.
|
|
|
|
mask_input_dim (int): Dimensionality of the mask input.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
To create and use a MaskBlock within a MaskNet, follow these steps:
|
|
|
|
|
|
|
|
```python
|
|
|
|
# Define the configuration for the MaskBlock
|
|
|
|
mask_block_config = MaskBlockConfig(
|
|
|
|
input_layer_norm=True, # Apply input layer normalization
|
|
|
|
reduction_factor=0.5 # Reduce input dimensionality by 50%
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create an instance of the MaskBlock
|
|
|
|
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
|
|
|
|
|
|
|
|
# Generate input tensors
|
|
|
|
input_data = torch.randn(batch_size, 64)
|
|
|
|
mask_input = torch.randn(batch_size, 32)
|
|
|
|
|
|
|
|
# Perform a forward pass through the MaskBlock
|
|
|
|
output = mask_block(input_data, mask_input)
|
|
|
|
```
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
|
|
|
|
operation that combines the input and mask input. Then, it passes the result through a hidden layer
|
|
|
|
with optional dimensionality reduction.
|
|
|
|
|
|
|
|
Warning:
|
|
|
|
This class is intended for internal use within neural network architectures and should not be
|
|
|
|
directly accessed or modified by external code.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
def __init__(
|
|
|
|
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
|
|
|
) -> None:
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Initializes the MaskBlock module.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
|
|
|
input_dim (int): Dimensionality of the input data.
|
|
|
|
mask_input_dim (int): Dimensionality of the mask input.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
"""
|
|
|
|
|
2023-03-31 20:05:14 +02:00
|
|
|
super(MaskBlock, self).__init__()
|
|
|
|
self.mask_block_config = mask_block_config
|
|
|
|
output_size = mask_block_config.output_size
|
|
|
|
|
|
|
|
if mask_block_config.input_layer_norm:
|
|
|
|
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
|
|
|
|
else:
|
|
|
|
self._input_layer_norm = None
|
|
|
|
|
|
|
|
if mask_block_config.reduction_factor:
|
|
|
|
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
|
|
|
|
elif mask_block_config.aggregation_size is not None:
|
|
|
|
aggregation_size = mask_block_config.aggregation_size
|
|
|
|
else:
|
|
|
|
raise ValueError("Need one of reduction factor or aggregation size.")
|
|
|
|
|
|
|
|
self._mask_layer = torch.nn.Sequential(
|
|
|
|
torch.nn.Linear(mask_input_dim, aggregation_size),
|
|
|
|
torch.nn.ReLU(),
|
|
|
|
torch.nn.Linear(aggregation_size, input_dim),
|
|
|
|
)
|
|
|
|
self._mask_layer.apply(_init_weights)
|
|
|
|
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
|
|
|
|
self._hidden_layer.apply(_init_weights)
|
|
|
|
self._layer_norm = torch.nn.LayerNorm(output_size)
|
|
|
|
|
|
|
|
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Performs a forward pass through the MaskBlock.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
net (torch.Tensor): Input data tensor.
|
|
|
|
mask_input (torch.Tensor): Mask input tensor.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Output tensor of the MaskBlock.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
if self._input_layer_norm:
|
|
|
|
net = self._input_layer_norm(net)
|
|
|
|
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
|
|
|
return self._layer_norm(hidden_layer_output)
|
|
|
|
|
|
|
|
|
|
|
|
class MaskNet(torch.nn.Module):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
MaskNet module in a mask-based neural network.
|
|
|
|
|
|
|
|
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
|
|
|
|
create mask-based neural networks with parallel or stacked MaskBlocks.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
|
|
|
in_features (int): Dimensionality of the input data.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
To create and use a MaskNet, you can follow these steps:
|
|
|
|
|
|
|
|
```python
|
|
|
|
# Define the configuration for the MaskNet
|
|
|
|
mask_net_config = MaskNetConfig(
|
|
|
|
use_parallel=True, # Use parallel MaskBlocks
|
|
|
|
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create an instance of the MaskNet
|
|
|
|
mask_net = MaskNet(mask_net_config, in_features=64)
|
|
|
|
|
|
|
|
# Generate input tensors
|
|
|
|
input_data = torch.randn(batch_size, 64)
|
|
|
|
|
|
|
|
# Perform a forward pass through the MaskNet
|
|
|
|
outputs = mask_net(input_data)
|
|
|
|
|
|
|
|
# Access the output and shared layer
|
|
|
|
output = outputs["output"]
|
|
|
|
shared_layer = outputs["shared_layer"]
|
|
|
|
```
|
|
|
|
|
|
|
|
Note:
|
|
|
|
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
|
|
|
|
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
|
|
|
|
|
|
|
|
Warning:
|
|
|
|
This class is intended for internal use within neural network architectures and should not be
|
|
|
|
directly accessed or modified by external code.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Initializes the MaskNet module.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
|
|
|
in_features (int): Dimensionality of the input data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
"""
|
|
|
|
|
2023-03-31 20:05:14 +02:00
|
|
|
super().__init__()
|
|
|
|
self.mask_net_config = mask_net_config
|
|
|
|
mask_blocks = []
|
|
|
|
|
|
|
|
if mask_net_config.use_parallel:
|
|
|
|
total_output_mask_blocks = 0
|
|
|
|
for mask_block_config in mask_net_config.mask_blocks:
|
|
|
|
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
|
|
|
|
total_output_mask_blocks += mask_block_config.output_size
|
|
|
|
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
|
|
|
else:
|
|
|
|
input_size = in_features
|
|
|
|
for mask_block_config in mask_net_config.mask_blocks:
|
|
|
|
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
|
|
|
|
input_size = mask_block_config.output_size
|
|
|
|
|
|
|
|
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
|
|
|
total_output_mask_blocks = mask_block_config.output_size
|
|
|
|
|
|
|
|
if mask_net_config.mlp:
|
|
|
|
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
|
|
|
|
self.out_features = mask_net_config.mlp.layer_sizes[-1]
|
|
|
|
else:
|
|
|
|
self.out_features = total_output_mask_blocks
|
|
|
|
self.shared_size = total_output_mask_blocks
|
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor):
|
2023-09-14 08:00:10 +02:00
|
|
|
"""
|
|
|
|
Performs a forward pass through the MaskNet.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
inputs (torch.Tensor): Input data tensor.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: Output tensor of the MaskNet.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
if self.mask_net_config.use_parallel:
|
|
|
|
mask_outputs = []
|
|
|
|
for mask_layer in self._mask_blocks:
|
|
|
|
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
|
|
|
|
# Share the outputs of the MaskBlocks.
|
|
|
|
all_mask_outputs = torch.cat(mask_outputs, dim=1)
|
|
|
|
output = (
|
|
|
|
all_mask_outputs
|
|
|
|
if self.mask_net_config.mlp is None
|
|
|
|
else self._dense_layers(all_mask_outputs)["output"]
|
|
|
|
)
|
|
|
|
return {"output": output, "shared_layer": all_mask_outputs}
|
|
|
|
else:
|
|
|
|
net = inputs
|
|
|
|
for mask_layer in self._mask_blocks:
|
|
|
|
net = mask_layer(net=net, mask_input=inputs)
|
|
|
|
# Share the output of the stacked MaskBlocks.
|
|
|
|
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
|
|
|
|
return {"output": output, "shared_layer": net}
|