This commit is contained in:
rajveer43 2023-09-14 11:30:10 +05:30
parent db4ff958f6
commit 0813989fd9
14 changed files with 765 additions and 27 deletions

View File

@ -9,6 +9,36 @@ import pydantic
class TrainingConfig(config_mod.BaseConfig):
"""
Configuration settings for the training process.
This class defines various training-related settings, including the directory to save checkpoints, the number
of training steps, logging intervals, and other training parameters.
Attributes:
save_dir (str): The directory where checkpoints and training artifacts will be saved.
num_train_steps (pydantic.PositiveInt): The total number of training steps to run.
initial_checkpoint_dir (str): The directory containing initial checkpoints (optional).
checkpoint_every_n (pydantic.PositiveInt): Frequency of saving checkpoints during training.
checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
train_log_every_n (pydantic.PositiveInt): Frequency of logging training progress.
num_eval_steps (int): Number of evaluation steps. Use a negative value to evaluate the entire dataset.
eval_log_every_n (pydantic.PositiveInt): Frequency of logging evaluation progress.
eval_timeout_in_s (pydantic.PositiveFloat): Maximum time (in seconds) allowed for evaluation.
gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
Example:
To configure training with checkpoints saved every 1000 steps, use the following settings:
```python
TrainingConfig(
save_dir="/tmp/model",
num_train_steps=1000000,
checkpoint_every_n=1000,
train_log_every_n=1000,
)
```
"""
save_dir: str = "/tmp/model"
num_train_steps: pydantic.PositiveInt = 1000000
initial_checkpoint_dir: str = pydantic.Field(
@ -32,6 +62,42 @@ class TrainingConfig(config_mod.BaseConfig):
class RecapConfig(config_mod.BaseConfig):
"""
Configuration settings for the Recap model training process.
This class defines the overall configuration for the training process of a Recap model. It includes settings for
training, model architecture, data, optimization, and evaluation.
Attributes:
training (TrainingConfig): Configuration settings for the training process.
model (model_config.ModelConfig): Configuration settings for the Recap model architecture.
train_data (data_config.RecapDataConfig): Configuration settings for training data.
validation_data (Dict[str, data_config.RecapDataConfig]): Configuration settings for validation data.
optimizer (optimizer_config.RecapOptimizerConfig): Configuration settings for optimization.
which_metrics (Optional[str]): Optional specification of which metrics to pick.
Note:
This class encapsulates all the necessary configurations to train a Recap model. It defines settings for
training, the model architecture, data loading, optimization, and evaluation.
Example:
To configure a Recap model training process, use the following settings:
```python
RecapConfig(
training=TrainingConfig(
save_dir="/tmp/model",
num_train_steps=1000000,
checkpoint_every_n=1000,
train_log_every_n=1000,
),
model=model_config.ModelConfig(...),
train_data=data_config.RecapDataConfig(...),
validation_data={"dev": data_config.RecapDataConfig(...)},
optimizer=optimizer_config.RecapOptimizerConfig(...),
)
```
"""
training: TrainingConfig = pydantic.Field(TrainingConfig())
model: model_config.ModelConfig
train_data: data_config.RecapDataConfig

View File

@ -30,6 +30,7 @@ class EmbeddingBagConfig(base_config.BaseConfig):
class EmbeddingOptimizerConfig(base_config.BaseConfig):
"""Configuration for the optimizer used for embedding tables."""
learning_rate: optimizer_config.LearningRate = pydantic.Field(
None, description="learning rate scheduler for the EBC"
)
@ -52,6 +53,7 @@ class LargeEmbeddingsConfig(base_config.BaseConfig):
class StratifierConfig(base_config.BaseConfig):
"""Configuration for Stratifier."""
name: str
index: int
value: int
@ -87,7 +89,8 @@ class SmallEmbeddingsConfig(base_config.BaseConfig):
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
serving time due to size (>>1 GB).
This small embeddings table uses the same optimizer as the rest of the model."""
This small embeddings table uses the same optimizer as the rest of the model.
"""
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
..., description="list of embedding tables"

View File

@ -34,6 +34,33 @@ FLAGS = flags.FLAGS
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
"""
Main function to run the training of a ranking model.
This function initializes and runs the training process for a ranking model based on the provided configuration.
Args:
unused_argv (str): Unused argument.
data_service_dispatcher (Optional[str]): The data service dispatcher for accessing training data (optional).
Returns:
None
Raises:
AssertionError: If the configuration or input data is not valid.
Note:
This function serves as the main entry point for training a ranking model. It loads the configuration, sets up
the training environment, defines the loss function, creates the model, optimizer, and scheduler, and runs the
training loop.
Example:
To run the training process, use the following command:
```
python run_training.py --config_path=config.yaml
```
"""
print("#" * 100)
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)

View File

@ -40,6 +40,7 @@ class BatchNormConfig(base_config.BaseConfig):
class DenseLayerConfig(base_config.BaseConfig):
"""Configuration for the dense layer."""
layer_size: pydantic.PositiveInt
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
@ -61,6 +62,7 @@ class BatchNormConfig(base_config.BaseConfig):
class DoubleNormLogConfig(base_config.BaseConfig):
"""Configuration for the double norm log transform."""
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
clip_magnitude: float = pydantic.Field(
5.0, description="Threshold to clip the normalized input values."
@ -73,12 +75,14 @@ class Log1pAbsConfig(base_config.BaseConfig):
class ClipLog1pAbsConfig(base_config.BaseConfig):
"""Configuration for the clip log transform."""
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
3e38, description="Threshold to clip the input values."
)
class ZScoreLogConfig(base_config.BaseConfig):
"""Configuration for the z-score log transform."""
analysis_path: str
schema_path: str = pydantic.Field(
None,
@ -148,6 +152,7 @@ class DcnConfig(base_config.BaseConfig):
class MaskBlockConfig(base_config.BaseConfig):
"""Config for MaskNet block."""
output_size: int
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
None, one_of="aggregation_size"
@ -159,6 +164,7 @@ class MaskBlockConfig(base_config.BaseConfig):
class MaskNetConfig(base_config.BaseConfig):
"""Config for MaskNet model."""
mask_blocks: List[MaskBlockConfig]
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
@ -190,6 +196,7 @@ class AffineMap(base_config.BaseConfig):
class DLRMConfig(base_config.BaseConfig):
"""Config for DLRM model."""
bottom_mlp: MlpConfig = pydantic.Field(
...,
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
@ -198,6 +205,7 @@ class DLRMConfig(base_config.BaseConfig):
class TaskModel(base_config.BaseConfig):
"""Configuration for a single task."""
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
@ -213,6 +221,7 @@ class TaskModel(base_config.BaseConfig):
class MultiTaskType(str, enum.Enum):
"""Type of multi task architecture."""
SHARE_NONE = "share_none" # Tasks are separate.
SHARE_ALL = "share_all" # Tasks share same backbone.
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
@ -247,6 +256,7 @@ class ModelConfig(base_config.BaseConfig):
@pydantic.root_validator()
def _validate_mtl(cls, values):
"""Validate the multi task architecture."""
if values.get("multi_task_type", None) is None:
return values
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:

View File

@ -26,7 +26,19 @@ def unsanitize(sanitized_task_name):
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
""" "Builds a model for a single task"""
"""
Build a model for a single task based on the provided configuration.
Args:
task (model_config_mod.TaskModel): The task model configuration.
input_shape (int): The input shape for the model.
Returns:
torch.nn.Module: The constructed model for the single task.
Raises:
ValueError: If the task configuration is not recognized.
"""
if task.mlp_config:
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
elif task.dcn_config:
@ -38,7 +50,12 @@ def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int)
class MultiTaskRankingModel(torch.nn.Module):
"""Multi-task ranking model."""
"""
Multi-task ranking model that handles multiple ranking tasks simultaneously.
This model takes various input features and predicts rankings for multiple
tasks using shared or separate towers.
"""
def __init__(
self,
@ -47,12 +64,18 @@ class MultiTaskRankingModel(torch.nn.Module):
data_config: RecapDataConfig,
return_backbone: bool = False,
):
"""Constructor for Multi task learning.
"""
Constructor for Multi-task ranking model.
Args:
input_shapes (Mapping[str, torch.Size]): A mapping of input feature names to their shapes.
config (ModelConfig): The model configuration.
data_config (RecapDataConfig): The data configuration.
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
Assumptions made:
1. Tasks specified in data config match model architecture.
These are all validated in config.
1. Tasks specified in data config match model architecture.
These are all validated in config.
"""
super().__init__()
@ -168,6 +191,23 @@ class MultiTaskRankingModel(torch.nn.Module):
labels: Optional[torch.Tensor] = None,
weights: Optional[torch.Tensor] = None,
):
"""
Forward pass of the Multi-task ranking model.
Args:
continuous_features (torch.Tensor): Continuous input features.
binary_features (torch.Tensor): Binary input features.
discrete_features (Optional[torch.Tensor], optional): Discrete input features. Defaults to None.
sparse_features ([type], optional): Sparse input features. Defaults to None.
user_embedding (Optional[torch.Tensor], optional): User embeddings. Defaults to None.
user_eng_embedding (Optional[torch.Tensor], optional): User engagement embeddings. Defaults to None.
author_embedding (Optional[torch.Tensor], optional): Author embeddings. Defaults to None.
labels (Optional[torch.Tensor], optional): Target labels. Defaults to None.
weights (Optional[torch.Tensor], optional): Weights for the loss function. Defaults to None.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the model's outputs.
"""
concat_dense_features = [
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
]
@ -270,6 +310,20 @@ def create_ranking_model(
data_config=None,
return_backbone=False,
):
"""
Creates a ranking model based on the provided specifications and configuration.
Args:
data_spec: The input data specifications.
config (config_mod.RecapConfig): The model configuration.
device (torch.device): The device where the model should be placed.
loss_fn (Optional[Callable], optional): A custom loss function. Defaults to None.
data_config: The data configuration. Defaults to None.
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
Returns:
torch.nn.Module: The created ranking model.
"""
if list(config.model.tasks.values())[0].dlrm_config:
raise NotImplementedError()

View File

@ -11,21 +11,52 @@ import torch
def log_transform(x: torch.Tensor) -> torch.Tensor:
"""Safe log transform that works across both negative, zero, and positive floats."""
"""
Safe log transform that works across both negative, zero, and positive floats.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with log1p applied to absolute values.
"""
return torch.sign(x) * torch.log1p(torch.abs(x))
class BatchNorm(torch.nn.Module):
def __init__(self, num_features: int, config: BatchNormConfig):
"""
Batch normalization layer.
Args:
num_features (int): Number of input features.
config (BatchNormConfig): Configuration for batch normalization.
"""
super().__init__()
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the batch normalization layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after batch normalization.
"""
return self.layer(x)
class LayerNorm(torch.nn.Module):
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
"""
Layer normalization layer.
Args:
normalized_shape (Union[int, Sequence[int]]): Size or shape of the input tensor.
config (LayerNormConfig): Configuration for layer normalization.
"""
super().__init__()
if config.axis != -1:
raise NotImplementedError
@ -38,6 +69,16 @@ class LayerNorm(torch.nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the layer normalization layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after layer normalization.
"""
return self.layer(x)
@ -46,11 +87,27 @@ class Log1pAbs(torch.nn.Module):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that applies a log transformation to the input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with log applied to absolute values.
"""
return log_transform(x)
class InputNonFinite(torch.nn.Module):
def __init__(self, fill_value: float = 0):
"""
Replaces non-finite (NaN and Inf) values in the input tensor with a specified fill value.
Args:
fill_value (float): The value to fill non-finite elements with. Default is 0.
"""
super().__init__()
self.register_buffer(
@ -58,11 +115,27 @@ class InputNonFinite(torch.nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that replaces non-finite values in the input tensor with the specified fill value.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with non-finite values replaced.
"""
return torch.where(torch.isfinite(x), x, self.fill_value)
class Clamp(torch.nn.Module):
def __init__(self, min_value: float, max_value: float):
"""
Applies element-wise clamping to a tensor, ensuring that values are within a specified range.
Args:
min_value (float): The minimum value to clamp elements to.
max_value (float): The maximum value to clamp elements to.
"""
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
@ -74,12 +147,31 @@ class Clamp(torch.nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that clamps the input tensor element-wise within the specified range.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with elements clamped within the specified range.
"""
return torch.clamp(x, min=self.min_value, max=self.max_value)
class DoubleNormLog(torch.nn.Module):
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features."""
"""
Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.
Args:
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
config (DoubleNormLogConfig): Configuration for the DoubleNormLog module.
Attributes:
_before_concat_layers (torch.nn.Sequential): Sequential layers for batch normalization, log transformation,
batch normalization (optional), and clamping.
layer_norm (LayerNorm or None): Layer normalization layer for binary and continuous features (optional).
"""
def __init__(
self,
input_shapes: Mapping[str, Sequence[int]],
@ -108,6 +200,17 @@ class DoubleNormLog(torch.nn.Module):
def forward(
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
) -> torch.Tensor:
"""
Forward pass that processes continuous and binary features using batch normalization, log transformation,
optional batch normalization (if configured), clamping, and layer normalization (if configured).
Args:
continuous_features (torch.Tensor): Input tensor of continuous features.
binary_features (torch.Tensor): Input tensor of binary features.
Returns:
torch.Tensor: Transformed tensor containing both continuous and binary features.
"""
x = self._before_concat_layers(continuous_features)
x = torch.cat([x, binary_features], dim=1)
if self.layer_norm:
@ -118,5 +221,15 @@ class DoubleNormLog(torch.nn.Module):
def build_features_preprocessor(
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
):
"""Trivial right now, but we will change in the future."""
"""
Build a feature preprocessor module based on the provided configuration.
Trivial right now, but we will change in the future.
Args:
config (FeaturizationConfig): Configuration for feature preprocessing.
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
Returns:
DoubleNormLog: An instance of the DoubleNormLog feature preprocessor.
"""
return DoubleNormLog(input_shapes, config.double_norm_log_config)

View File

@ -6,15 +6,84 @@ import torch
def _init_weights(module):
"""Initializes weights
Example
```python
import torch
import torch.nn as nn
# Define a simple linear layer
linear_layer = nn.Linear(64, 32)
# Initialize the weights and biases using _init_weights
_init_weights(linear_layer)
```
"""
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class MaskBlock(torch.nn.Module):
"""
MaskBlock module in a mask-based neural network.
This module represents a MaskBlock, which applies a masking operation to the input data and then
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
Example:
To create and use a MaskBlock within a MaskNet, follow these steps:
```python
# Define the configuration for the MaskBlock
mask_block_config = MaskBlockConfig(
input_layer_norm=True, # Apply input layer normalization
reduction_factor=0.5 # Reduce input dimensionality by 50%
)
# Create an instance of the MaskBlock
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
mask_input = torch.randn(batch_size, 32)
# Perform a forward pass through the MaskBlock
output = mask_block(input_data, mask_input)
```
Note:
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
operation that combines the input and mask input. Then, it passes the result through a hidden layer
with optional dimensionality reduction.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
"""
Initializes the MaskBlock module.
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
Returns:
None
"""
super(MaskBlock, self).__init__()
self.mask_block_config = mask_block_config
output_size = mask_block_config.output_size
@ -42,6 +111,16 @@ class MaskBlock(torch.nn.Module):
self._layer_norm = torch.nn.LayerNorm(output_size)
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
"""
Performs a forward pass through the MaskBlock.
Args:
net (torch.Tensor): Input data tensor.
mask_input (torch.Tensor): Mask input tensor.
Returns:
torch.Tensor: Output tensor of the MaskBlock.
"""
if self._input_layer_norm:
net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
@ -49,7 +128,60 @@ class MaskBlock(torch.nn.Module):
class MaskNet(torch.nn.Module):
"""
MaskNet module in a mask-based neural network.
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
create mask-based neural networks with parallel or stacked MaskBlocks.
Args:
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
Example:
To create and use a MaskNet, you can follow these steps:
```python
# Define the configuration for the MaskNet
mask_net_config = MaskNetConfig(
use_parallel=True, # Use parallel MaskBlocks
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
)
# Create an instance of the MaskNet
mask_net = MaskNet(mask_net_config, in_features=64)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
# Perform a forward pass through the MaskNet
outputs = mask_net(input_data)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
"""
Initializes the MaskNet module.
Args:
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
Returns:
None
"""
super().__init__()
self.mask_net_config = mask_net_config
mask_blocks = []
@ -77,6 +209,15 @@ class MaskNet(torch.nn.Module):
self.shared_size = total_output_mask_blocks
def forward(self, inputs: torch.Tensor):
"""
Performs a forward pass through the MaskNet.
Args:
inputs (torch.Tensor): Input data tensor.
Returns:
torch.Tensor: Output tensor of the MaskNet.
"""
if self.mask_net_config.use_parallel:
mask_outputs = []
for mask_layer in self._mask_blocks:

View File

@ -7,13 +7,83 @@ from absl import logging
def _init_weights(module):
"""Initializes weights
Example
-------
```python
import torch
import torch.nn as nn
# Define a simple linear layer
linear_layer = nn.Linear(64, 32)
# Initialize the weights and biases using _init_weights
_init_weights(linear_layer)
```
"""
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class Mlp(torch.nn.Module):
"""
Multi-Layer Perceptron (MLP) feedforward neural network module in PyTorch.
This module defines an MLP with customizable layers and activation functions. It is suitable for various
applications such as deep learning for tabular data, feature extraction, and more.
Args:
in_features (int): The number of input features or input dimensions.
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
Example:
To create an instance of the `Mlp` module and use it for forward passes, you can follow these steps:
```python
# Define the configuration for the MLP
mlp_config = MlpConfig(
layer_sizes=[128, 64], # Specify the sizes of hidden layers
batch_norm=True, # Enable batch normalization
dropout=0.2, # Apply dropout with a rate of 0.2
final_layer_activation=True # Apply ReLU activation to the final layer
)
# Create an instance of the MLP module
mlp_model = Mlp(in_features=input_dim, mlp_config=mlp_config)
# Generate an input tensor
input_tensor = torch.randn(batch_size, input_dim)
# Perform a forward pass through the MLP
outputs = mlp_model(input_tensor)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `Mlp` class allows you to create customizable MLP architectures by specifying the layer sizes,
enabling batch normalization and dropout, and choosing the activation function for the final layer.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, in_features: int, mlp_config: MlpConfig):
"""
Initializes the Mlp module.
Args:
in_features (int): The number of input features or input dimensions.
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
Returns:
None
"""
super().__init__()
self._mlp_config = mlp_config
input_size = in_features
@ -42,6 +112,15 @@ class Mlp(torch.nn.Module):
self.layers.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs a forward pass through the MLP.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: Output tensor of the MLP.
"""
net = x
for i, layer in enumerate(self.layers):
net = layer(net)
@ -51,8 +130,21 @@ class Mlp(torch.nn.Module):
@property
def shared_size(self):
"""
Returns the size of the shared layer in the MLP.
Returns:
int: Size of the shared layer.
"""
return self._mlp_config.layer_sizes[-1]
@property
def out_features(self):
"""
Returns the number of output features from the MLP.
Returns:
int: Number of output features.
"""
return self._mlp_config.layer_sizes[-1]

View File

@ -5,6 +5,53 @@ from absl import logging
class ModelAndLoss(torch.nn.Module):
"""
PyTorch module that combines a neural network model and loss function.
This module wraps a neural network model and facilitates the forward pass through the model
while also calculating the loss based on the model's predictions and provided labels.
Args:
model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification. Each stratifier config includes the name and index of discrete features
to emit for stratification.
Example:
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
and loss function as arguments:
```python
# Create a neural network model
model = YourNeuralNetworkModel()
# Define a loss function
loss_fn = torch.nn.CrossEntropyLoss()
# Create an instance of ModelAndLoss
model_and_loss = ModelAndLoss(model, loss_fn)
# Generate a batch of training data (e.g., RecapBatch)
batch = generate_training_batch()
# Perform a forward pass through the model and calculate the loss
loss, outputs = model_and_loss(batch)
# You can now backpropagate and optimize using the computed loss
loss.backward()
optimizer.step()
```
Note:
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
calculating loss, making it easier to integrate the model into your training loop. Additionally,
it supports the addition of stratifiers for metrics stratification, if needed.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
model,
@ -12,10 +59,13 @@ class ModelAndLoss(torch.nn.Module):
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.
Initializes the ModelAndLoss module.
Args:
model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification.
"""
super().__init__()
self.model = model

View File

@ -2,11 +2,57 @@ import torch
class NumericCalibration(torch.nn.Module):
"""
Numeric calibration module for adjusting probability scores.
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
for tasks such as binary classification.
Args:
pos_downsampling_rate (float): The downsampling rate for positive samples.
neg_downsampling_rate (float): The downsampling rate for negative samples.
Example:
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
scores like this:
```python
# Create a NumericCalibration instance with downsampling rates
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
# Generate probability scores (e.g., from a neural network)
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
# Apply numeric calibration to adjust the probabilities
calibrated_probs = calibration(raw_probs)
# The `calibrated_probs` now contains the adjusted probability scores
```
Note:
The `NumericCalibration` module is used to adjust probability scores to account for differences in
the number of positive and negative samples in a dataset. It can help improve the calibration of
probability estimates in imbalanced classification problems.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
pos_downsampling_rate: float,
neg_downsampling_rate: float,
):
"""
Apply numeric calibration to probability scores.
Args:
probs (torch.Tensor): Probability scores to be calibrated.
Returns:
torch.Tensor: Calibrated probability scores.
"""
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).

View File

@ -9,12 +9,60 @@ import pydantic
class RecapAdamConfig(base_config.BaseConfig):
"""
Configuration settings for the Adam optimizer used in Recap.
Args:
beta_1 (float): Momentum term (default: 0.9).
beta_2 (float): Exponential weighted decay factor (default: 0.999).
epsilon (float): Numerical stability in the denominator (default: 1e-7).
Example:
To define an Adam optimizer configuration for Recap, use:
```python
adam_config = RecapAdamConfig(beta_1=0.9, beta_2=0.999, epsilon=1e-7)
```
Note:
This class configures the parameters of the Adam optimizer, which is commonly used for optimizing neural networks.
Warning:
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
"""
beta_1: float = 0.9 # Momentum term.
beta_2: float = 0.999 # Exponential weighted decay factor.
epsilon: float = 1e-7 # Numerical stability in denominator.
class MultiTaskLearningRates(base_config.BaseConfig):
"""
Configuration settings for multiple learning rates in Recap.
Args:
tower_learning_rates (Dict[str, optimizers_config_mod.LearningRate]): Learning rates for different towers of the model.
backbone_learning_rate (optimizers_config_mod.LearningRate): Learning rate for the model's backbone (default: None).
Example:
To define multiple learning rates for different towers in Recap, use:
```python
multi_task_lr = MultiTaskLearningRates(
tower_learning_rates={
'task1': learning_rate1,
'task2': learning_rate2,
},
backbone_learning_rate=backbone_lr,
)
```
Note:
This class allows specifying different learning rates for different parts of the model, including task-specific towers and the backbone.
Warning:
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
"""
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
description="Learning rates for different towers of the model."
)
@ -25,6 +73,30 @@ class MultiTaskLearningRates(base_config.BaseConfig):
class RecapOptimizerConfig(base_config.BaseConfig):
"""
Configuration settings for the Recap optimizer.
Args:
multi_task_learning_rates (MultiTaskLearningRates): Multiple learning rates for different tasks (optional).
single_task_learning_rate (optimizers_config_mod.LearningRate): Learning rate for a single task (optional).
adam (RecapAdamConfig): Configuration settings for the Adam optimizer.
Example:
To define an optimizer configuration for training with Recap, use:
```python
optimizer_config = RecapOptimizerConfig(
multi_task_learning_rates=multi_task_lr,
single_task_learning_rate=single_task_lr,
adam=adam_config,
)
```
Warning:
This class is intended for internal use to configure the optimizer settings within Recap and should not be
directly accessed by external code.
"""
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
None, description="Multiple learning rates for different tasks.", one_of="lr"
)

View File

@ -23,12 +23,30 @@ _DENSE_EMBEDDINGS = "dense_ebc"
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
A shim to get learning rates into a LRScheduler.
This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports
learning rate schedules, such as exponential decay.
Args:
optimizer: The optimizer to which this scheduler is applied.
lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations.
emb_learning_rate: The learning rate for embeddings (optional).
last_epoch (int): The index of the last epoch (default: -1).
verbose (bool): If True, print warnings for deprecated functions (default: False).
Example:
To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use:
```python
scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate)
```
Warning:
This class is intended for internal use to handle learning rate scheduling within Recap training and should not
be directly accessed by external code.
"""
def __init__(
self,
@ -80,15 +98,25 @@ def build_optimizer(
optimizer_config: config.OptimizerConfig,
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
):
"""Builds an optimizer and scheduler.
Args:
model: A torch model, probably with DDP/DMP.
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
Returns:
A torch.optim instance, and a scheduler instance.
"""
Build an optimizer and scheduler for training.
Args:
model: The torch model, possibly with DDP/DMP.
optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer.
emb_optimizer_config: Configuration settings for embedding optimization (optional).
Returns:
torch.optim.Optimizer: The optimizer for training.
RecapLRShim: The learning rate scheduler for the optimizer.
Example:
To build an optimizer and scheduler for training, use:
```python
optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config)
```
"""
optimizer_fn = functools.partial(
torch.optim.Adam,
lr=_DEFAULT_LR,

View File

@ -23,6 +23,25 @@ RELATIONS = [
def test_gen():
"""Test function for generating edge-based datasets and dataloaders.
This function generates a synthetic dataset and tests the creation of an `EdgesDataset`
instance and a dataloader for it.
The test includes the following steps:
1. Create synthetic data with left-hand-side (lhs), right-hand-side (rhs), and relation (rel) columns.
2. Write the synthetic data to a Parquet file.
3. Create an `EdgesDataset` instance with the Parquet file pattern, table sizes, relations, and batch size.
4. Initialize the local file system for the dataset.
5. Create a dataloader for the dataset and retrieve the first batch.
6. Assert that the labels in the batch are positive.
7. Verify that the positive examples in the batch match the expected values.
This function serves as a test case for the data generation and dataset creation process.
Raises:
AssertionError: If any of the test assertions fail.
"""
import os
import tempfile

View File

@ -105,6 +105,23 @@ def test_twhin_model():
def test_unequal_dims():
"""
Test function for validating unequal embedding dimensions in TwhinEmbeddingsConfig.
This function tests whether the validation logic correctly raises a `ValidationError` when
embedding dimensions in the `TwhinEmbeddingsConfig` are not equal for all tables.
The test includes the following steps:
1. Create two embedding configurations with different embedding dimensions.
2. Attempt to create a `TwhinEmbeddingsConfig` instance with the unequal embedding dimensions.
3. Assert that a `ValidationError` is raised, indicating that embedding dimensions must match.
This function serves as a test case to ensure that the validation logic enforces equal embedding dimensions
in the `TwhinEmbeddingsConfig` for all tables.
Raises:
AssertionError: If the expected `ValidationError` is not raised.
"""
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
table0 = EmbeddingBagConfig(