99 lines
3.7 KiB
Python

"""MaskNet: Wang et al. (https://arxiv.org/abs/2102.07619)."""
from tml.projects.home.recap.model import config, mlp
import torch
def _init_weights(module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class MaskBlock(torch.nn.Module):
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
super(MaskBlock, self).__init__()
self.mask_block_config = mask_block_config
output_size = mask_block_config.output_size
if mask_block_config.input_layer_norm:
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
else:
self._input_layer_norm = None
if mask_block_config.reduction_factor:
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
elif mask_block_config.aggregation_size is not None:
aggregation_size = mask_block_config.aggregation_size
else:
raise ValueError("Need one of reduction factor or aggregation size.")
self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size),
torch.nn.ReLU(),
torch.nn.Linear(aggregation_size, input_dim),
)
self._mask_layer.apply(_init_weights)
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
self._hidden_layer.apply(_init_weights)
self._layer_norm = torch.nn.LayerNorm(output_size)
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
if self._input_layer_norm:
net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
return self._layer_norm(hidden_layer_output)
class MaskNet(torch.nn.Module):
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
super().__init__()
self.mask_net_config = mask_net_config
mask_blocks = []
if mask_net_config.use_parallel:
total_output_mask_blocks = 0
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else:
input_size = in_features
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
total_output_mask_blocks = mask_block_config.output_size
if mask_net_config.mlp:
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
self.out_features = mask_net_config.mlp.layer_sizes[-1]
else:
self.out_features = total_output_mask_blocks
self.shared_size = total_output_mask_blocks
def forward(self, inputs: torch.Tensor):
if self.mask_net_config.use_parallel:
mask_outputs = []
for mask_layer in self._mask_blocks:
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
# Share the outputs of the MaskBlocks.
all_mask_outputs = torch.cat(mask_outputs, dim=1)
output = (
all_mask_outputs
if self.mask_net_config.mlp is None
else self._dense_layers(all_mask_outputs)["output"]
)
return {"output": output, "shared_layer": all_mask_outputs}
else:
net = inputs
for mask_layer in self._mask_blocks:
net = mask_layer(net=net, mask_input=inputs)
# Share the output of the stacked MaskBlocks.
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
return {"output": output, "shared_layer": net}