This commit is contained in:
Rockerz 2023-09-14 06:00:51 +00:00 committed by GitHub
commit 63573c6f22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 3260 additions and 357 deletions

View File

@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable
class BatchBase(Pipelineable, abc.ABC):
"""
A base class for batches used in pipelines.
Attributes:
None
"""
@abc.abstractmethod
def as_dict(self) -> Dict:
"""
Convert the batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
raise NotImplementedError
def to(self, device: torch.device, non_blocking: bool = False):
"""
Move the batch to the specified device.
Args:
device (torch.device): The target device.
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.
Returns:
BatchBase: A new batch on the target device.
"""
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
return self.__class__(**args)
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
"""
Record a CUDA stream for all tensors in the batch.
Args:
stream (torch.cuda.streams.Stream): The CUDA stream to record.
Returns:
None
"""
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)
def pin_memory(self):
"""
Pin memory for all tensors in the batch.
Returns:
BatchBase: A new batch with pinned memory.
"""
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)
def __repr__(self) -> str:
"""
Generate a string representation of the batch.
Returns:
str: A string representation of the batch.
"""
def obj2str(v):
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC):
@property
def batch_size(self) -> int:
"""
Get the batch size from the tensors in the batch.
Returns:
int: The batch size.
Raises:
Exception: If the batch size cannot be determined from the tensors.
"""
for tensor in self.as_dict().values():
if tensor is None:
continue
@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC):
@dataclass
class DataclassBatch(BatchBase):
"""
A batch class that uses dataclasses to define its fields.
Attributes:
None
"""
@classmethod
def feature_names(cls):
"""
Get the feature names of the dataclass.
Returns:
List[str]: A list of feature names.
"""
return list(cls.__dataclass_fields__.keys())
def as_dict(self):
"""
Convert the dataclass batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
"""
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
@ -64,7 +147,18 @@ class DataclassBatch(BatchBase):
@staticmethod
def from_schema(name: str, schema):
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
"""
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.
Args:
name (str): The name of the custom batch class.
schema: The schema or structure of the batch.
Returns:
Type[DataclassBatch]: A custom batch class.
"""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
@ -73,6 +167,17 @@ class DataclassBatch(BatchBase):
@staticmethod
def from_fields(name: str, fields: dict):
"""
Create a custom batch subclass from a set of fields.
Args:
name (str): The name of the custom batch class.
fields (dict): A dictionary specifying the fields and their types.
Returns:
Type[DataclassBatch]: A custom batch class.
"""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
@ -81,5 +186,19 @@ class DataclassBatch(BatchBase):
class DictionaryBatch(BatchBase, dict):
"""
A batch class that represents data as a dictionary.
Attributes:
None
"""
def as_dict(self) -> Dict:
"""
Convert the dictionary batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
"""
return self

View File

@ -13,23 +13,30 @@ GCS_PREFIX = "gs://"
class Snapshot:
"""Checkpoints using torchsnapshot.
Also saves step to be updated by the training loop.
"""
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
"""
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
"""
Initializes a Snapshot object.
Args:
save_dir (str): Directory where checkpoints will be saved.
state (Dict[str, Any]): State dictionary containing checkpoint information.
"""
self.save_dir = save_dir
self.state = state
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property
def step(self):
"""Get the current training step."""
return self.state["extra_state"]["step"]
@step.setter
def step(self, step: int) -> None:
"""Set the current training step."""
self.state["extra_state"]["step"] = step
@property
@ -41,7 +48,15 @@ class Snapshot:
self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot":
"""Saves checkpoint with given global_step."""
"""
Saves a checkpoint with a given global step.
Args:
global_step (int): The global step to associate with the checkpoint.
Returns:
PendingSnapshot: A pending snapshot object.
"""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
start_time = time.time()
@ -58,7 +73,12 @@ class Snapshot:
return snapshot
def restore(self, checkpoint: str) -> None:
"""Restores a given checkpoint."""
"""
Restores a given checkpoint.
Args:
checkpoint (str): Path to the checkpoint to restore.
"""
snapshot = torchsnapshot.Snapshot(path=checkpoint)
logging.info(f"Restoring snapshot from {snapshot.path}.")
start_time = time.time()
@ -83,12 +103,17 @@ class Snapshot:
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> torchsnapshot.Snapshot:
"""Get torch stateless snapshot, without actually loading it.
Args:
snapshot_path: path to the model snapshot
global_step: restores from this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
Get a torch stateless snapshot, without actually loading it.
Args:
snapshot_path (str): Path to the model snapshot.
global_step (int, optional): Restores from this checkpoint if specified.
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
Returns:
torchsnapshot.Snapshot: A torch snapshot object.
"""
path = get_checkpoint(snapshot_path, global_step, missing_ok)
logging.info(f"Loading snapshot from {path}.")
return torchsnapshot.Snapshot(path=path)
@ -100,13 +125,14 @@ class Snapshot:
snapshot_emb_name: str,
weight_tensor,
) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
Args:
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
"""
Loads pretrained embedding from the snapshot to the model.
Args:
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
"""
start_time = time.time()
manifest = embedding_snapshot.get_manifest()
for path in manifest.keys():
@ -209,7 +235,22 @@ def get_checkpoint(
def get_checkpoints(save_dir: str) -> List[str]:
"""Gets all checkpoints that have been fully written."""
"""
Get a list of fully written checkpoints in the specified directory.
This function retrieves a list of fully written checkpoints in the given directory.
Checkpoints that are considered fully written include those that have a
corresponding snapshot metadata file.
Args:
save_dir (str): The directory where checkpoints are stored.
Returns:
List[str]: A list of fully written checkpoint paths.
Note:
Checkpoints are sorted by their numeric filenames in ascending order.
"""
checkpoints = []
fs = infer_fs(save_dir)
if fs.exists(save_dir):
@ -232,6 +273,18 @@ def wait_for_evaluators(
global_step: int,
timeout: int,
) -> None:
"""
Waits for all evaluators to finish and checks for their completion status.
Args:
save_dir (str): Directory where checkpoints are saved.
partition_names (List[str]): List of partition names to check for completion.
global_step (int): The global step for which to wait for evaluators.
timeout (int): Maximum time in seconds to wait for evaluators to finish.
Returns:
None: This function returns nothing but logs the progress and results.
"""
logging.info("Waiting for all evaluators to finish.")
start_time = time.time()

View File

@ -5,6 +5,15 @@ import torch.distributed as dist
def maybe_setup_tensorflow():
"""
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
Returns:
None
"""
try:
import tensorflow as tf
except ImportError:
@ -14,6 +23,18 @@ def maybe_setup_tensorflow():
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
"""
Set up the distributed environment and get the appropriate torch device.
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
Args:
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
Returns:
torch.device: The torch device for the current process.
"""
if tf_ok:
maybe_setup_tensorflow()

View File

@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem()
def infer_fs(path: str):
"""
Infer the file system (fs) type based on the given path.
Args:
path (str): The file path.
Returns:
str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system).
Raises:
NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported.
"""
if path.startswith("gs://"):
return GCS_FS
elif path.startswith("hdfs://"):
@ -18,8 +30,26 @@ def infer_fs(path: str):
def is_local_fs(fs):
"""
Check if the given file system is the local file system.
Args:
fs (str): The file system type to check.
Returns:
bool: True if the file system is the local file system, False otherwise.
"""
return fs == LOCAL_FS
def is_gcs_fs(fs):
"""
Check if the given file system is Google Cloud Storage (GCS).
Args:
fs (str): The file system type to check.
Returns:
bool: True if the file system is GCS, False otherwise.
"""
return fs == GCS_FS

View File

@ -49,15 +49,25 @@ def log_ebc_norms(
ebc_keys: List[str],
sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:
"""Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank.
Args:
model_state_dict: model.state_dict()
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
"""
Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank.
Args:
model_state_dict (dict): The state dictionary of the model.
ebc_keys (List[str]): List of embedding keys from state_dict to log.
Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight".
sample_size (int, optional): Limits the number of rows per rank to compute average on
to avoid out-of-memory (OOM) errors. Defaults to 4,000,000.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables.
The keys are in the format "{emb_key}-norm-{idx}".
Note:
This function computes and logs the average norm of embedding tables across ranks.
It gathers the norms from all ranks and returns them as a dictionary.
"""
norm_logs = dict()
for emb_key in ebc_keys:
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))

View File

@ -11,6 +11,20 @@ import numpy as np
class LargeEmbeddings(nn.Module):
"""
A module for handling large embeddings.
Args:
large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings.
Attributes:
ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings.
surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models.
Note:
The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving.
"""
def __init__(
self,
large_embeddings_config: LargeEmbeddingsConfig,
@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module):
self,
sparse_features: KeyedJaggedTensor,
) -> KeyedTensor:
"""
Forward pass of the LargeEmbeddings module.
Args:
sparse_features (KeyedJaggedTensor): Sparse input features.
Returns:
KeyedTensor: The output of the forward pass, a KeyedTensor.
"""
pooled_embs = self.ebc(sparse_features)
# a KeyedTensor

View File

@ -11,6 +11,12 @@ import torch.distributed.run
def is_distributed_worker():
"""
Checks if the current process is a distributed worker.
Returns:
bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False.
"""
world_size = os.environ.get("WORLD_SIZE", None)
rank = os.environ.get("RANK", None)
return world_size is not None and rank is not None
@ -25,25 +31,36 @@ def maybe_run_training(
is_chief: Optional[bool] = False,
**training_kwargs,
):
"""Wrapper function for single node, multi-GPU Pytorch training.
If the necessary distributed Pytorch environment variables
(WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module
`module_name`. After this call, the necessary environment variables are set
and training will commence.
Args:
train_fn: The function that is responsible for training
module_name: The name of the module that this function was called from;
used to indicate torchrun entrypoint.
nproc_per_node: Number of workers per node; supported values.
num_nodes: Number of nodes, otherwise inferred from environment.
is_chief: If process is running on chief.
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
"""
Wrapper function for single node, multi-GPU PyTorch training.
If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module
`module_name`. After this call, the necessary environment variables are set
and training will commence.
Args:
train_fn (callable): The function responsible for training.
module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint.
nproc_per_node (int, optional): Number of workers per node. Defaults to None.
num_nodes (int, optional): Number of nodes. Defaults to None.
is_chief (bool, optional): If the process is running on the chief node. Defaults to False.
set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False.
**training_kwargs: Additional keyword arguments to pass to the `train_fn`.
Note:
This function checks if the current process is a distributed worker by examining the environment variables.
If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary
environment variables and launches the training process using torchrun.
Example:
To run training on a single node with 4 GPUs, you can use:
```
maybe_run_training(train_function, __name__, nproc_per_node=4)
```
"""
machines = utils.machine_from_env()
if num_nodes is None:

View File

@ -19,6 +19,23 @@ MOCK_ENV = {
@contextmanager
def mock_pg():
"""
A context manager for mocking the distributed process group for testing purposes.
This context manager temporarily sets environment variables to mock the distributed process group
and initializes it using the Gloo backend. It is useful for testing distributed training without
actually launching multiple processes.
Example:
```python
with mock_pg():
# Your distributed training code here
```
Note:
This context manager should be used within a testing environment to simulate distributed training
without actually creating multiple processes.
"""
with patch.dict(os.environ, MOCK_ENV):
try:
dist.init_process_group(

View File

@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig)
def _read_file(f):
"""
Read the contents of a file using fsspec.
Args:
f: File path or URL.
Returns:
The contents of the file.
"""
with fsspec.open(f) as f:
return f.read()
@ -21,18 +30,24 @@ def setup_configuration(
yaml_path: str,
substitute_env_variable: bool = False,
) -> Tuple[C, str]:
"""Resolves a config at a yaml path.
Args:
config_type: Pydantic config class to load.
yaml_path: yaml path of the config file.
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
environment variable value whenever possible. If an environment variable doesn't exist,
the string is left unchanged.
Returns:
The pydantic config object.
"""
Load a Pydantic config object from a YAML file and optionally substitute environment variables.
Args:
config_type: Pydantic config class to load.
yaml_path: YAML path of the config file.
substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR}
with their environment variable values whenever possible.
If an environment variable doesn't exist, the string is left unchanged.
Returns:
A tuple containing the Pydantic config object and the resolved YAML content.
Example:
```python
config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True)
```
"""
def _substitute(s):
if substitute_env_variable:

View File

@ -6,6 +6,33 @@ import pydantic
class WandbConfig(base_config.BaseConfig):
"""
Configuration for integrating with Weights and Biases (WandB).
Attributes:
host (str): Host of the Weights and Biases instance, passed to login.
key_path (str): Path to the key file.
name (str): Name of the experiment, passed to init.
entity (str): Name of the user/service account, passed to init.
project (str): Name of the WandB project, passed to init.
tags (List[str]): List of tags, passed to init.
notes (str): Notes, passed to init.
metadata (Dict[str, Any]): Additional metadata to log.
Example:
```python
wandb_config = WandbConfig(
host="https://wandb.example.com",
key_path="/path/to/key",
name="experiment_1",
entity="user123",
project="my_project",
tags=["experiment", "ml"],
notes="This is a test experiment.",
metadata={"version": "1.0"}
)
```
"""
host: str = pydantic.Field(
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
description="Host of Weights and Biases instance, passed to login.",

View File

@ -8,59 +8,69 @@ import pydantic
class BaseConfig(pydantic.BaseModel):
"""Base class for all derived config classes.
"""Base class for all derived config classes.
This class provides some convenient functionality:
- Disallows extra fields when constructing an object. User error
should be reduced by exact arguments.
- "one_of" fields. A subclass can group optional fields and enforce
that only one of the fields be set. For example:
This class provides convenient functionality and constraints for derived config classes:
```
- Disallows extra fields when constructing an object. User errors due to extraneous arguments
are minimized.
- "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields
be set. For example:
```python
class ExampleConfig(BaseConfig):
x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1")
x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1")
ExampleConfig(x=1) # ok
ExampleConfig(y=1) # ok
ExampleConfig(x=1, y=1) # throws error
ExampleConfig(x=1) # OK
ExampleConfig(y=1) # OK
ExampleConfig(x=1, y=1) # Raises an error
```
"""
class Config:
"""Forbids extras."""
Attributes:
Config (class): Configuration options for this class, forbidding extra fields.
extra = pydantic.Extra.forbid # noqa
Methods:
_field_data_map(cls, field_data_name): Create a map of fields with the provided field data.
_one_of_check(cls, values): Validate that all 'one of' fields appear exactly once.
_at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once.
pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging.
@classmethod
@functools.lru_cache()
def _field_data_map(cls, field_data_name):
"""Create a map of fields with provided the field data."""
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
"""
class Config:
"""Configuration options that forbid extra fields."""
extra = pydantic.Extra.forbid # noqa
@pydantic.root_validator
def _one_of_check(cls, values):
"""Validate that all 'one of' fields are appear exactly once."""
one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values
@classmethod
@functools.lru_cache()
def _field_data_map(cls, field_data_name):
"""Create a map of fields with the provided field data."""
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
@pydantic.root_validator
def _at_most_one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
@pydantic.root_validator
def _one_of_check(cls, values):
"""Validate that all 'one of' fields appear exactly once."""
one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values
def pretty_print(self) -> str:
"""Return a human legible (yaml) representation of the config useful for logging."""
return yaml.dump(self.dict())
@pydantic.root_validator
def _at_most_one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
def pretty_print(self) -> str:
"""Return a human-readable (YAML) representation of the config useful for logging."""
return yaml.dump(self.dict())

View File

@ -6,33 +6,67 @@ import pydantic
class BaseConfigTest(TestCase):
def test_extra_forbidden(self):
class Config(BaseConfig):
x: int
"""
Unit tests for the BaseConfig class.
"""
Config(x=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=2)
def test_extra_forbidden(self):
"""
Test that extra fields are forbidden when creating a Config instance.
def test_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f")
y: int = pydantic.Field(None, one_of="f")
This test case checks whether the `BaseConfig` class correctly raises a
`pydantic.ValidationError` when extra fields are provided when creating a
`Config` instance.
with self.assertRaises(pydantic.ValidationError):
Config()
Config(x=1)
Config(y=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=3)
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig):
x: int
def test_at_most_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f")
y: str = pydantic.Field(None, at_most_one_of="f")
Config(x=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=2)
Config()
Config(x=1)
Config(y="a")
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y="a")
def test_one_of(self):
"""
Test the use of the `one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `one_of` attribute in a `Config`
instance. It verifies that the `pydantic.Field` correctly enforces the
specified constraint.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f")
y: int = pydantic.Field(None, one_of="f")
with self.assertRaises(pydantic.ValidationError):
Config()
Config(x=1)
Config(y=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=3)
def test_at_most_one_of(self):
"""
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `at_most_one_of` attribute in a
`Config` instance. It verifies that the `pydantic.Field` enforces the
constraint where at most one of the specified fields can be provided.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f")
y: str = pydantic.Field(None, at_most_one_of="f")
Config()
Config(x=1)
Config(y="a")
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y="a")

View File

@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
"""Recommend method to load a config file (a yaml file) and parse it.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
"""
Recommend method to Load and parse a configuration from a YAML file.
This function loads a configuration from a YAML file, parses it, and returns an instance of the
specified config type.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
Args:
config_type (Type[BaseConfig]): The Pydantic config class to load.
yaml_path (str): The path to the YAML configuration file.
Returns:
BaseConfig: An instance of the specified config type populated with values from the YAML file.
Example:
Suppose you have a YAML file 'my_config.yaml' containing the following:
```yaml
x: 42
y: "hello"
```
You can load and parse it using this function as follows:
```python
my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml')
```
Note:
This function performs environment variable substitution in the YAML file. It replaces
occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable
values. If an environment variable does not exist, the string is left unchanged.
"""
def _substitute(s):
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())

View File

@ -8,11 +8,27 @@ import pydantic
class _PointlessConfig(BaseConfig):
a: int
user: str
def test_load_config_from_yaml(tmp_path):
"""Test loading a configuration from a YAML file and verifying its values.
This test function checks the functionality of the `load_config_from_yaml` function by creating
a temporary YAML configuration file, loading it, and asserting that the loaded config object
has the expected values.
Args:
tmp_path: A temporary directory provided by the `pytest` framework.
Test Steps:
1. Create a temporary YAML file containing configuration data.
2. Use the `load_config_from_yaml` function to load the configuration from the YAML file.
3. Assert that the loaded configuration object has the expected values.
"""
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
with open(yaml_path, "w") as yaml_file:
yaml_file.write("""a: 3\nuser: ${USER}\n""")

View File

@ -9,6 +9,14 @@ import pydantic
class RuntimeConfig(base_config.BaseConfig):
"""
Configuration for runtime settings.
Attributes:
- wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration.
- enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices.
- enable_amp (bool): Enable automatic mixed precision.
"""
wandb: WandbConfig = pydantic.Field(None)
enable_tensorfloat32: bool = pydantic.Field(
False, description="Use tensorfloat32 if on Ampere devices."
@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig):
class TrainingConfig(base_config.BaseConfig):
"""
Configuration for training settings.
Attributes:
- save_dir (str): Directory to save checkpoints.
- num_train_steps (pydantic.PositiveInt): Number of training steps.
- initial_checkpoint_dir (str): Directory of initial checkpoints (optional).
- checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps.
- checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
- train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps.
- num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used.
- eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps.
- eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds.
- gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
- num_epochs (pydantic.PositiveInt): Number of training epochs.
"""
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
num_train_steps: pydantic.PositiveInt = 10000
initial_checkpoint_dir: str = pydantic.Field(

View File

@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable):
def _get_step_fn(pipeline, data_iterator, training: bool):
"""
Returns a function to perform a single evaluation step.
Args:
pipeline (Pipeline): The pipeline object containing the model.
data_iterator (Iterator): The data iterator for evaluation.
training (bool): Flag indicating if the model should be in training mode.
Returns:
function: A function that performs a single evaluation step.
"""
def step_fn():
"""
Perform a single evaluation step.
Returns:
Any: The evaluation results after a single step.
"""
# It turns out that model.train() and model.eval() simply switch a single field inside the model
# class,so it's somewhat safer to wrap in here.
if training:
@ -69,7 +86,21 @@ def _run_evaluation(
eval_batch_size: int,
logger=None,
):
"""Runs the evaluation loop over all evaluation iterators."""
"""
Run the evaluation loop over all evaluation iterators.
Args:
pipeline (Pipeline): The pipeline object containing the model.
dataset (Dataset): The dataset to evaluate.
eval_steps (int): The number of evaluation steps to perform.
metrics (tm.MetricCollection): A collection of evaluation metrics.
eval_batch_size (int): Batch size for evaluation.
logger (Optional[Logger]): A logger for recording evaluation progress (default: None).
Returns:
dict: A dictionary containing the computed evaluation metrics.
"""
dataset = get_new_iterator(dataset)
step_fn = _get_step_fn(pipeline, dataset, training=False)
last_time = datetime.datetime.now()
@ -109,15 +140,29 @@ def train(
parameters_to_log: Optional[Dict[str, Callable]] = None,
tables_to_log: Optional[List[str]] = None,
) -> None:
"""Runs training and eval on the given TrainPipeline
Args:
dataset: data iterator for the training set
evaluation_iterators: data iterators for the different evaluation sets
scheduler: optional learning rate scheduler
output_transform_for_metrics: optional transformation functions to transorm the model
output and labels into a format the metrics can understand
"""
Runs training and evaluation on the given TrainPipeline.
Args:
model (torch.nn.Module): The neural network model to train.
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
device (str): The target device for model training (e.g., 'cuda' or 'cpu').
save_dir (str): The directory to save model checkpoints and logs.
logging_interval (int): Interval for logging training progress.
train_steps (int): The number of training steps to perform.
checkpoint_frequency (int): Frequency of saving model checkpoints.
dataset (Iterable): Data iterator for the training set.
worker_batch_size (int): Batch size for data loading workers.
num_workers (Optional[int]): Number of data loading workers (default: 0).
enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False).
initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None).
gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None).
logger_initializer (Optional[Callable]): A logger initializer function (default: None).
scheduler (_LRScheduler): Optional learning rate scheduler (default: None).
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None).
tables_to_log (Optional[List[str]]): List of tables to log (default: None).
"""
train_pipeline = TrainPipelineSparseDist(
model=model,
@ -262,6 +307,15 @@ def log_eval_results(
partition_name: str,
step: int,
):
"""
Logs evaluation results and optionally records them using a provided logger.
Args:
results (Any): The evaluation results to log.
eval_logger (Callable): A logger for recording evaluation results.
partition_name (str): The name of the evaluation partition.
step (int): The current step in the evaluation.
"""
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
logging.info(f"Step: {step}, evaluation ({partition_name}).")
for metric_name, metric_value in results.items():
@ -285,6 +339,23 @@ def only_evaluate(
partition_name: str,
metrics: Optional[tm.MetricCollection] = None,
):
"""
Performs evaluation on a given dataset partition.
Args:
model (torch.nn.Module): The neural network model for evaluation.
optimizer (torch.optim.Optimizer): The optimizer used during evaluation.
device (str): The target device for evaluation (e.g., 'cuda' or 'cpu').
save_dir (str): The directory containing model checkpoints.
num_train_steps (int): The total number of training steps.
dataset (Iterable): Data iterator for evaluation.
eval_batch_size (int): Batch size for evaluation.
num_eval_steps (int): The number of evaluation steps to perform.
eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds.
eval_logger (Callable): A logger for recording evaluation results.
partition_name (str): The name of the evaluation partition.
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
"""
logging.info(f"Evaluating on partition {partition_name}.")
logging.info("Computing metrics:")
logging.info(metrics)

View File

@ -28,6 +28,18 @@ def train(
*args,
**kwargs,
) -> None:
"""
Debugging training loop. Do not use for actual model training.
Args:
model (torch.nn.Module): The neural network model.
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
train_steps (int): The number of training steps to perform.
dataset (Iterable): Data iterator for training data.
scheduler (_LRScheduler, optional): Learning rate scheduler (default: None).
*args: Additional arguments (ignored).
**kwargs: Additional keyword arguments (ignored).
"""
logging.warning("Running debug training loop, don't use for model training.")

View File

@ -10,8 +10,11 @@ import torch
def _maybe_warn(reduction: str):
"""
Warning for reduction different than mean.
"""
Emit a warning if the reduction method is different from 'mean'.
Args:
reduction (str): The reduction method being used.
"""
if reduction != "mean":
logging.warn(
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
@ -24,6 +27,16 @@ def build_loss(
loss_type: LossType,
reduction="mean",
):
"""
Build a loss function based on the specified loss type and reduction method.
Args:
loss_type (LossType): The type of loss to build.
reduction (str): The reduction method for the loss (default: 'mean').
Returns:
Callable: A loss function that takes logits and labels as input.
"""
_maybe_warn(reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
@ -35,11 +48,15 @@ def build_loss(
def get_global_loss_detached(local_loss, reduction="mean"):
"""
Perform all_reduce to obtain the global loss function using the provided reduction.
:param local_loss: The local loss of the current rank.
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
:return: The reduced & detached global loss.
"""
Perform all_reduce to obtain the global loss function using the provided reduction.
Args:
local_loss (torch.Tensor): The local loss of the current rank.
reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP.
Returns:
torch.Tensor: The reduced and detached global loss.
"""
if reduction != "mean":
logging.warn(
f"The reduction used in this function should be the same as the one used by "
@ -66,6 +83,19 @@ def build_multi_task_loss(
global_reduction="mean",
pos_weights=None,
):
"""
Build a multi-task loss function based on the specified loss type and configurations.
Args:
loss_type (LossType): The type of loss to build.
tasks (typing.List[str]): List of task names.
task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean').
global_reduction (str): Reduction method for the global loss (default: 'mean').
pos_weights (Optional): Positive class weights for tasks (default: None).
Returns:
Callable: A multi-task loss function that takes logits, labels, and weights as input.
"""
_maybe_warn(global_reduction)
_maybe_warn(task_loss_reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type]

View File

@ -36,9 +36,24 @@ import torchmetrics
class MetricMixin:
@abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
"""
Abstract method to transform model outputs into a dictionary of metrics.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
Returns:
Dict: A dictionary of computed metrics.
"""
...
def update(self, outputs: Dict[str, torch.Tensor]):
"""
Update the metrics based on model outputs.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
"""
results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification.
for value in results.values():
@ -49,6 +64,13 @@ class MetricMixin:
class TaskMixin:
def __init__(self, task_idx: int = -1, **kwargs):
"""
Initialize a TaskMixin instance.
Args:
task_idx (int): Index of the task associated with this mixin (default: -1).
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self._task_idx = task_idx
@ -59,13 +81,31 @@ class StratifyMixin:
stratifier=None,
**kwargs,
):
"""
Initialize a StratifyMixin instance.
Args:
stratifier: A stratifier for filtering outputs (default: None).
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self._stratifier = stratifier
def maybe_apply_stratification(
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
) -> Dict[str, torch.Tensor]:
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
"""
Apply stratification to filter examples in the outputs.
Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
value_names (List[str]): Names of values to filter.
Returns:
Dict[str, torch.Tensor]: Filtered outputs.
"""
outputs = outputs.copy()
if not self._stratifier:
return outputs
@ -84,12 +124,20 @@ class StratifyMixin:
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
"""Returns new class using MetricMixin and given base_metric.
Functionally the same using inheritance, just saves some lines of code
if no need for class attributes.
"""
Returns a new class using MetricMixin and the given base_metric.
Functionally the same as using inheritance, but it saves some lines of code
if there's no need for class attributes.
Args:
base_metric (torchmetrics.Metric): The base metric class to prepend the transform to.
transform (Callable): The transformation function to prepend to the metric.
Returns:
Type: A new class that includes MetricMixin and the provided base_metric
with the specified transformation method.
"""
def transform_method(_self, *args, **kwargs):
return transform(*args, **kwargs)

View File

@ -15,6 +15,16 @@ def probs_and_labels(
outputs: Dict[str, torch.Tensor],
task_idx: int,
) -> Dict[str, torch.Tensor]:
"""
Extract probabilities and labels from model outputs.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
task_idx (int): Index of the task.
Returns:
Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors.
"""
preds = outputs["probabilities"]
target = outputs["labels"]
if task_idx >= 0:
@ -28,6 +38,11 @@ def probs_and_labels(
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
def transform(self, outputs):
"""
Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric.
This metric counts values after potential stratification and task selection.
"""
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
if self._task_idx >= 0:
@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
This metric calculates the mean metric value after potential stratification and task selection.
"""
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
This metric calculates the mean metric value using probabilities after potential stratification and task selection.
"""
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
value = outputs["probabilities"]
@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
"""
Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision.
This metric computes precision after potential stratification and task selection.
"""
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
"""
Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall.
This metric computes recall after potential stratification and task selection.
"""
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
AUC (Area Under the ROC Curve) metric class.
This metric computes the AUC metric based on the logits and labels in the model outputs.
Args:
num_samples (int): The number of samples used to compute AUC.
**kwargs: Additional keyword arguments.
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
"""
@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The ranks of all positives
Based on:
PosRanks metric class.
This metric computes the ranks of all positive examples based on the logits and labels
in the model outputs.
Args:
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
"""
@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The reciprocal of the ranks of all
Based on:
ReciprocalRank metric class.
This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels
in the model outputs.
Args:
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
"""
@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The fraction of positives that rank in the top K among their negatives
Note that this is basically precision@k
Based on:
HitAtK metric class.
This metric computes the fraction of positive examples that rank in the top K among their negatives,
which is equivalent to precision@K.
Args:
k (int): The value of K.
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
"""

View File

@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric
@dataclass
class MockStratifierConfig:
"""
Configuration dataclass for mocking a stratifier.
Args:
name (str): The name of the stratifier.
index (int): The index of the stratifier.
value (int): The value of the stratifier.
"""
name: str
index: int
value: int
class Count(MetricMixin, SumMetric):
"""
Count metric class that inherits from MetricMixin and SumMetric.
This metric counts occurrences.
"""
def transform(self, outputs):
return {"value": 1}
@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
def test_count_metric():
"""
Test function for the Count metric.
It checks if the Count metric correctly counts the number of examples.
"""
num_examples = 123
examples = [
{"stuff": 0},
@ -36,6 +56,12 @@ def test_count_metric():
def test_collections():
"""
Test function for metric collections.
It tests if metric collections correctly aggregate metrics.
"""
max_metric = Max()
count_metric = Count()
metric = MetricCollection([max_metric, count_metric])
@ -51,6 +77,12 @@ def test_collections():
def test_task_dependent_ctr():
"""
Test function for task-dependent Ctr (Click-Through Rate) metric.
It checks if the Ctr metric computes the correct value for different tasks.
"""
num_examples = 144
batch_size = 1024
outputs = [
@ -69,6 +101,13 @@ def test_task_dependent_ctr():
def test_stratified_ctr():
"""
Test function for the Stratified Ctr (Click-Through Rate) metric.
It checks if the Stratified Ctr metric computes the correct value for different tasks
and stratified samples.
"""
outputs = [
{
"stuff": 0,
@ -114,6 +153,12 @@ def test_stratified_ctr():
def test_auc():
"""
Test function for the AUC (Area Under the Curve) metric.
It checks if the AUC metric correctly computes the Area Under the ROC Curve.
"""
num_samples = 10000
metric = core_metrics.Auc(num_samples)
target = torch.tensor([0, 0, 1, 1, 1])
@ -131,6 +176,12 @@ def test_auc():
def test_pos_rank():
"""
Test function for the PosRanks metric.
It checks if the PosRanks metric correctly computes the ranks of positive samples.
"""
metric = core_metrics.PosRanks()
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
@ -147,6 +198,12 @@ def test_pos_rank():
def test_reciprocal_rank():
"""
Test function for the Reciprocal Rank metric.
It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks.
"""
metric = core_metrics.ReciprocalRank()
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
@ -163,6 +220,12 @@ def test_reciprocal_rank():
def test_hit_k():
"""
Test function for the Hit@K metric.
It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives.
"""
hit1_metric = core_metrics.HitAtK(1)
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])

View File

@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel
@dataclass
class MockDataclassBatch(DataclassBatch):
"""
Mock data class batch for testing purposes.
This class represents a batch of data with continuous features and labels.
Attributes:
continuous_features (torch.Tensor): Tensor containing continuous feature data.
labels (torch.Tensor): Tensor containing label data.
"""
continuous_features: torch.Tensor
labels: torch.Tensor
class MockModule(torch.nn.Module):
"""
Mock PyTorch module for testing purposes.
This module defines a simple neural network model with a linear layer
followed by a BCEWithLogitsLoss loss function.
Attributes:
model (torch.nn.Linear): The linear model layer.
loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function.
"""
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Linear(10, 1)
self.loss_fn = torch.nn.BCEWithLogitsLoss()
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the mock module.
Args:
batch (MockDataclassBatch): Input data batch with continuous features and labels.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions.
"""
pred = self.model(batch.continuous_features)
loss = self.loss_fn(pred, batch.labels)
return (loss, pred)
def create_batch(bsz: int):
"""
Create a mock data batch with random continuous features and labels.
Args:
bsz (int): Batch size.
Returns:
MockDataclassBatch: A batch of data with continuous features and labels.
"""
return MockDataclassBatch(
continuous_features=torch.rand(bsz, 10).float(),
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
@ -35,6 +72,13 @@ def create_batch(bsz: int):
def test_sparse_pipeline():
"""
Test function for the sparse pipeline with distributed model parallelism.
This function tests the behavior of the sparse training pipeline using
a mock module and data.
"""
device = torch.device("cpu")
model = MockModule().to(device)
@ -65,6 +109,15 @@ def test_sparse_pipeline():
def test_amp():
"""
Test automatic mixed-precision (AMP) training with the sparse pipeline.
This function tests the behavior of the sparse training pipeline with
automatic mixed-precision (AMP) enabled, using a mock module and data.
AMP allows for faster training by using lower-precision data types, such as
torch.bfloat16, while maintaining model accuracy.
"""
device = torch.device("cpu")
model = MockModule().to(device)

View File

@ -39,12 +39,42 @@ Out = TypeVar("Out")
class TrainPipeline(abc.ABC, Generic[In, Out]):
"""
Abstract base class for training pipelines.
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline.
"""
@abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Make progress in the training pipeline.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
pass
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
"""
Move a batch of data to a specified device.
Args:
batch (In): The input batch.
device (torch.device): The target device.
non_blocking (bool): If True, move the data asynchronously.
Returns:
In: The batch of data on the target device.
"""
assert isinstance(
batch, (torch.Tensor, Pipelineable)
), f"{type(batch)} must implement Pipelineable interface"
@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
"""
Wait for a batch of data on a specified stream.
Args:
batch (In): The input batch.
stream (Optional[Stream]): The CUDA stream to wait for.
Note:
This function is used for managing asynchronous CUDA operations.
"""
if stream is None:
return
torch.cuda.current_stream().wait_stream(stream)
@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N
class TrainPipelineBase(TrainPipeline[In, Out]):
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None:
Initialize the TrainPipelineBase.
_connect(dataloader_iter: Iterator[In]) -> None:
Establish a connection to the data loader and move the input data to the GPU.
progress(dataloader_iter: Iterator[In]) -> Out:
Execute a training iteration, including forward and backward passes.
"""
def __init__(
self,
@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> None:
"""
Initialize the TrainPipelineBase.
Args:
model (torch.nn.Module): The PyTorch model to be trained.
optimizer (torch.optim.Optimizer): The optimizer used for training.
device (torch.device): The target device for training (CPU or GPU).
"""
self._model = model
self._optimizer = optimizer
self._device = device
@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = False
def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Establish a connection to the data loader and move the input data to the GPU.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
"""
cur_batch = next(dataloader_iter)
self._cur_batch = cur_batch
with torch.cuda.stream(self._memcpy_stream):
@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = True
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Execute a training iteration, including forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
if not self._connected:
self._connect(dataloader_iter)
@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
class Tracer(torch.fx.Tracer):
"""
Custom tracer class for PyTorch models.
This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings.
Attributes:
proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing.
_leaf_modules (List[str]): List of qualified names of leaf modules.
"""
# Disable proxying buffers during tracing. Ideally, proxying buffers would
# be disabled, but some models are currently mutating buffer values, which
# causes errors during tracing. If those models can be rewritten to not do
@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer):
proxy_buffer_attributes = False
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
"""
Initialize the Tracer.
Args:
leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing.
"""
super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
"""
Check if a module is a leaf module during tracing.
Args:
m (torch.nn.Module): The PyTorch module.
module_qualified_name (str): The qualified name of the module.
Returns:
bool: True if the module is considered a leaf module, False otherwise.
"""
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
return True
return super().is_leaf_module(m, module_qualified_name)
@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer):
@dataclass
class TrainPipelineContext:
"""
Dataclass to store information related to the training pipeline context.
Attributes:
input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests.
module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts.
feature_processor_forwards (List[Any]): A list of feature processor forwards.
"""
# pyre-ignore [4]
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
@ -166,6 +279,14 @@ class TrainPipelineContext:
@dataclass
class ArgInfo:
"""
Dataclass to store information about arguments in the training pipeline.
Attributes:
input_attrs (List[str]): List of attribute names of the input batch.
is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem.
name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments.
"""
# attributes of input batch, e.g. batch.attr1.attr2 call
# will produce ["attr1", "attr2"]
input_attrs: List[str]
@ -177,6 +298,16 @@ class ArgInfo:
class PipelinedForward:
"""
Represents a pipelined forward pass operation.
Attributes:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
def __init__(
self,
name: str,
@ -185,6 +316,16 @@ class PipelinedForward:
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> None:
"""
Initialize a PipelinedForward instance.
Args:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
self._name = name
self._args = args
self._module = module
@ -193,6 +334,16 @@ class PipelinedForward:
# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
"""
Perform the pipelined forward pass operation.
Args:
*input: Variable-length positional arguments.
**kwargs: Variable-length keyword arguments.
Returns:
Awaitable: An awaitable object representing the forward pass result.
"""
assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable)
@ -230,10 +381,22 @@ class PipelinedForward:
@property
def name(self) -> str:
"""
Get the name of the forward pass.
Returns:
str: The name of the forward pass.
"""
return self._name
@property
def args(self) -> List[ArgInfo]:
"""
Get the list of argument information for the forward pass.
Returns:
List[ArgInfo]: List of argument information.
"""
return self._args
@ -242,6 +405,17 @@ def _start_data_dist(
batch: In,
context: TrainPipelineContext,
) -> None:
"""
Start data distribution for a list of pipelined modules.
Args:
pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules.
batch (In): The input batch.
context (TrainPipelineContext): The training pipeline context.
Returns:
None: This function doesn't return a value.
"""
context.input_dist_requests.clear()
context.module_contexts.clear()
for module in pipelined_modules:
@ -286,9 +460,17 @@ def _get_node_args_helper(
feature_processor_arguments: Optional[List[Node]] = None,
) -> Tuple[List[ArgInfo], int]:
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
Args:
arguments: The arguments to process.
num_found: The current count of arguments found.
feature_processor_arguments: Optional list of feature processor arguments.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found.
"""
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
@ -332,6 +514,16 @@ def _get_node_args_helper(
def _get_node_args(
node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]:
"""
Get argument information for a given node.
Args:
node (Node): The node to process.
feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found.
"""
num_found = 0
pos_arg_info_list, num_found = _get_node_args_helper(
node.args, num_found, feature_processor_nodes
@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper(
path: str,
unsharded_module_names: Set[str],
) -> bool:
"""
Get the names of unsharded modules in a model.
Args:
model (torch.nn.Module): The model to analyze.
path (str): The current path in the model hierarchy.
unsharded_module_names (Set[str]): A set to store the names of unsharded modules.
Returns:
bool: True if any sharded modules were found in the hierarchy, False otherwise.
"""
sharded_children = set()
for name, child in model.named_children():
curr_path = path + name
@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper(
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
"""
Returns a list of top level modules do not contain any sharded sub modules.
"""
Returns a list of top-level modules that do not contain any sharded sub-modules.
Args:
model (torch.nn.Module): The model to analyze.
Returns:
List[str]: A list of top-level module names without sharded sub-modules.
"""
unsharded_module_names: Set[str] = set()
_get_unsharded_module_names_helper(
@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]:
"""
Rewrites the model to enable pipelined execution for selected sharded modules.
This function traces the input model using a custom tracer and identifies sharded modules
that can be pipelined. It then creates PipelinedForward objects for these modules,
which enable pipelining during training.
Args:
model (torch.nn.Module): The input model to be rewritten.
context (TrainPipelineContext): The context containing information needed for pipelining.
dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution.
Returns:
List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution.
"""
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
"""
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
"""
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
Attributes:
synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines.
"""
synced_pipeline_id: Dict[int, int] = {}
@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
enable_grad_scaling: bool = True,
grad_accum: Optional[int] = None,
) -> None:
"""
Initializes the training pipeline.
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
"""
self._model = model
self._optimizer = optimizer
self._device = device
@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
self._grad_accum = grad_accum
def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Connects the training pipeline to data and prepares for forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
"""
# batch 1
with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter)
@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
Progresses through the training pipeline, performing forward and backward passes.
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
Returns:
Out: The output of the forward pass.
"""
should_step_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
def _sync_pipeline(self) -> None:
"""
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
"""
for module in self._pipelined_modules:
module.forward._context = self._context

View File

@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222
def on_kf():
return "SPEC_TYPE" in os.environ
"""Check if the code is running on Kubernetes with Kubeflow (KF) environment.
Returns:
bool: True if running on KF, False otherwise.
"""
return "SPEC_TYPE" in os.environ
def has_readers():
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env["dataset_worker"] is not None
return os.environ.get("HAS_READERS", "False") == "True"
"""Check if the current task has dataset workers.
Returns:
bool: True if the task has dataset workers, False otherwise.
"""
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env.get("dataset_worker") is not None
return os.environ.get("HAS_READERS", "False") == "True"
def get_task_type():
if on_kf():
return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"]
"""Get the type of the current task.
Returns:
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
"""
if on_kf():
return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"]
def is_chief() -> bool:
return get_task_type() == "chief"
"""Check if the current task is the 'chief'.
Returns:
bool: True if the current task is the 'chief', False otherwise.
"""
return get_task_type() == "chief"
def is_reader() -> bool:
return get_task_type() == "datasetworker"
"""Check if the current task is a 'datasetworker'.
Returns:
bool: True if the current task is a 'datasetworker', False otherwise.
"""
return get_task_type() == "datasetworker"
def is_dispatcher() -> bool:
return get_task_type() == "datasetdispatcher"
"""Check if the current task is a 'datasetdispatcher'.
Returns:
bool: True if the current task is a 'datasetdispatcher', False otherwise.
"""
return get_task_type() == "datasetdispatcher"
def get_task_index():
if on_kf():
pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1])
else:
raise NotImplementedError
"""Get the index of the current task.
Returns:
int: Task index.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf():
pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1])
else:
raise NotImplementedError
def get_reader_port():
if on_kf():
return KF_DDS_PORT
return SLURM_DDS_PORT
"""Get the port used by readers.
Returns:
int: Reader port.
"""
if on_kf():
return KF_DDS_PORT
return SLURM_DDS_PORT
def get_dds():
if not has_readers():
return None
dispatcher_address = get_dds_dispatcher_address()
if dispatcher_address:
return f"grpc://{dispatcher_address}"
else:
raise ValueError("Job does not have DDS.")
"""Get the Distributed Data Service (DDS) address.
Returns:
str: DDS address in the format 'grpc://host:port'.
Raises:
ValueError: If the job does not have DDS.
"""
if not has_readers():
return None
dispatcher_address = get_dds_dispatcher_address()
if dispatcher_address:
return f"grpc://{dispatcher_address}"
else:
raise ValueError("Job does not have DDS.")
def get_dds_dispatcher_address():
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
dds_host = f"{job_name}-datasetdispatcher-0"
else:
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
return f"{dds_host}:{get_reader_port()}"
"""Get the DDS dispatcher address.
Returns:
str: DDS dispatcher address in the format 'host:port'.
"""
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
dds_host = f"{job_name}-datasetdispatcher-0"
else:
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
return f"{dds_host}:{get_reader_port()}"
def get_dds_worker_address():
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
task_index = get_task_index()
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
else:
node = os.environ["SLURMD_NODENAME"]
return f"{node}:{get_reader_port()}"
"""Get the DDS worker address.
Returns:
str: DDS worker address in the format 'host:port'.
"""
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
task_index = get_task_index()
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
else:
node = os.environ["SLURMD_NODENAME"]
return f"{node}:{get_reader_port()}"
def get_num_readers():
if not has_readers():
return 0
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env["num_dataset_workers"] or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
"""Get the number of dataset workers.
Returns:
int: Number of dataset workers.
"""
if not has_readers():
return 0
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env.get("num_dataset_workers") or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
def get_flight_server_addresses():
if on_kf():
job_name = os.environ["JOB_NAME"]
return [
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
for task_index in range(get_num_readers())
]
else:
raise NotImplementedError
"""Get Flight server addresses for dataset workers.
Returns:
List[str]: List of Flight server addresses in the format 'grpc://host:port'.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf():
job_name = os.environ["JOB_NAME"]
return [
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
for task_index in range(get_num_readers())
]
else:
raise NotImplementedError
def get_dds_journaling_dir():
return os.environ.get("DATASET_JOURNALING_DIR", None)
"""Get the DDS journaling directory.
Returns:
str: DDS journaling directory.
"""
return os.environ.get("DATASET_JOURNALING_DIR", None)

View File

@ -14,18 +14,21 @@ def update_mean(
weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update the mean according to Welford formula:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
Update the mean according to the Welford formula.
This function updates the mean and the weighted sum of values using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
Args:
current_mean: The value of the current accumulated mean.
current_weight_sum: The current weighted sum.
value: The new value that needs to be added to get a new mean.
weight: The weights for the new value.
Returns: The updated mean and updated weighted sum.
Args:
current_mean (torch.Tensor): The value of the current accumulated mean.
current_weight_sum (torch.Tensor): The current weighted sum.
value (torch.Tensor): The new value that needs to be added to get a new mean.
weight (torch.Tensor): The weights for the new value.
"""
Returns:
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
"""
weight = torch.broadcast_to(weight, value.shape)
# Avoiding (on purpose) in-place operation when using += in case
@ -38,12 +41,15 @@ def update_mean(
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
"""
Merge the state from multiple workers.
Args:
state: A tensor with the first dimension indicating workers.
Returns: The accumulated mean from all workers.
This function merges the state from multiple workers to compute the accumulated mean.
"""
Args:
state (torch.Tensor): A tensor with the first dimension indicating workers.
Returns:
torch.Tensor: The accumulated mean from all workers.
"""
mean, weight_sum = update_mean(
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
class StableMean(torchmetrics.Metric):
"""
This implements a numerical stable mean metrics computation using Welford algorithm according to
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
A numerical stable mean metric using the Welford algorithm.
This class implements a numerical stable mean metrics computation using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
than the maximum float32 as far as the mean is within the limit of float32.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
"""
Args:
**kwargs: Additional parameters supported by all torchmetrics.Metric.
Attributes:
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
"""
def __init__(self, **kwargs):
"""
@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric):
)
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
"""
Update the current mean.
Args:
value: Value to update the mean with.
weight: weight to use. Shape should be broadcastable to that of value.
"""
"""Update the current mean.
Args:
value (torch.Tensor): Value to update the mean with.
weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
"""
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
if not isinstance(weight, torch.Tensor):
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
)
def compute(self) -> torch.Tensor:
"""
Compute and return the accumulated mean.
"""
"""Compute and return the accumulated mean.
Returns:
torch.Tensor: The accumulated mean.
"""
return self.mean_and_weight_sum[0]

View File

@ -29,6 +29,9 @@ def _compute_helper(
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
we assume that they are correct prediction (i.e weight = 0).
Returns:
torch.Tensor: The computed AUROC
"""
dim = 0
@ -52,24 +55,34 @@ def _compute_helper(
class AUROCWithMWU(torchmetrics.Metric):
"""
AUROC using Mann-Whitney U-test.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return
the correct AUROC even if the predicted probabilities are all close to 0.
Currently only support binary classification.
"""
This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
Note: Currently, this implementation only supports binary classification.
Args:
label_threshold (float): Threshold for classifying labels as positive or negative.
Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
"""
Initializes the AUROCWithMWU metric.
Args:
label_threshold: Labels strictly above this threshold are considered positive labels,
otherwise, they are considered negative.
raise_missing_class: If True, an error will be raise if negative or positive class is missing.
Otherwise, we will simply log a warning.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
Args:
label_threshold (float): Threshold for classifying labels as positive or negative.
Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
super().__init__(**kwargs)
self.add_state("predictions", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

View File

@ -14,7 +14,24 @@ from absl import logging as logging
def setup_absl_logging():
"""Make sure that absl logging pushes to stdout rather than stderr."""
"""
Configure absl-py logging to direct log messages to stdout and apply a custom log message format.
This function ensures that log messages generated by the absl-py library are written to stdout
rather than stderr. It also applies a custom log message format that includes module, function,
line number, log level, and the log message content.
Note:
This function should be called once at the beginning of your script or application to
configure absl-py logging.
Example:
To use this function, simply call it at the start of your script:
```
setup_absl_logging()
```
"""
logging.get_absl_handler().python_handler.stream = sys.stdout
formatter = py_logging.Formatter(
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"

View File

@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging
class Testtlogging(unittest.TestCase):
def test_warn_once(self):
"""
Test that warning messages are logged only once when using the assertLogs context manager.
This unit test checks the behavior of the logging system when warning messages are issued
multiple times within the same context. It uses the assertLogs context manager to capture
log messages at the INFO level and verifies that warning messages are logged only once.
Example:
To use this test case, call it using a test runner like unittest:
```
python -m unittest your_test_module.TestLogging.test_warn_once
```
"""
with self.assertLogs(level="INFO") as captured_logs:
logging.info("first info")
logging.warning("first warning")

View File

@ -18,7 +18,35 @@ import torch.distributed as dist
def rank_specific(logger):
"""Ensures that we only override a given logger once."""
"""
Customize logger behavior based on the distributed environment and rank.
This function allows for customizing the behavior of a logger based on the distributed environment and the rank
of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
depending on the rank or limit the number of redundant logs.
Args:
logger: The logger object to customize.
Returns:
The customized logger.
Example:
To use this function with the `logging` module:
```python
import logging
from rank_specific_logging import rank_specific
logger = logging.getLogger(__name__)
rank_specific(logger)
```
Customization:
- Messages are only logged if the distributed environment is not initialized or if the rank matches.
- The 'warning' method is limited to logging a single redundant warning.
- Logging from rank -1 is redirected to include the rank information.
"""
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
return logger

View File

@ -8,11 +8,60 @@ import pydantic
class PiecewiseConstant(base_config.BaseConfig):
"""
Configuration for a piecewise constant learning rate schedule.
This configuration class allows you to specify a piecewise constant learning rate schedule
by defining boundaries and corresponding learning rate values.
Attributes:
learning_rate_boundaries (List[int], optional): List of step boundaries at which
the learning rate will change. If None, no boundaries are defined.
learning_rate_values (List[float], optional): List of learning rate values
corresponding to the boundaries. If None, no values are defined.
Example:
To configure a piecewise constant learning rate schedule, create an instance of this class
and set the attributes accordingly. For example:
```python
piecewise_lr = PiecewiseConstant(
learning_rate_boundaries=[1000, 2000, 3000],
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
)
```
Note:
The number of learning rate values should be one more than the number of boundaries.
"""
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
learning_rate_values: typing.List[float] = pydantic.Field(None)
class LinearRampToConstant(base_config.BaseConfig):
"""
Configuration for a linear ramp-up to constant learning rate schedule.
This configuration class allows you to specify a learning rate schedule that ramps up linearly
from zero to a constant value over a specified number of steps.
Attributes:
learning_rate (float): The final constant learning rate.
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
Example:
To configure a linear ramp-up to a constant learning rate, create an instance of this class
and set the attributes accordingly. For example:
```python
linear_ramp_lr = LinearRampToConstant(
learning_rate=0.1,
num_ramp_steps=1000
)
```
"""
learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
description="Number of steps to ramp this up from zero."
@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig):
class LinearRampToCosine(base_config.BaseConfig):
"""
Configuration for a linear ramp-up to cosine decay learning rate schedule.
This configuration class allows you to specify a learning rate schedule that ramps up linearly
from zero, then decays following a cosine schedule to a final constant learning rate.
Attributes:
learning_rate (float): The initial learning rate at the start of ramp-up.
final_learning_rate (float): The final constant learning rate after decay.
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
final_num_steps (PositiveInt): Final number of steps where decay stops.
Example:
To configure a linear ramp-up to cosine decay learning rate, create an instance of this
class and set the attributes accordingly. For example:
```python
ramp_to_cosine_lr = LinearRampToCosine(
learning_rate=0.01,
final_learning_rate=0.001,
num_ramp_steps=1000,
final_num_steps=5000
)
```
"""
learning_rate: float
final_learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig):
class LearningRate(base_config.BaseConfig):
"""
Learning rate configuration for training.
This configuration class allows you to specify different learning rate schedules
for your training process.
Attributes:
constant (float, optional): Constant learning rate to be used throughout training.
linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly
and then decays following a cosine schedule.
linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up
linearly and then remains constant.
piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified
boundaries with corresponding values.
Example:
To configure a learning rate schedule, create an instance of this class and set the
attributes accordingly. For example:
```python
learning_rate = LearningRate(
constant=0.01,
linear_ramp_to_cosine=LinearRampToCosine(
learning_rate=0.1,
final_learning_rate=0.001,
num_ramp_steps=1000,
final_num_steps=5000
)
)
```
Note:
Each learning rate schedule attribute can be set to `None` if not needed.
"""
constant: float = pydantic.Field(None, one_of="lr")
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig):
class OptimizerAlgorithmConfig(base_config.BaseConfig):
"""Base class for optimizer configurations."""
"""
Base class for optimizer configurations.
This base configuration class provides a structure for specifying various optimizer-related
settings, including the learning rate and different learning rate schedules.
Attributes:
lr (float): The base learning rate used by the optimizer.
Subclasses should inherit from this base class and define additional attributes specific to
the optimizer algorithm they represent.
Example:
To create a custom optimizer configuration, create a subclass of this base class and
define the necessary attributes. For example:
```python
class MyOptimizerConfig(OptimizerAlgorithmConfig):
momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.")
```
Note:
This base class does not include specific optimizer settings. Subclasses should define
the optimizer-specific attributes as needed.
"""
lr: float
...
class AdamConfig(OptimizerAlgorithmConfig):
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
"""
Configuration for the Adam optimizer.
This configuration class allows you to specify the hyperparameters for the Adam optimizer.
Attributes:
lr (float): The learning rate for optimization.
betas (Tuple[float, float], optional): Coefficients used for computing running averages
of gradient and squared gradient. Defaults to (0.9, 0.999).
eps (float, optional): A small constant added to the denominator for numerical stability.
Defaults to 1e-7.
Example:
To configure the Adam optimizer, create an instance of this class and set the attributes
accordingly. For example:
```python
adam_optimizer = AdamConfig(
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8
)
```
See Also:
[PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
"""
lr: float
betas: typing.Tuple[float, float] = [0.9, 0.999]
eps: float = 1e-7 # Numerical stability in denominator.
class SgdConfig(OptimizerAlgorithmConfig):
"""
Configuration for the Stochastic Gradient Descent (SGD) optimizer.
This configuration class allows you to specify the hyperparameters for the SGD optimizer.
Attributes:
lr (float): The learning rate for optimization.
momentum (float, optional): The momentum factor for SGD. Defaults to 0.0.
Example:
To configure the SGD optimizer, create an instance of this class and set the attributes
accordingly. For example:
```python
sgd_optimizer = SgdConfig(
lr=0.01,
momentum=0.9
)
```
"""
lr: float
momentum: float = 0.0
class AdagradConfig(OptimizerAlgorithmConfig):
"""
Configuration for the optimizer used during training.
This configuration class allows you to specify the optimizer for training, including
options for various optimizer algorithms.
Attributes:
learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None.
adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None.
sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
Defaults to None.
adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None.
Example:
To configure the optimizer for training, create an instance of this class and set the
attributes accordingly. For example:
```python
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.001),
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
)
```
"""
lr: float
eps: float = 0
class OptimizerConfig(base_config.BaseConfig):
"""
Configuration for defining different optimizer algorithms and their parameters.
This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad,
along with their respective hyperparameters.
Args:
learning_rate (LearningRate): The learning rate configuration, which can include
constant learning rates or other learning rate schedules.
adam (AdamConfig): Configuration for the Adam optimizer.
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
Example:
```python
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.001),
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8),
)
```
Attributes:
learning_rate (LearningRate): The learning rate configuration.
adam (AdamConfig): Configuration for the Adam optimizer.
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
Note:
You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an
`OptimizerConfig` instance.
See Also:
- `LearningRate`: Configuration for specifying learning rates.
- `AdamConfig`: Configuration for the Adam optimizer.
- `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer.
- `AdagradConfig`: Configuration for the Adagrad optimizer.
"""
learning_rate: LearningRate = pydantic.Field(
None,
description="Constant learning rates",
@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig):
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
"""
Get the optimizer algorithm configuration from the given `OptimizerConfig`.
This function extracts and returns the specific optimizer algorithm configuration
(e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`.
Args:
optimizer_config (OptimizerConfig): The optimizer configuration object containing
one of the optimizer algorithm configurations.
Returns:
Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm
configuration extracted from `optimizer_config`.
Raises:
ValueError: If no optimizer algorithm is selected in `optimizer_config`.
Example:
```python
optimizer_config = OptimizerConfig(
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
)
algorithm_config = get_optimizer_algorithm_config(optimizer_config)
# `algorithm_config` will be an instance of `AdamConfig`.
```
"""
if optimizer_config.adam is not None:
return optimizer_config.adam
elif optimizer_config.sgd is not None:

View File

@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step):
"""Compute a learning rate."""
"""
Compute the learning rate based on the specified learning rate configuration.
This function calculates the learning rate according to the given configuration, which can include
constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
Args:
lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
step (int): The current training step or iteration.
Returns:
float: The computed learning rate for the current step.
Raises:
ValueError: If the `lr_config` is invalid or contains conflicting options.
Example:
```python
lr_schedule = LearningRate(
constant=0.001,
piecewise_constant=PiecewiseConstant(
learning_rate_boundaries=[1000, 2000, 3000],
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
)
)
current_step = 2500
learning_rate = compute_lr(lr_schedule, current_step)
```
"""
if lr_config.constant is not None:
return lr_config.constant
elif lr_config.piecewise_constant is not None:
@ -46,11 +74,54 @@ def compute_lr(lr_config, step):
class LRShim(_LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
Learning Rate Scheduler Shim to adjust learning rates during training.
This class acts as a shim to apply different learning rates to individual parameter groups
within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
optimizers, allowing fine-grained control over learning rates based on configuration.
Args:
optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
corresponding learning rate configurations.
last_epoch (int, optional): The index of the last epoch. Default is -1.
verbose (bool, optional): If True, prints a warning message when accessing learning rates
using the deprecated `get_lr()` method. Default is False.
Raises:
ValueError: If the number of parameter groups in the optimizer does not match the number
of learning rate configurations provided.
Note:
To obtain the last computed learning rates, please use `get_last_lr()`.
Example:
```python
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
lr_schedule = {
'main': LearningRate(constant=0.01),
'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
learning_rate_boundaries=[1000, 2000],
learning_rate_values=[0.01, 0.001]
))
}
lr_shim = LRShim(optimizer, lr_schedule)
for epoch in range(num_epochs):
# Train the model
train(...)
# Update learning rates at the end of each epoch
lr_shim.step(epoch)
final_lr_main = lr_shim.get_last_lr()['main']
final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
```
See Also:
- `LearningRate`: Configuration for specifying learning rates.
- `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
"""
def __init__(
self,
@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig):
def build_optimizer(
model: torch.nn.Module, optimizer_config: OptimizerConfig
) -> Tuple[Optimizer, _LRScheduler]:
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
"""
Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
Args:
model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
algorithm and learning rate settings.
Returns:
Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
objects.
Note:
This function is intended for cases where you want the same optimizer and learning rate
schedule for all model parameters.
Example:
```python
model = MyModel()
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.01),
sgd=SgdConfig(lr=0.01, momentum=0.9)
)
optimizer, scheduler = build_optimizer(model, optimizer_config)
for epoch in range(num_epochs):
# Train the model with the optimizer
train(model, optimizer, ...)
# Update learning rates at the end of each epoch
scheduler.step(epoch)
```
See Also:
- `OptimizerConfig`: Configuration for specifying optimizer settings.
- `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
"""
optimizer_class = get_optimizer_class(optimizer_config)
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
# We're passing everything in as one group here

View File

@ -9,6 +9,36 @@ import pydantic
class TrainingConfig(config_mod.BaseConfig):
"""
Configuration settings for the training process.
This class defines various training-related settings, including the directory to save checkpoints, the number
of training steps, logging intervals, and other training parameters.
Attributes:
save_dir (str): The directory where checkpoints and training artifacts will be saved.
num_train_steps (pydantic.PositiveInt): The total number of training steps to run.
initial_checkpoint_dir (str): The directory containing initial checkpoints (optional).
checkpoint_every_n (pydantic.PositiveInt): Frequency of saving checkpoints during training.
checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
train_log_every_n (pydantic.PositiveInt): Frequency of logging training progress.
num_eval_steps (int): Number of evaluation steps. Use a negative value to evaluate the entire dataset.
eval_log_every_n (pydantic.PositiveInt): Frequency of logging evaluation progress.
eval_timeout_in_s (pydantic.PositiveFloat): Maximum time (in seconds) allowed for evaluation.
gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
Example:
To configure training with checkpoints saved every 1000 steps, use the following settings:
```python
TrainingConfig(
save_dir="/tmp/model",
num_train_steps=1000000,
checkpoint_every_n=1000,
train_log_every_n=1000,
)
```
"""
save_dir: str = "/tmp/model"
num_train_steps: pydantic.PositiveInt = 1000000
initial_checkpoint_dir: str = pydantic.Field(
@ -32,6 +62,42 @@ class TrainingConfig(config_mod.BaseConfig):
class RecapConfig(config_mod.BaseConfig):
"""
Configuration settings for the Recap model training process.
This class defines the overall configuration for the training process of a Recap model. It includes settings for
training, model architecture, data, optimization, and evaluation.
Attributes:
training (TrainingConfig): Configuration settings for the training process.
model (model_config.ModelConfig): Configuration settings for the Recap model architecture.
train_data (data_config.RecapDataConfig): Configuration settings for training data.
validation_data (Dict[str, data_config.RecapDataConfig]): Configuration settings for validation data.
optimizer (optimizer_config.RecapOptimizerConfig): Configuration settings for optimization.
which_metrics (Optional[str]): Optional specification of which metrics to pick.
Note:
This class encapsulates all the necessary configurations to train a Recap model. It defines settings for
training, the model architecture, data loading, optimization, and evaluation.
Example:
To configure a Recap model training process, use the following settings:
```python
RecapConfig(
training=TrainingConfig(
save_dir="/tmp/model",
num_train_steps=1000000,
checkpoint_every_n=1000,
train_log_every_n=1000,
),
model=model_config.ModelConfig(...),
train_data=data_config.RecapDataConfig(...),
validation_data={"dev": data_config.RecapDataConfig(...)},
optimizer=optimizer_config.RecapOptimizerConfig(...),
)
```
"""
training: TrainingConfig = pydantic.Field(TrainingConfig())
model: model_config.ModelConfig
train_data: data_config.RecapDataConfig

View File

@ -30,6 +30,7 @@ class EmbeddingBagConfig(base_config.BaseConfig):
class EmbeddingOptimizerConfig(base_config.BaseConfig):
"""Configuration for the optimizer used for embedding tables."""
learning_rate: optimizer_config.LearningRate = pydantic.Field(
None, description="learning rate scheduler for the EBC"
)
@ -52,6 +53,7 @@ class LargeEmbeddingsConfig(base_config.BaseConfig):
class StratifierConfig(base_config.BaseConfig):
"""Configuration for Stratifier."""
name: str
index: int
value: int
@ -87,7 +89,8 @@ class SmallEmbeddingsConfig(base_config.BaseConfig):
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
serving time due to size (>>1 GB).
This small embeddings table uses the same optimizer as the rest of the model."""
This small embeddings table uses the same optimizer as the rest of the model.
"""
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
..., description="list of embedding tables"

View File

@ -34,6 +34,33 @@ FLAGS = flags.FLAGS
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
"""
Main function to run the training of a ranking model.
This function initializes and runs the training process for a ranking model based on the provided configuration.
Args:
unused_argv (str): Unused argument.
data_service_dispatcher (Optional[str]): The data service dispatcher for accessing training data (optional).
Returns:
None
Raises:
AssertionError: If the configuration or input data is not valid.
Note:
This function serves as the main entry point for training a ranking model. It loads the configuration, sets up
the training environment, defines the loss function, creates the model, optimizer, and scheduler, and runs the
training loop.
Example:
To run the training process, use the following command:
```
python run_training.py --config_path=config.yaml
```
"""
print("#" * 100)
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)

View File

@ -40,6 +40,7 @@ class BatchNormConfig(base_config.BaseConfig):
class DenseLayerConfig(base_config.BaseConfig):
"""Configuration for the dense layer."""
layer_size: pydantic.PositiveInt
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
@ -61,6 +62,7 @@ class BatchNormConfig(base_config.BaseConfig):
class DoubleNormLogConfig(base_config.BaseConfig):
"""Configuration for the double norm log transform."""
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
clip_magnitude: float = pydantic.Field(
5.0, description="Threshold to clip the normalized input values."
@ -73,12 +75,14 @@ class Log1pAbsConfig(base_config.BaseConfig):
class ClipLog1pAbsConfig(base_config.BaseConfig):
"""Configuration for the clip log transform."""
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
3e38, description="Threshold to clip the input values."
)
class ZScoreLogConfig(base_config.BaseConfig):
"""Configuration for the z-score log transform."""
analysis_path: str
schema_path: str = pydantic.Field(
None,
@ -148,6 +152,7 @@ class DcnConfig(base_config.BaseConfig):
class MaskBlockConfig(base_config.BaseConfig):
"""Config for MaskNet block."""
output_size: int
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
None, one_of="aggregation_size"
@ -159,6 +164,7 @@ class MaskBlockConfig(base_config.BaseConfig):
class MaskNetConfig(base_config.BaseConfig):
"""Config for MaskNet model."""
mask_blocks: List[MaskBlockConfig]
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
@ -190,6 +196,7 @@ class AffineMap(base_config.BaseConfig):
class DLRMConfig(base_config.BaseConfig):
"""Config for DLRM model."""
bottom_mlp: MlpConfig = pydantic.Field(
...,
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
@ -198,6 +205,7 @@ class DLRMConfig(base_config.BaseConfig):
class TaskModel(base_config.BaseConfig):
"""Configuration for a single task."""
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
@ -213,6 +221,7 @@ class TaskModel(base_config.BaseConfig):
class MultiTaskType(str, enum.Enum):
"""Type of multi task architecture."""
SHARE_NONE = "share_none" # Tasks are separate.
SHARE_ALL = "share_all" # Tasks share same backbone.
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
@ -247,6 +256,7 @@ class ModelConfig(base_config.BaseConfig):
@pydantic.root_validator()
def _validate_mtl(cls, values):
"""Validate the multi task architecture."""
if values.get("multi_task_type", None) is None:
return values
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:

View File

@ -26,7 +26,19 @@ def unsanitize(sanitized_task_name):
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
""" "Builds a model for a single task"""
"""
Build a model for a single task based on the provided configuration.
Args:
task (model_config_mod.TaskModel): The task model configuration.
input_shape (int): The input shape for the model.
Returns:
torch.nn.Module: The constructed model for the single task.
Raises:
ValueError: If the task configuration is not recognized.
"""
if task.mlp_config:
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
elif task.dcn_config:
@ -38,7 +50,12 @@ def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int)
class MultiTaskRankingModel(torch.nn.Module):
"""Multi-task ranking model."""
"""
Multi-task ranking model that handles multiple ranking tasks simultaneously.
This model takes various input features and predicts rankings for multiple
tasks using shared or separate towers.
"""
def __init__(
self,
@ -47,12 +64,18 @@ class MultiTaskRankingModel(torch.nn.Module):
data_config: RecapDataConfig,
return_backbone: bool = False,
):
"""Constructor for Multi task learning.
"""
Constructor for Multi-task ranking model.
Args:
input_shapes (Mapping[str, torch.Size]): A mapping of input feature names to their shapes.
config (ModelConfig): The model configuration.
data_config (RecapDataConfig): The data configuration.
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
Assumptions made:
1. Tasks specified in data config match model architecture.
These are all validated in config.
1. Tasks specified in data config match model architecture.
These are all validated in config.
"""
super().__init__()
@ -168,6 +191,23 @@ class MultiTaskRankingModel(torch.nn.Module):
labels: Optional[torch.Tensor] = None,
weights: Optional[torch.Tensor] = None,
):
"""
Forward pass of the Multi-task ranking model.
Args:
continuous_features (torch.Tensor): Continuous input features.
binary_features (torch.Tensor): Binary input features.
discrete_features (Optional[torch.Tensor], optional): Discrete input features. Defaults to None.
sparse_features ([type], optional): Sparse input features. Defaults to None.
user_embedding (Optional[torch.Tensor], optional): User embeddings. Defaults to None.
user_eng_embedding (Optional[torch.Tensor], optional): User engagement embeddings. Defaults to None.
author_embedding (Optional[torch.Tensor], optional): Author embeddings. Defaults to None.
labels (Optional[torch.Tensor], optional): Target labels. Defaults to None.
weights (Optional[torch.Tensor], optional): Weights for the loss function. Defaults to None.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the model's outputs.
"""
concat_dense_features = [
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
]
@ -270,6 +310,20 @@ def create_ranking_model(
data_config=None,
return_backbone=False,
):
"""
Creates a ranking model based on the provided specifications and configuration.
Args:
data_spec: The input data specifications.
config (config_mod.RecapConfig): The model configuration.
device (torch.device): The device where the model should be placed.
loss_fn (Optional[Callable], optional): A custom loss function. Defaults to None.
data_config: The data configuration. Defaults to None.
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
Returns:
torch.nn.Module: The created ranking model.
"""
if list(config.model.tasks.values())[0].dlrm_config:
raise NotImplementedError()

View File

@ -11,21 +11,52 @@ import torch
def log_transform(x: torch.Tensor) -> torch.Tensor:
"""Safe log transform that works across both negative, zero, and positive floats."""
"""
Safe log transform that works across both negative, zero, and positive floats.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with log1p applied to absolute values.
"""
return torch.sign(x) * torch.log1p(torch.abs(x))
class BatchNorm(torch.nn.Module):
def __init__(self, num_features: int, config: BatchNormConfig):
"""
Batch normalization layer.
Args:
num_features (int): Number of input features.
config (BatchNormConfig): Configuration for batch normalization.
"""
super().__init__()
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the batch normalization layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after batch normalization.
"""
return self.layer(x)
class LayerNorm(torch.nn.Module):
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
"""
Layer normalization layer.
Args:
normalized_shape (Union[int, Sequence[int]]): Size or shape of the input tensor.
config (LayerNormConfig): Configuration for layer normalization.
"""
super().__init__()
if config.axis != -1:
raise NotImplementedError
@ -38,6 +69,16 @@ class LayerNorm(torch.nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the layer normalization layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after layer normalization.
"""
return self.layer(x)
@ -46,11 +87,27 @@ class Log1pAbs(torch.nn.Module):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that applies a log transformation to the input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with log applied to absolute values.
"""
return log_transform(x)
class InputNonFinite(torch.nn.Module):
def __init__(self, fill_value: float = 0):
"""
Replaces non-finite (NaN and Inf) values in the input tensor with a specified fill value.
Args:
fill_value (float): The value to fill non-finite elements with. Default is 0.
"""
super().__init__()
self.register_buffer(
@ -58,11 +115,27 @@ class InputNonFinite(torch.nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that replaces non-finite values in the input tensor with the specified fill value.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with non-finite values replaced.
"""
return torch.where(torch.isfinite(x), x, self.fill_value)
class Clamp(torch.nn.Module):
def __init__(self, min_value: float, max_value: float):
"""
Applies element-wise clamping to a tensor, ensuring that values are within a specified range.
Args:
min_value (float): The minimum value to clamp elements to.
max_value (float): The maximum value to clamp elements to.
"""
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
@ -74,12 +147,31 @@ class Clamp(torch.nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that clamps the input tensor element-wise within the specified range.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with elements clamped within the specified range.
"""
return torch.clamp(x, min=self.min_value, max=self.max_value)
class DoubleNormLog(torch.nn.Module):
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features."""
"""
Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.
Args:
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
config (DoubleNormLogConfig): Configuration for the DoubleNormLog module.
Attributes:
_before_concat_layers (torch.nn.Sequential): Sequential layers for batch normalization, log transformation,
batch normalization (optional), and clamping.
layer_norm (LayerNorm or None): Layer normalization layer for binary and continuous features (optional).
"""
def __init__(
self,
input_shapes: Mapping[str, Sequence[int]],
@ -108,6 +200,17 @@ class DoubleNormLog(torch.nn.Module):
def forward(
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
) -> torch.Tensor:
"""
Forward pass that processes continuous and binary features using batch normalization, log transformation,
optional batch normalization (if configured), clamping, and layer normalization (if configured).
Args:
continuous_features (torch.Tensor): Input tensor of continuous features.
binary_features (torch.Tensor): Input tensor of binary features.
Returns:
torch.Tensor: Transformed tensor containing both continuous and binary features.
"""
x = self._before_concat_layers(continuous_features)
x = torch.cat([x, binary_features], dim=1)
if self.layer_norm:
@ -118,5 +221,15 @@ class DoubleNormLog(torch.nn.Module):
def build_features_preprocessor(
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
):
"""Trivial right now, but we will change in the future."""
"""
Build a feature preprocessor module based on the provided configuration.
Trivial right now, but we will change in the future.
Args:
config (FeaturizationConfig): Configuration for feature preprocessing.
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
Returns:
DoubleNormLog: An instance of the DoubleNormLog feature preprocessor.
"""
return DoubleNormLog(input_shapes, config.double_norm_log_config)

View File

@ -6,15 +6,84 @@ import torch
def _init_weights(module):
"""Initializes weights
Example
```python
import torch
import torch.nn as nn
# Define a simple linear layer
linear_layer = nn.Linear(64, 32)
# Initialize the weights and biases using _init_weights
_init_weights(linear_layer)
```
"""
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class MaskBlock(torch.nn.Module):
"""
MaskBlock module in a mask-based neural network.
This module represents a MaskBlock, which applies a masking operation to the input data and then
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
Example:
To create and use a MaskBlock within a MaskNet, follow these steps:
```python
# Define the configuration for the MaskBlock
mask_block_config = MaskBlockConfig(
input_layer_norm=True, # Apply input layer normalization
reduction_factor=0.5 # Reduce input dimensionality by 50%
)
# Create an instance of the MaskBlock
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
mask_input = torch.randn(batch_size, 32)
# Perform a forward pass through the MaskBlock
output = mask_block(input_data, mask_input)
```
Note:
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
operation that combines the input and mask input. Then, it passes the result through a hidden layer
with optional dimensionality reduction.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
"""
Initializes the MaskBlock module.
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
Returns:
None
"""
super(MaskBlock, self).__init__()
self.mask_block_config = mask_block_config
output_size = mask_block_config.output_size
@ -42,6 +111,16 @@ class MaskBlock(torch.nn.Module):
self._layer_norm = torch.nn.LayerNorm(output_size)
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
"""
Performs a forward pass through the MaskBlock.
Args:
net (torch.Tensor): Input data tensor.
mask_input (torch.Tensor): Mask input tensor.
Returns:
torch.Tensor: Output tensor of the MaskBlock.
"""
if self._input_layer_norm:
net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
@ -49,7 +128,60 @@ class MaskBlock(torch.nn.Module):
class MaskNet(torch.nn.Module):
"""
MaskNet module in a mask-based neural network.
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
create mask-based neural networks with parallel or stacked MaskBlocks.
Args:
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
Example:
To create and use a MaskNet, you can follow these steps:
```python
# Define the configuration for the MaskNet
mask_net_config = MaskNetConfig(
use_parallel=True, # Use parallel MaskBlocks
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
)
# Create an instance of the MaskNet
mask_net = MaskNet(mask_net_config, in_features=64)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
# Perform a forward pass through the MaskNet
outputs = mask_net(input_data)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
"""
Initializes the MaskNet module.
Args:
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
Returns:
None
"""
super().__init__()
self.mask_net_config = mask_net_config
mask_blocks = []
@ -77,6 +209,15 @@ class MaskNet(torch.nn.Module):
self.shared_size = total_output_mask_blocks
def forward(self, inputs: torch.Tensor):
"""
Performs a forward pass through the MaskNet.
Args:
inputs (torch.Tensor): Input data tensor.
Returns:
torch.Tensor: Output tensor of the MaskNet.
"""
if self.mask_net_config.use_parallel:
mask_outputs = []
for mask_layer in self._mask_blocks:

View File

@ -7,13 +7,83 @@ from absl import logging
def _init_weights(module):
"""Initializes weights
Example
-------
```python
import torch
import torch.nn as nn
# Define a simple linear layer
linear_layer = nn.Linear(64, 32)
# Initialize the weights and biases using _init_weights
_init_weights(linear_layer)
```
"""
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class Mlp(torch.nn.Module):
"""
Multi-Layer Perceptron (MLP) feedforward neural network module in PyTorch.
This module defines an MLP with customizable layers and activation functions. It is suitable for various
applications such as deep learning for tabular data, feature extraction, and more.
Args:
in_features (int): The number of input features or input dimensions.
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
Example:
To create an instance of the `Mlp` module and use it for forward passes, you can follow these steps:
```python
# Define the configuration for the MLP
mlp_config = MlpConfig(
layer_sizes=[128, 64], # Specify the sizes of hidden layers
batch_norm=True, # Enable batch normalization
dropout=0.2, # Apply dropout with a rate of 0.2
final_layer_activation=True # Apply ReLU activation to the final layer
)
# Create an instance of the MLP module
mlp_model = Mlp(in_features=input_dim, mlp_config=mlp_config)
# Generate an input tensor
input_tensor = torch.randn(batch_size, input_dim)
# Perform a forward pass through the MLP
outputs = mlp_model(input_tensor)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `Mlp` class allows you to create customizable MLP architectures by specifying the layer sizes,
enabling batch normalization and dropout, and choosing the activation function for the final layer.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, in_features: int, mlp_config: MlpConfig):
"""
Initializes the Mlp module.
Args:
in_features (int): The number of input features or input dimensions.
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
Returns:
None
"""
super().__init__()
self._mlp_config = mlp_config
input_size = in_features
@ -42,6 +112,15 @@ class Mlp(torch.nn.Module):
self.layers.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs a forward pass through the MLP.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: Output tensor of the MLP.
"""
net = x
for i, layer in enumerate(self.layers):
net = layer(net)
@ -51,8 +130,21 @@ class Mlp(torch.nn.Module):
@property
def shared_size(self):
"""
Returns the size of the shared layer in the MLP.
Returns:
int: Size of the shared layer.
"""
return self._mlp_config.layer_sizes[-1]
@property
def out_features(self):
"""
Returns the number of output features from the MLP.
Returns:
int: Number of output features.
"""
return self._mlp_config.layer_sizes[-1]

View File

@ -5,6 +5,53 @@ from absl import logging
class ModelAndLoss(torch.nn.Module):
"""
PyTorch module that combines a neural network model and loss function.
This module wraps a neural network model and facilitates the forward pass through the model
while also calculating the loss based on the model's predictions and provided labels.
Args:
model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification. Each stratifier config includes the name and index of discrete features
to emit for stratification.
Example:
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
and loss function as arguments:
```python
# Create a neural network model
model = YourNeuralNetworkModel()
# Define a loss function
loss_fn = torch.nn.CrossEntropyLoss()
# Create an instance of ModelAndLoss
model_and_loss = ModelAndLoss(model, loss_fn)
# Generate a batch of training data (e.g., RecapBatch)
batch = generate_training_batch()
# Perform a forward pass through the model and calculate the loss
loss, outputs = model_and_loss(batch)
# You can now backpropagate and optimize using the computed loss
loss.backward()
optimizer.step()
```
Note:
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
calculating loss, making it easier to integrate the model into your training loop. Additionally,
it supports the addition of stratifiers for metrics stratification, if needed.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
model,
@ -12,10 +59,13 @@ class ModelAndLoss(torch.nn.Module):
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.
Initializes the ModelAndLoss module.
Args:
model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification.
"""
super().__init__()
self.model = model

View File

@ -2,11 +2,57 @@ import torch
class NumericCalibration(torch.nn.Module):
"""
Numeric calibration module for adjusting probability scores.
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
for tasks such as binary classification.
Args:
pos_downsampling_rate (float): The downsampling rate for positive samples.
neg_downsampling_rate (float): The downsampling rate for negative samples.
Example:
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
scores like this:
```python
# Create a NumericCalibration instance with downsampling rates
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
# Generate probability scores (e.g., from a neural network)
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
# Apply numeric calibration to adjust the probabilities
calibrated_probs = calibration(raw_probs)
# The `calibrated_probs` now contains the adjusted probability scores
```
Note:
The `NumericCalibration` module is used to adjust probability scores to account for differences in
the number of positive and negative samples in a dataset. It can help improve the calibration of
probability estimates in imbalanced classification problems.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
pos_downsampling_rate: float,
neg_downsampling_rate: float,
):
"""
Apply numeric calibration to probability scores.
Args:
probs (torch.Tensor): Probability scores to be calibrated.
Returns:
torch.Tensor: Calibrated probability scores.
"""
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).

View File

@ -9,12 +9,60 @@ import pydantic
class RecapAdamConfig(base_config.BaseConfig):
"""
Configuration settings for the Adam optimizer used in Recap.
Args:
beta_1 (float): Momentum term (default: 0.9).
beta_2 (float): Exponential weighted decay factor (default: 0.999).
epsilon (float): Numerical stability in the denominator (default: 1e-7).
Example:
To define an Adam optimizer configuration for Recap, use:
```python
adam_config = RecapAdamConfig(beta_1=0.9, beta_2=0.999, epsilon=1e-7)
```
Note:
This class configures the parameters of the Adam optimizer, which is commonly used for optimizing neural networks.
Warning:
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
"""
beta_1: float = 0.9 # Momentum term.
beta_2: float = 0.999 # Exponential weighted decay factor.
epsilon: float = 1e-7 # Numerical stability in denominator.
class MultiTaskLearningRates(base_config.BaseConfig):
"""
Configuration settings for multiple learning rates in Recap.
Args:
tower_learning_rates (Dict[str, optimizers_config_mod.LearningRate]): Learning rates for different towers of the model.
backbone_learning_rate (optimizers_config_mod.LearningRate): Learning rate for the model's backbone (default: None).
Example:
To define multiple learning rates for different towers in Recap, use:
```python
multi_task_lr = MultiTaskLearningRates(
tower_learning_rates={
'task1': learning_rate1,
'task2': learning_rate2,
},
backbone_learning_rate=backbone_lr,
)
```
Note:
This class allows specifying different learning rates for different parts of the model, including task-specific towers and the backbone.
Warning:
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
"""
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
description="Learning rates for different towers of the model."
)
@ -25,6 +73,30 @@ class MultiTaskLearningRates(base_config.BaseConfig):
class RecapOptimizerConfig(base_config.BaseConfig):
"""
Configuration settings for the Recap optimizer.
Args:
multi_task_learning_rates (MultiTaskLearningRates): Multiple learning rates for different tasks (optional).
single_task_learning_rate (optimizers_config_mod.LearningRate): Learning rate for a single task (optional).
adam (RecapAdamConfig): Configuration settings for the Adam optimizer.
Example:
To define an optimizer configuration for training with Recap, use:
```python
optimizer_config = RecapOptimizerConfig(
multi_task_learning_rates=multi_task_lr,
single_task_learning_rate=single_task_lr,
adam=adam_config,
)
```
Warning:
This class is intended for internal use to configure the optimizer settings within Recap and should not be
directly accessed by external code.
"""
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
None, description="Multiple learning rates for different tasks.", one_of="lr"
)

View File

@ -23,12 +23,30 @@ _DENSE_EMBEDDINGS = "dense_ebc"
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
A shim to get learning rates into a LRScheduler.
This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports
learning rate schedules, such as exponential decay.
Args:
optimizer: The optimizer to which this scheduler is applied.
lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations.
emb_learning_rate: The learning rate for embeddings (optional).
last_epoch (int): The index of the last epoch (default: -1).
verbose (bool): If True, print warnings for deprecated functions (default: False).
Example:
To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use:
```python
scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate)
```
Warning:
This class is intended for internal use to handle learning rate scheduling within Recap training and should not
be directly accessed by external code.
"""
def __init__(
self,
@ -80,15 +98,25 @@ def build_optimizer(
optimizer_config: config.OptimizerConfig,
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
):
"""Builds an optimizer and scheduler.
Args:
model: A torch model, probably with DDP/DMP.
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
Returns:
A torch.optim instance, and a scheduler instance.
"""
Build an optimizer and scheduler for training.
Args:
model: The torch model, possibly with DDP/DMP.
optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer.
emb_optimizer_config: Configuration settings for embedding optimization (optional).
Returns:
torch.optim.Optimizer: The optimizer for training.
RecapLRShim: The learning rate scheduler for the optimizer.
Example:
To build an optimizer and scheduler for training, use:
```python
optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config)
```
"""
optimizer_fn = functools.partial(
torch.optim.Adam,
lr=_DEFAULT_LR,

View File

@ -4,6 +4,17 @@ import pydantic
class TwhinDataConfig(base_config.BaseConfig):
"""
Configuration for Twhin model training data.
Args:
data_root (str): The root directory for the training data.
per_replica_batch_size (pydantic.PositiveInt): Batch size per replica.
global_negatives (int): The number of global negatives.
in_batch_negatives (int): The number of in-batch negatives.
limit (pydantic.PositiveInt): The limit on the number of data points to use.
offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None.
"""
data_root: str
per_replica_batch_size: pydantic.PositiveInt
global_negatives: int

View File

@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
"""
Create a dataset for Twhin model training.
Args:
data_config (TwhinDataConfig): The data configuration for the dataset.
model_config (TwhinModelConfig): The model configuration containing embeddings and relations.
Returns:
EdgesDataset: The dataset for Twhin model training.
"""
tables = model_config.embeddings.tables
table_sizes = {table.name: table.num_embeddings for table in tables}
relations = model_config.relations

View File

@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@dataclass
class EdgeBatch(DataclassBatch):
"""
Batch data structure for edge-based models.
Args:
nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
labels (torch.Tensor): Tensor containing labels.
rels (torch.Tensor): Tensor containing relation information.
weights (torch.Tensor): Tensor containing weights.
"""
nodes: KeyedJaggedTensor
labels: torch.Tensor
rels: torch.Tensor
@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch):
class EdgesDataset(Dataset):
"""
Dataset for edge-based models.
Args:
file_pattern (str): The file pattern for the dataset.
table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
relations (List[Relation]): A list of relations between tables.
lhs_column_name (str): The name of the left-hand-side column.
rhs_column_name (str): The name of the right-hand-side column.
rel_column_name (str): The name of the relation column.
**dataset_kwargs: Additional keyword arguments for the parent Dataset class.
"""
rng = np.random.default_rng()
def __init__(
@ -56,6 +77,15 @@ class EdgesDataset(Dataset):
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
def pa_to_batch(self, batch: pa.RecordBatch):
"""
Converts a pyarrow RecordBatch to an EdgeBatch.
Args:
batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
Returns:
EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
"""
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
@ -74,6 +104,14 @@ class EdgesDataset(Dataset):
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
"""Process edges that contain lhs index, rhs index, relation index.
Args:
lhs (torch.Tensor): Tensor containing left-hand-side indices.
rhs (torch.Tensor): Tensor containing right-hand-side indices.
rel (torch.Tensor): Tensor containing relation indices.
Returns:
Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
Example:
```
@ -147,6 +185,12 @@ class EdgesDataset(Dataset):
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
def to_batches(self):
"""
Converts data to batches.
Yields:
pa.RecordBatch: A pyarrow RecordBatch containing data.
"""
ds = super().to_batches()
batch_size = self._dataset_kwargs["batch_size"]

View File

@ -23,6 +23,25 @@ RELATIONS = [
def test_gen():
"""Test function for generating edge-based datasets and dataloaders.
This function generates a synthetic dataset and tests the creation of an `EdgesDataset`
instance and a dataloader for it.
The test includes the following steps:
1. Create synthetic data with left-hand-side (lhs), right-hand-side (rhs), and relation (rel) columns.
2. Write the synthetic data to a Parquet file.
3. Create an `EdgesDataset` instance with the Parquet file pattern, table sizes, relations, and batch size.
4. Initialize the local file system for the dataset.
5. Create a dataloader for the dataset and retrieve the first batch.
6. Assert that the labels in the batch are positive.
7. Verify that the positive examples in the batch match the expected values.
This function serves as a test case for the data generation and dataset creation process.
Raises:
AssertionError: If any of the test assertions fail.
"""
import os
import tempfile

View File

@ -10,8 +10,29 @@ from pydantic import validator
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
"""
Configuration class for Twhin model embeddings.
This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types
for all tables in the Twhin model embeddings configuration match.
Attributes:
tables (List[TableConfig]): A list of table configurations for the model's embeddings.
"""
@validator("tables")
def embedding_dims_match(cls, tables):
"""
Validate that embedding dimensions and data types match for all tables.
Args:
tables (List[TableConfig]): List of table configurations.
Returns:
List[TableConfig]: The list of validated table configurations.
Raises:
AssertionError: If embedding dimensions or data types do not match.
"""
embedding_dim = tables[0].embedding_dim
data_type = tables[0].data_type
for table in tables:
@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
class Operator(str, enum.Enum):
"""
Enumeration of operator types.
This enumeration defines different types of operators that can be applied to Twhin model relations.
"""
TRANSLATION = "translation"
class Relation(pydantic.BaseModel):
"""graph relationship properties and operator"""
"""
Configuration class for graph relationships in the Twhin model.
This class defines properties and operators for graph relationships in the Twhin model.
Attributes:
name (str): The name of the relationship.
lhs (str): The name of the entity on the left-hand side of the relation.
rhs (str): The name of the entity on the right-hand side of the relation.
operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product.
"""
name: str = pydantic.Field(..., description="Relationship name.")
lhs: str = pydantic.Field(
@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel):
class TwhinModelConfig(base_config.BaseConfig):
"""
Configuration class for the Twhin model.
This class defines configuration options specific to the Twhin model.
Attributes:
embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings.
relations (List[Relation]): List of graph relationship configurations.
translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation.
"""
embeddings: TwhinEmbeddingsConfig
relations: typing.List[Relation]
translation_optimizer: OptimizerConfig
@validator("relations", each_item=True)
def valid_node_types(cls, relation, values, **kwargs):
"""
Validate that the specified node types in relations are valid table names in embeddings.
Args:
relation (Relation): A single relation configuration.
values (dict): The values dictionary containing the "embeddings" configuration.
Returns:
Relation: The validated relation configuration.
Raises:
AssertionError: If the specified node types are not valid table names in embeddings.
"""
table_names = [table.name for table in values["embeddings"].tables]
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"

View File

@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa
class TwhinModel(nn.Module):
"""
Twhin model for graph-based entity embeddings and translation.
This class defines the Twhin model, which is used for learning embeddings of entities in a graph
and applying translations to these embeddings based on graph relationships.
Args:
model_config (TwhinModelConfig): Configuration for the Twhin model.
data_config (TwhinDataConfig): Configuration for the data used by the model.
Attributes:
batch_size (int): The batch size used for training.
table_names (List[str]): Names of tables in the model's embeddings.
large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
embedding_dim (int): Dimensionality of entity embeddings.
num_tables (int): Number of tables in the model's embeddings.
in_batch_negatives (int): Number of in-batch negative samples to use during training.
global_negatives (int): Number of global negative samples to use during training.
num_relations (int): Number of graph relationships in the model.
all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
"""
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
super().__init__()
self.batch_size = data_config.per_replica_batch_size
@ -31,7 +53,17 @@ class TwhinModel(nn.Module):
)
def forward(self, batch: EdgeBatch):
"""
Forward pass of the Twhin model.
Args:
batch (EdgeBatch): Input batch containing graph edge information.
Returns:
dict: A dictionary containing model output with "logits" and "probabilities".
- "logits" (torch.Tensor): Logit scores.
- "probabilities" (torch.Tensor): Sigmoid probabilities.
"""
# B x D
trans_embs = self.all_trans_embs.data[batch.rels]
@ -98,6 +130,18 @@ class TwhinModel(nn.Module):
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
"""
Apply optimizers to the Twhin model's embeddings.
This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
Args:
model (TwhinModel): The Twhin model to apply optimizers to.
model_config (TwhinModelConfig): Configuration for the Twhin model.
Returns:
TwhinModel: The Twhin model with optimizers applied to its embeddings.
"""
for table in model_config.embeddings.tables:
optimizer_class = get_optimizer_class(table.optimizer)
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module):
device: torch.device,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
"""
Initialize a TwhinModelAndLoss module.
Args:
model: The torch module to wrap.
loss_fn: A function for calculating loss, should accept logits and labels.
data_config: Configuration for Twhin data.
device: The torch device to use for calculations.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module):
self.device = device
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
Run the model forward and calculate the loss according to the given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types
Args:
batch ("RecapBatch"): The input batch for model inference.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
additional outputs including logits, labels, and weights.
"""
outputs = self.model(batch)
logits = outputs["logits"]

View File

@ -18,6 +18,12 @@ EMB_DIM = 128
def twhin_model_config() -> TwhinModelConfig:
"""
Create a configuration for the Twhin model.
Returns:
TwhinModelConfig: The Twhin model configuration.
"""
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig:
def twhin_data_config() -> TwhinDataConfig:
"""
Create a configuration for the Twhin data.
Returns:
TwhinDataConfig: The Twhin data configuration.
"""
data_config = TwhinDataConfig(
data_root="/",
per_replica_batch_size=10,
@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig:
def test_twhin_model():
"""
Test the Twhin model creation and optimization.
This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
the device placement of model parameters.
Returns:
None
"""
model_config = twhin_model_config()
loss_fn = F.binary_cross_entropy_with_logits
@ -84,6 +105,23 @@ def test_twhin_model():
def test_unequal_dims():
"""
Test function for validating unequal embedding dimensions in TwhinEmbeddingsConfig.
This function tests whether the validation logic correctly raises a `ValidationError` when
embedding dimensions in the `TwhinEmbeddingsConfig` are not equal for all tables.
The test includes the following steps:
1. Create two embedding configurations with different embedding dimensions.
2. Attempt to create a `TwhinEmbeddingsConfig` instance with the unequal embedding dimensions.
3. Assert that a `ValidationError` is raised, indicating that embedding dimensions must match.
This function serves as a test case to ensure that the validation logic enforces equal embedding dimensions
in the `TwhinEmbeddingsConfig` for all tables.
Raises:
AssertionError: If the expected `ValidationError` is not raised.
"""
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
table0 = EmbeddingBagConfig(

View File

@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt"
def _lr_from_config(optimizer_config):
"""Get the learning rate from an optimizer configuration.
Args:
optimizer_config: Optimizer configuration.
Returns:
Learning rate from the optimizer configuration.
"""
if optimizer_config.learning_rate is not None:
return optimizer_config.learning_rate
else:
@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config):
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
Args:
model: TwhinModel to build optimizer for.
config: TwhinConfig for model.
Args:
model: TwhinModel to build optimizer for.
config: TwhinModelConfig for model.
Returns:
Optimizer for model.
"""
Returns:
Optimizer for model.
"""
translation_optimizer_fn = functools.partial(
get_optimizer_class(config.translation_optimizer),
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),

View File

@ -37,6 +37,12 @@ def run(
all_config: TwhinConfig,
save_dir: Optional[str] = None,
):
"""Run the training process for TwhinModel.
Args:
all_config (TwhinConfig): The configuration for the entire Twhin model.
save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None.
"""
train_dataset = create_dataset(all_config.train_data, all_config.model)
if env.is_reader():
@ -80,6 +86,11 @@ def run(
def main(argv):
"""Main entry point for the Twhin training script.
Args:
argv: Command-line arguments.
"""
logging.info("Starting")
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")

View File

@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging
class _Reader(pa.flight.FlightServerBase):
"""Distributed reader flight server wrapping a dataset."""
"""
Distributed reader flight server wrapping a dataset.
This class implements a Flight server that wraps a dataset, allowing clients to retrieve data
from the dataset over the Flight protocol. It is designed to be used in a distributed environment
for efficient data access.
Args:
location (str): The location of the Flight server.
ds (Dataset): The dataset to be wrapped by the Flight server.
Attributes:
_location (str): The location of the Flight server.
_ds (Dataset): The dataset wrapped by the Flight server.
Methods:
do_get(_, __): Handles Flight requests for data retrieval.
Note:
Flight is an Apache Arrow project that provides a framework for efficient data transfer.
This class allows clients to retrieve data from the dataset using Flight.
"""
def __init__(self, location: str, ds: "Dataset"):
"""
Initialize a new _Reader instance.
Args:
location (str): The location of the Flight server.
ds (Dataset): The dataset to be wrapped by the Flight server.
"""
super().__init__(location=location)
self._location = location
self._ds = ds
def do_get(self, _, __):
"""
Handle Flight requests for data retrieval.
This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol.
Args:
_: Unused argument.
__: Unused argument.
Returns:
pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset.
Note:
An updated schema (to account for column selection) must be given to the stream.
"""
# NB: An updated schema (to account for column selection) has to be given the stream.
schema = next(iter(self._ds.to_batches())).schema
batches = self._ds.to_batches()
@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase):
class Dataset(torch.utils.data.IterableDataset):
"""
A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading.
This class enables efficient loading of data from Parquet files using PyArrow.
It is designed to be used as an IterableDataset in PyTorch for training and inference.
Args:
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset
for more details.
Attributes:
LOCATION (str): The default location for the Flight server used for data distribution.
_file_pattern (str): The glob pattern specifying Parquet files in the dataset.
_fs: The filesystem object used for file operations.
_dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method.
_files (list): A list of file paths matching the glob pattern.
_schema (pa.Schema): The schema of the Parquet dataset.
Methods:
serve(): Start serving the dataset using a Flight server.
to_batches(): Generate batches of data from the Parquet dataset.
pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch.
dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset.
Note:
This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch
to create DataLoader instances for training or inference.
"""
LOCATION = "grpc://0.0.0.0:2222"
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
"""Specify batch size and column to select for.
"""
Initialize a new Dataset instance. Specify batch size and column to select for.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
"""
Args:
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
"""
self._file_pattern = file_pattern
self._fs = infer_fs(self._file_pattern)
self._dataset_kwargs = dataset_kwargs
@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset):
self._validate_columns()
def _validate_columns(self):
"""
Validate the specified columns against the dataset schema.
Raises:
Exception: If any specified columns are not found in the dataset schema.
"""
columns = set(self._dataset_kwargs.get("columns", []))
wrong_columns = set(columns) - set(self._schema.names)
if wrong_columns:
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
def serve(self):
"""Start serving the dataset using a Flight server."""
self.reader = _Reader(location=self.LOCATION, ds=self)
self.reader.serve()
def _create_dataset(self):
"""Create a PyArrow dataset for data retrieval."""
return pads.dataset(
source=random.sample(self._files, len(self._files))[0],
format="parquet",
@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset):
@abc.abstractmethod
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
"""
Convert a Parquet RecordBatch to a custom data batch.
Args:
batch (pa.RecordBatch): A batch of data from the Parquet dataset.
Returns:
DataclassBatch: A custom data batch used in PyTorch training.
Raises:
NotImplementedError: This method must be implemented in derived classes.
"""
raise NotImplementedError
def dataloader(self, remote: bool = False):
"""
Create a PyTorch DataLoader for iterating through the dataset.
Args:
remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training.
Returns:
DataLoader: A PyTorch DataLoader for iterating through the dataset.
Note:
If `remote` is True, a remote DataLoader is created for distributed training using Flight.
"""
if not remote:
return map(self.pa_to_batch, self.to_batches())
readers = get_readers(2)
@ -117,6 +230,25 @@ GRPC_OPTIONS = [
def get_readers(num_readers_per_worker: int):
"""
Get Flight readers for distributed data loading.
This function retrieves Flight readers for distributed data loading in a PyTorch environment.
Args:
num_readers_per_worker (int): The number of Flight readers to retrieve per worker.
Returns:
List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading.
Note:
Flight readers are used to fetch data in a distributed manner for efficient data loading.
Example:
To obtain Flight readers, use the following code:
>>> readers = get_readers(num_readers_per_worker=2)
"""
addresses = env.get_flight_server_addresses()
readers = []

View File

@ -21,6 +21,16 @@ import torch.distributed as dist
def maybe_start_dataset_service():
"""
Start the dataset service if readers are available and required dependencies are met.
This function checks if readers are available and if the required TensorFlow version is >= 2.5.
If both conditions are met and the current environment is the dispatcher or reader, it starts
the TensorFlow dataset service.
Raises:
Exception: If the required TensorFlow version is not met (>= 2.5).
"""
if not env.has_readers():
return
@ -59,6 +69,24 @@ def maybe_start_dataset_service():
def register_dataset(
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
):
"""
Register a dataset with the distributed dataset service.
This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
and job name to all processes in the distributed environment.
Args:
dataset (tf.data.Dataset): The dataset to be registered.
dataset_service (str): The name of the dataset service.
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
Returns:
Tuple[int, str]: A tuple containing the dataset ID and job name.
Note:
This function should be called on the rank 0 process.
"""
if dist.get_rank() == 0:
dataset_id = _register_dataset(
service=dataset_service,
@ -82,6 +110,23 @@ def distribute_from_dataset_id(
compression: Optional[str] = "AUTO",
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
) -> tf.data.Dataset:
"""
Distribute a dataset from a registered dataset ID.
This function consumes a dataset from the distributed dataset service using the provided dataset ID
and job name. It also supports prefetching for improved performance.
Args:
dataset_service (str): The name of the dataset service.
dataset_id (int): The ID of the dataset to be consumed.
job_name (Optional[str]): The name of the job associated with the dataset (optional).
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
Returns:
tf.data.Dataset: The distributed dataset.
"""
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
dataset = _from_dataset_id(
processing_mode="parallel_epochs",
@ -97,15 +142,28 @@ def distribute_from_dataset_id(
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
"""Torch-compatible and distributed-training-aware dataset service distributor.
- rank 0 process will register the given dataset.
- rank 0 process will broadcast job name and dataset id.
- all rank processes will consume from the same job/dataset.
Without this, dataset workers will try to serve 1 job per rank process and OOM.
"""
Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
This function is used to distribute a dataset in a distributed training environment. It performs the
following steps:
- On the rank 0 process, it registers the given dataset with the distributed dataset service.
- It broadcasts the job name and dataset ID to all rank processes.
- All rank processes then consume the same dataset from the distributed dataset service.
Args:
dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
Returns:
tf.data.Dataset: The distributed TensorFlow dataset.
Note:
- If there are no reader processes in the distributed environment, the original dataset is returned
without any distribution.
- This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
issues caused by each rank process trying to serve one job.
"""
if not env.has_readers():
return dataset
dataset_service = env.get_dds()

View File

@ -12,6 +12,17 @@ import torch
def create_dataset(tmpdir):
"""
Create a mock dataset for testing.
This function creates a mock dataset using PyArrow and Parquet for testing purposes.
Args:
tmpdir: A temporary directory where the dataset will be created.
Returns:
MockDataset: A mock dataset for testing.
"""
table = pa.table(
{
@ -34,6 +45,14 @@ def create_dataset(tmpdir):
def test_dataset(tmpdir):
"""
Test the created dataset.
This function tests the created mock dataset and checks if it behaves as expected.
Args:
tmpdir: A temporary directory used for testing.
"""
ds = create_dataset(tmpdir)
batch = next(iter(ds.dataloader(remote=False)))
assert batch.batch_size == 2
@ -46,6 +65,14 @@ def test_dataset(tmpdir):
reason="Multiprocessing doesn't work on github yet.",
)
def test_distributed_dataset(tmpdir):
"""
Test the distributed dataset.
This function tests the distributed version of the mock dataset using multiprocessing.
Args:
tmpdir: A temporary directory used for testing.
"""
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
def _client():

View File

@ -11,11 +11,55 @@ import torch
def roundrobin(*iterables):
"""Round robin through provided iterables, useful for simple load balancing.
Adapted from https://docs.python.org/3/library/itertools.html.
"""
Iterate through provided iterables in a round-robin fashion.
This function takes multiple iterables and returns an iterator that yields elements from
each iterable in a round-robin manner. It continues cycling through the iterables until
all of them are exhausted.
Adapted from https://docs.python.org/3/library/itertools.html.
Args:
*iterables: One or more iterable objects to iterate through.
Yields:
Elements from the provided iterables in a round-robin fashion.
Raises:
StopIteration: If all provided iterables are exhausted.
Example:
```python
iterable1 = [1, 2, 3]
iterable2 = ['a', 'b', 'c']
iterable3 = [0.1, 0.2, 0.3]
for item in roundrobin(iterable1, iterable2, iterable3):
print(item)
# Output:
# 1
# 'a'
# 0.1
# 2
# 'b'
# 0.2
# 3
# 'c'
# 0.3
```
Note:
- If one of the provided iterables is shorter than the others, the function will
continue iterating through the remaining iterables until all are exhausted.
- If an iterable raises an exception during iteration, a warning message is logged,
and the function continues with the next iterable.
See Also:
- `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
- `itertools.islice`: A function to slice an iterable to limit the number of iterations.
"""
num_active = len(iterables)
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
while num_active:
@ -35,6 +79,48 @@ def roundrobin(*iterables):
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
"""
Monitor the speed and progress of data loading using a data loader.
This function iterates through a data loader for a specified number of steps or until
the end of the data loader is reached, periodically logging progress information.
Args:
data_loader: The data loader to monitor.
max_steps: The maximum number of steps to iterate through the data loader.
frequency: The frequency (in steps) at which to log progress.
peek (optional): If specified, it indicates the frequency (in steps) at which to log
batch contents for inspection.
Example:
```python
import torch
from torch.utils.data import DataLoader
# Create a data loader (replace with your own DataLoader configuration)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Monitor data loading speed and progress
speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
```
Args:
data_loader: The data loader to monitor.
max_steps: The maximum number of steps to iterate through the data loader.
frequency: The frequency (in steps) at which to log progress.
peek (optional): If specified, it indicates the frequency (in steps) at which to log
batch contents for inspection.
Note:
- The function logs information about elapsed time, the number of examples processed,
and the processing speed in examples per second.
- If `peek` is provided, batch contents will be logged for inspection at the specified
frequency.
See Also:
- `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
iterating through datasets.
"""
num_examples = 0
prev = time.perf_counter()
for idx, batch in enumerate(data_loader):
@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]
def pa_to_torch(array: pa.array) -> torch.Tensor:
"""
Convert a PyArrow Array to a PyTorch Tensor.
Args:
array (pa.array): The PyArrow Array to convert.
Returns:
torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
Example:
```python
import pyarrow as pa
import torch
# Create a PyArrow Array
arrow_array = pa.array([1, 2, 3])
# Convert it to a PyTorch Tensor
torch_tensor = pa_to_torch(arrow_array)
```
"""
return torch.from_numpy(array.to_numpy())
def create_default_pa_to_batch(schema) -> DataclassBatch:
""" """
"""
Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
Args:
schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
Returns:
callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
Example:
```python
import pyarrow as pa
from dataclass_batch import DataclassBatch
# Define a PyArrow schema
schema = pa.schema([
("feature1", pa.float64()),
("feature2", pa.int64()),
("label", pa.int64()),
])
# Create the conversion function
pa_to_batch = create_default_pa_to_batch(schema)
# Create a PyArrow RecordBatch
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
"feature1": [1.0, 2.0, None],
"feature2": [10, 20, 30],
"label": [0, 1, None],
}))
# Convert the RecordBatch to a custom DataclassBatch
custom_batch = pa_to_batch(record_batch)
```
"""
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
def get_imputation_value(pa_type):