diff --git a/common/batch.py b/common/batch.py index cfd0ac0..f72b762 100644 --- a/common/batch.py +++ b/common/batch.py @@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable class BatchBase(Pipelineable, abc.ABC): + """ + A base class for batches used in pipelines. + + Attributes: + None + + """ @abc.abstractmethod def as_dict(self) -> Dict: + """ + Convert the batch into a dictionary representation. + + Returns: + Dict: A dictionary representation of the batch. + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + + """ raise NotImplementedError def to(self, device: torch.device, non_blocking: bool = False): + """ + Move the batch to the specified device. + + Args: + device (torch.device): The target device. + non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False. + + Returns: + BatchBase: A new batch on the target device. + + """ args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking) return self.__class__(**args) def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + """ + Record a CUDA stream for all tensors in the batch. + + Args: + stream (torch.cuda.streams.Stream): The CUDA stream to record. + + Returns: + None + + """ for feature_value in self.as_dict().values(): feature_value.record_stream(stream) def pin_memory(self): + """ + Pin memory for all tensors in the batch. + + Returns: + BatchBase: A new batch with pinned memory. + + """ args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.pin_memory() return self.__class__(**args) def __repr__(self) -> str: + """ + Generate a string representation of the batch. + + Returns: + str: A string representation of the batch. + + """ def obj2str(v): return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" @@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC): @property def batch_size(self) -> int: + """ + Get the batch size from the tensors in the batch. + + Returns: + int: The batch size. + + Raises: + Exception: If the batch size cannot be determined from the tensors. + + """ for tensor in self.as_dict().values(): if tensor is None: continue @@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC): @dataclass class DataclassBatch(BatchBase): + """ + A batch class that uses dataclasses to define its fields. + + Attributes: + None + + """ @classmethod def feature_names(cls): + """ + Get the feature names of the dataclass. + + Returns: + List[str]: A list of feature names. + + """ return list(cls.__dataclass_fields__.keys()) def as_dict(self): + """ + Convert the dataclass batch into a dictionary representation. + + Returns: + Dict: A dictionary representation of the batch. + + """ return { feature_name: getattr(self, feature_name) for feature_name in self.feature_names() @@ -64,7 +147,18 @@ class DataclassBatch(BatchBase): @staticmethod def from_schema(name: str, schema): - """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" + """ + Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor. + + Args: + name (str): The name of the custom batch class. + schema: The schema or structure of the batch. + + Returns: + Type[DataclassBatch]: A custom batch class. + + """ + return dataclasses.make_dataclass( cls_name=name, fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names], @@ -73,6 +167,17 @@ class DataclassBatch(BatchBase): @staticmethod def from_fields(name: str, fields: dict): + """ + Create a custom batch subclass from a set of fields. + + Args: + name (str): The name of the custom batch class. + fields (dict): A dictionary specifying the fields and their types. + + Returns: + Type[DataclassBatch]: A custom batch class. + + """ return dataclasses.make_dataclass( cls_name=name, fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()], @@ -81,5 +186,19 @@ class DataclassBatch(BatchBase): class DictionaryBatch(BatchBase, dict): + """ + A batch class that represents data as a dictionary. + + Attributes: + None + + """ def as_dict(self) -> Dict: + """ + Convert the dictionary batch into a dictionary representation. + + Returns: + Dict: A dictionary representation of the batch. + + """ return self diff --git a/common/checkpointing/snapshot.py b/common/checkpointing/snapshot.py index 2703efd..ad1fac1 100644 --- a/common/checkpointing/snapshot.py +++ b/common/checkpointing/snapshot.py @@ -13,23 +13,30 @@ GCS_PREFIX = "gs://" class Snapshot: - """Checkpoints using torchsnapshot. - - Also saves step to be updated by the training loop. - """ + Checkpoints using torchsnapshot. Also saves step to be updated by the training loop. + """ def __init__(self, save_dir: str, state: Dict[str, Any]) -> None: + """ + Initializes a Snapshot object. + + Args: + save_dir (str): Directory where checkpoints will be saved. + state (Dict[str, Any]): State dictionary containing checkpoint information. + """ self.save_dir = save_dir self.state = state self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0) @property def step(self): + """Get the current training step.""" return self.state["extra_state"]["step"] @step.setter def step(self, step: int) -> None: + """Set the current training step.""" self.state["extra_state"]["step"] = step @property @@ -41,7 +48,15 @@ class Snapshot: self.state["extra_state"]["walltime"] = walltime def save(self, global_step: int) -> "PendingSnapshot": - """Saves checkpoint with given global_step.""" + """ + Saves a checkpoint with a given global step. + + Args: + global_step (int): The global step to associate with the checkpoint. + + Returns: + PendingSnapshot: A pending snapshot object. + """ path = os.path.join(self.save_dir, str(global_step)) logging.info(f"Saving snapshot global_step {global_step} to {path}.") start_time = time.time() @@ -58,7 +73,12 @@ class Snapshot: return snapshot def restore(self, checkpoint: str) -> None: - """Restores a given checkpoint.""" + """ + Restores a given checkpoint. + + Args: + checkpoint (str): Path to the checkpoint to restore. + """ snapshot = torchsnapshot.Snapshot(path=checkpoint) logging.info(f"Restoring snapshot from {snapshot.path}.") start_time = time.time() @@ -83,12 +103,17 @@ class Snapshot: global_step: Optional[int] = None, missing_ok: bool = False, ) -> torchsnapshot.Snapshot: - """Get torch stateless snapshot, without actually loading it. - Args: - snapshot_path: path to the model snapshot - global_step: restores from this checkpoint if specified. - missing_ok: if True and checkpoints do not exist, returns without restoration. """ + Get a torch stateless snapshot, without actually loading it. + + Args: + snapshot_path (str): Path to the model snapshot. + global_step (int, optional): Restores from this checkpoint if specified. + missing_ok (bool): If True and checkpoints do not exist, returns without restoration. + + Returns: + torchsnapshot.Snapshot: A torch snapshot object. + """ path = get_checkpoint(snapshot_path, global_step, missing_ok) logging.info(f"Loading snapshot from {path}.") return torchsnapshot.Snapshot(path=path) @@ -100,13 +125,14 @@ class Snapshot: snapshot_emb_name: str, weight_tensor, ) -> None: - """Loads pretrained embedding from the snapshot to the model. - Utilise partial lodaing meachanism from torchsnapshot. - Args: - embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC). - snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC. - weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded. """ + Loads pretrained embedding from the snapshot to the model. + + Args: + embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC). + snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC. + weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded. + """ start_time = time.time() manifest = embedding_snapshot.get_manifest() for path in manifest.keys(): @@ -209,7 +235,22 @@ def get_checkpoint( def get_checkpoints(save_dir: str) -> List[str]: - """Gets all checkpoints that have been fully written.""" + """ + Get a list of fully written checkpoints in the specified directory. + + This function retrieves a list of fully written checkpoints in the given directory. + Checkpoints that are considered fully written include those that have a + corresponding snapshot metadata file. + + Args: + save_dir (str): The directory where checkpoints are stored. + + Returns: + List[str]: A list of fully written checkpoint paths. + + Note: + Checkpoints are sorted by their numeric filenames in ascending order. + """ checkpoints = [] fs = infer_fs(save_dir) if fs.exists(save_dir): @@ -232,6 +273,18 @@ def wait_for_evaluators( global_step: int, timeout: int, ) -> None: + """ + Waits for all evaluators to finish and checks for their completion status. + + Args: + save_dir (str): Directory where checkpoints are saved. + partition_names (List[str]): List of partition names to check for completion. + global_step (int): The global step for which to wait for evaluators. + timeout (int): Maximum time in seconds to wait for evaluators to finish. + + Returns: + None: This function returns nothing but logs the progress and results. + """ logging.info("Waiting for all evaluators to finish.") start_time = time.time() diff --git a/common/device.py b/common/device.py index 33d1bb8..b51c638 100644 --- a/common/device.py +++ b/common/device.py @@ -5,6 +5,15 @@ import torch.distributed as dist def maybe_setup_tensorflow(): + """ + Try to import TensorFlow and disable GPU devices if TensorFlow is available. + + This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch. + + Returns: + None + + """ try: import tensorflow as tf except ImportError: @@ -14,6 +23,18 @@ def maybe_setup_tensorflow(): def setup_and_get_device(tf_ok: bool = True) -> torch.device: + """ + Set up the distributed environment and get the appropriate torch device. + + This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank. + + Args: + tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True. + + Returns: + torch.device: The torch device for the current process. + + """ if tf_ok: maybe_setup_tensorflow() diff --git a/common/filesystem/util.py b/common/filesystem/util.py index 5148a0d..30360f2 100644 --- a/common/filesystem/util.py +++ b/common/filesystem/util.py @@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem() def infer_fs(path: str): + """ + Infer the file system (fs) type based on the given path. + + Args: + path (str): The file path. + + Returns: + str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system). + + Raises: + NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported. + """ if path.startswith("gs://"): return GCS_FS elif path.startswith("hdfs://"): @@ -18,8 +30,26 @@ def infer_fs(path: str): def is_local_fs(fs): + """ + Check if the given file system is the local file system. + + Args: + fs (str): The file system type to check. + + Returns: + bool: True if the file system is the local file system, False otherwise. + """ return fs == LOCAL_FS def is_gcs_fs(fs): + """ + Check if the given file system is Google Cloud Storage (GCS). + + Args: + fs (str): The file system type to check. + + Returns: + bool: True if the file system is GCS, False otherwise. + """ return fs == GCS_FS diff --git a/common/log_weights.py b/common/log_weights.py index d2c8e21..cecf42b 100644 --- a/common/log_weights.py +++ b/common/log_weights.py @@ -49,15 +49,25 @@ def log_ebc_norms( ebc_keys: List[str], sample_size: int = 4_000_000, ) -> Dict[str, torch.Tensor]: - """Logs the norms of the embedding tables as specified by ebc_keys. - As of now, log average norm per rank. - - Args: - model_state_dict: model.state_dict() - ebc_keys: list of embedding keys from state_dict to log. Must contain full name, - i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight - sample_size: Limits number of rows per rank to compute average on to avoid OOM. """ + Logs the norms of the embedding tables as specified by ebc_keys. + As of now, log average norm per rank. + + Args: + model_state_dict (dict): The state dictionary of the model. + ebc_keys (List[str]): List of embedding keys from state_dict to log. + Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight". + sample_size (int, optional): Limits the number of rows per rank to compute average on + to avoid out-of-memory (OOM) errors. Defaults to 4,000,000. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables. + The keys are in the format "{emb_key}-norm-{idx}". + + Note: + This function computes and logs the average norm of embedding tables across ranks. + It gathers the norms from all ranks and returns them as a dictionary. + """ norm_logs = dict() for emb_key in ebc_keys: norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}")) diff --git a/common/modules/embedding/embedding.py b/common/modules/embedding/embedding.py index b0a085e..579faa5 100644 --- a/common/modules/embedding/embedding.py +++ b/common/modules/embedding/embedding.py @@ -11,6 +11,20 @@ import numpy as np class LargeEmbeddings(nn.Module): + """ + A module for handling large embeddings. + + Args: + large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings. + + Attributes: + ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings. + surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models. + + Note: + The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving. + + """ def __init__( self, large_embeddings_config: LargeEmbeddingsConfig, @@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module): self, sparse_features: KeyedJaggedTensor, ) -> KeyedTensor: + """ + Forward pass of the LargeEmbeddings module. + + Args: + sparse_features (KeyedJaggedTensor): Sparse input features. + + Returns: + KeyedTensor: The output of the forward pass, a KeyedTensor. + + """ pooled_embs = self.ebc(sparse_features) # a KeyedTensor diff --git a/common/run_training.py b/common/run_training.py index fafb140..20a3d34 100644 --- a/common/run_training.py +++ b/common/run_training.py @@ -11,6 +11,12 @@ import torch.distributed.run def is_distributed_worker(): + """ + Checks if the current process is a distributed worker. + + Returns: + bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False. + """ world_size = os.environ.get("WORLD_SIZE", None) rank = os.environ.get("RANK", None) return world_size is not None and rank is not None @@ -25,25 +31,36 @@ def maybe_run_training( is_chief: Optional[bool] = False, **training_kwargs, ): - """Wrapper function for single node, multi-GPU Pytorch training. - - If the necessary distributed Pytorch environment variables - (WORLD_SIZE, RANK) have been set, then this function executes - `train_fn(**training_kwargs)`. - - Otherwise, this function calls torchrun and points at the calling module - `module_name`. After this call, the necessary environment variables are set - and training will commence. - - Args: - train_fn: The function that is responsible for training - module_name: The name of the module that this function was called from; - used to indicate torchrun entrypoint. - nproc_per_node: Number of workers per node; supported values. - num_nodes: Number of nodes, otherwise inferred from environment. - is_chief: If process is running on chief. - set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH. """ + Wrapper function for single node, multi-GPU PyTorch training. + + If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes + `train_fn(**training_kwargs)`. + + Otherwise, this function calls torchrun and points at the calling module + `module_name`. After this call, the necessary environment variables are set + and training will commence. + + Args: + train_fn (callable): The function responsible for training. + module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint. + nproc_per_node (int, optional): Number of workers per node. Defaults to None. + num_nodes (int, optional): Number of nodes. Defaults to None. + is_chief (bool, optional): If the process is running on the chief node. Defaults to False. + set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False. + **training_kwargs: Additional keyword arguments to pass to the `train_fn`. + + Note: + This function checks if the current process is a distributed worker by examining the environment variables. + If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary + environment variables and launches the training process using torchrun. + + Example: + To run training on a single node with 4 GPUs, you can use: + ``` + maybe_run_training(train_function, __name__, nproc_per_node=4) + ``` + """ machines = utils.machine_from_env() if num_nodes is None: diff --git a/common/testing_utils.py b/common/testing_utils.py index 69e14de..6feaf6b 100644 --- a/common/testing_utils.py +++ b/common/testing_utils.py @@ -19,6 +19,23 @@ MOCK_ENV = { @contextmanager def mock_pg(): + """ + A context manager for mocking the distributed process group for testing purposes. + + This context manager temporarily sets environment variables to mock the distributed process group + and initializes it using the Gloo backend. It is useful for testing distributed training without + actually launching multiple processes. + + Example: + ```python + with mock_pg(): + # Your distributed training code here + ``` + + Note: + This context manager should be used within a testing environment to simulate distributed training + without actually creating multiple processes. + """ with patch.dict(os.environ, MOCK_ENV): try: dist.init_process_group( diff --git a/common/utils.py b/common/utils.py index 09d4464..e509ffb 100644 --- a/common/utils.py +++ b/common/utils.py @@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig) def _read_file(f): + """ + Read the contents of a file using fsspec. + + Args: + f: File path or URL. + + Returns: + The contents of the file. + """ with fsspec.open(f) as f: return f.read() @@ -21,18 +30,24 @@ def setup_configuration( yaml_path: str, substitute_env_variable: bool = False, ) -> Tuple[C, str]: - """Resolves a config at a yaml path. - - Args: - config_type: Pydantic config class to load. - yaml_path: yaml path of the config file. - substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their - environment variable value whenever possible. If an environment variable doesn't exist, - the string is left unchanged. - - Returns: - The pydantic config object. """ + Load a Pydantic config object from a YAML file and optionally substitute environment variables. + + Args: + config_type: Pydantic config class to load. + yaml_path: YAML path of the config file. + substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR} + with their environment variable values whenever possible. + If an environment variable doesn't exist, the string is left unchanged. + + Returns: + A tuple containing the Pydantic config object and the resolved YAML content. + + Example: + ```python + config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True) + ``` + """ def _substitute(s): if substitute_env_variable: diff --git a/common/wandb.py b/common/wandb.py index eb8cceb..9808db0 100644 --- a/common/wandb.py +++ b/common/wandb.py @@ -6,6 +6,33 @@ import pydantic class WandbConfig(base_config.BaseConfig): + """ + Configuration for integrating with Weights and Biases (WandB). + + Attributes: + host (str): Host of the Weights and Biases instance, passed to login. + key_path (str): Path to the key file. + name (str): Name of the experiment, passed to init. + entity (str): Name of the user/service account, passed to init. + project (str): Name of the WandB project, passed to init. + tags (List[str]): List of tags, passed to init. + notes (str): Notes, passed to init. + metadata (Dict[str, Any]): Additional metadata to log. + + Example: + ```python + wandb_config = WandbConfig( + host="https://wandb.example.com", + key_path="/path/to/key", + name="experiment_1", + entity="user123", + project="my_project", + tags=["experiment", "ml"], + notes="This is a test experiment.", + metadata={"version": "1.0"} + ) + ``` + """ host: str = pydantic.Field( "https://https--wandb--prod--wandb.service.qus1.twitter.biz/", description="Host of Weights and Biases instance, passed to login.", diff --git a/core/config/base_config.py b/core/config/base_config.py index 1c15d04..76a2022 100644 --- a/core/config/base_config.py +++ b/core/config/base_config.py @@ -8,59 +8,69 @@ import pydantic class BaseConfig(pydantic.BaseModel): - """Base class for all derived config classes. + """Base class for all derived config classes. - This class provides some convenient functionality: - - Disallows extra fields when constructing an object. User error - should be reduced by exact arguments. - - "one_of" fields. A subclass can group optional fields and enforce - that only one of the fields be set. For example: + This class provides convenient functionality and constraints for derived config classes: - ``` + - Disallows extra fields when constructing an object. User errors due to extraneous arguments + are minimized. + - "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields + be set. For example: + + ```python class ExampleConfig(BaseConfig): - x: int = Field(None, one_of="group_1") - y: int = Field(None, one_of="group_1") + x: int = Field(None, one_of="group_1") + y: int = Field(None, one_of="group_1") - ExampleConfig(x=1) # ok - ExampleConfig(y=1) # ok - ExampleConfig(x=1, y=1) # throws error + ExampleConfig(x=1) # OK + ExampleConfig(y=1) # OK + ExampleConfig(x=1, y=1) # Raises an error ``` - """ - class Config: - """Forbids extras.""" + Attributes: + Config (class): Configuration options for this class, forbidding extra fields. - extra = pydantic.Extra.forbid # noqa + Methods: + _field_data_map(cls, field_data_name): Create a map of fields with the provided field data. + _one_of_check(cls, values): Validate that all 'one of' fields appear exactly once. + _at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once. + pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging. - @classmethod - @functools.lru_cache() - def _field_data_map(cls, field_data_name): - """Create a map of fields with provided the field data.""" - schema = cls.schema() - one_of = collections.defaultdict(list) - for field, fdata in schema["properties"].items(): - if field_data_name in fdata: - one_of[fdata[field_data_name]].append(field) - return one_of + """ + class Config: + """Configuration options that forbid extra fields.""" + extra = pydantic.Extra.forbid # noqa - @pydantic.root_validator - def _one_of_check(cls, values): - """Validate that all 'one of' fields are appear exactly once.""" - one_of_map = cls._field_data_map("one_of") - for one_of, field_names in one_of_map.items(): - if sum([values.get(n, None) is not None for n in field_names]) != 1: - raise ValueError(f"Exactly one of {','.join(field_names)} required.") - return values + @classmethod + @functools.lru_cache() + def _field_data_map(cls, field_data_name): + """Create a map of fields with the provided field data.""" + schema = cls.schema() + one_of = collections.defaultdict(list) + for field, fdata in schema["properties"].items(): + if field_data_name in fdata: + one_of[fdata[field_data_name]].append(field) + return one_of - @pydantic.root_validator - def _at_most_one_of_check(cls, values): - """Validate that all 'at_most_one_of' fields appear at most once.""" - at_most_one_of_map = cls._field_data_map("at_most_one_of") - for one_of, field_names in at_most_one_of_map.items(): - if sum([values.get(n, None) is not None for n in field_names]) > 1: - raise ValueError(f"At most one of {','.join(field_names)} can be set.") - return values + @pydantic.root_validator + def _one_of_check(cls, values): + """Validate that all 'one of' fields appear exactly once.""" + one_of_map = cls._field_data_map("one_of") + for one_of, field_names in one_of_map.items(): + if sum([values.get(n, None) is not None for n in field_names]) != 1: + raise ValueError(f"Exactly one of {','.join(field_names)} required.") + return values - def pretty_print(self) -> str: - """Return a human legible (yaml) representation of the config useful for logging.""" - return yaml.dump(self.dict()) + @pydantic.root_validator + def _at_most_one_of_check(cls, values): + """Validate that all 'at_most_one_of' fields appear at most once.""" + at_most_one_of_map = cls._field_data_map("at_most_one_of") + for one_of, field_names in at_most_one_of_map.items(): + if sum([values.get(n, None) is not None for n in field_names]) > 1: + raise ValueError(f"At most one of {','.join(field_names)} can be set.") + + return values + + def pretty_print(self) -> str: + """Return a human-readable (YAML) representation of the config useful for logging.""" + return yaml.dump(self.dict()) diff --git a/core/config/base_config_test.py b/core/config/base_config_test.py index e623f05..c4319e5 100644 --- a/core/config/base_config_test.py +++ b/core/config/base_config_test.py @@ -6,33 +6,67 @@ import pydantic class BaseConfigTest(TestCase): - def test_extra_forbidden(self): - class Config(BaseConfig): - x: int + """ + Unit tests for the BaseConfig class. + """ - Config(x=1) - with self.assertRaises(pydantic.ValidationError): - Config(x=1, y=2) + def test_extra_forbidden(self): + """ + Test that extra fields are forbidden when creating a Config instance. - def test_one_of(self): - class Config(BaseConfig): - x: int = pydantic.Field(None, one_of="f") - y: int = pydantic.Field(None, one_of="f") + This test case checks whether the `BaseConfig` class correctly raises a + `pydantic.ValidationError` when extra fields are provided when creating a + `Config` instance. - with self.assertRaises(pydantic.ValidationError): - Config() - Config(x=1) - Config(y=1) - with self.assertRaises(pydantic.ValidationError): - Config(x=1, y=3) + Raises: + AssertionError: If the test fails. + """ + class Config(BaseConfig): + x: int - def test_at_most_one_of(self): - class Config(BaseConfig): - x: int = pydantic.Field(None, at_most_one_of="f") - y: str = pydantic.Field(None, at_most_one_of="f") + Config(x=1) + with self.assertRaises(pydantic.ValidationError): + Config(x=1, y=2) - Config() - Config(x=1) - Config(y="a") - with self.assertRaises(pydantic.ValidationError): - Config(x=1, y="a") + def test_one_of(self): + """ + Test the use of the `one_of` attribute for fields in a Config instance. + + This test case checks the behavior of the `one_of` attribute in a `Config` + instance. It verifies that the `pydantic.Field` correctly enforces the + specified constraint. + + Raises: + AssertionError: If the test fails. + """ + class Config(BaseConfig): + x: int = pydantic.Field(None, one_of="f") + y: int = pydantic.Field(None, one_of="f") + + with self.assertRaises(pydantic.ValidationError): + Config() + Config(x=1) + Config(y=1) + with self.assertRaises(pydantic.ValidationError): + Config(x=1, y=3) + + def test_at_most_one_of(self): + """ + Test the use of the `at_most_one_of` attribute for fields in a Config instance. + + This test case checks the behavior of the `at_most_one_of` attribute in a + `Config` instance. It verifies that the `pydantic.Field` enforces the + constraint where at most one of the specified fields can be provided. + + Raises: + AssertionError: If the test fails. + """ + class Config(BaseConfig): + x: int = pydantic.Field(None, at_most_one_of="f") + y: str = pydantic.Field(None, at_most_one_of="f") + + Config() + Config(x=1) + Config(y="a") + with self.assertRaises(pydantic.ValidationError): + Config(x=1, y="a") diff --git a/core/config/config_load.py b/core/config/config_load.py index 709da41..f04c783 100644 --- a/core/config/config_load.py +++ b/core/config/config_load.py @@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str): - """Recommend method to load a config file (a yaml file) and parse it. - - Because we have a shared filesystem the recommended route to running jobs it put modified config - files with the desired parameters somewhere on the filesytem and run jobs pointing to them. """ + Recommend method to Load and parse a configuration from a YAML file. + + This function loads a configuration from a YAML file, parses it, and returns an instance of the + specified config type. + + Because we have a shared filesystem the recommended route to running jobs it put modified config + files with the desired parameters somewhere on the filesytem and run jobs pointing to them. + + Args: + config_type (Type[BaseConfig]): The Pydantic config class to load. + yaml_path (str): The path to the YAML configuration file. + + Returns: + BaseConfig: An instance of the specified config type populated with values from the YAML file. + + Example: + Suppose you have a YAML file 'my_config.yaml' containing the following: + + ```yaml + x: 42 + y: "hello" + ``` + + You can load and parse it using this function as follows: + + ```python + my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml') + ``` + + Note: + This function performs environment variable substitution in the YAML file. It replaces + occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable + values. If an environment variable does not exist, the string is left unchanged. + + """ def _substitute(s): return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) diff --git a/core/config/test_config_load.py b/core/config/test_config_load.py index 0fc88e2..1ad3942 100644 --- a/core/config/test_config_load.py +++ b/core/config/test_config_load.py @@ -8,11 +8,27 @@ import pydantic class _PointlessConfig(BaseConfig): + a: int user: str def test_load_config_from_yaml(tmp_path): + """Test loading a configuration from a YAML file and verifying its values. + + This test function checks the functionality of the `load_config_from_yaml` function by creating + a temporary YAML configuration file, loading it, and asserting that the loaded config object + has the expected values. + + Args: + tmp_path: A temporary directory provided by the `pytest` framework. + + Test Steps: + 1. Create a temporary YAML file containing configuration data. + 2. Use the `load_config_from_yaml` function to load the configuration from the YAML file. + 3. Assert that the loaded configuration object has the expected values. + + """ yaml_path = tmp_path.joinpath("test.yaml").as_posix() with open(yaml_path, "w") as yaml_file: yaml_file.write("""a: 3\nuser: ${USER}\n""") diff --git a/core/config/training.py b/core/config/training.py index a0c54b4..acdbde7 100644 --- a/core/config/training.py +++ b/core/config/training.py @@ -9,6 +9,14 @@ import pydantic class RuntimeConfig(base_config.BaseConfig): + """ + Configuration for runtime settings. + + Attributes: + - wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration. + - enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices. + - enable_amp (bool): Enable automatic mixed precision. + """ wandb: WandbConfig = pydantic.Field(None) enable_tensorfloat32: bool = pydantic.Field( False, description="Use tensorfloat32 if on Ampere devices." @@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig): class TrainingConfig(base_config.BaseConfig): + """ + Configuration for training settings. + + Attributes: + - save_dir (str): Directory to save checkpoints. + - num_train_steps (pydantic.PositiveInt): Number of training steps. + - initial_checkpoint_dir (str): Directory of initial checkpoints (optional). + - checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps. + - checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional). + - train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps. + - num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used. + - eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps. + - eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds. + - gradient_accumulation (int): Number of replica steps to accumulate gradients (optional). + - num_epochs (pydantic.PositiveInt): Number of training epochs. + """ save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.") num_train_steps: pydantic.PositiveInt = 10000 initial_checkpoint_dir: str = pydantic.Field( diff --git a/core/custom_training_loop.py b/core/custom_training_loop.py index 0241145..73b2cf1 100644 --- a/core/custom_training_loop.py +++ b/core/custom_training_loop.py @@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable): def _get_step_fn(pipeline, data_iterator, training: bool): + """ + Returns a function to perform a single evaluation step. + + Args: + pipeline (Pipeline): The pipeline object containing the model. + data_iterator (Iterator): The data iterator for evaluation. + training (bool): Flag indicating if the model should be in training mode. + + Returns: + function: A function that performs a single evaluation step. + """ def step_fn(): + """ + Perform a single evaluation step. + + Returns: + Any: The evaluation results after a single step. + """ # It turns out that model.train() and model.eval() simply switch a single field inside the model # class,so it's somewhat safer to wrap in here. if training: @@ -69,7 +86,21 @@ def _run_evaluation( eval_batch_size: int, logger=None, ): - """Runs the evaluation loop over all evaluation iterators.""" + """ + Run the evaluation loop over all evaluation iterators. + + Args: + pipeline (Pipeline): The pipeline object containing the model. + dataset (Dataset): The dataset to evaluate. + eval_steps (int): The number of evaluation steps to perform. + metrics (tm.MetricCollection): A collection of evaluation metrics. + eval_batch_size (int): Batch size for evaluation. + logger (Optional[Logger]): A logger for recording evaluation progress (default: None). + + Returns: + dict: A dictionary containing the computed evaluation metrics. + """ + dataset = get_new_iterator(dataset) step_fn = _get_step_fn(pipeline, dataset, training=False) last_time = datetime.datetime.now() @@ -109,15 +140,29 @@ def train( parameters_to_log: Optional[Dict[str, Callable]] = None, tables_to_log: Optional[List[str]] = None, ) -> None: - """Runs training and eval on the given TrainPipeline - - Args: - dataset: data iterator for the training set - evaluation_iterators: data iterators for the different evaluation sets - scheduler: optional learning rate scheduler - output_transform_for_metrics: optional transformation functions to transorm the model - output and labels into a format the metrics can understand """ + Runs training and evaluation on the given TrainPipeline. + + Args: + model (torch.nn.Module): The neural network model to train. + optimizer (torch.optim.Optimizer): The optimizer for model optimization. + device (str): The target device for model training (e.g., 'cuda' or 'cpu'). + save_dir (str): The directory to save model checkpoints and logs. + logging_interval (int): Interval for logging training progress. + train_steps (int): The number of training steps to perform. + checkpoint_frequency (int): Frequency of saving model checkpoints. + dataset (Iterable): Data iterator for the training set. + worker_batch_size (int): Batch size for data loading workers. + num_workers (Optional[int]): Number of data loading workers (default: 0). + enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False). + initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None). + gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None). + logger_initializer (Optional[Callable]): A logger initializer function (default: None). + scheduler (_LRScheduler): Optional learning rate scheduler (default: None). + metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None). + parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None). + tables_to_log (Optional[List[str]]): List of tables to log (default: None). + """ train_pipeline = TrainPipelineSparseDist( model=model, @@ -262,6 +307,15 @@ def log_eval_results( partition_name: str, step: int, ): + """ + Logs evaluation results and optionally records them using a provided logger. + + Args: + results (Any): The evaluation results to log. + eval_logger (Callable): A logger for recording evaluation results. + partition_name (str): The name of the evaluation partition. + step (int): The current step in the evaluation. + """ results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results) logging.info(f"Step: {step}, evaluation ({partition_name}).") for metric_name, metric_value in results.items(): @@ -285,6 +339,23 @@ def only_evaluate( partition_name: str, metrics: Optional[tm.MetricCollection] = None, ): + """ + Performs evaluation on a given dataset partition. + + Args: + model (torch.nn.Module): The neural network model for evaluation. + optimizer (torch.optim.Optimizer): The optimizer used during evaluation. + device (str): The target device for evaluation (e.g., 'cuda' or 'cpu'). + save_dir (str): The directory containing model checkpoints. + num_train_steps (int): The total number of training steps. + dataset (Iterable): Data iterator for evaluation. + eval_batch_size (int): Batch size for evaluation. + num_eval_steps (int): The number of evaluation steps to perform. + eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds. + eval_logger (Callable): A logger for recording evaluation results. + partition_name (str): The name of the evaluation partition. + metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None). + """ logging.info(f"Evaluating on partition {partition_name}.") logging.info("Computing metrics:") logging.info(metrics) diff --git a/core/debug_training_loop.py b/core/debug_training_loop.py index 610eea9..bced83f 100644 --- a/core/debug_training_loop.py +++ b/core/debug_training_loop.py @@ -28,6 +28,18 @@ def train( *args, **kwargs, ) -> None: + """ + Debugging training loop. Do not use for actual model training. + + Args: + model (torch.nn.Module): The neural network model. + optimizer (torch.optim.Optimizer): The optimizer for model optimization. + train_steps (int): The number of training steps to perform. + dataset (Iterable): Data iterator for training data. + scheduler (_LRScheduler, optional): Learning rate scheduler (default: None). + *args: Additional arguments (ignored). + **kwargs: Additional keyword arguments (ignored). + """ logging.warning("Running debug training loop, don't use for model training.") diff --git a/core/losses.py b/core/losses.py index 6ef9a4a..7cf0cf1 100644 --- a/core/losses.py +++ b/core/losses.py @@ -10,8 +10,11 @@ import torch def _maybe_warn(reduction: str): """ - Warning for reduction different than mean. - """ + Emit a warning if the reduction method is different from 'mean'. + + Args: + reduction (str): The reduction method being used. + """ if reduction != "mean": logging.warn( f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal," @@ -24,6 +27,16 @@ def build_loss( loss_type: LossType, reduction="mean", ): + """ + Build a loss function based on the specified loss type and reduction method. + + Args: + loss_type (LossType): The type of loss to build. + reduction (str): The reduction method for the loss (default: 'mean'). + + Returns: + Callable: A loss function that takes logits and labels as input. + """ _maybe_warn(reduction) f = _LOSS_TYPE_TO_FUNCTION[loss_type] @@ -35,11 +48,15 @@ def build_loss( def get_global_loss_detached(local_loss, reduction="mean"): """ - Perform all_reduce to obtain the global loss function using the provided reduction. - :param local_loss: The local loss of the current rank. - :param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP. - :return: The reduced & detached global loss. - """ + Perform all_reduce to obtain the global loss function using the provided reduction. + + Args: + local_loss (torch.Tensor): The local loss of the current rank. + reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP. + + Returns: + torch.Tensor: The reduced and detached global loss. + """ if reduction != "mean": logging.warn( f"The reduction used in this function should be the same as the one used by " @@ -66,6 +83,19 @@ def build_multi_task_loss( global_reduction="mean", pos_weights=None, ): + """ + Build a multi-task loss function based on the specified loss type and configurations. + + Args: + loss_type (LossType): The type of loss to build. + tasks (typing.List[str]): List of task names. + task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean'). + global_reduction (str): Reduction method for the global loss (default: 'mean'). + pos_weights (Optional): Positive class weights for tasks (default: None). + + Returns: + Callable: A multi-task loss function that takes logits, labels, and weights as input. + """ _maybe_warn(global_reduction) _maybe_warn(task_loss_reduction) f = _LOSS_TYPE_TO_FUNCTION[loss_type] diff --git a/core/metric_mixin.py b/core/metric_mixin.py index a716ca7..def38cb 100644 --- a/core/metric_mixin.py +++ b/core/metric_mixin.py @@ -36,9 +36,24 @@ import torchmetrics class MetricMixin: @abstractmethod def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: + """ + Abstract method to transform model outputs into a dictionary of metrics. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + + Returns: + Dict: A dictionary of computed metrics. + """ ... def update(self, outputs: Dict[str, torch.Tensor]): + """ + Update the metrics based on model outputs. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + """ results = self.transform(outputs) # Do not try to update if any tensor is empty as a result of stratification. for value in results.values(): @@ -49,6 +64,13 @@ class MetricMixin: class TaskMixin: def __init__(self, task_idx: int = -1, **kwargs): + """ + Initialize a TaskMixin instance. + + Args: + task_idx (int): Index of the task associated with this mixin (default: -1). + **kwargs: Additional keyword arguments. + """ super().__init__(**kwargs) self._task_idx = task_idx @@ -59,13 +81,31 @@ class StratifyMixin: stratifier=None, **kwargs, ): + """ + Initialize a StratifyMixin instance. + + Args: + stratifier: A stratifier for filtering outputs (default: None). + **kwargs: Additional keyword arguments. + """ super().__init__(**kwargs) self._stratifier = stratifier def maybe_apply_stratification( self, outputs: Dict[str, torch.Tensor], value_names: List[str] ) -> Dict[str, torch.Tensor]: - """Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.""" + """ + Apply stratification to filter examples in the outputs. + + Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + value_names (List[str]): Names of values to filter. + + Returns: + Dict[str, torch.Tensor]: Filtered outputs. + """ outputs = outputs.copy() if not self._stratifier: return outputs @@ -84,12 +124,20 @@ class StratifyMixin: def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable): - """Returns new class using MetricMixin and given base_metric. - - Functionally the same using inheritance, just saves some lines of code - if no need for class attributes. - """ + Returns a new class using MetricMixin and the given base_metric. + + Functionally the same as using inheritance, but it saves some lines of code + if there's no need for class attributes. + + Args: + base_metric (torchmetrics.Metric): The base metric class to prepend the transform to. + transform (Callable): The transformation function to prepend to the metric. + + Returns: + Type: A new class that includes MetricMixin and the provided base_metric + with the specified transformation method. + """ def transform_method(_self, *args, **kwargs): return transform(*args, **kwargs) diff --git a/core/metrics.py b/core/metrics.py index 2384e4d..0a9c38b 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -15,6 +15,16 @@ def probs_and_labels( outputs: Dict[str, torch.Tensor], task_idx: int, ) -> Dict[str, torch.Tensor]: + """ + Extract probabilities and labels from model outputs. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + task_idx (int): Index of the task. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors. + """ preds = outputs["probabilities"] target = outputs["labels"] if task_idx >= 0: @@ -28,6 +38,11 @@ def probs_and_labels( class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): def transform(self, outputs): + """ + Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric. + + This metric counts values after potential stratification and task selection. + """ outputs = self.maybe_apply_stratification(outputs, ["labels"]) value = outputs["labels"] if self._task_idx >= 0: @@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): + """ + Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric. + + This metric calculates the mean metric value after potential stratification and task selection. + """ + def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["labels"]) value = outputs["labels"] @@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): + """ + Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric. + + This metric calculates the mean metric value using probabilities after potential stratification and task selection. + """ def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["probabilities"]) value = outputs["probabilities"] @@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision): + """ + Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision. + + This metric computes precision after potential stratification and task selection. + """ def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) return probs_and_labels(outputs, self._task_idx) class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall): + """ + Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall. + + This metric computes recall after potential stratification and task selection. + """ def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) return probs_and_labels(outputs, self._task_idx) @@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC): class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ + AUC (Area Under the ROC Curve) metric class. + + This metric computes the AUC metric based on the logits and labels in the model outputs. + + Args: + num_samples (int): The number of samples used to compute AUC. + **kwargs: Additional keyword arguments. + Based on: https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420 """ @@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ - The ranks of all positives - Based on: + PosRanks metric class. + + This metric computes the ranks of all positive examples based on the logits and labels + in the model outputs. + + Args: + **kwargs: Additional keyword arguments. + https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73 """ @@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ - The reciprocal of the ranks of all - Based on: + ReciprocalRank metric class. + + This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels + in the model outputs. + + Args: + **kwargs: Additional keyword arguments. https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74 """ @@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ - The fraction of positives that rank in the top K among their negatives - Note that this is basically precision@k - Based on: + HitAtK metric class. + + This metric computes the fraction of positive examples that rank in the top K among their negatives, + which is equivalent to precision@K. + + Args: + k (int): The value of K. + **kwargs: Additional keyword arguments. https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75 """ diff --git a/core/test_metrics.py b/core/test_metrics.py index ac29819..c9a34e8 100644 --- a/core/test_metrics.py +++ b/core/test_metrics.py @@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric @dataclass class MockStratifierConfig: + """ + Configuration dataclass for mocking a stratifier. + + Args: + name (str): The name of the stratifier. + index (int): The index of the stratifier. + value (int): The value of the stratifier. + """ name: str index: int value: int class Count(MetricMixin, SumMetric): + """ + Count metric class that inherits from MetricMixin and SumMetric. + + This metric counts occurrences. + + """ def transform(self, outputs): return {"value": 1} @@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]}) def test_count_metric(): + """ + Test function for the Count metric. + + It checks if the Count metric correctly counts the number of examples. + + """ num_examples = 123 examples = [ {"stuff": 0}, @@ -36,6 +56,12 @@ def test_count_metric(): def test_collections(): + """ + Test function for metric collections. + + It tests if metric collections correctly aggregate metrics. + + """ max_metric = Max() count_metric = Count() metric = MetricCollection([max_metric, count_metric]) @@ -51,6 +77,12 @@ def test_collections(): def test_task_dependent_ctr(): + """ + Test function for task-dependent Ctr (Click-Through Rate) metric. + + It checks if the Ctr metric computes the correct value for different tasks. + + """ num_examples = 144 batch_size = 1024 outputs = [ @@ -69,6 +101,13 @@ def test_task_dependent_ctr(): def test_stratified_ctr(): + """ + Test function for the Stratified Ctr (Click-Through Rate) metric. + + It checks if the Stratified Ctr metric computes the correct value for different tasks + and stratified samples. + + """ outputs = [ { "stuff": 0, @@ -114,6 +153,12 @@ def test_stratified_ctr(): def test_auc(): + """ + Test function for the AUC (Area Under the Curve) metric. + + It checks if the AUC metric correctly computes the Area Under the ROC Curve. + + """ num_samples = 10000 metric = core_metrics.Auc(num_samples) target = torch.tensor([0, 0, 1, 1, 1]) @@ -131,6 +176,12 @@ def test_auc(): def test_pos_rank(): + """ + Test function for the PosRanks metric. + + It checks if the PosRanks metric correctly computes the ranks of positive samples. + + """ metric = core_metrics.PosRanks() target = torch.tensor([0, 0, 1, 1, 1]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) @@ -147,6 +198,12 @@ def test_pos_rank(): def test_reciprocal_rank(): + """ + Test function for the Reciprocal Rank metric. + + It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks. + + """ metric = core_metrics.ReciprocalRank() target = torch.tensor([0, 0, 1, 1, 1]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) @@ -163,6 +220,12 @@ def test_reciprocal_rank(): def test_hit_k(): + """ + Test function for the Hit@K metric. + + It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives. + + """ hit1_metric = core_metrics.HitAtK(1) target = torch.tensor([0, 0, 1, 1, 1]) preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5]) diff --git a/core/test_train_pipeline.py b/core/test_train_pipeline.py index 8e2f6f5..ee9f401 100644 --- a/core/test_train_pipeline.py +++ b/core/test_train_pipeline.py @@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel @dataclass class MockDataclassBatch(DataclassBatch): + """ + Mock data class batch for testing purposes. + + This class represents a batch of data with continuous features and labels. + + Attributes: + continuous_features (torch.Tensor): Tensor containing continuous feature data. + labels (torch.Tensor): Tensor containing label data. + """ continuous_features: torch.Tensor labels: torch.Tensor class MockModule(torch.nn.Module): + """ + Mock PyTorch module for testing purposes. + + This module defines a simple neural network model with a linear layer + followed by a BCEWithLogitsLoss loss function. + + Attributes: + model (torch.nn.Linear): The linear model layer. + loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function. + """ def __init__(self) -> None: super().__init__() self.model = torch.nn.Linear(10, 1) self.loss_fn = torch.nn.BCEWithLogitsLoss() def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the mock module. + + Args: + batch (MockDataclassBatch): Input data batch with continuous features and labels. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions. + """ pred = self.model(batch.continuous_features) loss = self.loss_fn(pred, batch.labels) return (loss, pred) def create_batch(bsz: int): + """ + Create a mock data batch with random continuous features and labels. + + Args: + bsz (int): Batch size. + + Returns: + MockDataclassBatch: A batch of data with continuous features and labels. + """ return MockDataclassBatch( continuous_features=torch.rand(bsz, 10).float(), labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(), @@ -35,6 +72,13 @@ def create_batch(bsz: int): def test_sparse_pipeline(): + """ + Test function for the sparse pipeline with distributed model parallelism. + + This function tests the behavior of the sparse training pipeline using + a mock module and data. + """ + device = torch.device("cpu") model = MockModule().to(device) @@ -65,6 +109,15 @@ def test_sparse_pipeline(): def test_amp(): + """ + Test automatic mixed-precision (AMP) training with the sparse pipeline. + + This function tests the behavior of the sparse training pipeline with + automatic mixed-precision (AMP) enabled, using a mock module and data. + + AMP allows for faster training by using lower-precision data types, such as + torch.bfloat16, while maintaining model accuracy. + """ device = torch.device("cpu") model = MockModule().to(device) diff --git a/core/train_pipeline.py b/core/train_pipeline.py index cde587e..3209988 100644 --- a/core/train_pipeline.py +++ b/core/train_pipeline.py @@ -39,12 +39,42 @@ Out = TypeVar("Out") class TrainPipeline(abc.ABC, Generic[In, Out]): + """ + Abstract base class for training pipelines. + + Attributes: + In (TypeVar): Input data type. + Out (TypeVar): Output data type. + + Methods: + progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline. + """ @abc.abstractmethod def progress(self, dataloader_iter: Iterator[In]) -> Out: + """ + Make progress in the training pipeline. + + Args: + dataloader_iter (Iterator[In]): An iterator over input data. + + Returns: + Out: The output data. + """ pass def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: + """ + Move a batch of data to a specified device. + + Args: + batch (In): The input batch. + device (torch.device): The target device. + non_blocking (bool): If True, move the data asynchronously. + + Returns: + In: The batch of data on the target device. + """ assert isinstance( batch, (torch.Tensor, Pipelineable) ), f"{type(batch)} must implement Pipelineable interface" @@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None: + """ + Wait for a batch of data on a specified stream. + + Args: + batch (In): The input batch. + stream (Optional[Stream]): The CUDA stream to wait for. + + Note: + This function is used for managing asynchronous CUDA operations. + """ if stream is None: return torch.cuda.current_stream().wait_stream(stream) @@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N class TrainPipelineBase(TrainPipeline[In, Out]): """ - This class runs training iterations using a pipeline of two stages, each as a CUDA - stream, namely, the current (default) stream and `self._memcpy_stream`. For each - iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU - memory, and the default stream runs forward, backward, and optimization. - """ + This class runs training iterations using a pipeline of two stages, each as a CUDA + stream, namely, the current (default) stream and `self._memcpy_stream`. For each + iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU + memory, and the default stream runs forward, backward, and optimization. + + Attributes: + In (TypeVar): Input data type. + Out (TypeVar): Output data type. + + Methods: + __init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None: + Initialize the TrainPipelineBase. + + _connect(dataloader_iter: Iterator[In]) -> None: + Establish a connection to the data loader and move the input data to the GPU. + + progress(dataloader_iter: Iterator[In]) -> Out: + Execute a training iteration, including forward and backward passes. + + """ def __init__( self, @@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]): optimizer: torch.optim.Optimizer, device: torch.device, ) -> None: + """ + Initialize the TrainPipelineBase. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + optimizer (torch.optim.Optimizer): The optimizer used for training. + device (torch.device): The target device for training (CPU or GPU). + """ self._model = model self._optimizer = optimizer self._device = device @@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]): self._connected = False def _connect(self, dataloader_iter: Iterator[In]) -> None: + """ + Establish a connection to the data loader and move the input data to the GPU. + + Args: + dataloader_iter (Iterator[In]): An iterator over input data. + """ cur_batch = next(dataloader_iter) self._cur_batch = cur_batch with torch.cuda.stream(self._memcpy_stream): @@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]): self._connected = True def progress(self, dataloader_iter: Iterator[In]) -> Out: + """ + Execute a training iteration, including forward and backward passes. + + Args: + dataloader_iter (Iterator[In]): An iterator over input data. + + Returns: + Out: The output data. + """ if not self._connected: self._connect(dataloader_iter) @@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]): class Tracer(torch.fx.Tracer): + """ + Custom tracer class for PyTorch models. + + This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings. + + Attributes: + proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing. + _leaf_modules (List[str]): List of qualified names of leaf modules. + """ + # Disable proxying buffers during tracing. Ideally, proxying buffers would # be disabled, but some models are currently mutating buffer values, which # causes errors during tracing. If those models can be rewritten to not do @@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer): proxy_buffer_attributes = False def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: + """ + Initialize the Tracer. + + Args: + leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing. + """ super().__init__() self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + Check if a module is a leaf module during tracing. + + Args: + m (torch.nn.Module): The PyTorch module. + module_qualified_name (str): The qualified name of the module. + + Returns: + bool: True if the module is considered a leaf module, False otherwise. + """ if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules: return True return super().is_leaf_module(m, module_qualified_name) @@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer): @dataclass class TrainPipelineContext: + """ + Dataclass to store information related to the training pipeline context. + + Attributes: + input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests. + module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts. + feature_processor_forwards (List[Any]): A list of feature processor forwards. + """ + # pyre-ignore [4] input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) module_contexts: Dict[str, Multistreamable] = field(default_factory=dict) @@ -166,6 +279,14 @@ class TrainPipelineContext: @dataclass class ArgInfo: + """ + Dataclass to store information about arguments in the training pipeline. + + Attributes: + input_attrs (List[str]): List of attribute names of the input batch. + is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem. + name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments. + """ # attributes of input batch, e.g. batch.attr1.attr2 call # will produce ["attr1", "attr2"] input_attrs: List[str] @@ -177,6 +298,16 @@ class ArgInfo: class PipelinedForward: + """ + Represents a pipelined forward pass operation. + + Attributes: + name (str): The name of the forward pass. + args (List[ArgInfo]): List of argument information for the forward pass. + module (ShardedModule): The sharded module associated with the forward pass. + context (TrainPipelineContext): The training pipeline context. + dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing. + """ def __init__( self, name: str, @@ -185,6 +316,16 @@ class PipelinedForward: context: TrainPipelineContext, dist_stream: Optional[torch.cuda.streams.Stream], ) -> None: + """ + Initialize a PipelinedForward instance. + + Args: + name (str): The name of the forward pass. + args (List[ArgInfo]): List of argument information for the forward pass. + module (ShardedModule): The sharded module associated with the forward pass. + context (TrainPipelineContext): The training pipeline context. + dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing. + """ self._name = name self._args = args self._module = module @@ -193,6 +334,16 @@ class PipelinedForward: # pyre-ignore [2, 24] def __call__(self, *input, **kwargs) -> Awaitable: + """ + Perform the pipelined forward pass operation. + + Args: + *input: Variable-length positional arguments. + **kwargs: Variable-length keyword arguments. + + Returns: + Awaitable: An awaitable object representing the forward pass result. + """ assert self._name in self._context.input_dist_requests request = self._context.input_dist_requests[self._name] assert isinstance(request, Awaitable) @@ -230,10 +381,22 @@ class PipelinedForward: @property def name(self) -> str: + """ + Get the name of the forward pass. + + Returns: + str: The name of the forward pass. + """ return self._name @property def args(self) -> List[ArgInfo]: + """ + Get the list of argument information for the forward pass. + + Returns: + List[ArgInfo]: List of argument information. + """ return self._args @@ -242,6 +405,17 @@ def _start_data_dist( batch: In, context: TrainPipelineContext, ) -> None: + """ + Start data distribution for a list of pipelined modules. + + Args: + pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules. + batch (In): The input batch. + context (TrainPipelineContext): The training pipeline context. + + Returns: + None: This function doesn't return a value. + """ context.input_dist_requests.clear() context.module_contexts.clear() for module in pipelined_modules: @@ -286,9 +460,17 @@ def _get_node_args_helper( feature_processor_arguments: Optional[List[Node]] = None, ) -> Tuple[List[ArgInfo], int]: """ - Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. - It also counts the number of (args + kwargs) found. - """ + Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. + It also counts the number of (args + kwargs) found. + + Args: + arguments: The arguments to process. + num_found: The current count of arguments found. + feature_processor_arguments: Optional list of feature processor arguments. + + Returns: + Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found. + """ arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))] for arg, arg_info in zip(arguments, arg_info_list): @@ -332,6 +514,16 @@ def _get_node_args_helper( def _get_node_args( node: Node, feature_processor_nodes: Optional[List[Node]] = None ) -> Tuple[List[ArgInfo], int]: + """ + Get argument information for a given node. + + Args: + node (Node): The node to process. + feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes. + + Returns: + Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found. + """ num_found = 0 pos_arg_info_list, num_found = _get_node_args_helper( node.args, num_found, feature_processor_nodes @@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper( path: str, unsharded_module_names: Set[str], ) -> bool: + """ + Get the names of unsharded modules in a model. + + Args: + model (torch.nn.Module): The model to analyze. + path (str): The current path in the model hierarchy. + unsharded_module_names (Set[str]): A set to store the names of unsharded modules. + + Returns: + bool: True if any sharded modules were found in the hierarchy, False otherwise. + """ sharded_children = set() for name, child in model.named_children(): curr_path = path + name @@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper( def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]: """ - Returns a list of top level modules do not contain any sharded sub modules. - """ + Returns a list of top-level modules that do not contain any sharded sub-modules. + + Args: + model (torch.nn.Module): The model to analyze. + + Returns: + List[str]: A list of top-level module names without sharded sub-modules. + """ unsharded_module_names: Set[str] = set() _get_unsharded_module_names_helper( @@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901 context: TrainPipelineContext, dist_stream: Optional[torch.cuda.streams.Stream], ) -> List[ShardedModule]: + """ + Rewrites the model to enable pipelined execution for selected sharded modules. + + This function traces the input model using a custom tracer and identifies sharded modules + that can be pipelined. It then creates PipelinedForward objects for these modules, + which enable pipelining during training. + + Args: + model (torch.nn.Module): The input model to be rewritten. + context (TrainPipelineContext): The context containing information needed for pipelining. + dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution. + + Returns: + List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution. + """ # Get underlying nn.Module if isinstance(model, DistributedModelParallel): @@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901 class TrainPipelineSparseDist(TrainPipeline[In, Out]): """ This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with - forward and backward. This helps hide the all2all latency while preserving the - training forward / backward ordering. + forward and backward. This helps hide the all2all latency while preserving the + training forward / backward ordering. - stage 3: forward, backward - uses default CUDA stream - stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream - stage 1: device transfer - uses memcpy CUDA stream + stage 3: forward, backward - uses default CUDA stream + stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream + stage 1: device transfer - uses memcpy CUDA stream - `ShardedModule.input_dist()` is only done for top-level modules in the call graph. - To be considered a top-level module, a module can only depend on 'getattr' calls on - input. + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. - Input model must be symbolically traceable with the exception of `ShardedModule` and - `DistributedDataParallel` modules. - """ + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): The input model to be used for training. + optimizer (torch.optim.Optimizer): The optimizer for updating model parameters. + device (torch.device): The device where training will be performed. + enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False. + enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True. + grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None. + + Attributes: + synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines. + + """ synced_pipeline_id: Dict[int, int] = {} @@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): enable_grad_scaling: bool = True, grad_accum: Optional[int] = None, ) -> None: + """ + Initializes the training pipeline. + + Args: + model (torch.nn.Module): The input model to be used for training. + optimizer (torch.optim.Optimizer): The optimizer for updating model parameters. + device (torch.device): The device where training will be performed. + enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False. + enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True. + grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None. + """ self._model = model self._optimizer = optimizer self._device = device @@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): self._grad_accum = grad_accum def _connect(self, dataloader_iter: Iterator[In]) -> None: + """ + Connects the training pipeline to data and prepares for forward and backward passes. + + Args: + dataloader_iter (Iterator[In]): An iterator providing input data batches. + """ + # batch 1 with torch.cuda.stream(self._memcpy_stream): batch_i = next(dataloader_iter) @@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): def progress(self, dataloader_iter: Iterator[In]) -> Out: """ - NOTE: This method has been updated to perform gradient accumulation. - If `_grad_accum` is set, then loss values are scaled by this amount and - optimizer update/reset is skipped for `_grad_accum` calls of `progress` - (congruent to training steps), and then update/reset on every `_grad_accum`th - step. + Progresses through the training pipeline, performing forward and backward passes. - """ + NOTE: This method has been updated to perform gradient accumulation. + If `_grad_accum` is set, then loss values are scaled by this amount and + optimizer update/reset is skipped for `_grad_accum` calls of `progress` + (congruent to training steps), and then update/reset on every `_grad_accum`th + step. + + Args: + dataloader_iter (Iterator[In]): An iterator providing input data batches. + + Returns: + Out: The output of the forward pass. + """ should_step_optimizer = ( self._grad_accum is not None and self._progress_calls > 0 @@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): def _sync_pipeline(self) -> None: """ - Syncs `PipelinedForward` for sharded modules with context and dist stream of the - current train pipeline. Used when switching between train pipelines for the same - model. + Syncs `PipelinedForward` for sharded modules with context and dist stream of the + current train pipeline. Used when switching between train pipelines for the same + model. """ for module in self._pipelined_modules: module.forward._context = self._context diff --git a/machines/environment.py b/machines/environment.py index c7319d6..a0bfb98 100644 --- a/machines/environment.py +++ b/machines/environment.py @@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222 def on_kf(): - return "SPEC_TYPE" in os.environ + """Check if the code is running on Kubernetes with Kubeflow (KF) environment. + + Returns: + bool: True if running on KF, False otherwise. + """ + return "SPEC_TYPE" in os.environ def has_readers(): - if on_kf(): - machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) - return machines_config_env["dataset_worker"] is not None - return os.environ.get("HAS_READERS", "False") == "True" + """Check if the current task has dataset workers. + + Returns: + bool: True if the task has dataset workers, False otherwise. + """ + if on_kf(): + machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) + return machines_config_env.get("dataset_worker") is not None + return os.environ.get("HAS_READERS", "False") == "True" def get_task_type(): - if on_kf(): - return os.environ["SPEC_TYPE"] - return os.environ["TASK_TYPE"] + """Get the type of the current task. + + Returns: + str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'. + """ + if on_kf(): + return os.environ["SPEC_TYPE"] + return os.environ["TASK_TYPE"] def is_chief() -> bool: - return get_task_type() == "chief" + """Check if the current task is the 'chief'. + + Returns: + bool: True if the current task is the 'chief', False otherwise. + """ + return get_task_type() == "chief" def is_reader() -> bool: - return get_task_type() == "datasetworker" + """Check if the current task is a 'datasetworker'. + + Returns: + bool: True if the current task is a 'datasetworker', False otherwise. + """ + return get_task_type() == "datasetworker" def is_dispatcher() -> bool: - return get_task_type() == "datasetdispatcher" + """Check if the current task is a 'datasetdispatcher'. + + Returns: + bool: True if the current task is a 'datasetdispatcher', False otherwise. + """ + return get_task_type() == "datasetdispatcher" def get_task_index(): - if on_kf(): - pod_name = os.environ["MY_POD_NAME"] - return int(pod_name.split("-")[-1]) - else: - raise NotImplementedError + """Get the index of the current task. + + Returns: + int: Task index. + Raises: + NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment. + """ + if on_kf(): + pod_name = os.environ["MY_POD_NAME"] + return int(pod_name.split("-")[-1]) + else: + raise NotImplementedError def get_reader_port(): - if on_kf(): - return KF_DDS_PORT - return SLURM_DDS_PORT + """Get the port used by readers. + + Returns: + int: Reader port. + """ + if on_kf(): + return KF_DDS_PORT + return SLURM_DDS_PORT def get_dds(): - if not has_readers(): - return None - dispatcher_address = get_dds_dispatcher_address() - if dispatcher_address: - return f"grpc://{dispatcher_address}" - else: - raise ValueError("Job does not have DDS.") + """Get the Distributed Data Service (DDS) address. + + Returns: + str: DDS address in the format 'grpc://host:port'. + Raises: + ValueError: If the job does not have DDS. + """ + if not has_readers(): + return None + dispatcher_address = get_dds_dispatcher_address() + if dispatcher_address: + return f"grpc://{dispatcher_address}" + else: + raise ValueError("Job does not have DDS.") def get_dds_dispatcher_address(): - if not has_readers(): - return None - if on_kf(): - job_name = os.environ["JOB_NAME"] - dds_host = f"{job_name}-datasetdispatcher-0" - else: - dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"] - return f"{dds_host}:{get_reader_port()}" + """Get the DDS dispatcher address. + + Returns: + str: DDS dispatcher address in the format 'host:port'. + """ + if not has_readers(): + return None + if on_kf(): + job_name = os.environ["JOB_NAME"] + dds_host = f"{job_name}-datasetdispatcher-0" + else: + dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"] + return f"{dds_host}:{get_reader_port()}" def get_dds_worker_address(): - if not has_readers(): - return None - if on_kf(): - job_name = os.environ["JOB_NAME"] - task_index = get_task_index() - return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}" - else: - node = os.environ["SLURMD_NODENAME"] - return f"{node}:{get_reader_port()}" + """Get the DDS worker address. + + Returns: + str: DDS worker address in the format 'host:port'. + """ + if not has_readers(): + return None + if on_kf(): + job_name = os.environ["JOB_NAME"] + task_index = get_task_index() + return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}" + else: + node = os.environ["SLURMD_NODENAME"] + return f"{node}:{get_reader_port()}" def get_num_readers(): - if not has_readers(): - return 0 - if on_kf(): - machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) - return int(machines_config_env["num_dataset_workers"] or 0) - return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(",")) + """Get the number of dataset workers. + + Returns: + int: Number of dataset workers. + """ + if not has_readers(): + return 0 + if on_kf(): + machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) + return int(machines_config_env.get("num_dataset_workers") or 0) + return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(",")) def get_flight_server_addresses(): - if on_kf(): - job_name = os.environ["JOB_NAME"] - return [ - f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}" - for task_index in range(get_num_readers()) - ] - else: - raise NotImplementedError + """Get Flight server addresses for dataset workers. + + Returns: + List[str]: List of Flight server addresses in the format 'grpc://host:port'. + Raises: + NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment. + """ + if on_kf(): + job_name = os.environ["JOB_NAME"] + return [ + f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}" + for task_index in range(get_num_readers()) + ] + else: + raise NotImplementedError def get_dds_journaling_dir(): - return os.environ.get("DATASET_JOURNALING_DIR", None) + """Get the DDS journaling directory. + + Returns: + str: DDS journaling directory. + """ + return os.environ.get("DATASET_JOURNALING_DIR", None) diff --git a/metrics/aggregation.py b/metrics/aggregation.py index 57f75d5..13a0e09 100644 --- a/metrics/aggregation.py +++ b/metrics/aggregation.py @@ -14,18 +14,21 @@ def update_mean( weight: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Update the mean according to Welford formula: - https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. + Update the mean according to the Welford formula. + + This function updates the mean and the weighted sum of values using the Welford algorithm. + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. See also https://nullbuffer.com/articles/welford_algorithm.html for more information. - Args: - current_mean: The value of the current accumulated mean. - current_weight_sum: The current weighted sum. - value: The new value that needs to be added to get a new mean. - weight: The weights for the new value. - Returns: The updated mean and updated weighted sum. + Args: + current_mean (torch.Tensor): The value of the current accumulated mean. + current_weight_sum (torch.Tensor): The current weighted sum. + value (torch.Tensor): The new value that needs to be added to get a new mean. + weight (torch.Tensor): The weights for the new value. - """ + Returns: + Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum. + """ weight = torch.broadcast_to(weight, value.shape) # Avoiding (on purpose) in-place operation when using += in case @@ -38,12 +41,15 @@ def update_mean( def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: """ Merge the state from multiple workers. - Args: - state: A tensor with the first dimension indicating workers. - Returns: The accumulated mean from all workers. + This function merges the state from multiple workers to compute the accumulated mean. - """ + Args: + state (torch.Tensor): A tensor with the first dimension indicating workers. + + Returns: + torch.Tensor: The accumulated mean from all workers. + """ mean, weight_sum = update_mean( current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), @@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: class StableMean(torchmetrics.Metric): """ - This implements a numerical stable mean metrics computation using Welford algorithm according to - https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. + A numerical stable mean metric using the Welford algorithm. + + This class implements a numerical stable mean metrics computation using the Welford algorithm. + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. For example when using float32, the algorithm will give a valid output even if the "sum" is larger than the maximum float32 as far as the mean is within the limit of float32. See also https://nullbuffer.com/articles/welford_algorithm.html for more information. - """ + + Args: + **kwargs: Additional parameters supported by all torchmetrics.Metric. + + Attributes: + mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum. + """ def __init__(self, **kwargs): """ @@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric): ) def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None: - """ - Update the current mean. - Args: - value: Value to update the mean with. - weight: weight to use. Shape should be broadcastable to that of value. - """ + """Update the current mean. + + Args: + value (torch.Tensor): Value to update the mean with. + weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value. + """ mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] if not isinstance(weight, torch.Tensor): @@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric): ) def compute(self) -> torch.Tensor: - """ - Compute and return the accumulated mean. - """ + """Compute and return the accumulated mean. + + Returns: + torch.Tensor: The accumulated mean. + """ return self.mean_and_weight_sum[0] diff --git a/metrics/auroc.py b/metrics/auroc.py index 6979c20..66e39f2 100644 --- a/metrics/auroc.py +++ b/metrics/auroc.py @@ -29,6 +29,9 @@ def _compute_helper( equal_predictions_as_incorrect: For positive & negative labels having identical scores, we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise, we assume that they are correct prediction (i.e weight = 0). + + Returns: + torch.Tensor: The computed AUROC """ dim = 0 @@ -52,24 +55,34 @@ def _compute_helper( class AUROCWithMWU(torchmetrics.Metric): """ - AUROC using Mann-Whitney U-test. - See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve. + AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test. - This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return - the correct AUROC even if the predicted probabilities are all close to 0. - Currently only support binary classification. - """ + This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate) + scenarios. It returns the correct AUROC even when predicted probabilities are close to 0. + + See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve. + + Note: Currently, this implementation only supports binary classification. + + Args: + label_threshold (float): Threshold for classifying labels as positive or negative. + Labels above this threshold are considered positive, and those below are considered negative. + raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing. + Otherwise, a warning is logged, and AUROC is computed. + **kwargs: Additional parameters supported by all torchmetrics.Metric. + """ def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs): """ + Initializes the AUROCWithMWU metric. - Args: - label_threshold: Labels strictly above this threshold are considered positive labels, - otherwise, they are considered negative. - raise_missing_class: If True, an error will be raise if negative or positive class is missing. - Otherwise, we will simply log a warning. - **kwargs: Additional parameters supported by all torchmetrics.Metric. - """ + Args: + label_threshold (float): Threshold for classifying labels as positive or negative. + Labels above this threshold are considered positive, and those below are considered negative. + raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing. + Otherwise, a warning is logged, and AUROC is computed. + **kwargs: Additional parameters supported by all torchmetrics.Metric. + """ super().__init__(**kwargs) self.add_state("predictions", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") diff --git a/ml_logging/absl_logging.py b/ml_logging/absl_logging.py index e4d9e9e..07be8bb 100644 --- a/ml_logging/absl_logging.py +++ b/ml_logging/absl_logging.py @@ -14,7 +14,24 @@ from absl import logging as logging def setup_absl_logging(): - """Make sure that absl logging pushes to stdout rather than stderr.""" + """ + Configure absl-py logging to direct log messages to stdout and apply a custom log message format. + + This function ensures that log messages generated by the absl-py library are written to stdout + rather than stderr. It also applies a custom log message format that includes module, function, + line number, log level, and the log message content. + + Note: + This function should be called once at the beginning of your script or application to + configure absl-py logging. + + Example: + To use this function, simply call it at the start of your script: + ``` + setup_absl_logging() + ``` + + """ logging.get_absl_handler().python_handler.stream = sys.stdout formatter = py_logging.Formatter( fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s" diff --git a/ml_logging/test_torch_logging.py b/ml_logging/test_torch_logging.py index e3a2e90..36c141d 100644 --- a/ml_logging/test_torch_logging.py +++ b/ml_logging/test_torch_logging.py @@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging class Testtlogging(unittest.TestCase): def test_warn_once(self): + """ + Test that warning messages are logged only once when using the assertLogs context manager. + + This unit test checks the behavior of the logging system when warning messages are issued + multiple times within the same context. It uses the assertLogs context manager to capture + log messages at the INFO level and verifies that warning messages are logged only once. + + Example: + To use this test case, call it using a test runner like unittest: + ``` + python -m unittest your_test_module.TestLogging.test_warn_once + ``` + + """ + with self.assertLogs(level="INFO") as captured_logs: logging.info("first info") logging.warning("first warning") diff --git a/ml_logging/torch_logging.py b/ml_logging/torch_logging.py index e791c46..2f60114 100644 --- a/ml_logging/torch_logging.py +++ b/ml_logging/torch_logging.py @@ -18,7 +18,35 @@ import torch.distributed as dist def rank_specific(logger): - """Ensures that we only override a given logger once.""" + """ + Customize logger behavior based on the distributed environment and rank. + + This function allows for customizing the behavior of a logger based on the distributed environment and the rank + of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages + depending on the rank or limit the number of redundant logs. + + Args: + logger: The logger object to customize. + + Returns: + The customized logger. + + Example: + To use this function with the `logging` module: + ```python + import logging + from rank_specific_logging import rank_specific + + logger = logging.getLogger(__name__) + rank_specific(logger) + ``` + + Customization: + - Messages are only logged if the distributed environment is not initialized or if the rank matches. + - The 'warning' method is limited to logging a single redundant warning. + - Logging from rank -1 is redirected to include the rank information. + + """ if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"): return logger diff --git a/optimizers/config.py b/optimizers/config.py index f5011f0..dc5c94a 100644 --- a/optimizers/config.py +++ b/optimizers/config.py @@ -8,11 +8,60 @@ import pydantic class PiecewiseConstant(base_config.BaseConfig): + """ + Configuration for a piecewise constant learning rate schedule. + + This configuration class allows you to specify a piecewise constant learning rate schedule + by defining boundaries and corresponding learning rate values. + + Attributes: + learning_rate_boundaries (List[int], optional): List of step boundaries at which + the learning rate will change. If None, no boundaries are defined. + learning_rate_values (List[float], optional): List of learning rate values + corresponding to the boundaries. If None, no values are defined. + + Example: + To configure a piecewise constant learning rate schedule, create an instance of this class + and set the attributes accordingly. For example: + + ```python + piecewise_lr = PiecewiseConstant( + learning_rate_boundaries=[1000, 2000, 3000], + learning_rate_values=[0.1, 0.05, 0.01, 0.001] + ) + ``` + + Note: + The number of learning rate values should be one more than the number of boundaries. + + """ learning_rate_boundaries: typing.List[int] = pydantic.Field(None) learning_rate_values: typing.List[float] = pydantic.Field(None) class LinearRampToConstant(base_config.BaseConfig): + """ + Configuration for a linear ramp-up to constant learning rate schedule. + + This configuration class allows you to specify a learning rate schedule that ramps up linearly + from zero to a constant value over a specified number of steps. + + Attributes: + learning_rate (float): The final constant learning rate. + num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero. + + Example: + To configure a linear ramp-up to a constant learning rate, create an instance of this class + and set the attributes accordingly. For example: + + ```python + linear_ramp_lr = LinearRampToConstant( + learning_rate=0.1, + num_ramp_steps=1000 + ) + ``` + + """ learning_rate: float num_ramp_steps: pydantic.PositiveInt = pydantic.Field( description="Number of steps to ramp this up from zero." @@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig): class LinearRampToCosine(base_config.BaseConfig): + """ + Configuration for a linear ramp-up to cosine decay learning rate schedule. + + This configuration class allows you to specify a learning rate schedule that ramps up linearly + from zero, then decays following a cosine schedule to a final constant learning rate. + + Attributes: + learning_rate (float): The initial learning rate at the start of ramp-up. + final_learning_rate (float): The final constant learning rate after decay. + num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero. + final_num_steps (PositiveInt): Final number of steps where decay stops. + + Example: + To configure a linear ramp-up to cosine decay learning rate, create an instance of this + class and set the attributes accordingly. For example: + + ```python + ramp_to_cosine_lr = LinearRampToCosine( + learning_rate=0.01, + final_learning_rate=0.001, + num_ramp_steps=1000, + final_num_steps=5000 + ) + ``` + + """ learning_rate: float final_learning_rate: float num_ramp_steps: pydantic.PositiveInt = pydantic.Field( @@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig): class LearningRate(base_config.BaseConfig): + """ + Learning rate configuration for training. + + This configuration class allows you to specify different learning rate schedules + for your training process. + + Attributes: + constant (float, optional): Constant learning rate to be used throughout training. + linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly + and then decays following a cosine schedule. + linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up + linearly and then remains constant. + piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified + boundaries with corresponding values. + + Example: + To configure a learning rate schedule, create an instance of this class and set the + attributes accordingly. For example: + + ```python + learning_rate = LearningRate( + constant=0.01, + linear_ramp_to_cosine=LinearRampToCosine( + learning_rate=0.1, + final_learning_rate=0.001, + num_ramp_steps=1000, + final_num_steps=5000 + ) + ) + ``` + + Note: + Each learning rate schedule attribute can be set to `None` if not needed. + + """ constant: float = pydantic.Field(None, one_of="lr") linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr") linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr") @@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig): class OptimizerAlgorithmConfig(base_config.BaseConfig): - """Base class for optimizer configurations.""" + """ + Base class for optimizer configurations. + + This base configuration class provides a structure for specifying various optimizer-related + settings, including the learning rate and different learning rate schedules. + + Attributes: + lr (float): The base learning rate used by the optimizer. + + Subclasses should inherit from this base class and define additional attributes specific to + the optimizer algorithm they represent. + + Example: + To create a custom optimizer configuration, create a subclass of this base class and + define the necessary attributes. For example: + + ```python + class MyOptimizerConfig(OptimizerAlgorithmConfig): + momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.") + ``` + + Note: + This base class does not include specific optimizer settings. Subclasses should define + the optimizer-specific attributes as needed. + + """ lr: float ... class AdamConfig(OptimizerAlgorithmConfig): - # see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam + """ + Configuration for the Adam optimizer. + + This configuration class allows you to specify the hyperparameters for the Adam optimizer. + + Attributes: + lr (float): The learning rate for optimization. + betas (Tuple[float, float], optional): Coefficients used for computing running averages + of gradient and squared gradient. Defaults to (0.9, 0.999). + eps (float, optional): A small constant added to the denominator for numerical stability. + Defaults to 1e-7. + + Example: + To configure the Adam optimizer, create an instance of this class and set the attributes + accordingly. For example: + + ```python + adam_optimizer = AdamConfig( + lr=0.001, + betas=(0.9, 0.999), + eps=1e-8 + ) + ``` + + See Also: + [PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam) + + """ lr: float betas: typing.Tuple[float, float] = [0.9, 0.999] eps: float = 1e-7 # Numerical stability in denominator. class SgdConfig(OptimizerAlgorithmConfig): + """ + Configuration for the Stochastic Gradient Descent (SGD) optimizer. + + This configuration class allows you to specify the hyperparameters for the SGD optimizer. + + Attributes: + lr (float): The learning rate for optimization. + momentum (float, optional): The momentum factor for SGD. Defaults to 0.0. + + Example: + To configure the SGD optimizer, create an instance of this class and set the attributes + accordingly. For example: + + ```python + sgd_optimizer = SgdConfig( + lr=0.01, + momentum=0.9 + ) + ``` + + """ lr: float momentum: float = 0.0 class AdagradConfig(OptimizerAlgorithmConfig): + """ + Configuration for the optimizer used during training. + + This configuration class allows you to specify the optimizer for training, including + options for various optimizer algorithms. + + Attributes: + learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None. + adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None. + sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer. + Defaults to None. + adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None. + + Example: + To configure the optimizer for training, create an instance of this class and set the + attributes accordingly. For example: + + ```python + optimizer_config = OptimizerConfig( + learning_rate=LearningRate(constant=0.001), + adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8) + ) + ``` + + """ lr: float eps: float = 0 class OptimizerConfig(base_config.BaseConfig): + """ + Configuration for defining different optimizer algorithms and their parameters. + + This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad, + along with their respective hyperparameters. + + Args: + learning_rate (LearningRate): The learning rate configuration, which can include + constant learning rates or other learning rate schedules. + adam (AdamConfig): Configuration for the Adam optimizer. + sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer. + adagrad (AdagradConfig): Configuration for the Adagrad optimizer. + + Example: + ```python + optimizer_config = OptimizerConfig( + learning_rate=LearningRate(constant=0.001), + adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8), + ) + ``` + + Attributes: + learning_rate (LearningRate): The learning rate configuration. + adam (AdamConfig): Configuration for the Adam optimizer. + sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer. + adagrad (AdagradConfig): Configuration for the Adagrad optimizer. + + Note: + You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an + `OptimizerConfig` instance. + + See Also: + - `LearningRate`: Configuration for specifying learning rates. + - `AdamConfig`: Configuration for the Adam optimizer. + - `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer. + - `AdagradConfig`: Configuration for the Adagrad optimizer. + + """ learning_rate: LearningRate = pydantic.Field( None, description="Constant learning rates", @@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig): def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig): + """ + Get the optimizer algorithm configuration from the given `OptimizerConfig`. + + This function extracts and returns the specific optimizer algorithm configuration + (e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`. + + Args: + optimizer_config (OptimizerConfig): The optimizer configuration object containing + one of the optimizer algorithm configurations. + + Returns: + Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm + configuration extracted from `optimizer_config`. + + Raises: + ValueError: If no optimizer algorithm is selected in `optimizer_config`. + + Example: + ```python + optimizer_config = OptimizerConfig( + adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8) + ) + algorithm_config = get_optimizer_algorithm_config(optimizer_config) + # `algorithm_config` will be an instance of `AdamConfig`. + ``` + + """ if optimizer_config.adam is not None: return optimizer_config.adam elif optimizer_config.sgd is not None: diff --git a/optimizers/optimizer.py b/optimizers/optimizer.py index 4517368..e930f73 100644 --- a/optimizers/optimizer.py +++ b/optimizers/optimizer.py @@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging def compute_lr(lr_config, step): - """Compute a learning rate.""" + """ + Compute the learning rate based on the specified learning rate configuration. + + This function calculates the learning rate according to the given configuration, which can include + constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing. + + Args: + lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule. + step (int): The current training step or iteration. + + Returns: + float: The computed learning rate for the current step. + + Raises: + ValueError: If the `lr_config` is invalid or contains conflicting options. + + Example: + ```python + lr_schedule = LearningRate( + constant=0.001, + piecewise_constant=PiecewiseConstant( + learning_rate_boundaries=[1000, 2000, 3000], + learning_rate_values=[0.1, 0.05, 0.01, 0.001] + ) + ) + current_step = 2500 + learning_rate = compute_lr(lr_schedule, current_step) + ``` + """ if lr_config.constant is not None: return lr_config.constant elif lr_config.piecewise_constant is not None: @@ -46,11 +74,54 @@ def compute_lr(lr_config, step): class LRShim(_LRScheduler): - """Shim to get learning rates into a LRScheduler. - - This adheres to the torch.optim scheduler API and can be plugged anywhere that - e.g. exponential decay can be used. """ + Learning Rate Scheduler Shim to adjust learning rates during training. + + This class acts as a shim to apply different learning rates to individual parameter groups + within an optimizer. It adheres to the torch.optim scheduler API and can be used with various + optimizers, allowing fine-grained control over learning rates based on configuration. + + Args: + optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted. + lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their + corresponding learning rate configurations. + last_epoch (int, optional): The index of the last epoch. Default is -1. + verbose (bool, optional): If True, prints a warning message when accessing learning rates + using the deprecated `get_lr()` method. Default is False. + + Raises: + ValueError: If the number of parameter groups in the optimizer does not match the number + of learning rate configurations provided. + + Note: + To obtain the last computed learning rates, please use `get_last_lr()`. + + Example: + ```python + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + lr_schedule = { + 'main': LearningRate(constant=0.01), + 'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant( + learning_rate_boundaries=[1000, 2000], + learning_rate_values=[0.01, 0.001] + )) + } + lr_shim = LRShim(optimizer, lr_schedule) + + for epoch in range(num_epochs): + # Train the model + train(...) + # Update learning rates at the end of each epoch + lr_shim.step(epoch) + + final_lr_main = lr_shim.get_last_lr()['main'] + final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary'] + ``` + + See Also: + - `LearningRate`: Configuration for specifying learning rates. + - `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules. + """ def __init__( self, @@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig): def build_optimizer( model: torch.nn.Module, optimizer_config: OptimizerConfig ) -> Tuple[Optimizer, _LRScheduler]: - """Builds an optimizer and LR scheduler from an OptimizerConfig. - Note: use this when you want the same optimizer and learning rate schedule for all your parameters. """ + Build an optimizer and learning rate scheduler based on the provided optimizer configuration. + + Args: + model (torch.nn.Module): The PyTorch model for which the optimizer will be created. + optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer + algorithm and learning rate settings. + + Returns: + Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler + objects. + + Note: + This function is intended for cases where you want the same optimizer and learning rate + schedule for all model parameters. + + Example: + ```python + model = MyModel() + optimizer_config = OptimizerConfig( + learning_rate=LearningRate(constant=0.01), + sgd=SgdConfig(lr=0.01, momentum=0.9) + ) + optimizer, scheduler = build_optimizer(model, optimizer_config) + + for epoch in range(num_epochs): + # Train the model with the optimizer + train(model, optimizer, ...) + # Update learning rates at the end of each epoch + scheduler.step(epoch) + ``` + + See Also: + - `OptimizerConfig`: Configuration for specifying optimizer settings. + - `LRShim`: Learning rate scheduler shim for fine-grained learning rate control. + """ optimizer_class = get_optimizer_class(optimizer_config) optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict()) # We're passing everything in as one group here diff --git a/projects/twhin/data/config.py b/projects/twhin/data/config.py index 55e3cec..05c70c4 100644 --- a/projects/twhin/data/config.py +++ b/projects/twhin/data/config.py @@ -4,6 +4,17 @@ import pydantic class TwhinDataConfig(base_config.BaseConfig): + """ + Configuration for Twhin model training data. + + Args: + data_root (str): The root directory for the training data. + per_replica_batch_size (pydantic.PositiveInt): Batch size per replica. + global_negatives (int): The number of global negatives. + in_batch_negatives (int): The number of in-batch negatives. + limit (pydantic.PositiveInt): The limit on the number of data points to use. + offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None. + """ data_root: str per_replica_batch_size: pydantic.PositiveInt global_negatives: int diff --git a/projects/twhin/data/data.py b/projects/twhin/data/data.py index 37bf8f5..b0afe7f 100644 --- a/projects/twhin/data/data.py +++ b/projects/twhin/data/data.py @@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig): + """ + Create a dataset for Twhin model training. + + Args: + data_config (TwhinDataConfig): The data configuration for the dataset. + model_config (TwhinModelConfig): The model configuration containing embeddings and relations. + + Returns: + EdgesDataset: The dataset for Twhin model training. + """ tables = model_config.embeddings.tables table_sizes = {table.name: table.num_embeddings for table in tables} relations = model_config.relations diff --git a/projects/twhin/data/edges.py b/projects/twhin/data/edges.py index f7864b1..778a188 100644 --- a/projects/twhin/data/edges.py +++ b/projects/twhin/data/edges.py @@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @dataclass class EdgeBatch(DataclassBatch): + """ + Batch data structure for edge-based models. + + Args: + nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings. + labels (torch.Tensor): Tensor containing labels. + rels (torch.Tensor): Tensor containing relation information. + weights (torch.Tensor): Tensor containing weights. + """ nodes: KeyedJaggedTensor labels: torch.Tensor rels: torch.Tensor @@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch): class EdgesDataset(Dataset): + """ + Dataset for edge-based models. + + Args: + file_pattern (str): The file pattern for the dataset. + table_sizes (Dict[str, int]): A dictionary of table names and their sizes. + relations (List[Relation]): A list of relations between tables. + lhs_column_name (str): The name of the left-hand-side column. + rhs_column_name (str): The name of the right-hand-side column. + rel_column_name (str): The name of the relation column. + **dataset_kwargs: Additional keyword arguments for the parent Dataset class. + """ rng = np.random.default_rng() def __init__( @@ -56,6 +77,15 @@ class EdgesDataset(Dataset): super().__init__(file_pattern=file_pattern, **dataset_kwargs) def pa_to_batch(self, batch: pa.RecordBatch): + """ + Converts a pyarrow RecordBatch to an EdgeBatch. + + Args: + batch (pa.RecordBatch): A pyarrow RecordBatch containing data. + + Returns: + EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights. + """ lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy()) rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy()) rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy()) @@ -74,6 +104,14 @@ class EdgesDataset(Dataset): ) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: """Process edges that contain lhs index, rhs index, relation index. + + Args: + lhs (torch.Tensor): Tensor containing left-hand-side indices. + rhs (torch.Tensor): Tensor containing right-hand-side indices. + rel (torch.Tensor): Tensor containing relation indices. + + Returns: + Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs. Example: ``` @@ -147,6 +185,12 @@ class EdgesDataset(Dataset): return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths) def to_batches(self): + """ + Converts data to batches. + + Yields: + pa.RecordBatch: A pyarrow RecordBatch containing data. + """ ds = super().to_batches() batch_size = self._dataset_kwargs["batch_size"] diff --git a/projects/twhin/models/config.py b/projects/twhin/models/config.py index aafc9e6..c7b2e1c 100644 --- a/projects/twhin/models/config.py +++ b/projects/twhin/models/config.py @@ -10,8 +10,29 @@ from pydantic import validator class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): + """ + Configuration class for Twhin model embeddings. + + This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types + for all tables in the Twhin model embeddings configuration match. + + Attributes: + tables (List[TableConfig]): A list of table configurations for the model's embeddings. + """ @validator("tables") def embedding_dims_match(cls, tables): + """ + Validate that embedding dimensions and data types match for all tables. + + Args: + tables (List[TableConfig]): List of table configurations. + + Returns: + List[TableConfig]: The list of validated table configurations. + + Raises: + AssertionError: If embedding dimensions or data types do not match. + """ embedding_dim = tables[0].embedding_dim data_type = tables[0].data_type for table in tables: @@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): class Operator(str, enum.Enum): + """ + Enumeration of operator types. + + This enumeration defines different types of operators that can be applied to Twhin model relations. + """ TRANSLATION = "translation" class Relation(pydantic.BaseModel): - """graph relationship properties and operator""" + """ + Configuration class for graph relationships in the Twhin model. + + This class defines properties and operators for graph relationships in the Twhin model. + + Attributes: + name (str): The name of the relationship. + lhs (str): The name of the entity on the left-hand side of the relation. + rhs (str): The name of the entity on the right-hand side of the relation. + operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product. + """ name: str = pydantic.Field(..., description="Relationship name.") lhs: str = pydantic.Field( @@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel): class TwhinModelConfig(base_config.BaseConfig): + """ + Configuration class for the Twhin model. + + This class defines configuration options specific to the Twhin model. + + Attributes: + embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings. + relations (List[Relation]): List of graph relationship configurations. + translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation. + """ embeddings: TwhinEmbeddingsConfig relations: typing.List[Relation] translation_optimizer: OptimizerConfig @validator("relations", each_item=True) def valid_node_types(cls, relation, values, **kwargs): + """ + Validate that the specified node types in relations are valid table names in embeddings. + + Args: + relation (Relation): A single relation configuration. + values (dict): The values dictionary containing the "embeddings" configuration. + + Returns: + Relation: The validated relation configuration. + + Raises: + AssertionError: If the specified node types are not valid table names in embeddings. + """ table_names = [table.name for table in values["embeddings"].tables] assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}" assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}" diff --git a/projects/twhin/models/models.py b/projects/twhin/models/models.py index b1dfaa7..0b5f1cb 100644 --- a/projects/twhin/models/models.py +++ b/projects/twhin/models/models.py @@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa class TwhinModel(nn.Module): + """ + Twhin model for graph-based entity embeddings and translation. + + This class defines the Twhin model, which is used for learning embeddings of entities in a graph + and applying translations to these embeddings based on graph relationships. + + Args: + model_config (TwhinModelConfig): Configuration for the Twhin model. + data_config (TwhinDataConfig): Configuration for the data used by the model. + + Attributes: + batch_size (int): The batch size used for training. + table_names (List[str]): Names of tables in the model's embeddings. + large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings. + embedding_dim (int): Dimensionality of entity embeddings. + num_tables (int): Number of tables in the model's embeddings. + in_batch_negatives (int): Number of in-batch negative samples to use during training. + global_negatives (int): Number of global negative samples to use during training. + num_relations (int): Number of graph relationships in the model. + all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings. + + """ def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig): super().__init__() self.batch_size = data_config.per_replica_batch_size @@ -31,7 +53,17 @@ class TwhinModel(nn.Module): ) def forward(self, batch: EdgeBatch): + """ + Forward pass of the Twhin model. + Args: + batch (EdgeBatch): Input batch containing graph edge information. + + Returns: + dict: A dictionary containing model output with "logits" and "probabilities". + - "logits" (torch.Tensor): Logit scores. + - "probabilities" (torch.Tensor): Sigmoid probabilities. + """ # B x D trans_embs = self.all_trans_embs.data[batch.rels] @@ -98,6 +130,18 @@ class TwhinModel(nn.Module): def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig): + """ + Apply optimizers to the Twhin model's embeddings. + + This function applies optimizers to the embeddings of the Twhin model based on the provided configuration. + + Args: + model (TwhinModel): The Twhin model to apply optimizers to. + model_config (TwhinModelConfig): Configuration for the Twhin model. + + Returns: + TwhinModel: The Twhin model with optimizers applied to its embeddings. + """ for table in model_config.embeddings.tables: optimizer_class = get_optimizer_class(table.optimizer) optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict() @@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module): device: torch.device, ) -> None: """ - Args: - model: torch module to wrap. - loss_fn: Function for calculating loss, should accept logits and labels. - """ + Initialize a TwhinModelAndLoss module. + + Args: + model: The torch module to wrap. + loss_fn: A function for calculating loss, should accept logits and labels. + data_config: Configuration for Twhin data. + device: The torch device to use for calculations. + """ super().__init__() self.model = model self.loss_fn = loss_fn @@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module): self.device = device def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] - """Runs model forward and calculates loss according to given loss_fn. - - NOTE: The input signature here needs to be a Pipelineable object for - prefetching purposes during training using torchrec's pipeline. However - the underlying model signature needs to be exportable to onnx, requiring - generic python types. see https://pytorch.org/docs/stable/onnx.html#types. - """ + Run the model forward and calculate the loss according to the given loss_fn. + + NOTE: The input signature here needs to be a Pipelineable object for + prefetching purposes during training using torchrec's pipeline. However + the underlying model signature needs to be exportable to onnx, requiring + generic python types. see https://pytorch.org/docs/stable/onnx.html#types + + Args: + batch ("RecapBatch"): The input batch for model inference. + + Returns: + Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of + additional outputs including logits, labels, and weights. + """ outputs = self.model(batch) logits = outputs["logits"] diff --git a/projects/twhin/models/test_models.py b/projects/twhin/models/test_models.py index f8203ef..05e9b2b 100644 --- a/projects/twhin/models/test_models.py +++ b/projects/twhin/models/test_models.py @@ -18,6 +18,12 @@ EMB_DIM = 128 def twhin_model_config() -> TwhinModelConfig: + """ + Create a configuration for the Twhin model. + + Returns: + TwhinModelConfig: The Twhin model configuration. + """ sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01)) sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) @@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig: def twhin_data_config() -> TwhinDataConfig: + """ + Create a configuration for the Twhin data. + + Returns: + TwhinDataConfig: The Twhin data configuration. + """ data_config = TwhinDataConfig( data_root="/", per_replica_batch_size=10, @@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig: def test_twhin_model(): + """ + Test the Twhin model creation and optimization. + + This function creates a Twhin model using the specified configuration and tests its optimization. It also checks + the device placement of model parameters. + + Returns: + None + """ model_config = twhin_model_config() loss_fn = F.binary_cross_entropy_with_logits diff --git a/projects/twhin/optimizer.py b/projects/twhin/optimizer.py index 7c43b56..e39e0d0 100644 --- a/projects/twhin/optimizer.py +++ b/projects/twhin/optimizer.py @@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt" def _lr_from_config(optimizer_config): + """Get the learning rate from an optimizer configuration. + + Args: + optimizer_config: Optimizer configuration. + + Returns: + Learning rate from the optimizer configuration. + """ if optimizer_config.learning_rate is not None: return optimizer_config.learning_rate else: @@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config): def build_optimizer(model: TwhinModel, config: TwhinModelConfig): """Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations. - Args: - model: TwhinModel to build optimizer for. - config: TwhinConfig for model. + Args: + model: TwhinModel to build optimizer for. + config: TwhinModelConfig for model. - Returns: - Optimizer for model. - """ + Returns: + Optimizer for model. + """ translation_optimizer_fn = functools.partial( get_optimizer_class(config.translation_optimizer), **get_optimizer_algorithm_config(config.translation_optimizer).dict(), diff --git a/projects/twhin/run.py b/projects/twhin/run.py index 12081e0..6124bc1 100644 --- a/projects/twhin/run.py +++ b/projects/twhin/run.py @@ -37,6 +37,12 @@ def run( all_config: TwhinConfig, save_dir: Optional[str] = None, ): + """Run the training process for TwhinModel. + + Args: + all_config (TwhinConfig): The configuration for the entire Twhin model. + save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None. + """ train_dataset = create_dataset(all_config.train_data, all_config.model) if env.is_reader(): @@ -80,6 +86,11 @@ def run( def main(argv): + """Main entry point for the Twhin training script. + + Args: + argv: Command-line arguments. + """ logging.info("Starting") logging.info(f"parsing config from {FLAGS.config_yaml_path}...") diff --git a/reader/dataset.py b/reader/dataset.py index 6e811cc..5d873cd 100644 --- a/reader/dataset.py +++ b/reader/dataset.py @@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging class _Reader(pa.flight.FlightServerBase): - """Distributed reader flight server wrapping a dataset.""" + """ + Distributed reader flight server wrapping a dataset. + + This class implements a Flight server that wraps a dataset, allowing clients to retrieve data + from the dataset over the Flight protocol. It is designed to be used in a distributed environment + for efficient data access. + + Args: + location (str): The location of the Flight server. + ds (Dataset): The dataset to be wrapped by the Flight server. + + Attributes: + _location (str): The location of the Flight server. + _ds (Dataset): The dataset wrapped by the Flight server. + + Methods: + do_get(_, __): Handles Flight requests for data retrieval. + + Note: + Flight is an Apache Arrow project that provides a framework for efficient data transfer. + This class allows clients to retrieve data from the dataset using Flight. + + """ def __init__(self, location: str, ds: "Dataset"): + """ + Initialize a new _Reader instance. + + Args: + location (str): The location of the Flight server. + ds (Dataset): The dataset to be wrapped by the Flight server. + """ super().__init__(location=location) self._location = location self._ds = ds def do_get(self, _, __): + """ + Handle Flight requests for data retrieval. + + This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol. + + Args: + _: Unused argument. + __: Unused argument. + + Returns: + pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset. + + Note: + An updated schema (to account for column selection) must be given to the stream. + """ # NB: An updated schema (to account for column selection) has to be given the stream. schema = next(iter(self._ds.to_batches())).schema batches = self._ds.to_batches() @@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase): class Dataset(torch.utils.data.IterableDataset): + """ + A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading. + + This class enables efficient loading of data from Parquet files using PyArrow. + It is designed to be used as an IterableDataset in PyTorch for training and inference. + + Args: + file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset. + **dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method. + Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset + for more details. + + Attributes: + LOCATION (str): The default location for the Flight server used for data distribution. + _file_pattern (str): The glob pattern specifying Parquet files in the dataset. + _fs: The filesystem object used for file operations. + _dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method. + _files (list): A list of file paths matching the glob pattern. + _schema (pa.Schema): The schema of the Parquet dataset. + + Methods: + serve(): Start serving the dataset using a Flight server. + to_batches(): Generate batches of data from the Parquet dataset. + pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch. + dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset. + + Note: + This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch + to create DataLoader instances for training or inference. + """ LOCATION = "grpc://0.0.0.0:2222" def __init__(self, file_pattern: str, **dataset_kwargs) -> None: - """Specify batch size and column to select for. + """ + Initialize a new Dataset instance. Specify batch size and column to select for. Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset. - """ + + + Args: + file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset. + **dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method. + """ self._file_pattern = file_pattern self._fs = infer_fs(self._file_pattern) self._dataset_kwargs = dataset_kwargs @@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset): self._validate_columns() def _validate_columns(self): + """ + Validate the specified columns against the dataset schema. + + Raises: + Exception: If any specified columns are not found in the dataset schema. + """ columns = set(self._dataset_kwargs.get("columns", [])) wrong_columns = set(columns) - set(self._schema.names) if wrong_columns: raise Exception(f"Specified columns {list(wrong_columns)} not in schema.") def serve(self): + """Start serving the dataset using a Flight server.""" self.reader = _Reader(location=self.LOCATION, ds=self) self.reader.serve() def _create_dataset(self): + """Create a PyArrow dataset for data retrieval.""" + return pads.dataset( source=random.sample(self._files, len(self._files))[0], format="parquet", @@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset): @abc.abstractmethod def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch: + """ + Convert a Parquet RecordBatch to a custom data batch. + + Args: + batch (pa.RecordBatch): A batch of data from the Parquet dataset. + + Returns: + DataclassBatch: A custom data batch used in PyTorch training. + + Raises: + NotImplementedError: This method must be implemented in derived classes. + """ raise NotImplementedError def dataloader(self, remote: bool = False): + """ + Create a PyTorch DataLoader for iterating through the dataset. + + Args: + remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training. + + Returns: + DataLoader: A PyTorch DataLoader for iterating through the dataset. + + Note: + If `remote` is True, a remote DataLoader is created for distributed training using Flight. + """ if not remote: return map(self.pa_to_batch, self.to_batches()) readers = get_readers(2) @@ -117,6 +230,25 @@ GRPC_OPTIONS = [ def get_readers(num_readers_per_worker: int): + """ + Get Flight readers for distributed data loading. + + This function retrieves Flight readers for distributed data loading in a PyTorch environment. + + Args: + num_readers_per_worker (int): The number of Flight readers to retrieve per worker. + + Returns: + List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading. + + Note: + Flight readers are used to fetch data in a distributed manner for efficient data loading. + + Example: + To obtain Flight readers, use the following code: + + >>> readers = get_readers(num_readers_per_worker=2) + """ addresses = env.get_flight_server_addresses() readers = [] diff --git a/reader/dds.py b/reader/dds.py index 7b49893..54b2c02 100644 --- a/reader/dds.py +++ b/reader/dds.py @@ -21,6 +21,16 @@ import torch.distributed as dist def maybe_start_dataset_service(): + """ + Start the dataset service if readers are available and required dependencies are met. + + This function checks if readers are available and if the required TensorFlow version is >= 2.5. + If both conditions are met and the current environment is the dispatcher or reader, it starts + the TensorFlow dataset service. + + Raises: + Exception: If the required TensorFlow version is not met (>= 2.5). + """ if not env.has_readers(): return @@ -59,6 +69,24 @@ def maybe_start_dataset_service(): def register_dataset( dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO" ): + """ + Register a dataset with the distributed dataset service. + + This function registers a dataset with the distributed dataset service and broadcasts the dataset ID + and job name to all processes in the distributed environment. + + Args: + dataset (tf.data.Dataset): The dataset to be registered. + dataset_service (str): The name of the dataset service. + compression (Optional[str]): The compression type for the dataset (default is "AUTO"). + + Returns: + Tuple[int, str]: A tuple containing the dataset ID and job name. + + Note: + This function should be called on the rank 0 process. + + """ if dist.get_rank() == 0: dataset_id = _register_dataset( service=dataset_service, @@ -82,6 +110,23 @@ def distribute_from_dataset_id( compression: Optional[str] = "AUTO", prefetch: Optional[int] = tf.data.experimental.AUTOTUNE, ) -> tf.data.Dataset: + """ + Distribute a dataset from a registered dataset ID. + + This function consumes a dataset from the distributed dataset service using the provided dataset ID + and job name. It also supports prefetching for improved performance. + + Args: + dataset_service (str): The name of the dataset service. + dataset_id (int): The ID of the dataset to be consumed. + job_name (Optional[str]): The name of the job associated with the dataset (optional). + compression (Optional[str]): The compression type for the dataset (default is "AUTO"). + prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE). + + Returns: + tf.data.Dataset: The distributed dataset. + + """ logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}") dataset = _from_dataset_id( processing_mode="parallel_epochs", @@ -97,15 +142,28 @@ def distribute_from_dataset_id( def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset: - """Torch-compatible and distributed-training-aware dataset service distributor. - - - rank 0 process will register the given dataset. - - rank 0 process will broadcast job name and dataset id. - - all rank processes will consume from the same job/dataset. - - Without this, dataset workers will try to serve 1 job per rank process and OOM. - """ + Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption. + + This function is used to distribute a dataset in a distributed training environment. It performs the + following steps: + - On the rank 0 process, it registers the given dataset with the distributed dataset service. + - It broadcasts the job name and dataset ID to all rank processes. + - All rank processes then consume the same dataset from the distributed dataset service. + + Args: + dataset (tf.data.Dataset): The TensorFlow dataset to be distributed. + + Returns: + tf.data.Dataset: The distributed TensorFlow dataset. + + Note: + - If there are no reader processes in the distributed environment, the original dataset is returned + without any distribution. + - This function is intended for use in distributed training environments to prevent out-of-memory (OOM) + issues caused by each rank process trying to serve one job. + + """ if not env.has_readers(): return dataset dataset_service = env.get_dds() diff --git a/reader/test_dataset.py b/reader/test_dataset.py index 85e756f..d9400e4 100644 --- a/reader/test_dataset.py +++ b/reader/test_dataset.py @@ -12,6 +12,17 @@ import torch def create_dataset(tmpdir): + """ + Create a mock dataset for testing. + + This function creates a mock dataset using PyArrow and Parquet for testing purposes. + + Args: + tmpdir: A temporary directory where the dataset will be created. + + Returns: + MockDataset: A mock dataset for testing. + """ table = pa.table( { @@ -34,6 +45,14 @@ def create_dataset(tmpdir): def test_dataset(tmpdir): + """ + Test the created dataset. + + This function tests the created mock dataset and checks if it behaves as expected. + + Args: + tmpdir: A temporary directory used for testing. + """ ds = create_dataset(tmpdir) batch = next(iter(ds.dataloader(remote=False))) assert batch.batch_size == 2 @@ -46,6 +65,14 @@ def test_dataset(tmpdir): reason="Multiprocessing doesn't work on github yet.", ) def test_distributed_dataset(tmpdir): + """ + Test the distributed dataset. + + This function tests the distributed version of the mock dataset using multiprocessing. + + Args: + tmpdir: A temporary directory used for testing. + """ MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"} def _client(): diff --git a/reader/utils.py b/reader/utils.py index fc0e34c..c5c9721 100644 --- a/reader/utils.py +++ b/reader/utils.py @@ -11,11 +11,55 @@ import torch def roundrobin(*iterables): - """Round robin through provided iterables, useful for simple load balancing. - - Adapted from https://docs.python.org/3/library/itertools.html. - """ + Iterate through provided iterables in a round-robin fashion. + + This function takes multiple iterables and returns an iterator that yields elements from + each iterable in a round-robin manner. It continues cycling through the iterables until + all of them are exhausted. + + Adapted from https://docs.python.org/3/library/itertools.html. + + Args: + *iterables: One or more iterable objects to iterate through. + + Yields: + Elements from the provided iterables in a round-robin fashion. + + Raises: + StopIteration: If all provided iterables are exhausted. + + Example: + ```python + iterable1 = [1, 2, 3] + iterable2 = ['a', 'b', 'c'] + iterable3 = [0.1, 0.2, 0.3] + + for item in roundrobin(iterable1, iterable2, iterable3): + print(item) + + # Output: + # 1 + # 'a' + # 0.1 + # 2 + # 'b' + # 0.2 + # 3 + # 'c' + # 0.3 + ``` + + Note: + - If one of the provided iterables is shorter than the others, the function will + continue iterating through the remaining iterables until all are exhausted. + - If an iterable raises an exception during iteration, a warning message is logged, + and the function continues with the next iterable. + + See Also: + - `itertools.cycle`: A function that repeatedly cycles through elements of an iterable. + - `itertools.islice`: A function to slice an iterable to limit the number of iterations. + """ num_active = len(iterables) nexts = itertools.cycle(iter(it).__next__ for it in iterables) while num_active: @@ -35,6 +79,48 @@ def roundrobin(*iterables): def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]): + """ + Monitor the speed and progress of data loading using a data loader. + + This function iterates through a data loader for a specified number of steps or until + the end of the data loader is reached, periodically logging progress information. + + Args: + data_loader: The data loader to monitor. + max_steps: The maximum number of steps to iterate through the data loader. + frequency: The frequency (in steps) at which to log progress. + peek (optional): If specified, it indicates the frequency (in steps) at which to log + batch contents for inspection. + + Example: + ```python + import torch + from torch.utils.data import DataLoader + + # Create a data loader (replace with your own DataLoader configuration) + data_loader = DataLoader(dataset, batch_size=32, shuffle=True) + + # Monitor data loading speed and progress + speed_check(data_loader, max_steps=1000, frequency=50, peek=500) + ``` + + Args: + data_loader: The data loader to monitor. + max_steps: The maximum number of steps to iterate through the data loader. + frequency: The frequency (in steps) at which to log progress. + peek (optional): If specified, it indicates the frequency (in steps) at which to log + batch contents for inspection. + + Note: + - The function logs information about elapsed time, the number of examples processed, + and the processing speed in examples per second. + - If `peek` is provided, batch contents will be logged for inspection at the specified + frequency. + + See Also: + - `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and + iterating through datasets. + """ num_examples = 0 prev = time.perf_counter() for idx, batch in enumerate(data_loader): @@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int] def pa_to_torch(array: pa.array) -> torch.Tensor: + """ + Convert a PyArrow Array to a PyTorch Tensor. + + Args: + array (pa.array): The PyArrow Array to convert. + + Returns: + torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array. + + Example: + ```python + import pyarrow as pa + import torch + + # Create a PyArrow Array + arrow_array = pa.array([1, 2, 3]) + + # Convert it to a PyTorch Tensor + torch_tensor = pa_to_torch(arrow_array) + ``` + """ return torch.from_numpy(array.to_numpy()) def create_default_pa_to_batch(schema) -> DataclassBatch: - """ """ + """ + Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data. + + Args: + schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch. + + Returns: + callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch. + + Example: + ```python + import pyarrow as pa + from dataclass_batch import DataclassBatch + + # Define a PyArrow schema + schema = pa.schema([ + ("feature1", pa.float64()), + ("feature2", pa.int64()), + ("label", pa.int64()), + ]) + + # Create the conversion function + pa_to_batch = create_default_pa_to_batch(schema) + + # Create a PyArrow RecordBatch + record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({ + "feature1": [1.0, 2.0, None], + "feature2": [10, 20, 30], + "label": [0, 1, None], + })) + + # Convert the RecordBatch to a custom DataclassBatch + custom_batch = pa_to_batch(record_batch) + ``` + """ _CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema) def get_imputation_value(pa_type):