mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-16 13:19:23 +01:00
Merge 0813989fd9
into b85210863f
This commit is contained in:
commit
63573c6f22
121
common/batch.py
121
common/batch.py
@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable
|
||||
|
||||
|
||||
class BatchBase(Pipelineable, abc.ABC):
|
||||
"""
|
||||
A base class for batches used in pipelines.
|
||||
|
||||
Attributes:
|
||||
None
|
||||
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def as_dict(self) -> Dict:
|
||||
"""
|
||||
Convert the batch into a dictionary representation.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary representation of the batch.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented in a subclass.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device: torch.device, non_blocking: bool = False):
|
||||
"""
|
||||
Move the batch to the specified device.
|
||||
|
||||
Args:
|
||||
device (torch.device): The target device.
|
||||
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.
|
||||
|
||||
Returns:
|
||||
BatchBase: A new batch on the target device.
|
||||
|
||||
"""
|
||||
args = {}
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
||||
return self.__class__(**args)
|
||||
|
||||
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
|
||||
"""
|
||||
Record a CUDA stream for all tensors in the batch.
|
||||
|
||||
Args:
|
||||
stream (torch.cuda.streams.Stream): The CUDA stream to record.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
for feature_value in self.as_dict().values():
|
||||
feature_value.record_stream(stream)
|
||||
|
||||
def pin_memory(self):
|
||||
"""
|
||||
Pin memory for all tensors in the batch.
|
||||
|
||||
Returns:
|
||||
BatchBase: A new batch with pinned memory.
|
||||
|
||||
"""
|
||||
args = {}
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
args[feature_name] = feature_value.pin_memory()
|
||||
return self.__class__(**args)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Generate a string representation of the batch.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the batch.
|
||||
|
||||
"""
|
||||
def obj2str(v):
|
||||
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
||||
|
||||
@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC):
|
||||
|
||||
@property
|
||||
def batch_size(self) -> int:
|
||||
"""
|
||||
Get the batch size from the tensors in the batch.
|
||||
|
||||
Returns:
|
||||
int: The batch size.
|
||||
|
||||
Raises:
|
||||
Exception: If the batch size cannot be determined from the tensors.
|
||||
|
||||
"""
|
||||
for tensor in self.as_dict().values():
|
||||
if tensor is None:
|
||||
continue
|
||||
@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC):
|
||||
|
||||
@dataclass
|
||||
class DataclassBatch(BatchBase):
|
||||
"""
|
||||
A batch class that uses dataclasses to define its fields.
|
||||
|
||||
Attributes:
|
||||
None
|
||||
|
||||
"""
|
||||
@classmethod
|
||||
def feature_names(cls):
|
||||
"""
|
||||
Get the feature names of the dataclass.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of feature names.
|
||||
|
||||
"""
|
||||
return list(cls.__dataclass_fields__.keys())
|
||||
|
||||
def as_dict(self):
|
||||
"""
|
||||
Convert the dataclass batch into a dictionary representation.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary representation of the batch.
|
||||
|
||||
"""
|
||||
return {
|
||||
feature_name: getattr(self, feature_name)
|
||||
for feature_name in self.feature_names()
|
||||
@ -64,7 +147,18 @@ class DataclassBatch(BatchBase):
|
||||
|
||||
@staticmethod
|
||||
def from_schema(name: str, schema):
|
||||
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
|
||||
"""
|
||||
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.
|
||||
|
||||
Args:
|
||||
name (str): The name of the custom batch class.
|
||||
schema: The schema or structure of the batch.
|
||||
|
||||
Returns:
|
||||
Type[DataclassBatch]: A custom batch class.
|
||||
|
||||
"""
|
||||
|
||||
return dataclasses.make_dataclass(
|
||||
cls_name=name,
|
||||
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
|
||||
@ -73,6 +167,17 @@ class DataclassBatch(BatchBase):
|
||||
|
||||
@staticmethod
|
||||
def from_fields(name: str, fields: dict):
|
||||
"""
|
||||
Create a custom batch subclass from a set of fields.
|
||||
|
||||
Args:
|
||||
name (str): The name of the custom batch class.
|
||||
fields (dict): A dictionary specifying the fields and their types.
|
||||
|
||||
Returns:
|
||||
Type[DataclassBatch]: A custom batch class.
|
||||
|
||||
"""
|
||||
return dataclasses.make_dataclass(
|
||||
cls_name=name,
|
||||
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
||||
@ -81,5 +186,19 @@ class DataclassBatch(BatchBase):
|
||||
|
||||
|
||||
class DictionaryBatch(BatchBase, dict):
|
||||
"""
|
||||
A batch class that represents data as a dictionary.
|
||||
|
||||
Attributes:
|
||||
None
|
||||
|
||||
"""
|
||||
def as_dict(self) -> Dict:
|
||||
"""
|
||||
Convert the dictionary batch into a dictionary representation.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary representation of the batch.
|
||||
|
||||
"""
|
||||
return self
|
||||
|
@ -13,23 +13,30 @@ GCS_PREFIX = "gs://"
|
||||
|
||||
|
||||
class Snapshot:
|
||||
"""Checkpoints using torchsnapshot.
|
||||
|
||||
Also saves step to be updated by the training loop.
|
||||
|
||||
"""
|
||||
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Initializes a Snapshot object.
|
||||
|
||||
Args:
|
||||
save_dir (str): Directory where checkpoints will be saved.
|
||||
state (Dict[str, Any]): State dictionary containing checkpoint information.
|
||||
"""
|
||||
self.save_dir = save_dir
|
||||
self.state = state
|
||||
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
||||
|
||||
@property
|
||||
def step(self):
|
||||
"""Get the current training step."""
|
||||
return self.state["extra_state"]["step"]
|
||||
|
||||
@step.setter
|
||||
def step(self, step: int) -> None:
|
||||
"""Set the current training step."""
|
||||
self.state["extra_state"]["step"] = step
|
||||
|
||||
@property
|
||||
@ -41,7 +48,15 @@ class Snapshot:
|
||||
self.state["extra_state"]["walltime"] = walltime
|
||||
|
||||
def save(self, global_step: int) -> "PendingSnapshot":
|
||||
"""Saves checkpoint with given global_step."""
|
||||
"""
|
||||
Saves a checkpoint with a given global step.
|
||||
|
||||
Args:
|
||||
global_step (int): The global step to associate with the checkpoint.
|
||||
|
||||
Returns:
|
||||
PendingSnapshot: A pending snapshot object.
|
||||
"""
|
||||
path = os.path.join(self.save_dir, str(global_step))
|
||||
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
||||
start_time = time.time()
|
||||
@ -58,7 +73,12 @@ class Snapshot:
|
||||
return snapshot
|
||||
|
||||
def restore(self, checkpoint: str) -> None:
|
||||
"""Restores a given checkpoint."""
|
||||
"""
|
||||
Restores a given checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Path to the checkpoint to restore.
|
||||
"""
|
||||
snapshot = torchsnapshot.Snapshot(path=checkpoint)
|
||||
logging.info(f"Restoring snapshot from {snapshot.path}.")
|
||||
start_time = time.time()
|
||||
@ -83,12 +103,17 @@ class Snapshot:
|
||||
global_step: Optional[int] = None,
|
||||
missing_ok: bool = False,
|
||||
) -> torchsnapshot.Snapshot:
|
||||
"""Get torch stateless snapshot, without actually loading it.
|
||||
Args:
|
||||
snapshot_path: path to the model snapshot
|
||||
global_step: restores from this checkpoint if specified.
|
||||
missing_ok: if True and checkpoints do not exist, returns without restoration.
|
||||
"""
|
||||
Get a torch stateless snapshot, without actually loading it.
|
||||
|
||||
Args:
|
||||
snapshot_path (str): Path to the model snapshot.
|
||||
global_step (int, optional): Restores from this checkpoint if specified.
|
||||
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
|
||||
|
||||
Returns:
|
||||
torchsnapshot.Snapshot: A torch snapshot object.
|
||||
"""
|
||||
path = get_checkpoint(snapshot_path, global_step, missing_ok)
|
||||
logging.info(f"Loading snapshot from {path}.")
|
||||
return torchsnapshot.Snapshot(path=path)
|
||||
@ -100,13 +125,14 @@ class Snapshot:
|
||||
snapshot_emb_name: str,
|
||||
weight_tensor,
|
||||
) -> None:
|
||||
"""Loads pretrained embedding from the snapshot to the model.
|
||||
Utilise partial lodaing meachanism from torchsnapshot.
|
||||
Args:
|
||||
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
|
||||
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
|
||||
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
|
||||
"""
|
||||
Loads pretrained embedding from the snapshot to the model.
|
||||
|
||||
Args:
|
||||
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
|
||||
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
|
||||
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
|
||||
"""
|
||||
start_time = time.time()
|
||||
manifest = embedding_snapshot.get_manifest()
|
||||
for path in manifest.keys():
|
||||
@ -209,7 +235,22 @@ def get_checkpoint(
|
||||
|
||||
|
||||
def get_checkpoints(save_dir: str) -> List[str]:
|
||||
"""Gets all checkpoints that have been fully written."""
|
||||
"""
|
||||
Get a list of fully written checkpoints in the specified directory.
|
||||
|
||||
This function retrieves a list of fully written checkpoints in the given directory.
|
||||
Checkpoints that are considered fully written include those that have a
|
||||
corresponding snapshot metadata file.
|
||||
|
||||
Args:
|
||||
save_dir (str): The directory where checkpoints are stored.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of fully written checkpoint paths.
|
||||
|
||||
Note:
|
||||
Checkpoints are sorted by their numeric filenames in ascending order.
|
||||
"""
|
||||
checkpoints = []
|
||||
fs = infer_fs(save_dir)
|
||||
if fs.exists(save_dir):
|
||||
@ -232,6 +273,18 @@ def wait_for_evaluators(
|
||||
global_step: int,
|
||||
timeout: int,
|
||||
) -> None:
|
||||
"""
|
||||
Waits for all evaluators to finish and checks for their completion status.
|
||||
|
||||
Args:
|
||||
save_dir (str): Directory where checkpoints are saved.
|
||||
partition_names (List[str]): List of partition names to check for completion.
|
||||
global_step (int): The global step for which to wait for evaluators.
|
||||
timeout (int): Maximum time in seconds to wait for evaluators to finish.
|
||||
|
||||
Returns:
|
||||
None: This function returns nothing but logs the progress and results.
|
||||
"""
|
||||
logging.info("Waiting for all evaluators to finish.")
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -5,6 +5,15 @@ import torch.distributed as dist
|
||||
|
||||
|
||||
def maybe_setup_tensorflow():
|
||||
"""
|
||||
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
|
||||
|
||||
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
@ -14,6 +23,18 @@ def maybe_setup_tensorflow():
|
||||
|
||||
|
||||
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
|
||||
"""
|
||||
Set up the distributed environment and get the appropriate torch device.
|
||||
|
||||
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
|
||||
|
||||
Args:
|
||||
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
|
||||
|
||||
Returns:
|
||||
torch.device: The torch device for the current process.
|
||||
|
||||
"""
|
||||
if tf_ok:
|
||||
maybe_setup_tensorflow()
|
||||
|
||||
|
@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem()
|
||||
|
||||
|
||||
def infer_fs(path: str):
|
||||
"""
|
||||
Infer the file system (fs) type based on the given path.
|
||||
|
||||
Args:
|
||||
path (str): The file path.
|
||||
|
||||
Returns:
|
||||
str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system).
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported.
|
||||
"""
|
||||
if path.startswith("gs://"):
|
||||
return GCS_FS
|
||||
elif path.startswith("hdfs://"):
|
||||
@ -18,8 +30,26 @@ def infer_fs(path: str):
|
||||
|
||||
|
||||
def is_local_fs(fs):
|
||||
"""
|
||||
Check if the given file system is the local file system.
|
||||
|
||||
Args:
|
||||
fs (str): The file system type to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the file system is the local file system, False otherwise.
|
||||
"""
|
||||
return fs == LOCAL_FS
|
||||
|
||||
|
||||
def is_gcs_fs(fs):
|
||||
"""
|
||||
Check if the given file system is Google Cloud Storage (GCS).
|
||||
|
||||
Args:
|
||||
fs (str): The file system type to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the file system is GCS, False otherwise.
|
||||
"""
|
||||
return fs == GCS_FS
|
||||
|
@ -49,15 +49,25 @@ def log_ebc_norms(
|
||||
ebc_keys: List[str],
|
||||
sample_size: int = 4_000_000,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Logs the norms of the embedding tables as specified by ebc_keys.
|
||||
As of now, log average norm per rank.
|
||||
|
||||
Args:
|
||||
model_state_dict: model.state_dict()
|
||||
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
|
||||
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
|
||||
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
|
||||
"""
|
||||
Logs the norms of the embedding tables as specified by ebc_keys.
|
||||
As of now, log average norm per rank.
|
||||
|
||||
Args:
|
||||
model_state_dict (dict): The state dictionary of the model.
|
||||
ebc_keys (List[str]): List of embedding keys from state_dict to log.
|
||||
Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight".
|
||||
sample_size (int, optional): Limits the number of rows per rank to compute average on
|
||||
to avoid out-of-memory (OOM) errors. Defaults to 4,000,000.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables.
|
||||
The keys are in the format "{emb_key}-norm-{idx}".
|
||||
|
||||
Note:
|
||||
This function computes and logs the average norm of embedding tables across ranks.
|
||||
It gathers the norms from all ranks and returns them as a dictionary.
|
||||
"""
|
||||
norm_logs = dict()
|
||||
for emb_key in ebc_keys:
|
||||
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
|
||||
|
@ -11,6 +11,20 @@ import numpy as np
|
||||
|
||||
|
||||
class LargeEmbeddings(nn.Module):
|
||||
"""
|
||||
A module for handling large embeddings.
|
||||
|
||||
Args:
|
||||
large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings.
|
||||
|
||||
Attributes:
|
||||
ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings.
|
||||
surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models.
|
||||
|
||||
Note:
|
||||
The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
large_embeddings_config: LargeEmbeddingsConfig,
|
||||
@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module):
|
||||
self,
|
||||
sparse_features: KeyedJaggedTensor,
|
||||
) -> KeyedTensor:
|
||||
"""
|
||||
Forward pass of the LargeEmbeddings module.
|
||||
|
||||
Args:
|
||||
sparse_features (KeyedJaggedTensor): Sparse input features.
|
||||
|
||||
Returns:
|
||||
KeyedTensor: The output of the forward pass, a KeyedTensor.
|
||||
|
||||
"""
|
||||
pooled_embs = self.ebc(sparse_features)
|
||||
|
||||
# a KeyedTensor
|
||||
|
@ -11,6 +11,12 @@ import torch.distributed.run
|
||||
|
||||
|
||||
def is_distributed_worker():
|
||||
"""
|
||||
Checks if the current process is a distributed worker.
|
||||
|
||||
Returns:
|
||||
bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False.
|
||||
"""
|
||||
world_size = os.environ.get("WORLD_SIZE", None)
|
||||
rank = os.environ.get("RANK", None)
|
||||
return world_size is not None and rank is not None
|
||||
@ -25,25 +31,36 @@ def maybe_run_training(
|
||||
is_chief: Optional[bool] = False,
|
||||
**training_kwargs,
|
||||
):
|
||||
"""Wrapper function for single node, multi-GPU Pytorch training.
|
||||
|
||||
If the necessary distributed Pytorch environment variables
|
||||
(WORLD_SIZE, RANK) have been set, then this function executes
|
||||
`train_fn(**training_kwargs)`.
|
||||
|
||||
Otherwise, this function calls torchrun and points at the calling module
|
||||
`module_name`. After this call, the necessary environment variables are set
|
||||
and training will commence.
|
||||
|
||||
Args:
|
||||
train_fn: The function that is responsible for training
|
||||
module_name: The name of the module that this function was called from;
|
||||
used to indicate torchrun entrypoint.
|
||||
nproc_per_node: Number of workers per node; supported values.
|
||||
num_nodes: Number of nodes, otherwise inferred from environment.
|
||||
is_chief: If process is running on chief.
|
||||
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
|
||||
"""
|
||||
Wrapper function for single node, multi-GPU PyTorch training.
|
||||
|
||||
If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes
|
||||
`train_fn(**training_kwargs)`.
|
||||
|
||||
Otherwise, this function calls torchrun and points at the calling module
|
||||
`module_name`. After this call, the necessary environment variables are set
|
||||
and training will commence.
|
||||
|
||||
Args:
|
||||
train_fn (callable): The function responsible for training.
|
||||
module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint.
|
||||
nproc_per_node (int, optional): Number of workers per node. Defaults to None.
|
||||
num_nodes (int, optional): Number of nodes. Defaults to None.
|
||||
is_chief (bool, optional): If the process is running on the chief node. Defaults to False.
|
||||
set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False.
|
||||
**training_kwargs: Additional keyword arguments to pass to the `train_fn`.
|
||||
|
||||
Note:
|
||||
This function checks if the current process is a distributed worker by examining the environment variables.
|
||||
If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary
|
||||
environment variables and launches the training process using torchrun.
|
||||
|
||||
Example:
|
||||
To run training on a single node with 4 GPUs, you can use:
|
||||
```
|
||||
maybe_run_training(train_function, __name__, nproc_per_node=4)
|
||||
```
|
||||
"""
|
||||
|
||||
machines = utils.machine_from_env()
|
||||
if num_nodes is None:
|
||||
|
@ -19,6 +19,23 @@ MOCK_ENV = {
|
||||
|
||||
@contextmanager
|
||||
def mock_pg():
|
||||
"""
|
||||
A context manager for mocking the distributed process group for testing purposes.
|
||||
|
||||
This context manager temporarily sets environment variables to mock the distributed process group
|
||||
and initializes it using the Gloo backend. It is useful for testing distributed training without
|
||||
actually launching multiple processes.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with mock_pg():
|
||||
# Your distributed training code here
|
||||
```
|
||||
|
||||
Note:
|
||||
This context manager should be used within a testing environment to simulate distributed training
|
||||
without actually creating multiple processes.
|
||||
"""
|
||||
with patch.dict(os.environ, MOCK_ENV):
|
||||
try:
|
||||
dist.init_process_group(
|
||||
|
@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig)
|
||||
|
||||
|
||||
def _read_file(f):
|
||||
"""
|
||||
Read the contents of a file using fsspec.
|
||||
|
||||
Args:
|
||||
f: File path or URL.
|
||||
|
||||
Returns:
|
||||
The contents of the file.
|
||||
"""
|
||||
with fsspec.open(f) as f:
|
||||
return f.read()
|
||||
|
||||
@ -21,18 +30,24 @@ def setup_configuration(
|
||||
yaml_path: str,
|
||||
substitute_env_variable: bool = False,
|
||||
) -> Tuple[C, str]:
|
||||
"""Resolves a config at a yaml path.
|
||||
|
||||
Args:
|
||||
config_type: Pydantic config class to load.
|
||||
yaml_path: yaml path of the config file.
|
||||
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
|
||||
environment variable value whenever possible. If an environment variable doesn't exist,
|
||||
the string is left unchanged.
|
||||
|
||||
Returns:
|
||||
The pydantic config object.
|
||||
"""
|
||||
Load a Pydantic config object from a YAML file and optionally substitute environment variables.
|
||||
|
||||
Args:
|
||||
config_type: Pydantic config class to load.
|
||||
yaml_path: YAML path of the config file.
|
||||
substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR}
|
||||
with their environment variable values whenever possible.
|
||||
If an environment variable doesn't exist, the string is left unchanged.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Pydantic config object and the resolved YAML content.
|
||||
|
||||
Example:
|
||||
```python
|
||||
config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True)
|
||||
```
|
||||
"""
|
||||
|
||||
def _substitute(s):
|
||||
if substitute_env_variable:
|
||||
|
@ -6,6 +6,33 @@ import pydantic
|
||||
|
||||
|
||||
class WandbConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for integrating with Weights and Biases (WandB).
|
||||
|
||||
Attributes:
|
||||
host (str): Host of the Weights and Biases instance, passed to login.
|
||||
key_path (str): Path to the key file.
|
||||
name (str): Name of the experiment, passed to init.
|
||||
entity (str): Name of the user/service account, passed to init.
|
||||
project (str): Name of the WandB project, passed to init.
|
||||
tags (List[str]): List of tags, passed to init.
|
||||
notes (str): Notes, passed to init.
|
||||
metadata (Dict[str, Any]): Additional metadata to log.
|
||||
|
||||
Example:
|
||||
```python
|
||||
wandb_config = WandbConfig(
|
||||
host="https://wandb.example.com",
|
||||
key_path="/path/to/key",
|
||||
name="experiment_1",
|
||||
entity="user123",
|
||||
project="my_project",
|
||||
tags=["experiment", "ml"],
|
||||
notes="This is a test experiment.",
|
||||
metadata={"version": "1.0"}
|
||||
)
|
||||
```
|
||||
"""
|
||||
host: str = pydantic.Field(
|
||||
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
|
||||
description="Host of Weights and Biases instance, passed to login.",
|
||||
|
@ -8,59 +8,69 @@ import pydantic
|
||||
|
||||
|
||||
class BaseConfig(pydantic.BaseModel):
|
||||
"""Base class for all derived config classes.
|
||||
"""Base class for all derived config classes.
|
||||
|
||||
This class provides some convenient functionality:
|
||||
- Disallows extra fields when constructing an object. User error
|
||||
should be reduced by exact arguments.
|
||||
- "one_of" fields. A subclass can group optional fields and enforce
|
||||
that only one of the fields be set. For example:
|
||||
This class provides convenient functionality and constraints for derived config classes:
|
||||
|
||||
```
|
||||
- Disallows extra fields when constructing an object. User errors due to extraneous arguments
|
||||
are minimized.
|
||||
- "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields
|
||||
be set. For example:
|
||||
|
||||
```python
|
||||
class ExampleConfig(BaseConfig):
|
||||
x: int = Field(None, one_of="group_1")
|
||||
y: int = Field(None, one_of="group_1")
|
||||
x: int = Field(None, one_of="group_1")
|
||||
y: int = Field(None, one_of="group_1")
|
||||
|
||||
ExampleConfig(x=1) # ok
|
||||
ExampleConfig(y=1) # ok
|
||||
ExampleConfig(x=1, y=1) # throws error
|
||||
ExampleConfig(x=1) # OK
|
||||
ExampleConfig(y=1) # OK
|
||||
ExampleConfig(x=1, y=1) # Raises an error
|
||||
```
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Forbids extras."""
|
||||
Attributes:
|
||||
Config (class): Configuration options for this class, forbidding extra fields.
|
||||
|
||||
extra = pydantic.Extra.forbid # noqa
|
||||
Methods:
|
||||
_field_data_map(cls, field_data_name): Create a map of fields with the provided field data.
|
||||
_one_of_check(cls, values): Validate that all 'one of' fields appear exactly once.
|
||||
_at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once.
|
||||
pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging.
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def _field_data_map(cls, field_data_name):
|
||||
"""Create a map of fields with provided the field data."""
|
||||
schema = cls.schema()
|
||||
one_of = collections.defaultdict(list)
|
||||
for field, fdata in schema["properties"].items():
|
||||
if field_data_name in fdata:
|
||||
one_of[fdata[field_data_name]].append(field)
|
||||
return one_of
|
||||
"""
|
||||
class Config:
|
||||
"""Configuration options that forbid extra fields."""
|
||||
extra = pydantic.Extra.forbid # noqa
|
||||
|
||||
@pydantic.root_validator
|
||||
def _one_of_check(cls, values):
|
||||
"""Validate that all 'one of' fields are appear exactly once."""
|
||||
one_of_map = cls._field_data_map("one_of")
|
||||
for one_of, field_names in one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
||||
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
||||
return values
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def _field_data_map(cls, field_data_name):
|
||||
"""Create a map of fields with the provided field data."""
|
||||
schema = cls.schema()
|
||||
one_of = collections.defaultdict(list)
|
||||
for field, fdata in schema["properties"].items():
|
||||
if field_data_name in fdata:
|
||||
one_of[fdata[field_data_name]].append(field)
|
||||
return one_of
|
||||
|
||||
@pydantic.root_validator
|
||||
def _at_most_one_of_check(cls, values):
|
||||
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
||||
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
||||
for one_of, field_names in at_most_one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
||||
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
||||
return values
|
||||
@pydantic.root_validator
|
||||
def _one_of_check(cls, values):
|
||||
"""Validate that all 'one of' fields appear exactly once."""
|
||||
one_of_map = cls._field_data_map("one_of")
|
||||
for one_of, field_names in one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
||||
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
||||
return values
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
"""Return a human legible (yaml) representation of the config useful for logging."""
|
||||
return yaml.dump(self.dict())
|
||||
@pydantic.root_validator
|
||||
def _at_most_one_of_check(cls, values):
|
||||
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
||||
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
||||
for one_of, field_names in at_most_one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
||||
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
||||
|
||||
return values
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
"""Return a human-readable (YAML) representation of the config useful for logging."""
|
||||
return yaml.dump(self.dict())
|
||||
|
@ -6,33 +6,67 @@ import pydantic
|
||||
|
||||
|
||||
class BaseConfigTest(TestCase):
|
||||
def test_extra_forbidden(self):
|
||||
class Config(BaseConfig):
|
||||
x: int
|
||||
"""
|
||||
Unit tests for the BaseConfig class.
|
||||
"""
|
||||
|
||||
Config(x=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=2)
|
||||
def test_extra_forbidden(self):
|
||||
"""
|
||||
Test that extra fields are forbidden when creating a Config instance.
|
||||
|
||||
def test_one_of(self):
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, one_of="f")
|
||||
y: int = pydantic.Field(None, one_of="f")
|
||||
This test case checks whether the `BaseConfig` class correctly raises a
|
||||
`pydantic.ValidationError` when extra fields are provided when creating a
|
||||
`Config` instance.
|
||||
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=3)
|
||||
Raises:
|
||||
AssertionError: If the test fails.
|
||||
"""
|
||||
class Config(BaseConfig):
|
||||
x: int
|
||||
|
||||
def test_at_most_one_of(self):
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, at_most_one_of="f")
|
||||
y: str = pydantic.Field(None, at_most_one_of="f")
|
||||
Config(x=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=2)
|
||||
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y="a")
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y="a")
|
||||
def test_one_of(self):
|
||||
"""
|
||||
Test the use of the `one_of` attribute for fields in a Config instance.
|
||||
|
||||
This test case checks the behavior of the `one_of` attribute in a `Config`
|
||||
instance. It verifies that the `pydantic.Field` correctly enforces the
|
||||
specified constraint.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the test fails.
|
||||
"""
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, one_of="f")
|
||||
y: int = pydantic.Field(None, one_of="f")
|
||||
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=3)
|
||||
|
||||
def test_at_most_one_of(self):
|
||||
"""
|
||||
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
|
||||
|
||||
This test case checks the behavior of the `at_most_one_of` attribute in a
|
||||
`Config` instance. It verifies that the `pydantic.Field` enforces the
|
||||
constraint where at most one of the specified fields can be provided.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the test fails.
|
||||
"""
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, at_most_one_of="f")
|
||||
y: str = pydantic.Field(None, at_most_one_of="f")
|
||||
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y="a")
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y="a")
|
||||
|
@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig
|
||||
|
||||
|
||||
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
|
||||
"""Recommend method to load a config file (a yaml file) and parse it.
|
||||
|
||||
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
||||
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
||||
"""
|
||||
Recommend method to Load and parse a configuration from a YAML file.
|
||||
|
||||
This function loads a configuration from a YAML file, parses it, and returns an instance of the
|
||||
specified config type.
|
||||
|
||||
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
||||
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
||||
|
||||
Args:
|
||||
config_type (Type[BaseConfig]): The Pydantic config class to load.
|
||||
yaml_path (str): The path to the YAML configuration file.
|
||||
|
||||
Returns:
|
||||
BaseConfig: An instance of the specified config type populated with values from the YAML file.
|
||||
|
||||
Example:
|
||||
Suppose you have a YAML file 'my_config.yaml' containing the following:
|
||||
|
||||
```yaml
|
||||
x: 42
|
||||
y: "hello"
|
||||
```
|
||||
|
||||
You can load and parse it using this function as follows:
|
||||
|
||||
```python
|
||||
my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml')
|
||||
```
|
||||
|
||||
Note:
|
||||
This function performs environment variable substitution in the YAML file. It replaces
|
||||
occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable
|
||||
values. If an environment variable does not exist, the string is left unchanged.
|
||||
|
||||
"""
|
||||
|
||||
def _substitute(s):
|
||||
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
||||
|
@ -8,11 +8,27 @@ import pydantic
|
||||
|
||||
|
||||
class _PointlessConfig(BaseConfig):
|
||||
|
||||
a: int
|
||||
user: str
|
||||
|
||||
|
||||
def test_load_config_from_yaml(tmp_path):
|
||||
"""Test loading a configuration from a YAML file and verifying its values.
|
||||
|
||||
This test function checks the functionality of the `load_config_from_yaml` function by creating
|
||||
a temporary YAML configuration file, loading it, and asserting that the loaded config object
|
||||
has the expected values.
|
||||
|
||||
Args:
|
||||
tmp_path: A temporary directory provided by the `pytest` framework.
|
||||
|
||||
Test Steps:
|
||||
1. Create a temporary YAML file containing configuration data.
|
||||
2. Use the `load_config_from_yaml` function to load the configuration from the YAML file.
|
||||
3. Assert that the loaded configuration object has the expected values.
|
||||
|
||||
"""
|
||||
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
|
||||
with open(yaml_path, "w") as yaml_file:
|
||||
yaml_file.write("""a: 3\nuser: ${USER}\n""")
|
||||
|
@ -9,6 +9,14 @@ import pydantic
|
||||
|
||||
|
||||
class RuntimeConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for runtime settings.
|
||||
|
||||
Attributes:
|
||||
- wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration.
|
||||
- enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices.
|
||||
- enable_amp (bool): Enable automatic mixed precision.
|
||||
"""
|
||||
wandb: WandbConfig = pydantic.Field(None)
|
||||
enable_tensorfloat32: bool = pydantic.Field(
|
||||
False, description="Use tensorfloat32 if on Ampere devices."
|
||||
@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class TrainingConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for training settings.
|
||||
|
||||
Attributes:
|
||||
- save_dir (str): Directory to save checkpoints.
|
||||
- num_train_steps (pydantic.PositiveInt): Number of training steps.
|
||||
- initial_checkpoint_dir (str): Directory of initial checkpoints (optional).
|
||||
- checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps.
|
||||
- checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
|
||||
- train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps.
|
||||
- num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used.
|
||||
- eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps.
|
||||
- eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds.
|
||||
- gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
|
||||
- num_epochs (pydantic.PositiveInt): Number of training epochs.
|
||||
"""
|
||||
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
|
||||
num_train_steps: pydantic.PositiveInt = 10000
|
||||
initial_checkpoint_dir: str = pydantic.Field(
|
||||
|
@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable):
|
||||
|
||||
|
||||
def _get_step_fn(pipeline, data_iterator, training: bool):
|
||||
"""
|
||||
Returns a function to perform a single evaluation step.
|
||||
|
||||
Args:
|
||||
pipeline (Pipeline): The pipeline object containing the model.
|
||||
data_iterator (Iterator): The data iterator for evaluation.
|
||||
training (bool): Flag indicating if the model should be in training mode.
|
||||
|
||||
Returns:
|
||||
function: A function that performs a single evaluation step.
|
||||
"""
|
||||
def step_fn():
|
||||
"""
|
||||
Perform a single evaluation step.
|
||||
|
||||
Returns:
|
||||
Any: The evaluation results after a single step.
|
||||
"""
|
||||
# It turns out that model.train() and model.eval() simply switch a single field inside the model
|
||||
# class,so it's somewhat safer to wrap in here.
|
||||
if training:
|
||||
@ -69,7 +86,21 @@ def _run_evaluation(
|
||||
eval_batch_size: int,
|
||||
logger=None,
|
||||
):
|
||||
"""Runs the evaluation loop over all evaluation iterators."""
|
||||
"""
|
||||
Run the evaluation loop over all evaluation iterators.
|
||||
|
||||
Args:
|
||||
pipeline (Pipeline): The pipeline object containing the model.
|
||||
dataset (Dataset): The dataset to evaluate.
|
||||
eval_steps (int): The number of evaluation steps to perform.
|
||||
metrics (tm.MetricCollection): A collection of evaluation metrics.
|
||||
eval_batch_size (int): Batch size for evaluation.
|
||||
logger (Optional[Logger]): A logger for recording evaluation progress (default: None).
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the computed evaluation metrics.
|
||||
"""
|
||||
|
||||
dataset = get_new_iterator(dataset)
|
||||
step_fn = _get_step_fn(pipeline, dataset, training=False)
|
||||
last_time = datetime.datetime.now()
|
||||
@ -109,15 +140,29 @@ def train(
|
||||
parameters_to_log: Optional[Dict[str, Callable]] = None,
|
||||
tables_to_log: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Runs training and eval on the given TrainPipeline
|
||||
|
||||
Args:
|
||||
dataset: data iterator for the training set
|
||||
evaluation_iterators: data iterators for the different evaluation sets
|
||||
scheduler: optional learning rate scheduler
|
||||
output_transform_for_metrics: optional transformation functions to transorm the model
|
||||
output and labels into a format the metrics can understand
|
||||
"""
|
||||
Runs training and evaluation on the given TrainPipeline.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The neural network model to train.
|
||||
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
|
||||
device (str): The target device for model training (e.g., 'cuda' or 'cpu').
|
||||
save_dir (str): The directory to save model checkpoints and logs.
|
||||
logging_interval (int): Interval for logging training progress.
|
||||
train_steps (int): The number of training steps to perform.
|
||||
checkpoint_frequency (int): Frequency of saving model checkpoints.
|
||||
dataset (Iterable): Data iterator for the training set.
|
||||
worker_batch_size (int): Batch size for data loading workers.
|
||||
num_workers (Optional[int]): Number of data loading workers (default: 0).
|
||||
enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False).
|
||||
initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None).
|
||||
gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None).
|
||||
logger_initializer (Optional[Callable]): A logger initializer function (default: None).
|
||||
scheduler (_LRScheduler): Optional learning rate scheduler (default: None).
|
||||
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
|
||||
parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None).
|
||||
tables_to_log (Optional[List[str]]): List of tables to log (default: None).
|
||||
"""
|
||||
|
||||
train_pipeline = TrainPipelineSparseDist(
|
||||
model=model,
|
||||
@ -262,6 +307,15 @@ def log_eval_results(
|
||||
partition_name: str,
|
||||
step: int,
|
||||
):
|
||||
"""
|
||||
Logs evaluation results and optionally records them using a provided logger.
|
||||
|
||||
Args:
|
||||
results (Any): The evaluation results to log.
|
||||
eval_logger (Callable): A logger for recording evaluation results.
|
||||
partition_name (str): The name of the evaluation partition.
|
||||
step (int): The current step in the evaluation.
|
||||
"""
|
||||
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
|
||||
logging.info(f"Step: {step}, evaluation ({partition_name}).")
|
||||
for metric_name, metric_value in results.items():
|
||||
@ -285,6 +339,23 @@ def only_evaluate(
|
||||
partition_name: str,
|
||||
metrics: Optional[tm.MetricCollection] = None,
|
||||
):
|
||||
"""
|
||||
Performs evaluation on a given dataset partition.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The neural network model for evaluation.
|
||||
optimizer (torch.optim.Optimizer): The optimizer used during evaluation.
|
||||
device (str): The target device for evaluation (e.g., 'cuda' or 'cpu').
|
||||
save_dir (str): The directory containing model checkpoints.
|
||||
num_train_steps (int): The total number of training steps.
|
||||
dataset (Iterable): Data iterator for evaluation.
|
||||
eval_batch_size (int): Batch size for evaluation.
|
||||
num_eval_steps (int): The number of evaluation steps to perform.
|
||||
eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds.
|
||||
eval_logger (Callable): A logger for recording evaluation results.
|
||||
partition_name (str): The name of the evaluation partition.
|
||||
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
|
||||
"""
|
||||
logging.info(f"Evaluating on partition {partition_name}.")
|
||||
logging.info("Computing metrics:")
|
||||
logging.info(metrics)
|
||||
|
@ -28,6 +28,18 @@ def train(
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Debugging training loop. Do not use for actual model training.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The neural network model.
|
||||
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
|
||||
train_steps (int): The number of training steps to perform.
|
||||
dataset (Iterable): Data iterator for training data.
|
||||
scheduler (_LRScheduler, optional): Learning rate scheduler (default: None).
|
||||
*args: Additional arguments (ignored).
|
||||
**kwargs: Additional keyword arguments (ignored).
|
||||
"""
|
||||
|
||||
logging.warning("Running debug training loop, don't use for model training.")
|
||||
|
||||
|
@ -10,8 +10,11 @@ import torch
|
||||
|
||||
def _maybe_warn(reduction: str):
|
||||
"""
|
||||
Warning for reduction different than mean.
|
||||
"""
|
||||
Emit a warning if the reduction method is different from 'mean'.
|
||||
|
||||
Args:
|
||||
reduction (str): The reduction method being used.
|
||||
"""
|
||||
if reduction != "mean":
|
||||
logging.warn(
|
||||
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
|
||||
@ -24,6 +27,16 @@ def build_loss(
|
||||
loss_type: LossType,
|
||||
reduction="mean",
|
||||
):
|
||||
"""
|
||||
Build a loss function based on the specified loss type and reduction method.
|
||||
|
||||
Args:
|
||||
loss_type (LossType): The type of loss to build.
|
||||
reduction (str): The reduction method for the loss (default: 'mean').
|
||||
|
||||
Returns:
|
||||
Callable: A loss function that takes logits and labels as input.
|
||||
"""
|
||||
_maybe_warn(reduction)
|
||||
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||||
|
||||
@ -35,11 +48,15 @@ def build_loss(
|
||||
|
||||
def get_global_loss_detached(local_loss, reduction="mean"):
|
||||
"""
|
||||
Perform all_reduce to obtain the global loss function using the provided reduction.
|
||||
:param local_loss: The local loss of the current rank.
|
||||
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
|
||||
:return: The reduced & detached global loss.
|
||||
"""
|
||||
Perform all_reduce to obtain the global loss function using the provided reduction.
|
||||
|
||||
Args:
|
||||
local_loss (torch.Tensor): The local loss of the current rank.
|
||||
reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reduced and detached global loss.
|
||||
"""
|
||||
if reduction != "mean":
|
||||
logging.warn(
|
||||
f"The reduction used in this function should be the same as the one used by "
|
||||
@ -66,6 +83,19 @@ def build_multi_task_loss(
|
||||
global_reduction="mean",
|
||||
pos_weights=None,
|
||||
):
|
||||
"""
|
||||
Build a multi-task loss function based on the specified loss type and configurations.
|
||||
|
||||
Args:
|
||||
loss_type (LossType): The type of loss to build.
|
||||
tasks (typing.List[str]): List of task names.
|
||||
task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean').
|
||||
global_reduction (str): Reduction method for the global loss (default: 'mean').
|
||||
pos_weights (Optional): Positive class weights for tasks (default: None).
|
||||
|
||||
Returns:
|
||||
Callable: A multi-task loss function that takes logits, labels, and weights as input.
|
||||
"""
|
||||
_maybe_warn(global_reduction)
|
||||
_maybe_warn(task_loss_reduction)
|
||||
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||||
|
@ -36,9 +36,24 @@ import torchmetrics
|
||||
class MetricMixin:
|
||||
@abstractmethod
|
||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
|
||||
"""
|
||||
Abstract method to transform model outputs into a dictionary of metrics.
|
||||
|
||||
Args:
|
||||
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary of computed metrics.
|
||||
"""
|
||||
...
|
||||
|
||||
def update(self, outputs: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
Update the metrics based on model outputs.
|
||||
|
||||
Args:
|
||||
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||
"""
|
||||
results = self.transform(outputs)
|
||||
# Do not try to update if any tensor is empty as a result of stratification.
|
||||
for value in results.values():
|
||||
@ -49,6 +64,13 @@ class MetricMixin:
|
||||
|
||||
class TaskMixin:
|
||||
def __init__(self, task_idx: int = -1, **kwargs):
|
||||
"""
|
||||
Initialize a TaskMixin instance.
|
||||
|
||||
Args:
|
||||
task_idx (int): Index of the task associated with this mixin (default: -1).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._task_idx = task_idx
|
||||
|
||||
@ -59,13 +81,31 @@ class StratifyMixin:
|
||||
stratifier=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize a StratifyMixin instance.
|
||||
|
||||
Args:
|
||||
stratifier: A stratifier for filtering outputs (default: None).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._stratifier = stratifier
|
||||
|
||||
def maybe_apply_stratification(
|
||||
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
|
||||
"""
|
||||
Apply stratification to filter examples in the outputs.
|
||||
|
||||
Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.
|
||||
|
||||
Args:
|
||||
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||
value_names (List[str]): Names of values to filter.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: Filtered outputs.
|
||||
"""
|
||||
outputs = outputs.copy()
|
||||
if not self._stratifier:
|
||||
return outputs
|
||||
@ -84,12 +124,20 @@ class StratifyMixin:
|
||||
|
||||
|
||||
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
|
||||
"""Returns new class using MetricMixin and given base_metric.
|
||||
|
||||
Functionally the same using inheritance, just saves some lines of code
|
||||
if no need for class attributes.
|
||||
|
||||
"""
|
||||
Returns a new class using MetricMixin and the given base_metric.
|
||||
|
||||
Functionally the same as using inheritance, but it saves some lines of code
|
||||
if there's no need for class attributes.
|
||||
|
||||
Args:
|
||||
base_metric (torchmetrics.Metric): The base metric class to prepend the transform to.
|
||||
transform (Callable): The transformation function to prepend to the metric.
|
||||
|
||||
Returns:
|
||||
Type: A new class that includes MetricMixin and the provided base_metric
|
||||
with the specified transformation method.
|
||||
"""
|
||||
|
||||
def transform_method(_self, *args, **kwargs):
|
||||
return transform(*args, **kwargs)
|
||||
|
@ -15,6 +15,16 @@ def probs_and_labels(
|
||||
outputs: Dict[str, torch.Tensor],
|
||||
task_idx: int,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Extract probabilities and labels from model outputs.
|
||||
|
||||
Args:
|
||||
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||
task_idx (int): Index of the task.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors.
|
||||
"""
|
||||
preds = outputs["probabilities"]
|
||||
target = outputs["labels"]
|
||||
if task_idx >= 0:
|
||||
@ -28,6 +38,11 @@ def probs_and_labels(
|
||||
|
||||
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
|
||||
def transform(self, outputs):
|
||||
"""
|
||||
Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric.
|
||||
|
||||
This metric counts values after potential stratification and task selection.
|
||||
"""
|
||||
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
||||
value = outputs["labels"]
|
||||
if self._task_idx >= 0:
|
||||
@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
|
||||
|
||||
|
||||
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
|
||||
|
||||
This metric calculates the mean metric value after potential stratification and task selection.
|
||||
"""
|
||||
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
||||
value = outputs["labels"]
|
||||
@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
|
||||
|
||||
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
|
||||
|
||||
This metric calculates the mean metric value using probabilities after potential stratification and task selection.
|
||||
"""
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
|
||||
value = outputs["probabilities"]
|
||||
@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
|
||||
|
||||
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
|
||||
"""
|
||||
Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision.
|
||||
|
||||
This metric computes precision after potential stratification and task selection.
|
||||
"""
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
||||
return probs_and_labels(outputs, self._task_idx)
|
||||
|
||||
|
||||
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
|
||||
"""
|
||||
Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall.
|
||||
|
||||
This metric computes recall after potential stratification and task selection.
|
||||
"""
|
||||
def transform(self, outputs):
|
||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
||||
return probs_and_labels(outputs, self._task_idx)
|
||||
@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
|
||||
|
||||
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
AUC (Area Under the ROC Curve) metric class.
|
||||
|
||||
This metric computes the AUC metric based on the logits and labels in the model outputs.
|
||||
|
||||
Args:
|
||||
num_samples (int): The number of samples used to compute AUC.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Based on:
|
||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
|
||||
"""
|
||||
@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
|
||||
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
The ranks of all positives
|
||||
Based on:
|
||||
PosRanks metric class.
|
||||
|
||||
This metric computes the ranks of all positive examples based on the logits and labels
|
||||
in the model outputs.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
|
||||
"""
|
||||
|
||||
@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
|
||||
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
The reciprocal of the ranks of all
|
||||
Based on:
|
||||
ReciprocalRank metric class.
|
||||
|
||||
This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels
|
||||
in the model outputs.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments.
|
||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
|
||||
"""
|
||||
|
||||
@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
|
||||
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||
"""
|
||||
The fraction of positives that rank in the top K among their negatives
|
||||
Note that this is basically precision@k
|
||||
Based on:
|
||||
HitAtK metric class.
|
||||
|
||||
This metric computes the fraction of positive examples that rank in the top K among their negatives,
|
||||
which is equivalent to precision@K.
|
||||
|
||||
Args:
|
||||
k (int): The value of K.
|
||||
**kwargs: Additional keyword arguments.
|
||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
|
||||
"""
|
||||
|
||||
|
@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric
|
||||
|
||||
@dataclass
|
||||
class MockStratifierConfig:
|
||||
"""
|
||||
Configuration dataclass for mocking a stratifier.
|
||||
|
||||
Args:
|
||||
name (str): The name of the stratifier.
|
||||
index (int): The index of the stratifier.
|
||||
value (int): The value of the stratifier.
|
||||
"""
|
||||
name: str
|
||||
index: int
|
||||
value: int
|
||||
|
||||
|
||||
class Count(MetricMixin, SumMetric):
|
||||
"""
|
||||
Count metric class that inherits from MetricMixin and SumMetric.
|
||||
|
||||
This metric counts occurrences.
|
||||
|
||||
"""
|
||||
def transform(self, outputs):
|
||||
return {"value": 1}
|
||||
|
||||
@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
|
||||
|
||||
|
||||
def test_count_metric():
|
||||
"""
|
||||
Test function for the Count metric.
|
||||
|
||||
It checks if the Count metric correctly counts the number of examples.
|
||||
|
||||
"""
|
||||
num_examples = 123
|
||||
examples = [
|
||||
{"stuff": 0},
|
||||
@ -36,6 +56,12 @@ def test_count_metric():
|
||||
|
||||
|
||||
def test_collections():
|
||||
"""
|
||||
Test function for metric collections.
|
||||
|
||||
It tests if metric collections correctly aggregate metrics.
|
||||
|
||||
"""
|
||||
max_metric = Max()
|
||||
count_metric = Count()
|
||||
metric = MetricCollection([max_metric, count_metric])
|
||||
@ -51,6 +77,12 @@ def test_collections():
|
||||
|
||||
|
||||
def test_task_dependent_ctr():
|
||||
"""
|
||||
Test function for task-dependent Ctr (Click-Through Rate) metric.
|
||||
|
||||
It checks if the Ctr metric computes the correct value for different tasks.
|
||||
|
||||
"""
|
||||
num_examples = 144
|
||||
batch_size = 1024
|
||||
outputs = [
|
||||
@ -69,6 +101,13 @@ def test_task_dependent_ctr():
|
||||
|
||||
|
||||
def test_stratified_ctr():
|
||||
"""
|
||||
Test function for the Stratified Ctr (Click-Through Rate) metric.
|
||||
|
||||
It checks if the Stratified Ctr metric computes the correct value for different tasks
|
||||
and stratified samples.
|
||||
|
||||
"""
|
||||
outputs = [
|
||||
{
|
||||
"stuff": 0,
|
||||
@ -114,6 +153,12 @@ def test_stratified_ctr():
|
||||
|
||||
|
||||
def test_auc():
|
||||
"""
|
||||
Test function for the AUC (Area Under the Curve) metric.
|
||||
|
||||
It checks if the AUC metric correctly computes the Area Under the ROC Curve.
|
||||
|
||||
"""
|
||||
num_samples = 10000
|
||||
metric = core_metrics.Auc(num_samples)
|
||||
target = torch.tensor([0, 0, 1, 1, 1])
|
||||
@ -131,6 +176,12 @@ def test_auc():
|
||||
|
||||
|
||||
def test_pos_rank():
|
||||
"""
|
||||
Test function for the PosRanks metric.
|
||||
|
||||
It checks if the PosRanks metric correctly computes the ranks of positive samples.
|
||||
|
||||
"""
|
||||
metric = core_metrics.PosRanks()
|
||||
target = torch.tensor([0, 0, 1, 1, 1])
|
||||
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
||||
@ -147,6 +198,12 @@ def test_pos_rank():
|
||||
|
||||
|
||||
def test_reciprocal_rank():
|
||||
"""
|
||||
Test function for the Reciprocal Rank metric.
|
||||
|
||||
It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks.
|
||||
|
||||
"""
|
||||
metric = core_metrics.ReciprocalRank()
|
||||
target = torch.tensor([0, 0, 1, 1, 1])
|
||||
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
||||
@ -163,6 +220,12 @@ def test_reciprocal_rank():
|
||||
|
||||
|
||||
def test_hit_k():
|
||||
"""
|
||||
Test function for the Hit@K metric.
|
||||
|
||||
It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives.
|
||||
|
||||
"""
|
||||
hit1_metric = core_metrics.HitAtK(1)
|
||||
target = torch.tensor([0, 0, 1, 1, 1])
|
||||
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
|
||||
|
@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel
|
||||
|
||||
@dataclass
|
||||
class MockDataclassBatch(DataclassBatch):
|
||||
"""
|
||||
Mock data class batch for testing purposes.
|
||||
|
||||
This class represents a batch of data with continuous features and labels.
|
||||
|
||||
Attributes:
|
||||
continuous_features (torch.Tensor): Tensor containing continuous feature data.
|
||||
labels (torch.Tensor): Tensor containing label data.
|
||||
"""
|
||||
continuous_features: torch.Tensor
|
||||
labels: torch.Tensor
|
||||
|
||||
|
||||
class MockModule(torch.nn.Module):
|
||||
"""
|
||||
Mock PyTorch module for testing purposes.
|
||||
|
||||
This module defines a simple neural network model with a linear layer
|
||||
followed by a BCEWithLogitsLoss loss function.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Linear): The linear model layer.
|
||||
loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.model = torch.nn.Linear(10, 1)
|
||||
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the mock module.
|
||||
|
||||
Args:
|
||||
batch (MockDataclassBatch): Input data batch with continuous features and labels.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions.
|
||||
"""
|
||||
pred = self.model(batch.continuous_features)
|
||||
loss = self.loss_fn(pred, batch.labels)
|
||||
return (loss, pred)
|
||||
|
||||
|
||||
def create_batch(bsz: int):
|
||||
"""
|
||||
Create a mock data batch with random continuous features and labels.
|
||||
|
||||
Args:
|
||||
bsz (int): Batch size.
|
||||
|
||||
Returns:
|
||||
MockDataclassBatch: A batch of data with continuous features and labels.
|
||||
"""
|
||||
return MockDataclassBatch(
|
||||
continuous_features=torch.rand(bsz, 10).float(),
|
||||
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
|
||||
@ -35,6 +72,13 @@ def create_batch(bsz: int):
|
||||
|
||||
|
||||
def test_sparse_pipeline():
|
||||
"""
|
||||
Test function for the sparse pipeline with distributed model parallelism.
|
||||
|
||||
This function tests the behavior of the sparse training pipeline using
|
||||
a mock module and data.
|
||||
"""
|
||||
|
||||
device = torch.device("cpu")
|
||||
model = MockModule().to(device)
|
||||
|
||||
@ -65,6 +109,15 @@ def test_sparse_pipeline():
|
||||
|
||||
|
||||
def test_amp():
|
||||
"""
|
||||
Test automatic mixed-precision (AMP) training with the sparse pipeline.
|
||||
|
||||
This function tests the behavior of the sparse training pipeline with
|
||||
automatic mixed-precision (AMP) enabled, using a mock module and data.
|
||||
|
||||
AMP allows for faster training by using lower-precision data types, such as
|
||||
torch.bfloat16, while maintaining model accuracy.
|
||||
"""
|
||||
device = torch.device("cpu")
|
||||
model = MockModule().to(device)
|
||||
|
||||
|
@ -39,12 +39,42 @@ Out = TypeVar("Out")
|
||||
|
||||
|
||||
class TrainPipeline(abc.ABC, Generic[In, Out]):
|
||||
"""
|
||||
Abstract base class for training pipelines.
|
||||
|
||||
Attributes:
|
||||
In (TypeVar): Input data type.
|
||||
Out (TypeVar): Output data type.
|
||||
|
||||
Methods:
|
||||
progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline.
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||
"""
|
||||
Make progress in the training pipeline.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||
|
||||
Returns:
|
||||
Out: The output data.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
||||
"""
|
||||
Move a batch of data to a specified device.
|
||||
|
||||
Args:
|
||||
batch (In): The input batch.
|
||||
device (torch.device): The target device.
|
||||
non_blocking (bool): If True, move the data asynchronously.
|
||||
|
||||
Returns:
|
||||
In: The batch of data on the target device.
|
||||
"""
|
||||
assert isinstance(
|
||||
batch, (torch.Tensor, Pipelineable)
|
||||
), f"{type(batch)} must implement Pipelineable interface"
|
||||
@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
||||
|
||||
|
||||
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
||||
"""
|
||||
Wait for a batch of data on a specified stream.
|
||||
|
||||
Args:
|
||||
batch (In): The input batch.
|
||||
stream (Optional[Stream]): The CUDA stream to wait for.
|
||||
|
||||
Note:
|
||||
This function is used for managing asynchronous CUDA operations.
|
||||
"""
|
||||
if stream is None:
|
||||
return
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N
|
||||
|
||||
class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
"""
|
||||
This class runs training iterations using a pipeline of two stages, each as a CUDA
|
||||
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
|
||||
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
|
||||
memory, and the default stream runs forward, backward, and optimization.
|
||||
"""
|
||||
This class runs training iterations using a pipeline of two stages, each as a CUDA
|
||||
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
|
||||
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
|
||||
memory, and the default stream runs forward, backward, and optimization.
|
||||
|
||||
Attributes:
|
||||
In (TypeVar): Input data type.
|
||||
Out (TypeVar): Output data type.
|
||||
|
||||
Methods:
|
||||
__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None:
|
||||
Initialize the TrainPipelineBase.
|
||||
|
||||
_connect(dataloader_iter: Iterator[In]) -> None:
|
||||
Establish a connection to the data loader and move the input data to the GPU.
|
||||
|
||||
progress(dataloader_iter: Iterator[In]) -> Out:
|
||||
Execute a training iteration, including forward and backward passes.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the TrainPipelineBase.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The PyTorch model to be trained.
|
||||
optimizer (torch.optim.Optimizer): The optimizer used for training.
|
||||
device (torch.device): The target device for training (CPU or GPU).
|
||||
"""
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._device = device
|
||||
@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
self._connected = False
|
||||
|
||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||
"""
|
||||
Establish a connection to the data loader and move the input data to the GPU.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||
"""
|
||||
cur_batch = next(dataloader_iter)
|
||||
self._cur_batch = cur_batch
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
self._connected = True
|
||||
|
||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||
"""
|
||||
Execute a training iteration, including forward and backward passes.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||
|
||||
Returns:
|
||||
Out: The output data.
|
||||
"""
|
||||
if not self._connected:
|
||||
self._connect(dataloader_iter)
|
||||
|
||||
@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||
|
||||
|
||||
class Tracer(torch.fx.Tracer):
|
||||
"""
|
||||
Custom tracer class for PyTorch models.
|
||||
|
||||
This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings.
|
||||
|
||||
Attributes:
|
||||
proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing.
|
||||
_leaf_modules (List[str]): List of qualified names of leaf modules.
|
||||
"""
|
||||
|
||||
# Disable proxying buffers during tracing. Ideally, proxying buffers would
|
||||
# be disabled, but some models are currently mutating buffer values, which
|
||||
# causes errors during tracing. If those models can be rewritten to not do
|
||||
@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer):
|
||||
proxy_buffer_attributes = False
|
||||
|
||||
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Initialize the Tracer.
|
||||
|
||||
Args:
|
||||
leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing.
|
||||
"""
|
||||
super().__init__()
|
||||
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
|
||||
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
||||
"""
|
||||
Check if a module is a leaf module during tracing.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): The PyTorch module.
|
||||
module_qualified_name (str): The qualified name of the module.
|
||||
|
||||
Returns:
|
||||
bool: True if the module is considered a leaf module, False otherwise.
|
||||
"""
|
||||
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
|
||||
return True
|
||||
return super().is_leaf_module(m, module_qualified_name)
|
||||
@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer):
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineContext:
|
||||
"""
|
||||
Dataclass to store information related to the training pipeline context.
|
||||
|
||||
Attributes:
|
||||
input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests.
|
||||
module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts.
|
||||
feature_processor_forwards (List[Any]): A list of feature processor forwards.
|
||||
"""
|
||||
|
||||
# pyre-ignore [4]
|
||||
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
|
||||
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
|
||||
@ -166,6 +279,14 @@ class TrainPipelineContext:
|
||||
|
||||
@dataclass
|
||||
class ArgInfo:
|
||||
"""
|
||||
Dataclass to store information about arguments in the training pipeline.
|
||||
|
||||
Attributes:
|
||||
input_attrs (List[str]): List of attribute names of the input batch.
|
||||
is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem.
|
||||
name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments.
|
||||
"""
|
||||
# attributes of input batch, e.g. batch.attr1.attr2 call
|
||||
# will produce ["attr1", "attr2"]
|
||||
input_attrs: List[str]
|
||||
@ -177,6 +298,16 @@ class ArgInfo:
|
||||
|
||||
|
||||
class PipelinedForward:
|
||||
"""
|
||||
Represents a pipelined forward pass operation.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the forward pass.
|
||||
args (List[ArgInfo]): List of argument information for the forward pass.
|
||||
module (ShardedModule): The sharded module associated with the forward pass.
|
||||
context (TrainPipelineContext): The training pipeline context.
|
||||
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
@ -185,6 +316,16 @@ class PipelinedForward:
|
||||
context: TrainPipelineContext,
|
||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a PipelinedForward instance.
|
||||
|
||||
Args:
|
||||
name (str): The name of the forward pass.
|
||||
args (List[ArgInfo]): List of argument information for the forward pass.
|
||||
module (ShardedModule): The sharded module associated with the forward pass.
|
||||
context (TrainPipelineContext): The training pipeline context.
|
||||
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
|
||||
"""
|
||||
self._name = name
|
||||
self._args = args
|
||||
self._module = module
|
||||
@ -193,6 +334,16 @@ class PipelinedForward:
|
||||
|
||||
# pyre-ignore [2, 24]
|
||||
def __call__(self, *input, **kwargs) -> Awaitable:
|
||||
"""
|
||||
Perform the pipelined forward pass operation.
|
||||
|
||||
Args:
|
||||
*input: Variable-length positional arguments.
|
||||
**kwargs: Variable-length keyword arguments.
|
||||
|
||||
Returns:
|
||||
Awaitable: An awaitable object representing the forward pass result.
|
||||
"""
|
||||
assert self._name in self._context.input_dist_requests
|
||||
request = self._context.input_dist_requests[self._name]
|
||||
assert isinstance(request, Awaitable)
|
||||
@ -230,10 +381,22 @@ class PipelinedForward:
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Get the name of the forward pass.
|
||||
|
||||
Returns:
|
||||
str: The name of the forward pass.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def args(self) -> List[ArgInfo]:
|
||||
"""
|
||||
Get the list of argument information for the forward pass.
|
||||
|
||||
Returns:
|
||||
List[ArgInfo]: List of argument information.
|
||||
"""
|
||||
return self._args
|
||||
|
||||
|
||||
@ -242,6 +405,17 @@ def _start_data_dist(
|
||||
batch: In,
|
||||
context: TrainPipelineContext,
|
||||
) -> None:
|
||||
"""
|
||||
Start data distribution for a list of pipelined modules.
|
||||
|
||||
Args:
|
||||
pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules.
|
||||
batch (In): The input batch.
|
||||
context (TrainPipelineContext): The training pipeline context.
|
||||
|
||||
Returns:
|
||||
None: This function doesn't return a value.
|
||||
"""
|
||||
context.input_dist_requests.clear()
|
||||
context.module_contexts.clear()
|
||||
for module in pipelined_modules:
|
||||
@ -286,9 +460,17 @@ def _get_node_args_helper(
|
||||
feature_processor_arguments: Optional[List[Node]] = None,
|
||||
) -> Tuple[List[ArgInfo], int]:
|
||||
"""
|
||||
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
|
||||
It also counts the number of (args + kwargs) found.
|
||||
"""
|
||||
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
|
||||
It also counts the number of (args + kwargs) found.
|
||||
|
||||
Args:
|
||||
arguments: The arguments to process.
|
||||
num_found: The current count of arguments found.
|
||||
feature_processor_arguments: Optional list of feature processor arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found.
|
||||
"""
|
||||
|
||||
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
|
||||
for arg, arg_info in zip(arguments, arg_info_list):
|
||||
@ -332,6 +514,16 @@ def _get_node_args_helper(
|
||||
def _get_node_args(
|
||||
node: Node, feature_processor_nodes: Optional[List[Node]] = None
|
||||
) -> Tuple[List[ArgInfo], int]:
|
||||
"""
|
||||
Get argument information for a given node.
|
||||
|
||||
Args:
|
||||
node (Node): The node to process.
|
||||
feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes.
|
||||
|
||||
Returns:
|
||||
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found.
|
||||
"""
|
||||
num_found = 0
|
||||
pos_arg_info_list, num_found = _get_node_args_helper(
|
||||
node.args, num_found, feature_processor_nodes
|
||||
@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper(
|
||||
path: str,
|
||||
unsharded_module_names: Set[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Get the names of unsharded modules in a model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to analyze.
|
||||
path (str): The current path in the model hierarchy.
|
||||
unsharded_module_names (Set[str]): A set to store the names of unsharded modules.
|
||||
|
||||
Returns:
|
||||
bool: True if any sharded modules were found in the hierarchy, False otherwise.
|
||||
"""
|
||||
sharded_children = set()
|
||||
for name, child in model.named_children():
|
||||
curr_path = path + name
|
||||
@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper(
|
||||
|
||||
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
|
||||
"""
|
||||
Returns a list of top level modules do not contain any sharded sub modules.
|
||||
"""
|
||||
Returns a list of top-level modules that do not contain any sharded sub-modules.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to analyze.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of top-level module names without sharded sub-modules.
|
||||
"""
|
||||
|
||||
unsharded_module_names: Set[str] = set()
|
||||
_get_unsharded_module_names_helper(
|
||||
@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901
|
||||
context: TrainPipelineContext,
|
||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||
) -> List[ShardedModule]:
|
||||
"""
|
||||
Rewrites the model to enable pipelined execution for selected sharded modules.
|
||||
|
||||
This function traces the input model using a custom tracer and identifies sharded modules
|
||||
that can be pipelined. It then creates PipelinedForward objects for these modules,
|
||||
which enable pipelining during training.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The input model to be rewritten.
|
||||
context (TrainPipelineContext): The context containing information needed for pipelining.
|
||||
dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution.
|
||||
|
||||
Returns:
|
||||
List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution.
|
||||
"""
|
||||
|
||||
# Get underlying nn.Module
|
||||
if isinstance(model, DistributedModelParallel):
|
||||
@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901
|
||||
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
"""
|
||||
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
|
||||
forward and backward. This helps hide the all2all latency while preserving the
|
||||
training forward / backward ordering.
|
||||
forward and backward. This helps hide the all2all latency while preserving the
|
||||
training forward / backward ordering.
|
||||
|
||||
stage 3: forward, backward - uses default CUDA stream
|
||||
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
|
||||
stage 1: device transfer - uses memcpy CUDA stream
|
||||
stage 3: forward, backward - uses default CUDA stream
|
||||
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
|
||||
stage 1: device transfer - uses memcpy CUDA stream
|
||||
|
||||
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
|
||||
To be considered a top-level module, a module can only depend on 'getattr' calls on
|
||||
input.
|
||||
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
|
||||
To be considered a top-level module, a module can only depend on 'getattr' calls on
|
||||
input.
|
||||
|
||||
Input model must be symbolically traceable with the exception of `ShardedModule` and
|
||||
`DistributedDataParallel` modules.
|
||||
"""
|
||||
Input model must be symbolically traceable with the exception of `ShardedModule` and
|
||||
`DistributedDataParallel` modules.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The input model to be used for training.
|
||||
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
|
||||
device (torch.device): The device where training will be performed.
|
||||
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
|
||||
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
|
||||
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines.
|
||||
|
||||
"""
|
||||
|
||||
synced_pipeline_id: Dict[int, int] = {}
|
||||
|
||||
@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
enable_grad_scaling: bool = True,
|
||||
grad_accum: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the training pipeline.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The input model to be used for training.
|
||||
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
|
||||
device (torch.device): The device where training will be performed.
|
||||
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
|
||||
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
|
||||
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
|
||||
"""
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._device = device
|
||||
@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
self._grad_accum = grad_accum
|
||||
|
||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||
"""
|
||||
Connects the training pipeline to data and prepares for forward and backward passes.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator providing input data batches.
|
||||
"""
|
||||
|
||||
# batch 1
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
batch_i = next(dataloader_iter)
|
||||
@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
|
||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||
"""
|
||||
NOTE: This method has been updated to perform gradient accumulation.
|
||||
If `_grad_accum` is set, then loss values are scaled by this amount and
|
||||
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
||||
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
||||
step.
|
||||
Progresses through the training pipeline, performing forward and backward passes.
|
||||
|
||||
"""
|
||||
NOTE: This method has been updated to perform gradient accumulation.
|
||||
If `_grad_accum` is set, then loss values are scaled by this amount and
|
||||
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
||||
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
||||
step.
|
||||
|
||||
Args:
|
||||
dataloader_iter (Iterator[In]): An iterator providing input data batches.
|
||||
|
||||
Returns:
|
||||
Out: The output of the forward pass.
|
||||
"""
|
||||
should_step_optimizer = (
|
||||
self._grad_accum is not None
|
||||
and self._progress_calls > 0
|
||||
@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||
|
||||
def _sync_pipeline(self) -> None:
|
||||
"""
|
||||
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
|
||||
current train pipeline. Used when switching between train pipelines for the same
|
||||
model.
|
||||
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
|
||||
current train pipeline. Used when switching between train pipelines for the same
|
||||
model.
|
||||
"""
|
||||
for module in self._pipelined_modules:
|
||||
module.forward._context = self._context
|
||||
|
@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222
|
||||
|
||||
|
||||
def on_kf():
|
||||
return "SPEC_TYPE" in os.environ
|
||||
"""Check if the code is running on Kubernetes with Kubeflow (KF) environment.
|
||||
|
||||
Returns:
|
||||
bool: True if running on KF, False otherwise.
|
||||
"""
|
||||
return "SPEC_TYPE" in os.environ
|
||||
|
||||
|
||||
def has_readers():
|
||||
if on_kf():
|
||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||
return machines_config_env["dataset_worker"] is not None
|
||||
return os.environ.get("HAS_READERS", "False") == "True"
|
||||
"""Check if the current task has dataset workers.
|
||||
|
||||
Returns:
|
||||
bool: True if the task has dataset workers, False otherwise.
|
||||
"""
|
||||
if on_kf():
|
||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||
return machines_config_env.get("dataset_worker") is not None
|
||||
return os.environ.get("HAS_READERS", "False") == "True"
|
||||
|
||||
|
||||
def get_task_type():
|
||||
if on_kf():
|
||||
return os.environ["SPEC_TYPE"]
|
||||
return os.environ["TASK_TYPE"]
|
||||
"""Get the type of the current task.
|
||||
|
||||
Returns:
|
||||
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
|
||||
"""
|
||||
if on_kf():
|
||||
return os.environ["SPEC_TYPE"]
|
||||
return os.environ["TASK_TYPE"]
|
||||
|
||||
|
||||
def is_chief() -> bool:
|
||||
return get_task_type() == "chief"
|
||||
"""Check if the current task is the 'chief'.
|
||||
|
||||
Returns:
|
||||
bool: True if the current task is the 'chief', False otherwise.
|
||||
"""
|
||||
return get_task_type() == "chief"
|
||||
|
||||
|
||||
def is_reader() -> bool:
|
||||
return get_task_type() == "datasetworker"
|
||||
"""Check if the current task is a 'datasetworker'.
|
||||
|
||||
Returns:
|
||||
bool: True if the current task is a 'datasetworker', False otherwise.
|
||||
"""
|
||||
return get_task_type() == "datasetworker"
|
||||
|
||||
|
||||
def is_dispatcher() -> bool:
|
||||
return get_task_type() == "datasetdispatcher"
|
||||
"""Check if the current task is a 'datasetdispatcher'.
|
||||
|
||||
Returns:
|
||||
bool: True if the current task is a 'datasetdispatcher', False otherwise.
|
||||
"""
|
||||
return get_task_type() == "datasetdispatcher"
|
||||
|
||||
|
||||
def get_task_index():
|
||||
if on_kf():
|
||||
pod_name = os.environ["MY_POD_NAME"]
|
||||
return int(pod_name.split("-")[-1])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
"""Get the index of the current task.
|
||||
|
||||
Returns:
|
||||
int: Task index.
|
||||
Raises:
|
||||
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
|
||||
"""
|
||||
if on_kf():
|
||||
pod_name = os.environ["MY_POD_NAME"]
|
||||
return int(pod_name.split("-")[-1])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_reader_port():
|
||||
if on_kf():
|
||||
return KF_DDS_PORT
|
||||
return SLURM_DDS_PORT
|
||||
"""Get the port used by readers.
|
||||
|
||||
Returns:
|
||||
int: Reader port.
|
||||
"""
|
||||
if on_kf():
|
||||
return KF_DDS_PORT
|
||||
return SLURM_DDS_PORT
|
||||
|
||||
|
||||
def get_dds():
|
||||
if not has_readers():
|
||||
return None
|
||||
dispatcher_address = get_dds_dispatcher_address()
|
||||
if dispatcher_address:
|
||||
return f"grpc://{dispatcher_address}"
|
||||
else:
|
||||
raise ValueError("Job does not have DDS.")
|
||||
"""Get the Distributed Data Service (DDS) address.
|
||||
|
||||
Returns:
|
||||
str: DDS address in the format 'grpc://host:port'.
|
||||
Raises:
|
||||
ValueError: If the job does not have DDS.
|
||||
"""
|
||||
if not has_readers():
|
||||
return None
|
||||
dispatcher_address = get_dds_dispatcher_address()
|
||||
if dispatcher_address:
|
||||
return f"grpc://{dispatcher_address}"
|
||||
else:
|
||||
raise ValueError("Job does not have DDS.")
|
||||
|
||||
|
||||
def get_dds_dispatcher_address():
|
||||
if not has_readers():
|
||||
return None
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
dds_host = f"{job_name}-datasetdispatcher-0"
|
||||
else:
|
||||
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
|
||||
return f"{dds_host}:{get_reader_port()}"
|
||||
"""Get the DDS dispatcher address.
|
||||
|
||||
Returns:
|
||||
str: DDS dispatcher address in the format 'host:port'.
|
||||
"""
|
||||
if not has_readers():
|
||||
return None
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
dds_host = f"{job_name}-datasetdispatcher-0"
|
||||
else:
|
||||
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
|
||||
return f"{dds_host}:{get_reader_port()}"
|
||||
|
||||
|
||||
def get_dds_worker_address():
|
||||
if not has_readers():
|
||||
return None
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
task_index = get_task_index()
|
||||
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
|
||||
else:
|
||||
node = os.environ["SLURMD_NODENAME"]
|
||||
return f"{node}:{get_reader_port()}"
|
||||
"""Get the DDS worker address.
|
||||
|
||||
Returns:
|
||||
str: DDS worker address in the format 'host:port'.
|
||||
"""
|
||||
if not has_readers():
|
||||
return None
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
task_index = get_task_index()
|
||||
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
|
||||
else:
|
||||
node = os.environ["SLURMD_NODENAME"]
|
||||
return f"{node}:{get_reader_port()}"
|
||||
|
||||
|
||||
def get_num_readers():
|
||||
if not has_readers():
|
||||
return 0
|
||||
if on_kf():
|
||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||
return int(machines_config_env["num_dataset_workers"] or 0)
|
||||
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
|
||||
"""Get the number of dataset workers.
|
||||
|
||||
Returns:
|
||||
int: Number of dataset workers.
|
||||
"""
|
||||
if not has_readers():
|
||||
return 0
|
||||
if on_kf():
|
||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||
return int(machines_config_env.get("num_dataset_workers") or 0)
|
||||
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
|
||||
|
||||
|
||||
def get_flight_server_addresses():
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
return [
|
||||
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
|
||||
for task_index in range(get_num_readers())
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
"""Get Flight server addresses for dataset workers.
|
||||
|
||||
Returns:
|
||||
List[str]: List of Flight server addresses in the format 'grpc://host:port'.
|
||||
Raises:
|
||||
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
|
||||
"""
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
return [
|
||||
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
|
||||
for task_index in range(get_num_readers())
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_dds_journaling_dir():
|
||||
return os.environ.get("DATASET_JOURNALING_DIR", None)
|
||||
"""Get the DDS journaling directory.
|
||||
|
||||
Returns:
|
||||
str: DDS journaling directory.
|
||||
"""
|
||||
return os.environ.get("DATASET_JOURNALING_DIR", None)
|
||||
|
@ -14,18 +14,21 @@ def update_mean(
|
||||
weight: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update the mean according to Welford formula:
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||
Update the mean according to the Welford formula.
|
||||
|
||||
This function updates the mean and the weighted sum of values using the Welford algorithm.
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
||||
Args:
|
||||
current_mean: The value of the current accumulated mean.
|
||||
current_weight_sum: The current weighted sum.
|
||||
value: The new value that needs to be added to get a new mean.
|
||||
weight: The weights for the new value.
|
||||
|
||||
Returns: The updated mean and updated weighted sum.
|
||||
Args:
|
||||
current_mean (torch.Tensor): The value of the current accumulated mean.
|
||||
current_weight_sum (torch.Tensor): The current weighted sum.
|
||||
value (torch.Tensor): The new value that needs to be added to get a new mean.
|
||||
weight (torch.Tensor): The weights for the new value.
|
||||
|
||||
"""
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
|
||||
"""
|
||||
weight = torch.broadcast_to(weight, value.shape)
|
||||
|
||||
# Avoiding (on purpose) in-place operation when using += in case
|
||||
@ -38,12 +41,15 @@ def update_mean(
|
||||
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Merge the state from multiple workers.
|
||||
Args:
|
||||
state: A tensor with the first dimension indicating workers.
|
||||
|
||||
Returns: The accumulated mean from all workers.
|
||||
This function merges the state from multiple workers to compute the accumulated mean.
|
||||
|
||||
"""
|
||||
Args:
|
||||
state (torch.Tensor): A tensor with the first dimension indicating workers.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The accumulated mean from all workers.
|
||||
"""
|
||||
mean, weight_sum = update_mean(
|
||||
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
||||
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
||||
@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
class StableMean(torchmetrics.Metric):
|
||||
"""
|
||||
This implements a numerical stable mean metrics computation using Welford algorithm according to
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||
A numerical stable mean metric using the Welford algorithm.
|
||||
|
||||
This class implements a numerical stable mean metrics computation using the Welford algorithm.
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
|
||||
than the maximum float32 as far as the mean is within the limit of float32.
|
||||
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
||||
"""
|
||||
|
||||
Args:
|
||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||
|
||||
Attributes:
|
||||
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric):
|
||||
)
|
||||
|
||||
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
|
||||
"""
|
||||
Update the current mean.
|
||||
Args:
|
||||
value: Value to update the mean with.
|
||||
weight: weight to use. Shape should be broadcastable to that of value.
|
||||
"""
|
||||
"""Update the current mean.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Value to update the mean with.
|
||||
weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
|
||||
"""
|
||||
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
|
||||
|
||||
if not isinstance(weight, torch.Tensor):
|
||||
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
|
||||
)
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
"""
|
||||
Compute and return the accumulated mean.
|
||||
"""
|
||||
"""Compute and return the accumulated mean.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The accumulated mean.
|
||||
"""
|
||||
return self.mean_and_weight_sum[0]
|
||||
|
@ -29,6 +29,9 @@ def _compute_helper(
|
||||
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
|
||||
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
|
||||
we assume that they are correct prediction (i.e weight = 0).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The computed AUROC
|
||||
"""
|
||||
dim = 0
|
||||
|
||||
@ -52,24 +55,34 @@ def _compute_helper(
|
||||
|
||||
class AUROCWithMWU(torchmetrics.Metric):
|
||||
"""
|
||||
AUROC using Mann-Whitney U-test.
|
||||
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
|
||||
AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
|
||||
|
||||
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return
|
||||
the correct AUROC even if the predicted probabilities are all close to 0.
|
||||
Currently only support binary classification.
|
||||
"""
|
||||
This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
|
||||
scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
|
||||
|
||||
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
|
||||
|
||||
Note: Currently, this implementation only supports binary classification.
|
||||
|
||||
Args:
|
||||
label_threshold (float): Threshold for classifying labels as positive or negative.
|
||||
Labels above this threshold are considered positive, and those below are considered negative.
|
||||
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
|
||||
Otherwise, a warning is logged, and AUROC is computed.
|
||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||
"""
|
||||
|
||||
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
|
||||
"""
|
||||
Initializes the AUROCWithMWU metric.
|
||||
|
||||
Args:
|
||||
label_threshold: Labels strictly above this threshold are considered positive labels,
|
||||
otherwise, they are considered negative.
|
||||
raise_missing_class: If True, an error will be raise if negative or positive class is missing.
|
||||
Otherwise, we will simply log a warning.
|
||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||
"""
|
||||
Args:
|
||||
label_threshold (float): Threshold for classifying labels as positive or negative.
|
||||
Labels above this threshold are considered positive, and those below are considered negative.
|
||||
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
|
||||
Otherwise, a warning is logged, and AUROC is computed.
|
||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.add_state("predictions", default=[], dist_reduce_fx="cat")
|
||||
self.add_state("target", default=[], dist_reduce_fx="cat")
|
||||
|
@ -14,7 +14,24 @@ from absl import logging as logging
|
||||
|
||||
|
||||
def setup_absl_logging():
|
||||
"""Make sure that absl logging pushes to stdout rather than stderr."""
|
||||
"""
|
||||
Configure absl-py logging to direct log messages to stdout and apply a custom log message format.
|
||||
|
||||
This function ensures that log messages generated by the absl-py library are written to stdout
|
||||
rather than stderr. It also applies a custom log message format that includes module, function,
|
||||
line number, log level, and the log message content.
|
||||
|
||||
Note:
|
||||
This function should be called once at the beginning of your script or application to
|
||||
configure absl-py logging.
|
||||
|
||||
Example:
|
||||
To use this function, simply call it at the start of your script:
|
||||
```
|
||||
setup_absl_logging()
|
||||
```
|
||||
|
||||
"""
|
||||
logging.get_absl_handler().python_handler.stream = sys.stdout
|
||||
formatter = py_logging.Formatter(
|
||||
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
|
||||
|
@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging
|
||||
|
||||
class Testtlogging(unittest.TestCase):
|
||||
def test_warn_once(self):
|
||||
"""
|
||||
Test that warning messages are logged only once when using the assertLogs context manager.
|
||||
|
||||
This unit test checks the behavior of the logging system when warning messages are issued
|
||||
multiple times within the same context. It uses the assertLogs context manager to capture
|
||||
log messages at the INFO level and verifies that warning messages are logged only once.
|
||||
|
||||
Example:
|
||||
To use this test case, call it using a test runner like unittest:
|
||||
```
|
||||
python -m unittest your_test_module.TestLogging.test_warn_once
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
with self.assertLogs(level="INFO") as captured_logs:
|
||||
logging.info("first info")
|
||||
logging.warning("first warning")
|
||||
|
@ -18,7 +18,35 @@ import torch.distributed as dist
|
||||
|
||||
|
||||
def rank_specific(logger):
|
||||
"""Ensures that we only override a given logger once."""
|
||||
"""
|
||||
Customize logger behavior based on the distributed environment and rank.
|
||||
|
||||
This function allows for customizing the behavior of a logger based on the distributed environment and the rank
|
||||
of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
|
||||
depending on the rank or limit the number of redundant logs.
|
||||
|
||||
Args:
|
||||
logger: The logger object to customize.
|
||||
|
||||
Returns:
|
||||
The customized logger.
|
||||
|
||||
Example:
|
||||
To use this function with the `logging` module:
|
||||
```python
|
||||
import logging
|
||||
from rank_specific_logging import rank_specific
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
rank_specific(logger)
|
||||
```
|
||||
|
||||
Customization:
|
||||
- Messages are only logged if the distributed environment is not initialized or if the rank matches.
|
||||
- The 'warning' method is limited to logging a single redundant warning.
|
||||
- Logging from rank -1 is redirected to include the rank information.
|
||||
|
||||
"""
|
||||
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
|
||||
return logger
|
||||
|
||||
|
@ -8,11 +8,60 @@ import pydantic
|
||||
|
||||
|
||||
class PiecewiseConstant(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for a piecewise constant learning rate schedule.
|
||||
|
||||
This configuration class allows you to specify a piecewise constant learning rate schedule
|
||||
by defining boundaries and corresponding learning rate values.
|
||||
|
||||
Attributes:
|
||||
learning_rate_boundaries (List[int], optional): List of step boundaries at which
|
||||
the learning rate will change. If None, no boundaries are defined.
|
||||
learning_rate_values (List[float], optional): List of learning rate values
|
||||
corresponding to the boundaries. If None, no values are defined.
|
||||
|
||||
Example:
|
||||
To configure a piecewise constant learning rate schedule, create an instance of this class
|
||||
and set the attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
piecewise_lr = PiecewiseConstant(
|
||||
learning_rate_boundaries=[1000, 2000, 3000],
|
||||
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
The number of learning rate values should be one more than the number of boundaries.
|
||||
|
||||
"""
|
||||
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
|
||||
learning_rate_values: typing.List[float] = pydantic.Field(None)
|
||||
|
||||
|
||||
class LinearRampToConstant(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for a linear ramp-up to constant learning rate schedule.
|
||||
|
||||
This configuration class allows you to specify a learning rate schedule that ramps up linearly
|
||||
from zero to a constant value over a specified number of steps.
|
||||
|
||||
Attributes:
|
||||
learning_rate (float): The final constant learning rate.
|
||||
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
|
||||
|
||||
Example:
|
||||
To configure a linear ramp-up to a constant learning rate, create an instance of this class
|
||||
and set the attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
linear_ramp_lr = LinearRampToConstant(
|
||||
learning_rate=0.1,
|
||||
num_ramp_steps=1000
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
learning_rate: float
|
||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||
description="Number of steps to ramp this up from zero."
|
||||
@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig):
|
||||
|
||||
|
||||
class LinearRampToCosine(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for a linear ramp-up to cosine decay learning rate schedule.
|
||||
|
||||
This configuration class allows you to specify a learning rate schedule that ramps up linearly
|
||||
from zero, then decays following a cosine schedule to a final constant learning rate.
|
||||
|
||||
Attributes:
|
||||
learning_rate (float): The initial learning rate at the start of ramp-up.
|
||||
final_learning_rate (float): The final constant learning rate after decay.
|
||||
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
|
||||
final_num_steps (PositiveInt): Final number of steps where decay stops.
|
||||
|
||||
Example:
|
||||
To configure a linear ramp-up to cosine decay learning rate, create an instance of this
|
||||
class and set the attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
ramp_to_cosine_lr = LinearRampToCosine(
|
||||
learning_rate=0.01,
|
||||
final_learning_rate=0.001,
|
||||
num_ramp_steps=1000,
|
||||
final_num_steps=5000
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
learning_rate: float
|
||||
final_learning_rate: float
|
||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||
@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig):
|
||||
|
||||
|
||||
class LearningRate(base_config.BaseConfig):
|
||||
"""
|
||||
Learning rate configuration for training.
|
||||
|
||||
This configuration class allows you to specify different learning rate schedules
|
||||
for your training process.
|
||||
|
||||
Attributes:
|
||||
constant (float, optional): Constant learning rate to be used throughout training.
|
||||
linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly
|
||||
and then decays following a cosine schedule.
|
||||
linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up
|
||||
linearly and then remains constant.
|
||||
piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified
|
||||
boundaries with corresponding values.
|
||||
|
||||
Example:
|
||||
To configure a learning rate schedule, create an instance of this class and set the
|
||||
attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
learning_rate = LearningRate(
|
||||
constant=0.01,
|
||||
linear_ramp_to_cosine=LinearRampToCosine(
|
||||
learning_rate=0.1,
|
||||
final_learning_rate=0.001,
|
||||
num_ramp_steps=1000,
|
||||
final_num_steps=5000
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
Each learning rate schedule attribute can be set to `None` if not needed.
|
||||
|
||||
"""
|
||||
constant: float = pydantic.Field(None, one_of="lr")
|
||||
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
|
||||
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
|
||||
@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig):
|
||||
|
||||
|
||||
class OptimizerAlgorithmConfig(base_config.BaseConfig):
|
||||
"""Base class for optimizer configurations."""
|
||||
"""
|
||||
Base class for optimizer configurations.
|
||||
|
||||
This base configuration class provides a structure for specifying various optimizer-related
|
||||
settings, including the learning rate and different learning rate schedules.
|
||||
|
||||
Attributes:
|
||||
lr (float): The base learning rate used by the optimizer.
|
||||
|
||||
Subclasses should inherit from this base class and define additional attributes specific to
|
||||
the optimizer algorithm they represent.
|
||||
|
||||
Example:
|
||||
To create a custom optimizer configuration, create a subclass of this base class and
|
||||
define the necessary attributes. For example:
|
||||
|
||||
```python
|
||||
class MyOptimizerConfig(OptimizerAlgorithmConfig):
|
||||
momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.")
|
||||
```
|
||||
|
||||
Note:
|
||||
This base class does not include specific optimizer settings. Subclasses should define
|
||||
the optimizer-specific attributes as needed.
|
||||
|
||||
"""
|
||||
|
||||
lr: float
|
||||
...
|
||||
|
||||
|
||||
class AdamConfig(OptimizerAlgorithmConfig):
|
||||
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
|
||||
"""
|
||||
Configuration for the Adam optimizer.
|
||||
|
||||
This configuration class allows you to specify the hyperparameters for the Adam optimizer.
|
||||
|
||||
Attributes:
|
||||
lr (float): The learning rate for optimization.
|
||||
betas (Tuple[float, float], optional): Coefficients used for computing running averages
|
||||
of gradient and squared gradient. Defaults to (0.9, 0.999).
|
||||
eps (float, optional): A small constant added to the denominator for numerical stability.
|
||||
Defaults to 1e-7.
|
||||
|
||||
Example:
|
||||
To configure the Adam optimizer, create an instance of this class and set the attributes
|
||||
accordingly. For example:
|
||||
|
||||
```python
|
||||
adam_optimizer = AdamConfig(
|
||||
lr=0.001,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8
|
||||
)
|
||||
```
|
||||
|
||||
See Also:
|
||||
[PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
|
||||
|
||||
"""
|
||||
lr: float
|
||||
betas: typing.Tuple[float, float] = [0.9, 0.999]
|
||||
eps: float = 1e-7 # Numerical stability in denominator.
|
||||
|
||||
|
||||
class SgdConfig(OptimizerAlgorithmConfig):
|
||||
"""
|
||||
Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
|
||||
This configuration class allows you to specify the hyperparameters for the SGD optimizer.
|
||||
|
||||
Attributes:
|
||||
lr (float): The learning rate for optimization.
|
||||
momentum (float, optional): The momentum factor for SGD. Defaults to 0.0.
|
||||
|
||||
Example:
|
||||
To configure the SGD optimizer, create an instance of this class and set the attributes
|
||||
accordingly. For example:
|
||||
|
||||
```python
|
||||
sgd_optimizer = SgdConfig(
|
||||
lr=0.01,
|
||||
momentum=0.9
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
lr: float
|
||||
momentum: float = 0.0
|
||||
|
||||
|
||||
class AdagradConfig(OptimizerAlgorithmConfig):
|
||||
"""
|
||||
Configuration for the optimizer used during training.
|
||||
|
||||
This configuration class allows you to specify the optimizer for training, including
|
||||
options for various optimizer algorithms.
|
||||
|
||||
Attributes:
|
||||
learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None.
|
||||
adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None.
|
||||
sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
Defaults to None.
|
||||
adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None.
|
||||
|
||||
Example:
|
||||
To configure the optimizer for training, create an instance of this class and set the
|
||||
attributes accordingly. For example:
|
||||
|
||||
```python
|
||||
optimizer_config = OptimizerConfig(
|
||||
learning_rate=LearningRate(constant=0.001),
|
||||
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
lr: float
|
||||
eps: float = 0
|
||||
|
||||
|
||||
class OptimizerConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for defining different optimizer algorithms and their parameters.
|
||||
|
||||
This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad,
|
||||
along with their respective hyperparameters.
|
||||
|
||||
Args:
|
||||
learning_rate (LearningRate): The learning rate configuration, which can include
|
||||
constant learning rates or other learning rate schedules.
|
||||
adam (AdamConfig): Configuration for the Adam optimizer.
|
||||
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
|
||||
|
||||
Example:
|
||||
```python
|
||||
optimizer_config = OptimizerConfig(
|
||||
learning_rate=LearningRate(constant=0.001),
|
||||
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8),
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
learning_rate (LearningRate): The learning rate configuration.
|
||||
adam (AdamConfig): Configuration for the Adam optimizer.
|
||||
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
|
||||
|
||||
Note:
|
||||
You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an
|
||||
`OptimizerConfig` instance.
|
||||
|
||||
See Also:
|
||||
- `LearningRate`: Configuration for specifying learning rates.
|
||||
- `AdamConfig`: Configuration for the Adam optimizer.
|
||||
- `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||
- `AdagradConfig`: Configuration for the Adagrad optimizer.
|
||||
|
||||
"""
|
||||
learning_rate: LearningRate = pydantic.Field(
|
||||
None,
|
||||
description="Constant learning rates",
|
||||
@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
|
||||
"""
|
||||
Get the optimizer algorithm configuration from the given `OptimizerConfig`.
|
||||
|
||||
This function extracts and returns the specific optimizer algorithm configuration
|
||||
(e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`.
|
||||
|
||||
Args:
|
||||
optimizer_config (OptimizerConfig): The optimizer configuration object containing
|
||||
one of the optimizer algorithm configurations.
|
||||
|
||||
Returns:
|
||||
Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm
|
||||
configuration extracted from `optimizer_config`.
|
||||
|
||||
Raises:
|
||||
ValueError: If no optimizer algorithm is selected in `optimizer_config`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
optimizer_config = OptimizerConfig(
|
||||
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
|
||||
)
|
||||
algorithm_config = get_optimizer_algorithm_config(optimizer_config)
|
||||
# `algorithm_config` will be an instance of `AdamConfig`.
|
||||
```
|
||||
|
||||
"""
|
||||
if optimizer_config.adam is not None:
|
||||
return optimizer_config.adam
|
||||
elif optimizer_config.sgd is not None:
|
||||
|
@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
def compute_lr(lr_config, step):
|
||||
"""Compute a learning rate."""
|
||||
"""
|
||||
Compute the learning rate based on the specified learning rate configuration.
|
||||
|
||||
This function calculates the learning rate according to the given configuration, which can include
|
||||
constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
|
||||
|
||||
Args:
|
||||
lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
|
||||
step (int): The current training step or iteration.
|
||||
|
||||
Returns:
|
||||
float: The computed learning rate for the current step.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `lr_config` is invalid or contains conflicting options.
|
||||
|
||||
Example:
|
||||
```python
|
||||
lr_schedule = LearningRate(
|
||||
constant=0.001,
|
||||
piecewise_constant=PiecewiseConstant(
|
||||
learning_rate_boundaries=[1000, 2000, 3000],
|
||||
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
|
||||
)
|
||||
)
|
||||
current_step = 2500
|
||||
learning_rate = compute_lr(lr_schedule, current_step)
|
||||
```
|
||||
"""
|
||||
if lr_config.constant is not None:
|
||||
return lr_config.constant
|
||||
elif lr_config.piecewise_constant is not None:
|
||||
@ -46,11 +74,54 @@ def compute_lr(lr_config, step):
|
||||
|
||||
|
||||
class LRShim(_LRScheduler):
|
||||
"""Shim to get learning rates into a LRScheduler.
|
||||
|
||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
||||
e.g. exponential decay can be used.
|
||||
"""
|
||||
Learning Rate Scheduler Shim to adjust learning rates during training.
|
||||
|
||||
This class acts as a shim to apply different learning rates to individual parameter groups
|
||||
within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
|
||||
optimizers, allowing fine-grained control over learning rates based on configuration.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
|
||||
lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
|
||||
corresponding learning rate configurations.
|
||||
last_epoch (int, optional): The index of the last epoch. Default is -1.
|
||||
verbose (bool, optional): If True, prints a warning message when accessing learning rates
|
||||
using the deprecated `get_lr()` method. Default is False.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of parameter groups in the optimizer does not match the number
|
||||
of learning rate configurations provided.
|
||||
|
||||
Note:
|
||||
To obtain the last computed learning rates, please use `get_last_lr()`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
lr_schedule = {
|
||||
'main': LearningRate(constant=0.01),
|
||||
'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
|
||||
learning_rate_boundaries=[1000, 2000],
|
||||
learning_rate_values=[0.01, 0.001]
|
||||
))
|
||||
}
|
||||
lr_shim = LRShim(optimizer, lr_schedule)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
# Train the model
|
||||
train(...)
|
||||
# Update learning rates at the end of each epoch
|
||||
lr_shim.step(epoch)
|
||||
|
||||
final_lr_main = lr_shim.get_last_lr()['main']
|
||||
final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
|
||||
```
|
||||
|
||||
See Also:
|
||||
- `LearningRate`: Configuration for specifying learning rates.
|
||||
- `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig):
|
||||
def build_optimizer(
|
||||
model: torch.nn.Module, optimizer_config: OptimizerConfig
|
||||
) -> Tuple[Optimizer, _LRScheduler]:
|
||||
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
|
||||
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
|
||||
"""
|
||||
Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
|
||||
optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
|
||||
algorithm and learning rate settings.
|
||||
|
||||
Returns:
|
||||
Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
|
||||
objects.
|
||||
|
||||
Note:
|
||||
This function is intended for cases where you want the same optimizer and learning rate
|
||||
schedule for all model parameters.
|
||||
|
||||
Example:
|
||||
```python
|
||||
model = MyModel()
|
||||
optimizer_config = OptimizerConfig(
|
||||
learning_rate=LearningRate(constant=0.01),
|
||||
sgd=SgdConfig(lr=0.01, momentum=0.9)
|
||||
)
|
||||
optimizer, scheduler = build_optimizer(model, optimizer_config)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
# Train the model with the optimizer
|
||||
train(model, optimizer, ...)
|
||||
# Update learning rates at the end of each epoch
|
||||
scheduler.step(epoch)
|
||||
```
|
||||
|
||||
See Also:
|
||||
- `OptimizerConfig`: Configuration for specifying optimizer settings.
|
||||
- `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
|
||||
"""
|
||||
optimizer_class = get_optimizer_class(optimizer_config)
|
||||
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
||||
# We're passing everything in as one group here
|
||||
|
@ -9,6 +9,36 @@ import pydantic
|
||||
|
||||
|
||||
class TrainingConfig(config_mod.BaseConfig):
|
||||
"""
|
||||
Configuration settings for the training process.
|
||||
|
||||
This class defines various training-related settings, including the directory to save checkpoints, the number
|
||||
of training steps, logging intervals, and other training parameters.
|
||||
|
||||
Attributes:
|
||||
save_dir (str): The directory where checkpoints and training artifacts will be saved.
|
||||
num_train_steps (pydantic.PositiveInt): The total number of training steps to run.
|
||||
initial_checkpoint_dir (str): The directory containing initial checkpoints (optional).
|
||||
checkpoint_every_n (pydantic.PositiveInt): Frequency of saving checkpoints during training.
|
||||
checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
|
||||
train_log_every_n (pydantic.PositiveInt): Frequency of logging training progress.
|
||||
num_eval_steps (int): Number of evaluation steps. Use a negative value to evaluate the entire dataset.
|
||||
eval_log_every_n (pydantic.PositiveInt): Frequency of logging evaluation progress.
|
||||
eval_timeout_in_s (pydantic.PositiveFloat): Maximum time (in seconds) allowed for evaluation.
|
||||
gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
|
||||
|
||||
Example:
|
||||
To configure training with checkpoints saved every 1000 steps, use the following settings:
|
||||
|
||||
```python
|
||||
TrainingConfig(
|
||||
save_dir="/tmp/model",
|
||||
num_train_steps=1000000,
|
||||
checkpoint_every_n=1000,
|
||||
train_log_every_n=1000,
|
||||
)
|
||||
```
|
||||
"""
|
||||
save_dir: str = "/tmp/model"
|
||||
num_train_steps: pydantic.PositiveInt = 1000000
|
||||
initial_checkpoint_dir: str = pydantic.Field(
|
||||
@ -32,6 +62,42 @@ class TrainingConfig(config_mod.BaseConfig):
|
||||
|
||||
|
||||
class RecapConfig(config_mod.BaseConfig):
|
||||
"""
|
||||
Configuration settings for the Recap model training process.
|
||||
|
||||
This class defines the overall configuration for the training process of a Recap model. It includes settings for
|
||||
training, model architecture, data, optimization, and evaluation.
|
||||
|
||||
Attributes:
|
||||
training (TrainingConfig): Configuration settings for the training process.
|
||||
model (model_config.ModelConfig): Configuration settings for the Recap model architecture.
|
||||
train_data (data_config.RecapDataConfig): Configuration settings for training data.
|
||||
validation_data (Dict[str, data_config.RecapDataConfig]): Configuration settings for validation data.
|
||||
optimizer (optimizer_config.RecapOptimizerConfig): Configuration settings for optimization.
|
||||
which_metrics (Optional[str]): Optional specification of which metrics to pick.
|
||||
|
||||
Note:
|
||||
This class encapsulates all the necessary configurations to train a Recap model. It defines settings for
|
||||
training, the model architecture, data loading, optimization, and evaluation.
|
||||
|
||||
Example:
|
||||
To configure a Recap model training process, use the following settings:
|
||||
|
||||
```python
|
||||
RecapConfig(
|
||||
training=TrainingConfig(
|
||||
save_dir="/tmp/model",
|
||||
num_train_steps=1000000,
|
||||
checkpoint_every_n=1000,
|
||||
train_log_every_n=1000,
|
||||
),
|
||||
model=model_config.ModelConfig(...),
|
||||
train_data=data_config.RecapDataConfig(...),
|
||||
validation_data={"dev": data_config.RecapDataConfig(...)},
|
||||
optimizer=optimizer_config.RecapOptimizerConfig(...),
|
||||
)
|
||||
```
|
||||
"""
|
||||
training: TrainingConfig = pydantic.Field(TrainingConfig())
|
||||
model: model_config.ModelConfig
|
||||
train_data: data_config.RecapDataConfig
|
||||
|
@ -30,6 +30,7 @@ class EmbeddingBagConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class EmbeddingOptimizerConfig(base_config.BaseConfig):
|
||||
"""Configuration for the optimizer used for embedding tables."""
|
||||
learning_rate: optimizer_config.LearningRate = pydantic.Field(
|
||||
None, description="learning rate scheduler for the EBC"
|
||||
)
|
||||
@ -52,6 +53,7 @@ class LargeEmbeddingsConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class StratifierConfig(base_config.BaseConfig):
|
||||
"""Configuration for Stratifier."""
|
||||
name: str
|
||||
index: int
|
||||
value: int
|
||||
@ -87,7 +89,8 @@ class SmallEmbeddingsConfig(base_config.BaseConfig):
|
||||
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
|
||||
serving time due to size (>>1 GB).
|
||||
|
||||
This small embeddings table uses the same optimizer as the rest of the model."""
|
||||
This small embeddings table uses the same optimizer as the rest of the model.
|
||||
"""
|
||||
|
||||
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
|
||||
..., description="list of embedding tables"
|
||||
|
@ -34,6 +34,33 @@ FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
|
||||
"""
|
||||
Main function to run the training of a ranking model.
|
||||
|
||||
This function initializes and runs the training process for a ranking model based on the provided configuration.
|
||||
|
||||
Args:
|
||||
unused_argv (str): Unused argument.
|
||||
data_service_dispatcher (Optional[str]): The data service dispatcher for accessing training data (optional).
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
AssertionError: If the configuration or input data is not valid.
|
||||
|
||||
Note:
|
||||
This function serves as the main entry point for training a ranking model. It loads the configuration, sets up
|
||||
the training environment, defines the loss function, creates the model, optimizer, and scheduler, and runs the
|
||||
training loop.
|
||||
|
||||
Example:
|
||||
To run the training process, use the following command:
|
||||
|
||||
```
|
||||
python run_training.py --config_path=config.yaml
|
||||
```
|
||||
"""
|
||||
print("#" * 100)
|
||||
|
||||
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
||||
|
@ -40,6 +40,7 @@ class BatchNormConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class DenseLayerConfig(base_config.BaseConfig):
|
||||
"""Configuration for the dense layer."""
|
||||
layer_size: pydantic.PositiveInt
|
||||
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
|
||||
|
||||
@ -61,6 +62,7 @@ class BatchNormConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class DoubleNormLogConfig(base_config.BaseConfig):
|
||||
"""Configuration for the double norm log transform."""
|
||||
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
|
||||
clip_magnitude: float = pydantic.Field(
|
||||
5.0, description="Threshold to clip the normalized input values."
|
||||
@ -73,12 +75,14 @@ class Log1pAbsConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class ClipLog1pAbsConfig(base_config.BaseConfig):
|
||||
"""Configuration for the clip log transform."""
|
||||
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
|
||||
3e38, description="Threshold to clip the input values."
|
||||
)
|
||||
|
||||
|
||||
class ZScoreLogConfig(base_config.BaseConfig):
|
||||
"""Configuration for the z-score log transform."""
|
||||
analysis_path: str
|
||||
schema_path: str = pydantic.Field(
|
||||
None,
|
||||
@ -148,6 +152,7 @@ class DcnConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class MaskBlockConfig(base_config.BaseConfig):
|
||||
"""Config for MaskNet block."""
|
||||
output_size: int
|
||||
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
|
||||
None, one_of="aggregation_size"
|
||||
@ -159,6 +164,7 @@ class MaskBlockConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class MaskNetConfig(base_config.BaseConfig):
|
||||
"""Config for MaskNet model."""
|
||||
mask_blocks: List[MaskBlockConfig]
|
||||
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
|
||||
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
|
||||
@ -190,6 +196,7 @@ class AffineMap(base_config.BaseConfig):
|
||||
|
||||
|
||||
class DLRMConfig(base_config.BaseConfig):
|
||||
"""Config for DLRM model."""
|
||||
bottom_mlp: MlpConfig = pydantic.Field(
|
||||
...,
|
||||
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
|
||||
@ -198,6 +205,7 @@ class DLRMConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class TaskModel(base_config.BaseConfig):
|
||||
"""Configuration for a single task."""
|
||||
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
|
||||
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
|
||||
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
|
||||
@ -213,6 +221,7 @@ class TaskModel(base_config.BaseConfig):
|
||||
|
||||
|
||||
class MultiTaskType(str, enum.Enum):
|
||||
"""Type of multi task architecture."""
|
||||
SHARE_NONE = "share_none" # Tasks are separate.
|
||||
SHARE_ALL = "share_all" # Tasks share same backbone.
|
||||
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
|
||||
@ -247,6 +256,7 @@ class ModelConfig(base_config.BaseConfig):
|
||||
|
||||
@pydantic.root_validator()
|
||||
def _validate_mtl(cls, values):
|
||||
"""Validate the multi task architecture."""
|
||||
if values.get("multi_task_type", None) is None:
|
||||
return values
|
||||
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:
|
||||
|
@ -26,7 +26,19 @@ def unsanitize(sanitized_task_name):
|
||||
|
||||
|
||||
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
|
||||
""" "Builds a model for a single task"""
|
||||
"""
|
||||
Build a model for a single task based on the provided configuration.
|
||||
|
||||
Args:
|
||||
task (model_config_mod.TaskModel): The task model configuration.
|
||||
input_shape (int): The input shape for the model.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: The constructed model for the single task.
|
||||
|
||||
Raises:
|
||||
ValueError: If the task configuration is not recognized.
|
||||
"""
|
||||
if task.mlp_config:
|
||||
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
|
||||
elif task.dcn_config:
|
||||
@ -38,7 +50,12 @@ def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int)
|
||||
|
||||
|
||||
class MultiTaskRankingModel(torch.nn.Module):
|
||||
"""Multi-task ranking model."""
|
||||
"""
|
||||
Multi-task ranking model that handles multiple ranking tasks simultaneously.
|
||||
|
||||
This model takes various input features and predicts rankings for multiple
|
||||
tasks using shared or separate towers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -47,12 +64,18 @@ class MultiTaskRankingModel(torch.nn.Module):
|
||||
data_config: RecapDataConfig,
|
||||
return_backbone: bool = False,
|
||||
):
|
||||
"""Constructor for Multi task learning.
|
||||
"""
|
||||
Constructor for Multi-task ranking model.
|
||||
|
||||
Args:
|
||||
input_shapes (Mapping[str, torch.Size]): A mapping of input feature names to their shapes.
|
||||
config (ModelConfig): The model configuration.
|
||||
data_config (RecapDataConfig): The data configuration.
|
||||
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
|
||||
|
||||
Assumptions made:
|
||||
1. Tasks specified in data config match model architecture.
|
||||
|
||||
These are all validated in config.
|
||||
1. Tasks specified in data config match model architecture.
|
||||
These are all validated in config.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -168,6 +191,23 @@ class MultiTaskRankingModel(torch.nn.Module):
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
weights: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Forward pass of the Multi-task ranking model.
|
||||
|
||||
Args:
|
||||
continuous_features (torch.Tensor): Continuous input features.
|
||||
binary_features (torch.Tensor): Binary input features.
|
||||
discrete_features (Optional[torch.Tensor], optional): Discrete input features. Defaults to None.
|
||||
sparse_features ([type], optional): Sparse input features. Defaults to None.
|
||||
user_embedding (Optional[torch.Tensor], optional): User embeddings. Defaults to None.
|
||||
user_eng_embedding (Optional[torch.Tensor], optional): User engagement embeddings. Defaults to None.
|
||||
author_embedding (Optional[torch.Tensor], optional): Author embeddings. Defaults to None.
|
||||
labels (Optional[torch.Tensor], optional): Target labels. Defaults to None.
|
||||
weights (Optional[torch.Tensor], optional): Weights for the loss function. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary containing the model's outputs.
|
||||
"""
|
||||
concat_dense_features = [
|
||||
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
|
||||
]
|
||||
@ -270,6 +310,20 @@ def create_ranking_model(
|
||||
data_config=None,
|
||||
return_backbone=False,
|
||||
):
|
||||
"""
|
||||
Creates a ranking model based on the provided specifications and configuration.
|
||||
|
||||
Args:
|
||||
data_spec: The input data specifications.
|
||||
config (config_mod.RecapConfig): The model configuration.
|
||||
device (torch.device): The device where the model should be placed.
|
||||
loss_fn (Optional[Callable], optional): A custom loss function. Defaults to None.
|
||||
data_config: The data configuration. Defaults to None.
|
||||
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: The created ranking model.
|
||||
"""
|
||||
|
||||
if list(config.model.tasks.values())[0].dlrm_config:
|
||||
raise NotImplementedError()
|
||||
|
@ -11,21 +11,52 @@ import torch
|
||||
|
||||
|
||||
def log_transform(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Safe log transform that works across both negative, zero, and positive floats."""
|
||||
"""
|
||||
Safe log transform that works across both negative, zero, and positive floats.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with log1p applied to absolute values.
|
||||
"""
|
||||
return torch.sign(x) * torch.log1p(torch.abs(x))
|
||||
|
||||
|
||||
class BatchNorm(torch.nn.Module):
|
||||
def __init__(self, num_features: int, config: BatchNormConfig):
|
||||
"""
|
||||
Batch normalization layer.
|
||||
|
||||
Args:
|
||||
num_features (int): Number of input features.
|
||||
config (BatchNormConfig): Configuration for batch normalization.
|
||||
"""
|
||||
super().__init__()
|
||||
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the batch normalization layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after batch normalization.
|
||||
"""
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.Module):
|
||||
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
|
||||
"""
|
||||
Layer normalization layer.
|
||||
|
||||
Args:
|
||||
normalized_shape (Union[int, Sequence[int]]): Size or shape of the input tensor.
|
||||
config (LayerNormConfig): Configuration for layer normalization.
|
||||
"""
|
||||
super().__init__()
|
||||
if config.axis != -1:
|
||||
raise NotImplementedError
|
||||
@ -38,6 +69,16 @@ class LayerNorm(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the layer normalization layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after layer normalization.
|
||||
"""
|
||||
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
@ -46,11 +87,27 @@ class Log1pAbs(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that applies a log transformation to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with log applied to absolute values.
|
||||
"""
|
||||
|
||||
return log_transform(x)
|
||||
|
||||
|
||||
class InputNonFinite(torch.nn.Module):
|
||||
def __init__(self, fill_value: float = 0):
|
||||
"""
|
||||
Replaces non-finite (NaN and Inf) values in the input tensor with a specified fill value.
|
||||
|
||||
Args:
|
||||
fill_value (float): The value to fill non-finite elements with. Default is 0.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer(
|
||||
@ -58,11 +115,27 @@ class InputNonFinite(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that replaces non-finite values in the input tensor with the specified fill value.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with non-finite values replaced.
|
||||
"""
|
||||
return torch.where(torch.isfinite(x), x, self.fill_value)
|
||||
|
||||
|
||||
class Clamp(torch.nn.Module):
|
||||
def __init__(self, min_value: float, max_value: float):
|
||||
"""
|
||||
Applies element-wise clamping to a tensor, ensuring that values are within a specified range.
|
||||
|
||||
Args:
|
||||
min_value (float): The minimum value to clamp elements to.
|
||||
max_value (float): The maximum value to clamp elements to.
|
||||
"""
|
||||
super().__init__()
|
||||
# Using buffer to make sure they are on correct device (and not moved every time).
|
||||
# Will also be part of state_dict.
|
||||
@ -74,12 +147,31 @@ class Clamp(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that clamps the input tensor element-wise within the specified range.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with elements clamped within the specified range.
|
||||
"""
|
||||
return torch.clamp(x, min=self.min_value, max=self.max_value)
|
||||
|
||||
|
||||
class DoubleNormLog(torch.nn.Module):
|
||||
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features."""
|
||||
"""
|
||||
Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.
|
||||
|
||||
Args:
|
||||
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
|
||||
config (DoubleNormLogConfig): Configuration for the DoubleNormLog module.
|
||||
|
||||
Attributes:
|
||||
_before_concat_layers (torch.nn.Sequential): Sequential layers for batch normalization, log transformation,
|
||||
batch normalization (optional), and clamping.
|
||||
layer_norm (LayerNorm or None): Layer normalization layer for binary and continuous features (optional).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_shapes: Mapping[str, Sequence[int]],
|
||||
@ -108,6 +200,17 @@ class DoubleNormLog(torch.nn.Module):
|
||||
def forward(
|
||||
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that processes continuous and binary features using batch normalization, log transformation,
|
||||
optional batch normalization (if configured), clamping, and layer normalization (if configured).
|
||||
|
||||
Args:
|
||||
continuous_features (torch.Tensor): Input tensor of continuous features.
|
||||
binary_features (torch.Tensor): Input tensor of binary features.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor containing both continuous and binary features.
|
||||
"""
|
||||
x = self._before_concat_layers(continuous_features)
|
||||
x = torch.cat([x, binary_features], dim=1)
|
||||
if self.layer_norm:
|
||||
@ -118,5 +221,15 @@ class DoubleNormLog(torch.nn.Module):
|
||||
def build_features_preprocessor(
|
||||
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
|
||||
):
|
||||
"""Trivial right now, but we will change in the future."""
|
||||
"""
|
||||
Build a feature preprocessor module based on the provided configuration.
|
||||
Trivial right now, but we will change in the future.
|
||||
|
||||
Args:
|
||||
config (FeaturizationConfig): Configuration for feature preprocessing.
|
||||
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
|
||||
|
||||
Returns:
|
||||
DoubleNormLog: An instance of the DoubleNormLog feature preprocessor.
|
||||
"""
|
||||
return DoubleNormLog(input_shapes, config.double_norm_log_config)
|
||||
|
@ -6,15 +6,84 @@ import torch
|
||||
|
||||
|
||||
def _init_weights(module):
|
||||
"""Initializes weights
|
||||
|
||||
Example
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Define a simple linear layer
|
||||
linear_layer = nn.Linear(64, 32)
|
||||
|
||||
# Initialize the weights and biases using _init_weights
|
||||
_init_weights(linear_layer)
|
||||
```
|
||||
|
||||
"""
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
torch.nn.init.constant_(module.bias, 0)
|
||||
|
||||
|
||||
class MaskBlock(torch.nn.Module):
|
||||
"""
|
||||
MaskBlock module in a mask-based neural network.
|
||||
|
||||
This module represents a MaskBlock, which applies a masking operation to the input data and then
|
||||
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
|
||||
|
||||
Args:
|
||||
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
||||
input_dim (int): Dimensionality of the input data.
|
||||
mask_input_dim (int): Dimensionality of the mask input.
|
||||
|
||||
Example:
|
||||
To create and use a MaskBlock within a MaskNet, follow these steps:
|
||||
|
||||
```python
|
||||
# Define the configuration for the MaskBlock
|
||||
mask_block_config = MaskBlockConfig(
|
||||
input_layer_norm=True, # Apply input layer normalization
|
||||
reduction_factor=0.5 # Reduce input dimensionality by 50%
|
||||
)
|
||||
|
||||
# Create an instance of the MaskBlock
|
||||
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
|
||||
|
||||
# Generate input tensors
|
||||
input_data = torch.randn(batch_size, 64)
|
||||
mask_input = torch.randn(batch_size, 32)
|
||||
|
||||
# Perform a forward pass through the MaskBlock
|
||||
output = mask_block(input_data, mask_input)
|
||||
```
|
||||
|
||||
Note:
|
||||
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
|
||||
operation that combines the input and mask input. Then, it passes the result through a hidden layer
|
||||
with optional dimensionality reduction.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(
|
||||
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the MaskBlock module.
|
||||
|
||||
Args:
|
||||
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
||||
input_dim (int): Dimensionality of the input data.
|
||||
mask_input_dim (int): Dimensionality of the mask input.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
super(MaskBlock, self).__init__()
|
||||
self.mask_block_config = mask_block_config
|
||||
output_size = mask_block_config.output_size
|
||||
@ -42,6 +111,16 @@ class MaskBlock(torch.nn.Module):
|
||||
self._layer_norm = torch.nn.LayerNorm(output_size)
|
||||
|
||||
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
|
||||
"""
|
||||
Performs a forward pass through the MaskBlock.
|
||||
|
||||
Args:
|
||||
net (torch.Tensor): Input data tensor.
|
||||
mask_input (torch.Tensor): Mask input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of the MaskBlock.
|
||||
"""
|
||||
if self._input_layer_norm:
|
||||
net = self._input_layer_norm(net)
|
||||
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
||||
@ -49,7 +128,60 @@ class MaskBlock(torch.nn.Module):
|
||||
|
||||
|
||||
class MaskNet(torch.nn.Module):
|
||||
"""
|
||||
MaskNet module in a mask-based neural network.
|
||||
|
||||
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
|
||||
create mask-based neural networks with parallel or stacked MaskBlocks.
|
||||
|
||||
Args:
|
||||
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
||||
in_features (int): Dimensionality of the input data.
|
||||
|
||||
Example:
|
||||
To create and use a MaskNet, you can follow these steps:
|
||||
|
||||
```python
|
||||
# Define the configuration for the MaskNet
|
||||
mask_net_config = MaskNetConfig(
|
||||
use_parallel=True, # Use parallel MaskBlocks
|
||||
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
|
||||
)
|
||||
|
||||
# Create an instance of the MaskNet
|
||||
mask_net = MaskNet(mask_net_config, in_features=64)
|
||||
|
||||
# Generate input tensors
|
||||
input_data = torch.randn(batch_size, 64)
|
||||
|
||||
# Perform a forward pass through the MaskNet
|
||||
outputs = mask_net(input_data)
|
||||
|
||||
# Access the output and shared layer
|
||||
output = outputs["output"]
|
||||
shared_layer = outputs["shared_layer"]
|
||||
```
|
||||
|
||||
Note:
|
||||
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
|
||||
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
||||
"""
|
||||
Initializes the MaskNet module.
|
||||
|
||||
Args:
|
||||
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
||||
in_features (int): Dimensionality of the input data.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.mask_net_config = mask_net_config
|
||||
mask_blocks = []
|
||||
@ -77,6 +209,15 @@ class MaskNet(torch.nn.Module):
|
||||
self.shared_size = total_output_mask_blocks
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
"""
|
||||
Performs a forward pass through the MaskNet.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): Input data tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of the MaskNet.
|
||||
"""
|
||||
if self.mask_net_config.use_parallel:
|
||||
mask_outputs = []
|
||||
for mask_layer in self._mask_blocks:
|
||||
|
@ -7,13 +7,83 @@ from absl import logging
|
||||
|
||||
|
||||
def _init_weights(module):
|
||||
"""Initializes weights
|
||||
|
||||
Example
|
||||
-------
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Define a simple linear layer
|
||||
linear_layer = nn.Linear(64, 32)
|
||||
|
||||
# Initialize the weights and biases using _init_weights
|
||||
_init_weights(linear_layer)
|
||||
```
|
||||
|
||||
"""
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
torch.nn.init.constant_(module.bias, 0)
|
||||
|
||||
|
||||
class Mlp(torch.nn.Module):
|
||||
"""
|
||||
Multi-Layer Perceptron (MLP) feedforward neural network module in PyTorch.
|
||||
|
||||
This module defines an MLP with customizable layers and activation functions. It is suitable for various
|
||||
applications such as deep learning for tabular data, feature extraction, and more.
|
||||
|
||||
Args:
|
||||
in_features (int): The number of input features or input dimensions.
|
||||
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
|
||||
|
||||
Example:
|
||||
To create an instance of the `Mlp` module and use it for forward passes, you can follow these steps:
|
||||
|
||||
```python
|
||||
# Define the configuration for the MLP
|
||||
mlp_config = MlpConfig(
|
||||
layer_sizes=[128, 64], # Specify the sizes of hidden layers
|
||||
batch_norm=True, # Enable batch normalization
|
||||
dropout=0.2, # Apply dropout with a rate of 0.2
|
||||
final_layer_activation=True # Apply ReLU activation to the final layer
|
||||
)
|
||||
|
||||
# Create an instance of the MLP module
|
||||
mlp_model = Mlp(in_features=input_dim, mlp_config=mlp_config)
|
||||
|
||||
# Generate an input tensor
|
||||
input_tensor = torch.randn(batch_size, input_dim)
|
||||
|
||||
# Perform a forward pass through the MLP
|
||||
outputs = mlp_model(input_tensor)
|
||||
|
||||
# Access the output and shared layer
|
||||
output = outputs["output"]
|
||||
shared_layer = outputs["shared_layer"]
|
||||
```
|
||||
|
||||
Note:
|
||||
The `Mlp` class allows you to create customizable MLP architectures by specifying the layer sizes,
|
||||
enabling batch normalization and dropout, and choosing the activation function for the final layer.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(self, in_features: int, mlp_config: MlpConfig):
|
||||
"""
|
||||
Initializes the Mlp module.
|
||||
|
||||
Args:
|
||||
in_features (int): The number of input features or input dimensions.
|
||||
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
super().__init__()
|
||||
self._mlp_config = mlp_config
|
||||
input_size = in_features
|
||||
@ -42,6 +112,15 @@ class Mlp(torch.nn.Module):
|
||||
self.layers.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Performs a forward pass through the MLP.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of the MLP.
|
||||
"""
|
||||
net = x
|
||||
for i, layer in enumerate(self.layers):
|
||||
net = layer(net)
|
||||
@ -51,8 +130,21 @@ class Mlp(torch.nn.Module):
|
||||
|
||||
@property
|
||||
def shared_size(self):
|
||||
"""
|
||||
Returns the size of the shared layer in the MLP.
|
||||
|
||||
Returns:
|
||||
int: Size of the shared layer.
|
||||
"""
|
||||
return self._mlp_config.layer_sizes[-1]
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
"""
|
||||
Returns the number of output features from the MLP.
|
||||
|
||||
Returns:
|
||||
int: Number of output features.
|
||||
"""
|
||||
|
||||
return self._mlp_config.layer_sizes[-1]
|
||||
|
@ -5,6 +5,53 @@ from absl import logging
|
||||
|
||||
|
||||
class ModelAndLoss(torch.nn.Module):
|
||||
"""
|
||||
PyTorch module that combines a neural network model and loss function.
|
||||
|
||||
This module wraps a neural network model and facilitates the forward pass through the model
|
||||
while also calculating the loss based on the model's predictions and provided labels.
|
||||
|
||||
Args:
|
||||
model: The torch module to wrap.
|
||||
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
||||
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
||||
for metrics stratification. Each stratifier config includes the name and index of discrete features
|
||||
to emit for stratification.
|
||||
|
||||
Example:
|
||||
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
|
||||
and loss function as arguments:
|
||||
|
||||
```python
|
||||
# Create a neural network model
|
||||
model = YourNeuralNetworkModel()
|
||||
|
||||
# Define a loss function
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# Create an instance of ModelAndLoss
|
||||
model_and_loss = ModelAndLoss(model, loss_fn)
|
||||
|
||||
# Generate a batch of training data (e.g., RecapBatch)
|
||||
batch = generate_training_batch()
|
||||
|
||||
# Perform a forward pass through the model and calculate the loss
|
||||
loss, outputs = model_and_loss(batch)
|
||||
|
||||
# You can now backpropagate and optimize using the computed loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
Note:
|
||||
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
|
||||
calculating loss, making it easier to integrate the model into your training loop. Additionally,
|
||||
it supports the addition of stratifiers for metrics stratification, if needed.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@ -12,10 +59,13 @@ class ModelAndLoss(torch.nn.Module):
|
||||
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
model: torch module to wrap.
|
||||
loss_fn: Function for calculating loss, should accept logits and labels.
|
||||
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.
|
||||
Initializes the ModelAndLoss module.
|
||||
|
||||
Args:
|
||||
model: The torch module to wrap.
|
||||
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
||||
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
||||
for metrics stratification.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
@ -2,11 +2,57 @@ import torch
|
||||
|
||||
|
||||
class NumericCalibration(torch.nn.Module):
|
||||
"""
|
||||
Numeric calibration module for adjusting probability scores.
|
||||
|
||||
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
|
||||
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
|
||||
for tasks such as binary classification.
|
||||
|
||||
Args:
|
||||
pos_downsampling_rate (float): The downsampling rate for positive samples.
|
||||
neg_downsampling_rate (float): The downsampling rate for negative samples.
|
||||
|
||||
Example:
|
||||
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
|
||||
scores like this:
|
||||
|
||||
```python
|
||||
# Create a NumericCalibration instance with downsampling rates
|
||||
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
|
||||
|
||||
# Generate probability scores (e.g., from a neural network)
|
||||
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
|
||||
|
||||
# Apply numeric calibration to adjust the probabilities
|
||||
calibrated_probs = calibration(raw_probs)
|
||||
|
||||
# The `calibrated_probs` now contains the adjusted probability scores
|
||||
```
|
||||
|
||||
Note:
|
||||
The `NumericCalibration` module is used to adjust probability scores to account for differences in
|
||||
the number of positive and negative samples in a dataset. It can help improve the calibration of
|
||||
probability estimates in imbalanced classification problems.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within neural network architectures and should not be
|
||||
directly accessed or modified by external code.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
pos_downsampling_rate: float,
|
||||
neg_downsampling_rate: float,
|
||||
):
|
||||
"""
|
||||
Apply numeric calibration to probability scores.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability scores to be calibrated.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Calibrated probability scores.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Using buffer to make sure they are on correct device (and not moved every time).
|
||||
|
@ -9,12 +9,60 @@ import pydantic
|
||||
|
||||
|
||||
class RecapAdamConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration settings for the Adam optimizer used in Recap.
|
||||
|
||||
Args:
|
||||
beta_1 (float): Momentum term (default: 0.9).
|
||||
beta_2 (float): Exponential weighted decay factor (default: 0.999).
|
||||
epsilon (float): Numerical stability in the denominator (default: 1e-7).
|
||||
|
||||
Example:
|
||||
To define an Adam optimizer configuration for Recap, use:
|
||||
|
||||
```python
|
||||
adam_config = RecapAdamConfig(beta_1=0.9, beta_2=0.999, epsilon=1e-7)
|
||||
```
|
||||
|
||||
Note:
|
||||
This class configures the parameters of the Adam optimizer, which is commonly used for optimizing neural networks.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
|
||||
"""
|
||||
|
||||
beta_1: float = 0.9 # Momentum term.
|
||||
beta_2: float = 0.999 # Exponential weighted decay factor.
|
||||
epsilon: float = 1e-7 # Numerical stability in denominator.
|
||||
|
||||
|
||||
class MultiTaskLearningRates(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration settings for multiple learning rates in Recap.
|
||||
|
||||
Args:
|
||||
tower_learning_rates (Dict[str, optimizers_config_mod.LearningRate]): Learning rates for different towers of the model.
|
||||
backbone_learning_rate (optimizers_config_mod.LearningRate): Learning rate for the model's backbone (default: None).
|
||||
|
||||
Example:
|
||||
To define multiple learning rates for different towers in Recap, use:
|
||||
|
||||
```python
|
||||
multi_task_lr = MultiTaskLearningRates(
|
||||
tower_learning_rates={
|
||||
'task1': learning_rate1,
|
||||
'task2': learning_rate2,
|
||||
},
|
||||
backbone_learning_rate=backbone_lr,
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
This class allows specifying different learning rates for different parts of the model, including task-specific towers and the backbone.
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
|
||||
"""
|
||||
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
|
||||
description="Learning rates for different towers of the model."
|
||||
)
|
||||
@ -25,6 +73,30 @@ class MultiTaskLearningRates(base_config.BaseConfig):
|
||||
|
||||
|
||||
class RecapOptimizerConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration settings for the Recap optimizer.
|
||||
|
||||
Args:
|
||||
multi_task_learning_rates (MultiTaskLearningRates): Multiple learning rates for different tasks (optional).
|
||||
single_task_learning_rate (optimizers_config_mod.LearningRate): Learning rate for a single task (optional).
|
||||
adam (RecapAdamConfig): Configuration settings for the Adam optimizer.
|
||||
|
||||
Example:
|
||||
To define an optimizer configuration for training with Recap, use:
|
||||
|
||||
```python
|
||||
optimizer_config = RecapOptimizerConfig(
|
||||
multi_task_learning_rates=multi_task_lr,
|
||||
single_task_learning_rate=single_task_lr,
|
||||
adam=adam_config,
|
||||
)
|
||||
```
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use to configure the optimizer settings within Recap and should not be
|
||||
directly accessed by external code.
|
||||
"""
|
||||
|
||||
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
|
||||
None, description="Multiple learning rates for different tasks.", one_of="lr"
|
||||
)
|
||||
|
@ -23,12 +23,30 @@ _DENSE_EMBEDDINGS = "dense_ebc"
|
||||
|
||||
|
||||
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
|
||||
"""Shim to get learning rates into a LRScheduler.
|
||||
|
||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
||||
e.g. exponential decay can be used.
|
||||
|
||||
"""
|
||||
A shim to get learning rates into a LRScheduler.
|
||||
|
||||
This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports
|
||||
learning rate schedules, such as exponential decay.
|
||||
|
||||
Args:
|
||||
optimizer: The optimizer to which this scheduler is applied.
|
||||
lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations.
|
||||
emb_learning_rate: The learning rate for embeddings (optional).
|
||||
last_epoch (int): The index of the last epoch (default: -1).
|
||||
verbose (bool): If True, print warnings for deprecated functions (default: False).
|
||||
|
||||
Example:
|
||||
To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use:
|
||||
|
||||
```python
|
||||
scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate)
|
||||
```
|
||||
|
||||
Warning:
|
||||
This class is intended for internal use to handle learning rate scheduling within Recap training and should not
|
||||
be directly accessed by external code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -80,15 +98,25 @@ def build_optimizer(
|
||||
optimizer_config: config.OptimizerConfig,
|
||||
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
|
||||
):
|
||||
"""Builds an optimizer and scheduler.
|
||||
|
||||
Args:
|
||||
model: A torch model, probably with DDP/DMP.
|
||||
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
|
||||
|
||||
Returns:
|
||||
A torch.optim instance, and a scheduler instance.
|
||||
"""
|
||||
Build an optimizer and scheduler for training.
|
||||
|
||||
Args:
|
||||
model: The torch model, possibly with DDP/DMP.
|
||||
optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer.
|
||||
emb_optimizer_config: Configuration settings for embedding optimization (optional).
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: The optimizer for training.
|
||||
RecapLRShim: The learning rate scheduler for the optimizer.
|
||||
|
||||
Example:
|
||||
To build an optimizer and scheduler for training, use:
|
||||
|
||||
```python
|
||||
optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config)
|
||||
```
|
||||
"""
|
||||
optimizer_fn = functools.partial(
|
||||
torch.optim.Adam,
|
||||
lr=_DEFAULT_LR,
|
||||
|
@ -4,6 +4,17 @@ import pydantic
|
||||
|
||||
|
||||
class TwhinDataConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for Twhin model training data.
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for the training data.
|
||||
per_replica_batch_size (pydantic.PositiveInt): Batch size per replica.
|
||||
global_negatives (int): The number of global negatives.
|
||||
in_batch_negatives (int): The number of in-batch negatives.
|
||||
limit (pydantic.PositiveInt): The limit on the number of data points to use.
|
||||
offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None.
|
||||
"""
|
||||
data_root: str
|
||||
per_replica_batch_size: pydantic.PositiveInt
|
||||
global_negatives: int
|
||||
|
@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset
|
||||
|
||||
|
||||
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
|
||||
"""
|
||||
Create a dataset for Twhin model training.
|
||||
|
||||
Args:
|
||||
data_config (TwhinDataConfig): The data configuration for the dataset.
|
||||
model_config (TwhinModelConfig): The model configuration containing embeddings and relations.
|
||||
|
||||
Returns:
|
||||
EdgesDataset: The dataset for Twhin model training.
|
||||
"""
|
||||
tables = model_config.embeddings.tables
|
||||
table_sizes = {table.name: table.num_embeddings for table in tables}
|
||||
relations = model_config.relations
|
||||
|
@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
||||
|
||||
@dataclass
|
||||
class EdgeBatch(DataclassBatch):
|
||||
"""
|
||||
Batch data structure for edge-based models.
|
||||
|
||||
Args:
|
||||
nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
|
||||
labels (torch.Tensor): Tensor containing labels.
|
||||
rels (torch.Tensor): Tensor containing relation information.
|
||||
weights (torch.Tensor): Tensor containing weights.
|
||||
"""
|
||||
nodes: KeyedJaggedTensor
|
||||
labels: torch.Tensor
|
||||
rels: torch.Tensor
|
||||
@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch):
|
||||
|
||||
|
||||
class EdgesDataset(Dataset):
|
||||
"""
|
||||
Dataset for edge-based models.
|
||||
|
||||
Args:
|
||||
file_pattern (str): The file pattern for the dataset.
|
||||
table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
|
||||
relations (List[Relation]): A list of relations between tables.
|
||||
lhs_column_name (str): The name of the left-hand-side column.
|
||||
rhs_column_name (str): The name of the right-hand-side column.
|
||||
rel_column_name (str): The name of the relation column.
|
||||
**dataset_kwargs: Additional keyword arguments for the parent Dataset class.
|
||||
"""
|
||||
rng = np.random.default_rng()
|
||||
|
||||
def __init__(
|
||||
@ -56,6 +77,15 @@ class EdgesDataset(Dataset):
|
||||
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
|
||||
|
||||
def pa_to_batch(self, batch: pa.RecordBatch):
|
||||
"""
|
||||
Converts a pyarrow RecordBatch to an EdgeBatch.
|
||||
|
||||
Args:
|
||||
batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
|
||||
|
||||
Returns:
|
||||
EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
|
||||
"""
|
||||
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
|
||||
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
|
||||
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
|
||||
@ -74,6 +104,14 @@ class EdgesDataset(Dataset):
|
||||
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
|
||||
|
||||
"""Process edges that contain lhs index, rhs index, relation index.
|
||||
|
||||
Args:
|
||||
lhs (torch.Tensor): Tensor containing left-hand-side indices.
|
||||
rhs (torch.Tensor): Tensor containing right-hand-side indices.
|
||||
rel (torch.Tensor): Tensor containing relation indices.
|
||||
|
||||
Returns:
|
||||
Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
|
||||
Example:
|
||||
|
||||
```
|
||||
@ -147,6 +185,12 @@ class EdgesDataset(Dataset):
|
||||
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
|
||||
|
||||
def to_batches(self):
|
||||
"""
|
||||
Converts data to batches.
|
||||
|
||||
Yields:
|
||||
pa.RecordBatch: A pyarrow RecordBatch containing data.
|
||||
"""
|
||||
ds = super().to_batches()
|
||||
batch_size = self._dataset_kwargs["batch_size"]
|
||||
|
||||
|
@ -23,6 +23,25 @@ RELATIONS = [
|
||||
|
||||
|
||||
def test_gen():
|
||||
"""Test function for generating edge-based datasets and dataloaders.
|
||||
|
||||
This function generates a synthetic dataset and tests the creation of an `EdgesDataset`
|
||||
instance and a dataloader for it.
|
||||
|
||||
The test includes the following steps:
|
||||
1. Create synthetic data with left-hand-side (lhs), right-hand-side (rhs), and relation (rel) columns.
|
||||
2. Write the synthetic data to a Parquet file.
|
||||
3. Create an `EdgesDataset` instance with the Parquet file pattern, table sizes, relations, and batch size.
|
||||
4. Initialize the local file system for the dataset.
|
||||
5. Create a dataloader for the dataset and retrieve the first batch.
|
||||
6. Assert that the labels in the batch are positive.
|
||||
7. Verify that the positive examples in the batch match the expected values.
|
||||
|
||||
This function serves as a test case for the data generation and dataset creation process.
|
||||
|
||||
Raises:
|
||||
AssertionError: If any of the test assertions fail.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
@ -10,8 +10,29 @@ from pydantic import validator
|
||||
|
||||
|
||||
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
||||
"""
|
||||
Configuration class for Twhin model embeddings.
|
||||
|
||||
This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types
|
||||
for all tables in the Twhin model embeddings configuration match.
|
||||
|
||||
Attributes:
|
||||
tables (List[TableConfig]): A list of table configurations for the model's embeddings.
|
||||
"""
|
||||
@validator("tables")
|
||||
def embedding_dims_match(cls, tables):
|
||||
"""
|
||||
Validate that embedding dimensions and data types match for all tables.
|
||||
|
||||
Args:
|
||||
tables (List[TableConfig]): List of table configurations.
|
||||
|
||||
Returns:
|
||||
List[TableConfig]: The list of validated table configurations.
|
||||
|
||||
Raises:
|
||||
AssertionError: If embedding dimensions or data types do not match.
|
||||
"""
|
||||
embedding_dim = tables[0].embedding_dim
|
||||
data_type = tables[0].data_type
|
||||
for table in tables:
|
||||
@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
||||
|
||||
|
||||
class Operator(str, enum.Enum):
|
||||
"""
|
||||
Enumeration of operator types.
|
||||
|
||||
This enumeration defines different types of operators that can be applied to Twhin model relations.
|
||||
"""
|
||||
TRANSLATION = "translation"
|
||||
|
||||
|
||||
class Relation(pydantic.BaseModel):
|
||||
"""graph relationship properties and operator"""
|
||||
"""
|
||||
Configuration class for graph relationships in the Twhin model.
|
||||
|
||||
This class defines properties and operators for graph relationships in the Twhin model.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the relationship.
|
||||
lhs (str): The name of the entity on the left-hand side of the relation.
|
||||
rhs (str): The name of the entity on the right-hand side of the relation.
|
||||
operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product.
|
||||
"""
|
||||
|
||||
name: str = pydantic.Field(..., description="Relationship name.")
|
||||
lhs: str = pydantic.Field(
|
||||
@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel):
|
||||
|
||||
|
||||
class TwhinModelConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration class for the Twhin model.
|
||||
|
||||
This class defines configuration options specific to the Twhin model.
|
||||
|
||||
Attributes:
|
||||
embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings.
|
||||
relations (List[Relation]): List of graph relationship configurations.
|
||||
translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation.
|
||||
"""
|
||||
embeddings: TwhinEmbeddingsConfig
|
||||
relations: typing.List[Relation]
|
||||
translation_optimizer: OptimizerConfig
|
||||
|
||||
@validator("relations", each_item=True)
|
||||
def valid_node_types(cls, relation, values, **kwargs):
|
||||
"""
|
||||
Validate that the specified node types in relations are valid table names in embeddings.
|
||||
|
||||
Args:
|
||||
relation (Relation): A single relation configuration.
|
||||
values (dict): The values dictionary containing the "embeddings" configuration.
|
||||
|
||||
Returns:
|
||||
Relation: The validated relation configuration.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the specified node types are not valid table names in embeddings.
|
||||
"""
|
||||
table_names = [table.name for table in values["embeddings"].tables]
|
||||
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
|
||||
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
|
||||
|
@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa
|
||||
|
||||
|
||||
class TwhinModel(nn.Module):
|
||||
"""
|
||||
Twhin model for graph-based entity embeddings and translation.
|
||||
|
||||
This class defines the Twhin model, which is used for learning embeddings of entities in a graph
|
||||
and applying translations to these embeddings based on graph relationships.
|
||||
|
||||
Args:
|
||||
model_config (TwhinModelConfig): Configuration for the Twhin model.
|
||||
data_config (TwhinDataConfig): Configuration for the data used by the model.
|
||||
|
||||
Attributes:
|
||||
batch_size (int): The batch size used for training.
|
||||
table_names (List[str]): Names of tables in the model's embeddings.
|
||||
large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
|
||||
embedding_dim (int): Dimensionality of entity embeddings.
|
||||
num_tables (int): Number of tables in the model's embeddings.
|
||||
in_batch_negatives (int): Number of in-batch negative samples to use during training.
|
||||
global_negatives (int): Number of global negative samples to use during training.
|
||||
num_relations (int): Number of graph relationships in the model.
|
||||
all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
|
||||
|
||||
"""
|
||||
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
|
||||
super().__init__()
|
||||
self.batch_size = data_config.per_replica_batch_size
|
||||
@ -31,7 +53,17 @@ class TwhinModel(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, batch: EdgeBatch):
|
||||
"""
|
||||
Forward pass of the Twhin model.
|
||||
|
||||
Args:
|
||||
batch (EdgeBatch): Input batch containing graph edge information.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing model output with "logits" and "probabilities".
|
||||
- "logits" (torch.Tensor): Logit scores.
|
||||
- "probabilities" (torch.Tensor): Sigmoid probabilities.
|
||||
"""
|
||||
# B x D
|
||||
trans_embs = self.all_trans_embs.data[batch.rels]
|
||||
|
||||
@ -98,6 +130,18 @@ class TwhinModel(nn.Module):
|
||||
|
||||
|
||||
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
|
||||
"""
|
||||
Apply optimizers to the Twhin model's embeddings.
|
||||
|
||||
This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
|
||||
|
||||
Args:
|
||||
model (TwhinModel): The Twhin model to apply optimizers to.
|
||||
model_config (TwhinModelConfig): Configuration for the Twhin model.
|
||||
|
||||
Returns:
|
||||
TwhinModel: The Twhin model with optimizers applied to its embeddings.
|
||||
"""
|
||||
for table in model_config.embeddings.tables:
|
||||
optimizer_class = get_optimizer_class(table.optimizer)
|
||||
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
|
||||
@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module):
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
model: torch module to wrap.
|
||||
loss_fn: Function for calculating loss, should accept logits and labels.
|
||||
"""
|
||||
Initialize a TwhinModelAndLoss module.
|
||||
|
||||
Args:
|
||||
model: The torch module to wrap.
|
||||
loss_fn: A function for calculating loss, should accept logits and labels.
|
||||
data_config: Configuration for Twhin data.
|
||||
device: The torch device to use for calculations.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.loss_fn = loss_fn
|
||||
@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module):
|
||||
self.device = device
|
||||
|
||||
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
||||
"""Runs model forward and calculates loss according to given loss_fn.
|
||||
|
||||
NOTE: The input signature here needs to be a Pipelineable object for
|
||||
prefetching purposes during training using torchrec's pipeline. However
|
||||
the underlying model signature needs to be exportable to onnx, requiring
|
||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
||||
|
||||
"""
|
||||
Run the model forward and calculate the loss according to the given loss_fn.
|
||||
|
||||
NOTE: The input signature here needs to be a Pipelineable object for
|
||||
prefetching purposes during training using torchrec's pipeline. However
|
||||
the underlying model signature needs to be exportable to onnx, requiring
|
||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types
|
||||
|
||||
Args:
|
||||
batch ("RecapBatch"): The input batch for model inference.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
|
||||
additional outputs including logits, labels, and weights.
|
||||
"""
|
||||
outputs = self.model(batch)
|
||||
logits = outputs["logits"]
|
||||
|
||||
|
@ -18,6 +18,12 @@ EMB_DIM = 128
|
||||
|
||||
|
||||
def twhin_model_config() -> TwhinModelConfig:
|
||||
"""
|
||||
Create a configuration for the Twhin model.
|
||||
|
||||
Returns:
|
||||
TwhinModelConfig: The Twhin model configuration.
|
||||
"""
|
||||
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
|
||||
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
||||
|
||||
@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig:
|
||||
|
||||
|
||||
def twhin_data_config() -> TwhinDataConfig:
|
||||
"""
|
||||
Create a configuration for the Twhin data.
|
||||
|
||||
Returns:
|
||||
TwhinDataConfig: The Twhin data configuration.
|
||||
"""
|
||||
data_config = TwhinDataConfig(
|
||||
data_root="/",
|
||||
per_replica_batch_size=10,
|
||||
@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig:
|
||||
|
||||
|
||||
def test_twhin_model():
|
||||
"""
|
||||
Test the Twhin model creation and optimization.
|
||||
|
||||
This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
|
||||
the device placement of model parameters.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
model_config = twhin_model_config()
|
||||
loss_fn = F.binary_cross_entropy_with_logits
|
||||
|
||||
@ -84,6 +105,23 @@ def test_twhin_model():
|
||||
|
||||
|
||||
def test_unequal_dims():
|
||||
"""
|
||||
Test function for validating unequal embedding dimensions in TwhinEmbeddingsConfig.
|
||||
|
||||
This function tests whether the validation logic correctly raises a `ValidationError` when
|
||||
embedding dimensions in the `TwhinEmbeddingsConfig` are not equal for all tables.
|
||||
|
||||
The test includes the following steps:
|
||||
1. Create two embedding configurations with different embedding dimensions.
|
||||
2. Attempt to create a `TwhinEmbeddingsConfig` instance with the unequal embedding dimensions.
|
||||
3. Assert that a `ValidationError` is raised, indicating that embedding dimensions must match.
|
||||
|
||||
This function serves as a test case to ensure that the validation logic enforces equal embedding dimensions
|
||||
in the `TwhinEmbeddingsConfig` for all tables.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the expected `ValidationError` is not raised.
|
||||
"""
|
||||
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
||||
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
|
||||
table0 = EmbeddingBagConfig(
|
||||
|
@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt"
|
||||
|
||||
|
||||
def _lr_from_config(optimizer_config):
|
||||
"""Get the learning rate from an optimizer configuration.
|
||||
|
||||
Args:
|
||||
optimizer_config: Optimizer configuration.
|
||||
|
||||
Returns:
|
||||
Learning rate from the optimizer configuration.
|
||||
"""
|
||||
if optimizer_config.learning_rate is not None:
|
||||
return optimizer_config.learning_rate
|
||||
else:
|
||||
@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config):
|
||||
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
|
||||
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
|
||||
|
||||
Args:
|
||||
model: TwhinModel to build optimizer for.
|
||||
config: TwhinConfig for model.
|
||||
Args:
|
||||
model: TwhinModel to build optimizer for.
|
||||
config: TwhinModelConfig for model.
|
||||
|
||||
Returns:
|
||||
Optimizer for model.
|
||||
"""
|
||||
Returns:
|
||||
Optimizer for model.
|
||||
"""
|
||||
translation_optimizer_fn = functools.partial(
|
||||
get_optimizer_class(config.translation_optimizer),
|
||||
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),
|
||||
|
@ -37,6 +37,12 @@ def run(
|
||||
all_config: TwhinConfig,
|
||||
save_dir: Optional[str] = None,
|
||||
):
|
||||
"""Run the training process for TwhinModel.
|
||||
|
||||
Args:
|
||||
all_config (TwhinConfig): The configuration for the entire Twhin model.
|
||||
save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None.
|
||||
"""
|
||||
train_dataset = create_dataset(all_config.train_data, all_config.model)
|
||||
|
||||
if env.is_reader():
|
||||
@ -80,6 +86,11 @@ def run(
|
||||
|
||||
|
||||
def main(argv):
|
||||
"""Main entry point for the Twhin training script.
|
||||
|
||||
Args:
|
||||
argv: Command-line arguments.
|
||||
"""
|
||||
logging.info("Starting")
|
||||
|
||||
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")
|
||||
|
@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
class _Reader(pa.flight.FlightServerBase):
|
||||
"""Distributed reader flight server wrapping a dataset."""
|
||||
"""
|
||||
Distributed reader flight server wrapping a dataset.
|
||||
|
||||
This class implements a Flight server that wraps a dataset, allowing clients to retrieve data
|
||||
from the dataset over the Flight protocol. It is designed to be used in a distributed environment
|
||||
for efficient data access.
|
||||
|
||||
Args:
|
||||
location (str): The location of the Flight server.
|
||||
ds (Dataset): The dataset to be wrapped by the Flight server.
|
||||
|
||||
Attributes:
|
||||
_location (str): The location of the Flight server.
|
||||
_ds (Dataset): The dataset wrapped by the Flight server.
|
||||
|
||||
Methods:
|
||||
do_get(_, __): Handles Flight requests for data retrieval.
|
||||
|
||||
Note:
|
||||
Flight is an Apache Arrow project that provides a framework for efficient data transfer.
|
||||
This class allows clients to retrieve data from the dataset using Flight.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, location: str, ds: "Dataset"):
|
||||
"""
|
||||
Initialize a new _Reader instance.
|
||||
|
||||
Args:
|
||||
location (str): The location of the Flight server.
|
||||
ds (Dataset): The dataset to be wrapped by the Flight server.
|
||||
"""
|
||||
super().__init__(location=location)
|
||||
self._location = location
|
||||
self._ds = ds
|
||||
|
||||
def do_get(self, _, __):
|
||||
"""
|
||||
Handle Flight requests for data retrieval.
|
||||
|
||||
This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol.
|
||||
|
||||
Args:
|
||||
_: Unused argument.
|
||||
__: Unused argument.
|
||||
|
||||
Returns:
|
||||
pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset.
|
||||
|
||||
Note:
|
||||
An updated schema (to account for column selection) must be given to the stream.
|
||||
"""
|
||||
# NB: An updated schema (to account for column selection) has to be given the stream.
|
||||
schema = next(iter(self._ds.to_batches())).schema
|
||||
batches = self._ds.to_batches()
|
||||
@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase):
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.IterableDataset):
|
||||
"""
|
||||
A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading.
|
||||
|
||||
This class enables efficient loading of data from Parquet files using PyArrow.
|
||||
It is designed to be used as an IterableDataset in PyTorch for training and inference.
|
||||
|
||||
Args:
|
||||
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
|
||||
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset
|
||||
for more details.
|
||||
|
||||
Attributes:
|
||||
LOCATION (str): The default location for the Flight server used for data distribution.
|
||||
_file_pattern (str): The glob pattern specifying Parquet files in the dataset.
|
||||
_fs: The filesystem object used for file operations.
|
||||
_dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||
_files (list): A list of file paths matching the glob pattern.
|
||||
_schema (pa.Schema): The schema of the Parquet dataset.
|
||||
|
||||
Methods:
|
||||
serve(): Start serving the dataset using a Flight server.
|
||||
to_batches(): Generate batches of data from the Parquet dataset.
|
||||
pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch.
|
||||
dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset.
|
||||
|
||||
Note:
|
||||
This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch
|
||||
to create DataLoader instances for training or inference.
|
||||
"""
|
||||
LOCATION = "grpc://0.0.0.0:2222"
|
||||
|
||||
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
|
||||
"""Specify batch size and column to select for.
|
||||
"""
|
||||
Initialize a new Dataset instance. Specify batch size and column to select for.
|
||||
|
||||
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
|
||||
"""
|
||||
|
||||
|
||||
Args:
|
||||
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
|
||||
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||
"""
|
||||
self._file_pattern = file_pattern
|
||||
self._fs = infer_fs(self._file_pattern)
|
||||
self._dataset_kwargs = dataset_kwargs
|
||||
@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset):
|
||||
self._validate_columns()
|
||||
|
||||
def _validate_columns(self):
|
||||
"""
|
||||
Validate the specified columns against the dataset schema.
|
||||
|
||||
Raises:
|
||||
Exception: If any specified columns are not found in the dataset schema.
|
||||
"""
|
||||
columns = set(self._dataset_kwargs.get("columns", []))
|
||||
wrong_columns = set(columns) - set(self._schema.names)
|
||||
if wrong_columns:
|
||||
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
|
||||
|
||||
def serve(self):
|
||||
"""Start serving the dataset using a Flight server."""
|
||||
self.reader = _Reader(location=self.LOCATION, ds=self)
|
||||
self.reader.serve()
|
||||
|
||||
def _create_dataset(self):
|
||||
"""Create a PyArrow dataset for data retrieval."""
|
||||
|
||||
return pads.dataset(
|
||||
source=random.sample(self._files, len(self._files))[0],
|
||||
format="parquet",
|
||||
@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset):
|
||||
|
||||
@abc.abstractmethod
|
||||
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
|
||||
"""
|
||||
Convert a Parquet RecordBatch to a custom data batch.
|
||||
|
||||
Args:
|
||||
batch (pa.RecordBatch): A batch of data from the Parquet dataset.
|
||||
|
||||
Returns:
|
||||
DataclassBatch: A custom data batch used in PyTorch training.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in derived classes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def dataloader(self, remote: bool = False):
|
||||
"""
|
||||
Create a PyTorch DataLoader for iterating through the dataset.
|
||||
|
||||
Args:
|
||||
remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training.
|
||||
|
||||
Returns:
|
||||
DataLoader: A PyTorch DataLoader for iterating through the dataset.
|
||||
|
||||
Note:
|
||||
If `remote` is True, a remote DataLoader is created for distributed training using Flight.
|
||||
"""
|
||||
if not remote:
|
||||
return map(self.pa_to_batch, self.to_batches())
|
||||
readers = get_readers(2)
|
||||
@ -117,6 +230,25 @@ GRPC_OPTIONS = [
|
||||
|
||||
|
||||
def get_readers(num_readers_per_worker: int):
|
||||
"""
|
||||
Get Flight readers for distributed data loading.
|
||||
|
||||
This function retrieves Flight readers for distributed data loading in a PyTorch environment.
|
||||
|
||||
Args:
|
||||
num_readers_per_worker (int): The number of Flight readers to retrieve per worker.
|
||||
|
||||
Returns:
|
||||
List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading.
|
||||
|
||||
Note:
|
||||
Flight readers are used to fetch data in a distributed manner for efficient data loading.
|
||||
|
||||
Example:
|
||||
To obtain Flight readers, use the following code:
|
||||
|
||||
>>> readers = get_readers(num_readers_per_worker=2)
|
||||
"""
|
||||
addresses = env.get_flight_server_addresses()
|
||||
|
||||
readers = []
|
||||
|
@ -21,6 +21,16 @@ import torch.distributed as dist
|
||||
|
||||
|
||||
def maybe_start_dataset_service():
|
||||
"""
|
||||
Start the dataset service if readers are available and required dependencies are met.
|
||||
|
||||
This function checks if readers are available and if the required TensorFlow version is >= 2.5.
|
||||
If both conditions are met and the current environment is the dispatcher or reader, it starts
|
||||
the TensorFlow dataset service.
|
||||
|
||||
Raises:
|
||||
Exception: If the required TensorFlow version is not met (>= 2.5).
|
||||
"""
|
||||
if not env.has_readers():
|
||||
return
|
||||
|
||||
@ -59,6 +69,24 @@ def maybe_start_dataset_service():
|
||||
def register_dataset(
|
||||
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
|
||||
):
|
||||
"""
|
||||
Register a dataset with the distributed dataset service.
|
||||
|
||||
This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
|
||||
and job name to all processes in the distributed environment.
|
||||
|
||||
Args:
|
||||
dataset (tf.data.Dataset): The dataset to be registered.
|
||||
dataset_service (str): The name of the dataset service.
|
||||
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
||||
|
||||
Returns:
|
||||
Tuple[int, str]: A tuple containing the dataset ID and job name.
|
||||
|
||||
Note:
|
||||
This function should be called on the rank 0 process.
|
||||
|
||||
"""
|
||||
if dist.get_rank() == 0:
|
||||
dataset_id = _register_dataset(
|
||||
service=dataset_service,
|
||||
@ -82,6 +110,23 @@ def distribute_from_dataset_id(
|
||||
compression: Optional[str] = "AUTO",
|
||||
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
|
||||
) -> tf.data.Dataset:
|
||||
"""
|
||||
Distribute a dataset from a registered dataset ID.
|
||||
|
||||
This function consumes a dataset from the distributed dataset service using the provided dataset ID
|
||||
and job name. It also supports prefetching for improved performance.
|
||||
|
||||
Args:
|
||||
dataset_service (str): The name of the dataset service.
|
||||
dataset_id (int): The ID of the dataset to be consumed.
|
||||
job_name (Optional[str]): The name of the job associated with the dataset (optional).
|
||||
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
||||
prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset: The distributed dataset.
|
||||
|
||||
"""
|
||||
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
|
||||
dataset = _from_dataset_id(
|
||||
processing_mode="parallel_epochs",
|
||||
@ -97,15 +142,28 @@ def distribute_from_dataset_id(
|
||||
|
||||
|
||||
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
|
||||
"""Torch-compatible and distributed-training-aware dataset service distributor.
|
||||
|
||||
- rank 0 process will register the given dataset.
|
||||
- rank 0 process will broadcast job name and dataset id.
|
||||
- all rank processes will consume from the same job/dataset.
|
||||
|
||||
Without this, dataset workers will try to serve 1 job per rank process and OOM.
|
||||
|
||||
"""
|
||||
Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
|
||||
|
||||
This function is used to distribute a dataset in a distributed training environment. It performs the
|
||||
following steps:
|
||||
- On the rank 0 process, it registers the given dataset with the distributed dataset service.
|
||||
- It broadcasts the job name and dataset ID to all rank processes.
|
||||
- All rank processes then consume the same dataset from the distributed dataset service.
|
||||
|
||||
Args:
|
||||
dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset: The distributed TensorFlow dataset.
|
||||
|
||||
Note:
|
||||
- If there are no reader processes in the distributed environment, the original dataset is returned
|
||||
without any distribution.
|
||||
- This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
|
||||
issues caused by each rank process trying to serve one job.
|
||||
|
||||
"""
|
||||
if not env.has_readers():
|
||||
return dataset
|
||||
dataset_service = env.get_dds()
|
||||
|
@ -12,6 +12,17 @@ import torch
|
||||
|
||||
|
||||
def create_dataset(tmpdir):
|
||||
"""
|
||||
Create a mock dataset for testing.
|
||||
|
||||
This function creates a mock dataset using PyArrow and Parquet for testing purposes.
|
||||
|
||||
Args:
|
||||
tmpdir: A temporary directory where the dataset will be created.
|
||||
|
||||
Returns:
|
||||
MockDataset: A mock dataset for testing.
|
||||
"""
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
@ -34,6 +45,14 @@ def create_dataset(tmpdir):
|
||||
|
||||
|
||||
def test_dataset(tmpdir):
|
||||
"""
|
||||
Test the created dataset.
|
||||
|
||||
This function tests the created mock dataset and checks if it behaves as expected.
|
||||
|
||||
Args:
|
||||
tmpdir: A temporary directory used for testing.
|
||||
"""
|
||||
ds = create_dataset(tmpdir)
|
||||
batch = next(iter(ds.dataloader(remote=False)))
|
||||
assert batch.batch_size == 2
|
||||
@ -46,6 +65,14 @@ def test_dataset(tmpdir):
|
||||
reason="Multiprocessing doesn't work on github yet.",
|
||||
)
|
||||
def test_distributed_dataset(tmpdir):
|
||||
"""
|
||||
Test the distributed dataset.
|
||||
|
||||
This function tests the distributed version of the mock dataset using multiprocessing.
|
||||
|
||||
Args:
|
||||
tmpdir: A temporary directory used for testing.
|
||||
"""
|
||||
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
|
||||
|
||||
def _client():
|
||||
|
151
reader/utils.py
151
reader/utils.py
@ -11,11 +11,55 @@ import torch
|
||||
|
||||
|
||||
def roundrobin(*iterables):
|
||||
"""Round robin through provided iterables, useful for simple load balancing.
|
||||
|
||||
Adapted from https://docs.python.org/3/library/itertools.html.
|
||||
|
||||
"""
|
||||
Iterate through provided iterables in a round-robin fashion.
|
||||
|
||||
This function takes multiple iterables and returns an iterator that yields elements from
|
||||
each iterable in a round-robin manner. It continues cycling through the iterables until
|
||||
all of them are exhausted.
|
||||
|
||||
Adapted from https://docs.python.org/3/library/itertools.html.
|
||||
|
||||
Args:
|
||||
*iterables: One or more iterable objects to iterate through.
|
||||
|
||||
Yields:
|
||||
Elements from the provided iterables in a round-robin fashion.
|
||||
|
||||
Raises:
|
||||
StopIteration: If all provided iterables are exhausted.
|
||||
|
||||
Example:
|
||||
```python
|
||||
iterable1 = [1, 2, 3]
|
||||
iterable2 = ['a', 'b', 'c']
|
||||
iterable3 = [0.1, 0.2, 0.3]
|
||||
|
||||
for item in roundrobin(iterable1, iterable2, iterable3):
|
||||
print(item)
|
||||
|
||||
# Output:
|
||||
# 1
|
||||
# 'a'
|
||||
# 0.1
|
||||
# 2
|
||||
# 'b'
|
||||
# 0.2
|
||||
# 3
|
||||
# 'c'
|
||||
# 0.3
|
||||
```
|
||||
|
||||
Note:
|
||||
- If one of the provided iterables is shorter than the others, the function will
|
||||
continue iterating through the remaining iterables until all are exhausted.
|
||||
- If an iterable raises an exception during iteration, a warning message is logged,
|
||||
and the function continues with the next iterable.
|
||||
|
||||
See Also:
|
||||
- `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
|
||||
- `itertools.islice`: A function to slice an iterable to limit the number of iterations.
|
||||
"""
|
||||
num_active = len(iterables)
|
||||
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
||||
while num_active:
|
||||
@ -35,6 +79,48 @@ def roundrobin(*iterables):
|
||||
|
||||
|
||||
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
|
||||
"""
|
||||
Monitor the speed and progress of data loading using a data loader.
|
||||
|
||||
This function iterates through a data loader for a specified number of steps or until
|
||||
the end of the data loader is reached, periodically logging progress information.
|
||||
|
||||
Args:
|
||||
data_loader: The data loader to monitor.
|
||||
max_steps: The maximum number of steps to iterate through the data loader.
|
||||
frequency: The frequency (in steps) at which to log progress.
|
||||
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
||||
batch contents for inspection.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Create a data loader (replace with your own DataLoader configuration)
|
||||
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||
|
||||
# Monitor data loading speed and progress
|
||||
speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
|
||||
```
|
||||
|
||||
Args:
|
||||
data_loader: The data loader to monitor.
|
||||
max_steps: The maximum number of steps to iterate through the data loader.
|
||||
frequency: The frequency (in steps) at which to log progress.
|
||||
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
||||
batch contents for inspection.
|
||||
|
||||
Note:
|
||||
- The function logs information about elapsed time, the number of examples processed,
|
||||
and the processing speed in examples per second.
|
||||
- If `peek` is provided, batch contents will be logged for inspection at the specified
|
||||
frequency.
|
||||
|
||||
See Also:
|
||||
- `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
|
||||
iterating through datasets.
|
||||
"""
|
||||
num_examples = 0
|
||||
prev = time.perf_counter()
|
||||
for idx, batch in enumerate(data_loader):
|
||||
@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]
|
||||
|
||||
|
||||
def pa_to_torch(array: pa.array) -> torch.Tensor:
|
||||
"""
|
||||
Convert a PyArrow Array to a PyTorch Tensor.
|
||||
|
||||
Args:
|
||||
array (pa.array): The PyArrow Array to convert.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
|
||||
# Create a PyArrow Array
|
||||
arrow_array = pa.array([1, 2, 3])
|
||||
|
||||
# Convert it to a PyTorch Tensor
|
||||
torch_tensor = pa_to_torch(arrow_array)
|
||||
```
|
||||
"""
|
||||
return torch.from_numpy(array.to_numpy())
|
||||
|
||||
|
||||
def create_default_pa_to_batch(schema) -> DataclassBatch:
|
||||
""" """
|
||||
"""
|
||||
Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
|
||||
|
||||
Args:
|
||||
schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
|
||||
|
||||
Returns:
|
||||
callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import pyarrow as pa
|
||||
from dataclass_batch import DataclassBatch
|
||||
|
||||
# Define a PyArrow schema
|
||||
schema = pa.schema([
|
||||
("feature1", pa.float64()),
|
||||
("feature2", pa.int64()),
|
||||
("label", pa.int64()),
|
||||
])
|
||||
|
||||
# Create the conversion function
|
||||
pa_to_batch = create_default_pa_to_batch(schema)
|
||||
|
||||
# Create a PyArrow RecordBatch
|
||||
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
|
||||
"feature1": [1.0, 2.0, None],
|
||||
"feature2": [10, 20, 30],
|
||||
"label": [0, 1, None],
|
||||
}))
|
||||
|
||||
# Convert the RecordBatch to a custom DataclassBatch
|
||||
custom_batch = pa_to_batch(record_batch)
|
||||
```
|
||||
"""
|
||||
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
|
||||
|
||||
def get_imputation_value(pa_type):
|
||||
|
Loading…
Reference in New Issue
Block a user