This commit is contained in:
Rockerz 2023-09-21 17:35:12 +00:00 committed by GitHub
commit a8258a5da4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 4507 additions and 1233 deletions

View File

@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable
class BatchBase(Pipelineable, abc.ABC):
"""
A base class for batches used in pipelines.
Attributes:
None
"""
@abc.abstractmethod
def as_dict(self) -> Dict:
"""
Convert the batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
raise NotImplementedError
def to(self, device: torch.device, non_blocking: bool = False):
"""
Move the batch to the specified device.
Args:
device (torch.device): The target device.
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.
Returns:
BatchBase: A new batch on the target device.
"""
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
return self.__class__(**args)
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
"""
Record a CUDA stream for all tensors in the batch.
Args:
stream (torch.cuda.streams.Stream): The CUDA stream to record.
Returns:
None
"""
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)
def pin_memory(self):
"""
Pin memory for all tensors in the batch.
Returns:
BatchBase: A new batch with pinned memory.
"""
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)
def __repr__(self) -> str:
"""
Generate a string representation of the batch.
Returns:
str: A string representation of the batch.
"""
def obj2str(v):
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC):
@property
def batch_size(self) -> int:
"""
Get the batch size from the tensors in the batch.
Returns:
int: The batch size.
Raises:
Exception: If the batch size cannot be determined from the tensors.
"""
for tensor in self.as_dict().values():
if tensor is None:
continue
@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC):
@dataclass
class DataclassBatch(BatchBase):
"""
A batch class that uses dataclasses to define its fields.
Attributes:
None
"""
@classmethod
def feature_names(cls):
"""
Get the feature names of the dataclass.
Returns:
List[str]: A list of feature names.
"""
return list(cls.__dataclass_fields__.keys())
def as_dict(self):
"""
Convert the dataclass batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
"""
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
@ -64,7 +147,18 @@ class DataclassBatch(BatchBase):
@staticmethod
def from_schema(name: str, schema):
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
"""
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.
Args:
name (str): The name of the custom batch class.
schema: The schema or structure of the batch.
Returns:
Type[DataclassBatch]: A custom batch class.
"""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
@ -73,6 +167,17 @@ class DataclassBatch(BatchBase):
@staticmethod
def from_fields(name: str, fields: dict):
"""
Create a custom batch subclass from a set of fields.
Args:
name (str): The name of the custom batch class.
fields (dict): A dictionary specifying the fields and their types.
Returns:
Type[DataclassBatch]: A custom batch class.
"""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
@ -81,5 +186,19 @@ class DataclassBatch(BatchBase):
class DictionaryBatch(BatchBase, dict):
"""
A batch class that represents data as a dictionary.
Attributes:
None
"""
def as_dict(self) -> Dict:
"""
Convert the dictionary batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
"""
return self

View File

@ -13,23 +13,30 @@ GCS_PREFIX = "gs://"
class Snapshot:
"""Checkpoints using torchsnapshot.
Also saves step to be updated by the training loop.
"""
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
"""
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
"""
Initializes a Snapshot object.
Args:
save_dir (str): Directory where checkpoints will be saved.
state (Dict[str, Any]): State dictionary containing checkpoint information.
"""
self.save_dir = save_dir
self.state = state
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property
def step(self):
"""Get the current training step."""
return self.state["extra_state"]["step"]
@step.setter
def step(self, step: int) -> None:
"""Set the current training step."""
self.state["extra_state"]["step"] = step
@property
@ -41,7 +48,15 @@ class Snapshot:
self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot":
"""Saves checkpoint with given global_step."""
"""
Saves a checkpoint with a given global step.
Args:
global_step (int): The global step to associate with the checkpoint.
Returns:
PendingSnapshot: A pending snapshot object.
"""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
start_time = time.time()
@ -58,7 +73,12 @@ class Snapshot:
return snapshot
def restore(self, checkpoint: str) -> None:
"""Restores a given checkpoint."""
"""
Restores a given checkpoint.
Args:
checkpoint (str): Path to the checkpoint to restore.
"""
snapshot = torchsnapshot.Snapshot(path=checkpoint)
logging.info(f"Restoring snapshot from {snapshot.path}.")
start_time = time.time()
@ -83,12 +103,17 @@ class Snapshot:
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> torchsnapshot.Snapshot:
"""Get torch stateless snapshot, without actually loading it.
Args:
snapshot_path: path to the model snapshot
global_step: restores from this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
Get a torch stateless snapshot, without actually loading it.
Args:
snapshot_path (str): Path to the model snapshot.
global_step (int, optional): Restores from this checkpoint if specified.
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
Returns:
torchsnapshot.Snapshot: A torch snapshot object.
"""
path = get_checkpoint(snapshot_path, global_step, missing_ok)
logging.info(f"Loading snapshot from {path}.")
return torchsnapshot.Snapshot(path=path)
@ -100,13 +125,14 @@ class Snapshot:
snapshot_emb_name: str,
weight_tensor,
) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
Args:
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
"""
Loads pretrained embedding from the snapshot to the model.
Args:
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
"""
start_time = time.time()
manifest = embedding_snapshot.get_manifest()
for path in manifest.keys():
@ -209,7 +235,22 @@ def get_checkpoint(
def get_checkpoints(save_dir: str) -> List[str]:
"""Gets all checkpoints that have been fully written."""
"""
Get a list of fully written checkpoints in the specified directory.
This function retrieves a list of fully written checkpoints in the given directory.
Checkpoints that are considered fully written include those that have a
corresponding snapshot metadata file.
Args:
save_dir (str): The directory where checkpoints are stored.
Returns:
List[str]: A list of fully written checkpoint paths.
Note:
Checkpoints are sorted by their numeric filenames in ascending order.
"""
checkpoints = []
fs = infer_fs(save_dir)
if fs.exists(save_dir):
@ -232,6 +273,18 @@ def wait_for_evaluators(
global_step: int,
timeout: int,
) -> None:
"""
Waits for all evaluators to finish and checks for their completion status.
Args:
save_dir (str): Directory where checkpoints are saved.
partition_names (List[str]): List of partition names to check for completion.
global_step (int): The global step for which to wait for evaluators.
timeout (int): Maximum time in seconds to wait for evaluators to finish.
Returns:
None: This function returns nothing but logs the progress and results.
"""
logging.info("Waiting for all evaluators to finish.")
start_time = time.time()

View File

@ -5,6 +5,15 @@ import torch.distributed as dist
def maybe_setup_tensorflow():
"""
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
Returns:
None
"""
try:
import tensorflow as tf
except ImportError:
@ -14,6 +23,18 @@ def maybe_setup_tensorflow():
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
"""
Set up the distributed environment and get the appropriate torch device.
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
Args:
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
Returns:
torch.device: The torch device for the current process.
"""
if tf_ok:
maybe_setup_tensorflow()

View File

@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem()
def infer_fs(path: str):
"""
Infer the file system (fs) type based on the given path.
Args:
path (str): The file path.
Returns:
str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system).
Raises:
NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported.
"""
if path.startswith("gs://"):
return GCS_FS
elif path.startswith("hdfs://"):
@ -18,8 +30,26 @@ def infer_fs(path: str):
def is_local_fs(fs):
"""
Check if the given file system is the local file system.
Args:
fs (str): The file system type to check.
Returns:
bool: True if the file system is the local file system, False otherwise.
"""
return fs == LOCAL_FS
def is_gcs_fs(fs):
"""
Check if the given file system is Google Cloud Storage (GCS).
Args:
fs (str): The file system type to check.
Returns:
bool: True if the file system is GCS, False otherwise.
"""
return fs == GCS_FS

View File

@ -49,15 +49,25 @@ def log_ebc_norms(
ebc_keys: List[str],
sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:
"""Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank.
Args:
model_state_dict: model.state_dict()
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
"""
Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank.
Args:
model_state_dict (dict): The state dictionary of the model.
ebc_keys (List[str]): List of embedding keys from state_dict to log.
Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight".
sample_size (int, optional): Limits the number of rows per rank to compute average on
to avoid out-of-memory (OOM) errors. Defaults to 4,000,000.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables.
The keys are in the format "{emb_key}-norm-{idx}".
Note:
This function computes and logs the average norm of embedding tables across ranks.
It gathers the norms from all ranks and returns them as a dictionary.
"""
norm_logs = dict()
for emb_key in ebc_keys:
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))

View File

@ -11,6 +11,20 @@ import numpy as np
class LargeEmbeddings(nn.Module):
"""
A module for handling large embeddings.
Args:
large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings.
Attributes:
ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings.
surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models.
Note:
The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving.
"""
def __init__(
self,
large_embeddings_config: LargeEmbeddingsConfig,
@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module):
self,
sparse_features: KeyedJaggedTensor,
) -> KeyedTensor:
"""
Forward pass of the LargeEmbeddings module.
Args:
sparse_features (KeyedJaggedTensor): Sparse input features.
Returns:
KeyedTensor: The output of the forward pass, a KeyedTensor.
"""
pooled_embs = self.ebc(sparse_features)
# a KeyedTensor

View File

@ -11,6 +11,12 @@ import torch.distributed.run
def is_distributed_worker():
"""
Checks if the current process is a distributed worker.
Returns:
bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False.
"""
world_size = os.environ.get("WORLD_SIZE", None)
rank = os.environ.get("RANK", None)
return world_size is not None and rank is not None
@ -25,25 +31,36 @@ def maybe_run_training(
is_chief: Optional[bool] = False,
**training_kwargs,
):
"""Wrapper function for single node, multi-GPU Pytorch training.
If the necessary distributed Pytorch environment variables
(WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module
`module_name`. After this call, the necessary environment variables are set
and training will commence.
Args:
train_fn: The function that is responsible for training
module_name: The name of the module that this function was called from;
used to indicate torchrun entrypoint.
nproc_per_node: Number of workers per node; supported values.
num_nodes: Number of nodes, otherwise inferred from environment.
is_chief: If process is running on chief.
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
"""
Wrapper function for single node, multi-GPU PyTorch training.
If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module
`module_name`. After this call, the necessary environment variables are set
and training will commence.
Args:
train_fn (callable): The function responsible for training.
module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint.
nproc_per_node (int, optional): Number of workers per node. Defaults to None.
num_nodes (int, optional): Number of nodes. Defaults to None.
is_chief (bool, optional): If the process is running on the chief node. Defaults to False.
set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False.
**training_kwargs: Additional keyword arguments to pass to the `train_fn`.
Note:
This function checks if the current process is a distributed worker by examining the environment variables.
If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary
environment variables and launches the training process using torchrun.
Example:
To run training on a single node with 4 GPUs, you can use:
```
maybe_run_training(train_function, __name__, nproc_per_node=4)
```
"""
machines = utils.machine_from_env()
if num_nodes is None:

View File

@ -19,6 +19,23 @@ MOCK_ENV = {
@contextmanager
def mock_pg():
"""
A context manager for mocking the distributed process group for testing purposes.
This context manager temporarily sets environment variables to mock the distributed process group
and initializes it using the Gloo backend. It is useful for testing distributed training without
actually launching multiple processes.
Example:
```python
with mock_pg():
# Your distributed training code here
```
Note:
This context manager should be used within a testing environment to simulate distributed training
without actually creating multiple processes.
"""
with patch.dict(os.environ, MOCK_ENV):
try:
dist.init_process_group(

View File

@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig)
def _read_file(f):
"""
Read the contents of a file using fsspec.
Args:
f: File path or URL.
Returns:
The contents of the file.
"""
with fsspec.open(f) as f:
return f.read()
@ -21,18 +30,24 @@ def setup_configuration(
yaml_path: str,
substitute_env_variable: bool = False,
) -> Tuple[C, str]:
"""Resolves a config at a yaml path.
Args:
config_type: Pydantic config class to load.
yaml_path: yaml path of the config file.
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
environment variable value whenever possible. If an environment variable doesn't exist,
the string is left unchanged.
Returns:
The pydantic config object.
"""
Load a Pydantic config object from a YAML file and optionally substitute environment variables.
Args:
config_type: Pydantic config class to load.
yaml_path: YAML path of the config file.
substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR}
with their environment variable values whenever possible.
If an environment variable doesn't exist, the string is left unchanged.
Returns:
A tuple containing the Pydantic config object and the resolved YAML content.
Example:
```python
config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True)
```
"""
def _substitute(s):
if substitute_env_variable:

View File

@ -6,6 +6,33 @@ import pydantic
class WandbConfig(base_config.BaseConfig):
"""
Configuration for integrating with Weights and Biases (WandB).
Attributes:
host (str): Host of the Weights and Biases instance, passed to login.
key_path (str): Path to the key file.
name (str): Name of the experiment, passed to init.
entity (str): Name of the user/service account, passed to init.
project (str): Name of the WandB project, passed to init.
tags (List[str]): List of tags, passed to init.
notes (str): Notes, passed to init.
metadata (Dict[str, Any]): Additional metadata to log.
Example:
```python
wandb_config = WandbConfig(
host="https://wandb.example.com",
key_path="/path/to/key",
name="experiment_1",
entity="user123",
project="my_project",
tags=["experiment", "ml"],
notes="This is a test experiment.",
metadata={"version": "1.0"}
)
```
"""
host: str = pydantic.Field(
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
description="Host of Weights and Biases instance, passed to login.",

View File

@ -8,59 +8,69 @@ import pydantic
class BaseConfig(pydantic.BaseModel):
"""Base class for all derived config classes.
"""Base class for all derived config classes.
This class provides some convenient functionality:
- Disallows extra fields when constructing an object. User error
should be reduced by exact arguments.
- "one_of" fields. A subclass can group optional fields and enforce
that only one of the fields be set. For example:
This class provides convenient functionality and constraints for derived config classes:
```
- Disallows extra fields when constructing an object. User errors due to extraneous arguments
are minimized.
- "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields
be set. For example:
```python
class ExampleConfig(BaseConfig):
x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1")
x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1")
ExampleConfig(x=1) # ok
ExampleConfig(y=1) # ok
ExampleConfig(x=1, y=1) # throws error
ExampleConfig(x=1) # OK
ExampleConfig(y=1) # OK
ExampleConfig(x=1, y=1) # Raises an error
```
"""
class Config:
"""Forbids extras."""
Attributes:
Config (class): Configuration options for this class, forbidding extra fields.
extra = pydantic.Extra.forbid # noqa
Methods:
_field_data_map(cls, field_data_name): Create a map of fields with the provided field data.
_one_of_check(cls, values): Validate that all 'one of' fields appear exactly once.
_at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once.
pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging.
@classmethod
@functools.lru_cache()
def _field_data_map(cls, field_data_name):
"""Create a map of fields with provided the field data."""
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
"""
class Config:
"""Configuration options that forbid extra fields."""
extra = pydantic.Extra.forbid # noqa
@pydantic.root_validator
def _one_of_check(cls, values):
"""Validate that all 'one of' fields are appear exactly once."""
one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values
@classmethod
@functools.lru_cache()
def _field_data_map(cls, field_data_name):
"""Create a map of fields with the provided field data."""
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
@pydantic.root_validator
def _at_most_one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
@pydantic.root_validator
def _one_of_check(cls, values):
"""Validate that all 'one of' fields appear exactly once."""
one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values
def pretty_print(self) -> str:
"""Return a human legible (yaml) representation of the config useful for logging."""
return yaml.dump(self.dict())
@pydantic.root_validator
def _at_most_one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
def pretty_print(self) -> str:
"""Return a human-readable (YAML) representation of the config useful for logging."""
return yaml.dump(self.dict())

View File

@ -6,33 +6,67 @@ import pydantic
class BaseConfigTest(TestCase):
def test_extra_forbidden(self):
class Config(BaseConfig):
x: int
"""
Unit tests for the BaseConfig class.
"""
Config(x=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=2)
def test_extra_forbidden(self):
"""
Test that extra fields are forbidden when creating a Config instance.
def test_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f")
y: int = pydantic.Field(None, one_of="f")
This test case checks whether the `BaseConfig` class correctly raises a
`pydantic.ValidationError` when extra fields are provided when creating a
`Config` instance.
with self.assertRaises(pydantic.ValidationError):
Config()
Config(x=1)
Config(y=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=3)
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig):
x: int
def test_at_most_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f")
y: str = pydantic.Field(None, at_most_one_of="f")
Config(x=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=2)
Config()
Config(x=1)
Config(y="a")
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y="a")
def test_one_of(self):
"""
Test the use of the `one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `one_of` attribute in a `Config`
instance. It verifies that the `pydantic.Field` correctly enforces the
specified constraint.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f")
y: int = pydantic.Field(None, one_of="f")
with self.assertRaises(pydantic.ValidationError):
Config()
Config(x=1)
Config(y=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=3)
def test_at_most_one_of(self):
"""
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `at_most_one_of` attribute in a
`Config` instance. It verifies that the `pydantic.Field` enforces the
constraint where at most one of the specified fields can be provided.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f")
y: str = pydantic.Field(None, at_most_one_of="f")
Config()
Config(x=1)
Config(y="a")
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y="a")

View File

@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
"""Recommend method to load a config file (a yaml file) and parse it.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
"""
Recommend method to Load and parse a configuration from a YAML file.
This function loads a configuration from a YAML file, parses it, and returns an instance of the
specified config type.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
Args:
config_type (Type[BaseConfig]): The Pydantic config class to load.
yaml_path (str): The path to the YAML configuration file.
Returns:
BaseConfig: An instance of the specified config type populated with values from the YAML file.
Example:
Suppose you have a YAML file 'my_config.yaml' containing the following:
```yaml
x: 42
y: "hello"
```
You can load and parse it using this function as follows:
```python
my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml')
```
Note:
This function performs environment variable substitution in the YAML file. It replaces
occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable
values. If an environment variable does not exist, the string is left unchanged.
"""
def _substitute(s):
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())

View File

@ -8,11 +8,27 @@ import pydantic
class _PointlessConfig(BaseConfig):
a: int
user: str
def test_load_config_from_yaml(tmp_path):
"""Test loading a configuration from a YAML file and verifying its values.
This test function checks the functionality of the `load_config_from_yaml` function by creating
a temporary YAML configuration file, loading it, and asserting that the loaded config object
has the expected values.
Args:
tmp_path: A temporary directory provided by the `pytest` framework.
Test Steps:
1. Create a temporary YAML file containing configuration data.
2. Use the `load_config_from_yaml` function to load the configuration from the YAML file.
3. Assert that the loaded configuration object has the expected values.
"""
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
with open(yaml_path, "w") as yaml_file:
yaml_file.write("""a: 3\nuser: ${USER}\n""")

View File

@ -9,6 +9,14 @@ import pydantic
class RuntimeConfig(base_config.BaseConfig):
"""
Configuration for runtime settings.
Attributes:
- wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration.
- enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices.
- enable_amp (bool): Enable automatic mixed precision.
"""
wandb: WandbConfig = pydantic.Field(None)
enable_tensorfloat32: bool = pydantic.Field(
False, description="Use tensorfloat32 if on Ampere devices."
@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig):
class TrainingConfig(base_config.BaseConfig):
"""
Configuration for training settings.
Attributes:
- save_dir (str): Directory to save checkpoints.
- num_train_steps (pydantic.PositiveInt): Number of training steps.
- initial_checkpoint_dir (str): Directory of initial checkpoints (optional).
- checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps.
- checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
- train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps.
- num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used.
- eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps.
- eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds.
- gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
- num_epochs (pydantic.PositiveInt): Number of training epochs.
"""
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
num_train_steps: pydantic.PositiveInt = 10000
initial_checkpoint_dir: str = pydantic.Field(

View File

@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable):
def _get_step_fn(pipeline, data_iterator, training: bool):
"""
Returns a function to perform a single evaluation step.
Args:
pipeline (Pipeline): The pipeline object containing the model.
data_iterator (Iterator): The data iterator for evaluation.
training (bool): Flag indicating if the model should be in training mode.
Returns:
function: A function that performs a single evaluation step.
"""
def step_fn():
"""
Perform a single evaluation step.
Returns:
Any: The evaluation results after a single step.
"""
# It turns out that model.train() and model.eval() simply switch a single field inside the model
# class,so it's somewhat safer to wrap in here.
if training:
@ -69,7 +86,21 @@ def _run_evaluation(
eval_batch_size: int,
logger=None,
):
"""Runs the evaluation loop over all evaluation iterators."""
"""
Run the evaluation loop over all evaluation iterators.
Args:
pipeline (Pipeline): The pipeline object containing the model.
dataset (Dataset): The dataset to evaluate.
eval_steps (int): The number of evaluation steps to perform.
metrics (tm.MetricCollection): A collection of evaluation metrics.
eval_batch_size (int): Batch size for evaluation.
logger (Optional[Logger]): A logger for recording evaluation progress (default: None).
Returns:
dict: A dictionary containing the computed evaluation metrics.
"""
dataset = get_new_iterator(dataset)
step_fn = _get_step_fn(pipeline, dataset, training=False)
last_time = datetime.datetime.now()
@ -109,15 +140,29 @@ def train(
parameters_to_log: Optional[Dict[str, Callable]] = None,
tables_to_log: Optional[List[str]] = None,
) -> None:
"""Runs training and eval on the given TrainPipeline
Args:
dataset: data iterator for the training set
evaluation_iterators: data iterators for the different evaluation sets
scheduler: optional learning rate scheduler
output_transform_for_metrics: optional transformation functions to transorm the model
output and labels into a format the metrics can understand
"""
Runs training and evaluation on the given TrainPipeline.
Args:
model (torch.nn.Module): The neural network model to train.
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
device (str): The target device for model training (e.g., 'cuda' or 'cpu').
save_dir (str): The directory to save model checkpoints and logs.
logging_interval (int): Interval for logging training progress.
train_steps (int): The number of training steps to perform.
checkpoint_frequency (int): Frequency of saving model checkpoints.
dataset (Iterable): Data iterator for the training set.
worker_batch_size (int): Batch size for data loading workers.
num_workers (Optional[int]): Number of data loading workers (default: 0).
enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False).
initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None).
gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None).
logger_initializer (Optional[Callable]): A logger initializer function (default: None).
scheduler (_LRScheduler): Optional learning rate scheduler (default: None).
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None).
tables_to_log (Optional[List[str]]): List of tables to log (default: None).
"""
train_pipeline = TrainPipelineSparseDist(
model=model,
@ -262,6 +307,15 @@ def log_eval_results(
partition_name: str,
step: int,
):
"""
Logs evaluation results and optionally records them using a provided logger.
Args:
results (Any): The evaluation results to log.
eval_logger (Callable): A logger for recording evaluation results.
partition_name (str): The name of the evaluation partition.
step (int): The current step in the evaluation.
"""
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
logging.info(f"Step: {step}, evaluation ({partition_name}).")
for metric_name, metric_value in results.items():
@ -285,6 +339,23 @@ def only_evaluate(
partition_name: str,
metrics: Optional[tm.MetricCollection] = None,
):
"""
Performs evaluation on a given dataset partition.
Args:
model (torch.nn.Module): The neural network model for evaluation.
optimizer (torch.optim.Optimizer): The optimizer used during evaluation.
device (str): The target device for evaluation (e.g., 'cuda' or 'cpu').
save_dir (str): The directory containing model checkpoints.
num_train_steps (int): The total number of training steps.
dataset (Iterable): Data iterator for evaluation.
eval_batch_size (int): Batch size for evaluation.
num_eval_steps (int): The number of evaluation steps to perform.
eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds.
eval_logger (Callable): A logger for recording evaluation results.
partition_name (str): The name of the evaluation partition.
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
"""
logging.info(f"Evaluating on partition {partition_name}.")
logging.info("Computing metrics:")
logging.info(metrics)

View File

@ -28,6 +28,18 @@ def train(
*args,
**kwargs,
) -> None:
"""
Debugging training loop. Do not use for actual model training.
Args:
model (torch.nn.Module): The neural network model.
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
train_steps (int): The number of training steps to perform.
dataset (Iterable): Data iterator for training data.
scheduler (_LRScheduler, optional): Learning rate scheduler (default: None).
*args: Additional arguments (ignored).
**kwargs: Additional keyword arguments (ignored).
"""
logging.warning("Running debug training loop, don't use for model training.")

View File

@ -10,8 +10,11 @@ import torch
def _maybe_warn(reduction: str):
"""
Warning for reduction different than mean.
"""
Emit a warning if the reduction method is different from 'mean'.
Args:
reduction (str): The reduction method being used.
"""
if reduction != "mean":
logging.warn(
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
@ -24,6 +27,16 @@ def build_loss(
loss_type: LossType,
reduction="mean",
):
"""
Build a loss function based on the specified loss type and reduction method.
Args:
loss_type (LossType): The type of loss to build.
reduction (str): The reduction method for the loss (default: 'mean').
Returns:
Callable: A loss function that takes logits and labels as input.
"""
_maybe_warn(reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
@ -35,11 +48,15 @@ def build_loss(
def get_global_loss_detached(local_loss, reduction="mean"):
"""
Perform all_reduce to obtain the global loss function using the provided reduction.
:param local_loss: The local loss of the current rank.
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
:return: The reduced & detached global loss.
"""
Perform all_reduce to obtain the global loss function using the provided reduction.
Args:
local_loss (torch.Tensor): The local loss of the current rank.
reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP.
Returns:
torch.Tensor: The reduced and detached global loss.
"""
if reduction != "mean":
logging.warn(
f"The reduction used in this function should be the same as the one used by "
@ -66,6 +83,19 @@ def build_multi_task_loss(
global_reduction="mean",
pos_weights=None,
):
"""
Build a multi-task loss function based on the specified loss type and configurations.
Args:
loss_type (LossType): The type of loss to build.
tasks (typing.List[str]): List of task names.
task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean').
global_reduction (str): Reduction method for the global loss (default: 'mean').
pos_weights (Optional): Positive class weights for tasks (default: None).
Returns:
Callable: A multi-task loss function that takes logits, labels, and weights as input.
"""
_maybe_warn(global_reduction)
_maybe_warn(task_loss_reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type]

View File

@ -36,9 +36,24 @@ import torchmetrics
class MetricMixin:
@abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
"""
Abstract method to transform model outputs into a dictionary of metrics.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
Returns:
Dict: A dictionary of computed metrics.
"""
...
def update(self, outputs: Dict[str, torch.Tensor]):
"""
Update the metrics based on model outputs.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
"""
results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification.
for value in results.values():
@ -49,6 +64,13 @@ class MetricMixin:
class TaskMixin:
def __init__(self, task_idx: int = -1, **kwargs):
"""
Initialize a TaskMixin instance.
Args:
task_idx (int): Index of the task associated with this mixin (default: -1).
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self._task_idx = task_idx
@ -59,13 +81,31 @@ class StratifyMixin:
stratifier=None,
**kwargs,
):
"""
Initialize a StratifyMixin instance.
Args:
stratifier: A stratifier for filtering outputs (default: None).
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs)
self._stratifier = stratifier
def maybe_apply_stratification(
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
) -> Dict[str, torch.Tensor]:
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
"""
Apply stratification to filter examples in the outputs.
Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
value_names (List[str]): Names of values to filter.
Returns:
Dict[str, torch.Tensor]: Filtered outputs.
"""
outputs = outputs.copy()
if not self._stratifier:
return outputs
@ -84,12 +124,20 @@ class StratifyMixin:
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
"""Returns new class using MetricMixin and given base_metric.
Functionally the same using inheritance, just saves some lines of code
if no need for class attributes.
"""
Returns a new class using MetricMixin and the given base_metric.
Functionally the same as using inheritance, but it saves some lines of code
if there's no need for class attributes.
Args:
base_metric (torchmetrics.Metric): The base metric class to prepend the transform to.
transform (Callable): The transformation function to prepend to the metric.
Returns:
Type: A new class that includes MetricMixin and the provided base_metric
with the specified transformation method.
"""
def transform_method(_self, *args, **kwargs):
return transform(*args, **kwargs)

View File

@ -15,6 +15,16 @@ def probs_and_labels(
outputs: Dict[str, torch.Tensor],
task_idx: int,
) -> Dict[str, torch.Tensor]:
"""
Extract probabilities and labels from model outputs.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
task_idx (int): Index of the task.
Returns:
Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors.
"""
preds = outputs["probabilities"]
target = outputs["labels"]
if task_idx >= 0:
@ -28,6 +38,11 @@ def probs_and_labels(
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
def transform(self, outputs):
"""
Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric.
This metric counts values after potential stratification and task selection.
"""
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
if self._task_idx >= 0:
@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
This metric calculates the mean metric value after potential stratification and task selection.
"""
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
This metric calculates the mean metric value using probabilities after potential stratification and task selection.
"""
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
value = outputs["probabilities"]
@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
"""
Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision.
This metric computes precision after potential stratification and task selection.
"""
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
"""
Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall.
This metric computes recall after potential stratification and task selection.
"""
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
AUC (Area Under the ROC Curve) metric class.
This metric computes the AUC metric based on the logits and labels in the model outputs.
Args:
num_samples (int): The number of samples used to compute AUC.
**kwargs: Additional keyword arguments.
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
"""
@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The ranks of all positives
Based on:
PosRanks metric class.
This metric computes the ranks of all positive examples based on the logits and labels
in the model outputs.
Args:
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
"""
@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The reciprocal of the ranks of all
Based on:
ReciprocalRank metric class.
This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels
in the model outputs.
Args:
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
"""
@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The fraction of positives that rank in the top K among their negatives
Note that this is basically precision@k
Based on:
HitAtK metric class.
This metric computes the fraction of positive examples that rank in the top K among their negatives,
which is equivalent to precision@K.
Args:
k (int): The value of K.
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
"""

View File

@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric
@dataclass
class MockStratifierConfig:
"""
Configuration dataclass for mocking a stratifier.
Args:
name (str): The name of the stratifier.
index (int): The index of the stratifier.
value (int): The value of the stratifier.
"""
name: str
index: int
value: int
class Count(MetricMixin, SumMetric):
"""
Count metric class that inherits from MetricMixin and SumMetric.
This metric counts occurrences.
"""
def transform(self, outputs):
return {"value": 1}
@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
def test_count_metric():
"""
Test function for the Count metric.
It checks if the Count metric correctly counts the number of examples.
"""
num_examples = 123
examples = [
{"stuff": 0},
@ -36,6 +56,12 @@ def test_count_metric():
def test_collections():
"""
Test function for metric collections.
It tests if metric collections correctly aggregate metrics.
"""
max_metric = Max()
count_metric = Count()
metric = MetricCollection([max_metric, count_metric])
@ -51,6 +77,12 @@ def test_collections():
def test_task_dependent_ctr():
"""
Test function for task-dependent Ctr (Click-Through Rate) metric.
It checks if the Ctr metric computes the correct value for different tasks.
"""
num_examples = 144
batch_size = 1024
outputs = [
@ -69,6 +101,13 @@ def test_task_dependent_ctr():
def test_stratified_ctr():
"""
Test function for the Stratified Ctr (Click-Through Rate) metric.
It checks if the Stratified Ctr metric computes the correct value for different tasks
and stratified samples.
"""
outputs = [
{
"stuff": 0,
@ -114,6 +153,12 @@ def test_stratified_ctr():
def test_auc():
"""
Test function for the AUC (Area Under the Curve) metric.
It checks if the AUC metric correctly computes the Area Under the ROC Curve.
"""
num_samples = 10000
metric = core_metrics.Auc(num_samples)
target = torch.tensor([0, 0, 1, 1, 1])
@ -131,6 +176,12 @@ def test_auc():
def test_pos_rank():
"""
Test function for the PosRanks metric.
It checks if the PosRanks metric correctly computes the ranks of positive samples.
"""
metric = core_metrics.PosRanks()
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
@ -147,6 +198,12 @@ def test_pos_rank():
def test_reciprocal_rank():
"""
Test function for the Reciprocal Rank metric.
It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks.
"""
metric = core_metrics.ReciprocalRank()
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
@ -163,6 +220,12 @@ def test_reciprocal_rank():
def test_hit_k():
"""
Test function for the Hit@K metric.
It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives.
"""
hit1_metric = core_metrics.HitAtK(1)
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])

View File

@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel
@dataclass
class MockDataclassBatch(DataclassBatch):
"""
Mock data class batch for testing purposes.
This class represents a batch of data with continuous features and labels.
Attributes:
continuous_features (torch.Tensor): Tensor containing continuous feature data.
labels (torch.Tensor): Tensor containing label data.
"""
continuous_features: torch.Tensor
labels: torch.Tensor
class MockModule(torch.nn.Module):
"""
Mock PyTorch module for testing purposes.
This module defines a simple neural network model with a linear layer
followed by a BCEWithLogitsLoss loss function.
Attributes:
model (torch.nn.Linear): The linear model layer.
loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function.
"""
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Linear(10, 1)
self.loss_fn = torch.nn.BCEWithLogitsLoss()
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the mock module.
Args:
batch (MockDataclassBatch): Input data batch with continuous features and labels.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions.
"""
pred = self.model(batch.continuous_features)
loss = self.loss_fn(pred, batch.labels)
return (loss, pred)
def create_batch(bsz: int):
"""
Create a mock data batch with random continuous features and labels.
Args:
bsz (int): Batch size.
Returns:
MockDataclassBatch: A batch of data with continuous features and labels.
"""
return MockDataclassBatch(
continuous_features=torch.rand(bsz, 10).float(),
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
@ -35,6 +72,13 @@ def create_batch(bsz: int):
def test_sparse_pipeline():
"""
Test function for the sparse pipeline with distributed model parallelism.
This function tests the behavior of the sparse training pipeline using
a mock module and data.
"""
device = torch.device("cpu")
model = MockModule().to(device)
@ -65,6 +109,15 @@ def test_sparse_pipeline():
def test_amp():
"""
Test automatic mixed-precision (AMP) training with the sparse pipeline.
This function tests the behavior of the sparse training pipeline with
automatic mixed-precision (AMP) enabled, using a mock module and data.
AMP allows for faster training by using lower-precision data types, such as
torch.bfloat16, while maintaining model accuracy.
"""
device = torch.device("cpu")
model = MockModule().to(device)

View File

@ -39,12 +39,42 @@ Out = TypeVar("Out")
class TrainPipeline(abc.ABC, Generic[In, Out]):
"""
Abstract base class for training pipelines.
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline.
"""
@abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Make progress in the training pipeline.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
pass
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
"""
Move a batch of data to a specified device.
Args:
batch (In): The input batch.
device (torch.device): The target device.
non_blocking (bool): If True, move the data asynchronously.
Returns:
In: The batch of data on the target device.
"""
assert isinstance(
batch, (torch.Tensor, Pipelineable)
), f"{type(batch)} must implement Pipelineable interface"
@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
"""
Wait for a batch of data on a specified stream.
Args:
batch (In): The input batch.
stream (Optional[Stream]): The CUDA stream to wait for.
Note:
This function is used for managing asynchronous CUDA operations.
"""
if stream is None:
return
torch.cuda.current_stream().wait_stream(stream)
@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N
class TrainPipelineBase(TrainPipeline[In, Out]):
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None:
Initialize the TrainPipelineBase.
_connect(dataloader_iter: Iterator[In]) -> None:
Establish a connection to the data loader and move the input data to the GPU.
progress(dataloader_iter: Iterator[In]) -> Out:
Execute a training iteration, including forward and backward passes.
"""
def __init__(
self,
@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> None:
"""
Initialize the TrainPipelineBase.
Args:
model (torch.nn.Module): The PyTorch model to be trained.
optimizer (torch.optim.Optimizer): The optimizer used for training.
device (torch.device): The target device for training (CPU or GPU).
"""
self._model = model
self._optimizer = optimizer
self._device = device
@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = False
def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Establish a connection to the data loader and move the input data to the GPU.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
"""
cur_batch = next(dataloader_iter)
self._cur_batch = cur_batch
with torch.cuda.stream(self._memcpy_stream):
@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = True
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Execute a training iteration, including forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
if not self._connected:
self._connect(dataloader_iter)
@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
class Tracer(torch.fx.Tracer):
"""
Custom tracer class for PyTorch models.
This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings.
Attributes:
proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing.
_leaf_modules (List[str]): List of qualified names of leaf modules.
"""
# Disable proxying buffers during tracing. Ideally, proxying buffers would
# be disabled, but some models are currently mutating buffer values, which
# causes errors during tracing. If those models can be rewritten to not do
@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer):
proxy_buffer_attributes = False
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
"""
Initialize the Tracer.
Args:
leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing.
"""
super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
"""
Check if a module is a leaf module during tracing.
Args:
m (torch.nn.Module): The PyTorch module.
module_qualified_name (str): The qualified name of the module.
Returns:
bool: True if the module is considered a leaf module, False otherwise.
"""
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
return True
return super().is_leaf_module(m, module_qualified_name)
@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer):
@dataclass
class TrainPipelineContext:
"""
Dataclass to store information related to the training pipeline context.
Attributes:
input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests.
module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts.
feature_processor_forwards (List[Any]): A list of feature processor forwards.
"""
# pyre-ignore [4]
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
@ -166,6 +279,14 @@ class TrainPipelineContext:
@dataclass
class ArgInfo:
"""
Dataclass to store information about arguments in the training pipeline.
Attributes:
input_attrs (List[str]): List of attribute names of the input batch.
is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem.
name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments.
"""
# attributes of input batch, e.g. batch.attr1.attr2 call
# will produce ["attr1", "attr2"]
input_attrs: List[str]
@ -177,6 +298,16 @@ class ArgInfo:
class PipelinedForward:
"""
Represents a pipelined forward pass operation.
Attributes:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
def __init__(
self,
name: str,
@ -185,6 +316,16 @@ class PipelinedForward:
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> None:
"""
Initialize a PipelinedForward instance.
Args:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
self._name = name
self._args = args
self._module = module
@ -193,6 +334,16 @@ class PipelinedForward:
# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
"""
Perform the pipelined forward pass operation.
Args:
*input: Variable-length positional arguments.
**kwargs: Variable-length keyword arguments.
Returns:
Awaitable: An awaitable object representing the forward pass result.
"""
assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable)
@ -230,10 +381,22 @@ class PipelinedForward:
@property
def name(self) -> str:
"""
Get the name of the forward pass.
Returns:
str: The name of the forward pass.
"""
return self._name
@property
def args(self) -> List[ArgInfo]:
"""
Get the list of argument information for the forward pass.
Returns:
List[ArgInfo]: List of argument information.
"""
return self._args
@ -242,6 +405,17 @@ def _start_data_dist(
batch: In,
context: TrainPipelineContext,
) -> None:
"""
Start data distribution for a list of pipelined modules.
Args:
pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules.
batch (In): The input batch.
context (TrainPipelineContext): The training pipeline context.
Returns:
None: This function doesn't return a value.
"""
context.input_dist_requests.clear()
context.module_contexts.clear()
for module in pipelined_modules:
@ -286,9 +460,17 @@ def _get_node_args_helper(
feature_processor_arguments: Optional[List[Node]] = None,
) -> Tuple[List[ArgInfo], int]:
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
Args:
arguments: The arguments to process.
num_found: The current count of arguments found.
feature_processor_arguments: Optional list of feature processor arguments.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found.
"""
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
@ -332,6 +514,16 @@ def _get_node_args_helper(
def _get_node_args(
node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]:
"""
Get argument information for a given node.
Args:
node (Node): The node to process.
feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found.
"""
num_found = 0
pos_arg_info_list, num_found = _get_node_args_helper(
node.args, num_found, feature_processor_nodes
@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper(
path: str,
unsharded_module_names: Set[str],
) -> bool:
"""
Get the names of unsharded modules in a model.
Args:
model (torch.nn.Module): The model to analyze.
path (str): The current path in the model hierarchy.
unsharded_module_names (Set[str]): A set to store the names of unsharded modules.
Returns:
bool: True if any sharded modules were found in the hierarchy, False otherwise.
"""
sharded_children = set()
for name, child in model.named_children():
curr_path = path + name
@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper(
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
"""
Returns a list of top level modules do not contain any sharded sub modules.
"""
Returns a list of top-level modules that do not contain any sharded sub-modules.
Args:
model (torch.nn.Module): The model to analyze.
Returns:
List[str]: A list of top-level module names without sharded sub-modules.
"""
unsharded_module_names: Set[str] = set()
_get_unsharded_module_names_helper(
@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]:
"""
Rewrites the model to enable pipelined execution for selected sharded modules.
This function traces the input model using a custom tracer and identifies sharded modules
that can be pipelined. It then creates PipelinedForward objects for these modules,
which enable pipelining during training.
Args:
model (torch.nn.Module): The input model to be rewritten.
context (TrainPipelineContext): The context containing information needed for pipelining.
dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution.
Returns:
List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution.
"""
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
"""
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
"""
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
Attributes:
synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines.
"""
synced_pipeline_id: Dict[int, int] = {}
@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
enable_grad_scaling: bool = True,
grad_accum: Optional[int] = None,
) -> None:
"""
Initializes the training pipeline.
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
"""
self._model = model
self._optimizer = optimizer
self._device = device
@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
self._grad_accum = grad_accum
def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Connects the training pipeline to data and prepares for forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
"""
# batch 1
with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter)
@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
Progresses through the training pipeline, performing forward and backward passes.
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
Returns:
Out: The output of the forward pass.
"""
should_step_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
def _sync_pipeline(self) -> None:
"""
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
"""
for module in self._pipelined_modules:
module.forward._context = self._context

View File

@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222
def on_kf():
return "SPEC_TYPE" in os.environ
"""Check if the code is running on Kubernetes with Kubeflow (KF) environment.
Returns:
bool: True if running on KF, False otherwise.
"""
return "SPEC_TYPE" in os.environ
def has_readers():
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env["dataset_worker"] is not None
return os.environ.get("HAS_READERS", "False") == "True"
"""Check if the current task has dataset workers.
Returns:
bool: True if the task has dataset workers, False otherwise.
"""
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env.get("dataset_worker") is not None
return os.environ.get("HAS_READERS", "False") == "True"
def get_task_type():
if on_kf():
return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"]
"""Get the type of the current task.
Returns:
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
"""
if on_kf():
return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"]
def is_chief() -> bool:
return get_task_type() == "chief"
"""Check if the current task is the 'chief'.
Returns:
bool: True if the current task is the 'chief', False otherwise.
"""
return get_task_type() == "chief"
def is_reader() -> bool:
return get_task_type() == "datasetworker"
"""Check if the current task is a 'datasetworker'.
Returns:
bool: True if the current task is a 'datasetworker', False otherwise.
"""
return get_task_type() == "datasetworker"
def is_dispatcher() -> bool:
return get_task_type() == "datasetdispatcher"
"""Check if the current task is a 'datasetdispatcher'.
Returns:
bool: True if the current task is a 'datasetdispatcher', False otherwise.
"""
return get_task_type() == "datasetdispatcher"
def get_task_index():
if on_kf():
pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1])
else:
raise NotImplementedError
"""Get the index of the current task.
Returns:
int: Task index.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf():
pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1])
else:
raise NotImplementedError
def get_reader_port():
if on_kf():
return KF_DDS_PORT
return SLURM_DDS_PORT
"""Get the port used by readers.
Returns:
int: Reader port.
"""
if on_kf():
return KF_DDS_PORT
return SLURM_DDS_PORT
def get_dds():
if not has_readers():
return None
dispatcher_address = get_dds_dispatcher_address()
if dispatcher_address:
return f"grpc://{dispatcher_address}"
else:
raise ValueError("Job does not have DDS.")
"""Get the Distributed Data Service (DDS) address.
Returns:
str: DDS address in the format 'grpc://host:port'.
Raises:
ValueError: If the job does not have DDS.
"""
if not has_readers():
return None
dispatcher_address = get_dds_dispatcher_address()
if dispatcher_address:
return f"grpc://{dispatcher_address}"
else:
raise ValueError("Job does not have DDS.")
def get_dds_dispatcher_address():
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
dds_host = f"{job_name}-datasetdispatcher-0"
else:
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
return f"{dds_host}:{get_reader_port()}"
"""Get the DDS dispatcher address.
Returns:
str: DDS dispatcher address in the format 'host:port'.
"""
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
dds_host = f"{job_name}-datasetdispatcher-0"
else:
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
return f"{dds_host}:{get_reader_port()}"
def get_dds_worker_address():
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
task_index = get_task_index()
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
else:
node = os.environ["SLURMD_NODENAME"]
return f"{node}:{get_reader_port()}"
"""Get the DDS worker address.
Returns:
str: DDS worker address in the format 'host:port'.
"""
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
task_index = get_task_index()
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
else:
node = os.environ["SLURMD_NODENAME"]
return f"{node}:{get_reader_port()}"
def get_num_readers():
if not has_readers():
return 0
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env["num_dataset_workers"] or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
"""Get the number of dataset workers.
Returns:
int: Number of dataset workers.
"""
if not has_readers():
return 0
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env.get("num_dataset_workers") or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
def get_flight_server_addresses():
if on_kf():
job_name = os.environ["JOB_NAME"]
return [
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
for task_index in range(get_num_readers())
]
else:
raise NotImplementedError
"""Get Flight server addresses for dataset workers.
Returns:
List[str]: List of Flight server addresses in the format 'grpc://host:port'.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf():
job_name = os.environ["JOB_NAME"]
return [
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
for task_index in range(get_num_readers())
]
else:
raise NotImplementedError
def get_dds_journaling_dir():
return os.environ.get("DATASET_JOURNALING_DIR", None)
"""Get the DDS journaling directory.
Returns:
str: DDS journaling directory.
"""
return os.environ.get("DATASET_JOURNALING_DIR", None)

View File

@ -14,18 +14,21 @@ def update_mean(
weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update the mean according to Welford formula:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
Update the mean according to the Welford formula.
This function updates the mean and the weighted sum of values using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
Args:
current_mean: The value of the current accumulated mean.
current_weight_sum: The current weighted sum.
value: The new value that needs to be added to get a new mean.
weight: The weights for the new value.
Returns: The updated mean and updated weighted sum.
Args:
current_mean (torch.Tensor): The value of the current accumulated mean.
current_weight_sum (torch.Tensor): The current weighted sum.
value (torch.Tensor): The new value that needs to be added to get a new mean.
weight (torch.Tensor): The weights for the new value.
"""
Returns:
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
"""
weight = torch.broadcast_to(weight, value.shape)
# Avoiding (on purpose) in-place operation when using += in case
@ -38,12 +41,15 @@ def update_mean(
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
"""
Merge the state from multiple workers.
Args:
state: A tensor with the first dimension indicating workers.
Returns: The accumulated mean from all workers.
This function merges the state from multiple workers to compute the accumulated mean.
"""
Args:
state (torch.Tensor): A tensor with the first dimension indicating workers.
Returns:
torch.Tensor: The accumulated mean from all workers.
"""
mean, weight_sum = update_mean(
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
class StableMean(torchmetrics.Metric):
"""
This implements a numerical stable mean metrics computation using Welford algorithm according to
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
A numerical stable mean metric using the Welford algorithm.
This class implements a numerical stable mean metrics computation using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
than the maximum float32 as far as the mean is within the limit of float32.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
"""
Args:
**kwargs: Additional parameters supported by all torchmetrics.Metric.
Attributes:
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
"""
def __init__(self, **kwargs):
"""
@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric):
)
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
"""
Update the current mean.
Args:
value: Value to update the mean with.
weight: weight to use. Shape should be broadcastable to that of value.
"""
"""Update the current mean.
Args:
value (torch.Tensor): Value to update the mean with.
weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
"""
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
if not isinstance(weight, torch.Tensor):
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
)
def compute(self) -> torch.Tensor:
"""
Compute and return the accumulated mean.
"""
"""Compute and return the accumulated mean.
Returns:
torch.Tensor: The accumulated mean.
"""
return self.mean_and_weight_sum[0]

View File

@ -29,6 +29,9 @@ def _compute_helper(
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
we assume that they are correct prediction (i.e weight = 0).
Returns:
torch.Tensor: The computed AUROC
"""
dim = 0
@ -52,24 +55,34 @@ def _compute_helper(
class AUROCWithMWU(torchmetrics.Metric):
"""
AUROC using Mann-Whitney U-test.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return
the correct AUROC even if the predicted probabilities are all close to 0.
Currently only support binary classification.
"""
This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
Note: Currently, this implementation only supports binary classification.
Args:
label_threshold (float): Threshold for classifying labels as positive or negative.
Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
"""
Initializes the AUROCWithMWU metric.
Args:
label_threshold: Labels strictly above this threshold are considered positive labels,
otherwise, they are considered negative.
raise_missing_class: If True, an error will be raise if negative or positive class is missing.
Otherwise, we will simply log a warning.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
Args:
label_threshold (float): Threshold for classifying labels as positive or negative.
Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
super().__init__(**kwargs)
self.add_state("predictions", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

View File

@ -14,7 +14,24 @@ from absl import logging as logging
def setup_absl_logging():
"""Make sure that absl logging pushes to stdout rather than stderr."""
"""
Configure absl-py logging to direct log messages to stdout and apply a custom log message format.
This function ensures that log messages generated by the absl-py library are written to stdout
rather than stderr. It also applies a custom log message format that includes module, function,
line number, log level, and the log message content.
Note:
This function should be called once at the beginning of your script or application to
configure absl-py logging.
Example:
To use this function, simply call it at the start of your script:
```
setup_absl_logging()
```
"""
logging.get_absl_handler().python_handler.stream = sys.stdout
formatter = py_logging.Formatter(
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"

View File

@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging
class Testtlogging(unittest.TestCase):
def test_warn_once(self):
"""
Test that warning messages are logged only once when using the assertLogs context manager.
This unit test checks the behavior of the logging system when warning messages are issued
multiple times within the same context. It uses the assertLogs context manager to capture
log messages at the INFO level and verifies that warning messages are logged only once.
Example:
To use this test case, call it using a test runner like unittest:
```
python -m unittest your_test_module.TestLogging.test_warn_once
```
"""
with self.assertLogs(level="INFO") as captured_logs:
logging.info("first info")
logging.warning("first warning")

View File

@ -18,7 +18,35 @@ import torch.distributed as dist
def rank_specific(logger):
"""Ensures that we only override a given logger once."""
"""
Customize logger behavior based on the distributed environment and rank.
This function allows for customizing the behavior of a logger based on the distributed environment and the rank
of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
depending on the rank or limit the number of redundant logs.
Args:
logger: The logger object to customize.
Returns:
The customized logger.
Example:
To use this function with the `logging` module:
```python
import logging
from rank_specific_logging import rank_specific
logger = logging.getLogger(__name__)
rank_specific(logger)
```
Customization:
- Messages are only logged if the distributed environment is not initialized or if the rank matches.
- The 'warning' method is limited to logging a single redundant warning.
- Logging from rank -1 is redirected to include the rank information.
"""
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
return logger

View File

@ -54,13 +54,21 @@ def maybe_shard_model(
model,
device: torch.device,
):
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment.
"""
Set up and apply DistributedModelParallel to a model if running in a distributed environment.
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
DistributedModelParallel.
If not in a distributed environment, returns model directly.
"""
If not in a distributed environment, returns the model directly.
Args:
model: The PyTorch model.
device: The target device (e.g., 'cuda').
Returns:
The model wrapped with DistributedModelParallel if in a distributed environment, else the original model.
"""
if dist.is_initialized():
logging.info("***** Wrapping in DistributedModelParallel *****")
logging.info(f"Model before wrapping: {model}")
@ -74,14 +82,15 @@ def maybe_shard_model(
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
"""Handy function to log the content of EBC embedding layer.
Only works for single GPU machines.
Args:
weight_name: name of tensor, as defined in model
table_name: name of the EBC table the weight is taken from
weight_tensor: embedding weight tensor
"""
Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer.
Only works for single GPU machines.
Args:
weight_name: Name of the tensor, as defined in the model.
table_name: Name of the EBC table the weight is taken from.
weight_tensor: Embedding weight tensor.
"""
logging.info(f"{weight_name}, {table_name}", rank=-1)
logging.info(f"{weight_tensor.metadata()}", rank=-1)
output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0"))

View File

@ -8,11 +8,60 @@ import pydantic
class PiecewiseConstant(base_config.BaseConfig):
"""
Configuration for a piecewise constant learning rate schedule.
This configuration class allows you to specify a piecewise constant learning rate schedule
by defining boundaries and corresponding learning rate values.
Attributes:
learning_rate_boundaries (List[int], optional): List of step boundaries at which
the learning rate will change. If None, no boundaries are defined.
learning_rate_values (List[float], optional): List of learning rate values
corresponding to the boundaries. If None, no values are defined.
Example:
To configure a piecewise constant learning rate schedule, create an instance of this class
and set the attributes accordingly. For example:
```python
piecewise_lr = PiecewiseConstant(
learning_rate_boundaries=[1000, 2000, 3000],
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
)
```
Note:
The number of learning rate values should be one more than the number of boundaries.
"""
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
learning_rate_values: typing.List[float] = pydantic.Field(None)
class LinearRampToConstant(base_config.BaseConfig):
"""
Configuration for a linear ramp-up to constant learning rate schedule.
This configuration class allows you to specify a learning rate schedule that ramps up linearly
from zero to a constant value over a specified number of steps.
Attributes:
learning_rate (float): The final constant learning rate.
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
Example:
To configure a linear ramp-up to a constant learning rate, create an instance of this class
and set the attributes accordingly. For example:
```python
linear_ramp_lr = LinearRampToConstant(
learning_rate=0.1,
num_ramp_steps=1000
)
```
"""
learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
description="Number of steps to ramp this up from zero."
@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig):
class LinearRampToCosine(base_config.BaseConfig):
"""
Configuration for a linear ramp-up to cosine decay learning rate schedule.
This configuration class allows you to specify a learning rate schedule that ramps up linearly
from zero, then decays following a cosine schedule to a final constant learning rate.
Attributes:
learning_rate (float): The initial learning rate at the start of ramp-up.
final_learning_rate (float): The final constant learning rate after decay.
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
final_num_steps (PositiveInt): Final number of steps where decay stops.
Example:
To configure a linear ramp-up to cosine decay learning rate, create an instance of this
class and set the attributes accordingly. For example:
```python
ramp_to_cosine_lr = LinearRampToCosine(
learning_rate=0.01,
final_learning_rate=0.001,
num_ramp_steps=1000,
final_num_steps=5000
)
```
"""
learning_rate: float
final_learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig):
class LearningRate(base_config.BaseConfig):
"""
Learning rate configuration for training.
This configuration class allows you to specify different learning rate schedules
for your training process.
Attributes:
constant (float, optional): Constant learning rate to be used throughout training.
linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly
and then decays following a cosine schedule.
linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up
linearly and then remains constant.
piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified
boundaries with corresponding values.
Example:
To configure a learning rate schedule, create an instance of this class and set the
attributes accordingly. For example:
```python
learning_rate = LearningRate(
constant=0.01,
linear_ramp_to_cosine=LinearRampToCosine(
learning_rate=0.1,
final_learning_rate=0.001,
num_ramp_steps=1000,
final_num_steps=5000
)
)
```
Note:
Each learning rate schedule attribute can be set to `None` if not needed.
"""
constant: float = pydantic.Field(None, one_of="lr")
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig):
class OptimizerAlgorithmConfig(base_config.BaseConfig):
"""Base class for optimizer configurations."""
"""
Base class for optimizer configurations.
This base configuration class provides a structure for specifying various optimizer-related
settings, including the learning rate and different learning rate schedules.
Attributes:
lr (float): The base learning rate used by the optimizer.
Subclasses should inherit from this base class and define additional attributes specific to
the optimizer algorithm they represent.
Example:
To create a custom optimizer configuration, create a subclass of this base class and
define the necessary attributes. For example:
```python
class MyOptimizerConfig(OptimizerAlgorithmConfig):
momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.")
```
Note:
This base class does not include specific optimizer settings. Subclasses should define
the optimizer-specific attributes as needed.
"""
lr: float
...
class AdamConfig(OptimizerAlgorithmConfig):
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
"""
Configuration for the Adam optimizer.
This configuration class allows you to specify the hyperparameters for the Adam optimizer.
Attributes:
lr (float): The learning rate for optimization.
betas (Tuple[float, float], optional): Coefficients used for computing running averages
of gradient and squared gradient. Defaults to (0.9, 0.999).
eps (float, optional): A small constant added to the denominator for numerical stability.
Defaults to 1e-7.
Example:
To configure the Adam optimizer, create an instance of this class and set the attributes
accordingly. For example:
```python
adam_optimizer = AdamConfig(
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8
)
```
See Also:
[PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
"""
lr: float
betas: typing.Tuple[float, float] = [0.9, 0.999]
eps: float = 1e-7 # Numerical stability in denominator.
class SgdConfig(OptimizerAlgorithmConfig):
"""
Configuration for the Stochastic Gradient Descent (SGD) optimizer.
This configuration class allows you to specify the hyperparameters for the SGD optimizer.
Attributes:
lr (float): The learning rate for optimization.
momentum (float, optional): The momentum factor for SGD. Defaults to 0.0.
Example:
To configure the SGD optimizer, create an instance of this class and set the attributes
accordingly. For example:
```python
sgd_optimizer = SgdConfig(
lr=0.01,
momentum=0.9
)
```
"""
lr: float
momentum: float = 0.0
class AdagradConfig(OptimizerAlgorithmConfig):
"""
Configuration for the optimizer used during training.
This configuration class allows you to specify the optimizer for training, including
options for various optimizer algorithms.
Attributes:
learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None.
adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None.
sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
Defaults to None.
adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None.
Example:
To configure the optimizer for training, create an instance of this class and set the
attributes accordingly. For example:
```python
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.001),
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
)
```
"""
lr: float
eps: float = 0
class OptimizerConfig(base_config.BaseConfig):
"""
Configuration for defining different optimizer algorithms and their parameters.
This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad,
along with their respective hyperparameters.
Args:
learning_rate (LearningRate): The learning rate configuration, which can include
constant learning rates or other learning rate schedules.
adam (AdamConfig): Configuration for the Adam optimizer.
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
Example:
```python
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.001),
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8),
)
```
Attributes:
learning_rate (LearningRate): The learning rate configuration.
adam (AdamConfig): Configuration for the Adam optimizer.
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
Note:
You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an
`OptimizerConfig` instance.
See Also:
- `LearningRate`: Configuration for specifying learning rates.
- `AdamConfig`: Configuration for the Adam optimizer.
- `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer.
- `AdagradConfig`: Configuration for the Adagrad optimizer.
"""
learning_rate: LearningRate = pydantic.Field(
None,
description="Constant learning rates",
@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig):
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
"""
Get the optimizer algorithm configuration from the given `OptimizerConfig`.
This function extracts and returns the specific optimizer algorithm configuration
(e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`.
Args:
optimizer_config (OptimizerConfig): The optimizer configuration object containing
one of the optimizer algorithm configurations.
Returns:
Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm
configuration extracted from `optimizer_config`.
Raises:
ValueError: If no optimizer algorithm is selected in `optimizer_config`.
Example:
```python
optimizer_config = OptimizerConfig(
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
)
algorithm_config = get_optimizer_algorithm_config(optimizer_config)
# `algorithm_config` will be an instance of `AdamConfig`.
```
"""
if optimizer_config.adam is not None:
return optimizer_config.adam
elif optimizer_config.sgd is not None:

View File

@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step):
"""Compute a learning rate."""
"""
Compute the learning rate based on the specified learning rate configuration.
This function calculates the learning rate according to the given configuration, which can include
constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
Args:
lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
step (int): The current training step or iteration.
Returns:
float: The computed learning rate for the current step.
Raises:
ValueError: If the `lr_config` is invalid or contains conflicting options.
Example:
```python
lr_schedule = LearningRate(
constant=0.001,
piecewise_constant=PiecewiseConstant(
learning_rate_boundaries=[1000, 2000, 3000],
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
)
)
current_step = 2500
learning_rate = compute_lr(lr_schedule, current_step)
```
"""
if lr_config.constant is not None:
return lr_config.constant
elif lr_config.piecewise_constant is not None:
@ -46,11 +74,54 @@ def compute_lr(lr_config, step):
class LRShim(_LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
Learning Rate Scheduler Shim to adjust learning rates during training.
This class acts as a shim to apply different learning rates to individual parameter groups
within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
optimizers, allowing fine-grained control over learning rates based on configuration.
Args:
optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
corresponding learning rate configurations.
last_epoch (int, optional): The index of the last epoch. Default is -1.
verbose (bool, optional): If True, prints a warning message when accessing learning rates
using the deprecated `get_lr()` method. Default is False.
Raises:
ValueError: If the number of parameter groups in the optimizer does not match the number
of learning rate configurations provided.
Note:
To obtain the last computed learning rates, please use `get_last_lr()`.
Example:
```python
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
lr_schedule = {
'main': LearningRate(constant=0.01),
'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
learning_rate_boundaries=[1000, 2000],
learning_rate_values=[0.01, 0.001]
))
}
lr_shim = LRShim(optimizer, lr_schedule)
for epoch in range(num_epochs):
# Train the model
train(...)
# Update learning rates at the end of each epoch
lr_shim.step(epoch)
final_lr_main = lr_shim.get_last_lr()['main']
final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
```
See Also:
- `LearningRate`: Configuration for specifying learning rates.
- `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
"""
def __init__(
self,
@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig):
def build_optimizer(
model: torch.nn.Module, optimizer_config: OptimizerConfig
) -> Tuple[Optimizer, _LRScheduler]:
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
"""
Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
Args:
model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
algorithm and learning rate settings.
Returns:
Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
objects.
Note:
This function is intended for cases where you want the same optimizer and learning rate
schedule for all model parameters.
Example:
```python
model = MyModel()
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.01),
sgd=SgdConfig(lr=0.01, momentum=0.9)
)
optimizer, scheduler = build_optimizer(model, optimizer_config)
for epoch in range(num_epochs):
# Train the model with the optimizer
train(model, optimizer, ...)
# Update learning rates at the end of each epoch
scheduler.step(epoch)
```
See Also:
- `OptimizerConfig`: Configuration for specifying optimizer settings.
- `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
"""
optimizer_class = get_optimizer_class(optimizer_config)
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
# We're passing everything in as one group here

View File

@ -9,6 +9,36 @@ import pydantic
class TrainingConfig(config_mod.BaseConfig):
"""
Configuration settings for the training process.
This class defines various training-related settings, including the directory to save checkpoints, the number
of training steps, logging intervals, and other training parameters.
Attributes:
save_dir (str): The directory where checkpoints and training artifacts will be saved.
num_train_steps (pydantic.PositiveInt): The total number of training steps to run.
initial_checkpoint_dir (str): The directory containing initial checkpoints (optional).
checkpoint_every_n (pydantic.PositiveInt): Frequency of saving checkpoints during training.
checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
train_log_every_n (pydantic.PositiveInt): Frequency of logging training progress.
num_eval_steps (int): Number of evaluation steps. Use a negative value to evaluate the entire dataset.
eval_log_every_n (pydantic.PositiveInt): Frequency of logging evaluation progress.
eval_timeout_in_s (pydantic.PositiveFloat): Maximum time (in seconds) allowed for evaluation.
gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
Example:
To configure training with checkpoints saved every 1000 steps, use the following settings:
```python
TrainingConfig(
save_dir="/tmp/model",
num_train_steps=1000000,
checkpoint_every_n=1000,
train_log_every_n=1000,
)
```
"""
save_dir: str = "/tmp/model"
num_train_steps: pydantic.PositiveInt = 1000000
initial_checkpoint_dir: str = pydantic.Field(
@ -32,6 +62,42 @@ class TrainingConfig(config_mod.BaseConfig):
class RecapConfig(config_mod.BaseConfig):
"""
Configuration settings for the Recap model training process.
This class defines the overall configuration for the training process of a Recap model. It includes settings for
training, model architecture, data, optimization, and evaluation.
Attributes:
training (TrainingConfig): Configuration settings for the training process.
model (model_config.ModelConfig): Configuration settings for the Recap model architecture.
train_data (data_config.RecapDataConfig): Configuration settings for training data.
validation_data (Dict[str, data_config.RecapDataConfig]): Configuration settings for validation data.
optimizer (optimizer_config.RecapOptimizerConfig): Configuration settings for optimization.
which_metrics (Optional[str]): Optional specification of which metrics to pick.
Note:
This class encapsulates all the necessary configurations to train a Recap model. It defines settings for
training, the model architecture, data loading, optimization, and evaluation.
Example:
To configure a Recap model training process, use the following settings:
```python
RecapConfig(
training=TrainingConfig(
save_dir="/tmp/model",
num_train_steps=1000000,
checkpoint_every_n=1000,
train_log_every_n=1000,
),
model=model_config.ModelConfig(...),
train_data=data_config.RecapDataConfig(...),
validation_data={"dev": data_config.RecapDataConfig(...)},
optimizer=optimizer_config.RecapOptimizerConfig(...),
)
```
"""
training: TrainingConfig = pydantic.Field(TrainingConfig())
model: model_config.ModelConfig
train_data: data_config.RecapDataConfig

View File

@ -8,239 +8,250 @@ import pydantic
class ExplicitDateInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_date and days of data."""
"""Arguments to select train/validation data using end_date and days of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
num_missing_days_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_days_tol days of missing data."
)
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_date: str = pydantic.Field(...,
description="Data end date, inclusive.")
days: int = pydantic.Field(...,
description="Number of days of data for dataset.")
num_missing_days_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_days_tol days of missing data."
)
class ExplicitDatetimeInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_datetime and hours of data."""
"""Arguments to select train/validation data using end_datetime and hours of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
num_missing_hours_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
)
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_datetime: str = pydantic.Field(...,
description="Data end datetime, inclusive.")
hours: int = pydantic.Field(...,
description="Number of hours of data for dataset.")
num_missing_hours_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
)
class DdsCompressionOption(str, Enum):
"""The only valid compression option is 'AUTO'"""
"""The only valid compression option is 'AUTO'"""
AUTO = "AUTO"
AUTO = "AUTO"
class DatasetConfig(base_config.BaseConfig):
inputs: str = pydantic.Field(
None, description="A glob for selecting data.", one_of="date_inputs_format"
)
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
None, one_of="date_inputs_format"
)
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
inputs: str = pydantic.Field(
None, description="A glob for selecting data.", one_of="date_inputs_format"
)
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
None, one_of="date_inputs_format"
)
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(
None, one_of="date_inputs_format")
global_batch_size: pydantic.PositiveInt
global_batch_size: pydantic.PositiveInt
num_files_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Number of shards to keep."
)
repeat_files: bool = pydantic.Field(
True, description="DEPRICATED. Files are repeated no matter what this is set to."
)
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
num_files_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Number of shards to keep."
)
repeat_files: bool = pydantic.Field(
True, description="DEPRICATED. Files are repeated no matter what this is set to."
)
file_batch_size: pydantic.PositiveInt = pydantic.Field(
16, description="File batch size")
cache: bool = pydantic.Field(
False,
description="Cache dataset in memory. Careful to only use this when you"
" have enough memory to fit entire dataset.",
)
cache: bool = pydantic.Field(
False,
description="Cache dataset in memory. Careful to only use this when you"
" have enough memory to fit entire dataset.",
)
data_service_dispatcher: str = pydantic.Field(None)
ignore_data_errors: bool = pydantic.Field(
False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs."
)
dataset_service_compression: DdsCompressionOption = pydantic.Field(
None,
description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'",
)
data_service_dispatcher: str = pydantic.Field(None)
ignore_data_errors: bool = pydantic.Field(
False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs."
)
dataset_service_compression: DdsCompressionOption = pydantic.Field(
None,
description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'",
)
# tf.data.Dataset options
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of parallel calls."
)
interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of shards to interleave."
)
# tf.data.Dataset options
examples_shuffle_buffer_size: int = pydantic.Field(
1024, description="Size of shuffle buffers.")
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of parallel calls."
)
interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of shards to interleave."
)
class TruncateAndSlice(base_config.BaseConfig):
# Apply truncation and then slice.
continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates continuous features to this amount for efficiency."
)
binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates binary features to this amount for efficiency."
)
# Apply truncation and then slice.
continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates continuous features to this amount for efficiency."
)
binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates binary features to this amount for efficiency."
)
continuous_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input continuous features."
)
binary_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input binary features."
)
continuous_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input continuous features."
)
binary_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input binary features."
)
class DataType(str, Enum):
BFLOAT16 = "bfloat16"
BOOL = "bool"
BFLOAT16 = "bfloat16"
BOOL = "bool"
FLOAT32 = "float32"
FLOAT16 = "float16"
FLOAT32 = "float32"
FLOAT16 = "float16"
UINT8 = "uint8"
UINT8 = "uint8"
class DownCast(base_config.BaseConfig):
# Apply down casting to selected features.
features: typing.Dict[str, DataType] = pydantic.Field(
None, description="Map features to down cast data types."
)
# Apply down casting to selected features.
features: typing.Dict[str, DataType] = pydantic.Field(
None, description="Map features to down cast data types."
)
class TaskData(base_config.BaseConfig):
pos_downsampling_rate: float = pydantic.Field(
1.0,
description="Downsampling rate of positives used to generate dataset.",
)
neg_downsampling_rate: float = pydantic.Field(
1.0,
description="Downsampling rate of negatives used to generate dataset.",
)
pos_downsampling_rate: float = pydantic.Field(
1.0,
description="Downsampling rate of positives used to generate dataset.",
)
neg_downsampling_rate: float = pydantic.Field(
1.0,
description="Downsampling rate of negatives used to generate dataset.",
)
class SegDenseSchema(base_config.BaseConfig):
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
features: typing.List[str] = pydantic.Field(
[],
description="List of features (in addition to the renamed features) to read from schema path above.",
)
renamed_features: typing.Dict[str, str] = pydantic.Field(
{}, description="Dictionary of renamed features."
)
mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(
{},
description="(experimental) Number of mantissa bits to mask to simulate lower precision data.",
)
schema_path: str = pydantic.Field(...,
description="Path to feature config json.")
features: typing.List[str] = pydantic.Field(
[],
description="List of features (in addition to the renamed features) to read from schema path above.",
)
renamed_features: typing.Dict[str, str] = pydantic.Field(
{}, description="Dictionary of renamed features."
)
mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(
{},
description="(experimental) Number of mantissa bits to mask to simulate lower precision data.",
)
class RectifyLabels(base_config.BaseConfig):
label_rectification_window_in_hours: float = pydantic.Field(
3.0, description="overlap time in hours for which to flip labels"
)
served_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to served time"
)
impressed_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to impressed time"
)
label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(
..., description="label to the input field corresponding to engagement time"
)
label_rectification_window_in_hours: float = pydantic.Field(
3.0, description="overlap time in hours for which to flip labels"
)
served_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to served time"
)
impressed_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to impressed time"
)
label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(
..., description="label to the input field corresponding to engagement time"
)
class ExtractFeaturesRow(base_config.BaseConfig):
name: str = pydantic.Field(
...,
description="name of the new field name to be created",
)
source_tensor: str = pydantic.Field(
...,
description="name of the dense tensor to look for the feature",
)
index: int = pydantic.Field(
...,
description="index of the feature in the dense tensor",
)
name: str = pydantic.Field(
...,
description="name of the new field name to be created",
)
source_tensor: str = pydantic.Field(
...,
description="name of the dense tensor to look for the feature",
)
index: int = pydantic.Field(
...,
description="index of the feature in the dense tensor",
)
class ExtractFeatures(base_config.BaseConfig):
extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(
[],
description="list of features to be extracted with their name, source tensor and index",
)
extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(
[],
description="list of features to be extracted with their name, source tensor and index",
)
class DownsampleNegatives(base_config.BaseConfig):
batch_multiplier: int = pydantic.Field(
None,
description="batch multiplier",
)
engagements_list: typing.List[str] = pydantic.Field(
[],
description="engagements with kept positives",
)
num_engagements: int = pydantic.Field(
...,
description="number engagements used in the model, including ones excluded in engagements_list",
)
batch_multiplier: int = pydantic.Field(
None,
description="batch multiplier",
)
engagements_list: typing.List[str] = pydantic.Field(
[],
description="engagements with kept positives",
)
num_engagements: int = pydantic.Field(
...,
description="number engagements used in the model, including ones excluded in engagements_list",
)
class Preprocess(base_config.BaseConfig):
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
rectify_labels: RectifyLabels = pydantic.Field(
None, description="Rectify labels for a given overlap window"
)
extract_features: ExtractFeatures = pydantic.Field(
None, description="Extract features from dense tensors."
)
downsample_negatives: DownsampleNegatives = pydantic.Field(
None, description="Downsample negatives."
)
truncate_and_slice: TruncateAndSlice = pydantic.Field(
None, description="Truncation and slicing.")
downcast: DownCast = pydantic.Field(
None, description="Down cast to features.")
rectify_labels: RectifyLabels = pydantic.Field(
None, description="Rectify labels for a given overlap window"
)
extract_features: ExtractFeatures = pydantic.Field(
None, description="Extract features from dense tensors."
)
downsample_negatives: DownsampleNegatives = pydantic.Field(
None, description="Downsample negatives."
)
class Sampler(base_config.BaseConfig):
"""Assumes function is defined in data/samplers.py.
"""Assumes function is defined in data/samplers.py.
Only use this for quick experimentation.
If samplers are useful, we should sample from upstream data generation.
Only use this for quick experimentation.
If samplers are useful, we should sample from upstream data generation.
DEPRICATED, DO NOT USE.
"""
DEPRICATED, DO NOT USE.
"""
name: str
kwargs: typing.Dict
name: str
kwargs: typing.Dict
class RecapDataConfig(DatasetConfig):
seg_dense_schema: SegDenseSchema
seg_dense_schema: SegDenseSchema
tasks: typing.Dict[str, TaskData] = pydantic.Field(
description="Description of individual tasks in this dataset."
)
evaluation_tasks: typing.List[str] = pydantic.Field(
[], description="If specified, lists the tasks we're generating metrics for."
)
tasks: typing.Dict[str, TaskData] = pydantic.Field(
description="Description of individual tasks in this dataset."
)
evaluation_tasks: typing.List[str] = pydantic.Field(
[], description="If specified, lists the tasks we're generating metrics for."
)
preprocess: Preprocess = pydantic.Field(
None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference."
)
preprocess: Preprocess = pydantic.Field(
None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference."
)
sampler: Sampler = pydantic.Field(
None,
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
)
sampler: Sampler = pydantic.Field(
None,
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
)
@pydantic.root_validator()
def _validate_evaluation_tasks(cls, values):
if values.get("evaluation_tasks") is not None:
for task in values["evaluation_tasks"]:
if task not in values["tasks"]:
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
return values
@pydantic.root_validator()
def _validate_evaluation_tasks(cls, values):
if values.get("evaluation_tasks") is not None:
for task in values["evaluation_tasks"]:
if task not in values["tasks"]:
raise KeyError(
f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
return values

View File

@ -11,8 +11,8 @@ from tml.projects.home.recap.data import preprocessors
from tml.projects.home.recap.config import JobMode
from tml.projects.home.recap.data.tfe_parsing import get_seg_dense_parse_fn
from tml.projects.home.recap.data.util import (
keyed_jagged_tensor_from_tensors_dict,
sparse_or_dense_tf_to_torch,
keyed_jagged_tensor_from_tensors_dict,
sparse_or_dense_tf_to_torch,
)
from absl import logging
import torch.distributed as dist
@ -20,458 +20,546 @@ import torch.distributed as dist
@dataclass
class RecapBatch(DataclassBatch):
"""Holds features and labels from the Recap dataset."""
"""Holds features and labels from the Recap dataset."""
continuous_features: torch.Tensor
binary_features: torch.Tensor
discrete_features: torch.Tensor
sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821
labels: torch.Tensor
user_embedding: torch.Tensor = None
user_eng_embedding: torch.Tensor = None
author_embedding: torch.Tensor = None
weights: torch.Tensor = None
continuous_features: torch.Tensor
binary_features: torch.Tensor
discrete_features: torch.Tensor
sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821
labels: torch.Tensor
user_embedding: torch.Tensor = None
user_eng_embedding: torch.Tensor = None
author_embedding: torch.Tensor = None
weights: torch.Tensor = None
def __post_init__(self):
if self.weights is None:
self.weights = torch.ones_like(self.labels)
for feature_name, feature_value in self.as_dict().items():
if ("embedding" in feature_name) and (feature_value is None):
setattr(self, feature_name, torch.empty([0, 0]))
def __post_init__(self):
if self.weights is None:
self.weights = torch.ones_like(self.labels)
for feature_name, feature_value in self.as_dict().items():
if ("embedding" in feature_name) and (feature_value is None):
setattr(self, feature_name, torch.empty([0, 0]))
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
"""Converts a torch data loader output into `RecapBatch`."""
"""Converts a torch data loader output into `RecapBatch`."""
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
try:
features_in, labels = x
except ValueError:
# For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple
features_in, labels = x, None
x = tf.nest.map_structure(functools.partial(
sparse_or_dense_tf_to_torch, pin_memory=False), x)
try:
features_in, labels = x
except ValueError:
# For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple
features_in, labels = x, None
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
if sparse_feature_names:
sparse_features = keyed_jagged_tensor_from_tensors_dict(
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
if sparse_feature_names:
sparse_features = keyed_jagged_tensor_from_tensors_dict(
{embedding_name: features_in[embedding_name]
for embedding_name in sparse_feature_names}
)
user_embedding, user_eng_embedding, author_embedding = None, None, None
if "user_embedding" in features_in:
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
raise ValueError(
"Only one source of embedding for user is supported")
else:
user_embedding = features_in["user_embedding"]
if "user_eng_embedding" in features_in:
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
raise ValueError(
"Only one source of embedding for user is supported")
else:
user_eng_embedding = features_in["user_eng_embedding"]
if "author_embedding" in features_in:
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
raise ValueError(
"Only one source of embedding for user is supported")
else:
author_embedding = features_in["author_embedding"]
return RecapBatch(
continuous_features=features_in["continuous"],
binary_features=features_in["binary"],
discrete_features=features_in["discrete"],
sparse_features=sparse_features,
user_embedding=user_embedding,
user_eng_embedding=user_eng_embedding,
author_embedding=author_embedding,
labels=labels,
# Defaults to torch.ones_like(labels)
weights=features_in.get("weights", None),
)
user_embedding, user_eng_embedding, author_embedding = None, None, None
if "user_embedding" in features_in:
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
else:
user_embedding = features_in["user_embedding"]
if "user_eng_embedding" in features_in:
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
else:
user_eng_embedding = features_in["user_eng_embedding"]
if "author_embedding" in features_in:
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
else:
author_embedding = features_in["author_embedding"]
return RecapBatch(
continuous_features=features_in["continuous"],
binary_features=features_in["binary"],
discrete_features=features_in["discrete"],
sparse_features=sparse_features,
user_embedding=user_embedding,
user_eng_embedding=user_eng_embedding,
author_embedding=author_embedding,
labels=labels,
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels)
)
def _chain(param, f1, f2):
"""
Reduce multiple functions into one chained function
_chain(x, f1, f2) -> f2(f1(x))
"""
output = param
fns = [f1, f2]
for f in fns:
output = f(output)
return output
"""
Reduce multiple functions into one chained function
_chain(x, f1, f2) -> f2(f1(x))
"""
output = param
fns = [f1, f2]
for f in fns:
output = f(output)
return output
def _add_weights(inputs, tasks: Dict[str, TaskData]):
"""Adds weights based on label sampling for positive and negatives.
"""Adds weights based on label sampling for positive and negatives.
This is useful for numeric calibration etc. This mutates inputs.
This is useful for numeric calibration etc. This mutates inputs.
Args:
inputs: A dictionary of strings to tensor-like structures.
tasks: A dict of string (label) to `TaskData` specifying inputs.
Args:
inputs: A dictionary of strings to tensor-like structures.
tasks: A dict of string (label) to `TaskData` specifying inputs.
Returns:
A tuple of features and labels; weights are added to features.
"""
Returns:
A tuple of features and labels; weights are added to features.
"""
weights = []
for key, task in tasks.items():
label = inputs[key]
float_label = tf.cast(label, tf.float32)
weights = []
for key, task in tasks.items():
label = inputs[key]
float_label = tf.cast(label, tf.float32)
weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
)
weights.append(
float_label / task.pos_downsampling_rate +
(1.0 - float_label) / task.neg_downsampling_rate
)
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
return inputs
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(
tf.convert_to_tensor(weights)), axis=0)
return inputs
def get_datetimes(explicit_datetime_inputs):
"""Compute list datetime strings for train/validation data."""
datetime_format = "%Y/%m/%d/%H"
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)
dates = sorted(
[
(end - timedelta(hours=i + 1)).strftime(datetime_format)
for i in range(int(explicit_datetime_inputs.hours))
]
)
return dates
"""Compute list datetime strings for train/validation data."""
datetime_format = "%Y/%m/%d/%H"
end = datetime.strptime(
explicit_datetime_inputs.end_datetime, datetime_format)
dates = sorted(
[
(end - timedelta(hours=i + 1)).strftime(datetime_format)
for i in range(int(explicit_datetime_inputs.hours))
]
)
return dates
def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
"""
Compile list of files for training/validation.
"""
Compile list of files for training/validation.
Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data.
For each hour of data, if the directory is missing or empty, we increment a counter to keep
track of the number of missing data hours.
Returns only files with a `.gz` extension.
Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data.
For each hour of data, if the directory is missing or empty, we increment a counter to keep
track of the number of missing data hours.
Returns only files with a `.gz` extension.
Args:
explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object
Args:
explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object
Returns:
data_files: Sorted list of files to read corresponding to data at the desired datetimes
num_hours_missing: Number of hours that we are missing data
Returns:
data_files: Sorted list of files to read corresponding to data at the desired datetimes
num_hours_missing: Number of hours that we are missing data
"""
datetimes = get_datetimes(explicit_datetime_inputs)
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
data_files = []
num_hours_missing = 0
for folder in folders:
try:
files = tf.io.gfile.listdir(folder)
if not files:
logging.warning(f"{folder} contained no data files")
num_hours_missing += 1
data_files.extend(
[
os.path.join(folder, filename)
for filename in files
if filename.rsplit(".", 1)[-1].lower() == "gz"
]
)
except tf.errors.NotFoundError as e:
num_hours_missing += 1
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
return sorted(data_files), num_hours_missing
"""
datetimes = get_datetimes(explicit_datetime_inputs)
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime)
for datetime in datetimes]
data_files = []
num_hours_missing = 0
for folder in folders:
try:
files = tf.io.gfile.listdir(folder)
if not files:
logging.warning(f"{folder} contained no data files")
num_hours_missing += 1
data_files.extend(
[
os.path.join(folder, filename)
for filename in files
if filename.rsplit(".", 1)[-1].lower() == "gz"
]
)
except tf.errors.NotFoundError as e:
num_hours_missing += 1
logging.warning(
f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
return sorted(data_files), num_hours_missing
def _map_output_for_inference(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
):
if preprocessor:
raise ValueError("No preprocessor should be used at inference time.")
if add_weights:
raise NotImplementedError()
"""
Map the input data for inference.
# Add zero weights.
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1))
for label in tasks:
del inputs[label]
return inputs
Args:
inputs (dict): Input data dictionary.
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
add_weights (bool, optional): Whether to add weights. Defaults to False.
Returns:
dict: Modified input data dictionary for inference.
"""
if preprocessor:
raise ValueError("No preprocessor should be used at inference time.")
if add_weights:
raise NotImplementedError()
# Add zero weights.
inputs["weights"] = tf.zeros_like(
tf.expand_dims(inputs["continuous"][:, 0], -1))
for label in tasks:
del inputs[label]
return inputs
def _map_output_for_train_eval(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
):
if add_weights:
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
"""
Map the input data for training and evaluation.
# Warning this has to happen first as it changes the input
if preprocessor:
inputs = preprocessor(inputs)
Args:
inputs (dict): Input data dictionary.
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
add_weights (bool, optional): Whether to add weights. Defaults to False.
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
Returns:
Tuple[dict, tf.Tensor]: Modified input data dictionary and label values for training and evaluation.
"""
if add_weights:
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
for label in tasks:
del inputs[label]
# Warning this has to happen first as it changes the input
if preprocessor:
inputs = preprocessor(inputs)
return inputs, label_values
label_values = tf.squeeze(
tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
for label in tasks:
del inputs[label]
return inputs, label_values
def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
"""Adds weights based on label sampling for positive and negatives.
"""Adds weights based on label sampling for positive and negatives.
This is useful for numeric calibration etc. This mutates inputs.
This is useful for numeric calibration etc. This mutates inputs.
Args:
inputs: A dictionary of strings to tensor-like structures.
tasks: A dict of string (label) to `TaskData` specifying inputs.
Args:
inputs: A dictionary of strings to tensor-like structures.
tasks: A dict of string (label) to `TaskData` specifying inputs.
Returns:
A tuple of features and labels; weights are added to features.
"""
weights = []
for key, task in tasks.items():
label = inputs[key]
float_label = tf.cast(label, tf.float32)
Returns:
A tuple of features and labels; weights are added to features.
"""
weights = []
for key, task in tasks.items():
label = inputs[key]
float_label = tf.cast(label, tf.float32)
weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
)
weights.append(
float_label / task.pos_downsampling_rate +
(1.0 - float_label) / task.neg_downsampling_rate
)
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
return inputs
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(
tf.convert_to_tensor(weights)), axis=0)
return inputs
class RecapDataset(torch.utils.data.IterableDataset):
def __init__(
self,
data_config: RecapDataConfig,
dataset_service: Optional[str] = None,
mode: JobMode = JobMode.TRAIN,
compression: Optional[str] = "AUTO",
repeat: bool = False,
vocab_mapper: tf.keras.Model = None,
):
logging.info("***** Labels *****")
logging.info(list(data_config.tasks.keys()))
def __init__(
self,
data_config: RecapDataConfig,
dataset_service: Optional[str] = None,
mode: JobMode = JobMode.TRAIN,
compression: Optional[str] = "AUTO",
repeat: bool = False,
vocab_mapper: tf.keras.Model = None,
):
"""
Create a RecapDataset for training or inference.
self._data_config = data_config
self._parse_fn = get_seg_dense_parse_fn(data_config)
self._mode = mode
self._repeat = repeat
self._num_concurrent_iterators = 1
self._vocab_mapper = vocab_mapper
self.dataset_service = dataset_service
Args:
data_config (RecapDataConfig): Data configuration.
dataset_service (str, optional): Dataset service identifier. Defaults to None.
mode (JobMode, optional): Job mode (TRAIN or INFERENCE). Defaults to JobMode.TRAIN.
compression (str, optional): Compression type. Defaults to "AUTO".
repeat (bool, optional): Whether to repeat the dataset. Defaults to False.
vocab_mapper (tf.keras.Model, optional): Vocabulary mapper. Defaults to None.
"""
logging.info("***** Labels *****")
logging.info(list(data_config.tasks.keys()))
preprocessor = None
self._batch_size_multiplier = 1
if data_config.preprocess:
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)
if data_config.preprocess.downsample_negatives:
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
self._data_config = data_config
self._parse_fn = get_seg_dense_parse_fn(data_config)
self._mode = mode
self._repeat = repeat
self._num_concurrent_iterators = 1
self._vocab_mapper = vocab_mapper
self.dataset_service = dataset_service
self._preprocessor = preprocessor
preprocessor = None
self._batch_size_multiplier = 1
if data_config.preprocess:
preprocessor = preprocessors.build_preprocess(
data_config.preprocess, mode=mode)
if data_config.preprocess.downsample_negatives:
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
if mode == JobMode.INFERENCE:
if preprocessor is not None:
raise ValueError("Expect no preprocessor at inference time.")
should_add_weights = False
output_map_fn = _map_output_for_inference # (features,)
else:
# Only add weights if there is a reason to! If all weights will
# be equal to 1.0, save bandwidth between DDS and Chief by simply
# relying on the fact that weights default to 1.0 in `RecapBatch`
# WARNING: Weights may still be added as a side effect of a preprocessor
# such as `DownsampleNegatives`.
should_add_weights = any(
[
task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0
for task_cfg in data_config.tasks.values()
]
)
output_map_fn = _map_output_for_train_eval # (features, labels)
self._preprocessor = preprocessor
self._output_map_fn = functools.partial(
output_map_fn,
tasks=data_config.tasks,
preprocessor=preprocessor,
add_weights=should_add_weights,
)
if mode == JobMode.INFERENCE:
if preprocessor is not None:
raise ValueError("Expect no preprocessor at inference time.")
should_add_weights = False
output_map_fn = _map_output_for_inference # (features,)
else:
# Only add weights if there is a reason to! If all weights will
# be equal to 1.0, save bandwidth between DDS and Chief by simply
# relying on the fact that weights default to 1.0 in `RecapBatch`
# WARNING: Weights may still be added as a side effect of a preprocessor
# such as `DownsampleNegatives`.
should_add_weights = any(
[
task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0
for task_cfg in data_config.tasks.values()
]
)
output_map_fn = _map_output_for_train_eval # (features, labels)
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
self._tf_dataset = self._create_tf_dataset()
self._init_tensor_spec()
def _init_tensor_spec(self):
def _tensor_spec_to_torch_shape(spec):
if spec.shape is None:
return None
shape = [x if x is not None else -1 for x in spec.shape]
return torch.Size(shape)
self.torch_element_spec = tf.nest.map_structure(
_tensor_spec_to_torch_shape, self._tf_dataset.element_spec
)
def _create_tf_dataset(self):
if hasattr(self, "_tf_dataset"):
raise ValueError("Do not call `_create_tf_dataset` more than once.")
world_size = dist.get_world_size() if dist.is_initialized() else 1
per_replica_bsz = (
self._batch_size_multiplier * self._data_config.global_batch_size // world_size
)
dataset: tf.data.Dataset = self._create_base_tf_dataset(
batch_size=per_replica_bsz,
)
if self._repeat:
logging.info("Repeating dataset")
dataset = dataset.repeat()
if self.dataset_service:
if self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self.num_concurrent_iterators,
on_chief=False,
self._output_map_fn = functools.partial(
output_map_fn,
tasks=data_config.tasks,
preprocessor=preprocessor,
add_weights=should_add_weights,
)
self.dataset_id, self.job_name = register_dataset(
dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
)
dataset = distribute_from_dataset_id(
dataset_id=self.dataset_id, # type: ignore[arg-type]
job_name=self.job_name,
dataset_service=self.dataset_service,
compression=self.compression,
)
sparse_feature_names = list(
vocab_mapper.vocabs.keys()) if vocab_mapper else None
elif self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
self._tf_dataset = self._create_tf_dataset()
self._init_tensor_spec()
def _init_tensor_spec(self):
"""
Initialize the tensor specification for the dataset.
"""
def _tensor_spec_to_torch_shape(spec):
if spec.shape is None:
return None
shape = [x if x is not None else -1 for x in spec.shape]
return torch.Size(shape)
self.torch_element_spec = tf.nest.map_structure(
_tensor_spec_to_torch_shape, self._tf_dataset.element_spec
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self._num_concurrent_iterators,
on_chief=True,
)
# Vocabulary mapping happens on the training node, not in dds because of size.
if self._vocab_mapper:
dataset = dataset.map(self._vocab_mapper)
def _create_tf_dataset(self):
"""
Create a TensorFlow dataset from the data files.
return dataset.prefetch(world_size * 2)
Returns:
tf.data.Dataset: TensorFlow dataset.
"""
if hasattr(self, "_tf_dataset"):
raise ValueError(
"Do not call `_create_tf_dataset` more than once.")
def _create_base_tf_dataset(self, batch_size: int):
if self._data_config.inputs:
glob = self._data_config.inputs
filenames = sorted(tf.io.gfile.glob(glob))
elif self._data_config.explicit_datetime_inputs:
num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol
filenames, num_hours_missing = get_explicit_datetime_inputs_files(
self._data_config.explicit_datetime_inputs,
increment="hourly",
)
if num_hours_missing > num_missing_hours_tol:
raise ValueError(
f"We are missing {num_hours_missing} hours of data"
f"more than tolerance {num_missing_hours_tol}."
world_size = dist.get_world_size() if dist.is_initialized() else 1
per_replica_bsz = (
self._batch_size_multiplier * self._data_config.global_batch_size // world_size
)
elif self._data_config.explicit_date_inputs:
num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol
filenames, num_days_missing = get_explicit_datetime_inputs_files(
self._data_config.explicit_date_inputs,
increment="daily",
)
if num_days_missing > num_missing_days_tol:
raise ValueError(
f"We are missing {num_days_missing} days of data"
f"more than tolerance {num_missing_days_tol}."
dataset: tf.data.Dataset = self._create_base_tf_dataset(
batch_size=per_replica_bsz,
)
else:
raise ValueError(
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
)
num_files = len(filenames)
logging.info(f"Found {num_files} data files")
if num_files < 1:
raise ValueError("No data files found")
if self._repeat:
logging.info("Repeating dataset")
dataset = dataset.repeat()
if self._data_config.num_files_to_keep is not None:
filenames = filenames[: self._data_config.num_files_to_keep]
logging.info(f"Retaining only {len(filenames)} files.")
if self.dataset_service:
if self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self.num_concurrent_iterators,
on_chief=False,
)
filenames_ds = (
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
# Because of drop_remainder, if our dataset does not fill
# up a batch, it will emit nothing without this repeat.
.repeat(-1)
)
self.dataset_id, self.job_name = register_dataset(
dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
)
dataset = distribute_from_dataset_id(
dataset_id=self.dataset_id, # type: ignore[arg-type]
job_name=self.job_name,
dataset_service=self.dataset_service,
compression=self.compression,
)
if self._data_config.file_batch_size:
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)
elif self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self._num_concurrent_iterators,
on_chief=True,
)
def per_shard_dataset(filename):
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
return ds.prefetch(4)
# Vocabulary mapping happens on the training node, not in dds because of size.
if self._vocab_mapper:
dataset = dataset.map(self._vocab_mapper)
ds = filenames_ds.interleave(
per_shard_dataset,
block_length=4,
deterministic=False,
num_parallel_calls=self._data_config.interleave_num_parallel_calls
or tf.data.experimental.AUTOTUNE,
)
return dataset.prefetch(world_size * 2)
# Combine functions into one map call to reduce overhead.
map_fn = functools.partial(
_chain,
f1=self._parse_fn,
f2=self._output_map_fn,
)
def _create_base_tf_dataset(self, batch_size: int):
"""
Create the base TensorFlow dataset.
# Shuffle -> Batch -> Parse is the correct ordering
# Shuffling needs to be performed before batching otherwise there is not much point
# Batching happens before parsing because tf.Example parsing is actually vectorized
# and works much faster overall on batches of data.
ds = (
# DANGER DANGER: there is a default shuffle size here.
ds.shuffle(self._data_config.examples_shuffle_buffer_size)
.batch(batch_size=batch_size, drop_remainder=True)
.map(
map_fn,
num_parallel_calls=self._data_config.map_num_parallel_calls
or tf.data.experimental.AUTOTUNE,
)
)
Args:
batch_size (int): Batch size.
if self._data_config.cache:
ds = ds.cache()
Returns:
tf.data.Dataset: Base TensorFlow dataset.
"""
if self._data_config.inputs:
glob = self._data_config.inputs
filenames = sorted(tf.io.gfile.glob(glob))
elif self._data_config.explicit_datetime_inputs:
num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol
filenames, num_hours_missing = get_explicit_datetime_inputs_files(
self._data_config.explicit_datetime_inputs,
increment="hourly",
)
if num_hours_missing > num_missing_hours_tol:
raise ValueError(
f"We are missing {num_hours_missing} hours of data"
f"more than tolerance {num_missing_hours_tol}."
)
elif self._data_config.explicit_date_inputs:
num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol
filenames, num_days_missing = get_explicit_datetime_inputs_files(
self._data_config.explicit_date_inputs,
increment="daily",
)
if num_days_missing > num_missing_days_tol:
raise ValueError(
f"We are missing {num_days_missing} days of data"
f"more than tolerance {num_missing_days_tol}."
)
else:
raise ValueError(
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
)
if self._data_config.ignore_data_errors:
ds = ds.apply(tf.data.experimental.ignore_errors())
num_files = len(filenames)
logging.info(f"Found {num_files} data files")
if num_files < 1:
raise ValueError("No data files found")
options = tf.data.Options()
options.experimental_deterministic = False
ds = ds.with_options(options)
if self._data_config.num_files_to_keep is not None:
filenames = filenames[: self._data_config.num_files_to_keep]
logging.info(f"Retaining only {len(filenames)} files.")
return ds
filenames_ds = (
tf.data.Dataset.from_tensor_slices(
filenames).shuffle(len(filenames))
# Because of drop_remainder, if our dataset does not fill
# up a batch, it will emit nothing without this repeat.
.repeat(-1)
)
def _gen(self):
for x in self._tf_dataset:
yield to_batch(x)
if self._data_config.file_batch_size:
filenames_ds = filenames_ds.batch(
self._data_config.file_batch_size)
def to_dataloader(self) -> Dict[str, torch.Tensor]:
return torch.utils.data.DataLoader(self, batch_size=None)
def per_shard_dataset(filename):
"""
Create a TensorFlow dataset for a single shard file.
def __iter__(self):
return iter(self._gen())
Args:
filename (str): Path to the shard file.
Returns:
tf.data.Dataset: TensorFlow dataset for the shard file.
"""
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
return ds.prefetch(4)
ds = filenames_ds.interleave(
per_shard_dataset,
block_length=4,
deterministic=False,
num_parallel_calls=self._data_config.interleave_num_parallel_calls
or tf.data.experimental.AUTOTUNE,
)
# Combine functions into one map call to reduce overhead.
map_fn = functools.partial(
_chain,
f1=self._parse_fn,
f2=self._output_map_fn,
)
# Shuffle -> Batch -> Parse is the correct ordering
# Shuffling needs to be performed before batching otherwise there is not much point
# Batching happens before parsing because tf.Example parsing is actually vectorized
# and works much faster overall on batches of data.
ds = (
# DANGER DANGER: there is a default shuffle size here.
ds.shuffle(self._data_config.examples_shuffle_buffer_size)
.batch(batch_size=batch_size, drop_remainder=True)
.map(
map_fn,
num_parallel_calls=self._data_config.map_num_parallel_calls
or tf.data.experimental.AUTOTUNE,
)
)
if self._data_config.cache:
ds = ds.cache()
if self._data_config.ignore_data_errors:
ds = ds.apply(tf.data.experimental.ignore_errors())
options = tf.data.Options()
options.experimental_deterministic = False
ds = ds.with_options(options)
return ds
def _gen(self):
for x in self._tf_dataset:
yield to_batch(x)
def to_dataloader(self) -> Dict[str, torch.Tensor]:
"""
Convert the RecapDataset to a PyTorch DataLoader.
Returns:
torch.utils.data.DataLoader: PyTorch DataLoader for the dataset.
"""
return torch.utils.data.DataLoader(self, batch_size=None)
def __iter__(self):
return iter(self._gen())

View File

@ -17,6 +17,16 @@ FLAGS = flags.FLAGS
def _generate_random_example(
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
) -> Dict[str, tf.Tensor]:
"""
Generate a random example based on the provided TensorFlow example schema.
Args:
tf_example_schema (Dict[str, tf.io.FixedLenFeature]): A dictionary defining the schema of the TensorFlow example.
Returns:
Dict[str, tf.Tensor]: A dictionary containing random data for each feature defined in the schema.
"""
example = {}
for feature_name, feature_spec in tf_example_schema.items():
dtype = feature_spec.dtype
@ -33,14 +43,43 @@ def _generate_random_example(
def _float_feature(value):
"""
Create a TensorFlow float feature.
Args:
value: A float or list of floats.
Returns:
tf.train.Feature: A TensorFlow feature containing the float value(s).
"""
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
"""
Create a TensorFlow int64 feature.
Args:
value: An integer or list of integers.
Returns:
tf.train.Feature: A TensorFlow feature containing the int64 value(s).
"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
"""
Serialize a dictionary of TensorFlow tensors into a binary string.
Args:
x (Dict[str, tf.Tensor]): A dictionary of TensorFlow tensors.
Returns:
bytes: The serialized binary string.
"""
feature = {}
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
for feature_name, tensor in x.items():
@ -51,6 +90,15 @@ def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
"""
Generate random data based on the provided configuration and save it as a TFRecord file.
Args:
data_path (str): The path where the TFRecord file will be saved.
config (recap_config_mod.RecapConfig): The configuration for generating the random data.
"""
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
@ -68,6 +116,12 @@ def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
def _generate_data_main(unused_argv):
"""
Main function to generate random data according to the provided configuration.
Args:
unused_argv: Unused command-line arguments.
"""
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
# Find the path where to put the data

View File

@ -9,9 +9,20 @@ import numpy as np
class TruncateAndSlice(tf.keras.Model):
"""Class for truncating and slicing."""
"""
A class for truncating and slicing input features based on the provided configuration.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
def __init__(self, truncate_and_slice_config):
"""
Initializes the TruncateAndSlice model.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
super().__init__()
self._truncate_and_slice_config = truncate_and_slice_config
@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model):
self._binary_mask = None
def call(self, inputs, training=None, mask=None):
"""
Applies truncation and slicing to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of truncated and sliced input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
if self._truncate_and_slice_config.continuous_feature_truncation:
logging.info("Truncating continuous")
@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model):
class DownCast(tf.keras.Model):
"""Class for Down casting dataset before serialization and transferring to training host.
Depends on the data type and the actual data range, the down casting can be lossless or not.
It is strongly recommended to compare the metrics before and after down casting.
"""
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, downcast_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__()
self.config = downcast_config
self._type_map = {
@ -65,6 +98,17 @@ class DownCast(tf.keras.Model):
}
def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
for feature, type_str in self.config.features.items():
assert type_str in self._type_map
@ -78,14 +122,39 @@ class DownCast(tf.keras.Model):
class RectifyLabels(tf.keras.Model):
"""Class for rectifying labels"""
"""
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, rectify_label_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__()
self._config = rectify_label_config
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
served_ts_field = self._config.served_timestamp_field
impressed_ts_field = self._config.impressed_timestamp_field
@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model):
class ExtractFeatures(tf.keras.Model):
"""Class for extracting individual features from dense tensors by their index."""
"""
A class for rectifying labels based on specified conditions.
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
def __init__(self, extract_features_config):
"""
Initializes the RectifyLabels model.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
super().__init__()
self._config = extract_features_config
def call(self, inputs, training=None, mask=None):
"""
Rectifies label values based on the specified conditions.
Args:
inputs: A dictionary of input features including timestamp fields and labels.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of input features with rectified label values.
"""
for row in self._config.extract_feature_table:
inputs[row.name] = inputs[row.source_tensor][:, row.index]
@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model):
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
"""Builds a preprocess model to apply all preprocessing stages."""
"""
Builds a preprocess model to apply all preprocessing stages.
Args:
preprocess_config: A configuration object specifying the preprocessing parameters.
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
Returns:
A preprocess model that applies all specified preprocessing stages.
"""
if mode == config_mod.JobMode.INFERENCE:
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
return None

View File

@ -8,122 +8,129 @@ import tensorflow as tf
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
DTYPE_MAP = {"int64_list": tf.int64,
"float_list": tf.float32, "bytes_list": tf.string}
def create_tf_example_schema(
data_config: recap_data_config.SegDenseSchema,
segdense_schema,
data_config: recap_data_config.SegDenseSchema,
segdense_schema,
):
"""Generate schema for deseralizing tf.Example.
"""Generate schema for deseralizing tf.Example.
Args:
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
labels: List of strings denoting labels.
Args:
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
labels: List of strings denoting labels.
Returns:
A dictionary schema suitable for deserializing tf.Example.
"""
segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys())
used_features = (
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
)
logging.info(used_features)
Returns:
A dictionary schema suitable for deserializing tf.Example.
"""
segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys())
used_features = (
segdense_config.features +
list(segdense_config.renamed_features.values()) + labels
)
logging.info(used_features)
tfe_schema = {}
for entry in segdense_schema:
feature_name = entry["feature_name"]
tfe_schema = {}
for entry in segdense_schema:
feature_name = entry["feature_name"]
if feature_name in used_features:
length = entry["length"]
dtype = entry["dtype"]
if feature_name in used_features:
length = entry["length"]
dtype = entry["dtype"]
if feature_name in labels:
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
)
elif length == -1:
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
else:
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
)
for feature_name in used_features:
if feature_name not in tfe_schema:
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
return tfe_schema
if feature_name in labels:
logging.info(
f"Label: feature name is {feature_name} type is {dtype}")
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
)
elif length == -1:
tfe_schema[feature_name] = tf.io.VarLenFeature(
DTYPE_MAP[dtype])
else:
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
)
for feature_name in used_features:
if feature_name not in tfe_schema:
raise ValueError(
f"{feature_name} missing from schema: {segdense_config.schema_path}.")
return tfe_schema
@functools.lru_cache(1)
def make_mantissa_mask(mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types."""
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
"""For experimentating with emulating bfloat16 or less precise types."""
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types."""
mask: tf.Tensor = make_mantissa_mask(mask_length)
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
"""For experimentating with emulating bfloat16 or less precise types."""
mask: tf.Tensor = make_mantissa_mask(mask_length)
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
def parse_tf_example(
serialized_example,
tfe_schema,
seg_dense_schema_config,
serialized_example,
tfe_schema,
seg_dense_schema_config,
):
"""Parse serialized tf.Example into dict of tensors.
"""Parse serialized tf.Example into dict of tensors.
Args:
serialized_example: Serialized tf.Example to be parsed.
tfe_schema: Dictionary schema suitable for deserializing tf.Example.
Args:
serialized_example: Serialized tf.Example to be parsed.
tfe_schema: Dictionary schema suitable for deserializing tf.Example.
Returns:
Dictionary of tensors to be used as model input.
"""
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
Returns:
Dictionary of tensors to be used as model input.
"""
inputs = tf.io.parse_example(
serialized=serialized_example, features=tfe_schema)
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
inputs[new_feature_name] = inputs.pop(old_feature_name)
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
inputs[new_feature_name] = inputs.pop(old_feature_name)
# This should not actually be used except for experimentation with low precision floats.
if "mask_mantissa_features" in seg_dense_schema_config:
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
# This should not actually be used except for experimentation with low precision floats.
if "mask_mantissa_features" in seg_dense_schema_config:
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
inputs[feature_name] = mask_mantissa(
inputs[feature_name], mask_length)
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
# at TF level.
# We should not return empty tensors if we dont use embeddings.
# Otherwise, it breaks numpy->pt conversion
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
for renamed_key in renamed_keys:
if "embedding" in renamed_key and (renamed_key not in inputs):
inputs[renamed_key] = tf.zeros([], tf.float32)
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
# at TF level.
# We should not return empty tensors if we dont use embeddings.
# Otherwise, it breaks numpy->pt conversion
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
for renamed_key in renamed_keys:
if "embedding" in renamed_key and (renamed_key not in inputs):
inputs[renamed_key] = tf.zeros([], tf.float32)
logging.info(f"parsed example and inputs are {inputs}")
return inputs
logging.info(f"parsed example and inputs are {inputs}")
return inputs
def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):
"""Placeholder for seg dense.
"""Placeholder for seg dense.
In the future, when we use more seg dense variations, we can change this.
"""
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
In the future, when we use more seg dense variations, we can change this.
"""
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
tf_example_schema = create_tf_example_schema(
data_config,
seg_dense_schema,
)
tf_example_schema = create_tf_example_schema(
data_config,
seg_dense_schema,
)
logging.info("***** TF Example Schema *****")
logging.info(tf_example_schema)
logging.info("***** TF Example Schema *****")
logging.info(tf_example_schema)
parse = functools.partial(
parse_tf_example,
tfe_schema=tf_example_schema,
seg_dense_schema_config=data_config.seg_dense_schema,
)
return parse
parse = functools.partial(
parse_tf_example,
tfe_schema=tf_example_schema,
seg_dense_schema_config=data_config.seg_dense_schema,
)
return parse

View File

@ -6,115 +6,160 @@ import tensorflow as tf
def keyed_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor]
tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedTensor":
"""
Convert a dictionary of torch tensor to torchrec keyed tensor
Args:
tensor_map:
"""
Convert a dictionary of torch tensors to a torchrec KeyedTensor.
Returns:
Args:
tensor_map: A mapping of tensor names to torch tensors.
"""
keys = list(tensor_map.keys())
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
# [Batch_size x 1].
values = [
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
for key in keys
]
return torchrec.KeyedTensor.from_tensor_list(keys, values)
Returns:
A torchrec KeyedTensor.
"""
keys = list(tensor_map.keys())
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
# [Batch_size x 1].
values = [
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(
tensor_map[key], -1)
for key in keys
]
return torchrec.KeyedTensor.from_tensor_list(keys, values)
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if tensor.is_sparse:
x = tensor.coalesce() # Ensure that the indices are ordered.
lengths = torch.bincount(x.indices()[0])
values = x.values()
else:
values = tensor
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
return values, lengths
"""
Compute a jagged tensor from a torch tensor.
Args:
tensor: Input torch tensor.
Returns:
A tuple containing the values and lengths of the jagged tensor.
"""
if tensor.is_sparse:
x = tensor.coalesce() # Ensure that the indices are ordered.
lengths = torch.bincount(x.indices()[0])
values = x.values()
else:
values = tensor
lengths = torch.ones(
tensor.shape[0], dtype=torch.int32, device=tensor.device)
return values, lengths
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
"""
Convert a torch tensor to torchrec jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
dense_shape of the sparse tensor can be arbitrary.
Args:
tensor: a torch (sparse) tensor.
Returns:
"""
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
return torchrec.JaggedTensor(values=values, lengths=lengths)
"""
Convert a torch tensor to a torchrec jagged tensor.
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors.
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary.
Args:
tensor: A torch (sparse) tensor.
Returns:
A torchrec JaggedTensor.
"""
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
return torchrec.JaggedTensor(values=values, lengths=lengths)
def keyed_jagged_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor]
tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedJaggedTensor":
"""
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
dense_shape of the sparse tensor can be arbitrary.
Args:
tensor_map:
"""
Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor.
Returns:
Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary.
"""
Args:
tensor_map: A mapping of tensor names to torch tensors.
Returns:
A torchrec KeyedJaggedTensor.
"""
if not tensor_map:
return torchrec.KeyedJaggedTensor(
keys=[],
values=torch.zeros(0, dtype=torch.int),
lengths=torch.zeros(0, dtype=torch.int),
)
values = []
lengths = []
for tensor in tensor_map.values():
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
values.append(torch.squeeze(tensor_val))
lengths.append(tensor_len)
values = torch.cat(values, axis=0)
lengths = torch.cat(lengths, axis=0)
if not tensor_map:
return torchrec.KeyedJaggedTensor(
keys=[],
values=torch.zeros(0, dtype=torch.int),
lengths=torch.zeros(0, dtype=torch.int),
keys=list(tensor_map.keys()),
values=values,
lengths=lengths,
)
values = []
lengths = []
for tensor in tensor_map.values():
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
values.append(torch.squeeze(tensor_val))
lengths.append(tensor_len)
values = torch.cat(values, axis=0)
lengths = torch.cat(lengths, axis=0)
return torchrec.KeyedJaggedTensor(
keys=list(tensor_map.keys()),
values=values,
lengths=lengths,
)
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
return tf_tensor._numpy() # noqa
"""
Convert a TensorFlow tensor to a NumPy array.
Args:
tf_tensor: TensorFlow tensor.
Returns:
NumPy array.
"""
return tf_tensor._numpy() # noqa
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
tensor = _tf_to_numpy(tensor)
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
if tensor.dtype.name == "bfloat16":
tensor = tensor.astype(np.float32)
"""
Convert a dense TensorFlow tensor to a PyTorch tensor.
tensor = torch.from_numpy(tensor)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
Args:
tensor: Dense TensorFlow tensor.
pin_memory: Whether to pin the tensor in memory (for CUDA).
Returns:
PyTorch tensor.
"""
tensor = _tf_to_numpy(tensor)
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
if tensor.dtype.name == "bfloat16":
tensor = tensor.astype(np.float32)
tensor = torch.from_numpy(tensor)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
def sparse_or_dense_tf_to_torch(
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
) -> torch.Tensor:
if isinstance(tensor, tf.SparseTensor):
tensor = torch.sparse_coo_tensor(
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
_dense_tf_to_torch(tensor.values, pin_memory),
torch.Size(_tf_to_numpy(tensor.dense_shape)),
)
else:
tensor = _dense_tf_to_torch(tensor, pin_memory)
return tensor
"""
Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor.
Args:
tensor: TensorFlow tensor (sparse or dense).
pin_memory: Whether to pin the tensor in memory (for CUDA).
Returns:
PyTorch tensor.
"""
if isinstance(tensor, tf.SparseTensor):
tensor = torch.sparse_coo_tensor(
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
_dense_tf_to_torch(tensor.values, pin_memory),
torch.Size(_tf_to_numpy(tensor.dense_shape)),
)
else:
tensor = _dense_tf_to_torch(tensor, pin_memory)
return tensor

View File

@ -30,6 +30,7 @@ class EmbeddingBagConfig(base_config.BaseConfig):
class EmbeddingOptimizerConfig(base_config.BaseConfig):
"""Configuration for the optimizer used for embedding tables."""
learning_rate: optimizer_config.LearningRate = pydantic.Field(
None, description="learning rate scheduler for the EBC"
)
@ -52,6 +53,7 @@ class LargeEmbeddingsConfig(base_config.BaseConfig):
class StratifierConfig(base_config.BaseConfig):
"""Configuration for Stratifier."""
name: str
index: int
value: int
@ -87,7 +89,8 @@ class SmallEmbeddingsConfig(base_config.BaseConfig):
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
serving time due to size (>>1 GB).
This small embeddings table uses the same optimizer as the rest of the model."""
This small embeddings table uses the same optimizer as the rest of the model.
"""
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
..., description="list of embedding tables"

View File

@ -34,6 +34,33 @@ FLAGS = flags.FLAGS
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
"""
Main function to run the training of a ranking model.
This function initializes and runs the training process for a ranking model based on the provided configuration.
Args:
unused_argv (str): Unused argument.
data_service_dispatcher (Optional[str]): The data service dispatcher for accessing training data (optional).
Returns:
None
Raises:
AssertionError: If the configuration or input data is not valid.
Note:
This function serves as the main entry point for training a ranking model. It loads the configuration, sets up
the training environment, defines the loss function, creates the model, optimizer, and scheduler, and runs the
training loop.
Example:
To run the training process, use the following command:
```
python run_training.py --config_path=config.yaml
```
"""
print("#" * 100)
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)

View File

@ -40,6 +40,7 @@ class BatchNormConfig(base_config.BaseConfig):
class DenseLayerConfig(base_config.BaseConfig):
"""Configuration for the dense layer."""
layer_size: pydantic.PositiveInt
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
@ -61,6 +62,7 @@ class BatchNormConfig(base_config.BaseConfig):
class DoubleNormLogConfig(base_config.BaseConfig):
"""Configuration for the double norm log transform."""
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
clip_magnitude: float = pydantic.Field(
5.0, description="Threshold to clip the normalized input values."
@ -73,12 +75,14 @@ class Log1pAbsConfig(base_config.BaseConfig):
class ClipLog1pAbsConfig(base_config.BaseConfig):
"""Configuration for the clip log transform."""
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
3e38, description="Threshold to clip the input values."
)
class ZScoreLogConfig(base_config.BaseConfig):
"""Configuration for the z-score log transform."""
analysis_path: str
schema_path: str = pydantic.Field(
None,
@ -148,6 +152,7 @@ class DcnConfig(base_config.BaseConfig):
class MaskBlockConfig(base_config.BaseConfig):
"""Config for MaskNet block."""
output_size: int
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
None, one_of="aggregation_size"
@ -159,6 +164,7 @@ class MaskBlockConfig(base_config.BaseConfig):
class MaskNetConfig(base_config.BaseConfig):
"""Config for MaskNet model."""
mask_blocks: List[MaskBlockConfig]
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
@ -190,6 +196,7 @@ class AffineMap(base_config.BaseConfig):
class DLRMConfig(base_config.BaseConfig):
"""Config for DLRM model."""
bottom_mlp: MlpConfig = pydantic.Field(
...,
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
@ -198,6 +205,7 @@ class DLRMConfig(base_config.BaseConfig):
class TaskModel(base_config.BaseConfig):
"""Configuration for a single task."""
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
@ -213,6 +221,7 @@ class TaskModel(base_config.BaseConfig):
class MultiTaskType(str, enum.Enum):
"""Type of multi task architecture."""
SHARE_NONE = "share_none" # Tasks are separate.
SHARE_ALL = "share_all" # Tasks share same backbone.
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
@ -247,6 +256,7 @@ class ModelConfig(base_config.BaseConfig):
@pydantic.root_validator()
def _validate_mtl(cls, values):
"""Validate the multi task architecture."""
if values.get("multi_task_type", None) is None:
return values
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:

View File

@ -26,7 +26,19 @@ def unsanitize(sanitized_task_name):
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
""" "Builds a model for a single task"""
"""
Build a model for a single task based on the provided configuration.
Args:
task (model_config_mod.TaskModel): The task model configuration.
input_shape (int): The input shape for the model.
Returns:
torch.nn.Module: The constructed model for the single task.
Raises:
ValueError: If the task configuration is not recognized.
"""
if task.mlp_config:
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
elif task.dcn_config:
@ -38,7 +50,12 @@ def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int)
class MultiTaskRankingModel(torch.nn.Module):
"""Multi-task ranking model."""
"""
Multi-task ranking model that handles multiple ranking tasks simultaneously.
This model takes various input features and predicts rankings for multiple
tasks using shared or separate towers.
"""
def __init__(
self,
@ -47,12 +64,18 @@ class MultiTaskRankingModel(torch.nn.Module):
data_config: RecapDataConfig,
return_backbone: bool = False,
):
"""Constructor for Multi task learning.
"""
Constructor for Multi-task ranking model.
Args:
input_shapes (Mapping[str, torch.Size]): A mapping of input feature names to their shapes.
config (ModelConfig): The model configuration.
data_config (RecapDataConfig): The data configuration.
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
Assumptions made:
1. Tasks specified in data config match model architecture.
These are all validated in config.
1. Tasks specified in data config match model architecture.
These are all validated in config.
"""
super().__init__()
@ -168,6 +191,23 @@ class MultiTaskRankingModel(torch.nn.Module):
labels: Optional[torch.Tensor] = None,
weights: Optional[torch.Tensor] = None,
):
"""
Forward pass of the Multi-task ranking model.
Args:
continuous_features (torch.Tensor): Continuous input features.
binary_features (torch.Tensor): Binary input features.
discrete_features (Optional[torch.Tensor], optional): Discrete input features. Defaults to None.
sparse_features ([type], optional): Sparse input features. Defaults to None.
user_embedding (Optional[torch.Tensor], optional): User embeddings. Defaults to None.
user_eng_embedding (Optional[torch.Tensor], optional): User engagement embeddings. Defaults to None.
author_embedding (Optional[torch.Tensor], optional): Author embeddings. Defaults to None.
labels (Optional[torch.Tensor], optional): Target labels. Defaults to None.
weights (Optional[torch.Tensor], optional): Weights for the loss function. Defaults to None.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the model's outputs.
"""
concat_dense_features = [
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
]
@ -270,6 +310,20 @@ def create_ranking_model(
data_config=None,
return_backbone=False,
):
"""
Creates a ranking model based on the provided specifications and configuration.
Args:
data_spec: The input data specifications.
config (config_mod.RecapConfig): The model configuration.
device (torch.device): The device where the model should be placed.
loss_fn (Optional[Callable], optional): A custom loss function. Defaults to None.
data_config: The data configuration. Defaults to None.
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
Returns:
torch.nn.Module: The created ranking model.
"""
if list(config.model.tasks.values())[0].dlrm_config:
raise NotImplementedError()

View File

@ -11,21 +11,52 @@ import torch
def log_transform(x: torch.Tensor) -> torch.Tensor:
"""Safe log transform that works across both negative, zero, and positive floats."""
"""
Safe log transform that works across both negative, zero, and positive floats.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with log1p applied to absolute values.
"""
return torch.sign(x) * torch.log1p(torch.abs(x))
class BatchNorm(torch.nn.Module):
def __init__(self, num_features: int, config: BatchNormConfig):
"""
Batch normalization layer.
Args:
num_features (int): Number of input features.
config (BatchNormConfig): Configuration for batch normalization.
"""
super().__init__()
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the batch normalization layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after batch normalization.
"""
return self.layer(x)
class LayerNorm(torch.nn.Module):
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
"""
Layer normalization layer.
Args:
normalized_shape (Union[int, Sequence[int]]): Size or shape of the input tensor.
config (LayerNormConfig): Configuration for layer normalization.
"""
super().__init__()
if config.axis != -1:
raise NotImplementedError
@ -38,6 +69,16 @@ class LayerNorm(torch.nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the layer normalization layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after layer normalization.
"""
return self.layer(x)
@ -46,11 +87,27 @@ class Log1pAbs(torch.nn.Module):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that applies a log transformation to the input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with log applied to absolute values.
"""
return log_transform(x)
class InputNonFinite(torch.nn.Module):
def __init__(self, fill_value: float = 0):
"""
Replaces non-finite (NaN and Inf) values in the input tensor with a specified fill value.
Args:
fill_value (float): The value to fill non-finite elements with. Default is 0.
"""
super().__init__()
self.register_buffer(
@ -58,11 +115,27 @@ class InputNonFinite(torch.nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that replaces non-finite values in the input tensor with the specified fill value.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with non-finite values replaced.
"""
return torch.where(torch.isfinite(x), x, self.fill_value)
class Clamp(torch.nn.Module):
def __init__(self, min_value: float, max_value: float):
"""
Applies element-wise clamping to a tensor, ensuring that values are within a specified range.
Args:
min_value (float): The minimum value to clamp elements to.
max_value (float): The maximum value to clamp elements to.
"""
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
@ -74,12 +147,31 @@ class Clamp(torch.nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that clamps the input tensor element-wise within the specified range.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with elements clamped within the specified range.
"""
return torch.clamp(x, min=self.min_value, max=self.max_value)
class DoubleNormLog(torch.nn.Module):
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features."""
"""
Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.
Args:
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
config (DoubleNormLogConfig): Configuration for the DoubleNormLog module.
Attributes:
_before_concat_layers (torch.nn.Sequential): Sequential layers for batch normalization, log transformation,
batch normalization (optional), and clamping.
layer_norm (LayerNorm or None): Layer normalization layer for binary and continuous features (optional).
"""
def __init__(
self,
input_shapes: Mapping[str, Sequence[int]],
@ -108,6 +200,17 @@ class DoubleNormLog(torch.nn.Module):
def forward(
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
) -> torch.Tensor:
"""
Forward pass that processes continuous and binary features using batch normalization, log transformation,
optional batch normalization (if configured), clamping, and layer normalization (if configured).
Args:
continuous_features (torch.Tensor): Input tensor of continuous features.
binary_features (torch.Tensor): Input tensor of binary features.
Returns:
torch.Tensor: Transformed tensor containing both continuous and binary features.
"""
x = self._before_concat_layers(continuous_features)
x = torch.cat([x, binary_features], dim=1)
if self.layer_norm:
@ -118,5 +221,15 @@ class DoubleNormLog(torch.nn.Module):
def build_features_preprocessor(
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
):
"""Trivial right now, but we will change in the future."""
"""
Build a feature preprocessor module based on the provided configuration.
Trivial right now, but we will change in the future.
Args:
config (FeaturizationConfig): Configuration for feature preprocessing.
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
Returns:
DoubleNormLog: An instance of the DoubleNormLog feature preprocessor.
"""
return DoubleNormLog(input_shapes, config.double_norm_log_config)

View File

@ -6,93 +6,243 @@ import torch
def _init_weights(module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
"""Initializes weights
Example
```python
import torch
import torch.nn as nn
# Define a simple linear layer
linear_layer = nn.Linear(64, 32)
# Initialize the weights and biases using _init_weights
_init_weights(linear_layer)
```
"""
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class MaskBlock(torch.nn.Module):
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
super(MaskBlock, self).__init__()
self.mask_block_config = mask_block_config
output_size = mask_block_config.output_size
"""
MaskBlock module in a mask-based neural network.
if mask_block_config.input_layer_norm:
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
else:
self._input_layer_norm = None
This module represents a MaskBlock, which applies a masking operation to the input data and then
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
if mask_block_config.reduction_factor:
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
elif mask_block_config.aggregation_size is not None:
aggregation_size = mask_block_config.aggregation_size
else:
raise ValueError("Need one of reduction factor or aggregation size.")
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size),
torch.nn.ReLU(),
torch.nn.Linear(aggregation_size, input_dim),
)
self._mask_layer.apply(_init_weights)
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
self._hidden_layer.apply(_init_weights)
self._layer_norm = torch.nn.LayerNorm(output_size)
Example:
To create and use a MaskBlock within a MaskNet, follow these steps:
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
if self._input_layer_norm:
net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
return self._layer_norm(hidden_layer_output)
```python
# Define the configuration for the MaskBlock
mask_block_config = MaskBlockConfig(
input_layer_norm=True, # Apply input layer normalization
reduction_factor=0.5 # Reduce input dimensionality by 50%
)
# Create an instance of the MaskBlock
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
mask_input = torch.randn(batch_size, 32)
# Perform a forward pass through the MaskBlock
output = mask_block(input_data, mask_input)
```
Note:
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
operation that combines the input and mask input. Then, it passes the result through a hidden layer
with optional dimensionality reduction.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
"""
Initializes the MaskBlock module.
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
Returns:
None
"""
super(MaskBlock, self).__init__()
self.mask_block_config = mask_block_config
output_size = mask_block_config.output_size
if mask_block_config.input_layer_norm:
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
else:
self._input_layer_norm = None
if mask_block_config.reduction_factor:
aggregation_size = int(
mask_input_dim * mask_block_config.reduction_factor)
elif mask_block_config.aggregation_size is not None:
aggregation_size = mask_block_config.aggregation_size
else:
raise ValueError(
"Need one of reduction factor or aggregation size.")
self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size),
torch.nn.ReLU(),
torch.nn.Linear(aggregation_size, input_dim),
)
self._mask_layer.apply(_init_weights)
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
self._hidden_layer.apply(_init_weights)
self._layer_norm = torch.nn.LayerNorm(output_size)
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
"""
Performs a forward pass through the MaskBlock.
Args:
net (torch.Tensor): Input data tensor.
mask_input (torch.Tensor): Mask input tensor.
Returns:
torch.Tensor: Output tensor of the MaskBlock.
"""
if self._input_layer_norm:
net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(
net * self._mask_layer(mask_input))
return self._layer_norm(hidden_layer_output)
class MaskNet(torch.nn.Module):
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
super().__init__()
self.mask_net_config = mask_net_config
mask_blocks = []
"""
MaskNet module in a mask-based neural network.
if mask_net_config.use_parallel:
total_output_mask_blocks = 0
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else:
input_size = in_features
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
create mask-based neural networks with parallel or stacked MaskBlocks.
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
total_output_mask_blocks = mask_block_config.output_size
Args:
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
if mask_net_config.mlp:
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
self.out_features = mask_net_config.mlp.layer_sizes[-1]
else:
self.out_features = total_output_mask_blocks
self.shared_size = total_output_mask_blocks
Example:
To create and use a MaskNet, you can follow these steps:
def forward(self, inputs: torch.Tensor):
if self.mask_net_config.use_parallel:
mask_outputs = []
for mask_layer in self._mask_blocks:
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
# Share the outputs of the MaskBlocks.
all_mask_outputs = torch.cat(mask_outputs, dim=1)
output = (
all_mask_outputs
if self.mask_net_config.mlp is None
else self._dense_layers(all_mask_outputs)["output"]
)
return {"output": output, "shared_layer": all_mask_outputs}
else:
net = inputs
for mask_layer in self._mask_blocks:
net = mask_layer(net=net, mask_input=inputs)
# Share the output of the stacked MaskBlocks.
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
return {"output": output, "shared_layer": net}
```python
# Define the configuration for the MaskNet
mask_net_config = MaskNetConfig(
use_parallel=True, # Use parallel MaskBlocks
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
)
# Create an instance of the MaskNet
mask_net = MaskNet(mask_net_config, in_features=64)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
# Perform a forward pass through the MaskNet
outputs = mask_net(input_data)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
"""
Initializes the MaskNet module.
Args:
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
Returns:
None
"""
super().__init__()
self.mask_net_config = mask_net_config
mask_blocks = []
if mask_net_config.use_parallel:
total_output_mask_blocks = 0
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(
MaskBlock(mask_block_config, in_features, in_features))
total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else:
input_size = in_features
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(
MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
total_output_mask_blocks = mask_block_config.output_size
if mask_net_config.mlp:
self._dense_layers = mlp.Mlp(
total_output_mask_blocks, mask_net_config.mlp)
self.out_features = mask_net_config.mlp.layer_sizes[-1]
else:
self.out_features = total_output_mask_blocks
self.shared_size = total_output_mask_blocks
def forward(self, inputs: torch.Tensor):
"""
Performs a forward pass through the MaskNet.
Args:
inputs (torch.Tensor): Input data tensor.
Returns:
torch.Tensor: Output tensor of the MaskNet.
"""
if self.mask_net_config.use_parallel:
mask_outputs = []
for mask_layer in self._mask_blocks:
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
# Share the outputs of the MaskBlocks.
all_mask_outputs = torch.cat(mask_outputs, dim=1)
output = (
all_mask_outputs
if self.mask_net_config.mlp is None
else self._dense_layers(all_mask_outputs)["output"]
)
return {"output": output, "shared_layer": all_mask_outputs}
else:
net = inputs
for mask_layer in self._mask_blocks:
net = mask_layer(net=net, mask_input=inputs)
# Share the output of the stacked MaskBlocks.
output = net if self.mask_net_config.mlp is None else self._dense_layers[
net]["output"]
return {"output": output, "shared_layer": net}

View File

@ -7,13 +7,83 @@ from absl import logging
def _init_weights(module):
"""Initializes weights
Example
-------
```python
import torch
import torch.nn as nn
# Define a simple linear layer
linear_layer = nn.Linear(64, 32)
# Initialize the weights and biases using _init_weights
_init_weights(linear_layer)
```
"""
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class Mlp(torch.nn.Module):
"""
Multi-Layer Perceptron (MLP) feedforward neural network module in PyTorch.
This module defines an MLP with customizable layers and activation functions. It is suitable for various
applications such as deep learning for tabular data, feature extraction, and more.
Args:
in_features (int): The number of input features or input dimensions.
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
Example:
To create an instance of the `Mlp` module and use it for forward passes, you can follow these steps:
```python
# Define the configuration for the MLP
mlp_config = MlpConfig(
layer_sizes=[128, 64], # Specify the sizes of hidden layers
batch_norm=True, # Enable batch normalization
dropout=0.2, # Apply dropout with a rate of 0.2
final_layer_activation=True # Apply ReLU activation to the final layer
)
# Create an instance of the MLP module
mlp_model = Mlp(in_features=input_dim, mlp_config=mlp_config)
# Generate an input tensor
input_tensor = torch.randn(batch_size, input_dim)
# Perform a forward pass through the MLP
outputs = mlp_model(input_tensor)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `Mlp` class allows you to create customizable MLP architectures by specifying the layer sizes,
enabling batch normalization and dropout, and choosing the activation function for the final layer.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, in_features: int, mlp_config: MlpConfig):
"""
Initializes the Mlp module.
Args:
in_features (int): The number of input features or input dimensions.
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
Returns:
None
"""
super().__init__()
self._mlp_config = mlp_config
input_size = in_features
@ -42,6 +112,15 @@ class Mlp(torch.nn.Module):
self.layers.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs a forward pass through the MLP.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: Output tensor of the MLP.
"""
net = x
for i, layer in enumerate(self.layers):
net = layer(net)
@ -51,8 +130,21 @@ class Mlp(torch.nn.Module):
@property
def shared_size(self):
"""
Returns the size of the shared layer in the MLP.
Returns:
int: Size of the shared layer.
"""
return self._mlp_config.layer_sizes[-1]
@property
def out_features(self):
"""
Returns the number of output features from the MLP.
Returns:
int: Number of output features.
"""
return self._mlp_config.layer_sizes[-1]

View File

@ -5,63 +5,117 @@ from absl import logging
class ModelAndLoss(torch.nn.Module):
def __init__(
self,
model,
loss_fn: Callable,
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.stratifiers = stratifiers
PyTorch module that combines a neural network model and loss function.
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
This module wraps a neural network model and facilitates the forward pass through the model
while also calculating the loss based on the model's predictions and provided labels.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
Args:
model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification. Each stratifier config includes the name and index of discrete features
to emit for stratification.
"""
outputs = self.model(
continuous_features=batch.continuous_features,
binary_features=batch.binary_features,
discrete_features=batch.discrete_features,
sparse_features=batch.sparse_features,
user_embedding=batch.user_embedding,
user_eng_embedding=batch.user_eng_embedding,
author_embedding=batch.author_embedding,
labels=batch.labels,
weights=batch.weights,
)
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
Example:
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
and loss function as arguments:
if self.stratifiers:
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}")
outputs["stratifiers"] = {}
for stratifier in self.stratifiers:
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
```python
# Create a neural network model
model = YourNeuralNetworkModel()
# In general, we can have a large number of losses returned by our loss function.
if isinstance(losses, dict):
return losses["loss"], {
**outputs,
**losses,
"labels": batch.labels,
"weights": batch.weights,
}
else: # Assume that this is a float.
return losses, {
**outputs,
"loss": losses,
"labels": batch.labels,
"weights": batch.weights,
}
# Define a loss function
loss_fn = torch.nn.CrossEntropyLoss()
# Create an instance of ModelAndLoss
model_and_loss = ModelAndLoss(model, loss_fn)
# Generate a batch of training data (e.g., RecapBatch)
batch = generate_training_batch()
# Perform a forward pass through the model and calculate the loss
loss, outputs = model_and_loss(batch)
# You can now backpropagate and optimize using the computed loss
loss.backward()
optimizer.step()
```
Note:
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
calculating loss, making it easier to integrate the model into your training loop. Additionally,
it supports the addition of stratifiers for metrics stratification, if needed.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
model,
loss_fn: Callable,
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
) -> None:
"""
Initializes the ModelAndLoss module.
Args:
model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.stratifiers = stratifiers
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(
continuous_features=batch.continuous_features,
binary_features=batch.binary_features,
discrete_features=batch.discrete_features,
sparse_features=batch.sparse_features,
user_embedding=batch.user_embedding,
user_eng_embedding=batch.user_eng_embedding,
author_embedding=batch.author_embedding,
labels=batch.labels,
weights=batch.weights,
)
losses = self.loss_fn(
outputs["logits"], batch.labels.float(), batch.weights.float())
if self.stratifiers:
logging.info(
f"***** Adding stratifiers *****\n {self.stratifiers}")
outputs["stratifiers"] = {}
for stratifier in self.stratifiers:
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:,
stratifier.index]
# In general, we can have a large number of losses returned by our loss function.
if isinstance(losses, dict):
return losses["loss"], {
**outputs,
**losses,
"labels": batch.labels,
"weights": batch.weights,
}
else: # Assume that this is a float.
return losses, {
**outputs,
"loss": losses,
"labels": batch.labels,
"weights": batch.weights,
}

View File

@ -2,18 +2,65 @@ import torch
class NumericCalibration(torch.nn.Module):
def __init__(
self,
pos_downsampling_rate: float,
neg_downsampling_rate: float,
):
super().__init__()
"""
Numeric calibration module for adjusting probability scores.
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
self.register_buffer(
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
)
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
for tasks such as binary classification.
def forward(self, probs: torch.Tensor):
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))
Args:
pos_downsampling_rate (float): The downsampling rate for positive samples.
neg_downsampling_rate (float): The downsampling rate for negative samples.
Example:
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
scores like this:
```python
# Create a NumericCalibration instance with downsampling rates
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
# Generate probability scores (e.g., from a neural network)
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
# Apply numeric calibration to adjust the probabilities
calibrated_probs = calibration(raw_probs)
# The `calibrated_probs` now contains the adjusted probability scores
```
Note:
The `NumericCalibration` module is used to adjust probability scores to account for differences in
the number of positive and negative samples in a dataset. It can help improve the calibration of
probability estimates in imbalanced classification problems.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(
self,
pos_downsampling_rate: float,
neg_downsampling_rate: float,
):
"""
Apply numeric calibration to probability scores.
Args:
probs (torch.Tensor): Probability scores to be calibrated.
Returns:
torch.Tensor: Calibrated probability scores.
"""
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
self.register_buffer(
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
)
def forward(self, probs: torch.Tensor):
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))

View File

@ -9,12 +9,60 @@ import pydantic
class RecapAdamConfig(base_config.BaseConfig):
"""
Configuration settings for the Adam optimizer used in Recap.
Args:
beta_1 (float): Momentum term (default: 0.9).
beta_2 (float): Exponential weighted decay factor (default: 0.999).
epsilon (float): Numerical stability in the denominator (default: 1e-7).
Example:
To define an Adam optimizer configuration for Recap, use:
```python
adam_config = RecapAdamConfig(beta_1=0.9, beta_2=0.999, epsilon=1e-7)
```
Note:
This class configures the parameters of the Adam optimizer, which is commonly used for optimizing neural networks.
Warning:
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
"""
beta_1: float = 0.9 # Momentum term.
beta_2: float = 0.999 # Exponential weighted decay factor.
epsilon: float = 1e-7 # Numerical stability in denominator.
class MultiTaskLearningRates(base_config.BaseConfig):
"""
Configuration settings for multiple learning rates in Recap.
Args:
tower_learning_rates (Dict[str, optimizers_config_mod.LearningRate]): Learning rates for different towers of the model.
backbone_learning_rate (optimizers_config_mod.LearningRate): Learning rate for the model's backbone (default: None).
Example:
To define multiple learning rates for different towers in Recap, use:
```python
multi_task_lr = MultiTaskLearningRates(
tower_learning_rates={
'task1': learning_rate1,
'task2': learning_rate2,
},
backbone_learning_rate=backbone_lr,
)
```
Note:
This class allows specifying different learning rates for different parts of the model, including task-specific towers and the backbone.
Warning:
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
"""
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
description="Learning rates for different towers of the model."
)
@ -25,6 +73,30 @@ class MultiTaskLearningRates(base_config.BaseConfig):
class RecapOptimizerConfig(base_config.BaseConfig):
"""
Configuration settings for the Recap optimizer.
Args:
multi_task_learning_rates (MultiTaskLearningRates): Multiple learning rates for different tasks (optional).
single_task_learning_rate (optimizers_config_mod.LearningRate): Learning rate for a single task (optional).
adam (RecapAdamConfig): Configuration settings for the Adam optimizer.
Example:
To define an optimizer configuration for training with Recap, use:
```python
optimizer_config = RecapOptimizerConfig(
multi_task_learning_rates=multi_task_lr,
single_task_learning_rate=single_task_lr,
adam=adam_config,
)
```
Warning:
This class is intended for internal use to configure the optimizer settings within Recap and should not be
directly accessed by external code.
"""
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
None, description="Multiple learning rates for different tasks.", one_of="lr"
)

View File

@ -23,12 +23,30 @@ _DENSE_EMBEDDINGS = "dense_ebc"
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
A shim to get learning rates into a LRScheduler.
This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports
learning rate schedules, such as exponential decay.
Args:
optimizer: The optimizer to which this scheduler is applied.
lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations.
emb_learning_rate: The learning rate for embeddings (optional).
last_epoch (int): The index of the last epoch (default: -1).
verbose (bool): If True, print warnings for deprecated functions (default: False).
Example:
To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use:
```python
scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate)
```
Warning:
This class is intended for internal use to handle learning rate scheduling within Recap training and should not
be directly accessed by external code.
"""
def __init__(
self,
@ -80,15 +98,25 @@ def build_optimizer(
optimizer_config: config.OptimizerConfig,
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
):
"""Builds an optimizer and scheduler.
Args:
model: A torch model, probably with DDP/DMP.
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
Returns:
A torch.optim instance, and a scheduler instance.
"""
Build an optimizer and scheduler for training.
Args:
model: The torch model, possibly with DDP/DMP.
optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer.
emb_optimizer_config: Configuration settings for embedding optimization (optional).
Returns:
torch.optim.Optimizer: The optimizer for training.
RecapLRShim: The learning rate scheduler for the optimizer.
Example:
To build an optimizer and scheduler for training, use:
```python
optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config)
```
"""
optimizer_fn = functools.partial(
torch.optim.Adam,
lr=_DEFAULT_LR,

View File

@ -4,6 +4,17 @@ import pydantic
class TwhinDataConfig(base_config.BaseConfig):
"""
Configuration for Twhin model training data.
Args:
data_root (str): The root directory for the training data.
per_replica_batch_size (pydantic.PositiveInt): Batch size per replica.
global_negatives (int): The number of global negatives.
in_batch_negatives (int): The number of in-batch negatives.
limit (pydantic.PositiveInt): The limit on the number of data points to use.
offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None.
"""
data_root: str
per_replica_batch_size: pydantic.PositiveInt
global_negatives: int

View File

@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
"""
Create a dataset for Twhin model training.
Args:
data_config (TwhinDataConfig): The data configuration for the dataset.
model_config (TwhinModelConfig): The model configuration containing embeddings and relations.
Returns:
EdgesDataset: The dataset for Twhin model training.
"""
tables = model_config.embeddings.tables
table_sizes = {table.name: table.num_embeddings for table in tables}
relations = model_config.relations

View File

@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@dataclass
class EdgeBatch(DataclassBatch):
"""
Batch data structure for edge-based models.
Args:
nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
labels (torch.Tensor): Tensor containing labels.
rels (torch.Tensor): Tensor containing relation information.
weights (torch.Tensor): Tensor containing weights.
"""
nodes: KeyedJaggedTensor
labels: torch.Tensor
rels: torch.Tensor
@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch):
class EdgesDataset(Dataset):
"""
Dataset for edge-based models.
Args:
file_pattern (str): The file pattern for the dataset.
table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
relations (List[Relation]): A list of relations between tables.
lhs_column_name (str): The name of the left-hand-side column.
rhs_column_name (str): The name of the right-hand-side column.
rel_column_name (str): The name of the relation column.
**dataset_kwargs: Additional keyword arguments for the parent Dataset class.
"""
rng = np.random.default_rng()
def __init__(
@ -56,6 +77,15 @@ class EdgesDataset(Dataset):
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
def pa_to_batch(self, batch: pa.RecordBatch):
"""
Converts a pyarrow RecordBatch to an EdgeBatch.
Args:
batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
Returns:
EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
"""
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
@ -74,6 +104,14 @@ class EdgesDataset(Dataset):
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
"""Process edges that contain lhs index, rhs index, relation index.
Args:
lhs (torch.Tensor): Tensor containing left-hand-side indices.
rhs (torch.Tensor): Tensor containing right-hand-side indices.
rel (torch.Tensor): Tensor containing relation indices.
Returns:
Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
Example:
```
@ -147,6 +185,12 @@ class EdgesDataset(Dataset):
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
def to_batches(self):
"""
Converts data to batches.
Yields:
pa.RecordBatch: A pyarrow RecordBatch containing data.
"""
ds = super().to_batches()
batch_size = self._dataset_kwargs["batch_size"]

View File

@ -23,6 +23,25 @@ RELATIONS = [
def test_gen():
"""Test function for generating edge-based datasets and dataloaders.
This function generates a synthetic dataset and tests the creation of an `EdgesDataset`
instance and a dataloader for it.
The test includes the following steps:
1. Create synthetic data with left-hand-side (lhs), right-hand-side (rhs), and relation (rel) columns.
2. Write the synthetic data to a Parquet file.
3. Create an `EdgesDataset` instance with the Parquet file pattern, table sizes, relations, and batch size.
4. Initialize the local file system for the dataset.
5. Create a dataloader for the dataset and retrieve the first batch.
6. Assert that the labels in the batch are positive.
7. Verify that the positive examples in the batch match the expected values.
This function serves as a test case for the data generation and dataset creation process.
Raises:
AssertionError: If any of the test assertions fail.
"""
import os
import tempfile

View File

@ -10,8 +10,29 @@ from pydantic import validator
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
"""
Configuration class for Twhin model embeddings.
This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types
for all tables in the Twhin model embeddings configuration match.
Attributes:
tables (List[TableConfig]): A list of table configurations for the model's embeddings.
"""
@validator("tables")
def embedding_dims_match(cls, tables):
"""
Validate that embedding dimensions and data types match for all tables.
Args:
tables (List[TableConfig]): List of table configurations.
Returns:
List[TableConfig]: The list of validated table configurations.
Raises:
AssertionError: If embedding dimensions or data types do not match.
"""
embedding_dim = tables[0].embedding_dim
data_type = tables[0].data_type
for table in tables:
@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
class Operator(str, enum.Enum):
"""
Enumeration of operator types.
This enumeration defines different types of operators that can be applied to Twhin model relations.
"""
TRANSLATION = "translation"
class Relation(pydantic.BaseModel):
"""graph relationship properties and operator"""
"""
Configuration class for graph relationships in the Twhin model.
This class defines properties and operators for graph relationships in the Twhin model.
Attributes:
name (str): The name of the relationship.
lhs (str): The name of the entity on the left-hand side of the relation.
rhs (str): The name of the entity on the right-hand side of the relation.
operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product.
"""
name: str = pydantic.Field(..., description="Relationship name.")
lhs: str = pydantic.Field(
@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel):
class TwhinModelConfig(base_config.BaseConfig):
"""
Configuration class for the Twhin model.
This class defines configuration options specific to the Twhin model.
Attributes:
embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings.
relations (List[Relation]): List of graph relationship configurations.
translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation.
"""
embeddings: TwhinEmbeddingsConfig
relations: typing.List[Relation]
translation_optimizer: OptimizerConfig
@validator("relations", each_item=True)
def valid_node_types(cls, relation, values, **kwargs):
"""
Validate that the specified node types in relations are valid table names in embeddings.
Args:
relation (Relation): A single relation configuration.
values (dict): The values dictionary containing the "embeddings" configuration.
Returns:
Relation: The validated relation configuration.
Raises:
AssertionError: If the specified node types are not valid table names in embeddings.
"""
table_names = [table.name for table in values["embeddings"].tables]
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"

View File

@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa
class TwhinModel(nn.Module):
"""
Twhin model for graph-based entity embeddings and translation.
This class defines the Twhin model, which is used for learning embeddings of entities in a graph
and applying translations to these embeddings based on graph relationships.
Args:
model_config (TwhinModelConfig): Configuration for the Twhin model.
data_config (TwhinDataConfig): Configuration for the data used by the model.
Attributes:
batch_size (int): The batch size used for training.
table_names (List[str]): Names of tables in the model's embeddings.
large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
embedding_dim (int): Dimensionality of entity embeddings.
num_tables (int): Number of tables in the model's embeddings.
in_batch_negatives (int): Number of in-batch negative samples to use during training.
global_negatives (int): Number of global negative samples to use during training.
num_relations (int): Number of graph relationships in the model.
all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
"""
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
super().__init__()
self.batch_size = data_config.per_replica_batch_size
@ -31,7 +53,17 @@ class TwhinModel(nn.Module):
)
def forward(self, batch: EdgeBatch):
"""
Forward pass of the Twhin model.
Args:
batch (EdgeBatch): Input batch containing graph edge information.
Returns:
dict: A dictionary containing model output with "logits" and "probabilities".
- "logits" (torch.Tensor): Logit scores.
- "probabilities" (torch.Tensor): Sigmoid probabilities.
"""
# B x D
trans_embs = self.all_trans_embs.data[batch.rels]
@ -98,6 +130,18 @@ class TwhinModel(nn.Module):
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
"""
Apply optimizers to the Twhin model's embeddings.
This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
Args:
model (TwhinModel): The Twhin model to apply optimizers to.
model_config (TwhinModelConfig): Configuration for the Twhin model.
Returns:
TwhinModel: The Twhin model with optimizers applied to its embeddings.
"""
for table in model_config.embeddings.tables:
optimizer_class = get_optimizer_class(table.optimizer)
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module):
device: torch.device,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
"""
Initialize a TwhinModelAndLoss module.
Args:
model: The torch module to wrap.
loss_fn: A function for calculating loss, should accept logits and labels.
data_config: Configuration for Twhin data.
device: The torch device to use for calculations.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module):
self.device = device
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
Run the model forward and calculate the loss according to the given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types
Args:
batch ("RecapBatch"): The input batch for model inference.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
additional outputs including logits, labels, and weights.
"""
outputs = self.model(batch)
logits = outputs["logits"]

View File

@ -18,6 +18,12 @@ EMB_DIM = 128
def twhin_model_config() -> TwhinModelConfig:
"""
Create a configuration for the Twhin model.
Returns:
TwhinModelConfig: The Twhin model configuration.
"""
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig:
def twhin_data_config() -> TwhinDataConfig:
"""
Create a configuration for the Twhin data.
Returns:
TwhinDataConfig: The Twhin data configuration.
"""
data_config = TwhinDataConfig(
data_root="/",
per_replica_batch_size=10,
@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig:
def test_twhin_model():
"""
Test the Twhin model creation and optimization.
This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
the device placement of model parameters.
Returns:
None
"""
model_config = twhin_model_config()
loss_fn = F.binary_cross_entropy_with_logits
@ -84,6 +105,23 @@ def test_twhin_model():
def test_unequal_dims():
"""
Test function for validating unequal embedding dimensions in TwhinEmbeddingsConfig.
This function tests whether the validation logic correctly raises a `ValidationError` when
embedding dimensions in the `TwhinEmbeddingsConfig` are not equal for all tables.
The test includes the following steps:
1. Create two embedding configurations with different embedding dimensions.
2. Attempt to create a `TwhinEmbeddingsConfig` instance with the unequal embedding dimensions.
3. Assert that a `ValidationError` is raised, indicating that embedding dimensions must match.
This function serves as a test case to ensure that the validation logic enforces equal embedding dimensions
in the `TwhinEmbeddingsConfig` for all tables.
Raises:
AssertionError: If the expected `ValidationError` is not raised.
"""
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
table0 = EmbeddingBagConfig(

View File

@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt"
def _lr_from_config(optimizer_config):
"""Get the learning rate from an optimizer configuration.
Args:
optimizer_config: Optimizer configuration.
Returns:
Learning rate from the optimizer configuration.
"""
if optimizer_config.learning_rate is not None:
return optimizer_config.learning_rate
else:
@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config):
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
Args:
model: TwhinModel to build optimizer for.
config: TwhinConfig for model.
Args:
model: TwhinModel to build optimizer for.
config: TwhinModelConfig for model.
Returns:
Optimizer for model.
"""
Returns:
Optimizer for model.
"""
translation_optimizer_fn = functools.partial(
get_optimizer_class(config.translation_optimizer),
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),

View File

@ -37,6 +37,12 @@ def run(
all_config: TwhinConfig,
save_dir: Optional[str] = None,
):
"""Run the training process for TwhinModel.
Args:
all_config (TwhinConfig): The configuration for the entire Twhin model.
save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None.
"""
train_dataset = create_dataset(all_config.train_data, all_config.model)
if env.is_reader():
@ -80,6 +86,11 @@ def run(
def main(argv):
"""Main entry point for the Twhin training script.
Args:
argv: Command-line arguments.
"""
logging.info("Starting")
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")

View File

@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging
class _Reader(pa.flight.FlightServerBase):
"""Distributed reader flight server wrapping a dataset."""
"""
Distributed reader flight server wrapping a dataset.
This class implements a Flight server that wraps a dataset, allowing clients to retrieve data
from the dataset over the Flight protocol. It is designed to be used in a distributed environment
for efficient data access.
Args:
location (str): The location of the Flight server.
ds (Dataset): The dataset to be wrapped by the Flight server.
Attributes:
_location (str): The location of the Flight server.
_ds (Dataset): The dataset wrapped by the Flight server.
Methods:
do_get(_, __): Handles Flight requests for data retrieval.
Note:
Flight is an Apache Arrow project that provides a framework for efficient data transfer.
This class allows clients to retrieve data from the dataset using Flight.
"""
def __init__(self, location: str, ds: "Dataset"):
"""
Initialize a new _Reader instance.
Args:
location (str): The location of the Flight server.
ds (Dataset): The dataset to be wrapped by the Flight server.
"""
super().__init__(location=location)
self._location = location
self._ds = ds
def do_get(self, _, __):
"""
Handle Flight requests for data retrieval.
This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol.
Args:
_: Unused argument.
__: Unused argument.
Returns:
pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset.
Note:
An updated schema (to account for column selection) must be given to the stream.
"""
# NB: An updated schema (to account for column selection) has to be given the stream.
schema = next(iter(self._ds.to_batches())).schema
batches = self._ds.to_batches()
@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase):
class Dataset(torch.utils.data.IterableDataset):
"""
A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading.
This class enables efficient loading of data from Parquet files using PyArrow.
It is designed to be used as an IterableDataset in PyTorch for training and inference.
Args:
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset
for more details.
Attributes:
LOCATION (str): The default location for the Flight server used for data distribution.
_file_pattern (str): The glob pattern specifying Parquet files in the dataset.
_fs: The filesystem object used for file operations.
_dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method.
_files (list): A list of file paths matching the glob pattern.
_schema (pa.Schema): The schema of the Parquet dataset.
Methods:
serve(): Start serving the dataset using a Flight server.
to_batches(): Generate batches of data from the Parquet dataset.
pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch.
dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset.
Note:
This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch
to create DataLoader instances for training or inference.
"""
LOCATION = "grpc://0.0.0.0:2222"
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
"""Specify batch size and column to select for.
"""
Initialize a new Dataset instance. Specify batch size and column to select for.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
"""
Args:
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
"""
self._file_pattern = file_pattern
self._fs = infer_fs(self._file_pattern)
self._dataset_kwargs = dataset_kwargs
@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset):
self._validate_columns()
def _validate_columns(self):
"""
Validate the specified columns against the dataset schema.
Raises:
Exception: If any specified columns are not found in the dataset schema.
"""
columns = set(self._dataset_kwargs.get("columns", []))
wrong_columns = set(columns) - set(self._schema.names)
if wrong_columns:
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
def serve(self):
"""Start serving the dataset using a Flight server."""
self.reader = _Reader(location=self.LOCATION, ds=self)
self.reader.serve()
def _create_dataset(self):
"""Create a PyArrow dataset for data retrieval."""
return pads.dataset(
source=random.sample(self._files, len(self._files))[0],
format="parquet",
@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset):
@abc.abstractmethod
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
"""
Convert a Parquet RecordBatch to a custom data batch.
Args:
batch (pa.RecordBatch): A batch of data from the Parquet dataset.
Returns:
DataclassBatch: A custom data batch used in PyTorch training.
Raises:
NotImplementedError: This method must be implemented in derived classes.
"""
raise NotImplementedError
def dataloader(self, remote: bool = False):
"""
Create a PyTorch DataLoader for iterating through the dataset.
Args:
remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training.
Returns:
DataLoader: A PyTorch DataLoader for iterating through the dataset.
Note:
If `remote` is True, a remote DataLoader is created for distributed training using Flight.
"""
if not remote:
return map(self.pa_to_batch, self.to_batches())
readers = get_readers(2)
@ -117,6 +230,25 @@ GRPC_OPTIONS = [
def get_readers(num_readers_per_worker: int):
"""
Get Flight readers for distributed data loading.
This function retrieves Flight readers for distributed data loading in a PyTorch environment.
Args:
num_readers_per_worker (int): The number of Flight readers to retrieve per worker.
Returns:
List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading.
Note:
Flight readers are used to fetch data in a distributed manner for efficient data loading.
Example:
To obtain Flight readers, use the following code:
>>> readers = get_readers(num_readers_per_worker=2)
"""
addresses = env.get_flight_server_addresses()
readers = []

View File

@ -21,6 +21,16 @@ import torch.distributed as dist
def maybe_start_dataset_service():
"""
Start the dataset service if readers are available and required dependencies are met.
This function checks if readers are available and if the required TensorFlow version is >= 2.5.
If both conditions are met and the current environment is the dispatcher or reader, it starts
the TensorFlow dataset service.
Raises:
Exception: If the required TensorFlow version is not met (>= 2.5).
"""
if not env.has_readers():
return
@ -59,6 +69,24 @@ def maybe_start_dataset_service():
def register_dataset(
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
):
"""
Register a dataset with the distributed dataset service.
This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
and job name to all processes in the distributed environment.
Args:
dataset (tf.data.Dataset): The dataset to be registered.
dataset_service (str): The name of the dataset service.
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
Returns:
Tuple[int, str]: A tuple containing the dataset ID and job name.
Note:
This function should be called on the rank 0 process.
"""
if dist.get_rank() == 0:
dataset_id = _register_dataset(
service=dataset_service,
@ -82,6 +110,23 @@ def distribute_from_dataset_id(
compression: Optional[str] = "AUTO",
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
) -> tf.data.Dataset:
"""
Distribute a dataset from a registered dataset ID.
This function consumes a dataset from the distributed dataset service using the provided dataset ID
and job name. It also supports prefetching for improved performance.
Args:
dataset_service (str): The name of the dataset service.
dataset_id (int): The ID of the dataset to be consumed.
job_name (Optional[str]): The name of the job associated with the dataset (optional).
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
Returns:
tf.data.Dataset: The distributed dataset.
"""
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
dataset = _from_dataset_id(
processing_mode="parallel_epochs",
@ -97,15 +142,28 @@ def distribute_from_dataset_id(
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
"""Torch-compatible and distributed-training-aware dataset service distributor.
- rank 0 process will register the given dataset.
- rank 0 process will broadcast job name and dataset id.
- all rank processes will consume from the same job/dataset.
Without this, dataset workers will try to serve 1 job per rank process and OOM.
"""
Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
This function is used to distribute a dataset in a distributed training environment. It performs the
following steps:
- On the rank 0 process, it registers the given dataset with the distributed dataset service.
- It broadcasts the job name and dataset ID to all rank processes.
- All rank processes then consume the same dataset from the distributed dataset service.
Args:
dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
Returns:
tf.data.Dataset: The distributed TensorFlow dataset.
Note:
- If there are no reader processes in the distributed environment, the original dataset is returned
without any distribution.
- This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
issues caused by each rank process trying to serve one job.
"""
if not env.has_readers():
return dataset
dataset_service = env.get_dds()

View File

@ -12,6 +12,17 @@ import torch
def create_dataset(tmpdir):
"""
Create a mock dataset for testing.
This function creates a mock dataset using PyArrow and Parquet for testing purposes.
Args:
tmpdir: A temporary directory where the dataset will be created.
Returns:
MockDataset: A mock dataset for testing.
"""
table = pa.table(
{
@ -34,6 +45,14 @@ def create_dataset(tmpdir):
def test_dataset(tmpdir):
"""
Test the created dataset.
This function tests the created mock dataset and checks if it behaves as expected.
Args:
tmpdir: A temporary directory used for testing.
"""
ds = create_dataset(tmpdir)
batch = next(iter(ds.dataloader(remote=False)))
assert batch.batch_size == 2
@ -46,6 +65,14 @@ def test_dataset(tmpdir):
reason="Multiprocessing doesn't work on github yet.",
)
def test_distributed_dataset(tmpdir):
"""
Test the distributed dataset.
This function tests the distributed version of the mock dataset using multiprocessing.
Args:
tmpdir: A temporary directory used for testing.
"""
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
def _client():

View File

@ -11,11 +11,55 @@ import torch
def roundrobin(*iterables):
"""Round robin through provided iterables, useful for simple load balancing.
Adapted from https://docs.python.org/3/library/itertools.html.
"""
Iterate through provided iterables in a round-robin fashion.
This function takes multiple iterables and returns an iterator that yields elements from
each iterable in a round-robin manner. It continues cycling through the iterables until
all of them are exhausted.
Adapted from https://docs.python.org/3/library/itertools.html.
Args:
*iterables: One or more iterable objects to iterate through.
Yields:
Elements from the provided iterables in a round-robin fashion.
Raises:
StopIteration: If all provided iterables are exhausted.
Example:
```python
iterable1 = [1, 2, 3]
iterable2 = ['a', 'b', 'c']
iterable3 = [0.1, 0.2, 0.3]
for item in roundrobin(iterable1, iterable2, iterable3):
print(item)
# Output:
# 1
# 'a'
# 0.1
# 2
# 'b'
# 0.2
# 3
# 'c'
# 0.3
```
Note:
- If one of the provided iterables is shorter than the others, the function will
continue iterating through the remaining iterables until all are exhausted.
- If an iterable raises an exception during iteration, a warning message is logged,
and the function continues with the next iterable.
See Also:
- `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
- `itertools.islice`: A function to slice an iterable to limit the number of iterations.
"""
num_active = len(iterables)
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
while num_active:
@ -35,6 +79,48 @@ def roundrobin(*iterables):
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
"""
Monitor the speed and progress of data loading using a data loader.
This function iterates through a data loader for a specified number of steps or until
the end of the data loader is reached, periodically logging progress information.
Args:
data_loader: The data loader to monitor.
max_steps: The maximum number of steps to iterate through the data loader.
frequency: The frequency (in steps) at which to log progress.
peek (optional): If specified, it indicates the frequency (in steps) at which to log
batch contents for inspection.
Example:
```python
import torch
from torch.utils.data import DataLoader
# Create a data loader (replace with your own DataLoader configuration)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Monitor data loading speed and progress
speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
```
Args:
data_loader: The data loader to monitor.
max_steps: The maximum number of steps to iterate through the data loader.
frequency: The frequency (in steps) at which to log progress.
peek (optional): If specified, it indicates the frequency (in steps) at which to log
batch contents for inspection.
Note:
- The function logs information about elapsed time, the number of examples processed,
and the processing speed in examples per second.
- If `peek` is provided, batch contents will be logged for inspection at the specified
frequency.
See Also:
- `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
iterating through datasets.
"""
num_examples = 0
prev = time.perf_counter()
for idx, batch in enumerate(data_loader):
@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]
def pa_to_torch(array: pa.array) -> torch.Tensor:
"""
Convert a PyArrow Array to a PyTorch Tensor.
Args:
array (pa.array): The PyArrow Array to convert.
Returns:
torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
Example:
```python
import pyarrow as pa
import torch
# Create a PyArrow Array
arrow_array = pa.array([1, 2, 3])
# Convert it to a PyTorch Tensor
torch_tensor = pa_to_torch(arrow_array)
```
"""
return torch.from_numpy(array.to_numpy())
def create_default_pa_to_batch(schema) -> DataclassBatch:
""" """
"""
Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
Args:
schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
Returns:
callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
Example:
```python
import pyarrow as pa
from dataclass_batch import DataclassBatch
# Define a PyArrow schema
schema = pa.schema([
("feature1", pa.float64()),
("feature2", pa.int64()),
("label", pa.int64()),
])
# Create the conversion function
pa_to_batch = create_default_pa_to_batch(schema)
# Create a PyArrow RecordBatch
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
"feature1": [1.0, 2.0, None],
"feature2": [10, 20, 30],
"label": [0, 1, None],
}))
# Convert the RecordBatch to a custom DataclassBatch
custom_batch = pa_to_batch(record_batch)
```
"""
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
def get_imputation_value(pa_type):

View File

@ -38,6 +38,15 @@ import pyarrow.parquet as pq
def _create_dataset(path: str):
"""
Create a PyArrow dataset from Parquet files located at the specified path.
Args:
path (str): The path to the Parquet files.
Returns:
pyarrow.dataset.Dataset: The PyArrow dataset.
"""
fs = infer_fs(path)
files = fs.glob(path)
return pads.dataset(files, format="parquet", filesystem=fs)
@ -47,12 +56,27 @@ class PqReader:
def __init__(
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
):
"""
Initialize a Parquet Reader.
Args:
path (str): The path to the Parquet files.
num (int): The maximum number of rows to read.
batch_size (int): The batch size for reading data.
columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns).
"""
self._ds = _create_dataset(path)
self._batch_size = batch_size
self._num = num
self._columns = columns
def __iter__(self):
"""
Iterate through the Parquet data and yield batches of rows.
Yields:
pyarrow.RecordBatch: A batch of rows.
"""
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
rows_seen = 0
for count, record in enumerate(batches):
@ -62,6 +86,12 @@ class PqReader:
rows_seen += record.data.num_rows
def _head(self):
"""
Get the first `num` rows of the Parquet data.
Returns:
pyarrow.RecordBatch: A batch of rows.
"""
total_read = self._num * self.bytes_per_row
if total_read >= int(500e6):
raise Exception(
@ -71,6 +101,12 @@ class PqReader:
@property
def bytes_per_row(self) -> int:
"""
Calculate the estimated bytes per row in the dataset.
Returns:
int: The estimated bytes per row.
"""
nbits = 0
for t in self._ds.schema.types:
try:
@ -81,18 +117,23 @@ class PqReader:
return nbits // 8
def schema(self):
"""
Display the schema of the Parquet dataset.
"""
print(f"\n# Schema\n{self._ds.schema}")
def head(self):
"""Displays first --num rows."""
"""
Display the first `num` rows of the Parquet data as a pandas DataFrame.
"""
print(self._head().to_pandas())
def distinct(self):
"""Displays unique values seen in specified columns in the first `--num` rows.
Useful for getting an approximate vocabulary for certain columns.
"""
Display unique values seen in specified columns in the first `num` rows.
Useful for getting an approximate vocabulary for certain columns.
"""
for col_name, column in zip(self._head().column_names, self._head().columns):
print(col_name)
print("unique:", column.unique().to_pylist())