From deec9a820eabc22940c17eb6bef89086c3dfb8d1 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Mon, 11 Sep 2023 21:31:42 +0530 Subject: [PATCH] new --- common/batch.py | 121 +++++++++++++++++++++++++- common/device.py | 21 +++++ common/filesystem/util.py | 30 +++++++ common/log_weights.py | 26 ++++-- common/modules/embedding/embedding.py | 24 +++++ common/run_training.py | 53 +++++++---- common/testing_utils.py | 17 ++++ common/utils.py | 37 +++++--- common/wandb.py | 27 ++++++ core/config/base_config_test.py | 63 ++++++++------ core/config/training.py | 24 +++++ 11 files changed, 380 insertions(+), 63 deletions(-) diff --git a/common/batch.py b/common/batch.py index cfd0ac0..f72b762 100644 --- a/common/batch.py +++ b/common/batch.py @@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable class BatchBase(Pipelineable, abc.ABC): + """ + A base class for batches used in pipelines. + + Attributes: + None + + """ @abc.abstractmethod def as_dict(self) -> Dict: + """ + Convert the batch into a dictionary representation. + + Returns: + Dict: A dictionary representation of the batch. + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + + """ raise NotImplementedError def to(self, device: torch.device, non_blocking: bool = False): + """ + Move the batch to the specified device. + + Args: + device (torch.device): The target device. + non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False. + + Returns: + BatchBase: A new batch on the target device. + + """ args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking) return self.__class__(**args) def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + """ + Record a CUDA stream for all tensors in the batch. + + Args: + stream (torch.cuda.streams.Stream): The CUDA stream to record. + + Returns: + None + + """ for feature_value in self.as_dict().values(): feature_value.record_stream(stream) def pin_memory(self): + """ + Pin memory for all tensors in the batch. + + Returns: + BatchBase: A new batch with pinned memory. + + """ args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.pin_memory() return self.__class__(**args) def __repr__(self) -> str: + """ + Generate a string representation of the batch. + + Returns: + str: A string representation of the batch. + + """ def obj2str(v): return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" @@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC): @property def batch_size(self) -> int: + """ + Get the batch size from the tensors in the batch. + + Returns: + int: The batch size. + + Raises: + Exception: If the batch size cannot be determined from the tensors. + + """ for tensor in self.as_dict().values(): if tensor is None: continue @@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC): @dataclass class DataclassBatch(BatchBase): + """ + A batch class that uses dataclasses to define its fields. + + Attributes: + None + + """ @classmethod def feature_names(cls): + """ + Get the feature names of the dataclass. + + Returns: + List[str]: A list of feature names. + + """ return list(cls.__dataclass_fields__.keys()) def as_dict(self): + """ + Convert the dataclass batch into a dictionary representation. + + Returns: + Dict: A dictionary representation of the batch. + + """ return { feature_name: getattr(self, feature_name) for feature_name in self.feature_names() @@ -64,7 +147,18 @@ class DataclassBatch(BatchBase): @staticmethod def from_schema(name: str, schema): - """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" + """ + Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor. + + Args: + name (str): The name of the custom batch class. + schema: The schema or structure of the batch. + + Returns: + Type[DataclassBatch]: A custom batch class. + + """ + return dataclasses.make_dataclass( cls_name=name, fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names], @@ -73,6 +167,17 @@ class DataclassBatch(BatchBase): @staticmethod def from_fields(name: str, fields: dict): + """ + Create a custom batch subclass from a set of fields. + + Args: + name (str): The name of the custom batch class. + fields (dict): A dictionary specifying the fields and their types. + + Returns: + Type[DataclassBatch]: A custom batch class. + + """ return dataclasses.make_dataclass( cls_name=name, fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()], @@ -81,5 +186,19 @@ class DataclassBatch(BatchBase): class DictionaryBatch(BatchBase, dict): + """ + A batch class that represents data as a dictionary. + + Attributes: + None + + """ def as_dict(self) -> Dict: + """ + Convert the dictionary batch into a dictionary representation. + + Returns: + Dict: A dictionary representation of the batch. + + """ return self diff --git a/common/device.py b/common/device.py index 33d1bb8..b51c638 100644 --- a/common/device.py +++ b/common/device.py @@ -5,6 +5,15 @@ import torch.distributed as dist def maybe_setup_tensorflow(): + """ + Try to import TensorFlow and disable GPU devices if TensorFlow is available. + + This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch. + + Returns: + None + + """ try: import tensorflow as tf except ImportError: @@ -14,6 +23,18 @@ def maybe_setup_tensorflow(): def setup_and_get_device(tf_ok: bool = True) -> torch.device: + """ + Set up the distributed environment and get the appropriate torch device. + + This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank. + + Args: + tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True. + + Returns: + torch.device: The torch device for the current process. + + """ if tf_ok: maybe_setup_tensorflow() diff --git a/common/filesystem/util.py b/common/filesystem/util.py index 5148a0d..30360f2 100644 --- a/common/filesystem/util.py +++ b/common/filesystem/util.py @@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem() def infer_fs(path: str): + """ + Infer the file system (fs) type based on the given path. + + Args: + path (str): The file path. + + Returns: + str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system). + + Raises: + NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported. + """ if path.startswith("gs://"): return GCS_FS elif path.startswith("hdfs://"): @@ -18,8 +30,26 @@ def infer_fs(path: str): def is_local_fs(fs): + """ + Check if the given file system is the local file system. + + Args: + fs (str): The file system type to check. + + Returns: + bool: True if the file system is the local file system, False otherwise. + """ return fs == LOCAL_FS def is_gcs_fs(fs): + """ + Check if the given file system is Google Cloud Storage (GCS). + + Args: + fs (str): The file system type to check. + + Returns: + bool: True if the file system is GCS, False otherwise. + """ return fs == GCS_FS diff --git a/common/log_weights.py b/common/log_weights.py index d2c8e21..cecf42b 100644 --- a/common/log_weights.py +++ b/common/log_weights.py @@ -49,15 +49,25 @@ def log_ebc_norms( ebc_keys: List[str], sample_size: int = 4_000_000, ) -> Dict[str, torch.Tensor]: - """Logs the norms of the embedding tables as specified by ebc_keys. - As of now, log average norm per rank. - - Args: - model_state_dict: model.state_dict() - ebc_keys: list of embedding keys from state_dict to log. Must contain full name, - i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight - sample_size: Limits number of rows per rank to compute average on to avoid OOM. """ + Logs the norms of the embedding tables as specified by ebc_keys. + As of now, log average norm per rank. + + Args: + model_state_dict (dict): The state dictionary of the model. + ebc_keys (List[str]): List of embedding keys from state_dict to log. + Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight". + sample_size (int, optional): Limits the number of rows per rank to compute average on + to avoid out-of-memory (OOM) errors. Defaults to 4,000,000. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables. + The keys are in the format "{emb_key}-norm-{idx}". + + Note: + This function computes and logs the average norm of embedding tables across ranks. + It gathers the norms from all ranks and returns them as a dictionary. + """ norm_logs = dict() for emb_key in ebc_keys: norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}")) diff --git a/common/modules/embedding/embedding.py b/common/modules/embedding/embedding.py index b0a085e..579faa5 100644 --- a/common/modules/embedding/embedding.py +++ b/common/modules/embedding/embedding.py @@ -11,6 +11,20 @@ import numpy as np class LargeEmbeddings(nn.Module): + """ + A module for handling large embeddings. + + Args: + large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings. + + Attributes: + ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings. + surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models. + + Note: + The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving. + + """ def __init__( self, large_embeddings_config: LargeEmbeddingsConfig, @@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module): self, sparse_features: KeyedJaggedTensor, ) -> KeyedTensor: + """ + Forward pass of the LargeEmbeddings module. + + Args: + sparse_features (KeyedJaggedTensor): Sparse input features. + + Returns: + KeyedTensor: The output of the forward pass, a KeyedTensor. + + """ pooled_embs = self.ebc(sparse_features) # a KeyedTensor diff --git a/common/run_training.py b/common/run_training.py index fafb140..20a3d34 100644 --- a/common/run_training.py +++ b/common/run_training.py @@ -11,6 +11,12 @@ import torch.distributed.run def is_distributed_worker(): + """ + Checks if the current process is a distributed worker. + + Returns: + bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False. + """ world_size = os.environ.get("WORLD_SIZE", None) rank = os.environ.get("RANK", None) return world_size is not None and rank is not None @@ -25,25 +31,36 @@ def maybe_run_training( is_chief: Optional[bool] = False, **training_kwargs, ): - """Wrapper function for single node, multi-GPU Pytorch training. - - If the necessary distributed Pytorch environment variables - (WORLD_SIZE, RANK) have been set, then this function executes - `train_fn(**training_kwargs)`. - - Otherwise, this function calls torchrun and points at the calling module - `module_name`. After this call, the necessary environment variables are set - and training will commence. - - Args: - train_fn: The function that is responsible for training - module_name: The name of the module that this function was called from; - used to indicate torchrun entrypoint. - nproc_per_node: Number of workers per node; supported values. - num_nodes: Number of nodes, otherwise inferred from environment. - is_chief: If process is running on chief. - set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH. """ + Wrapper function for single node, multi-GPU PyTorch training. + + If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes + `train_fn(**training_kwargs)`. + + Otherwise, this function calls torchrun and points at the calling module + `module_name`. After this call, the necessary environment variables are set + and training will commence. + + Args: + train_fn (callable): The function responsible for training. + module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint. + nproc_per_node (int, optional): Number of workers per node. Defaults to None. + num_nodes (int, optional): Number of nodes. Defaults to None. + is_chief (bool, optional): If the process is running on the chief node. Defaults to False. + set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False. + **training_kwargs: Additional keyword arguments to pass to the `train_fn`. + + Note: + This function checks if the current process is a distributed worker by examining the environment variables. + If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary + environment variables and launches the training process using torchrun. + + Example: + To run training on a single node with 4 GPUs, you can use: + ``` + maybe_run_training(train_function, __name__, nproc_per_node=4) + ``` + """ machines = utils.machine_from_env() if num_nodes is None: diff --git a/common/testing_utils.py b/common/testing_utils.py index 69e14de..6feaf6b 100644 --- a/common/testing_utils.py +++ b/common/testing_utils.py @@ -19,6 +19,23 @@ MOCK_ENV = { @contextmanager def mock_pg(): + """ + A context manager for mocking the distributed process group for testing purposes. + + This context manager temporarily sets environment variables to mock the distributed process group + and initializes it using the Gloo backend. It is useful for testing distributed training without + actually launching multiple processes. + + Example: + ```python + with mock_pg(): + # Your distributed training code here + ``` + + Note: + This context manager should be used within a testing environment to simulate distributed training + without actually creating multiple processes. + """ with patch.dict(os.environ, MOCK_ENV): try: dist.init_process_group( diff --git a/common/utils.py b/common/utils.py index 09d4464..e509ffb 100644 --- a/common/utils.py +++ b/common/utils.py @@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig) def _read_file(f): + """ + Read the contents of a file using fsspec. + + Args: + f: File path or URL. + + Returns: + The contents of the file. + """ with fsspec.open(f) as f: return f.read() @@ -21,18 +30,24 @@ def setup_configuration( yaml_path: str, substitute_env_variable: bool = False, ) -> Tuple[C, str]: - """Resolves a config at a yaml path. - - Args: - config_type: Pydantic config class to load. - yaml_path: yaml path of the config file. - substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their - environment variable value whenever possible. If an environment variable doesn't exist, - the string is left unchanged. - - Returns: - The pydantic config object. """ + Load a Pydantic config object from a YAML file and optionally substitute environment variables. + + Args: + config_type: Pydantic config class to load. + yaml_path: YAML path of the config file. + substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR} + with their environment variable values whenever possible. + If an environment variable doesn't exist, the string is left unchanged. + + Returns: + A tuple containing the Pydantic config object and the resolved YAML content. + + Example: + ```python + config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True) + ``` + """ def _substitute(s): if substitute_env_variable: diff --git a/common/wandb.py b/common/wandb.py index eb8cceb..9808db0 100644 --- a/common/wandb.py +++ b/common/wandb.py @@ -6,6 +6,33 @@ import pydantic class WandbConfig(base_config.BaseConfig): + """ + Configuration for integrating with Weights and Biases (WandB). + + Attributes: + host (str): Host of the Weights and Biases instance, passed to login. + key_path (str): Path to the key file. + name (str): Name of the experiment, passed to init. + entity (str): Name of the user/service account, passed to init. + project (str): Name of the WandB project, passed to init. + tags (List[str]): List of tags, passed to init. + notes (str): Notes, passed to init. + metadata (Dict[str, Any]): Additional metadata to log. + + Example: + ```python + wandb_config = WandbConfig( + host="https://wandb.example.com", + key_path="/path/to/key", + name="experiment_1", + entity="user123", + project="my_project", + tags=["experiment", "ml"], + notes="This is a test experiment.", + metadata={"version": "1.0"} + ) + ``` + """ host: str = pydantic.Field( "https://https--wandb--prod--wandb.service.qus1.twitter.biz/", description="Host of Weights and Biases instance, passed to login.", diff --git a/core/config/base_config_test.py b/core/config/base_config_test.py index e623f05..e322116 100644 --- a/core/config/base_config_test.py +++ b/core/config/base_config_test.py @@ -6,33 +6,46 @@ import pydantic class BaseConfigTest(TestCase): - def test_extra_forbidden(self): - class Config(BaseConfig): - x: int + """ + Unit tests for the BaseConfig class. + """ - Config(x=1) - with self.assertRaises(pydantic.ValidationError): - Config(x=1, y=2) + def test_extra_forbidden(self): + """ + Test that extra fields are forbidden when creating a Config instance. + """ + class Config(BaseConfig): + x: int - def test_one_of(self): - class Config(BaseConfig): - x: int = pydantic.Field(None, one_of="f") - y: int = pydantic.Field(None, one_of="f") + Config(x=1) + with self.assertRaises(pydantic.ValidationError): + Config(x=1, y=2) - with self.assertRaises(pydantic.ValidationError): - Config() - Config(x=1) - Config(y=1) - with self.assertRaises(pydantic.ValidationError): - Config(x=1, y=3) + def test_one_of(self): + """ + Test the use of the `one_of` attribute for fields in a Config instance. + """ + class Config(BaseConfig): + x: int = pydantic.Field(None, one_of="f") + y: int = pydantic.Field(None, one_of="f") - def test_at_most_one_of(self): - class Config(BaseConfig): - x: int = pydantic.Field(None, at_most_one_of="f") - y: str = pydantic.Field(None, at_most_one_of="f") + with self.assertRaises(pydantic.ValidationError): + Config() + Config(x=1) + Config(y=1) + with self.assertRaises(pydantic.ValidationError): + Config(x=1, y=3) - Config() - Config(x=1) - Config(y="a") - with self.assertRaises(pydantic.ValidationError): - Config(x=1, y="a") + def test_at_most_one_of(self): + """ + Test the use of the `at_most_one_of` attribute for fields in a Config instance. + """ + class Config(BaseConfig): + x: int = pydantic.Field(None, at_most_one_of="f") + y: str = pydantic.Field(None, at_most_one_of="f") + + Config() + Config(x=1) + Config(y="a") + with self.assertRaises(pydantic.ValidationError): + Config(x=1, y="a") diff --git a/core/config/training.py b/core/config/training.py index a0c54b4..acdbde7 100644 --- a/core/config/training.py +++ b/core/config/training.py @@ -9,6 +9,14 @@ import pydantic class RuntimeConfig(base_config.BaseConfig): + """ + Configuration for runtime settings. + + Attributes: + - wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration. + - enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices. + - enable_amp (bool): Enable automatic mixed precision. + """ wandb: WandbConfig = pydantic.Field(None) enable_tensorfloat32: bool = pydantic.Field( False, description="Use tensorfloat32 if on Ampere devices." @@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig): class TrainingConfig(base_config.BaseConfig): + """ + Configuration for training settings. + + Attributes: + - save_dir (str): Directory to save checkpoints. + - num_train_steps (pydantic.PositiveInt): Number of training steps. + - initial_checkpoint_dir (str): Directory of initial checkpoints (optional). + - checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps. + - checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional). + - train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps. + - num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used. + - eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps. + - eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds. + - gradient_accumulation (int): Number of replica steps to accumulate gradients (optional). + - num_epochs (pydantic.PositiveInt): Number of training epochs. + """ save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.") num_train_steps: pydantic.PositiveInt = 10000 initial_checkpoint_dir: str = pydantic.Field(