This commit is contained in:
Rockerz 2023-09-21 17:35:12 +00:00 committed by GitHub
commit a8258a5da4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 4507 additions and 1233 deletions

View File

@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable
class BatchBase(Pipelineable, abc.ABC): class BatchBase(Pipelineable, abc.ABC):
"""
A base class for batches used in pipelines.
Attributes:
None
"""
@abc.abstractmethod @abc.abstractmethod
def as_dict(self) -> Dict: def as_dict(self) -> Dict:
"""
Convert the batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
raise NotImplementedError raise NotImplementedError
def to(self, device: torch.device, non_blocking: bool = False): def to(self, device: torch.device, non_blocking: bool = False):
"""
Move the batch to the specified device.
Args:
device (torch.device): The target device.
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.
Returns:
BatchBase: A new batch on the target device.
"""
args = {} args = {}
for feature_name, feature_value in self.as_dict().items(): for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking) args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
return self.__class__(**args) return self.__class__(**args)
def record_stream(self, stream: torch.cuda.streams.Stream) -> None: def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
"""
Record a CUDA stream for all tensors in the batch.
Args:
stream (torch.cuda.streams.Stream): The CUDA stream to record.
Returns:
None
"""
for feature_value in self.as_dict().values(): for feature_value in self.as_dict().values():
feature_value.record_stream(stream) feature_value.record_stream(stream)
def pin_memory(self): def pin_memory(self):
"""
Pin memory for all tensors in the batch.
Returns:
BatchBase: A new batch with pinned memory.
"""
args = {} args = {}
for feature_name, feature_value in self.as_dict().items(): for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory() args[feature_name] = feature_value.pin_memory()
return self.__class__(**args) return self.__class__(**args)
def __repr__(self) -> str: def __repr__(self) -> str:
"""
Generate a string representation of the batch.
Returns:
str: A string representation of the batch.
"""
def obj2str(v): def obj2str(v):
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC):
@property @property
def batch_size(self) -> int: def batch_size(self) -> int:
"""
Get the batch size from the tensors in the batch.
Returns:
int: The batch size.
Raises:
Exception: If the batch size cannot be determined from the tensors.
"""
for tensor in self.as_dict().values(): for tensor in self.as_dict().values():
if tensor is None: if tensor is None:
continue continue
@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC):
@dataclass @dataclass
class DataclassBatch(BatchBase): class DataclassBatch(BatchBase):
"""
A batch class that uses dataclasses to define its fields.
Attributes:
None
"""
@classmethod @classmethod
def feature_names(cls): def feature_names(cls):
"""
Get the feature names of the dataclass.
Returns:
List[str]: A list of feature names.
"""
return list(cls.__dataclass_fields__.keys()) return list(cls.__dataclass_fields__.keys())
def as_dict(self): def as_dict(self):
"""
Convert the dataclass batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
"""
return { return {
feature_name: getattr(self, feature_name) feature_name: getattr(self, feature_name)
for feature_name in self.feature_names() for feature_name in self.feature_names()
@ -64,7 +147,18 @@ class DataclassBatch(BatchBase):
@staticmethod @staticmethod
def from_schema(name: str, schema): def from_schema(name: str, schema):
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" """
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.
Args:
name (str): The name of the custom batch class.
schema: The schema or structure of the batch.
Returns:
Type[DataclassBatch]: A custom batch class.
"""
return dataclasses.make_dataclass( return dataclasses.make_dataclass(
cls_name=name, cls_name=name,
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names], fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
@ -73,6 +167,17 @@ class DataclassBatch(BatchBase):
@staticmethod @staticmethod
def from_fields(name: str, fields: dict): def from_fields(name: str, fields: dict):
"""
Create a custom batch subclass from a set of fields.
Args:
name (str): The name of the custom batch class.
fields (dict): A dictionary specifying the fields and their types.
Returns:
Type[DataclassBatch]: A custom batch class.
"""
return dataclasses.make_dataclass( return dataclasses.make_dataclass(
cls_name=name, cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()], fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
@ -81,5 +186,19 @@ class DataclassBatch(BatchBase):
class DictionaryBatch(BatchBase, dict): class DictionaryBatch(BatchBase, dict):
"""
A batch class that represents data as a dictionary.
Attributes:
None
"""
def as_dict(self) -> Dict: def as_dict(self) -> Dict:
"""
Convert the dictionary batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
"""
return self return self

View File

@ -13,23 +13,30 @@ GCS_PREFIX = "gs://"
class Snapshot: class Snapshot:
"""Checkpoints using torchsnapshot. """
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
Also saves step to be updated by the training loop.
""" """
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None: def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
"""
Initializes a Snapshot object.
Args:
save_dir (str): Directory where checkpoints will be saved.
state (Dict[str, Any]): State dictionary containing checkpoint information.
"""
self.save_dir = save_dir self.save_dir = save_dir
self.state = state self.state = state
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0) self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property @property
def step(self): def step(self):
"""Get the current training step."""
return self.state["extra_state"]["step"] return self.state["extra_state"]["step"]
@step.setter @step.setter
def step(self, step: int) -> None: def step(self, step: int) -> None:
"""Set the current training step."""
self.state["extra_state"]["step"] = step self.state["extra_state"]["step"] = step
@property @property
@ -41,7 +48,15 @@ class Snapshot:
self.state["extra_state"]["walltime"] = walltime self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot": def save(self, global_step: int) -> "PendingSnapshot":
"""Saves checkpoint with given global_step.""" """
Saves a checkpoint with a given global step.
Args:
global_step (int): The global step to associate with the checkpoint.
Returns:
PendingSnapshot: A pending snapshot object.
"""
path = os.path.join(self.save_dir, str(global_step)) path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.") logging.info(f"Saving snapshot global_step {global_step} to {path}.")
start_time = time.time() start_time = time.time()
@ -58,7 +73,12 @@ class Snapshot:
return snapshot return snapshot
def restore(self, checkpoint: str) -> None: def restore(self, checkpoint: str) -> None:
"""Restores a given checkpoint.""" """
Restores a given checkpoint.
Args:
checkpoint (str): Path to the checkpoint to restore.
"""
snapshot = torchsnapshot.Snapshot(path=checkpoint) snapshot = torchsnapshot.Snapshot(path=checkpoint)
logging.info(f"Restoring snapshot from {snapshot.path}.") logging.info(f"Restoring snapshot from {snapshot.path}.")
start_time = time.time() start_time = time.time()
@ -83,11 +103,16 @@ class Snapshot:
global_step: Optional[int] = None, global_step: Optional[int] = None,
missing_ok: bool = False, missing_ok: bool = False,
) -> torchsnapshot.Snapshot: ) -> torchsnapshot.Snapshot:
"""Get torch stateless snapshot, without actually loading it. """
Get a torch stateless snapshot, without actually loading it.
Args: Args:
snapshot_path: path to the model snapshot snapshot_path (str): Path to the model snapshot.
global_step: restores from this checkpoint if specified. global_step (int, optional): Restores from this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration. missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
Returns:
torchsnapshot.Snapshot: A torch snapshot object.
""" """
path = get_checkpoint(snapshot_path, global_step, missing_ok) path = get_checkpoint(snapshot_path, global_step, missing_ok)
logging.info(f"Loading snapshot from {path}.") logging.info(f"Loading snapshot from {path}.")
@ -100,12 +125,13 @@ class Snapshot:
snapshot_emb_name: str, snapshot_emb_name: str,
weight_tensor, weight_tensor,
) -> None: ) -> None:
"""Loads pretrained embedding from the snapshot to the model. """
Utilise partial lodaing meachanism from torchsnapshot. Loads pretrained embedding from the snapshot to the model.
Args: Args:
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC). embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC. snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded. weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
""" """
start_time = time.time() start_time = time.time()
manifest = embedding_snapshot.get_manifest() manifest = embedding_snapshot.get_manifest()
@ -209,7 +235,22 @@ def get_checkpoint(
def get_checkpoints(save_dir: str) -> List[str]: def get_checkpoints(save_dir: str) -> List[str]:
"""Gets all checkpoints that have been fully written.""" """
Get a list of fully written checkpoints in the specified directory.
This function retrieves a list of fully written checkpoints in the given directory.
Checkpoints that are considered fully written include those that have a
corresponding snapshot metadata file.
Args:
save_dir (str): The directory where checkpoints are stored.
Returns:
List[str]: A list of fully written checkpoint paths.
Note:
Checkpoints are sorted by their numeric filenames in ascending order.
"""
checkpoints = [] checkpoints = []
fs = infer_fs(save_dir) fs = infer_fs(save_dir)
if fs.exists(save_dir): if fs.exists(save_dir):
@ -232,6 +273,18 @@ def wait_for_evaluators(
global_step: int, global_step: int,
timeout: int, timeout: int,
) -> None: ) -> None:
"""
Waits for all evaluators to finish and checks for their completion status.
Args:
save_dir (str): Directory where checkpoints are saved.
partition_names (List[str]): List of partition names to check for completion.
global_step (int): The global step for which to wait for evaluators.
timeout (int): Maximum time in seconds to wait for evaluators to finish.
Returns:
None: This function returns nothing but logs the progress and results.
"""
logging.info("Waiting for all evaluators to finish.") logging.info("Waiting for all evaluators to finish.")
start_time = time.time() start_time = time.time()

View File

@ -5,6 +5,15 @@ import torch.distributed as dist
def maybe_setup_tensorflow(): def maybe_setup_tensorflow():
"""
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
Returns:
None
"""
try: try:
import tensorflow as tf import tensorflow as tf
except ImportError: except ImportError:
@ -14,6 +23,18 @@ def maybe_setup_tensorflow():
def setup_and_get_device(tf_ok: bool = True) -> torch.device: def setup_and_get_device(tf_ok: bool = True) -> torch.device:
"""
Set up the distributed environment and get the appropriate torch device.
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
Args:
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
Returns:
torch.device: The torch device for the current process.
"""
if tf_ok: if tf_ok:
maybe_setup_tensorflow() maybe_setup_tensorflow()

View File

@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem()
def infer_fs(path: str): def infer_fs(path: str):
"""
Infer the file system (fs) type based on the given path.
Args:
path (str): The file path.
Returns:
str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system).
Raises:
NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported.
"""
if path.startswith("gs://"): if path.startswith("gs://"):
return GCS_FS return GCS_FS
elif path.startswith("hdfs://"): elif path.startswith("hdfs://"):
@ -18,8 +30,26 @@ def infer_fs(path: str):
def is_local_fs(fs): def is_local_fs(fs):
"""
Check if the given file system is the local file system.
Args:
fs (str): The file system type to check.
Returns:
bool: True if the file system is the local file system, False otherwise.
"""
return fs == LOCAL_FS return fs == LOCAL_FS
def is_gcs_fs(fs): def is_gcs_fs(fs):
"""
Check if the given file system is Google Cloud Storage (GCS).
Args:
fs (str): The file system type to check.
Returns:
bool: True if the file system is GCS, False otherwise.
"""
return fs == GCS_FS return fs == GCS_FS

View File

@ -49,14 +49,24 @@ def log_ebc_norms(
ebc_keys: List[str], ebc_keys: List[str],
sample_size: int = 4_000_000, sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Logs the norms of the embedding tables as specified by ebc_keys. """
Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank. As of now, log average norm per rank.
Args: Args:
model_state_dict: model.state_dict() model_state_dict (dict): The state dictionary of the model.
ebc_keys: list of embedding keys from state_dict to log. Must contain full name, ebc_keys (List[str]): List of embedding keys from state_dict to log.
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight".
sample_size: Limits number of rows per rank to compute average on to avoid OOM. sample_size (int, optional): Limits the number of rows per rank to compute average on
to avoid out-of-memory (OOM) errors. Defaults to 4,000,000.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables.
The keys are in the format "{emb_key}-norm-{idx}".
Note:
This function computes and logs the average norm of embedding tables across ranks.
It gathers the norms from all ranks and returns them as a dictionary.
""" """
norm_logs = dict() norm_logs = dict()
for emb_key in ebc_keys: for emb_key in ebc_keys:

View File

@ -11,6 +11,20 @@ import numpy as np
class LargeEmbeddings(nn.Module): class LargeEmbeddings(nn.Module):
"""
A module for handling large embeddings.
Args:
large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings.
Attributes:
ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings.
surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models.
Note:
The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving.
"""
def __init__( def __init__(
self, self,
large_embeddings_config: LargeEmbeddingsConfig, large_embeddings_config: LargeEmbeddingsConfig,
@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module):
self, self,
sparse_features: KeyedJaggedTensor, sparse_features: KeyedJaggedTensor,
) -> KeyedTensor: ) -> KeyedTensor:
"""
Forward pass of the LargeEmbeddings module.
Args:
sparse_features (KeyedJaggedTensor): Sparse input features.
Returns:
KeyedTensor: The output of the forward pass, a KeyedTensor.
"""
pooled_embs = self.ebc(sparse_features) pooled_embs = self.ebc(sparse_features)
# a KeyedTensor # a KeyedTensor

View File

@ -11,6 +11,12 @@ import torch.distributed.run
def is_distributed_worker(): def is_distributed_worker():
"""
Checks if the current process is a distributed worker.
Returns:
bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False.
"""
world_size = os.environ.get("WORLD_SIZE", None) world_size = os.environ.get("WORLD_SIZE", None)
rank = os.environ.get("RANK", None) rank = os.environ.get("RANK", None)
return world_size is not None and rank is not None return world_size is not None and rank is not None
@ -25,10 +31,10 @@ def maybe_run_training(
is_chief: Optional[bool] = False, is_chief: Optional[bool] = False,
**training_kwargs, **training_kwargs,
): ):
"""Wrapper function for single node, multi-GPU Pytorch training. """
Wrapper function for single node, multi-GPU PyTorch training.
If the necessary distributed Pytorch environment variables If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes
(WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`. `train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module Otherwise, this function calls torchrun and points at the calling module
@ -36,13 +42,24 @@ def maybe_run_training(
and training will commence. and training will commence.
Args: Args:
train_fn: The function that is responsible for training train_fn (callable): The function responsible for training.
module_name: The name of the module that this function was called from; module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint.
used to indicate torchrun entrypoint. nproc_per_node (int, optional): Number of workers per node. Defaults to None.
nproc_per_node: Number of workers per node; supported values. num_nodes (int, optional): Number of nodes. Defaults to None.
num_nodes: Number of nodes, otherwise inferred from environment. is_chief (bool, optional): If the process is running on the chief node. Defaults to False.
is_chief: If process is running on chief. set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False.
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH. **training_kwargs: Additional keyword arguments to pass to the `train_fn`.
Note:
This function checks if the current process is a distributed worker by examining the environment variables.
If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary
environment variables and launches the training process using torchrun.
Example:
To run training on a single node with 4 GPUs, you can use:
```
maybe_run_training(train_function, __name__, nproc_per_node=4)
```
""" """
machines = utils.machine_from_env() machines = utils.machine_from_env()

View File

@ -19,6 +19,23 @@ MOCK_ENV = {
@contextmanager @contextmanager
def mock_pg(): def mock_pg():
"""
A context manager for mocking the distributed process group for testing purposes.
This context manager temporarily sets environment variables to mock the distributed process group
and initializes it using the Gloo backend. It is useful for testing distributed training without
actually launching multiple processes.
Example:
```python
with mock_pg():
# Your distributed training code here
```
Note:
This context manager should be used within a testing environment to simulate distributed training
without actually creating multiple processes.
"""
with patch.dict(os.environ, MOCK_ENV): with patch.dict(os.environ, MOCK_ENV):
try: try:
dist.init_process_group( dist.init_process_group(

View File

@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig)
def _read_file(f): def _read_file(f):
"""
Read the contents of a file using fsspec.
Args:
f: File path or URL.
Returns:
The contents of the file.
"""
with fsspec.open(f) as f: with fsspec.open(f) as f:
return f.read() return f.read()
@ -21,17 +30,23 @@ def setup_configuration(
yaml_path: str, yaml_path: str,
substitute_env_variable: bool = False, substitute_env_variable: bool = False,
) -> Tuple[C, str]: ) -> Tuple[C, str]:
"""Resolves a config at a yaml path. """
Load a Pydantic config object from a YAML file and optionally substitute environment variables.
Args: Args:
config_type: Pydantic config class to load. config_type: Pydantic config class to load.
yaml_path: yaml path of the config file. yaml_path: YAML path of the config file.
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR}
environment variable value whenever possible. If an environment variable doesn't exist, with their environment variable values whenever possible.
the string is left unchanged. If an environment variable doesn't exist, the string is left unchanged.
Returns: Returns:
The pydantic config object. A tuple containing the Pydantic config object and the resolved YAML content.
Example:
```python
config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True)
```
""" """
def _substitute(s): def _substitute(s):

View File

@ -6,6 +6,33 @@ import pydantic
class WandbConfig(base_config.BaseConfig): class WandbConfig(base_config.BaseConfig):
"""
Configuration for integrating with Weights and Biases (WandB).
Attributes:
host (str): Host of the Weights and Biases instance, passed to login.
key_path (str): Path to the key file.
name (str): Name of the experiment, passed to init.
entity (str): Name of the user/service account, passed to init.
project (str): Name of the WandB project, passed to init.
tags (List[str]): List of tags, passed to init.
notes (str): Notes, passed to init.
metadata (Dict[str, Any]): Additional metadata to log.
Example:
```python
wandb_config = WandbConfig(
host="https://wandb.example.com",
key_path="/path/to/key",
name="experiment_1",
entity="user123",
project="my_project",
tags=["experiment", "ml"],
notes="This is a test experiment.",
metadata={"version": "1.0"}
)
```
"""
host: str = pydantic.Field( host: str = pydantic.Field(
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/", "https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
description="Host of Weights and Biases instance, passed to login.", description="Host of Weights and Biases instance, passed to login.",

View File

@ -10,32 +10,41 @@ import pydantic
class BaseConfig(pydantic.BaseModel): class BaseConfig(pydantic.BaseModel):
"""Base class for all derived config classes. """Base class for all derived config classes.
This class provides some convenient functionality: This class provides convenient functionality and constraints for derived config classes:
- Disallows extra fields when constructing an object. User error
should be reduced by exact arguments.
- "one_of" fields. A subclass can group optional fields and enforce
that only one of the fields be set. For example:
``` - Disallows extra fields when constructing an object. User errors due to extraneous arguments
are minimized.
- "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields
be set. For example:
```python
class ExampleConfig(BaseConfig): class ExampleConfig(BaseConfig):
x: int = Field(None, one_of="group_1") x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1") y: int = Field(None, one_of="group_1")
ExampleConfig(x=1) # ok ExampleConfig(x=1) # OK
ExampleConfig(y=1) # ok ExampleConfig(y=1) # OK
ExampleConfig(x=1, y=1) # throws error ExampleConfig(x=1, y=1) # Raises an error
``` ```
Attributes:
Config (class): Configuration options for this class, forbidding extra fields.
Methods:
_field_data_map(cls, field_data_name): Create a map of fields with the provided field data.
_one_of_check(cls, values): Validate that all 'one of' fields appear exactly once.
_at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once.
pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging.
""" """
class Config: class Config:
"""Forbids extras.""" """Configuration options that forbid extra fields."""
extra = pydantic.Extra.forbid # noqa extra = pydantic.Extra.forbid # noqa
@classmethod @classmethod
@functools.lru_cache() @functools.lru_cache()
def _field_data_map(cls, field_data_name): def _field_data_map(cls, field_data_name):
"""Create a map of fields with provided the field data.""" """Create a map of fields with the provided field data."""
schema = cls.schema() schema = cls.schema()
one_of = collections.defaultdict(list) one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items(): for field, fdata in schema["properties"].items():
@ -45,7 +54,7 @@ class BaseConfig(pydantic.BaseModel):
@pydantic.root_validator @pydantic.root_validator
def _one_of_check(cls, values): def _one_of_check(cls, values):
"""Validate that all 'one of' fields are appear exactly once.""" """Validate that all 'one of' fields appear exactly once."""
one_of_map = cls._field_data_map("one_of") one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items(): for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1: if sum([values.get(n, None) is not None for n in field_names]) != 1:
@ -59,8 +68,9 @@ class BaseConfig(pydantic.BaseModel):
for one_of, field_names in at_most_one_of_map.items(): for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1: if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.") raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values return values
def pretty_print(self) -> str: def pretty_print(self) -> str:
"""Return a human legible (yaml) representation of the config useful for logging.""" """Return a human-readable (YAML) representation of the config useful for logging."""
return yaml.dump(self.dict()) return yaml.dump(self.dict())

View File

@ -6,7 +6,21 @@ import pydantic
class BaseConfigTest(TestCase): class BaseConfigTest(TestCase):
"""
Unit tests for the BaseConfig class.
"""
def test_extra_forbidden(self): def test_extra_forbidden(self):
"""
Test that extra fields are forbidden when creating a Config instance.
This test case checks whether the `BaseConfig` class correctly raises a
`pydantic.ValidationError` when extra fields are provided when creating a
`Config` instance.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig): class Config(BaseConfig):
x: int x: int
@ -15,6 +29,16 @@ class BaseConfigTest(TestCase):
Config(x=1, y=2) Config(x=1, y=2)
def test_one_of(self): def test_one_of(self):
"""
Test the use of the `one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `one_of` attribute in a `Config`
instance. It verifies that the `pydantic.Field` correctly enforces the
specified constraint.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig): class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f") x: int = pydantic.Field(None, one_of="f")
y: int = pydantic.Field(None, one_of="f") y: int = pydantic.Field(None, one_of="f")
@ -27,6 +51,16 @@ class BaseConfigTest(TestCase):
Config(x=1, y=3) Config(x=1, y=3)
def test_at_most_one_of(self): def test_at_most_one_of(self):
"""
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `at_most_one_of` attribute in a
`Config` instance. It verifies that the `pydantic.Field` enforces the
constraint where at most one of the specified fields can be provided.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig): class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f") x: int = pydantic.Field(None, at_most_one_of="f")
y: str = pydantic.Field(None, at_most_one_of="f") y: str = pydantic.Field(None, at_most_one_of="f")

View File

@ -8,10 +8,41 @@ from tml.core.config.base_config import BaseConfig
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str): def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
"""Recommend method to load a config file (a yaml file) and parse it. """
Recommend method to Load and parse a configuration from a YAML file.
This function loads a configuration from a YAML file, parses it, and returns an instance of the
specified config type.
Because we have a shared filesystem the recommended route to running jobs it put modified config Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them. files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
Args:
config_type (Type[BaseConfig]): The Pydantic config class to load.
yaml_path (str): The path to the YAML configuration file.
Returns:
BaseConfig: An instance of the specified config type populated with values from the YAML file.
Example:
Suppose you have a YAML file 'my_config.yaml' containing the following:
```yaml
x: 42
y: "hello"
```
You can load and parse it using this function as follows:
```python
my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml')
```
Note:
This function performs environment variable substitution in the YAML file. It replaces
occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable
values. If an environment variable does not exist, the string is left unchanged.
""" """
def _substitute(s): def _substitute(s):

View File

@ -8,11 +8,27 @@ import pydantic
class _PointlessConfig(BaseConfig): class _PointlessConfig(BaseConfig):
a: int a: int
user: str user: str
def test_load_config_from_yaml(tmp_path): def test_load_config_from_yaml(tmp_path):
"""Test loading a configuration from a YAML file and verifying its values.
This test function checks the functionality of the `load_config_from_yaml` function by creating
a temporary YAML configuration file, loading it, and asserting that the loaded config object
has the expected values.
Args:
tmp_path: A temporary directory provided by the `pytest` framework.
Test Steps:
1. Create a temporary YAML file containing configuration data.
2. Use the `load_config_from_yaml` function to load the configuration from the YAML file.
3. Assert that the loaded configuration object has the expected values.
"""
yaml_path = tmp_path.joinpath("test.yaml").as_posix() yaml_path = tmp_path.joinpath("test.yaml").as_posix()
with open(yaml_path, "w") as yaml_file: with open(yaml_path, "w") as yaml_file:
yaml_file.write("""a: 3\nuser: ${USER}\n""") yaml_file.write("""a: 3\nuser: ${USER}\n""")

View File

@ -9,6 +9,14 @@ import pydantic
class RuntimeConfig(base_config.BaseConfig): class RuntimeConfig(base_config.BaseConfig):
"""
Configuration for runtime settings.
Attributes:
- wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration.
- enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices.
- enable_amp (bool): Enable automatic mixed precision.
"""
wandb: WandbConfig = pydantic.Field(None) wandb: WandbConfig = pydantic.Field(None)
enable_tensorfloat32: bool = pydantic.Field( enable_tensorfloat32: bool = pydantic.Field(
False, description="Use tensorfloat32 if on Ampere devices." False, description="Use tensorfloat32 if on Ampere devices."
@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig):
class TrainingConfig(base_config.BaseConfig): class TrainingConfig(base_config.BaseConfig):
"""
Configuration for training settings.
Attributes:
- save_dir (str): Directory to save checkpoints.
- num_train_steps (pydantic.PositiveInt): Number of training steps.
- initial_checkpoint_dir (str): Directory of initial checkpoints (optional).
- checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps.
- checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
- train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps.
- num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used.
- eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps.
- eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds.
- gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
- num_epochs (pydantic.PositiveInt): Number of training epochs.
"""
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.") save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
num_train_steps: pydantic.PositiveInt = 10000 num_train_steps: pydantic.PositiveInt = 10000
initial_checkpoint_dir: str = pydantic.Field( initial_checkpoint_dir: str = pydantic.Field(

View File

@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable):
def _get_step_fn(pipeline, data_iterator, training: bool): def _get_step_fn(pipeline, data_iterator, training: bool):
"""
Returns a function to perform a single evaluation step.
Args:
pipeline (Pipeline): The pipeline object containing the model.
data_iterator (Iterator): The data iterator for evaluation.
training (bool): Flag indicating if the model should be in training mode.
Returns:
function: A function that performs a single evaluation step.
"""
def step_fn(): def step_fn():
"""
Perform a single evaluation step.
Returns:
Any: The evaluation results after a single step.
"""
# It turns out that model.train() and model.eval() simply switch a single field inside the model # It turns out that model.train() and model.eval() simply switch a single field inside the model
# class,so it's somewhat safer to wrap in here. # class,so it's somewhat safer to wrap in here.
if training: if training:
@ -69,7 +86,21 @@ def _run_evaluation(
eval_batch_size: int, eval_batch_size: int,
logger=None, logger=None,
): ):
"""Runs the evaluation loop over all evaluation iterators.""" """
Run the evaluation loop over all evaluation iterators.
Args:
pipeline (Pipeline): The pipeline object containing the model.
dataset (Dataset): The dataset to evaluate.
eval_steps (int): The number of evaluation steps to perform.
metrics (tm.MetricCollection): A collection of evaluation metrics.
eval_batch_size (int): Batch size for evaluation.
logger (Optional[Logger]): A logger for recording evaluation progress (default: None).
Returns:
dict: A dictionary containing the computed evaluation metrics.
"""
dataset = get_new_iterator(dataset) dataset = get_new_iterator(dataset)
step_fn = _get_step_fn(pipeline, dataset, training=False) step_fn = _get_step_fn(pipeline, dataset, training=False)
last_time = datetime.datetime.now() last_time = datetime.datetime.now()
@ -109,14 +140,28 @@ def train(
parameters_to_log: Optional[Dict[str, Callable]] = None, parameters_to_log: Optional[Dict[str, Callable]] = None,
tables_to_log: Optional[List[str]] = None, tables_to_log: Optional[List[str]] = None,
) -> None: ) -> None:
"""Runs training and eval on the given TrainPipeline """
Runs training and evaluation on the given TrainPipeline.
Args: Args:
dataset: data iterator for the training set model (torch.nn.Module): The neural network model to train.
evaluation_iterators: data iterators for the different evaluation sets optimizer (torch.optim.Optimizer): The optimizer for model optimization.
scheduler: optional learning rate scheduler device (str): The target device for model training (e.g., 'cuda' or 'cpu').
output_transform_for_metrics: optional transformation functions to transorm the model save_dir (str): The directory to save model checkpoints and logs.
output and labels into a format the metrics can understand logging_interval (int): Interval for logging training progress.
train_steps (int): The number of training steps to perform.
checkpoint_frequency (int): Frequency of saving model checkpoints.
dataset (Iterable): Data iterator for the training set.
worker_batch_size (int): Batch size for data loading workers.
num_workers (Optional[int]): Number of data loading workers (default: 0).
enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False).
initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None).
gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None).
logger_initializer (Optional[Callable]): A logger initializer function (default: None).
scheduler (_LRScheduler): Optional learning rate scheduler (default: None).
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None).
tables_to_log (Optional[List[str]]): List of tables to log (default: None).
""" """
train_pipeline = TrainPipelineSparseDist( train_pipeline = TrainPipelineSparseDist(
@ -262,6 +307,15 @@ def log_eval_results(
partition_name: str, partition_name: str,
step: int, step: int,
): ):
"""
Logs evaluation results and optionally records them using a provided logger.
Args:
results (Any): The evaluation results to log.
eval_logger (Callable): A logger for recording evaluation results.
partition_name (str): The name of the evaluation partition.
step (int): The current step in the evaluation.
"""
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results) results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
logging.info(f"Step: {step}, evaluation ({partition_name}).") logging.info(f"Step: {step}, evaluation ({partition_name}).")
for metric_name, metric_value in results.items(): for metric_name, metric_value in results.items():
@ -285,6 +339,23 @@ def only_evaluate(
partition_name: str, partition_name: str,
metrics: Optional[tm.MetricCollection] = None, metrics: Optional[tm.MetricCollection] = None,
): ):
"""
Performs evaluation on a given dataset partition.
Args:
model (torch.nn.Module): The neural network model for evaluation.
optimizer (torch.optim.Optimizer): The optimizer used during evaluation.
device (str): The target device for evaluation (e.g., 'cuda' or 'cpu').
save_dir (str): The directory containing model checkpoints.
num_train_steps (int): The total number of training steps.
dataset (Iterable): Data iterator for evaluation.
eval_batch_size (int): Batch size for evaluation.
num_eval_steps (int): The number of evaluation steps to perform.
eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds.
eval_logger (Callable): A logger for recording evaluation results.
partition_name (str): The name of the evaluation partition.
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
"""
logging.info(f"Evaluating on partition {partition_name}.") logging.info(f"Evaluating on partition {partition_name}.")
logging.info("Computing metrics:") logging.info("Computing metrics:")
logging.info(metrics) logging.info(metrics)

View File

@ -28,6 +28,18 @@ def train(
*args, *args,
**kwargs, **kwargs,
) -> None: ) -> None:
"""
Debugging training loop. Do not use for actual model training.
Args:
model (torch.nn.Module): The neural network model.
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
train_steps (int): The number of training steps to perform.
dataset (Iterable): Data iterator for training data.
scheduler (_LRScheduler, optional): Learning rate scheduler (default: None).
*args: Additional arguments (ignored).
**kwargs: Additional keyword arguments (ignored).
"""
logging.warning("Running debug training loop, don't use for model training.") logging.warning("Running debug training loop, don't use for model training.")

View File

@ -10,7 +10,10 @@ import torch
def _maybe_warn(reduction: str): def _maybe_warn(reduction: str):
""" """
Warning for reduction different than mean. Emit a warning if the reduction method is different from 'mean'.
Args:
reduction (str): The reduction method being used.
""" """
if reduction != "mean": if reduction != "mean":
logging.warn( logging.warn(
@ -24,6 +27,16 @@ def build_loss(
loss_type: LossType, loss_type: LossType,
reduction="mean", reduction="mean",
): ):
"""
Build a loss function based on the specified loss type and reduction method.
Args:
loss_type (LossType): The type of loss to build.
reduction (str): The reduction method for the loss (default: 'mean').
Returns:
Callable: A loss function that takes logits and labels as input.
"""
_maybe_warn(reduction) _maybe_warn(reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type] f = _LOSS_TYPE_TO_FUNCTION[loss_type]
@ -36,9 +49,13 @@ def build_loss(
def get_global_loss_detached(local_loss, reduction="mean"): def get_global_loss_detached(local_loss, reduction="mean"):
""" """
Perform all_reduce to obtain the global loss function using the provided reduction. Perform all_reduce to obtain the global loss function using the provided reduction.
:param local_loss: The local loss of the current rank.
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP. Args:
:return: The reduced & detached global loss. local_loss (torch.Tensor): The local loss of the current rank.
reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP.
Returns:
torch.Tensor: The reduced and detached global loss.
""" """
if reduction != "mean": if reduction != "mean":
logging.warn( logging.warn(
@ -66,6 +83,19 @@ def build_multi_task_loss(
global_reduction="mean", global_reduction="mean",
pos_weights=None, pos_weights=None,
): ):
"""
Build a multi-task loss function based on the specified loss type and configurations.
Args:
loss_type (LossType): The type of loss to build.
tasks (typing.List[str]): List of task names.
task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean').
global_reduction (str): Reduction method for the global loss (default: 'mean').
pos_weights (Optional): Positive class weights for tasks (default: None).
Returns:
Callable: A multi-task loss function that takes logits, labels, and weights as input.
"""
_maybe_warn(global_reduction) _maybe_warn(global_reduction)
_maybe_warn(task_loss_reduction) _maybe_warn(task_loss_reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type] f = _LOSS_TYPE_TO_FUNCTION[loss_type]

View File

@ -36,9 +36,24 @@ import torchmetrics
class MetricMixin: class MetricMixin:
@abstractmethod @abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
"""
Abstract method to transform model outputs into a dictionary of metrics.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
Returns:
Dict: A dictionary of computed metrics.
"""
... ...
def update(self, outputs: Dict[str, torch.Tensor]): def update(self, outputs: Dict[str, torch.Tensor]):
"""
Update the metrics based on model outputs.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
"""
results = self.transform(outputs) results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification. # Do not try to update if any tensor is empty as a result of stratification.
for value in results.values(): for value in results.values():
@ -49,6 +64,13 @@ class MetricMixin:
class TaskMixin: class TaskMixin:
def __init__(self, task_idx: int = -1, **kwargs): def __init__(self, task_idx: int = -1, **kwargs):
"""
Initialize a TaskMixin instance.
Args:
task_idx (int): Index of the task associated with this mixin (default: -1).
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs) super().__init__(**kwargs)
self._task_idx = task_idx self._task_idx = task_idx
@ -59,13 +81,31 @@ class StratifyMixin:
stratifier=None, stratifier=None,
**kwargs, **kwargs,
): ):
"""
Initialize a StratifyMixin instance.
Args:
stratifier: A stratifier for filtering outputs (default: None).
**kwargs: Additional keyword arguments.
"""
super().__init__(**kwargs) super().__init__(**kwargs)
self._stratifier = stratifier self._stratifier = stratifier
def maybe_apply_stratification( def maybe_apply_stratification(
self, outputs: Dict[str, torch.Tensor], value_names: List[str] self, outputs: Dict[str, torch.Tensor], value_names: List[str]
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.""" """
Apply stratification to filter examples in the outputs.
Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
value_names (List[str]): Names of values to filter.
Returns:
Dict[str, torch.Tensor]: Filtered outputs.
"""
outputs = outputs.copy() outputs = outputs.copy()
if not self._stratifier: if not self._stratifier:
return outputs return outputs
@ -84,11 +124,19 @@ class StratifyMixin:
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable): def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
"""Returns new class using MetricMixin and given base_metric. """
Returns a new class using MetricMixin and the given base_metric.
Functionally the same using inheritance, just saves some lines of code Functionally the same as using inheritance, but it saves some lines of code
if no need for class attributes. if there's no need for class attributes.
Args:
base_metric (torchmetrics.Metric): The base metric class to prepend the transform to.
transform (Callable): The transformation function to prepend to the metric.
Returns:
Type: A new class that includes MetricMixin and the provided base_metric
with the specified transformation method.
""" """
def transform_method(_self, *args, **kwargs): def transform_method(_self, *args, **kwargs):

View File

@ -15,6 +15,16 @@ def probs_and_labels(
outputs: Dict[str, torch.Tensor], outputs: Dict[str, torch.Tensor],
task_idx: int, task_idx: int,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""
Extract probabilities and labels from model outputs.
Args:
outputs (Dict[str, torch.Tensor]): Model outputs.
task_idx (int): Index of the task.
Returns:
Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors.
"""
preds = outputs["probabilities"] preds = outputs["probabilities"]
target = outputs["labels"] target = outputs["labels"]
if task_idx >= 0: if task_idx >= 0:
@ -28,6 +38,11 @@ def probs_and_labels(
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
def transform(self, outputs): def transform(self, outputs):
"""
Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric.
This metric counts values after potential stratification and task selection.
"""
outputs = self.maybe_apply_stratification(outputs, ["labels"]) outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"] value = outputs["labels"]
if self._task_idx >= 0: if self._task_idx >= 0:
@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
This metric calculates the mean metric value after potential stratification and task selection.
"""
def transform(self, outputs): def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"]) outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"] value = outputs["labels"]
@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
This metric calculates the mean metric value using probabilities after potential stratification and task selection.
"""
def transform(self, outputs): def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities"]) outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
value = outputs["probabilities"] value = outputs["probabilities"]
@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision): class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
"""
Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision.
This metric computes precision after potential stratification and task selection.
"""
def transform(self, outputs): def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx) return probs_and_labels(outputs, self._task_idx)
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall): class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
"""
Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall.
This metric computes recall after potential stratification and task selection.
"""
def transform(self, outputs): def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx) return probs_and_labels(outputs, self._task_idx)
@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
""" """
AUC (Area Under the ROC Curve) metric class.
This metric computes the AUC metric based on the logits and labels in the model outputs.
Args:
num_samples (int): The number of samples used to compute AUC.
**kwargs: Additional keyword arguments.
Based on: Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420 https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
""" """
@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
""" """
The ranks of all positives PosRanks metric class.
Based on:
This metric computes the ranks of all positive examples based on the logits and labels
in the model outputs.
Args:
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73 https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
""" """
@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
""" """
The reciprocal of the ranks of all ReciprocalRank metric class.
Based on:
This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels
in the model outputs.
Args:
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74 https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
""" """
@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
""" """
The fraction of positives that rank in the top K among their negatives HitAtK metric class.
Note that this is basically precision@k
Based on: This metric computes the fraction of positive examples that rank in the top K among their negatives,
which is equivalent to precision@K.
Args:
k (int): The value of K.
**kwargs: Additional keyword arguments.
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75 https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
""" """

View File

@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric
@dataclass @dataclass
class MockStratifierConfig: class MockStratifierConfig:
"""
Configuration dataclass for mocking a stratifier.
Args:
name (str): The name of the stratifier.
index (int): The index of the stratifier.
value (int): The value of the stratifier.
"""
name: str name: str
index: int index: int
value: int value: int
class Count(MetricMixin, SumMetric): class Count(MetricMixin, SumMetric):
"""
Count metric class that inherits from MetricMixin and SumMetric.
This metric counts occurrences.
"""
def transform(self, outputs): def transform(self, outputs):
return {"value": 1} return {"value": 1}
@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
def test_count_metric(): def test_count_metric():
"""
Test function for the Count metric.
It checks if the Count metric correctly counts the number of examples.
"""
num_examples = 123 num_examples = 123
examples = [ examples = [
{"stuff": 0}, {"stuff": 0},
@ -36,6 +56,12 @@ def test_count_metric():
def test_collections(): def test_collections():
"""
Test function for metric collections.
It tests if metric collections correctly aggregate metrics.
"""
max_metric = Max() max_metric = Max()
count_metric = Count() count_metric = Count()
metric = MetricCollection([max_metric, count_metric]) metric = MetricCollection([max_metric, count_metric])
@ -51,6 +77,12 @@ def test_collections():
def test_task_dependent_ctr(): def test_task_dependent_ctr():
"""
Test function for task-dependent Ctr (Click-Through Rate) metric.
It checks if the Ctr metric computes the correct value for different tasks.
"""
num_examples = 144 num_examples = 144
batch_size = 1024 batch_size = 1024
outputs = [ outputs = [
@ -69,6 +101,13 @@ def test_task_dependent_ctr():
def test_stratified_ctr(): def test_stratified_ctr():
"""
Test function for the Stratified Ctr (Click-Through Rate) metric.
It checks if the Stratified Ctr metric computes the correct value for different tasks
and stratified samples.
"""
outputs = [ outputs = [
{ {
"stuff": 0, "stuff": 0,
@ -114,6 +153,12 @@ def test_stratified_ctr():
def test_auc(): def test_auc():
"""
Test function for the AUC (Area Under the Curve) metric.
It checks if the AUC metric correctly computes the Area Under the ROC Curve.
"""
num_samples = 10000 num_samples = 10000
metric = core_metrics.Auc(num_samples) metric = core_metrics.Auc(num_samples)
target = torch.tensor([0, 0, 1, 1, 1]) target = torch.tensor([0, 0, 1, 1, 1])
@ -131,6 +176,12 @@ def test_auc():
def test_pos_rank(): def test_pos_rank():
"""
Test function for the PosRanks metric.
It checks if the PosRanks metric correctly computes the ranks of positive samples.
"""
metric = core_metrics.PosRanks() metric = core_metrics.PosRanks()
target = torch.tensor([0, 0, 1, 1, 1]) target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
@ -147,6 +198,12 @@ def test_pos_rank():
def test_reciprocal_rank(): def test_reciprocal_rank():
"""
Test function for the Reciprocal Rank metric.
It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks.
"""
metric = core_metrics.ReciprocalRank() metric = core_metrics.ReciprocalRank()
target = torch.tensor([0, 0, 1, 1, 1]) target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
@ -163,6 +220,12 @@ def test_reciprocal_rank():
def test_hit_k(): def test_hit_k():
"""
Test function for the Hit@K metric.
It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives.
"""
hit1_metric = core_metrics.HitAtK(1) hit1_metric = core_metrics.HitAtK(1)
target = torch.tensor([0, 0, 1, 1, 1]) target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5]) preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])

View File

@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel
@dataclass @dataclass
class MockDataclassBatch(DataclassBatch): class MockDataclassBatch(DataclassBatch):
"""
Mock data class batch for testing purposes.
This class represents a batch of data with continuous features and labels.
Attributes:
continuous_features (torch.Tensor): Tensor containing continuous feature data.
labels (torch.Tensor): Tensor containing label data.
"""
continuous_features: torch.Tensor continuous_features: torch.Tensor
labels: torch.Tensor labels: torch.Tensor
class MockModule(torch.nn.Module): class MockModule(torch.nn.Module):
"""
Mock PyTorch module for testing purposes.
This module defines a simple neural network model with a linear layer
followed by a BCEWithLogitsLoss loss function.
Attributes:
model (torch.nn.Linear): The linear model layer.
loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function.
"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.model = torch.nn.Linear(10, 1) self.model = torch.nn.Linear(10, 1)
self.loss_fn = torch.nn.BCEWithLogitsLoss() self.loss_fn = torch.nn.BCEWithLogitsLoss()
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the mock module.
Args:
batch (MockDataclassBatch): Input data batch with continuous features and labels.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions.
"""
pred = self.model(batch.continuous_features) pred = self.model(batch.continuous_features)
loss = self.loss_fn(pred, batch.labels) loss = self.loss_fn(pred, batch.labels)
return (loss, pred) return (loss, pred)
def create_batch(bsz: int): def create_batch(bsz: int):
"""
Create a mock data batch with random continuous features and labels.
Args:
bsz (int): Batch size.
Returns:
MockDataclassBatch: A batch of data with continuous features and labels.
"""
return MockDataclassBatch( return MockDataclassBatch(
continuous_features=torch.rand(bsz, 10).float(), continuous_features=torch.rand(bsz, 10).float(),
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(), labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
@ -35,6 +72,13 @@ def create_batch(bsz: int):
def test_sparse_pipeline(): def test_sparse_pipeline():
"""
Test function for the sparse pipeline with distributed model parallelism.
This function tests the behavior of the sparse training pipeline using
a mock module and data.
"""
device = torch.device("cpu") device = torch.device("cpu")
model = MockModule().to(device) model = MockModule().to(device)
@ -65,6 +109,15 @@ def test_sparse_pipeline():
def test_amp(): def test_amp():
"""
Test automatic mixed-precision (AMP) training with the sparse pipeline.
This function tests the behavior of the sparse training pipeline with
automatic mixed-precision (AMP) enabled, using a mock module and data.
AMP allows for faster training by using lower-precision data types, such as
torch.bfloat16, while maintaining model accuracy.
"""
device = torch.device("cpu") device = torch.device("cpu")
model = MockModule().to(device) model = MockModule().to(device)

View File

@ -39,12 +39,42 @@ Out = TypeVar("Out")
class TrainPipeline(abc.ABC, Generic[In, Out]): class TrainPipeline(abc.ABC, Generic[In, Out]):
"""
Abstract base class for training pipelines.
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline.
"""
@abc.abstractmethod @abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out: def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Make progress in the training pipeline.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
pass pass
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
"""
Move a batch of data to a specified device.
Args:
batch (In): The input batch.
device (torch.device): The target device.
non_blocking (bool): If True, move the data asynchronously.
Returns:
In: The batch of data on the target device.
"""
assert isinstance( assert isinstance(
batch, (torch.Tensor, Pipelineable) batch, (torch.Tensor, Pipelineable)
), f"{type(batch)} must implement Pipelineable interface" ), f"{type(batch)} must implement Pipelineable interface"
@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None: def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
"""
Wait for a batch of data on a specified stream.
Args:
batch (In): The input batch.
stream (Optional[Stream]): The CUDA stream to wait for.
Note:
This function is used for managing asynchronous CUDA operations.
"""
if stream is None: if stream is None:
return return
torch.cuda.current_stream().wait_stream(stream) torch.cuda.current_stream().wait_stream(stream)
@ -76,6 +116,21 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
stream, namely, the current (default) stream and `self._memcpy_stream`. For each stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization. memory, and the default stream runs forward, backward, and optimization.
Attributes:
In (TypeVar): Input data type.
Out (TypeVar): Output data type.
Methods:
__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None:
Initialize the TrainPipelineBase.
_connect(dataloader_iter: Iterator[In]) -> None:
Establish a connection to the data loader and move the input data to the GPU.
progress(dataloader_iter: Iterator[In]) -> Out:
Execute a training iteration, including forward and backward passes.
""" """
def __init__( def __init__(
@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
device: torch.device, device: torch.device,
) -> None: ) -> None:
"""
Initialize the TrainPipelineBase.
Args:
model (torch.nn.Module): The PyTorch model to be trained.
optimizer (torch.optim.Optimizer): The optimizer used for training.
device (torch.device): The target device for training (CPU or GPU).
"""
self._model = model self._model = model
self._optimizer = optimizer self._optimizer = optimizer
self._device = device self._device = device
@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = False self._connected = False
def _connect(self, dataloader_iter: Iterator[In]) -> None: def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Establish a connection to the data loader and move the input data to the GPU.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
"""
cur_batch = next(dataloader_iter) cur_batch = next(dataloader_iter)
self._cur_batch = cur_batch self._cur_batch = cur_batch
with torch.cuda.stream(self._memcpy_stream): with torch.cuda.stream(self._memcpy_stream):
@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
self._connected = True self._connected = True
def progress(self, dataloader_iter: Iterator[In]) -> Out: def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
Execute a training iteration, including forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator over input data.
Returns:
Out: The output data.
"""
if not self._connected: if not self._connected:
self._connect(dataloader_iter) self._connect(dataloader_iter)
@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
class Tracer(torch.fx.Tracer): class Tracer(torch.fx.Tracer):
"""
Custom tracer class for PyTorch models.
This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings.
Attributes:
proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing.
_leaf_modules (List[str]): List of qualified names of leaf modules.
"""
# Disable proxying buffers during tracing. Ideally, proxying buffers would # Disable proxying buffers during tracing. Ideally, proxying buffers would
# be disabled, but some models are currently mutating buffer values, which # be disabled, but some models are currently mutating buffer values, which
# causes errors during tracing. If those models can be rewritten to not do # causes errors during tracing. If those models can be rewritten to not do
@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer):
proxy_buffer_attributes = False proxy_buffer_attributes = False
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
"""
Initialize the Tracer.
Args:
leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing.
"""
super().__init__() super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
"""
Check if a module is a leaf module during tracing.
Args:
m (torch.nn.Module): The PyTorch module.
module_qualified_name (str): The qualified name of the module.
Returns:
bool: True if the module is considered a leaf module, False otherwise.
"""
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules: if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
return True return True
return super().is_leaf_module(m, module_qualified_name) return super().is_leaf_module(m, module_qualified_name)
@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer):
@dataclass @dataclass
class TrainPipelineContext: class TrainPipelineContext:
"""
Dataclass to store information related to the training pipeline context.
Attributes:
input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests.
module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts.
feature_processor_forwards (List[Any]): A list of feature processor forwards.
"""
# pyre-ignore [4] # pyre-ignore [4]
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict) module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
@ -166,6 +279,14 @@ class TrainPipelineContext:
@dataclass @dataclass
class ArgInfo: class ArgInfo:
"""
Dataclass to store information about arguments in the training pipeline.
Attributes:
input_attrs (List[str]): List of attribute names of the input batch.
is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem.
name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments.
"""
# attributes of input batch, e.g. batch.attr1.attr2 call # attributes of input batch, e.g. batch.attr1.attr2 call
# will produce ["attr1", "attr2"] # will produce ["attr1", "attr2"]
input_attrs: List[str] input_attrs: List[str]
@ -177,6 +298,16 @@ class ArgInfo:
class PipelinedForward: class PipelinedForward:
"""
Represents a pipelined forward pass operation.
Attributes:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
def __init__( def __init__(
self, self,
name: str, name: str,
@ -185,6 +316,16 @@ class PipelinedForward:
context: TrainPipelineContext, context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream], dist_stream: Optional[torch.cuda.streams.Stream],
) -> None: ) -> None:
"""
Initialize a PipelinedForward instance.
Args:
name (str): The name of the forward pass.
args (List[ArgInfo]): List of argument information for the forward pass.
module (ShardedModule): The sharded module associated with the forward pass.
context (TrainPipelineContext): The training pipeline context.
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
"""
self._name = name self._name = name
self._args = args self._args = args
self._module = module self._module = module
@ -193,6 +334,16 @@ class PipelinedForward:
# pyre-ignore [2, 24] # pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable: def __call__(self, *input, **kwargs) -> Awaitable:
"""
Perform the pipelined forward pass operation.
Args:
*input: Variable-length positional arguments.
**kwargs: Variable-length keyword arguments.
Returns:
Awaitable: An awaitable object representing the forward pass result.
"""
assert self._name in self._context.input_dist_requests assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name] request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable) assert isinstance(request, Awaitable)
@ -230,10 +381,22 @@ class PipelinedForward:
@property @property
def name(self) -> str: def name(self) -> str:
"""
Get the name of the forward pass.
Returns:
str: The name of the forward pass.
"""
return self._name return self._name
@property @property
def args(self) -> List[ArgInfo]: def args(self) -> List[ArgInfo]:
"""
Get the list of argument information for the forward pass.
Returns:
List[ArgInfo]: List of argument information.
"""
return self._args return self._args
@ -242,6 +405,17 @@ def _start_data_dist(
batch: In, batch: In,
context: TrainPipelineContext, context: TrainPipelineContext,
) -> None: ) -> None:
"""
Start data distribution for a list of pipelined modules.
Args:
pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules.
batch (In): The input batch.
context (TrainPipelineContext): The training pipeline context.
Returns:
None: This function doesn't return a value.
"""
context.input_dist_requests.clear() context.input_dist_requests.clear()
context.module_contexts.clear() context.module_contexts.clear()
for module in pipelined_modules: for module in pipelined_modules:
@ -288,6 +462,14 @@ def _get_node_args_helper(
""" """
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found. It also counts the number of (args + kwargs) found.
Args:
arguments: The arguments to process.
num_found: The current count of arguments found.
feature_processor_arguments: Optional list of feature processor arguments.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found.
""" """
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))] arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
@ -332,6 +514,16 @@ def _get_node_args_helper(
def _get_node_args( def _get_node_args(
node: Node, feature_processor_nodes: Optional[List[Node]] = None node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]: ) -> Tuple[List[ArgInfo], int]:
"""
Get argument information for a given node.
Args:
node (Node): The node to process.
feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes.
Returns:
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found.
"""
num_found = 0 num_found = 0
pos_arg_info_list, num_found = _get_node_args_helper( pos_arg_info_list, num_found = _get_node_args_helper(
node.args, num_found, feature_processor_nodes node.args, num_found, feature_processor_nodes
@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper(
path: str, path: str,
unsharded_module_names: Set[str], unsharded_module_names: Set[str],
) -> bool: ) -> bool:
"""
Get the names of unsharded modules in a model.
Args:
model (torch.nn.Module): The model to analyze.
path (str): The current path in the model hierarchy.
unsharded_module_names (Set[str]): A set to store the names of unsharded modules.
Returns:
bool: True if any sharded modules were found in the hierarchy, False otherwise.
"""
sharded_children = set() sharded_children = set()
for name, child in model.named_children(): for name, child in model.named_children():
curr_path = path + name curr_path = path + name
@ -375,7 +578,13 @@ def _get_unsharded_module_names_helper(
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]: def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
""" """
Returns a list of top level modules do not contain any sharded sub modules. Returns a list of top-level modules that do not contain any sharded sub-modules.
Args:
model (torch.nn.Module): The model to analyze.
Returns:
List[str]: A list of top-level module names without sharded sub-modules.
""" """
unsharded_module_names: Set[str] = set() unsharded_module_names: Set[str] = set()
@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901
context: TrainPipelineContext, context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream], dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]: ) -> List[ShardedModule]:
"""
Rewrites the model to enable pipelined execution for selected sharded modules.
This function traces the input model using a custom tracer and identifies sharded modules
that can be pipelined. It then creates PipelinedForward objects for these modules,
which enable pipelining during training.
Args:
model (torch.nn.Module): The input model to be rewritten.
context (TrainPipelineContext): The context containing information needed for pipelining.
dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution.
Returns:
List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution.
"""
# Get underlying nn.Module # Get underlying nn.Module
if isinstance(model, DistributedModelParallel): if isinstance(model, DistributedModelParallel):
@ -455,6 +679,18 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
Input model must be symbolically traceable with the exception of `ShardedModule` and Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules. `DistributedDataParallel` modules.
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
Attributes:
synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines.
""" """
synced_pipeline_id: Dict[int, int] = {} synced_pipeline_id: Dict[int, int] = {}
@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
enable_grad_scaling: bool = True, enable_grad_scaling: bool = True,
grad_accum: Optional[int] = None, grad_accum: Optional[int] = None,
) -> None: ) -> None:
"""
Initializes the training pipeline.
Args:
model (torch.nn.Module): The input model to be used for training.
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
device (torch.device): The device where training will be performed.
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
"""
self._model = model self._model = model
self._optimizer = optimizer self._optimizer = optimizer
self._device = device self._device = device
@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
self._grad_accum = grad_accum self._grad_accum = grad_accum
def _connect(self, dataloader_iter: Iterator[In]) -> None: def _connect(self, dataloader_iter: Iterator[In]) -> None:
"""
Connects the training pipeline to data and prepares for forward and backward passes.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
"""
# batch 1 # batch 1
with torch.cuda.stream(self._memcpy_stream): with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter) batch_i = next(dataloader_iter)
@ -524,12 +778,19 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
def progress(self, dataloader_iter: Iterator[In]) -> Out: def progress(self, dataloader_iter: Iterator[In]) -> Out:
""" """
Progresses through the training pipeline, performing forward and backward passes.
NOTE: This method has been updated to perform gradient accumulation. NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress` optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th (congruent to training steps), and then update/reset on every `_grad_accum`th
step. step.
Args:
dataloader_iter (Iterator[In]): An iterator providing input data batches.
Returns:
Out: The output of the forward pass.
""" """
should_step_optimizer = ( should_step_optimizer = (
self._grad_accum is not None self._grad_accum is not None

View File

@ -9,35 +9,72 @@ FLIGHT_SERVER_PORT: int = 2222
def on_kf(): def on_kf():
"""Check if the code is running on Kubernetes with Kubeflow (KF) environment.
Returns:
bool: True if running on KF, False otherwise.
"""
return "SPEC_TYPE" in os.environ return "SPEC_TYPE" in os.environ
def has_readers(): def has_readers():
"""Check if the current task has dataset workers.
Returns:
bool: True if the task has dataset workers, False otherwise.
"""
if on_kf(): if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env["dataset_worker"] is not None return machines_config_env.get("dataset_worker") is not None
return os.environ.get("HAS_READERS", "False") == "True" return os.environ.get("HAS_READERS", "False") == "True"
def get_task_type(): def get_task_type():
"""Get the type of the current task.
Returns:
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
"""
if on_kf(): if on_kf():
return os.environ["SPEC_TYPE"] return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"] return os.environ["TASK_TYPE"]
def is_chief() -> bool: def is_chief() -> bool:
"""Check if the current task is the 'chief'.
Returns:
bool: True if the current task is the 'chief', False otherwise.
"""
return get_task_type() == "chief" return get_task_type() == "chief"
def is_reader() -> bool: def is_reader() -> bool:
"""Check if the current task is a 'datasetworker'.
Returns:
bool: True if the current task is a 'datasetworker', False otherwise.
"""
return get_task_type() == "datasetworker" return get_task_type() == "datasetworker"
def is_dispatcher() -> bool: def is_dispatcher() -> bool:
"""Check if the current task is a 'datasetdispatcher'.
Returns:
bool: True if the current task is a 'datasetdispatcher', False otherwise.
"""
return get_task_type() == "datasetdispatcher" return get_task_type() == "datasetdispatcher"
def get_task_index(): def get_task_index():
"""Get the index of the current task.
Returns:
int: Task index.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf(): if on_kf():
pod_name = os.environ["MY_POD_NAME"] pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1]) return int(pod_name.split("-")[-1])
@ -46,12 +83,24 @@ def get_task_index():
def get_reader_port(): def get_reader_port():
"""Get the port used by readers.
Returns:
int: Reader port.
"""
if on_kf(): if on_kf():
return KF_DDS_PORT return KF_DDS_PORT
return SLURM_DDS_PORT return SLURM_DDS_PORT
def get_dds(): def get_dds():
"""Get the Distributed Data Service (DDS) address.
Returns:
str: DDS address in the format 'grpc://host:port'.
Raises:
ValueError: If the job does not have DDS.
"""
if not has_readers(): if not has_readers():
return None return None
dispatcher_address = get_dds_dispatcher_address() dispatcher_address = get_dds_dispatcher_address()
@ -62,6 +111,11 @@ def get_dds():
def get_dds_dispatcher_address(): def get_dds_dispatcher_address():
"""Get the DDS dispatcher address.
Returns:
str: DDS dispatcher address in the format 'host:port'.
"""
if not has_readers(): if not has_readers():
return None return None
if on_kf(): if on_kf():
@ -73,6 +127,11 @@ def get_dds_dispatcher_address():
def get_dds_worker_address(): def get_dds_worker_address():
"""Get the DDS worker address.
Returns:
str: DDS worker address in the format 'host:port'.
"""
if not has_readers(): if not has_readers():
return None return None
if on_kf(): if on_kf():
@ -85,15 +144,27 @@ def get_dds_worker_address():
def get_num_readers(): def get_num_readers():
"""Get the number of dataset workers.
Returns:
int: Number of dataset workers.
"""
if not has_readers(): if not has_readers():
return 0 return 0
if on_kf(): if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env["num_dataset_workers"] or 0) return int(machines_config_env.get("num_dataset_workers") or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(",")) return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
def get_flight_server_addresses(): def get_flight_server_addresses():
"""Get Flight server addresses for dataset workers.
Returns:
List[str]: List of Flight server addresses in the format 'grpc://host:port'.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf(): if on_kf():
job_name = os.environ["JOB_NAME"] job_name = os.environ["JOB_NAME"]
return [ return [
@ -105,4 +176,9 @@ def get_flight_server_addresses():
def get_dds_journaling_dir(): def get_dds_journaling_dir():
"""Get the DDS journaling directory.
Returns:
str: DDS journaling directory.
"""
return os.environ.get("DATASET_JOURNALING_DIR", None) return os.environ.get("DATASET_JOURNALING_DIR", None)

View File

@ -14,17 +14,20 @@ def update_mean(
weight: torch.Tensor, weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Update the mean according to Welford formula: Update the mean according to the Welford formula.
This function updates the mean and the weighted sum of values using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information. See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
Args: Args:
current_mean: The value of the current accumulated mean. current_mean (torch.Tensor): The value of the current accumulated mean.
current_weight_sum: The current weighted sum. current_weight_sum (torch.Tensor): The current weighted sum.
value: The new value that needs to be added to get a new mean. value (torch.Tensor): The new value that needs to be added to get a new mean.
weight: The weights for the new value. weight (torch.Tensor): The weights for the new value.
Returns: The updated mean and updated weighted sum.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
""" """
weight = torch.broadcast_to(weight, value.shape) weight = torch.broadcast_to(weight, value.shape)
@ -38,11 +41,14 @@ def update_mean(
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
""" """
Merge the state from multiple workers. Merge the state from multiple workers.
This function merges the state from multiple workers to compute the accumulated mean.
Args: Args:
state: A tensor with the first dimension indicating workers. state (torch.Tensor): A tensor with the first dimension indicating workers.
Returns: The accumulated mean from all workers.
Returns:
torch.Tensor: The accumulated mean from all workers.
""" """
mean, weight_sum = update_mean( mean, weight_sum = update_mean(
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
@ -55,11 +61,19 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
class StableMean(torchmetrics.Metric): class StableMean(torchmetrics.Metric):
""" """
This implements a numerical stable mean metrics computation using Welford algorithm according to A numerical stable mean metric using the Welford algorithm.
This class implements a numerical stable mean metrics computation using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
For example when using float32, the algorithm will give a valid output even if the "sum" is larger For example when using float32, the algorithm will give a valid output even if the "sum" is larger
than the maximum float32 as far as the mean is within the limit of float32. than the maximum float32 as far as the mean is within the limit of float32.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information. See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
Args:
**kwargs: Additional parameters supported by all torchmetrics.Metric.
Attributes:
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -75,11 +89,11 @@ class StableMean(torchmetrics.Metric):
) )
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None: def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
""" """Update the current mean.
Update the current mean.
Args: Args:
value: Value to update the mean with. value (torch.Tensor): Value to update the mean with.
weight: weight to use. Shape should be broadcastable to that of value. weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
""" """
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
) )
def compute(self) -> torch.Tensor: def compute(self) -> torch.Tensor:
""" """Compute and return the accumulated mean.
Compute and return the accumulated mean.
Returns:
torch.Tensor: The accumulated mean.
""" """
return self.mean_and_weight_sum[0] return self.mean_and_weight_sum[0]

View File

@ -29,6 +29,9 @@ def _compute_helper(
equal_predictions_as_incorrect: For positive & negative labels having identical scores, equal_predictions_as_incorrect: For positive & negative labels having identical scores,
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise, we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
we assume that they are correct prediction (i.e weight = 0). we assume that they are correct prediction (i.e weight = 0).
Returns:
torch.Tensor: The computed AUROC
""" """
dim = 0 dim = 0
@ -52,22 +55,32 @@ def _compute_helper(
class AUROCWithMWU(torchmetrics.Metric): class AUROCWithMWU(torchmetrics.Metric):
""" """
AUROC using Mann-Whitney U-test. AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve. See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return Note: Currently, this implementation only supports binary classification.
the correct AUROC even if the predicted probabilities are all close to 0.
Currently only support binary classification. Args:
label_threshold (float): Threshold for classifying labels as positive or negative.
Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
""" """
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs): def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
""" """
Initializes the AUROCWithMWU metric.
Args: Args:
label_threshold: Labels strictly above this threshold are considered positive labels, label_threshold (float): Threshold for classifying labels as positive or negative.
otherwise, they are considered negative. Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class: If True, an error will be raise if negative or positive class is missing. raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, we will simply log a warning. Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric. **kwargs: Additional parameters supported by all torchmetrics.Metric.
""" """
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -14,7 +14,24 @@ from absl import logging as logging
def setup_absl_logging(): def setup_absl_logging():
"""Make sure that absl logging pushes to stdout rather than stderr.""" """
Configure absl-py logging to direct log messages to stdout and apply a custom log message format.
This function ensures that log messages generated by the absl-py library are written to stdout
rather than stderr. It also applies a custom log message format that includes module, function,
line number, log level, and the log message content.
Note:
This function should be called once at the beginning of your script or application to
configure absl-py logging.
Example:
To use this function, simply call it at the start of your script:
```
setup_absl_logging()
```
"""
logging.get_absl_handler().python_handler.stream = sys.stdout logging.get_absl_handler().python_handler.stream = sys.stdout
formatter = py_logging.Formatter( formatter = py_logging.Formatter(
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s" fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"

View File

@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging
class Testtlogging(unittest.TestCase): class Testtlogging(unittest.TestCase):
def test_warn_once(self): def test_warn_once(self):
"""
Test that warning messages are logged only once when using the assertLogs context manager.
This unit test checks the behavior of the logging system when warning messages are issued
multiple times within the same context. It uses the assertLogs context manager to capture
log messages at the INFO level and verifies that warning messages are logged only once.
Example:
To use this test case, call it using a test runner like unittest:
```
python -m unittest your_test_module.TestLogging.test_warn_once
```
"""
with self.assertLogs(level="INFO") as captured_logs: with self.assertLogs(level="INFO") as captured_logs:
logging.info("first info") logging.info("first info")
logging.warning("first warning") logging.warning("first warning")

View File

@ -18,7 +18,35 @@ import torch.distributed as dist
def rank_specific(logger): def rank_specific(logger):
"""Ensures that we only override a given logger once.""" """
Customize logger behavior based on the distributed environment and rank.
This function allows for customizing the behavior of a logger based on the distributed environment and the rank
of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
depending on the rank or limit the number of redundant logs.
Args:
logger: The logger object to customize.
Returns:
The customized logger.
Example:
To use this function with the `logging` module:
```python
import logging
from rank_specific_logging import rank_specific
logger = logging.getLogger(__name__)
rank_specific(logger)
```
Customization:
- Messages are only logged if the distributed environment is not initialized or if the rank matches.
- The 'warning' method is limited to logging a single redundant warning.
- Logging from rank -1 is redirected to include the rank information.
"""
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"): if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
return logger return logger

View File

@ -54,12 +54,20 @@ def maybe_shard_model(
model, model,
device: torch.device, device: torch.device,
): ):
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment. """
Set up and apply DistributedModelParallel to a model if running in a distributed environment.
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
DistributedModelParallel. DistributedModelParallel.
If not in a distributed environment, returns model directly. If not in a distributed environment, returns the model directly.
Args:
model: The PyTorch model.
device: The target device (e.g., 'cuda').
Returns:
The model wrapped with DistributedModelParallel if in a distributed environment, else the original model.
""" """
if dist.is_initialized(): if dist.is_initialized():
logging.info("***** Wrapping in DistributedModelParallel *****") logging.info("***** Wrapping in DistributedModelParallel *****")
@ -74,13 +82,14 @@ def maybe_shard_model(
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None: def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
"""Handy function to log the content of EBC embedding layer. """
Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer.
Only works for single GPU machines. Only works for single GPU machines.
Args: Args:
weight_name: name of tensor, as defined in model weight_name: Name of the tensor, as defined in the model.
table_name: name of the EBC table the weight is taken from table_name: Name of the EBC table the weight is taken from.
weight_tensor: embedding weight tensor weight_tensor: Embedding weight tensor.
""" """
logging.info(f"{weight_name}, {table_name}", rank=-1) logging.info(f"{weight_name}, {table_name}", rank=-1)
logging.info(f"{weight_tensor.metadata()}", rank=-1) logging.info(f"{weight_tensor.metadata()}", rank=-1)

View File

@ -8,11 +8,60 @@ import pydantic
class PiecewiseConstant(base_config.BaseConfig): class PiecewiseConstant(base_config.BaseConfig):
"""
Configuration for a piecewise constant learning rate schedule.
This configuration class allows you to specify a piecewise constant learning rate schedule
by defining boundaries and corresponding learning rate values.
Attributes:
learning_rate_boundaries (List[int], optional): List of step boundaries at which
the learning rate will change. If None, no boundaries are defined.
learning_rate_values (List[float], optional): List of learning rate values
corresponding to the boundaries. If None, no values are defined.
Example:
To configure a piecewise constant learning rate schedule, create an instance of this class
and set the attributes accordingly. For example:
```python
piecewise_lr = PiecewiseConstant(
learning_rate_boundaries=[1000, 2000, 3000],
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
)
```
Note:
The number of learning rate values should be one more than the number of boundaries.
"""
learning_rate_boundaries: typing.List[int] = pydantic.Field(None) learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
learning_rate_values: typing.List[float] = pydantic.Field(None) learning_rate_values: typing.List[float] = pydantic.Field(None)
class LinearRampToConstant(base_config.BaseConfig): class LinearRampToConstant(base_config.BaseConfig):
"""
Configuration for a linear ramp-up to constant learning rate schedule.
This configuration class allows you to specify a learning rate schedule that ramps up linearly
from zero to a constant value over a specified number of steps.
Attributes:
learning_rate (float): The final constant learning rate.
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
Example:
To configure a linear ramp-up to a constant learning rate, create an instance of this class
and set the attributes accordingly. For example:
```python
linear_ramp_lr = LinearRampToConstant(
learning_rate=0.1,
num_ramp_steps=1000
)
```
"""
learning_rate: float learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field( num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
description="Number of steps to ramp this up from zero." description="Number of steps to ramp this up from zero."
@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig):
class LinearRampToCosine(base_config.BaseConfig): class LinearRampToCosine(base_config.BaseConfig):
"""
Configuration for a linear ramp-up to cosine decay learning rate schedule.
This configuration class allows you to specify a learning rate schedule that ramps up linearly
from zero, then decays following a cosine schedule to a final constant learning rate.
Attributes:
learning_rate (float): The initial learning rate at the start of ramp-up.
final_learning_rate (float): The final constant learning rate after decay.
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
final_num_steps (PositiveInt): Final number of steps where decay stops.
Example:
To configure a linear ramp-up to cosine decay learning rate, create an instance of this
class and set the attributes accordingly. For example:
```python
ramp_to_cosine_lr = LinearRampToCosine(
learning_rate=0.01,
final_learning_rate=0.001,
num_ramp_steps=1000,
final_num_steps=5000
)
```
"""
learning_rate: float learning_rate: float
final_learning_rate: float final_learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field( num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig):
class LearningRate(base_config.BaseConfig): class LearningRate(base_config.BaseConfig):
"""
Learning rate configuration for training.
This configuration class allows you to specify different learning rate schedules
for your training process.
Attributes:
constant (float, optional): Constant learning rate to be used throughout training.
linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly
and then decays following a cosine schedule.
linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up
linearly and then remains constant.
piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified
boundaries with corresponding values.
Example:
To configure a learning rate schedule, create an instance of this class and set the
attributes accordingly. For example:
```python
learning_rate = LearningRate(
constant=0.01,
linear_ramp_to_cosine=LinearRampToCosine(
learning_rate=0.1,
final_learning_rate=0.001,
num_ramp_steps=1000,
final_num_steps=5000
)
)
```
Note:
Each learning rate schedule attribute can be set to `None` if not needed.
"""
constant: float = pydantic.Field(None, one_of="lr") constant: float = pydantic.Field(None, one_of="lr")
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr") linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr") linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig):
class OptimizerAlgorithmConfig(base_config.BaseConfig): class OptimizerAlgorithmConfig(base_config.BaseConfig):
"""Base class for optimizer configurations.""" """
Base class for optimizer configurations.
This base configuration class provides a structure for specifying various optimizer-related
settings, including the learning rate and different learning rate schedules.
Attributes:
lr (float): The base learning rate used by the optimizer.
Subclasses should inherit from this base class and define additional attributes specific to
the optimizer algorithm they represent.
Example:
To create a custom optimizer configuration, create a subclass of this base class and
define the necessary attributes. For example:
```python
class MyOptimizerConfig(OptimizerAlgorithmConfig):
momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.")
```
Note:
This base class does not include specific optimizer settings. Subclasses should define
the optimizer-specific attributes as needed.
"""
lr: float lr: float
... ...
class AdamConfig(OptimizerAlgorithmConfig): class AdamConfig(OptimizerAlgorithmConfig):
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam """
Configuration for the Adam optimizer.
This configuration class allows you to specify the hyperparameters for the Adam optimizer.
Attributes:
lr (float): The learning rate for optimization.
betas (Tuple[float, float], optional): Coefficients used for computing running averages
of gradient and squared gradient. Defaults to (0.9, 0.999).
eps (float, optional): A small constant added to the denominator for numerical stability.
Defaults to 1e-7.
Example:
To configure the Adam optimizer, create an instance of this class and set the attributes
accordingly. For example:
```python
adam_optimizer = AdamConfig(
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8
)
```
See Also:
[PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
"""
lr: float lr: float
betas: typing.Tuple[float, float] = [0.9, 0.999] betas: typing.Tuple[float, float] = [0.9, 0.999]
eps: float = 1e-7 # Numerical stability in denominator. eps: float = 1e-7 # Numerical stability in denominator.
class SgdConfig(OptimizerAlgorithmConfig): class SgdConfig(OptimizerAlgorithmConfig):
"""
Configuration for the Stochastic Gradient Descent (SGD) optimizer.
This configuration class allows you to specify the hyperparameters for the SGD optimizer.
Attributes:
lr (float): The learning rate for optimization.
momentum (float, optional): The momentum factor for SGD. Defaults to 0.0.
Example:
To configure the SGD optimizer, create an instance of this class and set the attributes
accordingly. For example:
```python
sgd_optimizer = SgdConfig(
lr=0.01,
momentum=0.9
)
```
"""
lr: float lr: float
momentum: float = 0.0 momentum: float = 0.0
class AdagradConfig(OptimizerAlgorithmConfig): class AdagradConfig(OptimizerAlgorithmConfig):
"""
Configuration for the optimizer used during training.
This configuration class allows you to specify the optimizer for training, including
options for various optimizer algorithms.
Attributes:
learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None.
adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None.
sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
Defaults to None.
adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None.
Example:
To configure the optimizer for training, create an instance of this class and set the
attributes accordingly. For example:
```python
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.001),
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
)
```
"""
lr: float lr: float
eps: float = 0 eps: float = 0
class OptimizerConfig(base_config.BaseConfig): class OptimizerConfig(base_config.BaseConfig):
"""
Configuration for defining different optimizer algorithms and their parameters.
This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad,
along with their respective hyperparameters.
Args:
learning_rate (LearningRate): The learning rate configuration, which can include
constant learning rates or other learning rate schedules.
adam (AdamConfig): Configuration for the Adam optimizer.
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
Example:
```python
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.001),
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8),
)
```
Attributes:
learning_rate (LearningRate): The learning rate configuration.
adam (AdamConfig): Configuration for the Adam optimizer.
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
Note:
You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an
`OptimizerConfig` instance.
See Also:
- `LearningRate`: Configuration for specifying learning rates.
- `AdamConfig`: Configuration for the Adam optimizer.
- `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer.
- `AdagradConfig`: Configuration for the Adagrad optimizer.
"""
learning_rate: LearningRate = pydantic.Field( learning_rate: LearningRate = pydantic.Field(
None, None,
description="Constant learning rates", description="Constant learning rates",
@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig):
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig): def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
"""
Get the optimizer algorithm configuration from the given `OptimizerConfig`.
This function extracts and returns the specific optimizer algorithm configuration
(e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`.
Args:
optimizer_config (OptimizerConfig): The optimizer configuration object containing
one of the optimizer algorithm configurations.
Returns:
Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm
configuration extracted from `optimizer_config`.
Raises:
ValueError: If no optimizer algorithm is selected in `optimizer_config`.
Example:
```python
optimizer_config = OptimizerConfig(
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
)
algorithm_config = get_optimizer_algorithm_config(optimizer_config)
# `algorithm_config` will be an instance of `AdamConfig`.
```
"""
if optimizer_config.adam is not None: if optimizer_config.adam is not None:
return optimizer_config.adam return optimizer_config.adam
elif optimizer_config.sgd is not None: elif optimizer_config.sgd is not None:

View File

@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step): def compute_lr(lr_config, step):
"""Compute a learning rate.""" """
Compute the learning rate based on the specified learning rate configuration.
This function calculates the learning rate according to the given configuration, which can include
constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
Args:
lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
step (int): The current training step or iteration.
Returns:
float: The computed learning rate for the current step.
Raises:
ValueError: If the `lr_config` is invalid or contains conflicting options.
Example:
```python
lr_schedule = LearningRate(
constant=0.001,
piecewise_constant=PiecewiseConstant(
learning_rate_boundaries=[1000, 2000, 3000],
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
)
)
current_step = 2500
learning_rate = compute_lr(lr_schedule, current_step)
```
"""
if lr_config.constant is not None: if lr_config.constant is not None:
return lr_config.constant return lr_config.constant
elif lr_config.piecewise_constant is not None: elif lr_config.piecewise_constant is not None:
@ -46,10 +74,53 @@ def compute_lr(lr_config, step):
class LRShim(_LRScheduler): class LRShim(_LRScheduler):
"""Shim to get learning rates into a LRScheduler. """
Learning Rate Scheduler Shim to adjust learning rates during training.
This adheres to the torch.optim scheduler API and can be plugged anywhere that This class acts as a shim to apply different learning rates to individual parameter groups
e.g. exponential decay can be used. within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
optimizers, allowing fine-grained control over learning rates based on configuration.
Args:
optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
corresponding learning rate configurations.
last_epoch (int, optional): The index of the last epoch. Default is -1.
verbose (bool, optional): If True, prints a warning message when accessing learning rates
using the deprecated `get_lr()` method. Default is False.
Raises:
ValueError: If the number of parameter groups in the optimizer does not match the number
of learning rate configurations provided.
Note:
To obtain the last computed learning rates, please use `get_last_lr()`.
Example:
```python
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
lr_schedule = {
'main': LearningRate(constant=0.01),
'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
learning_rate_boundaries=[1000, 2000],
learning_rate_values=[0.01, 0.001]
))
}
lr_shim = LRShim(optimizer, lr_schedule)
for epoch in range(num_epochs):
# Train the model
train(...)
# Update learning rates at the end of each epoch
lr_shim.step(epoch)
final_lr_main = lr_shim.get_last_lr()['main']
final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
```
See Also:
- `LearningRate`: Configuration for specifying learning rates.
- `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
""" """
def __init__( def __init__(
@ -95,8 +166,41 @@ def get_optimizer_class(optimizer_config: OptimizerConfig):
def build_optimizer( def build_optimizer(
model: torch.nn.Module, optimizer_config: OptimizerConfig model: torch.nn.Module, optimizer_config: OptimizerConfig
) -> Tuple[Optimizer, _LRScheduler]: ) -> Tuple[Optimizer, _LRScheduler]:
"""Builds an optimizer and LR scheduler from an OptimizerConfig. """
Note: use this when you want the same optimizer and learning rate schedule for all your parameters. Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
Args:
model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
algorithm and learning rate settings.
Returns:
Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
objects.
Note:
This function is intended for cases where you want the same optimizer and learning rate
schedule for all model parameters.
Example:
```python
model = MyModel()
optimizer_config = OptimizerConfig(
learning_rate=LearningRate(constant=0.01),
sgd=SgdConfig(lr=0.01, momentum=0.9)
)
optimizer, scheduler = build_optimizer(model, optimizer_config)
for epoch in range(num_epochs):
# Train the model with the optimizer
train(model, optimizer, ...)
# Update learning rates at the end of each epoch
scheduler.step(epoch)
```
See Also:
- `OptimizerConfig`: Configuration for specifying optimizer settings.
- `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
""" """
optimizer_class = get_optimizer_class(optimizer_config) optimizer_class = get_optimizer_class(optimizer_config)
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict()) optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())

View File

@ -9,6 +9,36 @@ import pydantic
class TrainingConfig(config_mod.BaseConfig): class TrainingConfig(config_mod.BaseConfig):
"""
Configuration settings for the training process.
This class defines various training-related settings, including the directory to save checkpoints, the number
of training steps, logging intervals, and other training parameters.
Attributes:
save_dir (str): The directory where checkpoints and training artifacts will be saved.
num_train_steps (pydantic.PositiveInt): The total number of training steps to run.
initial_checkpoint_dir (str): The directory containing initial checkpoints (optional).
checkpoint_every_n (pydantic.PositiveInt): Frequency of saving checkpoints during training.
checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
train_log_every_n (pydantic.PositiveInt): Frequency of logging training progress.
num_eval_steps (int): Number of evaluation steps. Use a negative value to evaluate the entire dataset.
eval_log_every_n (pydantic.PositiveInt): Frequency of logging evaluation progress.
eval_timeout_in_s (pydantic.PositiveFloat): Maximum time (in seconds) allowed for evaluation.
gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
Example:
To configure training with checkpoints saved every 1000 steps, use the following settings:
```python
TrainingConfig(
save_dir="/tmp/model",
num_train_steps=1000000,
checkpoint_every_n=1000,
train_log_every_n=1000,
)
```
"""
save_dir: str = "/tmp/model" save_dir: str = "/tmp/model"
num_train_steps: pydantic.PositiveInt = 1000000 num_train_steps: pydantic.PositiveInt = 1000000
initial_checkpoint_dir: str = pydantic.Field( initial_checkpoint_dir: str = pydantic.Field(
@ -32,6 +62,42 @@ class TrainingConfig(config_mod.BaseConfig):
class RecapConfig(config_mod.BaseConfig): class RecapConfig(config_mod.BaseConfig):
"""
Configuration settings for the Recap model training process.
This class defines the overall configuration for the training process of a Recap model. It includes settings for
training, model architecture, data, optimization, and evaluation.
Attributes:
training (TrainingConfig): Configuration settings for the training process.
model (model_config.ModelConfig): Configuration settings for the Recap model architecture.
train_data (data_config.RecapDataConfig): Configuration settings for training data.
validation_data (Dict[str, data_config.RecapDataConfig]): Configuration settings for validation data.
optimizer (optimizer_config.RecapOptimizerConfig): Configuration settings for optimization.
which_metrics (Optional[str]): Optional specification of which metrics to pick.
Note:
This class encapsulates all the necessary configurations to train a Recap model. It defines settings for
training, the model architecture, data loading, optimization, and evaluation.
Example:
To configure a Recap model training process, use the following settings:
```python
RecapConfig(
training=TrainingConfig(
save_dir="/tmp/model",
num_train_steps=1000000,
checkpoint_every_n=1000,
train_log_every_n=1000,
),
model=model_config.ModelConfig(...),
train_data=data_config.RecapDataConfig(...),
validation_data={"dev": data_config.RecapDataConfig(...)},
optimizer=optimizer_config.RecapOptimizerConfig(...),
)
```
"""
training: TrainingConfig = pydantic.Field(TrainingConfig()) training: TrainingConfig = pydantic.Field(TrainingConfig())
model: model_config.ModelConfig model: model_config.ModelConfig
train_data: data_config.RecapDataConfig train_data: data_config.RecapDataConfig

View File

@ -11,8 +11,10 @@ class ExplicitDateInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_date and days of data.""" """Arguments to select train/validation data using end_date and days of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.") data_root: str = pydantic.Field(..., description="Data path prefix.")
end_date: str = pydantic.Field(..., description="Data end date, inclusive.") end_date: str = pydantic.Field(...,
days: int = pydantic.Field(..., description="Number of days of data for dataset.") description="Data end date, inclusive.")
days: int = pydantic.Field(...,
description="Number of days of data for dataset.")
num_missing_days_tol: int = pydantic.Field( num_missing_days_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_days_tol days of missing data." 0, description="We tolerate <= num_missing_days_tol days of missing data."
) )
@ -22,8 +24,10 @@ class ExplicitDatetimeInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_datetime and hours of data.""" """Arguments to select train/validation data using end_datetime and hours of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.") data_root: str = pydantic.Field(..., description="Data path prefix.")
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.") end_datetime: str = pydantic.Field(...,
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.") description="Data end datetime, inclusive.")
hours: int = pydantic.Field(...,
description="Number of hours of data for dataset.")
num_missing_hours_tol: int = pydantic.Field( num_missing_hours_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_hours_tol hours of missing data." 0, description="We tolerate <= num_missing_hours_tol hours of missing data."
) )
@ -42,7 +46,8 @@ class DatasetConfig(base_config.BaseConfig):
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field( explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
None, one_of="date_inputs_format" None, one_of="date_inputs_format"
) )
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format") explicit_date_inputs: ExplicitDateInputs = pydantic.Field(
None, one_of="date_inputs_format")
global_batch_size: pydantic.PositiveInt global_batch_size: pydantic.PositiveInt
@ -52,7 +57,8 @@ class DatasetConfig(base_config.BaseConfig):
repeat_files: bool = pydantic.Field( repeat_files: bool = pydantic.Field(
True, description="DEPRICATED. Files are repeated no matter what this is set to." True, description="DEPRICATED. Files are repeated no matter what this is set to."
) )
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size") file_batch_size: pydantic.PositiveInt = pydantic.Field(
16, description="File batch size")
cache: bool = pydantic.Field( cache: bool = pydantic.Field(
False, False,
@ -70,7 +76,8 @@ class DatasetConfig(base_config.BaseConfig):
) )
# tf.data.Dataset options # tf.data.Dataset options
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.") examples_shuffle_buffer_size: int = pydantic.Field(
1024, description="Size of shuffle buffers.")
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of parallel calls." None, description="Number of parallel calls."
) )
@ -125,7 +132,8 @@ class TaskData(base_config.BaseConfig):
class SegDenseSchema(base_config.BaseConfig): class SegDenseSchema(base_config.BaseConfig):
schema_path: str = pydantic.Field(..., description="Path to feature config json.") schema_path: str = pydantic.Field(...,
description="Path to feature config json.")
features: typing.List[str] = pydantic.Field( features: typing.List[str] = pydantic.Field(
[], [],
description="List of features (in addition to the renamed features) to read from schema path above.", description="List of features (in addition to the renamed features) to read from schema path above.",
@ -192,8 +200,10 @@ class DownsampleNegatives(base_config.BaseConfig):
class Preprocess(base_config.BaseConfig): class Preprocess(base_config.BaseConfig):
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.") truncate_and_slice: TruncateAndSlice = pydantic.Field(
downcast: DownCast = pydantic.Field(None, description="Down cast to features.") None, description="Truncation and slicing.")
downcast: DownCast = pydantic.Field(
None, description="Down cast to features.")
rectify_labels: RectifyLabels = pydantic.Field( rectify_labels: RectifyLabels = pydantic.Field(
None, description="Rectify labels for a given overlap window" None, description="Rectify labels for a given overlap window"
) )
@ -242,5 +252,6 @@ class RecapDataConfig(DatasetConfig):
if values.get("evaluation_tasks") is not None: if values.get("evaluation_tasks") is not None:
for task in values["evaluation_tasks"]: for task in values["evaluation_tasks"]:
if task not in values["tasks"]: if task not in values["tasks"]:
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}") raise KeyError(
f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
return values return values

View File

@ -43,7 +43,8 @@ class RecapBatch(DataclassBatch):
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch: def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
"""Converts a torch data loader output into `RecapBatch`.""" """Converts a torch data loader output into `RecapBatch`."""
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x) x = tf.nest.map_structure(functools.partial(
sparse_or_dense_tf_to_torch, pin_memory=False), x)
try: try:
features_in, labels = x features_in, labels = x
except ValueError: except ValueError:
@ -53,25 +54,29 @@ def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
sparse_features = keyed_jagged_tensor_from_tensors_dict({}) sparse_features = keyed_jagged_tensor_from_tensors_dict({})
if sparse_feature_names: if sparse_feature_names:
sparse_features = keyed_jagged_tensor_from_tensors_dict( sparse_features = keyed_jagged_tensor_from_tensors_dict(
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names} {embedding_name: features_in[embedding_name]
for embedding_name in sparse_feature_names}
) )
user_embedding, user_eng_embedding, author_embedding = None, None, None user_embedding, user_eng_embedding, author_embedding = None, None, None
if "user_embedding" in features_in: if "user_embedding" in features_in:
if sparse_feature_names and "meta__user_id" in sparse_feature_names: if sparse_feature_names and "meta__user_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported") raise ValueError(
"Only one source of embedding for user is supported")
else: else:
user_embedding = features_in["user_embedding"] user_embedding = features_in["user_embedding"]
if "user_eng_embedding" in features_in: if "user_eng_embedding" in features_in:
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names: if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported") raise ValueError(
"Only one source of embedding for user is supported")
else: else:
user_eng_embedding = features_in["user_eng_embedding"] user_eng_embedding = features_in["user_eng_embedding"]
if "author_embedding" in features_in: if "author_embedding" in features_in:
if sparse_feature_names and "meta__author_id" in sparse_feature_names: if sparse_feature_names and "meta__author_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported") raise ValueError(
"Only one source of embedding for user is supported")
else: else:
author_embedding = features_in["author_embedding"] author_embedding = features_in["author_embedding"]
@ -84,7 +89,8 @@ def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
user_eng_embedding=user_eng_embedding, user_eng_embedding=user_eng_embedding,
author_embedding=author_embedding, author_embedding=author_embedding,
labels=labels, labels=labels,
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels) # Defaults to torch.ones_like(labels)
weights=features_in.get("weights", None),
) )
@ -119,18 +125,21 @@ def _add_weights(inputs, tasks: Dict[str, TaskData]):
float_label = tf.cast(label, tf.float32) float_label = tf.cast(label, tf.float32)
weights.append( weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate float_label / task.pos_downsampling_rate +
(1.0 - float_label) / task.neg_downsampling_rate
) )
# Ensure we are batch-major (assumes we batch before this call). # Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0) inputs["weights"] = tf.squeeze(tf.transpose(
tf.convert_to_tensor(weights)), axis=0)
return inputs return inputs
def get_datetimes(explicit_datetime_inputs): def get_datetimes(explicit_datetime_inputs):
"""Compute list datetime strings for train/validation data.""" """Compute list datetime strings for train/validation data."""
datetime_format = "%Y/%m/%d/%H" datetime_format = "%Y/%m/%d/%H"
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format) end = datetime.strptime(
explicit_datetime_inputs.end_datetime, datetime_format)
dates = sorted( dates = sorted(
[ [
(end - timedelta(hours=i + 1)).strftime(datetime_format) (end - timedelta(hours=i + 1)).strftime(datetime_format)
@ -158,7 +167,8 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
""" """
datetimes = get_datetimes(explicit_datetime_inputs) datetimes = get_datetimes(explicit_datetime_inputs)
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes] folders = [os.path.join(explicit_datetime_inputs.data_root, datetime)
for datetime in datetimes]
data_files = [] data_files = []
num_hours_missing = 0 num_hours_missing = 0
for folder in folders: for folder in folders:
@ -176,20 +186,34 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
) )
except tf.errors.NotFoundError as e: except tf.errors.NotFoundError as e:
num_hours_missing += 1 num_hours_missing += 1
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}") logging.warning(
f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
return sorted(data_files), num_hours_missing return sorted(data_files), num_hours_missing
def _map_output_for_inference( def _map_output_for_inference(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
): ):
"""
Map the input data for inference.
Args:
inputs (dict): Input data dictionary.
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
add_weights (bool, optional): Whether to add weights. Defaults to False.
Returns:
dict: Modified input data dictionary for inference.
"""
if preprocessor: if preprocessor:
raise ValueError("No preprocessor should be used at inference time.") raise ValueError("No preprocessor should be used at inference time.")
if add_weights: if add_weights:
raise NotImplementedError() raise NotImplementedError()
# Add zero weights. # Add zero weights.
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1)) inputs["weights"] = tf.zeros_like(
tf.expand_dims(inputs["continuous"][:, 0], -1))
for label in tasks: for label in tasks:
del inputs[label] del inputs[label]
return inputs return inputs
@ -198,6 +222,18 @@ def _map_output_for_inference(
def _map_output_for_train_eval( def _map_output_for_train_eval(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
): ):
"""
Map the input data for training and evaluation.
Args:
inputs (dict): Input data dictionary.
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
add_weights (bool, optional): Whether to add weights. Defaults to False.
Returns:
Tuple[dict, tf.Tensor]: Modified input data dictionary and label values for training and evaluation.
"""
if add_weights: if add_weights:
inputs = _add_weights_based_on_sampling_rates(inputs, tasks) inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
@ -205,7 +241,8 @@ def _map_output_for_train_eval(
if preprocessor: if preprocessor:
inputs = preprocessor(inputs) inputs = preprocessor(inputs)
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1]) label_values = tf.squeeze(
tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
for label in tasks: for label in tasks:
del inputs[label] del inputs[label]
@ -231,11 +268,13 @@ def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
float_label = tf.cast(label, tf.float32) float_label = tf.cast(label, tf.float32)
weights.append( weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate float_label / task.pos_downsampling_rate +
(1.0 - float_label) / task.neg_downsampling_rate
) )
# Ensure we are batch-major (assumes we batch before this call). # Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0) inputs["weights"] = tf.squeeze(tf.transpose(
tf.convert_to_tensor(weights)), axis=0)
return inputs return inputs
@ -249,6 +288,17 @@ class RecapDataset(torch.utils.data.IterableDataset):
repeat: bool = False, repeat: bool = False,
vocab_mapper: tf.keras.Model = None, vocab_mapper: tf.keras.Model = None,
): ):
"""
Create a RecapDataset for training or inference.
Args:
data_config (RecapDataConfig): Data configuration.
dataset_service (str, optional): Dataset service identifier. Defaults to None.
mode (JobMode, optional): Job mode (TRAIN or INFERENCE). Defaults to JobMode.TRAIN.
compression (str, optional): Compression type. Defaults to "AUTO".
repeat (bool, optional): Whether to repeat the dataset. Defaults to False.
vocab_mapper (tf.keras.Model, optional): Vocabulary mapper. Defaults to None.
"""
logging.info("***** Labels *****") logging.info("***** Labels *****")
logging.info(list(data_config.tasks.keys())) logging.info(list(data_config.tasks.keys()))
@ -263,7 +313,8 @@ class RecapDataset(torch.utils.data.IterableDataset):
preprocessor = None preprocessor = None
self._batch_size_multiplier = 1 self._batch_size_multiplier = 1
if data_config.preprocess: if data_config.preprocess:
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode) preprocessor = preprocessors.build_preprocess(
data_config.preprocess, mode=mode)
if data_config.preprocess.downsample_negatives: if data_config.preprocess.downsample_negatives:
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
@ -295,13 +346,17 @@ class RecapDataset(torch.utils.data.IterableDataset):
add_weights=should_add_weights, add_weights=should_add_weights,
) )
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None sparse_feature_names = list(
vocab_mapper.vocabs.keys()) if vocab_mapper else None
self._tf_dataset = self._create_tf_dataset() self._tf_dataset = self._create_tf_dataset()
self._init_tensor_spec() self._init_tensor_spec()
def _init_tensor_spec(self): def _init_tensor_spec(self):
"""
Initialize the tensor specification for the dataset.
"""
def _tensor_spec_to_torch_shape(spec): def _tensor_spec_to_torch_shape(spec):
if spec.shape is None: if spec.shape is None:
return None return None
@ -313,8 +368,15 @@ class RecapDataset(torch.utils.data.IterableDataset):
) )
def _create_tf_dataset(self): def _create_tf_dataset(self):
"""
Create a TensorFlow dataset from the data files.
Returns:
tf.data.Dataset: TensorFlow dataset.
"""
if hasattr(self, "_tf_dataset"): if hasattr(self, "_tf_dataset"):
raise ValueError("Do not call `_create_tf_dataset` more than once.") raise ValueError(
"Do not call `_create_tf_dataset` more than once.")
world_size = dist.get_world_size() if dist.is_initialized() else 1 world_size = dist.get_world_size() if dist.is_initialized() else 1
per_replica_bsz = ( per_replica_bsz = (
@ -371,6 +433,15 @@ class RecapDataset(torch.utils.data.IterableDataset):
return dataset.prefetch(world_size * 2) return dataset.prefetch(world_size * 2)
def _create_base_tf_dataset(self, batch_size: int): def _create_base_tf_dataset(self, batch_size: int):
"""
Create the base TensorFlow dataset.
Args:
batch_size (int): Batch size.
Returns:
tf.data.Dataset: Base TensorFlow dataset.
"""
if self._data_config.inputs: if self._data_config.inputs:
glob = self._data_config.inputs glob = self._data_config.inputs
filenames = sorted(tf.io.gfile.glob(glob)) filenames = sorted(tf.io.gfile.glob(glob))
@ -411,16 +482,27 @@ class RecapDataset(torch.utils.data.IterableDataset):
logging.info(f"Retaining only {len(filenames)} files.") logging.info(f"Retaining only {len(filenames)} files.")
filenames_ds = ( filenames_ds = (
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames)) tf.data.Dataset.from_tensor_slices(
filenames).shuffle(len(filenames))
# Because of drop_remainder, if our dataset does not fill # Because of drop_remainder, if our dataset does not fill
# up a batch, it will emit nothing without this repeat. # up a batch, it will emit nothing without this repeat.
.repeat(-1) .repeat(-1)
) )
if self._data_config.file_batch_size: if self._data_config.file_batch_size:
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size) filenames_ds = filenames_ds.batch(
self._data_config.file_batch_size)
def per_shard_dataset(filename): def per_shard_dataset(filename):
"""
Create a TensorFlow dataset for a single shard file.
Args:
filename (str): Path to the shard file.
Returns:
tf.data.Dataset: TensorFlow dataset for the shard file.
"""
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP") ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
return ds.prefetch(4) return ds.prefetch(4)
@ -471,6 +553,12 @@ class RecapDataset(torch.utils.data.IterableDataset):
yield to_batch(x) yield to_batch(x)
def to_dataloader(self) -> Dict[str, torch.Tensor]: def to_dataloader(self) -> Dict[str, torch.Tensor]:
"""
Convert the RecapDataset to a PyTorch DataLoader.
Returns:
torch.utils.data.DataLoader: PyTorch DataLoader for the dataset.
"""
return torch.utils.data.DataLoader(self, batch_size=None) return torch.utils.data.DataLoader(self, batch_size=None)
def __iter__(self): def __iter__(self):

View File

@ -17,6 +17,16 @@ FLAGS = flags.FLAGS
def _generate_random_example( def _generate_random_example(
tf_example_schema: Dict[str, tf.io.FixedLenFeature] tf_example_schema: Dict[str, tf.io.FixedLenFeature]
) -> Dict[str, tf.Tensor]: ) -> Dict[str, tf.Tensor]:
"""
Generate a random example based on the provided TensorFlow example schema.
Args:
tf_example_schema (Dict[str, tf.io.FixedLenFeature]): A dictionary defining the schema of the TensorFlow example.
Returns:
Dict[str, tf.Tensor]: A dictionary containing random data for each feature defined in the schema.
"""
example = {} example = {}
for feature_name, feature_spec in tf_example_schema.items(): for feature_name, feature_spec in tf_example_schema.items():
dtype = feature_spec.dtype dtype = feature_spec.dtype
@ -33,14 +43,43 @@ def _generate_random_example(
def _float_feature(value): def _float_feature(value):
"""
Create a TensorFlow float feature.
Args:
value: A float or list of floats.
Returns:
tf.train.Feature: A TensorFlow feature containing the float value(s).
"""
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value): def _int64_feature(value):
"""
Create a TensorFlow int64 feature.
Args:
value: An integer or list of integers.
Returns:
tf.train.Feature: A TensorFlow feature containing the int64 value(s).
"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes: def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
"""
Serialize a dictionary of TensorFlow tensors into a binary string.
Args:
x (Dict[str, tf.Tensor]): A dictionary of TensorFlow tensors.
Returns:
bytes: The serialized binary string.
"""
feature = {} feature = {}
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature} serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
for feature_name, tensor in x.items(): for feature_name, tensor in x.items():
@ -51,6 +90,15 @@ def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
def generate_data(data_path: str, config: recap_config_mod.RecapConfig): def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
"""
Generate random data based on the provided configuration and save it as a TFRecord file.
Args:
data_path (str): The path where the TFRecord file will be saved.
config (recap_config_mod.RecapConfig): The configuration for generating the random data.
"""
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f: with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"] seg_dense_schema = json.load(f)["schema"]
@ -68,6 +116,12 @@ def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
def _generate_data_main(unused_argv): def _generate_data_main(unused_argv):
"""
Main function to generate random data according to the provided configuration.
Args:
unused_argv: Unused command-line arguments.
"""
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path) config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
# Find the path where to put the data # Find the path where to put the data

View File

@ -9,9 +9,20 @@ import numpy as np
class TruncateAndSlice(tf.keras.Model): class TruncateAndSlice(tf.keras.Model):
"""Class for truncating and slicing.""" """
A class for truncating and slicing input features based on the provided configuration.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
def __init__(self, truncate_and_slice_config): def __init__(self, truncate_and_slice_config):
"""
Initializes the TruncateAndSlice model.
Args:
truncate_and_slice_config: A configuration object specifying how to truncate and slice features.
"""
super().__init__() super().__init__()
self._truncate_and_slice_config = truncate_and_slice_config self._truncate_and_slice_config = truncate_and_slice_config
@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model):
self._binary_mask = None self._binary_mask = None
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Applies truncation and slicing to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of truncated and sliced input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
if self._truncate_and_slice_config.continuous_feature_truncation: if self._truncate_and_slice_config.continuous_feature_truncation:
logging.info("Truncating continuous") logging.info("Truncating continuous")
@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model):
class DownCast(tf.keras.Model): class DownCast(tf.keras.Model):
"""Class for Down casting dataset before serialization and transferring to training host. """
Depends on the data type and the actual data range, the down casting can be lossless or not. A class for downcasting dataset before serialization and transferring to the training host.
It is strongly recommended to compare the metrics before and after down casting.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
""" """
def __init__(self, downcast_config): def __init__(self, downcast_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__() super().__init__()
self.config = downcast_config self.config = downcast_config
self._type_map = { self._type_map = {
@ -65,6 +98,17 @@ class DownCast(tf.keras.Model):
} }
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
for feature, type_str in self.config.features.items(): for feature, type_str in self.config.features.items():
assert type_str in self._type_map assert type_str in self._type_map
@ -78,14 +122,39 @@ class DownCast(tf.keras.Model):
class RectifyLabels(tf.keras.Model): class RectifyLabels(tf.keras.Model):
"""Class for rectifying labels""" """
A class for downcasting dataset before serialization and transferring to the training host.
Depending on the data type and the actual data range, the downcasting can be lossless or not.
It is strongly recommended to compare the metrics before and after downcasting.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
def __init__(self, rectify_label_config): def __init__(self, rectify_label_config):
"""
Initializes the DownCast model.
Args:
downcast_config: A configuration object specifying the features and their target data types.
"""
super().__init__() super().__init__()
self._config = rectify_label_config self._config = rectify_label_config
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000) self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Applies downcasting to the input features based on the configuration.
Args:
inputs: A dictionary of input features.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of downcasted input features.
"""
served_ts_field = self._config.served_timestamp_field served_ts_field = self._config.served_timestamp_field
impressed_ts_field = self._config.impressed_timestamp_field impressed_ts_field = self._config.impressed_timestamp_field
@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model):
class ExtractFeatures(tf.keras.Model): class ExtractFeatures(tf.keras.Model):
"""Class for extracting individual features from dense tensors by their index.""" """
A class for rectifying labels based on specified conditions.
This class is used to adjust label values in a dataset based on configured conditions involving timestamps.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
def __init__(self, extract_features_config): def __init__(self, extract_features_config):
"""
Initializes the RectifyLabels model.
Args:
rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings.
"""
super().__init__() super().__init__()
self._config = extract_features_config self._config = extract_features_config
def call(self, inputs, training=None, mask=None): def call(self, inputs, training=None, mask=None):
"""
Rectifies label values based on the specified conditions.
Args:
inputs: A dictionary of input features including timestamp fields and labels.
training: A boolean indicating whether the model is in training mode.
mask: A mask tensor.
Returns:
A dictionary of input features with rectified label values.
"""
for row in self._config.extract_feature_table: for row in self._config.extract_feature_table:
inputs[row.name] = inputs[row.source_tensor][:, row.index] inputs[row.name] = inputs[row.source_tensor][:, row.index]
@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model):
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN): def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
"""Builds a preprocess model to apply all preprocessing stages.""" """
Builds a preprocess model to apply all preprocessing stages.
Args:
preprocess_config: A configuration object specifying the preprocessing parameters.
mode: A mode indicating the current job mode (TRAIN or INFERENCE).
Returns:
A preprocess model that applies all specified preprocessing stages.
"""
if mode == config_mod.JobMode.INFERENCE: if mode == config_mod.JobMode.INFERENCE:
logging.info("Not building preprocessors for dataloading since we are in Inference mode.") logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
return None return None

View File

@ -8,7 +8,8 @@ import tensorflow as tf
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""} DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string} DTYPE_MAP = {"int64_list": tf.int64,
"float_list": tf.float32, "bytes_list": tf.string}
def create_tf_example_schema( def create_tf_example_schema(
@ -27,7 +28,8 @@ def create_tf_example_schema(
segdense_config = data_config.seg_dense_schema segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys()) labels = list(data_config.tasks.keys())
used_features = ( used_features = (
segdense_config.features + list(segdense_config.renamed_features.values()) + labels segdense_config.features +
list(segdense_config.renamed_features.values()) + labels
) )
logging.info(used_features) logging.info(used_features)
@ -40,19 +42,22 @@ def create_tf_example_schema(
dtype = entry["dtype"] dtype = entry["dtype"]
if feature_name in labels: if feature_name in labels:
logging.info(f"Label: feature name is {feature_name} type is {dtype}") logging.info(
f"Label: feature name is {feature_name} type is {dtype}")
tfe_schema[feature_name] = tf.io.FixedLenFeature( tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype] length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
) )
elif length == -1: elif length == -1:
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype]) tfe_schema[feature_name] = tf.io.VarLenFeature(
DTYPE_MAP[dtype])
else: else:
tfe_schema[feature_name] = tf.io.FixedLenFeature( tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
) )
for feature_name in used_features: for feature_name in used_features:
if feature_name not in tfe_schema: if feature_name not in tfe_schema:
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.") raise ValueError(
f"{feature_name} missing from schema: {segdense_config.schema_path}.")
return tfe_schema return tfe_schema
@ -82,7 +87,8 @@ def parse_tf_example(
Returns: Returns:
Dictionary of tensors to be used as model input. Dictionary of tensors to be used as model input.
""" """
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema) inputs = tf.io.parse_example(
serialized=serialized_example, features=tfe_schema)
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items(): for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
inputs[new_feature_name] = inputs.pop(old_feature_name) inputs[new_feature_name] = inputs.pop(old_feature_name)
@ -90,7 +96,8 @@ def parse_tf_example(
# This should not actually be used except for experimentation with low precision floats. # This should not actually be used except for experimentation with low precision floats.
if "mask_mantissa_features" in seg_dense_schema_config: if "mask_mantissa_features" in seg_dense_schema_config:
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items(): for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length) inputs[feature_name] = mask_mantissa(
inputs[feature_name], mask_length)
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
# at TF level. # at TF level.

View File

@ -9,44 +9,59 @@ def keyed_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor] tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedTensor": ) -> "torchrec.KeyedTensor":
""" """
Convert a dictionary of torch tensor to torchrec keyed tensor Convert a dictionary of torch tensors to a torchrec KeyedTensor.
Args: Args:
tensor_map: tensor_map: A mapping of tensor names to torch tensors.
Returns: Returns:
A torchrec KeyedTensor.
""" """
keys = list(tensor_map.keys()) keys = list(tensor_map.keys())
# We expect batch size to be first dim. However, if we get a shape [Batch_size], # We expect batch size to be first dim. However, if we get a shape [Batch_size],
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
# [Batch_size x 1]. # [Batch_size x 1].
values = [ values = [
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1) tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(
tensor_map[key], -1)
for key in keys for key in keys
] ]
return torchrec.KeyedTensor.from_tensor_list(keys, values) return torchrec.KeyedTensor.from_tensor_list(keys, values)
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute a jagged tensor from a torch tensor.
Args:
tensor: Input torch tensor.
Returns:
A tuple containing the values and lengths of the jagged tensor.
"""
if tensor.is_sparse: if tensor.is_sparse:
x = tensor.coalesce() # Ensure that the indices are ordered. x = tensor.coalesce() # Ensure that the indices are ordered.
lengths = torch.bincount(x.indices()[0]) lengths = torch.bincount(x.indices()[0])
values = x.values() values = x.values()
else: else:
values = tensor values = tensor
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device) lengths = torch.ones(
tensor.shape[0], dtype=torch.int32, device=tensor.device)
return values, lengths return values, lengths
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor": def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
""" """
Convert a torch tensor to torchrec jagged tensor. Convert a torch tensor to a torchrec jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors.
dense_shape of the sparse tensor can be arbitrary. For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary.
Args: Args:
tensor: a torch (sparse) tensor. tensor: A torch (sparse) tensor.
Returns: Returns:
A torchrec JaggedTensor.
""" """
values, lengths = _compute_jagged_tensor_from_tensor(tensor) values, lengths = _compute_jagged_tensor_from_tensor(tensor)
return torchrec.JaggedTensor(values=values, lengths=lengths) return torchrec.JaggedTensor(values=values, lengths=lengths)
@ -56,15 +71,16 @@ def keyed_jagged_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor] tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedJaggedTensor": ) -> "torchrec.KeyedJaggedTensor":
""" """
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor. Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors.
dense_shape of the sparse tensor can be arbitrary. For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary.
Args: Args:
tensor_map: tensor_map: A mapping of tensor names to torch tensors.
Returns: Returns:
A torchrec KeyedJaggedTensor.
""" """
if not tensor_map: if not tensor_map:
@ -91,10 +107,29 @@ def keyed_jagged_tensor_from_tensors_dict(
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray: def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
"""
Convert a TensorFlow tensor to a NumPy array.
Args:
tf_tensor: TensorFlow tensor.
Returns:
NumPy array.
"""
return tf_tensor._numpy() # noqa return tf_tensor._numpy() # noqa
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor: def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
"""
Convert a dense TensorFlow tensor to a PyTorch tensor.
Args:
tensor: Dense TensorFlow tensor.
pin_memory: Whether to pin the tensor in memory (for CUDA).
Returns:
PyTorch tensor.
"""
tensor = _tf_to_numpy(tensor) tensor = _tf_to_numpy(tensor)
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent # Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
if tensor.dtype.name == "bfloat16": if tensor.dtype.name == "bfloat16":
@ -109,6 +144,16 @@ def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
def sparse_or_dense_tf_to_torch( def sparse_or_dense_tf_to_torch(
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
) -> torch.Tensor: ) -> torch.Tensor:
"""
Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor.
Args:
tensor: TensorFlow tensor (sparse or dense).
pin_memory: Whether to pin the tensor in memory (for CUDA).
Returns:
PyTorch tensor.
"""
if isinstance(tensor, tf.SparseTensor): if isinstance(tensor, tf.SparseTensor):
tensor = torch.sparse_coo_tensor( tensor = torch.sparse_coo_tensor(
_dense_tf_to_torch(tensor.indices, pin_memory).t(), _dense_tf_to_torch(tensor.indices, pin_memory).t(),

View File

@ -30,6 +30,7 @@ class EmbeddingBagConfig(base_config.BaseConfig):
class EmbeddingOptimizerConfig(base_config.BaseConfig): class EmbeddingOptimizerConfig(base_config.BaseConfig):
"""Configuration for the optimizer used for embedding tables."""
learning_rate: optimizer_config.LearningRate = pydantic.Field( learning_rate: optimizer_config.LearningRate = pydantic.Field(
None, description="learning rate scheduler for the EBC" None, description="learning rate scheduler for the EBC"
) )
@ -52,6 +53,7 @@ class LargeEmbeddingsConfig(base_config.BaseConfig):
class StratifierConfig(base_config.BaseConfig): class StratifierConfig(base_config.BaseConfig):
"""Configuration for Stratifier."""
name: str name: str
index: int index: int
value: int value: int
@ -87,7 +89,8 @@ class SmallEmbeddingsConfig(base_config.BaseConfig):
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
serving time due to size (>>1 GB). serving time due to size (>>1 GB).
This small embeddings table uses the same optimizer as the rest of the model.""" This small embeddings table uses the same optimizer as the rest of the model.
"""
tables: List[SmallEmbeddingBagConfig] = pydantic.Field( tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
..., description="list of embedding tables" ..., description="list of embedding tables"

View File

@ -34,6 +34,33 @@ FLAGS = flags.FLAGS
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None): def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
"""
Main function to run the training of a ranking model.
This function initializes and runs the training process for a ranking model based on the provided configuration.
Args:
unused_argv (str): Unused argument.
data_service_dispatcher (Optional[str]): The data service dispatcher for accessing training data (optional).
Returns:
None
Raises:
AssertionError: If the configuration or input data is not valid.
Note:
This function serves as the main entry point for training a ranking model. It loads the configuration, sets up
the training environment, defines the loss function, creates the model, optimizer, and scheduler, and runs the
training loop.
Example:
To run the training process, use the following command:
```
python run_training.py --config_path=config.yaml
```
"""
print("#" * 100) print("#" * 100)
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path) config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)

View File

@ -40,6 +40,7 @@ class BatchNormConfig(base_config.BaseConfig):
class DenseLayerConfig(base_config.BaseConfig): class DenseLayerConfig(base_config.BaseConfig):
"""Configuration for the dense layer."""
layer_size: pydantic.PositiveInt layer_size: pydantic.PositiveInt
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.") dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
@ -61,6 +62,7 @@ class BatchNormConfig(base_config.BaseConfig):
class DoubleNormLogConfig(base_config.BaseConfig): class DoubleNormLogConfig(base_config.BaseConfig):
"""Configuration for the double norm log transform."""
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None) batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
clip_magnitude: float = pydantic.Field( clip_magnitude: float = pydantic.Field(
5.0, description="Threshold to clip the normalized input values." 5.0, description="Threshold to clip the normalized input values."
@ -73,12 +75,14 @@ class Log1pAbsConfig(base_config.BaseConfig):
class ClipLog1pAbsConfig(base_config.BaseConfig): class ClipLog1pAbsConfig(base_config.BaseConfig):
"""Configuration for the clip log transform."""
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field( clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
3e38, description="Threshold to clip the input values." 3e38, description="Threshold to clip the input values."
) )
class ZScoreLogConfig(base_config.BaseConfig): class ZScoreLogConfig(base_config.BaseConfig):
"""Configuration for the z-score log transform."""
analysis_path: str analysis_path: str
schema_path: str = pydantic.Field( schema_path: str = pydantic.Field(
None, None,
@ -148,6 +152,7 @@ class DcnConfig(base_config.BaseConfig):
class MaskBlockConfig(base_config.BaseConfig): class MaskBlockConfig(base_config.BaseConfig):
"""Config for MaskNet block."""
output_size: int output_size: int
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field( reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
None, one_of="aggregation_size" None, one_of="aggregation_size"
@ -159,6 +164,7 @@ class MaskBlockConfig(base_config.BaseConfig):
class MaskNetConfig(base_config.BaseConfig): class MaskNetConfig(base_config.BaseConfig):
"""Config for MaskNet model."""
mask_blocks: List[MaskBlockConfig] mask_blocks: List[MaskBlockConfig]
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel") mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.") use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
@ -190,6 +196,7 @@ class AffineMap(base_config.BaseConfig):
class DLRMConfig(base_config.BaseConfig): class DLRMConfig(base_config.BaseConfig):
"""Config for DLRM model."""
bottom_mlp: MlpConfig = pydantic.Field( bottom_mlp: MlpConfig = pydantic.Field(
..., ...,
description="Bottom mlp, the output to be combined with sparse features and feed to interaction", description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
@ -198,6 +205,7 @@ class DLRMConfig(base_config.BaseConfig):
class TaskModel(base_config.BaseConfig): class TaskModel(base_config.BaseConfig):
"""Configuration for a single task."""
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture") mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture") dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture") dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
@ -213,6 +221,7 @@ class TaskModel(base_config.BaseConfig):
class MultiTaskType(str, enum.Enum): class MultiTaskType(str, enum.Enum):
"""Type of multi task architecture."""
SHARE_NONE = "share_none" # Tasks are separate. SHARE_NONE = "share_none" # Tasks are separate.
SHARE_ALL = "share_all" # Tasks share same backbone. SHARE_ALL = "share_all" # Tasks share same backbone.
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions. SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
@ -247,6 +256,7 @@ class ModelConfig(base_config.BaseConfig):
@pydantic.root_validator() @pydantic.root_validator()
def _validate_mtl(cls, values): def _validate_mtl(cls, values):
"""Validate the multi task architecture."""
if values.get("multi_task_type", None) is None: if values.get("multi_task_type", None) is None:
return values return values
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]: elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:

View File

@ -26,7 +26,19 @@ def unsanitize(sanitized_task_name):
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int): def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
""" "Builds a model for a single task""" """
Build a model for a single task based on the provided configuration.
Args:
task (model_config_mod.TaskModel): The task model configuration.
input_shape (int): The input shape for the model.
Returns:
torch.nn.Module: The constructed model for the single task.
Raises:
ValueError: If the task configuration is not recognized.
"""
if task.mlp_config: if task.mlp_config:
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config) return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
elif task.dcn_config: elif task.dcn_config:
@ -38,7 +50,12 @@ def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int)
class MultiTaskRankingModel(torch.nn.Module): class MultiTaskRankingModel(torch.nn.Module):
"""Multi-task ranking model.""" """
Multi-task ranking model that handles multiple ranking tasks simultaneously.
This model takes various input features and predicts rankings for multiple
tasks using shared or separate towers.
"""
def __init__( def __init__(
self, self,
@ -47,11 +64,17 @@ class MultiTaskRankingModel(torch.nn.Module):
data_config: RecapDataConfig, data_config: RecapDataConfig,
return_backbone: bool = False, return_backbone: bool = False,
): ):
"""Constructor for Multi task learning. """
Constructor for Multi-task ranking model.
Args:
input_shapes (Mapping[str, torch.Size]): A mapping of input feature names to their shapes.
config (ModelConfig): The model configuration.
data_config (RecapDataConfig): The data configuration.
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
Assumptions made: Assumptions made:
1. Tasks specified in data config match model architecture. 1. Tasks specified in data config match model architecture.
These are all validated in config. These are all validated in config.
""" """
super().__init__() super().__init__()
@ -168,6 +191,23 @@ class MultiTaskRankingModel(torch.nn.Module):
labels: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None,
weights: Optional[torch.Tensor] = None, weights: Optional[torch.Tensor] = None,
): ):
"""
Forward pass of the Multi-task ranking model.
Args:
continuous_features (torch.Tensor): Continuous input features.
binary_features (torch.Tensor): Binary input features.
discrete_features (Optional[torch.Tensor], optional): Discrete input features. Defaults to None.
sparse_features ([type], optional): Sparse input features. Defaults to None.
user_embedding (Optional[torch.Tensor], optional): User embeddings. Defaults to None.
user_eng_embedding (Optional[torch.Tensor], optional): User engagement embeddings. Defaults to None.
author_embedding (Optional[torch.Tensor], optional): Author embeddings. Defaults to None.
labels (Optional[torch.Tensor], optional): Target labels. Defaults to None.
weights (Optional[torch.Tensor], optional): Weights for the loss function. Defaults to None.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the model's outputs.
"""
concat_dense_features = [ concat_dense_features = [
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features) self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
] ]
@ -270,6 +310,20 @@ def create_ranking_model(
data_config=None, data_config=None,
return_backbone=False, return_backbone=False,
): ):
"""
Creates a ranking model based on the provided specifications and configuration.
Args:
data_spec: The input data specifications.
config (config_mod.RecapConfig): The model configuration.
device (torch.device): The device where the model should be placed.
loss_fn (Optional[Callable], optional): A custom loss function. Defaults to None.
data_config: The data configuration. Defaults to None.
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
Returns:
torch.nn.Module: The created ranking model.
"""
if list(config.model.tasks.values())[0].dlrm_config: if list(config.model.tasks.values())[0].dlrm_config:
raise NotImplementedError() raise NotImplementedError()

View File

@ -11,21 +11,52 @@ import torch
def log_transform(x: torch.Tensor) -> torch.Tensor: def log_transform(x: torch.Tensor) -> torch.Tensor:
"""Safe log transform that works across both negative, zero, and positive floats.""" """
Safe log transform that works across both negative, zero, and positive floats.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with log1p applied to absolute values.
"""
return torch.sign(x) * torch.log1p(torch.abs(x)) return torch.sign(x) * torch.log1p(torch.abs(x))
class BatchNorm(torch.nn.Module): class BatchNorm(torch.nn.Module):
def __init__(self, num_features: int, config: BatchNormConfig): def __init__(self, num_features: int, config: BatchNormConfig):
"""
Batch normalization layer.
Args:
num_features (int): Number of input features.
config (BatchNormConfig): Configuration for batch normalization.
"""
super().__init__() super().__init__()
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum) self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the batch normalization layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after batch normalization.
"""
return self.layer(x) return self.layer(x)
class LayerNorm(torch.nn.Module): class LayerNorm(torch.nn.Module):
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig): def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
"""
Layer normalization layer.
Args:
normalized_shape (Union[int, Sequence[int]]): Size or shape of the input tensor.
config (LayerNormConfig): Configuration for layer normalization.
"""
super().__init__() super().__init__()
if config.axis != -1: if config.axis != -1:
raise NotImplementedError raise NotImplementedError
@ -38,6 +69,16 @@ class LayerNorm(torch.nn.Module):
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the layer normalization layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after layer normalization.
"""
return self.layer(x) return self.layer(x)
@ -46,11 +87,27 @@ class Log1pAbs(torch.nn.Module):
super().__init__() super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that applies a log transformation to the input tensor.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with log applied to absolute values.
"""
return log_transform(x) return log_transform(x)
class InputNonFinite(torch.nn.Module): class InputNonFinite(torch.nn.Module):
def __init__(self, fill_value: float = 0): def __init__(self, fill_value: float = 0):
"""
Replaces non-finite (NaN and Inf) values in the input tensor with a specified fill value.
Args:
fill_value (float): The value to fill non-finite elements with. Default is 0.
"""
super().__init__() super().__init__()
self.register_buffer( self.register_buffer(
@ -58,11 +115,27 @@ class InputNonFinite(torch.nn.Module):
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that replaces non-finite values in the input tensor with the specified fill value.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with non-finite values replaced.
"""
return torch.where(torch.isfinite(x), x, self.fill_value) return torch.where(torch.isfinite(x), x, self.fill_value)
class Clamp(torch.nn.Module): class Clamp(torch.nn.Module):
def __init__(self, min_value: float, max_value: float): def __init__(self, min_value: float, max_value: float):
"""
Applies element-wise clamping to a tensor, ensuring that values are within a specified range.
Args:
min_value (float): The minimum value to clamp elements to.
max_value (float): The maximum value to clamp elements to.
"""
super().__init__() super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time). # Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict. # Will also be part of state_dict.
@ -74,12 +147,31 @@ class Clamp(torch.nn.Module):
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass that clamps the input tensor element-wise within the specified range.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with elements clamped within the specified range.
"""
return torch.clamp(x, min=self.min_value, max=self.max_value) return torch.clamp(x, min=self.min_value, max=self.max_value)
class DoubleNormLog(torch.nn.Module): class DoubleNormLog(torch.nn.Module):
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.""" """
Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.
Args:
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
config (DoubleNormLogConfig): Configuration for the DoubleNormLog module.
Attributes:
_before_concat_layers (torch.nn.Sequential): Sequential layers for batch normalization, log transformation,
batch normalization (optional), and clamping.
layer_norm (LayerNorm or None): Layer normalization layer for binary and continuous features (optional).
"""
def __init__( def __init__(
self, self,
input_shapes: Mapping[str, Sequence[int]], input_shapes: Mapping[str, Sequence[int]],
@ -108,6 +200,17 @@ class DoubleNormLog(torch.nn.Module):
def forward( def forward(
self, continuous_features: torch.Tensor, binary_features: torch.Tensor self, continuous_features: torch.Tensor, binary_features: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
"""
Forward pass that processes continuous and binary features using batch normalization, log transformation,
optional batch normalization (if configured), clamping, and layer normalization (if configured).
Args:
continuous_features (torch.Tensor): Input tensor of continuous features.
binary_features (torch.Tensor): Input tensor of binary features.
Returns:
torch.Tensor: Transformed tensor containing both continuous and binary features.
"""
x = self._before_concat_layers(continuous_features) x = self._before_concat_layers(continuous_features)
x = torch.cat([x, binary_features], dim=1) x = torch.cat([x, binary_features], dim=1)
if self.layer_norm: if self.layer_norm:
@ -118,5 +221,15 @@ class DoubleNormLog(torch.nn.Module):
def build_features_preprocessor( def build_features_preprocessor(
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]] config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
): ):
"""Trivial right now, but we will change in the future.""" """
Build a feature preprocessor module based on the provided configuration.
Trivial right now, but we will change in the future.
Args:
config (FeaturizationConfig): Configuration for feature preprocessing.
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
Returns:
DoubleNormLog: An instance of the DoubleNormLog feature preprocessor.
"""
return DoubleNormLog(input_shapes, config.double_norm_log_config) return DoubleNormLog(input_shapes, config.double_norm_log_config)

View File

@ -6,15 +6,85 @@ import torch
def _init_weights(module): def _init_weights(module):
"""Initializes weights
Example
```python
import torch
import torch.nn as nn
# Define a simple linear layer
linear_layer = nn.Linear(64, 32)
# Initialize the weights and biases using _init_weights
_init_weights(linear_layer)
```
"""
if isinstance(module, torch.nn.Linear): if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight) torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0) torch.nn.init.constant_(module.bias, 0)
class MaskBlock(torch.nn.Module): class MaskBlock(torch.nn.Module):
"""
MaskBlock module in a mask-based neural network.
This module represents a MaskBlock, which applies a masking operation to the input data and then
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
Example:
To create and use a MaskBlock within a MaskNet, follow these steps:
```python
# Define the configuration for the MaskBlock
mask_block_config = MaskBlockConfig(
input_layer_norm=True, # Apply input layer normalization
reduction_factor=0.5 # Reduce input dimensionality by 50%
)
# Create an instance of the MaskBlock
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
mask_input = torch.randn(batch_size, 32)
# Perform a forward pass through the MaskBlock
output = mask_block(input_data, mask_input)
```
Note:
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
operation that combines the input and mask input. Then, it passes the result through a hidden layer
with optional dimensionality reduction.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__( def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None: ) -> None:
"""
Initializes the MaskBlock module.
Args:
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
input_dim (int): Dimensionality of the input data.
mask_input_dim (int): Dimensionality of the mask input.
Returns:
None
"""
super(MaskBlock, self).__init__() super(MaskBlock, self).__init__()
self.mask_block_config = mask_block_config self.mask_block_config = mask_block_config
output_size = mask_block_config.output_size output_size = mask_block_config.output_size
@ -25,11 +95,13 @@ class MaskBlock(torch.nn.Module):
self._input_layer_norm = None self._input_layer_norm = None
if mask_block_config.reduction_factor: if mask_block_config.reduction_factor:
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor) aggregation_size = int(
mask_input_dim * mask_block_config.reduction_factor)
elif mask_block_config.aggregation_size is not None: elif mask_block_config.aggregation_size is not None:
aggregation_size = mask_block_config.aggregation_size aggregation_size = mask_block_config.aggregation_size
else: else:
raise ValueError("Need one of reduction factor or aggregation size.") raise ValueError(
"Need one of reduction factor or aggregation size.")
self._mask_layer = torch.nn.Sequential( self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size), torch.nn.Linear(mask_input_dim, aggregation_size),
@ -42,14 +114,79 @@ class MaskBlock(torch.nn.Module):
self._layer_norm = torch.nn.LayerNorm(output_size) self._layer_norm = torch.nn.LayerNorm(output_size)
def forward(self, net: torch.Tensor, mask_input: torch.Tensor): def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
"""
Performs a forward pass through the MaskBlock.
Args:
net (torch.Tensor): Input data tensor.
mask_input (torch.Tensor): Mask input tensor.
Returns:
torch.Tensor: Output tensor of the MaskBlock.
"""
if self._input_layer_norm: if self._input_layer_norm:
net = self._input_layer_norm(net) net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input)) hidden_layer_output = self._hidden_layer(
net * self._mask_layer(mask_input))
return self._layer_norm(hidden_layer_output) return self._layer_norm(hidden_layer_output)
class MaskNet(torch.nn.Module): class MaskNet(torch.nn.Module):
"""
MaskNet module in a mask-based neural network.
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
create mask-based neural networks with parallel or stacked MaskBlocks.
Args:
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
Example:
To create and use a MaskNet, you can follow these steps:
```python
# Define the configuration for the MaskNet
mask_net_config = MaskNetConfig(
use_parallel=True, # Use parallel MaskBlocks
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
)
# Create an instance of the MaskNet
mask_net = MaskNet(mask_net_config, in_features=64)
# Generate input tensors
input_data = torch.randn(batch_size, 64)
# Perform a forward pass through the MaskNet
outputs = mask_net(input_data)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int): def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
"""
Initializes the MaskNet module.
Args:
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
in_features (int): Dimensionality of the input data.
Returns:
None
"""
super().__init__() super().__init__()
self.mask_net_config = mask_net_config self.mask_net_config = mask_net_config
mask_blocks = [] mask_blocks = []
@ -57,26 +194,38 @@ class MaskNet(torch.nn.Module):
if mask_net_config.use_parallel: if mask_net_config.use_parallel:
total_output_mask_blocks = 0 total_output_mask_blocks = 0
for mask_block_config in mask_net_config.mask_blocks: for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features)) mask_blocks.append(
MaskBlock(mask_block_config, in_features, in_features))
total_output_mask_blocks += mask_block_config.output_size total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks) self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else: else:
input_size = in_features input_size = in_features
for mask_block_config in mask_net_config.mask_blocks: for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features)) mask_blocks.append(
MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size input_size = mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks) self._mask_blocks = torch.nn.ModuleList(mask_blocks)
total_output_mask_blocks = mask_block_config.output_size total_output_mask_blocks = mask_block_config.output_size
if mask_net_config.mlp: if mask_net_config.mlp:
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp) self._dense_layers = mlp.Mlp(
total_output_mask_blocks, mask_net_config.mlp)
self.out_features = mask_net_config.mlp.layer_sizes[-1] self.out_features = mask_net_config.mlp.layer_sizes[-1]
else: else:
self.out_features = total_output_mask_blocks self.out_features = total_output_mask_blocks
self.shared_size = total_output_mask_blocks self.shared_size = total_output_mask_blocks
def forward(self, inputs: torch.Tensor): def forward(self, inputs: torch.Tensor):
"""
Performs a forward pass through the MaskNet.
Args:
inputs (torch.Tensor): Input data tensor.
Returns:
torch.Tensor: Output tensor of the MaskNet.
"""
if self.mask_net_config.use_parallel: if self.mask_net_config.use_parallel:
mask_outputs = [] mask_outputs = []
for mask_layer in self._mask_blocks: for mask_layer in self._mask_blocks:
@ -94,5 +243,6 @@ class MaskNet(torch.nn.Module):
for mask_layer in self._mask_blocks: for mask_layer in self._mask_blocks:
net = mask_layer(net=net, mask_input=inputs) net = mask_layer(net=net, mask_input=inputs)
# Share the output of the stacked MaskBlocks. # Share the output of the stacked MaskBlocks.
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"] output = net if self.mask_net_config.mlp is None else self._dense_layers[
net]["output"]
return {"output": output, "shared_layer": net} return {"output": output, "shared_layer": net}

View File

@ -7,13 +7,83 @@ from absl import logging
def _init_weights(module): def _init_weights(module):
"""Initializes weights
Example
-------
```python
import torch
import torch.nn as nn
# Define a simple linear layer
linear_layer = nn.Linear(64, 32)
# Initialize the weights and biases using _init_weights
_init_weights(linear_layer)
```
"""
if isinstance(module, torch.nn.Linear): if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight) torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0) torch.nn.init.constant_(module.bias, 0)
class Mlp(torch.nn.Module): class Mlp(torch.nn.Module):
"""
Multi-Layer Perceptron (MLP) feedforward neural network module in PyTorch.
This module defines an MLP with customizable layers and activation functions. It is suitable for various
applications such as deep learning for tabular data, feature extraction, and more.
Args:
in_features (int): The number of input features or input dimensions.
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
Example:
To create an instance of the `Mlp` module and use it for forward passes, you can follow these steps:
```python
# Define the configuration for the MLP
mlp_config = MlpConfig(
layer_sizes=[128, 64], # Specify the sizes of hidden layers
batch_norm=True, # Enable batch normalization
dropout=0.2, # Apply dropout with a rate of 0.2
final_layer_activation=True # Apply ReLU activation to the final layer
)
# Create an instance of the MLP module
mlp_model = Mlp(in_features=input_dim, mlp_config=mlp_config)
# Generate an input tensor
input_tensor = torch.randn(batch_size, input_dim)
# Perform a forward pass through the MLP
outputs = mlp_model(input_tensor)
# Access the output and shared layer
output = outputs["output"]
shared_layer = outputs["shared_layer"]
```
Note:
The `Mlp` class allows you to create customizable MLP architectures by specifying the layer sizes,
enabling batch normalization and dropout, and choosing the activation function for the final layer.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__(self, in_features: int, mlp_config: MlpConfig): def __init__(self, in_features: int, mlp_config: MlpConfig):
"""
Initializes the Mlp module.
Args:
in_features (int): The number of input features or input dimensions.
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
Returns:
None
"""
super().__init__() super().__init__()
self._mlp_config = mlp_config self._mlp_config = mlp_config
input_size = in_features input_size = in_features
@ -42,6 +112,15 @@ class Mlp(torch.nn.Module):
self.layers.apply(_init_weights) self.layers.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs a forward pass through the MLP.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: Output tensor of the MLP.
"""
net = x net = x
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
net = layer(net) net = layer(net)
@ -51,8 +130,21 @@ class Mlp(torch.nn.Module):
@property @property
def shared_size(self): def shared_size(self):
"""
Returns the size of the shared layer in the MLP.
Returns:
int: Size of the shared layer.
"""
return self._mlp_config.layer_sizes[-1] return self._mlp_config.layer_sizes[-1]
@property @property
def out_features(self): def out_features(self):
"""
Returns the number of output features from the MLP.
Returns:
int: Number of output features.
"""
return self._mlp_config.layer_sizes[-1] return self._mlp_config.layer_sizes[-1]

View File

@ -5,6 +5,54 @@ from absl import logging
class ModelAndLoss(torch.nn.Module): class ModelAndLoss(torch.nn.Module):
"""
PyTorch module that combines a neural network model and loss function.
This module wraps a neural network model and facilitates the forward pass through the model
while also calculating the loss based on the model's predictions and provided labels.
Args:
model: The torch module to wrap.
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification. Each stratifier config includes the name and index of discrete features
to emit for stratification.
Example:
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
and loss function as arguments:
```python
# Create a neural network model
model = YourNeuralNetworkModel()
# Define a loss function
loss_fn = torch.nn.CrossEntropyLoss()
# Create an instance of ModelAndLoss
model_and_loss = ModelAndLoss(model, loss_fn)
# Generate a batch of training data (e.g., RecapBatch)
batch = generate_training_batch()
# Perform a forward pass through the model and calculate the loss
loss, outputs = model_and_loss(batch)
# You can now backpropagate and optimize using the computed loss
loss.backward()
optimizer.step()
```
Note:
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
calculating loss, making it easier to integrate the model into your training loop. Additionally,
it supports the addition of stratifiers for metrics stratification, if needed.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__( def __init__(
self, self,
model, model,
@ -12,10 +60,13 @@ class ModelAndLoss(torch.nn.Module):
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None, stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
) -> None: ) -> None:
""" """
Initializes the ModelAndLoss module.
Args: Args:
model: torch module to wrap. model: The torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels. loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification. stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
for metrics stratification.
""" """
super().__init__() super().__init__()
self.model = model self.model = model
@ -42,13 +93,16 @@ class ModelAndLoss(torch.nn.Module):
labels=batch.labels, labels=batch.labels,
weights=batch.weights, weights=batch.weights,
) )
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float()) losses = self.loss_fn(
outputs["logits"], batch.labels.float(), batch.weights.float())
if self.stratifiers: if self.stratifiers:
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}") logging.info(
f"***** Adding stratifiers *****\n {self.stratifiers}")
outputs["stratifiers"] = {} outputs["stratifiers"] = {}
for stratifier in self.stratifiers: for stratifier in self.stratifiers:
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index] outputs["stratifiers"][stratifier.name] = batch.discrete_features[:,
stratifier.index]
# In general, we can have a large number of losses returned by our loss function. # In general, we can have a large number of losses returned by our loss function.
if isinstance(losses, dict): if isinstance(losses, dict):

View File

@ -2,11 +2,58 @@ import torch
class NumericCalibration(torch.nn.Module): class NumericCalibration(torch.nn.Module):
"""
Numeric calibration module for adjusting probability scores.
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
for tasks such as binary classification.
Args:
pos_downsampling_rate (float): The downsampling rate for positive samples.
neg_downsampling_rate (float): The downsampling rate for negative samples.
Example:
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
scores like this:
```python
# Create a NumericCalibration instance with downsampling rates
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
# Generate probability scores (e.g., from a neural network)
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
# Apply numeric calibration to adjust the probabilities
calibrated_probs = calibration(raw_probs)
# The `calibrated_probs` now contains the adjusted probability scores
```
Note:
The `NumericCalibration` module is used to adjust probability scores to account for differences in
the number of positive and negative samples in a dataset. It can help improve the calibration of
probability estimates in imbalanced classification problems.
Warning:
This class is intended for internal use within neural network architectures and should not be
directly accessed or modified by external code.
"""
def __init__( def __init__(
self, self,
pos_downsampling_rate: float, pos_downsampling_rate: float,
neg_downsampling_rate: float, neg_downsampling_rate: float,
): ):
"""
Apply numeric calibration to probability scores.
Args:
probs (torch.Tensor): Probability scores to be calibrated.
Returns:
torch.Tensor: Calibrated probability scores.
"""
super().__init__() super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time). # Using buffer to make sure they are on correct device (and not moved every time).

View File

@ -9,12 +9,60 @@ import pydantic
class RecapAdamConfig(base_config.BaseConfig): class RecapAdamConfig(base_config.BaseConfig):
"""
Configuration settings for the Adam optimizer used in Recap.
Args:
beta_1 (float): Momentum term (default: 0.9).
beta_2 (float): Exponential weighted decay factor (default: 0.999).
epsilon (float): Numerical stability in the denominator (default: 1e-7).
Example:
To define an Adam optimizer configuration for Recap, use:
```python
adam_config = RecapAdamConfig(beta_1=0.9, beta_2=0.999, epsilon=1e-7)
```
Note:
This class configures the parameters of the Adam optimizer, which is commonly used for optimizing neural networks.
Warning:
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
"""
beta_1: float = 0.9 # Momentum term. beta_1: float = 0.9 # Momentum term.
beta_2: float = 0.999 # Exponential weighted decay factor. beta_2: float = 0.999 # Exponential weighted decay factor.
epsilon: float = 1e-7 # Numerical stability in denominator. epsilon: float = 1e-7 # Numerical stability in denominator.
class MultiTaskLearningRates(base_config.BaseConfig): class MultiTaskLearningRates(base_config.BaseConfig):
"""
Configuration settings for multiple learning rates in Recap.
Args:
tower_learning_rates (Dict[str, optimizers_config_mod.LearningRate]): Learning rates for different towers of the model.
backbone_learning_rate (optimizers_config_mod.LearningRate): Learning rate for the model's backbone (default: None).
Example:
To define multiple learning rates for different towers in Recap, use:
```python
multi_task_lr = MultiTaskLearningRates(
tower_learning_rates={
'task1': learning_rate1,
'task2': learning_rate2,
},
backbone_learning_rate=backbone_lr,
)
```
Note:
This class allows specifying different learning rates for different parts of the model, including task-specific towers and the backbone.
Warning:
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
"""
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field( tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
description="Learning rates for different towers of the model." description="Learning rates for different towers of the model."
) )
@ -25,6 +73,30 @@ class MultiTaskLearningRates(base_config.BaseConfig):
class RecapOptimizerConfig(base_config.BaseConfig): class RecapOptimizerConfig(base_config.BaseConfig):
"""
Configuration settings for the Recap optimizer.
Args:
multi_task_learning_rates (MultiTaskLearningRates): Multiple learning rates for different tasks (optional).
single_task_learning_rate (optimizers_config_mod.LearningRate): Learning rate for a single task (optional).
adam (RecapAdamConfig): Configuration settings for the Adam optimizer.
Example:
To define an optimizer configuration for training with Recap, use:
```python
optimizer_config = RecapOptimizerConfig(
multi_task_learning_rates=multi_task_lr,
single_task_learning_rate=single_task_lr,
adam=adam_config,
)
```
Warning:
This class is intended for internal use to configure the optimizer settings within Recap and should not be
directly accessed by external code.
"""
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field( multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
None, description="Multiple learning rates for different tasks.", one_of="lr" None, description="Multiple learning rates for different tasks.", one_of="lr"
) )

View File

@ -23,11 +23,29 @@ _DENSE_EMBEDDINGS = "dense_ebc"
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler): class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
"""Shim to get learning rates into a LRScheduler. """
A shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports
e.g. exponential decay can be used. learning rate schedules, such as exponential decay.
Args:
optimizer: The optimizer to which this scheduler is applied.
lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations.
emb_learning_rate: The learning rate for embeddings (optional).
last_epoch (int): The index of the last epoch (default: -1).
verbose (bool): If True, print warnings for deprecated functions (default: False).
Example:
To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use:
```python
scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate)
```
Warning:
This class is intended for internal use to handle learning rate scheduling within Recap training and should not
be directly accessed by external code.
""" """
def __init__( def __init__(
@ -80,14 +98,24 @@ def build_optimizer(
optimizer_config: config.OptimizerConfig, optimizer_config: config.OptimizerConfig,
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None, emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
): ):
"""Builds an optimizer and scheduler. """
Build an optimizer and scheduler for training.
Args: Args:
model: A torch model, probably with DDP/DMP. model: The torch model, possibly with DDP/DMP.
optimizer_config: An OptimizerConfig object that specifies learning rates per tower. optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer.
emb_optimizer_config: Configuration settings for embedding optimization (optional).
Returns: Returns:
A torch.optim instance, and a scheduler instance. torch.optim.Optimizer: The optimizer for training.
RecapLRShim: The learning rate scheduler for the optimizer.
Example:
To build an optimizer and scheduler for training, use:
```python
optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config)
```
""" """
optimizer_fn = functools.partial( optimizer_fn = functools.partial(
torch.optim.Adam, torch.optim.Adam,

View File

@ -4,6 +4,17 @@ import pydantic
class TwhinDataConfig(base_config.BaseConfig): class TwhinDataConfig(base_config.BaseConfig):
"""
Configuration for Twhin model training data.
Args:
data_root (str): The root directory for the training data.
per_replica_batch_size (pydantic.PositiveInt): Batch size per replica.
global_negatives (int): The number of global negatives.
in_batch_negatives (int): The number of in-batch negatives.
limit (pydantic.PositiveInt): The limit on the number of data points to use.
offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None.
"""
data_root: str data_root: str
per_replica_batch_size: pydantic.PositiveInt per_replica_batch_size: pydantic.PositiveInt
global_negatives: int global_negatives: int

View File

@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig): def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
"""
Create a dataset for Twhin model training.
Args:
data_config (TwhinDataConfig): The data configuration for the dataset.
model_config (TwhinModelConfig): The model configuration containing embeddings and relations.
Returns:
EdgesDataset: The dataset for Twhin model training.
"""
tables = model_config.embeddings.tables tables = model_config.embeddings.tables
table_sizes = {table.name: table.num_embeddings for table in tables} table_sizes = {table.name: table.num_embeddings for table in tables}
relations = model_config.relations relations = model_config.relations

View File

@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@dataclass @dataclass
class EdgeBatch(DataclassBatch): class EdgeBatch(DataclassBatch):
"""
Batch data structure for edge-based models.
Args:
nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
labels (torch.Tensor): Tensor containing labels.
rels (torch.Tensor): Tensor containing relation information.
weights (torch.Tensor): Tensor containing weights.
"""
nodes: KeyedJaggedTensor nodes: KeyedJaggedTensor
labels: torch.Tensor labels: torch.Tensor
rels: torch.Tensor rels: torch.Tensor
@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch):
class EdgesDataset(Dataset): class EdgesDataset(Dataset):
"""
Dataset for edge-based models.
Args:
file_pattern (str): The file pattern for the dataset.
table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
relations (List[Relation]): A list of relations between tables.
lhs_column_name (str): The name of the left-hand-side column.
rhs_column_name (str): The name of the right-hand-side column.
rel_column_name (str): The name of the relation column.
**dataset_kwargs: Additional keyword arguments for the parent Dataset class.
"""
rng = np.random.default_rng() rng = np.random.default_rng()
def __init__( def __init__(
@ -56,6 +77,15 @@ class EdgesDataset(Dataset):
super().__init__(file_pattern=file_pattern, **dataset_kwargs) super().__init__(file_pattern=file_pattern, **dataset_kwargs)
def pa_to_batch(self, batch: pa.RecordBatch): def pa_to_batch(self, batch: pa.RecordBatch):
"""
Converts a pyarrow RecordBatch to an EdgeBatch.
Args:
batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
Returns:
EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
"""
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy()) lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy()) rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy()) rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
@ -74,6 +104,14 @@ class EdgesDataset(Dataset):
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: ) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
"""Process edges that contain lhs index, rhs index, relation index. """Process edges that contain lhs index, rhs index, relation index.
Args:
lhs (torch.Tensor): Tensor containing left-hand-side indices.
rhs (torch.Tensor): Tensor containing right-hand-side indices.
rel (torch.Tensor): Tensor containing relation indices.
Returns:
Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
Example: Example:
``` ```
@ -147,6 +185,12 @@ class EdgesDataset(Dataset):
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths) return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
def to_batches(self): def to_batches(self):
"""
Converts data to batches.
Yields:
pa.RecordBatch: A pyarrow RecordBatch containing data.
"""
ds = super().to_batches() ds = super().to_batches()
batch_size = self._dataset_kwargs["batch_size"] batch_size = self._dataset_kwargs["batch_size"]

View File

@ -23,6 +23,25 @@ RELATIONS = [
def test_gen(): def test_gen():
"""Test function for generating edge-based datasets and dataloaders.
This function generates a synthetic dataset and tests the creation of an `EdgesDataset`
instance and a dataloader for it.
The test includes the following steps:
1. Create synthetic data with left-hand-side (lhs), right-hand-side (rhs), and relation (rel) columns.
2. Write the synthetic data to a Parquet file.
3. Create an `EdgesDataset` instance with the Parquet file pattern, table sizes, relations, and batch size.
4. Initialize the local file system for the dataset.
5. Create a dataloader for the dataset and retrieve the first batch.
6. Assert that the labels in the batch are positive.
7. Verify that the positive examples in the batch match the expected values.
This function serves as a test case for the data generation and dataset creation process.
Raises:
AssertionError: If any of the test assertions fail.
"""
import os import os
import tempfile import tempfile

View File

@ -10,8 +10,29 @@ from pydantic import validator
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
"""
Configuration class for Twhin model embeddings.
This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types
for all tables in the Twhin model embeddings configuration match.
Attributes:
tables (List[TableConfig]): A list of table configurations for the model's embeddings.
"""
@validator("tables") @validator("tables")
def embedding_dims_match(cls, tables): def embedding_dims_match(cls, tables):
"""
Validate that embedding dimensions and data types match for all tables.
Args:
tables (List[TableConfig]): List of table configurations.
Returns:
List[TableConfig]: The list of validated table configurations.
Raises:
AssertionError: If embedding dimensions or data types do not match.
"""
embedding_dim = tables[0].embedding_dim embedding_dim = tables[0].embedding_dim
data_type = tables[0].data_type data_type = tables[0].data_type
for table in tables: for table in tables:
@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
class Operator(str, enum.Enum): class Operator(str, enum.Enum):
"""
Enumeration of operator types.
This enumeration defines different types of operators that can be applied to Twhin model relations.
"""
TRANSLATION = "translation" TRANSLATION = "translation"
class Relation(pydantic.BaseModel): class Relation(pydantic.BaseModel):
"""graph relationship properties and operator""" """
Configuration class for graph relationships in the Twhin model.
This class defines properties and operators for graph relationships in the Twhin model.
Attributes:
name (str): The name of the relationship.
lhs (str): The name of the entity on the left-hand side of the relation.
rhs (str): The name of the entity on the right-hand side of the relation.
operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product.
"""
name: str = pydantic.Field(..., description="Relationship name.") name: str = pydantic.Field(..., description="Relationship name.")
lhs: str = pydantic.Field( lhs: str = pydantic.Field(
@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel):
class TwhinModelConfig(base_config.BaseConfig): class TwhinModelConfig(base_config.BaseConfig):
"""
Configuration class for the Twhin model.
This class defines configuration options specific to the Twhin model.
Attributes:
embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings.
relations (List[Relation]): List of graph relationship configurations.
translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation.
"""
embeddings: TwhinEmbeddingsConfig embeddings: TwhinEmbeddingsConfig
relations: typing.List[Relation] relations: typing.List[Relation]
translation_optimizer: OptimizerConfig translation_optimizer: OptimizerConfig
@validator("relations", each_item=True) @validator("relations", each_item=True)
def valid_node_types(cls, relation, values, **kwargs): def valid_node_types(cls, relation, values, **kwargs):
"""
Validate that the specified node types in relations are valid table names in embeddings.
Args:
relation (Relation): A single relation configuration.
values (dict): The values dictionary containing the "embeddings" configuration.
Returns:
Relation: The validated relation configuration.
Raises:
AssertionError: If the specified node types are not valid table names in embeddings.
"""
table_names = [table.name for table in values["embeddings"].tables] table_names = [table.name for table in values["embeddings"].tables]
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}" assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}" assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"

View File

@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa
class TwhinModel(nn.Module): class TwhinModel(nn.Module):
"""
Twhin model for graph-based entity embeddings and translation.
This class defines the Twhin model, which is used for learning embeddings of entities in a graph
and applying translations to these embeddings based on graph relationships.
Args:
model_config (TwhinModelConfig): Configuration for the Twhin model.
data_config (TwhinDataConfig): Configuration for the data used by the model.
Attributes:
batch_size (int): The batch size used for training.
table_names (List[str]): Names of tables in the model's embeddings.
large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
embedding_dim (int): Dimensionality of entity embeddings.
num_tables (int): Number of tables in the model's embeddings.
in_batch_negatives (int): Number of in-batch negative samples to use during training.
global_negatives (int): Number of global negative samples to use during training.
num_relations (int): Number of graph relationships in the model.
all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
"""
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig): def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
super().__init__() super().__init__()
self.batch_size = data_config.per_replica_batch_size self.batch_size = data_config.per_replica_batch_size
@ -31,7 +53,17 @@ class TwhinModel(nn.Module):
) )
def forward(self, batch: EdgeBatch): def forward(self, batch: EdgeBatch):
"""
Forward pass of the Twhin model.
Args:
batch (EdgeBatch): Input batch containing graph edge information.
Returns:
dict: A dictionary containing model output with "logits" and "probabilities".
- "logits" (torch.Tensor): Logit scores.
- "probabilities" (torch.Tensor): Sigmoid probabilities.
"""
# B x D # B x D
trans_embs = self.all_trans_embs.data[batch.rels] trans_embs = self.all_trans_embs.data[batch.rels]
@ -98,6 +130,18 @@ class TwhinModel(nn.Module):
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig): def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
"""
Apply optimizers to the Twhin model's embeddings.
This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
Args:
model (TwhinModel): The Twhin model to apply optimizers to.
model_config (TwhinModelConfig): Configuration for the Twhin model.
Returns:
TwhinModel: The Twhin model with optimizers applied to its embeddings.
"""
for table in model_config.embeddings.tables: for table in model_config.embeddings.tables:
optimizer_class = get_optimizer_class(table.optimizer) optimizer_class = get_optimizer_class(table.optimizer)
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict() optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
@ -124,9 +168,13 @@ class TwhinModelAndLoss(torch.nn.Module):
device: torch.device, device: torch.device,
) -> None: ) -> None:
""" """
Initialize a TwhinModelAndLoss module.
Args: Args:
model: torch module to wrap. model: The torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels. loss_fn: A function for calculating loss, should accept logits and labels.
data_config: Configuration for Twhin data.
device: The torch device to use for calculations.
""" """
super().__init__() super().__init__()
self.model = model self.model = model
@ -136,13 +184,20 @@ class TwhinModelAndLoss(torch.nn.Module):
self.device = device self.device = device
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn. """
Run the model forward and calculate the loss according to the given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types. generic python types. see https://pytorch.org/docs/stable/onnx.html#types
Args:
batch ("RecapBatch"): The input batch for model inference.
Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
additional outputs including logits, labels, and weights.
""" """
outputs = self.model(batch) outputs = self.model(batch)
logits = outputs["logits"] logits = outputs["logits"]

View File

@ -18,6 +18,12 @@ EMB_DIM = 128
def twhin_model_config() -> TwhinModelConfig: def twhin_model_config() -> TwhinModelConfig:
"""
Create a configuration for the Twhin model.
Returns:
TwhinModelConfig: The Twhin model configuration.
"""
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01)) sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig:
def twhin_data_config() -> TwhinDataConfig: def twhin_data_config() -> TwhinDataConfig:
"""
Create a configuration for the Twhin data.
Returns:
TwhinDataConfig: The Twhin data configuration.
"""
data_config = TwhinDataConfig( data_config = TwhinDataConfig(
data_root="/", data_root="/",
per_replica_batch_size=10, per_replica_batch_size=10,
@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig:
def test_twhin_model(): def test_twhin_model():
"""
Test the Twhin model creation and optimization.
This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
the device placement of model parameters.
Returns:
None
"""
model_config = twhin_model_config() model_config = twhin_model_config()
loss_fn = F.binary_cross_entropy_with_logits loss_fn = F.binary_cross_entropy_with_logits
@ -84,6 +105,23 @@ def test_twhin_model():
def test_unequal_dims(): def test_unequal_dims():
"""
Test function for validating unequal embedding dimensions in TwhinEmbeddingsConfig.
This function tests whether the validation logic correctly raises a `ValidationError` when
embedding dimensions in the `TwhinEmbeddingsConfig` are not equal for all tables.
The test includes the following steps:
1. Create two embedding configurations with different embedding dimensions.
2. Attempt to create a `TwhinEmbeddingsConfig` instance with the unequal embedding dimensions.
3. Assert that a `ValidationError` is raised, indicating that embedding dimensions must match.
This function serves as a test case to ensure that the validation logic enforces equal embedding dimensions
in the `TwhinEmbeddingsConfig` for all tables.
Raises:
AssertionError: If the expected `ValidationError` is not raised.
"""
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05)) sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
table0 = EmbeddingBagConfig( table0 = EmbeddingBagConfig(

View File

@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt"
def _lr_from_config(optimizer_config): def _lr_from_config(optimizer_config):
"""Get the learning rate from an optimizer configuration.
Args:
optimizer_config: Optimizer configuration.
Returns:
Learning rate from the optimizer configuration.
"""
if optimizer_config.learning_rate is not None: if optimizer_config.learning_rate is not None:
return optimizer_config.learning_rate return optimizer_config.learning_rate
else: else:
@ -28,7 +36,7 @@ def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
Args: Args:
model: TwhinModel to build optimizer for. model: TwhinModel to build optimizer for.
config: TwhinConfig for model. config: TwhinModelConfig for model.
Returns: Returns:
Optimizer for model. Optimizer for model.

View File

@ -37,6 +37,12 @@ def run(
all_config: TwhinConfig, all_config: TwhinConfig,
save_dir: Optional[str] = None, save_dir: Optional[str] = None,
): ):
"""Run the training process for TwhinModel.
Args:
all_config (TwhinConfig): The configuration for the entire Twhin model.
save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None.
"""
train_dataset = create_dataset(all_config.train_data, all_config.model) train_dataset = create_dataset(all_config.train_data, all_config.model)
if env.is_reader(): if env.is_reader():
@ -80,6 +86,11 @@ def run(
def main(argv): def main(argv):
"""Main entry point for the Twhin training script.
Args:
argv: Command-line arguments.
"""
logging.info("Starting") logging.info("Starting")
logging.info(f"parsing config from {FLAGS.config_yaml_path}...") logging.info(f"parsing config from {FLAGS.config_yaml_path}...")

View File

@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging
class _Reader(pa.flight.FlightServerBase): class _Reader(pa.flight.FlightServerBase):
"""Distributed reader flight server wrapping a dataset.""" """
Distributed reader flight server wrapping a dataset.
This class implements a Flight server that wraps a dataset, allowing clients to retrieve data
from the dataset over the Flight protocol. It is designed to be used in a distributed environment
for efficient data access.
Args:
location (str): The location of the Flight server.
ds (Dataset): The dataset to be wrapped by the Flight server.
Attributes:
_location (str): The location of the Flight server.
_ds (Dataset): The dataset wrapped by the Flight server.
Methods:
do_get(_, __): Handles Flight requests for data retrieval.
Note:
Flight is an Apache Arrow project that provides a framework for efficient data transfer.
This class allows clients to retrieve data from the dataset using Flight.
"""
def __init__(self, location: str, ds: "Dataset"): def __init__(self, location: str, ds: "Dataset"):
"""
Initialize a new _Reader instance.
Args:
location (str): The location of the Flight server.
ds (Dataset): The dataset to be wrapped by the Flight server.
"""
super().__init__(location=location) super().__init__(location=location)
self._location = location self._location = location
self._ds = ds self._ds = ds
def do_get(self, _, __): def do_get(self, _, __):
"""
Handle Flight requests for data retrieval.
This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol.
Args:
_: Unused argument.
__: Unused argument.
Returns:
pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset.
Note:
An updated schema (to account for column selection) must be given to the stream.
"""
# NB: An updated schema (to account for column selection) has to be given the stream. # NB: An updated schema (to account for column selection) has to be given the stream.
schema = next(iter(self._ds.to_batches())).schema schema = next(iter(self._ds.to_batches())).schema
batches = self._ds.to_batches() batches = self._ds.to_batches()
@ -46,12 +90,48 @@ class _Reader(pa.flight.FlightServerBase):
class Dataset(torch.utils.data.IterableDataset): class Dataset(torch.utils.data.IterableDataset):
"""
A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading.
This class enables efficient loading of data from Parquet files using PyArrow.
It is designed to be used as an IterableDataset in PyTorch for training and inference.
Args:
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset
for more details.
Attributes:
LOCATION (str): The default location for the Flight server used for data distribution.
_file_pattern (str): The glob pattern specifying Parquet files in the dataset.
_fs: The filesystem object used for file operations.
_dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method.
_files (list): A list of file paths matching the glob pattern.
_schema (pa.Schema): The schema of the Parquet dataset.
Methods:
serve(): Start serving the dataset using a Flight server.
to_batches(): Generate batches of data from the Parquet dataset.
pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch.
dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset.
Note:
This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch
to create DataLoader instances for training or inference.
"""
LOCATION = "grpc://0.0.0.0:2222" LOCATION = "grpc://0.0.0.0:2222"
def __init__(self, file_pattern: str, **dataset_kwargs) -> None: def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
"""Specify batch size and column to select for. """
Initialize a new Dataset instance. Specify batch size and column to select for.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset. Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
Args:
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
""" """
self._file_pattern = file_pattern self._file_pattern = file_pattern
self._fs = infer_fs(self._file_pattern) self._fs = infer_fs(self._file_pattern)
@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset):
self._validate_columns() self._validate_columns()
def _validate_columns(self): def _validate_columns(self):
"""
Validate the specified columns against the dataset schema.
Raises:
Exception: If any specified columns are not found in the dataset schema.
"""
columns = set(self._dataset_kwargs.get("columns", [])) columns = set(self._dataset_kwargs.get("columns", []))
wrong_columns = set(columns) - set(self._schema.names) wrong_columns = set(columns) - set(self._schema.names)
if wrong_columns: if wrong_columns:
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.") raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
def serve(self): def serve(self):
"""Start serving the dataset using a Flight server."""
self.reader = _Reader(location=self.LOCATION, ds=self) self.reader = _Reader(location=self.LOCATION, ds=self)
self.reader.serve() self.reader.serve()
def _create_dataset(self): def _create_dataset(self):
"""Create a PyArrow dataset for data retrieval."""
return pads.dataset( return pads.dataset(
source=random.sample(self._files, len(self._files))[0], source=random.sample(self._files, len(self._files))[0],
format="parquet", format="parquet",
@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset):
@abc.abstractmethod @abc.abstractmethod
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch: def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
"""
Convert a Parquet RecordBatch to a custom data batch.
Args:
batch (pa.RecordBatch): A batch of data from the Parquet dataset.
Returns:
DataclassBatch: A custom data batch used in PyTorch training.
Raises:
NotImplementedError: This method must be implemented in derived classes.
"""
raise NotImplementedError raise NotImplementedError
def dataloader(self, remote: bool = False): def dataloader(self, remote: bool = False):
"""
Create a PyTorch DataLoader for iterating through the dataset.
Args:
remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training.
Returns:
DataLoader: A PyTorch DataLoader for iterating through the dataset.
Note:
If `remote` is True, a remote DataLoader is created for distributed training using Flight.
"""
if not remote: if not remote:
return map(self.pa_to_batch, self.to_batches()) return map(self.pa_to_batch, self.to_batches())
readers = get_readers(2) readers = get_readers(2)
@ -117,6 +230,25 @@ GRPC_OPTIONS = [
def get_readers(num_readers_per_worker: int): def get_readers(num_readers_per_worker: int):
"""
Get Flight readers for distributed data loading.
This function retrieves Flight readers for distributed data loading in a PyTorch environment.
Args:
num_readers_per_worker (int): The number of Flight readers to retrieve per worker.
Returns:
List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading.
Note:
Flight readers are used to fetch data in a distributed manner for efficient data loading.
Example:
To obtain Flight readers, use the following code:
>>> readers = get_readers(num_readers_per_worker=2)
"""
addresses = env.get_flight_server_addresses() addresses = env.get_flight_server_addresses()
readers = [] readers = []

View File

@ -21,6 +21,16 @@ import torch.distributed as dist
def maybe_start_dataset_service(): def maybe_start_dataset_service():
"""
Start the dataset service if readers are available and required dependencies are met.
This function checks if readers are available and if the required TensorFlow version is >= 2.5.
If both conditions are met and the current environment is the dispatcher or reader, it starts
the TensorFlow dataset service.
Raises:
Exception: If the required TensorFlow version is not met (>= 2.5).
"""
if not env.has_readers(): if not env.has_readers():
return return
@ -59,6 +69,24 @@ def maybe_start_dataset_service():
def register_dataset( def register_dataset(
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO" dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
): ):
"""
Register a dataset with the distributed dataset service.
This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
and job name to all processes in the distributed environment.
Args:
dataset (tf.data.Dataset): The dataset to be registered.
dataset_service (str): The name of the dataset service.
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
Returns:
Tuple[int, str]: A tuple containing the dataset ID and job name.
Note:
This function should be called on the rank 0 process.
"""
if dist.get_rank() == 0: if dist.get_rank() == 0:
dataset_id = _register_dataset( dataset_id = _register_dataset(
service=dataset_service, service=dataset_service,
@ -82,6 +110,23 @@ def distribute_from_dataset_id(
compression: Optional[str] = "AUTO", compression: Optional[str] = "AUTO",
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE, prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
) -> tf.data.Dataset: ) -> tf.data.Dataset:
"""
Distribute a dataset from a registered dataset ID.
This function consumes a dataset from the distributed dataset service using the provided dataset ID
and job name. It also supports prefetching for improved performance.
Args:
dataset_service (str): The name of the dataset service.
dataset_id (int): The ID of the dataset to be consumed.
job_name (Optional[str]): The name of the job associated with the dataset (optional).
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
Returns:
tf.data.Dataset: The distributed dataset.
"""
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}") logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
dataset = _from_dataset_id( dataset = _from_dataset_id(
processing_mode="parallel_epochs", processing_mode="parallel_epochs",
@ -97,13 +142,26 @@ def distribute_from_dataset_id(
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset: def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
"""Torch-compatible and distributed-training-aware dataset service distributor. """
Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
- rank 0 process will register the given dataset. This function is used to distribute a dataset in a distributed training environment. It performs the
- rank 0 process will broadcast job name and dataset id. following steps:
- all rank processes will consume from the same job/dataset. - On the rank 0 process, it registers the given dataset with the distributed dataset service.
- It broadcasts the job name and dataset ID to all rank processes.
- All rank processes then consume the same dataset from the distributed dataset service.
Without this, dataset workers will try to serve 1 job per rank process and OOM. Args:
dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
Returns:
tf.data.Dataset: The distributed TensorFlow dataset.
Note:
- If there are no reader processes in the distributed environment, the original dataset is returned
without any distribution.
- This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
issues caused by each rank process trying to serve one job.
""" """
if not env.has_readers(): if not env.has_readers():

View File

@ -12,6 +12,17 @@ import torch
def create_dataset(tmpdir): def create_dataset(tmpdir):
"""
Create a mock dataset for testing.
This function creates a mock dataset using PyArrow and Parquet for testing purposes.
Args:
tmpdir: A temporary directory where the dataset will be created.
Returns:
MockDataset: A mock dataset for testing.
"""
table = pa.table( table = pa.table(
{ {
@ -34,6 +45,14 @@ def create_dataset(tmpdir):
def test_dataset(tmpdir): def test_dataset(tmpdir):
"""
Test the created dataset.
This function tests the created mock dataset and checks if it behaves as expected.
Args:
tmpdir: A temporary directory used for testing.
"""
ds = create_dataset(tmpdir) ds = create_dataset(tmpdir)
batch = next(iter(ds.dataloader(remote=False))) batch = next(iter(ds.dataloader(remote=False)))
assert batch.batch_size == 2 assert batch.batch_size == 2
@ -46,6 +65,14 @@ def test_dataset(tmpdir):
reason="Multiprocessing doesn't work on github yet.", reason="Multiprocessing doesn't work on github yet.",
) )
def test_distributed_dataset(tmpdir): def test_distributed_dataset(tmpdir):
"""
Test the distributed dataset.
This function tests the distributed version of the mock dataset using multiprocessing.
Args:
tmpdir: A temporary directory used for testing.
"""
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"} MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
def _client(): def _client():

View File

@ -11,10 +11,54 @@ import torch
def roundrobin(*iterables): def roundrobin(*iterables):
"""Round robin through provided iterables, useful for simple load balancing. """
Iterate through provided iterables in a round-robin fashion.
This function takes multiple iterables and returns an iterator that yields elements from
each iterable in a round-robin manner. It continues cycling through the iterables until
all of them are exhausted.
Adapted from https://docs.python.org/3/library/itertools.html. Adapted from https://docs.python.org/3/library/itertools.html.
Args:
*iterables: One or more iterable objects to iterate through.
Yields:
Elements from the provided iterables in a round-robin fashion.
Raises:
StopIteration: If all provided iterables are exhausted.
Example:
```python
iterable1 = [1, 2, 3]
iterable2 = ['a', 'b', 'c']
iterable3 = [0.1, 0.2, 0.3]
for item in roundrobin(iterable1, iterable2, iterable3):
print(item)
# Output:
# 1
# 'a'
# 0.1
# 2
# 'b'
# 0.2
# 3
# 'c'
# 0.3
```
Note:
- If one of the provided iterables is shorter than the others, the function will
continue iterating through the remaining iterables until all are exhausted.
- If an iterable raises an exception during iteration, a warning message is logged,
and the function continues with the next iterable.
See Also:
- `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
- `itertools.islice`: A function to slice an iterable to limit the number of iterations.
""" """
num_active = len(iterables) num_active = len(iterables)
nexts = itertools.cycle(iter(it).__next__ for it in iterables) nexts = itertools.cycle(iter(it).__next__ for it in iterables)
@ -35,6 +79,48 @@ def roundrobin(*iterables):
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]): def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
"""
Monitor the speed and progress of data loading using a data loader.
This function iterates through a data loader for a specified number of steps or until
the end of the data loader is reached, periodically logging progress information.
Args:
data_loader: The data loader to monitor.
max_steps: The maximum number of steps to iterate through the data loader.
frequency: The frequency (in steps) at which to log progress.
peek (optional): If specified, it indicates the frequency (in steps) at which to log
batch contents for inspection.
Example:
```python
import torch
from torch.utils.data import DataLoader
# Create a data loader (replace with your own DataLoader configuration)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Monitor data loading speed and progress
speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
```
Args:
data_loader: The data loader to monitor.
max_steps: The maximum number of steps to iterate through the data loader.
frequency: The frequency (in steps) at which to log progress.
peek (optional): If specified, it indicates the frequency (in steps) at which to log
batch contents for inspection.
Note:
- The function logs information about elapsed time, the number of examples processed,
and the processing speed in examples per second.
- If `peek` is provided, batch contents will be logged for inspection at the specified
frequency.
See Also:
- `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
iterating through datasets.
"""
num_examples = 0 num_examples = 0
prev = time.perf_counter() prev = time.perf_counter()
for idx, batch in enumerate(data_loader): for idx, batch in enumerate(data_loader):
@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]
def pa_to_torch(array: pa.array) -> torch.Tensor: def pa_to_torch(array: pa.array) -> torch.Tensor:
"""
Convert a PyArrow Array to a PyTorch Tensor.
Args:
array (pa.array): The PyArrow Array to convert.
Returns:
torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
Example:
```python
import pyarrow as pa
import torch
# Create a PyArrow Array
arrow_array = pa.array([1, 2, 3])
# Convert it to a PyTorch Tensor
torch_tensor = pa_to_torch(arrow_array)
```
"""
return torch.from_numpy(array.to_numpy()) return torch.from_numpy(array.to_numpy())
def create_default_pa_to_batch(schema) -> DataclassBatch: def create_default_pa_to_batch(schema) -> DataclassBatch:
""" """ """
Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
Args:
schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
Returns:
callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
Example:
```python
import pyarrow as pa
from dataclass_batch import DataclassBatch
# Define a PyArrow schema
schema = pa.schema([
("feature1", pa.float64()),
("feature2", pa.int64()),
("label", pa.int64()),
])
# Create the conversion function
pa_to_batch = create_default_pa_to_batch(schema)
# Create a PyArrow RecordBatch
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
"feature1": [1.0, 2.0, None],
"feature2": [10, 20, 30],
"label": [0, 1, None],
}))
# Convert the RecordBatch to a custom DataclassBatch
custom_batch = pa_to_batch(record_batch)
```
"""
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema) _CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
def get_imputation_value(pa_type): def get_imputation_value(pa_type):

View File

@ -38,6 +38,15 @@ import pyarrow.parquet as pq
def _create_dataset(path: str): def _create_dataset(path: str):
"""
Create a PyArrow dataset from Parquet files located at the specified path.
Args:
path (str): The path to the Parquet files.
Returns:
pyarrow.dataset.Dataset: The PyArrow dataset.
"""
fs = infer_fs(path) fs = infer_fs(path)
files = fs.glob(path) files = fs.glob(path)
return pads.dataset(files, format="parquet", filesystem=fs) return pads.dataset(files, format="parquet", filesystem=fs)
@ -47,12 +56,27 @@ class PqReader:
def __init__( def __init__(
self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None
): ):
"""
Initialize a Parquet Reader.
Args:
path (str): The path to the Parquet files.
num (int): The maximum number of rows to read.
batch_size (int): The batch size for reading data.
columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns).
"""
self._ds = _create_dataset(path) self._ds = _create_dataset(path)
self._batch_size = batch_size self._batch_size = batch_size
self._num = num self._num = num
self._columns = columns self._columns = columns
def __iter__(self): def __iter__(self):
"""
Iterate through the Parquet data and yield batches of rows.
Yields:
pyarrow.RecordBatch: A batch of rows.
"""
batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns) batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns)
rows_seen = 0 rows_seen = 0
for count, record in enumerate(batches): for count, record in enumerate(batches):
@ -62,6 +86,12 @@ class PqReader:
rows_seen += record.data.num_rows rows_seen += record.data.num_rows
def _head(self): def _head(self):
"""
Get the first `num` rows of the Parquet data.
Returns:
pyarrow.RecordBatch: A batch of rows.
"""
total_read = self._num * self.bytes_per_row total_read = self._num * self.bytes_per_row
if total_read >= int(500e6): if total_read >= int(500e6):
raise Exception( raise Exception(
@ -71,6 +101,12 @@ class PqReader:
@property @property
def bytes_per_row(self) -> int: def bytes_per_row(self) -> int:
"""
Calculate the estimated bytes per row in the dataset.
Returns:
int: The estimated bytes per row.
"""
nbits = 0 nbits = 0
for t in self._ds.schema.types: for t in self._ds.schema.types:
try: try:
@ -81,17 +117,22 @@ class PqReader:
return nbits // 8 return nbits // 8
def schema(self): def schema(self):
"""
Display the schema of the Parquet dataset.
"""
print(f"\n# Schema\n{self._ds.schema}") print(f"\n# Schema\n{self._ds.schema}")
def head(self): def head(self):
"""Displays first --num rows.""" """
Display the first `num` rows of the Parquet data as a pandas DataFrame.
"""
print(self._head().to_pandas()) print(self._head().to_pandas())
def distinct(self): def distinct(self):
"""Displays unique values seen in specified columns in the first `--num` rows. """
Display unique values seen in specified columns in the first `num` rows.
Useful for getting an approximate vocabulary for certain columns. Useful for getting an approximate vocabulary for certain columns.
""" """
for col_name, column in zip(self._head().column_names, self._head().columns): for col_name, column in zip(self._head().column_names, self._head().columns):
print(col_name) print(col_name)