mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-25 05:11:10 +01:00
update
This commit is contained in:
parent
db4ff958f6
commit
0813989fd9
@ -9,6 +9,36 @@ import pydantic
|
||||
|
||||
|
||||
class TrainingConfig(config_mod.BaseConfig):
|
||||
"""
|
||||
Configuration settings for the training process.
|
||||
|
||||
This class defines various training-related settings, including the directory to save checkpoints, the number
|
||||
of training steps, logging intervals, and other training parameters.
|
||||
|
||||
Attributes:
|
||||
save_dir (str): The directory where checkpoints and training artifacts will be saved.
|
||||
num_train_steps (pydantic.PositiveInt): The total number of training steps to run.
|
||||
initial_checkpoint_dir (str): The directory containing initial checkpoints (optional).
|
||||
checkpoint_every_n (pydantic.PositiveInt): Frequency of saving checkpoints during training.
|
||||
checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
|
||||
train_log_every_n (pydantic.PositiveInt): Frequency of logging training progress.
|
||||
num_eval_steps (int): Number of evaluation steps. Use a negative value to evaluate the entire dataset.
|
||||
eval_log_every_n (pydantic.PositiveInt): Frequency of logging evaluation progress.
|
||||
eval_timeout_in_s (pydantic.PositiveFloat): Maximum time (in seconds) allowed for evaluation.
|
||||
gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
|
||||
|
||||
Example:
|
||||
To configure training with checkpoints saved every 1000 steps, use the following settings:
|
||||
|
||||
```python
|
||||
TrainingConfig(
|
||||
save_dir="/tmp/model",
|
||||
num_train_steps=1000000,
|
||||
checkpoint_every_n=1000,
|
||||
train_log_every_n=1000,
|
||||
)
|
||||
```
|
||||
"""
|
||||
save_dir: str = "/tmp/model"
|
||||
num_train_steps: pydantic.PositiveInt = 1000000
|
||||
initial_checkpoint_dir: str = pydantic.Field(
|
||||
@ -32,6 +62,42 @@ class TrainingConfig(config_mod.BaseConfig):
|
||||
|
||||
|
||||
class RecapConfig(config_mod.BaseConfig):
|
||||
"""
|
||||
Configuration settings for the Recap model training process.
|
||||
|
||||
This class defines the overall configuration for the training process of a Recap model. It includes settings for
|
||||
training, model architecture, data, optimization, and evaluation.
|
||||
|
||||
Attributes:
|
||||
training (TrainingConfig): Configuration settings for the training process.
|
||||
model (model_config.ModelConfig): Configuration settings for the Recap model architecture.
|
||||
train_data (data_config.RecapDataConfig): Configuration settings for training data.
|
||||
validation_data (Dict[str, data_config.RecapDataConfig]): Configuration settings for validation data.
|
||||
optimizer (optimizer_config.RecapOptimizerConfig): Configuration settings for optimization.
|
||||
which_metrics (Optional[str]): Optional specification of which metrics to pick.
|
||||
|
||||
Note:
|
||||
This class encapsulates all the necessary configurations to train a Recap model. It defines settings for
|
||||
training, the model architecture, data loading, optimization, and evaluation.
|
||||
|
||||
Example:
|
||||
To configure a Recap model training process, use the following settings:
|
||||
|
||||
```python
|
||||
RecapConfig(
|
||||
training=TrainingConfig(
|
||||
save_dir="/tmp/model",
|
||||
num_train_steps=1000000,
|
||||
checkpoint_every_n=1000,
|
||||
train_log_every_n=1000,
|
||||
),
|
||||
model=model_config.ModelConfig(...),
|
||||
train_data=data_config.RecapDataConfig(...),
|
||||
validation_data={"dev": data_config.RecapDataConfig(...)},
|
||||
optimizer=optimizer_config.RecapOptimizerConfig(...),
|
||||
)
|
||||
```
|
||||
"""
|
||||
training: TrainingConfig = pydantic.Field(TrainingConfig())
|
||||
model: model_config.ModelConfig
|
||||
train_data: data_config.RecapDataConfig
|
||||
|
@ -30,6 +30,7 @@ class EmbeddingBagConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class EmbeddingOptimizerConfig(base_config.BaseConfig):
|
||||
"""Configuration for the optimizer used for embedding tables."""
|
||||
learning_rate: optimizer_config.LearningRate = pydantic.Field(
|
||||
None, description="learning rate scheduler for the EBC"
|
||||
)
|
||||
@ -52,6 +53,7 @@ class LargeEmbeddingsConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class StratifierConfig(base_config.BaseConfig):
|
||||
"""Configuration for Stratifier."""
|
||||
name: str
|
||||
index: int
|
||||
value: int
|
||||
@ -87,7 +89,8 @@ class SmallEmbeddingsConfig(base_config.BaseConfig):
|
||||
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
|
||||
serving time due to size (>>1 GB).
|
||||
|
||||
This small embeddings table uses the same optimizer as the rest of the model."""
|
||||
This small embeddings table uses the same optimizer as the rest of the model.
|
||||
"""
|
||||
|
||||
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
|
||||
..., description="list of embedding tables"
|
||||
|
@ -34,6 +34,33 @@ FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
|
||||
"""
|
||||
Main function to run the training of a ranking model.
|
||||
|
||||
This function initializes and runs the training process for a ranking model based on the provided configuration.
|
||||
|
||||
Args:
|
||||
unused_argv (str): Unused argument.
|
||||
data_service_dispatcher (Optional[str]): The data service dispatcher for accessing training data (optional).
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
AssertionError: If the configuration or input data is not valid.
|
||||
|
||||
Note:
|
||||
This function serves as the main entry point for training a ranking model. It loads the configuration, sets up
|
||||
the training environment, defines the loss function, creates the model, optimizer, and scheduler, and runs the
|
||||
training loop.
|
||||
|
||||
Example:
|
||||
To run the training process, use the following command:
|
||||
|
||||
```
|
||||
python run_training.py --config_path=config.yaml
|
||||
```
|
||||
"""
|
||||
print("#" * 100)
|
||||
|
||||
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
||||
|
@ -40,6 +40,7 @@ class BatchNormConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class DenseLayerConfig(base_config.BaseConfig):
|
||||
"""Configuration for the dense layer."""
|
||||
layer_size: pydantic.PositiveInt
|
||||
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
|
||||
|
||||
@ -61,6 +62,7 @@ class BatchNormConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class DoubleNormLogConfig(base_config.BaseConfig):
|
||||
"""Configuration for the double norm log transform."""
|
||||
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
|
||||
clip_magnitude: float = pydantic.Field(
|
||||
5.0, description="Threshold to clip the normalized input values."
|
||||
@ -73,12 +75,14 @@ class Log1pAbsConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class ClipLog1pAbsConfig(base_config.BaseConfig):
|
||||
"""Configuration for the clip log transform."""
|
||||
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
|
||||
3e38, description="Threshold to clip the input values."
|
||||
)
|
||||
|
||||
|
||||
class ZScoreLogConfig(base_config.BaseConfig):
|
||||
"""Configuration for the z-score log transform."""
|
||||
analysis_path: str
|
||||
schema_path: str = pydantic.Field(
|
||||
None,
|
||||
@ -148,6 +152,7 @@ class DcnConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class MaskBlockConfig(base_config.BaseConfig):
|
||||
"""Config for MaskNet block."""
|
||||
output_size: int
|
||||
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
|
||||
None, one_of="aggregation_size"
|
||||
@ -159,6 +164,7 @@ class MaskBlockConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class MaskNetConfig(base_config.BaseConfig):
|
||||
"""Config for MaskNet model."""
|
||||
mask_blocks: List[MaskBlockConfig]
|
||||
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
|
||||
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
|
||||
@ -190,6 +196,7 @@ class AffineMap(base_config.BaseConfig):
|
||||
|
||||
|
||||
class DLRMConfig(base_config.BaseConfig):
|
||||
"""Config for DLRM model."""
|
||||
bottom_mlp: MlpConfig = pydantic.Field(
|
||||
...,
|
||||
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
|
||||
@ -198,6 +205,7 @@ class DLRMConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class TaskModel(base_config.BaseConfig):
|
||||
"""Configuration for a single task."""
|
||||
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
|
||||
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
|
||||
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
|
||||
@ -213,6 +221,7 @@ class TaskModel(base_config.BaseConfig):
|
||||
|
||||
|
||||
class MultiTaskType(str, enum.Enum):
|
||||
"""Type of multi task architecture."""
|
||||
SHARE_NONE = "share_none" # Tasks are separate.
|
||||
SHARE_ALL = "share_all" # Tasks share same backbone.
|
||||
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
|
||||
@ -247,6 +256,7 @@ class ModelConfig(base_config.BaseConfig):
|
||||
|
||||
@pydantic.root_validator()
|
||||
def _validate_mtl(cls, values):
|
||||
"""Validate the multi task architecture."""
|
||||
if values.get("multi_task_type", None) is None:
|
||||
return values
|
||||
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:
|
||||
|
@ -26,7 +26,19 @@ def unsanitize(sanitized_task_name):
|
||||
|
||||
|
||||
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
|
||||
""" "Builds a model for a single task"""
|
||||
"""
|
||||
Build a model for a single task based on the provided configuration.
|
||||
|
||||
Args:
|
||||
task (model_config_mod.TaskModel): The task model configuration.
|
||||
input_shape (int): The input shape for the model.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: The constructed model for the single task.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task configuration is not recognized.
|
||||
"""
|
||||
if task.mlp_config:
|
||||
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
|
||||
elif task.dcn_config:
|
||||
@ -38,7 +50,12 @@ def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int)
|
||||
|
||||
|
||||
class MultiTaskRankingModel(torch.nn.Module):
|
||||
"""Multi-task ranking model."""
|
||||
"""
|
||||
Multi-task ranking model that handles multiple ranking tasks simultaneously.
|
||||
|
||||
This model takes various input features and predicts rankings for multiple
|
||||
tasks using shared or separate towers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -47,12 +64,18 @@ class MultiTaskRankingModel(torch.nn.Module):
|
||||
data_config: RecapDataConfig,
|
||||
return_backbone: bool = False,
|
||||
):
|
||||
"""Constructor for Multi task learning.
|
||||
"""
|
||||
Constructor for Multi-task ranking model.
|
||||
|
||||
Args:
|
||||
input_shapes (Mapping[str, torch.Size]): A mapping of input feature names to their shapes.
|
||||
config (ModelConfig): The model configuration.
|
||||
data_config (RecapDataConfig): The data configuration.
|
||||
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
|
||||
|
||||
Assumptions made:
|
||||
1. Tasks specified in data config match model architecture.
|
||||
|
||||
These are all validated in config.
|
||||
1. Tasks specified in data config match model architecture.
|
||||
These are all validated in config.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -168,6 +191,23 @@ class MultiTaskRankingModel(torch.nn.Module):
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
weights: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Forward pass of the Multi-task ranking model.
|
||||
|
||||
Args:
|
||||
continuous_features (torch.Tensor): Continuous input features.
|
||||
binary_features (torch.Tensor): Binary input features.
|
||||
discrete_features (Optional[torch.Tensor], optional): Discrete input features. Defaults to None.
|
||||
sparse_features ([type], optional): Sparse input features. Defaults to None.
|
||||
user_embedding (Optional[torch.Tensor], optional): User embeddings. Defaults to None.
|
||||
user_eng_embedding (Optional[torch.Tensor], optional): User engagement embeddings. Defaults to None.
|
||||
author_embedding (Optional[torch.Tensor], optional): Author embeddings. Defaults to None.
|
||||
labels (Optional[torch.Tensor], optional): Target labels. Defaults to None.
|
||||
weights (Optional[torch.Tensor], optional): Weights for the loss function. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary containing the model's outputs.
|
||||
"""
|
||||
concat_dense_features = [
|
||||
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
|
||||
]
|
||||
@ -270,6 +310,20 @@ def create_ranking_model(
|
||||
data_config=None,
|
||||
return_backbone=False,
|
||||
):
|
||||
"""
|
||||
Creates a ranking model based on the provided specifications and configuration.
|
||||
|
||||
Args:
|
||||
data_spec: The input data specifications.
|
||||
config (config_mod.RecapConfig): The model configuration.
|
||||
device (torch.device): The device where the model should be placed.
|
||||
loss_fn (Optional[Callable], optional): A custom loss function. Defaults to None.
|
||||
data_config: The data configuration. Defaults to None.
|
||||
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: The created ranking model.
|
||||
"""
|
||||
|
||||
if list(config.model.tasks.values())[0].dlrm_config:
|
||||
raise NotImplementedError()
|
||||
|
@ -11,21 +11,52 @@ import torch
|
||||
|
||||
|
||||
def log_transform(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Safe log transform that works across both negative, zero, and positive floats."""
|
||||
"""
|
||||
Safe log transform that works across both negative, zero, and positive floats.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with log1p applied to absolute values.
|
||||
"""
|
||||
return torch.sign(x) * torch.log1p(torch.abs(x))
|
||||
|
||||
|
||||
class BatchNorm(torch.nn.Module):
|
||||
def __init__(self, num_features: int, config: BatchNormConfig):
|
||||
"""
|
||||
Batch normalization layer.
|
||||
|
||||
Args:
|
||||
num_features (int): Number of input features.
|
||||
config (BatchNormConfig): Configuration for batch normalization.
|
||||
"""
|
||||
super().__init__()
|
||||
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the batch normalization layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after batch normalization.
|
||||
"""
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.Module):
|
||||
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
|
||||
"""
|
||||
Layer normalization layer.
|
||||
|
||||
Args:
|
||||
normalized_shape (Union[int, Sequence[int]]): Size or shape of the input tensor.
|
||||
config (LayerNormConfig): Configuration for layer normalization.
|
||||
"""
|
||||
super().__init__()
|
||||
if config.axis != -1:
|
||||
raise NotImplementedError
|
||||
@ -38,6 +69,16 @@ class LayerNorm(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the layer normalization layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after layer normalization.
|
||||
"""
|
||||
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
@ -46,11 +87,27 @@ class Log1pAbs(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that applies a log transformation to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with log applied to absolute values.
|
||||
"""
|
||||
|
||||
return log_transform(x)
|
||||
|
||||
|
||||
class InputNonFinite(torch.nn.Module):
|
||||
def __init__(self, fill_value: float = 0):
|
||||
"""
|
||||
Replaces non-finite (NaN and Inf) values in the input tensor with a specified fill value.
|
||||
|
||||
Args:
|
||||
fill_value (float): The value to fill non-finite elements with. Default is 0.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer(
|
||||
@ -58,11 +115,27 @@ class InputNonFinite(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that replaces non-finite values in the input tensor with the specified fill value.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with non-finite values replaced.
|
||||
"""
|
||||
return torch.where(torch.isfinite(x), x, self.fill_value)
|
||||
|
||||
|
||||
class Clamp(torch.nn.Module):
|
||||
def __init__(self, min_value: float, max_value: float):
|
||||
"""
|
||||
Applies element-wise clamping to a tensor, ensuring that values are within a specified range.
|
||||
|
||||
Args:
|
||||
min_value (float): The minimum value to clamp elements to.
|
||||
max_value (float): The maximum value to clamp elements to.
|
||||
"""
|
||||
super().__init__()
|
||||
# Using buffer to make sure they are on correct device (and not moved every time).
|
||||
# Will also be part of state_dict.
|
||||
@ -74,12 +147,31 @@ class Clamp(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that clamps the input tensor element-wise within the specified range.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with elements clamped within the specified range.
|
||||
"""
|
||||
return torch.clamp(x, min=self.min_value, max=self.max_value)
|
||||
|
||||
|
||||
class DoubleNormLog(torch.nn.Module):
|
||||
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features."""
|
||||
"""
|
||||
Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.
|
||||
|
||||
Args:
|
||||
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
|
||||
config (DoubleNormLogConfig): Configuration for the DoubleNormLog module.
|
||||
|
||||
Attributes:
|
||||
_before_concat_layers (torch.nn.Sequential): Sequential layers for batch normalization, log transformation,
|
||||
batch normalization (optional), and clamping.
|
||||
layer_norm (LayerNorm or None): Layer normalization layer for binary and continuous features (optional).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_shapes: Mapping[str, Sequence[int]],
|
||||
@ -108,6 +200,17 @@ class DoubleNormLog(torch.nn.Module):
|
||||
def forward(
|
||||
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that processes continuous and binary features using batch normalization, log transformation,
|
||||
optional batch normalization (if configured), clamping, and layer normalization (if configured).
|
||||
|
||||
Args:
|
||||
continuous_features (torch.Tensor): Input tensor of continuous features.
|
||||
binary_features (torch.Tensor): Input tensor of binary features.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor containing both continuous and binary features.
|
||||
"""
|
||||
x = self._before_concat_layers(continuous_features)
|
||||
x = torch.cat([x, binary_features], dim=1)
|
||||
if self.layer_norm:
|
||||
@ -118,5 +221,15 @@ class DoubleNormLog(torch.nn.Module):
|
||||
def build_features_preprocessor(
|
||||
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
|
||||
):
|
||||
"""Trivial right now, but we will change in the future."""
|
||||
"""
|
||||
Build a feature preprocessor module based on the provided configuration.
|
||||
Trivial right now, but we will change in the future.
|
||||
|
||||
Args:
|
||||
config (FeaturizationConfig): Configuration for feature preprocessing.
|
||||
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
|
||||
|
||||
Returns:
|
||||
DoubleNormLog: An instance of the DoubleNormLog feature preprocessor.
|
||||
"""
|
||||
return DoubleNormLog(input_shapes, config.double_norm_log_config)
|
||||
|
@ -6,15 +6,84 @@ import torch
|
||||
|
||||
|
||||
def _init_weights(module):
|
||||
"""Initializes weights
|
||||
|
||||
Example
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Define a simple linear layer
|
||||
linear_layer = nn.Linear(64, 32)
|
||||
|
||||
# Initialize the weights and biases using _init_weights
|
||||
_init_weights(linear_layer)
|
||||
```
|
||||
|
||||
"""
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
torch.nn.init.constant_(module.bias, 0)
|
||||
|
||||
|
||||
class MaskBlock(torch.nn.Module):
|
||||
"""
|
||||
MaskBlock module in a mask-based neural network.
|
||||
|
||||
This module represents a MaskBlock, which applies a masking operation to the input data and then
|
||||
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
|
||||
|
||||
Args:
|
||||
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
||||
input_dim (int): Dimensionality of the input data.
|
||||
mask_input_dim (int): Dimensionality of the mask input.
|
||||
|
||||
Example:
|
||||
To create and use a MaskBlock within a MaskNet, follow these steps:
|
||||
|
||||
```python
|
||||
# Define the configuration for the MaskBlock
|
||||
mask_block_config = MaskBlockConfig(
|
||||
input_layer_norm=True, # Apply input layer normalization
|
||||
reduction_factor=0.5 # Reduce input dimensionality by 50%
|
||||
)
|
||||
|
||||
# Create an instance of the MaskBlock
|
||||
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
|
||||
|
||||
# Generate input tensors
|
||||
input_data = torch.randn(batch_size, 64)
|
||||
mask_input = torch.randn(batch_size, 32)
|
||||
|
||||
# Perform a forward pass through the MaskBlock
|
||||
output = mask_block(input_data, mask_input)
|
||||
```
|
||||
|
||||
Note:
|
||||
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
|
||||
operation that combines the input and mask input. Then, it passes the result through a hidden layer
|
||||
with optional dimensionality reduction.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(
|
||||
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the MaskBlock module.
|
||||
|
||||
Args:
|
||||
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
||||
input_dim (int): Dimensionality of the input data.
|
||||
mask_input_dim (int): Dimensionality of the mask input.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
super(MaskBlock, self).__init__()
|
||||
self.mask_block_config = mask_block_config
|
||||
output_size = mask_block_config.output_size
|
||||
@ -42,6 +111,16 @@ class MaskBlock(torch.nn.Module):
|
||||
self._layer_norm = torch.nn.LayerNorm(output_size)
|
||||
|
||||
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
|
||||
"""
|
||||
Performs a forward pass through the MaskBlock.
|
||||
|
||||
Args:
|
||||
net (torch.Tensor): Input data tensor.
|
||||
mask_input (torch.Tensor): Mask input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of the MaskBlock.
|
||||
"""
|
||||
if self._input_layer_norm:
|
||||
net = self._input_layer_norm(net)
|
||||
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
||||
@ -49,7 +128,60 @@ class MaskBlock(torch.nn.Module):
|
||||
|
||||
|
||||
class MaskNet(torch.nn.Module):
|
||||
"""
|
||||
MaskNet module in a mask-based neural network.
|
||||
|
||||
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
|
||||
create mask-based neural networks with parallel or stacked MaskBlocks.
|
||||
|
||||
Args:
|
||||
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
||||
in_features (int): Dimensionality of the input data.
|
||||
|
||||
Example:
|
||||
To create and use a MaskNet, you can follow these steps:
|
||||
|
||||
```python
|
||||
# Define the configuration for the MaskNet
|
||||
mask_net_config = MaskNetConfig(
|
||||
use_parallel=True, # Use parallel MaskBlocks
|
||||
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
|
||||
)
|
||||
|
||||
# Create an instance of the MaskNet
|
||||
mask_net = MaskNet(mask_net_config, in_features=64)
|
||||
|
||||
# Generate input tensors
|
||||
input_data = torch.randn(batch_size, 64)
|
||||
|
||||
# Perform a forward pass through the MaskNet
|
||||
outputs = mask_net(input_data)
|
||||
|
||||
# Access the output and shared layer
|
||||
output = outputs["output"]
|
||||
shared_layer = outputs["shared_layer"]
|
||||
```
|
||||
|
||||
Note:
|
||||
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
|
||||
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
||||
"""
|
||||
Initializes the MaskNet module.
|
||||
|
||||
Args:
|
||||
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
||||
in_features (int): Dimensionality of the input data.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.mask_net_config = mask_net_config
|
||||
mask_blocks = []
|
||||
@ -77,6 +209,15 @@ class MaskNet(torch.nn.Module):
|
||||
self.shared_size = total_output_mask_blocks
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
"""
|
||||
Performs a forward pass through the MaskNet.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): Input data tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of the MaskNet.
|
||||
"""
|
||||
if self.mask_net_config.use_parallel:
|
||||
mask_outputs = []
|
||||
for mask_layer in self._mask_blocks:
|
||||
|
@ -7,13 +7,83 @@ from absl import logging
|
||||
|
||||
|
||||
def _init_weights(module):
|
||||
"""Initializes weights
|
||||
|
||||
Example
|
||||
-------
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Define a simple linear layer
|
||||
linear_layer = nn.Linear(64, 32)
|
||||
|
||||
# Initialize the weights and biases using _init_weights
|
||||
_init_weights(linear_layer)
|
||||
```
|
||||
|
||||
"""
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
torch.nn.init.constant_(module.bias, 0)
|
||||
|
||||
|
||||
class Mlp(torch.nn.Module):
|
||||
"""
|
||||
Multi-Layer Perceptron (MLP) feedforward neural network module in PyTorch.
|
||||
|
||||
This module defines an MLP with customizable layers and activation functions. It is suitable for various
|
||||
applications such as deep learning for tabular data, feature extraction, and more.
|
||||
|
||||
Args:
|
||||
in_features (int): The number of input features or input dimensions.
|
||||
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
|
||||
|
||||
Example:
|
||||
To create an instance of the `Mlp` module and use it for forward passes, you can follow these steps:
|
||||
|
||||
```python
|
||||
# Define the configuration for the MLP
|
||||
mlp_config = MlpConfig(
|
||||
layer_sizes=[128, 64], # Specify the sizes of hidden layers
|
||||
batch_norm=True, # Enable batch normalization
|
||||
dropout=0.2, # Apply dropout with a rate of 0.2
|
||||
final_layer_activation=True # Apply ReLU activation to the final layer
|
||||
)
|
||||
|
||||
# Create an instance of the MLP module
|
||||
mlp_model = Mlp(in_features=input_dim, mlp_config=mlp_config)
|
||||
|
||||
# Generate an input tensor
|
||||
input_tensor = torch.randn(batch_size, input_dim)
|
||||
|
||||
# Perform a forward pass through the MLP
|
||||
outputs = mlp_model(input_tensor)
|
||||
|
||||
# Access the output and shared layer
|
||||
output = outputs["output"]
|
||||
shared_layer = outputs["shared_layer"]
|
||||
```
|
||||
|
||||
Note:
|
||||
The `Mlp` class allows you to create customizable MLP architectures by specifying the layer sizes,
|
||||
enabling batch normalization and dropout, and choosing the activation function for the final layer.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(self, in_features: int, mlp_config: MlpConfig):
|
||||
"""
|
||||
Initializes the Mlp module.
|
||||
|
||||
Args:
|
||||
in_features (int): The number of input features or input dimensions.
|
||||
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
super().__init__()
|
||||
self._mlp_config = mlp_config
|
||||
input_size = in_features
|
||||
@ -42,6 +112,15 @@ class Mlp(torch.nn.Module):
|
||||
self.layers.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Performs a forward pass through the MLP.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of the MLP.
|
||||
"""
|
||||
net = x
|
||||
for i, layer in enumerate(self.layers):
|
||||
net = layer(net)
|
||||
@ -51,8 +130,21 @@ class Mlp(torch.nn.Module):
|
||||
|
||||
@property
|
||||
def shared_size(self):
|
||||
"""
|
||||
Returns the size of the shared layer in the MLP.
|
||||
|
||||
Returns:
|
||||
int: Size of the shared layer.
|
||||
"""
|
||||
return self._mlp_config.layer_sizes[-1]
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
"""
|
||||
Returns the number of output features from the MLP.
|
||||
|
||||
Returns:
|
||||
int: Number of output features.
|
||||
"""
|
||||
|
||||
return self._mlp_config.layer_sizes[-1]
|
||||
|
@ -5,6 +5,53 @@ from absl import logging
|
||||
|
||||
|
||||
class ModelAndLoss(torch.nn.Module):
|
||||
"""
|
||||
PyTorch module that combines a neural network model and loss function.
|
||||
|
||||
This module wraps a neural network model and facilitates the forward pass through the model
|
||||
while also calculating the loss based on the model's predictions and provided labels.
|
||||
|
||||
Args:
|
||||
model: The torch module to wrap.
|
||||
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
||||
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
||||
for metrics stratification. Each stratifier config includes the name and index of discrete features
|
||||
to emit for stratification.
|
||||
|
||||
Example:
|
||||
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
|
||||
and loss function as arguments:
|
||||
|
||||
```python
|
||||
# Create a neural network model
|
||||
model = YourNeuralNetworkModel()
|
||||
|
||||
# Define a loss function
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# Create an instance of ModelAndLoss
|
||||
model_and_loss = ModelAndLoss(model, loss_fn)
|
||||
|
||||
# Generate a batch of training data (e.g., RecapBatch)
|
||||
batch = generate_training_batch()
|
||||
|
||||
# Perform a forward pass through the model and calculate the loss
|
||||
loss, outputs = model_and_loss(batch)
|
||||
|
||||
# You can now backpropagate and optimize using the computed loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
Note:
|
||||
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
|
||||
calculating loss, making it easier to integrate the model into your training loop. Additionally,
|
||||
it supports the addition of stratifiers for metrics stratification, if needed.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@ -12,10 +59,13 @@ class ModelAndLoss(torch.nn.Module):
|
||||
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
model: torch module to wrap.
|
||||
loss_fn: Function for calculating loss, should accept logits and labels.
|
||||
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.
|
||||
Initializes the ModelAndLoss module.
|
||||
|
||||
Args:
|
||||
model: The torch module to wrap.
|
||||
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
||||
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
||||
for metrics stratification.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
@ -2,11 +2,57 @@ import torch
|
||||
|
||||
|
||||
class NumericCalibration(torch.nn.Module):
|
||||
"""
|
||||
Numeric calibration module for adjusting probability scores.
|
||||
|
||||
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
|
||||
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
|
||||
for tasks such as binary classification.
|
||||
|
||||
Args:
|
||||
pos_downsampling_rate (float): The downsampling rate for positive samples.
|
||||
neg_downsampling_rate (float): The downsampling rate for negative samples.
|
||||
|
||||
Example:
|
||||
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
|
||||
scores like this:
|
||||
|
||||
```python
|
||||
# Create a NumericCalibration instance with downsampling rates
|
||||
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
|
||||
|
||||
# Generate probability scores (e.g., from a neural network)
|
||||
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
|
||||
|
||||
# Apply numeric calibration to adjust the probabilities
|
||||
calibrated_probs = calibration(raw_probs)
|
||||
|
||||
# The `calibrated_probs` now contains the adjusted probability scores
|
||||
```
|
||||
|
||||
Note:
|
||||
The `NumericCalibration` module is used to adjust probability scores to account for differences in
|
||||
the number of positive and negative samples in a dataset. It can help improve the calibration of
|
||||
probability estimates in imbalanced classification problems.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
pos_downsampling_rate: float,
|
||||
neg_downsampling_rate: float,
|
||||
):
|
||||
"""
|
||||
Apply numeric calibration to probability scores.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability scores to be calibrated.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Calibrated probability scores.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Using buffer to make sure they are on correct device (and not moved every time).
|
||||
|
@ -9,12 +9,60 @@ import pydantic
|
||||
|
||||
|
||||
class RecapAdamConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration settings for the Adam optimizer used in Recap.
|
||||
|
||||
Args:
|
||||
beta_1 (float): Momentum term (default: 0.9).
|
||||
beta_2 (float): Exponential weighted decay factor (default: 0.999).
|
||||
epsilon (float): Numerical stability in the denominator (default: 1e-7).
|
||||
|
||||
Example:
|
||||
To define an Adam optimizer configuration for Recap, use:
|
||||
|
||||
```python
|
||||
adam_config = RecapAdamConfig(beta_1=0.9, beta_2=0.999, epsilon=1e-7)
|
||||
```
|
||||
|
||||
Note:
|
||||
This class configures the parameters of the Adam optimizer, which is commonly used for optimizing neural networks.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
|
||||
"""
|
||||
|
||||
beta_1: float = 0.9 # Momentum term.
|
||||
beta_2: float = 0.999 # Exponential weighted decay factor.
|
||||
epsilon: float = 1e-7 # Numerical stability in denominator.
|
||||
|
||||
|
||||
class MultiTaskLearningRates(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration settings for multiple learning rates in Recap.
|
||||
|
||||
Args:
|
||||
tower_learning_rates (Dict[str, optimizers_config_mod.LearningRate]): Learning rates for different towers of the model.
|
||||
backbone_learning_rate (optimizers_config_mod.LearningRate): Learning rate for the model's backbone (default: None).
|
||||
|
||||
Example:
|
||||
To define multiple learning rates for different towers in Recap, use:
|
||||
|
||||
```python
|
||||
multi_task_lr = MultiTaskLearningRates(
|
||||
tower_learning_rates={
|
||||
'task1': learning_rate1,
|
||||
'task2': learning_rate2,
|
||||
},
|
||||
backbone_learning_rate=backbone_lr,
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
This class allows specifying different learning rates for different parts of the model, including task-specific towers and the backbone.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
|
||||
"""
|
||||
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
|
||||
description="Learning rates for different towers of the model."
|
||||
)
|
||||
@ -25,6 +73,30 @@ class MultiTaskLearningRates(base_config.BaseConfig):
|
||||
|
||||
|
||||
class RecapOptimizerConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration settings for the Recap optimizer.
|
||||
|
||||
Args:
|
||||
multi_task_learning_rates (MultiTaskLearningRates): Multiple learning rates for different tasks (optional).
|
||||
single_task_learning_rate (optimizers_config_mod.LearningRate): Learning rate for a single task (optional).
|
||||
adam (RecapAdamConfig): Configuration settings for the Adam optimizer.
|
||||
|
||||
Example:
|
||||
To define an optimizer configuration for training with Recap, use:
|
||||
|
||||
```python
|
||||
optimizer_config = RecapOptimizerConfig(
|
||||
multi_task_learning_rates=multi_task_lr,
|
||||
single_task_learning_rate=single_task_lr,
|
||||
adam=adam_config,
|
||||
)
|
||||
```
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use to configure the optimizer settings within Recap and should not be
|
||||
directly accessed by external code.
|
||||
"""
|
||||
|
||||
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
|
||||
None, description="Multiple learning rates for different tasks.", one_of="lr"
|
||||
)
|
||||
|
@ -23,12 +23,30 @@ _DENSE_EMBEDDINGS = "dense_ebc"
|
||||
|
||||
|
||||
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
|
||||
"""Shim to get learning rates into a LRScheduler.
|
||||
|
||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
||||
e.g. exponential decay can be used.
|
||||
|
||||
"""
|
||||
A shim to get learning rates into a LRScheduler.
|
||||
|
||||
This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports
|
||||
learning rate schedules, such as exponential decay.
|
||||
|
||||
Args:
|
||||
optimizer: The optimizer to which this scheduler is applied.
|
||||
lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations.
|
||||
emb_learning_rate: The learning rate for embeddings (optional).
|
||||
last_epoch (int): The index of the last epoch (default: -1).
|
||||
verbose (bool): If True, print warnings for deprecated functions (default: False).
|
||||
|
||||
Example:
|
||||
To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use:
|
||||
|
||||
```python
|
||||
scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate)
|
||||
```
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use to handle learning rate scheduling within Recap training and should not
|
||||
be directly accessed by external code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -80,15 +98,25 @@ def build_optimizer(
|
||||
optimizer_config: config.OptimizerConfig,
|
||||
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
|
||||
):
|
||||
"""Builds an optimizer and scheduler.
|
||||
|
||||
Args:
|
||||
model: A torch model, probably with DDP/DMP.
|
||||
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
|
||||
|
||||
Returns:
|
||||
A torch.optim instance, and a scheduler instance.
|
||||
"""
|
||||
Build an optimizer and scheduler for training.
|
||||
|
||||
Args:
|
||||
model: The torch model, possibly with DDP/DMP.
|
||||
optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer.
|
||||
emb_optimizer_config: Configuration settings for embedding optimization (optional).
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: The optimizer for training.
|
||||
RecapLRShim: The learning rate scheduler for the optimizer.
|
||||
|
||||
Example:
|
||||
To build an optimizer and scheduler for training, use:
|
||||
|
||||
```python
|
||||
optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config)
|
||||
```
|
||||
"""
|
||||
optimizer_fn = functools.partial(
|
||||
torch.optim.Adam,
|
||||
lr=_DEFAULT_LR,
|
||||
|
@ -23,6 +23,25 @@ RELATIONS = [
|
||||
|
||||
|
||||
def test_gen():
|
||||
"""Test function for generating edge-based datasets and dataloaders.
|
||||
|
||||
This function generates a synthetic dataset and tests the creation of an `EdgesDataset`
|
||||
instance and a dataloader for it.
|
||||
|
||||
The test includes the following steps:
|
||||
1. Create synthetic data with left-hand-side (lhs), right-hand-side (rhs), and relation (rel) columns.
|
||||
2. Write the synthetic data to a Parquet file.
|
||||
3. Create an `EdgesDataset` instance with the Parquet file pattern, table sizes, relations, and batch size.
|
||||
4. Initialize the local file system for the dataset.
|
||||
5. Create a dataloader for the dataset and retrieve the first batch.
|
||||
6. Assert that the labels in the batch are positive.
|
||||
7. Verify that the positive examples in the batch match the expected values.
|
||||
|
||||
This function serves as a test case for the data generation and dataset creation process.
|
||||
|
||||
Raises:
|
||||
AssertionError: If any of the test assertions fail.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
@ -105,6 +105,23 @@ def test_twhin_model():
|
||||
|
||||
|
||||
def test_unequal_dims():
|
||||
"""
|
||||
Test function for validating unequal embedding dimensions in TwhinEmbeddingsConfig.
|
||||
|
||||
This function tests whether the validation logic correctly raises a `ValidationError` when
|
||||
embedding dimensions in the `TwhinEmbeddingsConfig` are not equal for all tables.
|
||||
|
||||
The test includes the following steps:
|
||||
1. Create two embedding configurations with different embedding dimensions.
|
||||
2. Attempt to create a `TwhinEmbeddingsConfig` instance with the unequal embedding dimensions.
|
||||
3. Assert that a `ValidationError` is raised, indicating that embedding dimensions must match.
|
||||
|
||||
This function serves as a test case to ensure that the validation logic enforces equal embedding dimensions
|
||||
in the `TwhinEmbeddingsConfig` for all tables.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the expected `ValidationError` is not raised.
|
||||
"""
|
||||
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
||||
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
|
||||
table0 = EmbeddingBagConfig(
|
||||
|
Loading…
x
Reference in New Issue
Block a user