This commit is contained in:
rajveer43 2023-09-11 21:31:42 +05:30
parent 9bb0986079
commit deec9a820e
11 changed files with 380 additions and 63 deletions

View File

@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable
class BatchBase(Pipelineable, abc.ABC):
"""
A base class for batches used in pipelines.
Attributes:
None
"""
@abc.abstractmethod
def as_dict(self) -> Dict:
"""
Convert the batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
raise NotImplementedError
def to(self, device: torch.device, non_blocking: bool = False):
"""
Move the batch to the specified device.
Args:
device (torch.device): The target device.
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.
Returns:
BatchBase: A new batch on the target device.
"""
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
return self.__class__(**args)
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
"""
Record a CUDA stream for all tensors in the batch.
Args:
stream (torch.cuda.streams.Stream): The CUDA stream to record.
Returns:
None
"""
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)
def pin_memory(self):
"""
Pin memory for all tensors in the batch.
Returns:
BatchBase: A new batch with pinned memory.
"""
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)
def __repr__(self) -> str:
"""
Generate a string representation of the batch.
Returns:
str: A string representation of the batch.
"""
def obj2str(v):
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC):
@property
def batch_size(self) -> int:
"""
Get the batch size from the tensors in the batch.
Returns:
int: The batch size.
Raises:
Exception: If the batch size cannot be determined from the tensors.
"""
for tensor in self.as_dict().values():
if tensor is None:
continue
@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC):
@dataclass
class DataclassBatch(BatchBase):
"""
A batch class that uses dataclasses to define its fields.
Attributes:
None
"""
@classmethod
def feature_names(cls):
"""
Get the feature names of the dataclass.
Returns:
List[str]: A list of feature names.
"""
return list(cls.__dataclass_fields__.keys())
def as_dict(self):
"""
Convert the dataclass batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
"""
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
@ -64,7 +147,18 @@ class DataclassBatch(BatchBase):
@staticmethod
def from_schema(name: str, schema):
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
"""
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.
Args:
name (str): The name of the custom batch class.
schema: The schema or structure of the batch.
Returns:
Type[DataclassBatch]: A custom batch class.
"""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
@ -73,6 +167,17 @@ class DataclassBatch(BatchBase):
@staticmethod
def from_fields(name: str, fields: dict):
"""
Create a custom batch subclass from a set of fields.
Args:
name (str): The name of the custom batch class.
fields (dict): A dictionary specifying the fields and their types.
Returns:
Type[DataclassBatch]: A custom batch class.
"""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
@ -81,5 +186,19 @@ class DataclassBatch(BatchBase):
class DictionaryBatch(BatchBase, dict):
"""
A batch class that represents data as a dictionary.
Attributes:
None
"""
def as_dict(self) -> Dict:
"""
Convert the dictionary batch into a dictionary representation.
Returns:
Dict: A dictionary representation of the batch.
"""
return self

View File

@ -5,6 +5,15 @@ import torch.distributed as dist
def maybe_setup_tensorflow():
"""
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
Returns:
None
"""
try:
import tensorflow as tf
except ImportError:
@ -14,6 +23,18 @@ def maybe_setup_tensorflow():
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
"""
Set up the distributed environment and get the appropriate torch device.
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
Args:
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
Returns:
torch.device: The torch device for the current process.
"""
if tf_ok:
maybe_setup_tensorflow()

View File

@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem()
def infer_fs(path: str):
"""
Infer the file system (fs) type based on the given path.
Args:
path (str): The file path.
Returns:
str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system).
Raises:
NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported.
"""
if path.startswith("gs://"):
return GCS_FS
elif path.startswith("hdfs://"):
@ -18,8 +30,26 @@ def infer_fs(path: str):
def is_local_fs(fs):
"""
Check if the given file system is the local file system.
Args:
fs (str): The file system type to check.
Returns:
bool: True if the file system is the local file system, False otherwise.
"""
return fs == LOCAL_FS
def is_gcs_fs(fs):
"""
Check if the given file system is Google Cloud Storage (GCS).
Args:
fs (str): The file system type to check.
Returns:
bool: True if the file system is GCS, False otherwise.
"""
return fs == GCS_FS

View File

@ -49,15 +49,25 @@ def log_ebc_norms(
ebc_keys: List[str],
sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:
"""Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank.
Args:
model_state_dict: model.state_dict()
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
"""
Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank.
Args:
model_state_dict (dict): The state dictionary of the model.
ebc_keys (List[str]): List of embedding keys from state_dict to log.
Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight".
sample_size (int, optional): Limits the number of rows per rank to compute average on
to avoid out-of-memory (OOM) errors. Defaults to 4,000,000.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables.
The keys are in the format "{emb_key}-norm-{idx}".
Note:
This function computes and logs the average norm of embedding tables across ranks.
It gathers the norms from all ranks and returns them as a dictionary.
"""
norm_logs = dict()
for emb_key in ebc_keys:
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))

View File

@ -11,6 +11,20 @@ import numpy as np
class LargeEmbeddings(nn.Module):
"""
A module for handling large embeddings.
Args:
large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings.
Attributes:
ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings.
surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models.
Note:
The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving.
"""
def __init__(
self,
large_embeddings_config: LargeEmbeddingsConfig,
@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module):
self,
sparse_features: KeyedJaggedTensor,
) -> KeyedTensor:
"""
Forward pass of the LargeEmbeddings module.
Args:
sparse_features (KeyedJaggedTensor): Sparse input features.
Returns:
KeyedTensor: The output of the forward pass, a KeyedTensor.
"""
pooled_embs = self.ebc(sparse_features)
# a KeyedTensor

View File

@ -11,6 +11,12 @@ import torch.distributed.run
def is_distributed_worker():
"""
Checks if the current process is a distributed worker.
Returns:
bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False.
"""
world_size = os.environ.get("WORLD_SIZE", None)
rank = os.environ.get("RANK", None)
return world_size is not None and rank is not None
@ -25,25 +31,36 @@ def maybe_run_training(
is_chief: Optional[bool] = False,
**training_kwargs,
):
"""Wrapper function for single node, multi-GPU Pytorch training.
If the necessary distributed Pytorch environment variables
(WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module
`module_name`. After this call, the necessary environment variables are set
and training will commence.
Args:
train_fn: The function that is responsible for training
module_name: The name of the module that this function was called from;
used to indicate torchrun entrypoint.
nproc_per_node: Number of workers per node; supported values.
num_nodes: Number of nodes, otherwise inferred from environment.
is_chief: If process is running on chief.
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
"""
Wrapper function for single node, multi-GPU PyTorch training.
If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module
`module_name`. After this call, the necessary environment variables are set
and training will commence.
Args:
train_fn (callable): The function responsible for training.
module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint.
nproc_per_node (int, optional): Number of workers per node. Defaults to None.
num_nodes (int, optional): Number of nodes. Defaults to None.
is_chief (bool, optional): If the process is running on the chief node. Defaults to False.
set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False.
**training_kwargs: Additional keyword arguments to pass to the `train_fn`.
Note:
This function checks if the current process is a distributed worker by examining the environment variables.
If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary
environment variables and launches the training process using torchrun.
Example:
To run training on a single node with 4 GPUs, you can use:
```
maybe_run_training(train_function, __name__, nproc_per_node=4)
```
"""
machines = utils.machine_from_env()
if num_nodes is None:

View File

@ -19,6 +19,23 @@ MOCK_ENV = {
@contextmanager
def mock_pg():
"""
A context manager for mocking the distributed process group for testing purposes.
This context manager temporarily sets environment variables to mock the distributed process group
and initializes it using the Gloo backend. It is useful for testing distributed training without
actually launching multiple processes.
Example:
```python
with mock_pg():
# Your distributed training code here
```
Note:
This context manager should be used within a testing environment to simulate distributed training
without actually creating multiple processes.
"""
with patch.dict(os.environ, MOCK_ENV):
try:
dist.init_process_group(

View File

@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig)
def _read_file(f):
"""
Read the contents of a file using fsspec.
Args:
f: File path or URL.
Returns:
The contents of the file.
"""
with fsspec.open(f) as f:
return f.read()
@ -21,18 +30,24 @@ def setup_configuration(
yaml_path: str,
substitute_env_variable: bool = False,
) -> Tuple[C, str]:
"""Resolves a config at a yaml path.
Args:
config_type: Pydantic config class to load.
yaml_path: yaml path of the config file.
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
environment variable value whenever possible. If an environment variable doesn't exist,
the string is left unchanged.
Returns:
The pydantic config object.
"""
Load a Pydantic config object from a YAML file and optionally substitute environment variables.
Args:
config_type: Pydantic config class to load.
yaml_path: YAML path of the config file.
substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR}
with their environment variable values whenever possible.
If an environment variable doesn't exist, the string is left unchanged.
Returns:
A tuple containing the Pydantic config object and the resolved YAML content.
Example:
```python
config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True)
```
"""
def _substitute(s):
if substitute_env_variable:

View File

@ -6,6 +6,33 @@ import pydantic
class WandbConfig(base_config.BaseConfig):
"""
Configuration for integrating with Weights and Biases (WandB).
Attributes:
host (str): Host of the Weights and Biases instance, passed to login.
key_path (str): Path to the key file.
name (str): Name of the experiment, passed to init.
entity (str): Name of the user/service account, passed to init.
project (str): Name of the WandB project, passed to init.
tags (List[str]): List of tags, passed to init.
notes (str): Notes, passed to init.
metadata (Dict[str, Any]): Additional metadata to log.
Example:
```python
wandb_config = WandbConfig(
host="https://wandb.example.com",
key_path="/path/to/key",
name="experiment_1",
entity="user123",
project="my_project",
tags=["experiment", "ml"],
notes="This is a test experiment.",
metadata={"version": "1.0"}
)
```
"""
host: str = pydantic.Field(
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
description="Host of Weights and Biases instance, passed to login.",

View File

@ -6,33 +6,46 @@ import pydantic
class BaseConfigTest(TestCase):
def test_extra_forbidden(self):
class Config(BaseConfig):
x: int
"""
Unit tests for the BaseConfig class.
"""
Config(x=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=2)
def test_extra_forbidden(self):
"""
Test that extra fields are forbidden when creating a Config instance.
"""
class Config(BaseConfig):
x: int
def test_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f")
y: int = pydantic.Field(None, one_of="f")
Config(x=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=2)
with self.assertRaises(pydantic.ValidationError):
Config()
Config(x=1)
Config(y=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=3)
def test_one_of(self):
"""
Test the use of the `one_of` attribute for fields in a Config instance.
"""
class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f")
y: int = pydantic.Field(None, one_of="f")
def test_at_most_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f")
y: str = pydantic.Field(None, at_most_one_of="f")
with self.assertRaises(pydantic.ValidationError):
Config()
Config(x=1)
Config(y=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=3)
Config()
Config(x=1)
Config(y="a")
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y="a")
def test_at_most_one_of(self):
"""
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
"""
class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f")
y: str = pydantic.Field(None, at_most_one_of="f")
Config()
Config(x=1)
Config(y="a")
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y="a")

View File

@ -9,6 +9,14 @@ import pydantic
class RuntimeConfig(base_config.BaseConfig):
"""
Configuration for runtime settings.
Attributes:
- wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration.
- enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices.
- enable_amp (bool): Enable automatic mixed precision.
"""
wandb: WandbConfig = pydantic.Field(None)
enable_tensorfloat32: bool = pydantic.Field(
False, description="Use tensorfloat32 if on Ampere devices."
@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig):
class TrainingConfig(base_config.BaseConfig):
"""
Configuration for training settings.
Attributes:
- save_dir (str): Directory to save checkpoints.
- num_train_steps (pydantic.PositiveInt): Number of training steps.
- initial_checkpoint_dir (str): Directory of initial checkpoints (optional).
- checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps.
- checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
- train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps.
- num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used.
- eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps.
- eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds.
- gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
- num_epochs (pydantic.PositiveInt): Number of training epochs.
"""
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
num_train_steps: pydantic.PositiveInt = 10000
initial_checkpoint_dir: str = pydantic.Field(