diff --git a/common/batch.py b/common/batch.py index cfd0ac0..f72b762 100644 --- a/common/batch.py +++ b/common/batch.py @@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable class BatchBase(Pipelineable, abc.ABC): + """ + A base class for batches used in pipelines. + + Attributes: + None + + """ @abc.abstractmethod def as_dict(self) -> Dict: + """ + Convert the batch into a dictionary representation. + + Returns: + Dict: A dictionary representation of the batch. + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + + """ raise NotImplementedError def to(self, device: torch.device, non_blocking: bool = False): + """ + Move the batch to the specified device. + + Args: + device (torch.device): The target device. + non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False. + + Returns: + BatchBase: A new batch on the target device. + + """ args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking) return self.__class__(**args) def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + """ + Record a CUDA stream for all tensors in the batch. + + Args: + stream (torch.cuda.streams.Stream): The CUDA stream to record. + + Returns: + None + + """ for feature_value in self.as_dict().values(): feature_value.record_stream(stream) def pin_memory(self): + """ + Pin memory for all tensors in the batch. + + Returns: + BatchBase: A new batch with pinned memory. + + """ args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.pin_memory() return self.__class__(**args) def __repr__(self) -> str: + """ + Generate a string representation of the batch. + + Returns: + str: A string representation of the batch. + + """ def obj2str(v): return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" @@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC): @property def batch_size(self) -> int: + """ + Get the batch size from the tensors in the batch. + + Returns: + int: The batch size. + + Raises: + Exception: If the batch size cannot be determined from the tensors. + + """ for tensor in self.as_dict().values(): if tensor is None: continue @@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC): @dataclass class DataclassBatch(BatchBase): + """ + A batch class that uses dataclasses to define its fields. + + Attributes: + None + + """ @classmethod def feature_names(cls): + """ + Get the feature names of the dataclass. + + Returns: + List[str]: A list of feature names. + + """ return list(cls.__dataclass_fields__.keys()) def as_dict(self): + """ + Convert the dataclass batch into a dictionary representation. + + Returns: + Dict: A dictionary representation of the batch. + + """ return { feature_name: getattr(self, feature_name) for feature_name in self.feature_names() @@ -64,7 +147,18 @@ class DataclassBatch(BatchBase): @staticmethod def from_schema(name: str, schema): - """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" + """ + Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor. + + Args: + name (str): The name of the custom batch class. + schema: The schema or structure of the batch. + + Returns: + Type[DataclassBatch]: A custom batch class. + + """ + return dataclasses.make_dataclass( cls_name=name, fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names], @@ -73,6 +167,17 @@ class DataclassBatch(BatchBase): @staticmethod def from_fields(name: str, fields: dict): + """ + Create a custom batch subclass from a set of fields. + + Args: + name (str): The name of the custom batch class. + fields (dict): A dictionary specifying the fields and their types. + + Returns: + Type[DataclassBatch]: A custom batch class. + + """ return dataclasses.make_dataclass( cls_name=name, fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()], @@ -81,5 +186,19 @@ class DataclassBatch(BatchBase): class DictionaryBatch(BatchBase, dict): + """ + A batch class that represents data as a dictionary. + + Attributes: + None + + """ def as_dict(self) -> Dict: + """ + Convert the dictionary batch into a dictionary representation. + + Returns: + Dict: A dictionary representation of the batch. + + """ return self diff --git a/common/checkpointing/snapshot.py b/common/checkpointing/snapshot.py index 2703efd..ad1fac1 100644 --- a/common/checkpointing/snapshot.py +++ b/common/checkpointing/snapshot.py @@ -13,23 +13,30 @@ GCS_PREFIX = "gs://" class Snapshot: - """Checkpoints using torchsnapshot. - - Also saves step to be updated by the training loop. - """ + Checkpoints using torchsnapshot. Also saves step to be updated by the training loop. + """ def __init__(self, save_dir: str, state: Dict[str, Any]) -> None: + """ + Initializes a Snapshot object. + + Args: + save_dir (str): Directory where checkpoints will be saved. + state (Dict[str, Any]): State dictionary containing checkpoint information. + """ self.save_dir = save_dir self.state = state self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0) @property def step(self): + """Get the current training step.""" return self.state["extra_state"]["step"] @step.setter def step(self, step: int) -> None: + """Set the current training step.""" self.state["extra_state"]["step"] = step @property @@ -41,7 +48,15 @@ class Snapshot: self.state["extra_state"]["walltime"] = walltime def save(self, global_step: int) -> "PendingSnapshot": - """Saves checkpoint with given global_step.""" + """ + Saves a checkpoint with a given global step. + + Args: + global_step (int): The global step to associate with the checkpoint. + + Returns: + PendingSnapshot: A pending snapshot object. + """ path = os.path.join(self.save_dir, str(global_step)) logging.info(f"Saving snapshot global_step {global_step} to {path}.") start_time = time.time() @@ -58,7 +73,12 @@ class Snapshot: return snapshot def restore(self, checkpoint: str) -> None: - """Restores a given checkpoint.""" + """ + Restores a given checkpoint. + + Args: + checkpoint (str): Path to the checkpoint to restore. + """ snapshot = torchsnapshot.Snapshot(path=checkpoint) logging.info(f"Restoring snapshot from {snapshot.path}.") start_time = time.time() @@ -83,12 +103,17 @@ class Snapshot: global_step: Optional[int] = None, missing_ok: bool = False, ) -> torchsnapshot.Snapshot: - """Get torch stateless snapshot, without actually loading it. - Args: - snapshot_path: path to the model snapshot - global_step: restores from this checkpoint if specified. - missing_ok: if True and checkpoints do not exist, returns without restoration. """ + Get a torch stateless snapshot, without actually loading it. + + Args: + snapshot_path (str): Path to the model snapshot. + global_step (int, optional): Restores from this checkpoint if specified. + missing_ok (bool): If True and checkpoints do not exist, returns without restoration. + + Returns: + torchsnapshot.Snapshot: A torch snapshot object. + """ path = get_checkpoint(snapshot_path, global_step, missing_ok) logging.info(f"Loading snapshot from {path}.") return torchsnapshot.Snapshot(path=path) @@ -100,13 +125,14 @@ class Snapshot: snapshot_emb_name: str, weight_tensor, ) -> None: - """Loads pretrained embedding from the snapshot to the model. - Utilise partial lodaing meachanism from torchsnapshot. - Args: - embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC). - snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC. - weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded. """ + Loads pretrained embedding from the snapshot to the model. + + Args: + embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC). + snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC. + weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded. + """ start_time = time.time() manifest = embedding_snapshot.get_manifest() for path in manifest.keys(): @@ -209,7 +235,22 @@ def get_checkpoint( def get_checkpoints(save_dir: str) -> List[str]: - """Gets all checkpoints that have been fully written.""" + """ + Get a list of fully written checkpoints in the specified directory. + + This function retrieves a list of fully written checkpoints in the given directory. + Checkpoints that are considered fully written include those that have a + corresponding snapshot metadata file. + + Args: + save_dir (str): The directory where checkpoints are stored. + + Returns: + List[str]: A list of fully written checkpoint paths. + + Note: + Checkpoints are sorted by their numeric filenames in ascending order. + """ checkpoints = [] fs = infer_fs(save_dir) if fs.exists(save_dir): @@ -232,6 +273,18 @@ def wait_for_evaluators( global_step: int, timeout: int, ) -> None: + """ + Waits for all evaluators to finish and checks for their completion status. + + Args: + save_dir (str): Directory where checkpoints are saved. + partition_names (List[str]): List of partition names to check for completion. + global_step (int): The global step for which to wait for evaluators. + timeout (int): Maximum time in seconds to wait for evaluators to finish. + + Returns: + None: This function returns nothing but logs the progress and results. + """ logging.info("Waiting for all evaluators to finish.") start_time = time.time() diff --git a/common/device.py b/common/device.py index 33d1bb8..b51c638 100644 --- a/common/device.py +++ b/common/device.py @@ -5,6 +5,15 @@ import torch.distributed as dist def maybe_setup_tensorflow(): + """ + Try to import TensorFlow and disable GPU devices if TensorFlow is available. + + This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch. + + Returns: + None + + """ try: import tensorflow as tf except ImportError: @@ -14,6 +23,18 @@ def maybe_setup_tensorflow(): def setup_and_get_device(tf_ok: bool = True) -> torch.device: + """ + Set up the distributed environment and get the appropriate torch device. + + This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank. + + Args: + tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True. + + Returns: + torch.device: The torch device for the current process. + + """ if tf_ok: maybe_setup_tensorflow() diff --git a/common/filesystem/util.py b/common/filesystem/util.py index 5148a0d..30360f2 100644 --- a/common/filesystem/util.py +++ b/common/filesystem/util.py @@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem() def infer_fs(path: str): + """ + Infer the file system (fs) type based on the given path. + + Args: + path (str): The file path. + + Returns: + str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system). + + Raises: + NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported. + """ if path.startswith("gs://"): return GCS_FS elif path.startswith("hdfs://"): @@ -18,8 +30,26 @@ def infer_fs(path: str): def is_local_fs(fs): + """ + Check if the given file system is the local file system. + + Args: + fs (str): The file system type to check. + + Returns: + bool: True if the file system is the local file system, False otherwise. + """ return fs == LOCAL_FS def is_gcs_fs(fs): + """ + Check if the given file system is Google Cloud Storage (GCS). + + Args: + fs (str): The file system type to check. + + Returns: + bool: True if the file system is GCS, False otherwise. + """ return fs == GCS_FS diff --git a/common/log_weights.py b/common/log_weights.py index d2c8e21..cecf42b 100644 --- a/common/log_weights.py +++ b/common/log_weights.py @@ -49,15 +49,25 @@ def log_ebc_norms( ebc_keys: List[str], sample_size: int = 4_000_000, ) -> Dict[str, torch.Tensor]: - """Logs the norms of the embedding tables as specified by ebc_keys. - As of now, log average norm per rank. - - Args: - model_state_dict: model.state_dict() - ebc_keys: list of embedding keys from state_dict to log. Must contain full name, - i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight - sample_size: Limits number of rows per rank to compute average on to avoid OOM. """ + Logs the norms of the embedding tables as specified by ebc_keys. + As of now, log average norm per rank. + + Args: + model_state_dict (dict): The state dictionary of the model. + ebc_keys (List[str]): List of embedding keys from state_dict to log. + Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight". + sample_size (int, optional): Limits the number of rows per rank to compute average on + to avoid out-of-memory (OOM) errors. Defaults to 4,000,000. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables. + The keys are in the format "{emb_key}-norm-{idx}". + + Note: + This function computes and logs the average norm of embedding tables across ranks. + It gathers the norms from all ranks and returns them as a dictionary. + """ norm_logs = dict() for emb_key in ebc_keys: norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}")) diff --git a/common/modules/embedding/embedding.py b/common/modules/embedding/embedding.py index b0a085e..579faa5 100644 --- a/common/modules/embedding/embedding.py +++ b/common/modules/embedding/embedding.py @@ -11,6 +11,20 @@ import numpy as np class LargeEmbeddings(nn.Module): + """ + A module for handling large embeddings. + + Args: + large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings. + + Attributes: + ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings. + surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models. + + Note: + The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving. + + """ def __init__( self, large_embeddings_config: LargeEmbeddingsConfig, @@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module): self, sparse_features: KeyedJaggedTensor, ) -> KeyedTensor: + """ + Forward pass of the LargeEmbeddings module. + + Args: + sparse_features (KeyedJaggedTensor): Sparse input features. + + Returns: + KeyedTensor: The output of the forward pass, a KeyedTensor. + + """ pooled_embs = self.ebc(sparse_features) # a KeyedTensor diff --git a/common/run_training.py b/common/run_training.py index fafb140..20a3d34 100644 --- a/common/run_training.py +++ b/common/run_training.py @@ -11,6 +11,12 @@ import torch.distributed.run def is_distributed_worker(): + """ + Checks if the current process is a distributed worker. + + Returns: + bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False. + """ world_size = os.environ.get("WORLD_SIZE", None) rank = os.environ.get("RANK", None) return world_size is not None and rank is not None @@ -25,25 +31,36 @@ def maybe_run_training( is_chief: Optional[bool] = False, **training_kwargs, ): - """Wrapper function for single node, multi-GPU Pytorch training. - - If the necessary distributed Pytorch environment variables - (WORLD_SIZE, RANK) have been set, then this function executes - `train_fn(**training_kwargs)`. - - Otherwise, this function calls torchrun and points at the calling module - `module_name`. After this call, the necessary environment variables are set - and training will commence. - - Args: - train_fn: The function that is responsible for training - module_name: The name of the module that this function was called from; - used to indicate torchrun entrypoint. - nproc_per_node: Number of workers per node; supported values. - num_nodes: Number of nodes, otherwise inferred from environment. - is_chief: If process is running on chief. - set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH. """ + Wrapper function for single node, multi-GPU PyTorch training. + + If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes + `train_fn(**training_kwargs)`. + + Otherwise, this function calls torchrun and points at the calling module + `module_name`. After this call, the necessary environment variables are set + and training will commence. + + Args: + train_fn (callable): The function responsible for training. + module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint. + nproc_per_node (int, optional): Number of workers per node. Defaults to None. + num_nodes (int, optional): Number of nodes. Defaults to None. + is_chief (bool, optional): If the process is running on the chief node. Defaults to False. + set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False. + **training_kwargs: Additional keyword arguments to pass to the `train_fn`. + + Note: + This function checks if the current process is a distributed worker by examining the environment variables. + If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary + environment variables and launches the training process using torchrun. + + Example: + To run training on a single node with 4 GPUs, you can use: + ``` + maybe_run_training(train_function, __name__, nproc_per_node=4) + ``` + """ machines = utils.machine_from_env() if num_nodes is None: diff --git a/common/testing_utils.py b/common/testing_utils.py index 69e14de..6feaf6b 100644 --- a/common/testing_utils.py +++ b/common/testing_utils.py @@ -19,6 +19,23 @@ MOCK_ENV = { @contextmanager def mock_pg(): + """ + A context manager for mocking the distributed process group for testing purposes. + + This context manager temporarily sets environment variables to mock the distributed process group + and initializes it using the Gloo backend. It is useful for testing distributed training without + actually launching multiple processes. + + Example: + ```python + with mock_pg(): + # Your distributed training code here + ``` + + Note: + This context manager should be used within a testing environment to simulate distributed training + without actually creating multiple processes. + """ with patch.dict(os.environ, MOCK_ENV): try: dist.init_process_group( diff --git a/common/utils.py b/common/utils.py index 09d4464..e509ffb 100644 --- a/common/utils.py +++ b/common/utils.py @@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig) def _read_file(f): + """ + Read the contents of a file using fsspec. + + Args: + f: File path or URL. + + Returns: + The contents of the file. + """ with fsspec.open(f) as f: return f.read() @@ -21,18 +30,24 @@ def setup_configuration( yaml_path: str, substitute_env_variable: bool = False, ) -> Tuple[C, str]: - """Resolves a config at a yaml path. - - Args: - config_type: Pydantic config class to load. - yaml_path: yaml path of the config file. - substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their - environment variable value whenever possible. If an environment variable doesn't exist, - the string is left unchanged. - - Returns: - The pydantic config object. """ + Load a Pydantic config object from a YAML file and optionally substitute environment variables. + + Args: + config_type: Pydantic config class to load. + yaml_path: YAML path of the config file. + substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR} + with their environment variable values whenever possible. + If an environment variable doesn't exist, the string is left unchanged. + + Returns: + A tuple containing the Pydantic config object and the resolved YAML content. + + Example: + ```python + config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True) + ``` + """ def _substitute(s): if substitute_env_variable: diff --git a/common/wandb.py b/common/wandb.py index eb8cceb..9808db0 100644 --- a/common/wandb.py +++ b/common/wandb.py @@ -6,6 +6,33 @@ import pydantic class WandbConfig(base_config.BaseConfig): + """ + Configuration for integrating with Weights and Biases (WandB). + + Attributes: + host (str): Host of the Weights and Biases instance, passed to login. + key_path (str): Path to the key file. + name (str): Name of the experiment, passed to init. + entity (str): Name of the user/service account, passed to init. + project (str): Name of the WandB project, passed to init. + tags (List[str]): List of tags, passed to init. + notes (str): Notes, passed to init. + metadata (Dict[str, Any]): Additional metadata to log. + + Example: + ```python + wandb_config = WandbConfig( + host="https://wandb.example.com", + key_path="/path/to/key", + name="experiment_1", + entity="user123", + project="my_project", + tags=["experiment", "ml"], + notes="This is a test experiment.", + metadata={"version": "1.0"} + ) + ``` + """ host: str = pydantic.Field( "https://https--wandb--prod--wandb.service.qus1.twitter.biz/", description="Host of Weights and Biases instance, passed to login.", diff --git a/core/config/base_config.py b/core/config/base_config.py index 1c15d04..76a2022 100644 --- a/core/config/base_config.py +++ b/core/config/base_config.py @@ -8,59 +8,69 @@ import pydantic class BaseConfig(pydantic.BaseModel): - """Base class for all derived config classes. + """Base class for all derived config classes. - This class provides some convenient functionality: - - Disallows extra fields when constructing an object. User error - should be reduced by exact arguments. - - "one_of" fields. A subclass can group optional fields and enforce - that only one of the fields be set. For example: + This class provides convenient functionality and constraints for derived config classes: - ``` + - Disallows extra fields when constructing an object. User errors due to extraneous arguments + are minimized. + - "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields + be set. For example: + + ```python class ExampleConfig(BaseConfig): - x: int = Field(None, one_of="group_1") - y: int = Field(None, one_of="group_1") + x: int = Field(None, one_of="group_1") + y: int = Field(None, one_of="group_1") - ExampleConfig(x=1) # ok - ExampleConfig(y=1) # ok - ExampleConfig(x=1, y=1) # throws error + ExampleConfig(x=1) # OK + ExampleConfig(y=1) # OK + ExampleConfig(x=1, y=1) # Raises an error ``` - """ - class Config: - """Forbids extras.""" + Attributes: + Config (class): Configuration options for this class, forbidding extra fields. - extra = pydantic.Extra.forbid # noqa + Methods: + _field_data_map(cls, field_data_name): Create a map of fields with the provided field data. + _one_of_check(cls, values): Validate that all 'one of' fields appear exactly once. + _at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once. + pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging. - @classmethod - @functools.lru_cache() - def _field_data_map(cls, field_data_name): - """Create a map of fields with provided the field data.""" - schema = cls.schema() - one_of = collections.defaultdict(list) - for field, fdata in schema["properties"].items(): - if field_data_name in fdata: - one_of[fdata[field_data_name]].append(field) - return one_of + """ + class Config: + """Configuration options that forbid extra fields.""" + extra = pydantic.Extra.forbid # noqa - @pydantic.root_validator - def _one_of_check(cls, values): - """Validate that all 'one of' fields are appear exactly once.""" - one_of_map = cls._field_data_map("one_of") - for one_of, field_names in one_of_map.items(): - if sum([values.get(n, None) is not None for n in field_names]) != 1: - raise ValueError(f"Exactly one of {','.join(field_names)} required.") - return values + @classmethod + @functools.lru_cache() + def _field_data_map(cls, field_data_name): + """Create a map of fields with the provided field data.""" + schema = cls.schema() + one_of = collections.defaultdict(list) + for field, fdata in schema["properties"].items(): + if field_data_name in fdata: + one_of[fdata[field_data_name]].append(field) + return one_of - @pydantic.root_validator - def _at_most_one_of_check(cls, values): - """Validate that all 'at_most_one_of' fields appear at most once.""" - at_most_one_of_map = cls._field_data_map("at_most_one_of") - for one_of, field_names in at_most_one_of_map.items(): - if sum([values.get(n, None) is not None for n in field_names]) > 1: - raise ValueError(f"At most one of {','.join(field_names)} can be set.") - return values + @pydantic.root_validator + def _one_of_check(cls, values): + """Validate that all 'one of' fields appear exactly once.""" + one_of_map = cls._field_data_map("one_of") + for one_of, field_names in one_of_map.items(): + if sum([values.get(n, None) is not None for n in field_names]) != 1: + raise ValueError(f"Exactly one of {','.join(field_names)} required.") + return values - def pretty_print(self) -> str: - """Return a human legible (yaml) representation of the config useful for logging.""" - return yaml.dump(self.dict()) + @pydantic.root_validator + def _at_most_one_of_check(cls, values): + """Validate that all 'at_most_one_of' fields appear at most once.""" + at_most_one_of_map = cls._field_data_map("at_most_one_of") + for one_of, field_names in at_most_one_of_map.items(): + if sum([values.get(n, None) is not None for n in field_names]) > 1: + raise ValueError(f"At most one of {','.join(field_names)} can be set.") + + return values + + def pretty_print(self) -> str: + """Return a human-readable (YAML) representation of the config useful for logging.""" + return yaml.dump(self.dict()) diff --git a/core/config/base_config_test.py b/core/config/base_config_test.py index e623f05..c4319e5 100644 --- a/core/config/base_config_test.py +++ b/core/config/base_config_test.py @@ -6,33 +6,67 @@ import pydantic class BaseConfigTest(TestCase): - def test_extra_forbidden(self): - class Config(BaseConfig): - x: int + """ + Unit tests for the BaseConfig class. + """ - Config(x=1) - with self.assertRaises(pydantic.ValidationError): - Config(x=1, y=2) + def test_extra_forbidden(self): + """ + Test that extra fields are forbidden when creating a Config instance. - def test_one_of(self): - class Config(BaseConfig): - x: int = pydantic.Field(None, one_of="f") - y: int = pydantic.Field(None, one_of="f") + This test case checks whether the `BaseConfig` class correctly raises a + `pydantic.ValidationError` when extra fields are provided when creating a + `Config` instance. - with self.assertRaises(pydantic.ValidationError): - Config() - Config(x=1) - Config(y=1) - with self.assertRaises(pydantic.ValidationError): - Config(x=1, y=3) + Raises: + AssertionError: If the test fails. + """ + class Config(BaseConfig): + x: int - def test_at_most_one_of(self): - class Config(BaseConfig): - x: int = pydantic.Field(None, at_most_one_of="f") - y: str = pydantic.Field(None, at_most_one_of="f") + Config(x=1) + with self.assertRaises(pydantic.ValidationError): + Config(x=1, y=2) - Config() - Config(x=1) - Config(y="a") - with self.assertRaises(pydantic.ValidationError): - Config(x=1, y="a") + def test_one_of(self): + """ + Test the use of the `one_of` attribute for fields in a Config instance. + + This test case checks the behavior of the `one_of` attribute in a `Config` + instance. It verifies that the `pydantic.Field` correctly enforces the + specified constraint. + + Raises: + AssertionError: If the test fails. + """ + class Config(BaseConfig): + x: int = pydantic.Field(None, one_of="f") + y: int = pydantic.Field(None, one_of="f") + + with self.assertRaises(pydantic.ValidationError): + Config() + Config(x=1) + Config(y=1) + with self.assertRaises(pydantic.ValidationError): + Config(x=1, y=3) + + def test_at_most_one_of(self): + """ + Test the use of the `at_most_one_of` attribute for fields in a Config instance. + + This test case checks the behavior of the `at_most_one_of` attribute in a + `Config` instance. It verifies that the `pydantic.Field` enforces the + constraint where at most one of the specified fields can be provided. + + Raises: + AssertionError: If the test fails. + """ + class Config(BaseConfig): + x: int = pydantic.Field(None, at_most_one_of="f") + y: str = pydantic.Field(None, at_most_one_of="f") + + Config() + Config(x=1) + Config(y="a") + with self.assertRaises(pydantic.ValidationError): + Config(x=1, y="a") diff --git a/core/config/config_load.py b/core/config/config_load.py index 709da41..f04c783 100644 --- a/core/config/config_load.py +++ b/core/config/config_load.py @@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str): - """Recommend method to load a config file (a yaml file) and parse it. - - Because we have a shared filesystem the recommended route to running jobs it put modified config - files with the desired parameters somewhere on the filesytem and run jobs pointing to them. """ + Recommend method to Load and parse a configuration from a YAML file. + + This function loads a configuration from a YAML file, parses it, and returns an instance of the + specified config type. + + Because we have a shared filesystem the recommended route to running jobs it put modified config + files with the desired parameters somewhere on the filesytem and run jobs pointing to them. + + Args: + config_type (Type[BaseConfig]): The Pydantic config class to load. + yaml_path (str): The path to the YAML configuration file. + + Returns: + BaseConfig: An instance of the specified config type populated with values from the YAML file. + + Example: + Suppose you have a YAML file 'my_config.yaml' containing the following: + + ```yaml + x: 42 + y: "hello" + ``` + + You can load and parse it using this function as follows: + + ```python + my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml') + ``` + + Note: + This function performs environment variable substitution in the YAML file. It replaces + occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable + values. If an environment variable does not exist, the string is left unchanged. + + """ def _substitute(s): return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) diff --git a/core/config/test_config_load.py b/core/config/test_config_load.py index 0fc88e2..1ad3942 100644 --- a/core/config/test_config_load.py +++ b/core/config/test_config_load.py @@ -8,11 +8,27 @@ import pydantic class _PointlessConfig(BaseConfig): + a: int user: str def test_load_config_from_yaml(tmp_path): + """Test loading a configuration from a YAML file and verifying its values. + + This test function checks the functionality of the `load_config_from_yaml` function by creating + a temporary YAML configuration file, loading it, and asserting that the loaded config object + has the expected values. + + Args: + tmp_path: A temporary directory provided by the `pytest` framework. + + Test Steps: + 1. Create a temporary YAML file containing configuration data. + 2. Use the `load_config_from_yaml` function to load the configuration from the YAML file. + 3. Assert that the loaded configuration object has the expected values. + + """ yaml_path = tmp_path.joinpath("test.yaml").as_posix() with open(yaml_path, "w") as yaml_file: yaml_file.write("""a: 3\nuser: ${USER}\n""") diff --git a/core/config/training.py b/core/config/training.py index a0c54b4..acdbde7 100644 --- a/core/config/training.py +++ b/core/config/training.py @@ -9,6 +9,14 @@ import pydantic class RuntimeConfig(base_config.BaseConfig): + """ + Configuration for runtime settings. + + Attributes: + - wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration. + - enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices. + - enable_amp (bool): Enable automatic mixed precision. + """ wandb: WandbConfig = pydantic.Field(None) enable_tensorfloat32: bool = pydantic.Field( False, description="Use tensorfloat32 if on Ampere devices." @@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig): class TrainingConfig(base_config.BaseConfig): + """ + Configuration for training settings. + + Attributes: + - save_dir (str): Directory to save checkpoints. + - num_train_steps (pydantic.PositiveInt): Number of training steps. + - initial_checkpoint_dir (str): Directory of initial checkpoints (optional). + - checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps. + - checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional). + - train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps. + - num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used. + - eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps. + - eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds. + - gradient_accumulation (int): Number of replica steps to accumulate gradients (optional). + - num_epochs (pydantic.PositiveInt): Number of training epochs. + """ save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.") num_train_steps: pydantic.PositiveInt = 10000 initial_checkpoint_dir: str = pydantic.Field( diff --git a/core/custom_training_loop.py b/core/custom_training_loop.py index 0241145..73b2cf1 100644 --- a/core/custom_training_loop.py +++ b/core/custom_training_loop.py @@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable): def _get_step_fn(pipeline, data_iterator, training: bool): + """ + Returns a function to perform a single evaluation step. + + Args: + pipeline (Pipeline): The pipeline object containing the model. + data_iterator (Iterator): The data iterator for evaluation. + training (bool): Flag indicating if the model should be in training mode. + + Returns: + function: A function that performs a single evaluation step. + """ def step_fn(): + """ + Perform a single evaluation step. + + Returns: + Any: The evaluation results after a single step. + """ # It turns out that model.train() and model.eval() simply switch a single field inside the model # class,so it's somewhat safer to wrap in here. if training: @@ -69,7 +86,21 @@ def _run_evaluation( eval_batch_size: int, logger=None, ): - """Runs the evaluation loop over all evaluation iterators.""" + """ + Run the evaluation loop over all evaluation iterators. + + Args: + pipeline (Pipeline): The pipeline object containing the model. + dataset (Dataset): The dataset to evaluate. + eval_steps (int): The number of evaluation steps to perform. + metrics (tm.MetricCollection): A collection of evaluation metrics. + eval_batch_size (int): Batch size for evaluation. + logger (Optional[Logger]): A logger for recording evaluation progress (default: None). + + Returns: + dict: A dictionary containing the computed evaluation metrics. + """ + dataset = get_new_iterator(dataset) step_fn = _get_step_fn(pipeline, dataset, training=False) last_time = datetime.datetime.now() @@ -109,15 +140,29 @@ def train( parameters_to_log: Optional[Dict[str, Callable]] = None, tables_to_log: Optional[List[str]] = None, ) -> None: - """Runs training and eval on the given TrainPipeline - - Args: - dataset: data iterator for the training set - evaluation_iterators: data iterators for the different evaluation sets - scheduler: optional learning rate scheduler - output_transform_for_metrics: optional transformation functions to transorm the model - output and labels into a format the metrics can understand """ + Runs training and evaluation on the given TrainPipeline. + + Args: + model (torch.nn.Module): The neural network model to train. + optimizer (torch.optim.Optimizer): The optimizer for model optimization. + device (str): The target device for model training (e.g., 'cuda' or 'cpu'). + save_dir (str): The directory to save model checkpoints and logs. + logging_interval (int): Interval for logging training progress. + train_steps (int): The number of training steps to perform. + checkpoint_frequency (int): Frequency of saving model checkpoints. + dataset (Iterable): Data iterator for the training set. + worker_batch_size (int): Batch size for data loading workers. + num_workers (Optional[int]): Number of data loading workers (default: 0). + enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False). + initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None). + gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None). + logger_initializer (Optional[Callable]): A logger initializer function (default: None). + scheduler (_LRScheduler): Optional learning rate scheduler (default: None). + metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None). + parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None). + tables_to_log (Optional[List[str]]): List of tables to log (default: None). + """ train_pipeline = TrainPipelineSparseDist( model=model, @@ -262,6 +307,15 @@ def log_eval_results( partition_name: str, step: int, ): + """ + Logs evaluation results and optionally records them using a provided logger. + + Args: + results (Any): The evaluation results to log. + eval_logger (Callable): A logger for recording evaluation results. + partition_name (str): The name of the evaluation partition. + step (int): The current step in the evaluation. + """ results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results) logging.info(f"Step: {step}, evaluation ({partition_name}).") for metric_name, metric_value in results.items(): @@ -285,6 +339,23 @@ def only_evaluate( partition_name: str, metrics: Optional[tm.MetricCollection] = None, ): + """ + Performs evaluation on a given dataset partition. + + Args: + model (torch.nn.Module): The neural network model for evaluation. + optimizer (torch.optim.Optimizer): The optimizer used during evaluation. + device (str): The target device for evaluation (e.g., 'cuda' or 'cpu'). + save_dir (str): The directory containing model checkpoints. + num_train_steps (int): The total number of training steps. + dataset (Iterable): Data iterator for evaluation. + eval_batch_size (int): Batch size for evaluation. + num_eval_steps (int): The number of evaluation steps to perform. + eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds. + eval_logger (Callable): A logger for recording evaluation results. + partition_name (str): The name of the evaluation partition. + metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None). + """ logging.info(f"Evaluating on partition {partition_name}.") logging.info("Computing metrics:") logging.info(metrics) diff --git a/core/debug_training_loop.py b/core/debug_training_loop.py index 610eea9..bced83f 100644 --- a/core/debug_training_loop.py +++ b/core/debug_training_loop.py @@ -28,6 +28,18 @@ def train( *args, **kwargs, ) -> None: + """ + Debugging training loop. Do not use for actual model training. + + Args: + model (torch.nn.Module): The neural network model. + optimizer (torch.optim.Optimizer): The optimizer for model optimization. + train_steps (int): The number of training steps to perform. + dataset (Iterable): Data iterator for training data. + scheduler (_LRScheduler, optional): Learning rate scheduler (default: None). + *args: Additional arguments (ignored). + **kwargs: Additional keyword arguments (ignored). + """ logging.warning("Running debug training loop, don't use for model training.") diff --git a/core/losses.py b/core/losses.py index 6ef9a4a..7cf0cf1 100644 --- a/core/losses.py +++ b/core/losses.py @@ -10,8 +10,11 @@ import torch def _maybe_warn(reduction: str): """ - Warning for reduction different than mean. - """ + Emit a warning if the reduction method is different from 'mean'. + + Args: + reduction (str): The reduction method being used. + """ if reduction != "mean": logging.warn( f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal," @@ -24,6 +27,16 @@ def build_loss( loss_type: LossType, reduction="mean", ): + """ + Build a loss function based on the specified loss type and reduction method. + + Args: + loss_type (LossType): The type of loss to build. + reduction (str): The reduction method for the loss (default: 'mean'). + + Returns: + Callable: A loss function that takes logits and labels as input. + """ _maybe_warn(reduction) f = _LOSS_TYPE_TO_FUNCTION[loss_type] @@ -35,11 +48,15 @@ def build_loss( def get_global_loss_detached(local_loss, reduction="mean"): """ - Perform all_reduce to obtain the global loss function using the provided reduction. - :param local_loss: The local loss of the current rank. - :param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP. - :return: The reduced & detached global loss. - """ + Perform all_reduce to obtain the global loss function using the provided reduction. + + Args: + local_loss (torch.Tensor): The local loss of the current rank. + reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP. + + Returns: + torch.Tensor: The reduced and detached global loss. + """ if reduction != "mean": logging.warn( f"The reduction used in this function should be the same as the one used by " @@ -66,6 +83,19 @@ def build_multi_task_loss( global_reduction="mean", pos_weights=None, ): + """ + Build a multi-task loss function based on the specified loss type and configurations. + + Args: + loss_type (LossType): The type of loss to build. + tasks (typing.List[str]): List of task names. + task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean'). + global_reduction (str): Reduction method for the global loss (default: 'mean'). + pos_weights (Optional): Positive class weights for tasks (default: None). + + Returns: + Callable: A multi-task loss function that takes logits, labels, and weights as input. + """ _maybe_warn(global_reduction) _maybe_warn(task_loss_reduction) f = _LOSS_TYPE_TO_FUNCTION[loss_type] diff --git a/core/metric_mixin.py b/core/metric_mixin.py index a716ca7..def38cb 100644 --- a/core/metric_mixin.py +++ b/core/metric_mixin.py @@ -36,9 +36,24 @@ import torchmetrics class MetricMixin: @abstractmethod def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: + """ + Abstract method to transform model outputs into a dictionary of metrics. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + + Returns: + Dict: A dictionary of computed metrics. + """ ... def update(self, outputs: Dict[str, torch.Tensor]): + """ + Update the metrics based on model outputs. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + """ results = self.transform(outputs) # Do not try to update if any tensor is empty as a result of stratification. for value in results.values(): @@ -49,6 +64,13 @@ class MetricMixin: class TaskMixin: def __init__(self, task_idx: int = -1, **kwargs): + """ + Initialize a TaskMixin instance. + + Args: + task_idx (int): Index of the task associated with this mixin (default: -1). + **kwargs: Additional keyword arguments. + """ super().__init__(**kwargs) self._task_idx = task_idx @@ -59,13 +81,31 @@ class StratifyMixin: stratifier=None, **kwargs, ): + """ + Initialize a StratifyMixin instance. + + Args: + stratifier: A stratifier for filtering outputs (default: None). + **kwargs: Additional keyword arguments. + """ super().__init__(**kwargs) self._stratifier = stratifier def maybe_apply_stratification( self, outputs: Dict[str, torch.Tensor], value_names: List[str] ) -> Dict[str, torch.Tensor]: - """Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.""" + """ + Apply stratification to filter examples in the outputs. + + Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + value_names (List[str]): Names of values to filter. + + Returns: + Dict[str, torch.Tensor]: Filtered outputs. + """ outputs = outputs.copy() if not self._stratifier: return outputs @@ -84,12 +124,20 @@ class StratifyMixin: def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable): - """Returns new class using MetricMixin and given base_metric. - - Functionally the same using inheritance, just saves some lines of code - if no need for class attributes. - """ + Returns a new class using MetricMixin and the given base_metric. + + Functionally the same as using inheritance, but it saves some lines of code + if there's no need for class attributes. + + Args: + base_metric (torchmetrics.Metric): The base metric class to prepend the transform to. + transform (Callable): The transformation function to prepend to the metric. + + Returns: + Type: A new class that includes MetricMixin and the provided base_metric + with the specified transformation method. + """ def transform_method(_self, *args, **kwargs): return transform(*args, **kwargs) diff --git a/core/metrics.py b/core/metrics.py index 2384e4d..0a9c38b 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -15,6 +15,16 @@ def probs_and_labels( outputs: Dict[str, torch.Tensor], task_idx: int, ) -> Dict[str, torch.Tensor]: + """ + Extract probabilities and labels from model outputs. + + Args: + outputs (Dict[str, torch.Tensor]): Model outputs. + task_idx (int): Index of the task. + + Returns: + Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors. + """ preds = outputs["probabilities"] target = outputs["labels"] if task_idx >= 0: @@ -28,6 +38,11 @@ def probs_and_labels( class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): def transform(self, outputs): + """ + Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric. + + This metric counts values after potential stratification and task selection. + """ outputs = self.maybe_apply_stratification(outputs, ["labels"]) value = outputs["labels"] if self._task_idx >= 0: @@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): + """ + Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric. + + This metric calculates the mean metric value after potential stratification and task selection. + """ + def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["labels"]) value = outputs["labels"] @@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): + """ + Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric. + + This metric calculates the mean metric value using probabilities after potential stratification and task selection. + """ def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["probabilities"]) value = outputs["probabilities"] @@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision): + """ + Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision. + + This metric computes precision after potential stratification and task selection. + """ def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) return probs_and_labels(outputs, self._task_idx) class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall): + """ + Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall. + + This metric computes recall after potential stratification and task selection. + """ def transform(self, outputs): outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"]) return probs_and_labels(outputs, self._task_idx) @@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC): class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ + AUC (Area Under the ROC Curve) metric class. + + This metric computes the AUC metric based on the logits and labels in the model outputs. + + Args: + num_samples (int): The number of samples used to compute AUC. + **kwargs: Additional keyword arguments. + Based on: https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420 """ @@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ - The ranks of all positives - Based on: + PosRanks metric class. + + This metric computes the ranks of all positive examples based on the logits and labels + in the model outputs. + + Args: + **kwargs: Additional keyword arguments. + https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73 """ @@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ - The reciprocal of the ranks of all - Based on: + ReciprocalRank metric class. + + This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels + in the model outputs. + + Args: + **kwargs: Additional keyword arguments. https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74 """ @@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): """ - The fraction of positives that rank in the top K among their negatives - Note that this is basically precision@k - Based on: + HitAtK metric class. + + This metric computes the fraction of positive examples that rank in the top K among their negatives, + which is equivalent to precision@K. + + Args: + k (int): The value of K. + **kwargs: Additional keyword arguments. https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75 """ diff --git a/core/test_metrics.py b/core/test_metrics.py index ac29819..c9a34e8 100644 --- a/core/test_metrics.py +++ b/core/test_metrics.py @@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric @dataclass class MockStratifierConfig: + """ + Configuration dataclass for mocking a stratifier. + + Args: + name (str): The name of the stratifier. + index (int): The index of the stratifier. + value (int): The value of the stratifier. + """ name: str index: int value: int class Count(MetricMixin, SumMetric): + """ + Count metric class that inherits from MetricMixin and SumMetric. + + This metric counts occurrences. + + """ def transform(self, outputs): return {"value": 1} @@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]}) def test_count_metric(): + """ + Test function for the Count metric. + + It checks if the Count metric correctly counts the number of examples. + + """ num_examples = 123 examples = [ {"stuff": 0}, @@ -36,6 +56,12 @@ def test_count_metric(): def test_collections(): + """ + Test function for metric collections. + + It tests if metric collections correctly aggregate metrics. + + """ max_metric = Max() count_metric = Count() metric = MetricCollection([max_metric, count_metric]) @@ -51,6 +77,12 @@ def test_collections(): def test_task_dependent_ctr(): + """ + Test function for task-dependent Ctr (Click-Through Rate) metric. + + It checks if the Ctr metric computes the correct value for different tasks. + + """ num_examples = 144 batch_size = 1024 outputs = [ @@ -69,6 +101,13 @@ def test_task_dependent_ctr(): def test_stratified_ctr(): + """ + Test function for the Stratified Ctr (Click-Through Rate) metric. + + It checks if the Stratified Ctr metric computes the correct value for different tasks + and stratified samples. + + """ outputs = [ { "stuff": 0, @@ -114,6 +153,12 @@ def test_stratified_ctr(): def test_auc(): + """ + Test function for the AUC (Area Under the Curve) metric. + + It checks if the AUC metric correctly computes the Area Under the ROC Curve. + + """ num_samples = 10000 metric = core_metrics.Auc(num_samples) target = torch.tensor([0, 0, 1, 1, 1]) @@ -131,6 +176,12 @@ def test_auc(): def test_pos_rank(): + """ + Test function for the PosRanks metric. + + It checks if the PosRanks metric correctly computes the ranks of positive samples. + + """ metric = core_metrics.PosRanks() target = torch.tensor([0, 0, 1, 1, 1]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) @@ -147,6 +198,12 @@ def test_pos_rank(): def test_reciprocal_rank(): + """ + Test function for the Reciprocal Rank metric. + + It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks. + + """ metric = core_metrics.ReciprocalRank() target = torch.tensor([0, 0, 1, 1, 1]) preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5]) @@ -163,6 +220,12 @@ def test_reciprocal_rank(): def test_hit_k(): + """ + Test function for the Hit@K metric. + + It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives. + + """ hit1_metric = core_metrics.HitAtK(1) target = torch.tensor([0, 0, 1, 1, 1]) preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5]) diff --git a/core/test_train_pipeline.py b/core/test_train_pipeline.py index 8e2f6f5..ee9f401 100644 --- a/core/test_train_pipeline.py +++ b/core/test_train_pipeline.py @@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel @dataclass class MockDataclassBatch(DataclassBatch): + """ + Mock data class batch for testing purposes. + + This class represents a batch of data with continuous features and labels. + + Attributes: + continuous_features (torch.Tensor): Tensor containing continuous feature data. + labels (torch.Tensor): Tensor containing label data. + """ continuous_features: torch.Tensor labels: torch.Tensor class MockModule(torch.nn.Module): + """ + Mock PyTorch module for testing purposes. + + This module defines a simple neural network model with a linear layer + followed by a BCEWithLogitsLoss loss function. + + Attributes: + model (torch.nn.Linear): The linear model layer. + loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function. + """ def __init__(self) -> None: super().__init__() self.model = torch.nn.Linear(10, 1) self.loss_fn = torch.nn.BCEWithLogitsLoss() def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the mock module. + + Args: + batch (MockDataclassBatch): Input data batch with continuous features and labels. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions. + """ pred = self.model(batch.continuous_features) loss = self.loss_fn(pred, batch.labels) return (loss, pred) def create_batch(bsz: int): + """ + Create a mock data batch with random continuous features and labels. + + Args: + bsz (int): Batch size. + + Returns: + MockDataclassBatch: A batch of data with continuous features and labels. + """ return MockDataclassBatch( continuous_features=torch.rand(bsz, 10).float(), labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(), @@ -35,6 +72,13 @@ def create_batch(bsz: int): def test_sparse_pipeline(): + """ + Test function for the sparse pipeline with distributed model parallelism. + + This function tests the behavior of the sparse training pipeline using + a mock module and data. + """ + device = torch.device("cpu") model = MockModule().to(device) @@ -65,6 +109,15 @@ def test_sparse_pipeline(): def test_amp(): + """ + Test automatic mixed-precision (AMP) training with the sparse pipeline. + + This function tests the behavior of the sparse training pipeline with + automatic mixed-precision (AMP) enabled, using a mock module and data. + + AMP allows for faster training by using lower-precision data types, such as + torch.bfloat16, while maintaining model accuracy. + """ device = torch.device("cpu") model = MockModule().to(device) diff --git a/core/train_pipeline.py b/core/train_pipeline.py index cde587e..3209988 100644 --- a/core/train_pipeline.py +++ b/core/train_pipeline.py @@ -39,12 +39,42 @@ Out = TypeVar("Out") class TrainPipeline(abc.ABC, Generic[In, Out]): + """ + Abstract base class for training pipelines. + + Attributes: + In (TypeVar): Input data type. + Out (TypeVar): Output data type. + + Methods: + progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline. + """ @abc.abstractmethod def progress(self, dataloader_iter: Iterator[In]) -> Out: + """ + Make progress in the training pipeline. + + Args: + dataloader_iter (Iterator[In]): An iterator over input data. + + Returns: + Out: The output data. + """ pass def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: + """ + Move a batch of data to a specified device. + + Args: + batch (In): The input batch. + device (torch.device): The target device. + non_blocking (bool): If True, move the data asynchronously. + + Returns: + In: The batch of data on the target device. + """ assert isinstance( batch, (torch.Tensor, Pipelineable) ), f"{type(batch)} must implement Pipelineable interface" @@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None: + """ + Wait for a batch of data on a specified stream. + + Args: + batch (In): The input batch. + stream (Optional[Stream]): The CUDA stream to wait for. + + Note: + This function is used for managing asynchronous CUDA operations. + """ if stream is None: return torch.cuda.current_stream().wait_stream(stream) @@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N class TrainPipelineBase(TrainPipeline[In, Out]): """ - This class runs training iterations using a pipeline of two stages, each as a CUDA - stream, namely, the current (default) stream and `self._memcpy_stream`. For each - iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU - memory, and the default stream runs forward, backward, and optimization. - """ + This class runs training iterations using a pipeline of two stages, each as a CUDA + stream, namely, the current (default) stream and `self._memcpy_stream`. For each + iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU + memory, and the default stream runs forward, backward, and optimization. + + Attributes: + In (TypeVar): Input data type. + Out (TypeVar): Output data type. + + Methods: + __init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None: + Initialize the TrainPipelineBase. + + _connect(dataloader_iter: Iterator[In]) -> None: + Establish a connection to the data loader and move the input data to the GPU. + + progress(dataloader_iter: Iterator[In]) -> Out: + Execute a training iteration, including forward and backward passes. + + """ def __init__( self, @@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]): optimizer: torch.optim.Optimizer, device: torch.device, ) -> None: + """ + Initialize the TrainPipelineBase. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + optimizer (torch.optim.Optimizer): The optimizer used for training. + device (torch.device): The target device for training (CPU or GPU). + """ self._model = model self._optimizer = optimizer self._device = device @@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]): self._connected = False def _connect(self, dataloader_iter: Iterator[In]) -> None: + """ + Establish a connection to the data loader and move the input data to the GPU. + + Args: + dataloader_iter (Iterator[In]): An iterator over input data. + """ cur_batch = next(dataloader_iter) self._cur_batch = cur_batch with torch.cuda.stream(self._memcpy_stream): @@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]): self._connected = True def progress(self, dataloader_iter: Iterator[In]) -> Out: + """ + Execute a training iteration, including forward and backward passes. + + Args: + dataloader_iter (Iterator[In]): An iterator over input data. + + Returns: + Out: The output data. + """ if not self._connected: self._connect(dataloader_iter) @@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]): class Tracer(torch.fx.Tracer): + """ + Custom tracer class for PyTorch models. + + This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings. + + Attributes: + proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing. + _leaf_modules (List[str]): List of qualified names of leaf modules. + """ + # Disable proxying buffers during tracing. Ideally, proxying buffers would # be disabled, but some models are currently mutating buffer values, which # causes errors during tracing. If those models can be rewritten to not do @@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer): proxy_buffer_attributes = False def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: + """ + Initialize the Tracer. + + Args: + leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing. + """ super().__init__() self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + Check if a module is a leaf module during tracing. + + Args: + m (torch.nn.Module): The PyTorch module. + module_qualified_name (str): The qualified name of the module. + + Returns: + bool: True if the module is considered a leaf module, False otherwise. + """ if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules: return True return super().is_leaf_module(m, module_qualified_name) @@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer): @dataclass class TrainPipelineContext: + """ + Dataclass to store information related to the training pipeline context. + + Attributes: + input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests. + module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts. + feature_processor_forwards (List[Any]): A list of feature processor forwards. + """ + # pyre-ignore [4] input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) module_contexts: Dict[str, Multistreamable] = field(default_factory=dict) @@ -166,6 +279,14 @@ class TrainPipelineContext: @dataclass class ArgInfo: + """ + Dataclass to store information about arguments in the training pipeline. + + Attributes: + input_attrs (List[str]): List of attribute names of the input batch. + is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem. + name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments. + """ # attributes of input batch, e.g. batch.attr1.attr2 call # will produce ["attr1", "attr2"] input_attrs: List[str] @@ -177,6 +298,16 @@ class ArgInfo: class PipelinedForward: + """ + Represents a pipelined forward pass operation. + + Attributes: + name (str): The name of the forward pass. + args (List[ArgInfo]): List of argument information for the forward pass. + module (ShardedModule): The sharded module associated with the forward pass. + context (TrainPipelineContext): The training pipeline context. + dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing. + """ def __init__( self, name: str, @@ -185,6 +316,16 @@ class PipelinedForward: context: TrainPipelineContext, dist_stream: Optional[torch.cuda.streams.Stream], ) -> None: + """ + Initialize a PipelinedForward instance. + + Args: + name (str): The name of the forward pass. + args (List[ArgInfo]): List of argument information for the forward pass. + module (ShardedModule): The sharded module associated with the forward pass. + context (TrainPipelineContext): The training pipeline context. + dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing. + """ self._name = name self._args = args self._module = module @@ -193,6 +334,16 @@ class PipelinedForward: # pyre-ignore [2, 24] def __call__(self, *input, **kwargs) -> Awaitable: + """ + Perform the pipelined forward pass operation. + + Args: + *input: Variable-length positional arguments. + **kwargs: Variable-length keyword arguments. + + Returns: + Awaitable: An awaitable object representing the forward pass result. + """ assert self._name in self._context.input_dist_requests request = self._context.input_dist_requests[self._name] assert isinstance(request, Awaitable) @@ -230,10 +381,22 @@ class PipelinedForward: @property def name(self) -> str: + """ + Get the name of the forward pass. + + Returns: + str: The name of the forward pass. + """ return self._name @property def args(self) -> List[ArgInfo]: + """ + Get the list of argument information for the forward pass. + + Returns: + List[ArgInfo]: List of argument information. + """ return self._args @@ -242,6 +405,17 @@ def _start_data_dist( batch: In, context: TrainPipelineContext, ) -> None: + """ + Start data distribution for a list of pipelined modules. + + Args: + pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules. + batch (In): The input batch. + context (TrainPipelineContext): The training pipeline context. + + Returns: + None: This function doesn't return a value. + """ context.input_dist_requests.clear() context.module_contexts.clear() for module in pipelined_modules: @@ -286,9 +460,17 @@ def _get_node_args_helper( feature_processor_arguments: Optional[List[Node]] = None, ) -> Tuple[List[ArgInfo], int]: """ - Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. - It also counts the number of (args + kwargs) found. - """ + Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. + It also counts the number of (args + kwargs) found. + + Args: + arguments: The arguments to process. + num_found: The current count of arguments found. + feature_processor_arguments: Optional list of feature processor arguments. + + Returns: + Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found. + """ arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))] for arg, arg_info in zip(arguments, arg_info_list): @@ -332,6 +514,16 @@ def _get_node_args_helper( def _get_node_args( node: Node, feature_processor_nodes: Optional[List[Node]] = None ) -> Tuple[List[ArgInfo], int]: + """ + Get argument information for a given node. + + Args: + node (Node): The node to process. + feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes. + + Returns: + Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found. + """ num_found = 0 pos_arg_info_list, num_found = _get_node_args_helper( node.args, num_found, feature_processor_nodes @@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper( path: str, unsharded_module_names: Set[str], ) -> bool: + """ + Get the names of unsharded modules in a model. + + Args: + model (torch.nn.Module): The model to analyze. + path (str): The current path in the model hierarchy. + unsharded_module_names (Set[str]): A set to store the names of unsharded modules. + + Returns: + bool: True if any sharded modules were found in the hierarchy, False otherwise. + """ sharded_children = set() for name, child in model.named_children(): curr_path = path + name @@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper( def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]: """ - Returns a list of top level modules do not contain any sharded sub modules. - """ + Returns a list of top-level modules that do not contain any sharded sub-modules. + + Args: + model (torch.nn.Module): The model to analyze. + + Returns: + List[str]: A list of top-level module names without sharded sub-modules. + """ unsharded_module_names: Set[str] = set() _get_unsharded_module_names_helper( @@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901 context: TrainPipelineContext, dist_stream: Optional[torch.cuda.streams.Stream], ) -> List[ShardedModule]: + """ + Rewrites the model to enable pipelined execution for selected sharded modules. + + This function traces the input model using a custom tracer and identifies sharded modules + that can be pipelined. It then creates PipelinedForward objects for these modules, + which enable pipelining during training. + + Args: + model (torch.nn.Module): The input model to be rewritten. + context (TrainPipelineContext): The context containing information needed for pipelining. + dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution. + + Returns: + List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution. + """ # Get underlying nn.Module if isinstance(model, DistributedModelParallel): @@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901 class TrainPipelineSparseDist(TrainPipeline[In, Out]): """ This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with - forward and backward. This helps hide the all2all latency while preserving the - training forward / backward ordering. + forward and backward. This helps hide the all2all latency while preserving the + training forward / backward ordering. - stage 3: forward, backward - uses default CUDA stream - stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream - stage 1: device transfer - uses memcpy CUDA stream + stage 3: forward, backward - uses default CUDA stream + stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream + stage 1: device transfer - uses memcpy CUDA stream - `ShardedModule.input_dist()` is only done for top-level modules in the call graph. - To be considered a top-level module, a module can only depend on 'getattr' calls on - input. + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. - Input model must be symbolically traceable with the exception of `ShardedModule` and - `DistributedDataParallel` modules. - """ + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): The input model to be used for training. + optimizer (torch.optim.Optimizer): The optimizer for updating model parameters. + device (torch.device): The device where training will be performed. + enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False. + enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True. + grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None. + + Attributes: + synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines. + + """ synced_pipeline_id: Dict[int, int] = {} @@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): enable_grad_scaling: bool = True, grad_accum: Optional[int] = None, ) -> None: + """ + Initializes the training pipeline. + + Args: + model (torch.nn.Module): The input model to be used for training. + optimizer (torch.optim.Optimizer): The optimizer for updating model parameters. + device (torch.device): The device where training will be performed. + enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False. + enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True. + grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None. + """ self._model = model self._optimizer = optimizer self._device = device @@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): self._grad_accum = grad_accum def _connect(self, dataloader_iter: Iterator[In]) -> None: + """ + Connects the training pipeline to data and prepares for forward and backward passes. + + Args: + dataloader_iter (Iterator[In]): An iterator providing input data batches. + """ + # batch 1 with torch.cuda.stream(self._memcpy_stream): batch_i = next(dataloader_iter) @@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): def progress(self, dataloader_iter: Iterator[In]) -> Out: """ - NOTE: This method has been updated to perform gradient accumulation. - If `_grad_accum` is set, then loss values are scaled by this amount and - optimizer update/reset is skipped for `_grad_accum` calls of `progress` - (congruent to training steps), and then update/reset on every `_grad_accum`th - step. + Progresses through the training pipeline, performing forward and backward passes. - """ + NOTE: This method has been updated to perform gradient accumulation. + If `_grad_accum` is set, then loss values are scaled by this amount and + optimizer update/reset is skipped for `_grad_accum` calls of `progress` + (congruent to training steps), and then update/reset on every `_grad_accum`th + step. + + Args: + dataloader_iter (Iterator[In]): An iterator providing input data batches. + + Returns: + Out: The output of the forward pass. + """ should_step_optimizer = ( self._grad_accum is not None and self._progress_calls > 0 @@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): def _sync_pipeline(self) -> None: """ - Syncs `PipelinedForward` for sharded modules with context and dist stream of the - current train pipeline. Used when switching between train pipelines for the same - model. + Syncs `PipelinedForward` for sharded modules with context and dist stream of the + current train pipeline. Used when switching between train pipelines for the same + model. """ for module in self._pipelined_modules: module.forward._context = self._context diff --git a/machines/environment.py b/machines/environment.py index c7319d6..a0bfb98 100644 --- a/machines/environment.py +++ b/machines/environment.py @@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222 def on_kf(): - return "SPEC_TYPE" in os.environ + """Check if the code is running on Kubernetes with Kubeflow (KF) environment. + + Returns: + bool: True if running on KF, False otherwise. + """ + return "SPEC_TYPE" in os.environ def has_readers(): - if on_kf(): - machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) - return machines_config_env["dataset_worker"] is not None - return os.environ.get("HAS_READERS", "False") == "True" + """Check if the current task has dataset workers. + + Returns: + bool: True if the task has dataset workers, False otherwise. + """ + if on_kf(): + machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) + return machines_config_env.get("dataset_worker") is not None + return os.environ.get("HAS_READERS", "False") == "True" def get_task_type(): - if on_kf(): - return os.environ["SPEC_TYPE"] - return os.environ["TASK_TYPE"] + """Get the type of the current task. + + Returns: + str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'. + """ + if on_kf(): + return os.environ["SPEC_TYPE"] + return os.environ["TASK_TYPE"] def is_chief() -> bool: - return get_task_type() == "chief" + """Check if the current task is the 'chief'. + + Returns: + bool: True if the current task is the 'chief', False otherwise. + """ + return get_task_type() == "chief" def is_reader() -> bool: - return get_task_type() == "datasetworker" + """Check if the current task is a 'datasetworker'. + + Returns: + bool: True if the current task is a 'datasetworker', False otherwise. + """ + return get_task_type() == "datasetworker" def is_dispatcher() -> bool: - return get_task_type() == "datasetdispatcher" + """Check if the current task is a 'datasetdispatcher'. + + Returns: + bool: True if the current task is a 'datasetdispatcher', False otherwise. + """ + return get_task_type() == "datasetdispatcher" def get_task_index(): - if on_kf(): - pod_name = os.environ["MY_POD_NAME"] - return int(pod_name.split("-")[-1]) - else: - raise NotImplementedError + """Get the index of the current task. + + Returns: + int: Task index. + Raises: + NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment. + """ + if on_kf(): + pod_name = os.environ["MY_POD_NAME"] + return int(pod_name.split("-")[-1]) + else: + raise NotImplementedError def get_reader_port(): - if on_kf(): - return KF_DDS_PORT - return SLURM_DDS_PORT + """Get the port used by readers. + + Returns: + int: Reader port. + """ + if on_kf(): + return KF_DDS_PORT + return SLURM_DDS_PORT def get_dds(): - if not has_readers(): - return None - dispatcher_address = get_dds_dispatcher_address() - if dispatcher_address: - return f"grpc://{dispatcher_address}" - else: - raise ValueError("Job does not have DDS.") + """Get the Distributed Data Service (DDS) address. + + Returns: + str: DDS address in the format 'grpc://host:port'. + Raises: + ValueError: If the job does not have DDS. + """ + if not has_readers(): + return None + dispatcher_address = get_dds_dispatcher_address() + if dispatcher_address: + return f"grpc://{dispatcher_address}" + else: + raise ValueError("Job does not have DDS.") def get_dds_dispatcher_address(): - if not has_readers(): - return None - if on_kf(): - job_name = os.environ["JOB_NAME"] - dds_host = f"{job_name}-datasetdispatcher-0" - else: - dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"] - return f"{dds_host}:{get_reader_port()}" + """Get the DDS dispatcher address. + + Returns: + str: DDS dispatcher address in the format 'host:port'. + """ + if not has_readers(): + return None + if on_kf(): + job_name = os.environ["JOB_NAME"] + dds_host = f"{job_name}-datasetdispatcher-0" + else: + dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"] + return f"{dds_host}:{get_reader_port()}" def get_dds_worker_address(): - if not has_readers(): - return None - if on_kf(): - job_name = os.environ["JOB_NAME"] - task_index = get_task_index() - return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}" - else: - node = os.environ["SLURMD_NODENAME"] - return f"{node}:{get_reader_port()}" + """Get the DDS worker address. + + Returns: + str: DDS worker address in the format 'host:port'. + """ + if not has_readers(): + return None + if on_kf(): + job_name = os.environ["JOB_NAME"] + task_index = get_task_index() + return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}" + else: + node = os.environ["SLURMD_NODENAME"] + return f"{node}:{get_reader_port()}" def get_num_readers(): - if not has_readers(): - return 0 - if on_kf(): - machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) - return int(machines_config_env["num_dataset_workers"] or 0) - return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(",")) + """Get the number of dataset workers. + + Returns: + int: Number of dataset workers. + """ + if not has_readers(): + return 0 + if on_kf(): + machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) + return int(machines_config_env.get("num_dataset_workers") or 0) + return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(",")) def get_flight_server_addresses(): - if on_kf(): - job_name = os.environ["JOB_NAME"] - return [ - f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}" - for task_index in range(get_num_readers()) - ] - else: - raise NotImplementedError + """Get Flight server addresses for dataset workers. + + Returns: + List[str]: List of Flight server addresses in the format 'grpc://host:port'. + Raises: + NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment. + """ + if on_kf(): + job_name = os.environ["JOB_NAME"] + return [ + f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}" + for task_index in range(get_num_readers()) + ] + else: + raise NotImplementedError def get_dds_journaling_dir(): - return os.environ.get("DATASET_JOURNALING_DIR", None) + """Get the DDS journaling directory. + + Returns: + str: DDS journaling directory. + """ + return os.environ.get("DATASET_JOURNALING_DIR", None) diff --git a/metrics/aggregation.py b/metrics/aggregation.py index 57f75d5..13a0e09 100644 --- a/metrics/aggregation.py +++ b/metrics/aggregation.py @@ -14,18 +14,21 @@ def update_mean( weight: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Update the mean according to Welford formula: - https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. + Update the mean according to the Welford formula. + + This function updates the mean and the weighted sum of values using the Welford algorithm. + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. See also https://nullbuffer.com/articles/welford_algorithm.html for more information. - Args: - current_mean: The value of the current accumulated mean. - current_weight_sum: The current weighted sum. - value: The new value that needs to be added to get a new mean. - weight: The weights for the new value. - Returns: The updated mean and updated weighted sum. + Args: + current_mean (torch.Tensor): The value of the current accumulated mean. + current_weight_sum (torch.Tensor): The current weighted sum. + value (torch.Tensor): The new value that needs to be added to get a new mean. + weight (torch.Tensor): The weights for the new value. - """ + Returns: + Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum. + """ weight = torch.broadcast_to(weight, value.shape) # Avoiding (on purpose) in-place operation when using += in case @@ -38,12 +41,15 @@ def update_mean( def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: """ Merge the state from multiple workers. - Args: - state: A tensor with the first dimension indicating workers. - Returns: The accumulated mean from all workers. + This function merges the state from multiple workers to compute the accumulated mean. - """ + Args: + state (torch.Tensor): A tensor with the first dimension indicating workers. + + Returns: + torch.Tensor: The accumulated mean from all workers. + """ mean, weight_sum = update_mean( current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), @@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: class StableMean(torchmetrics.Metric): """ - This implements a numerical stable mean metrics computation using Welford algorithm according to - https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. + A numerical stable mean metric using the Welford algorithm. + + This class implements a numerical stable mean metrics computation using the Welford algorithm. + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. For example when using float32, the algorithm will give a valid output even if the "sum" is larger than the maximum float32 as far as the mean is within the limit of float32. See also https://nullbuffer.com/articles/welford_algorithm.html for more information. - """ + + Args: + **kwargs: Additional parameters supported by all torchmetrics.Metric. + + Attributes: + mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum. + """ def __init__(self, **kwargs): """ @@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric): ) def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None: - """ - Update the current mean. - Args: - value: Value to update the mean with. - weight: weight to use. Shape should be broadcastable to that of value. - """ + """Update the current mean. + + Args: + value (torch.Tensor): Value to update the mean with. + weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value. + """ mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] if not isinstance(weight, torch.Tensor): @@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric): ) def compute(self) -> torch.Tensor: - """ - Compute and return the accumulated mean. - """ + """Compute and return the accumulated mean. + + Returns: + torch.Tensor: The accumulated mean. + """ return self.mean_and_weight_sum[0] diff --git a/metrics/auroc.py b/metrics/auroc.py index 6979c20..66e39f2 100644 --- a/metrics/auroc.py +++ b/metrics/auroc.py @@ -29,6 +29,9 @@ def _compute_helper( equal_predictions_as_incorrect: For positive & negative labels having identical scores, we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise, we assume that they are correct prediction (i.e weight = 0). + + Returns: + torch.Tensor: The computed AUROC """ dim = 0 @@ -52,24 +55,34 @@ def _compute_helper( class AUROCWithMWU(torchmetrics.Metric): """ - AUROC using Mann-Whitney U-test. - See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve. + AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test. - This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return - the correct AUROC even if the predicted probabilities are all close to 0. - Currently only support binary classification. - """ + This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate) + scenarios. It returns the correct AUROC even when predicted probabilities are close to 0. + + See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve. + + Note: Currently, this implementation only supports binary classification. + + Args: + label_threshold (float): Threshold for classifying labels as positive or negative. + Labels above this threshold are considered positive, and those below are considered negative. + raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing. + Otherwise, a warning is logged, and AUROC is computed. + **kwargs: Additional parameters supported by all torchmetrics.Metric. + """ def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs): """ + Initializes the AUROCWithMWU metric. - Args: - label_threshold: Labels strictly above this threshold are considered positive labels, - otherwise, they are considered negative. - raise_missing_class: If True, an error will be raise if negative or positive class is missing. - Otherwise, we will simply log a warning. - **kwargs: Additional parameters supported by all torchmetrics.Metric. - """ + Args: + label_threshold (float): Threshold for classifying labels as positive or negative. + Labels above this threshold are considered positive, and those below are considered negative. + raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing. + Otherwise, a warning is logged, and AUROC is computed. + **kwargs: Additional parameters supported by all torchmetrics.Metric. + """ super().__init__(**kwargs) self.add_state("predictions", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") diff --git a/ml_logging/absl_logging.py b/ml_logging/absl_logging.py index e4d9e9e..07be8bb 100644 --- a/ml_logging/absl_logging.py +++ b/ml_logging/absl_logging.py @@ -14,7 +14,24 @@ from absl import logging as logging def setup_absl_logging(): - """Make sure that absl logging pushes to stdout rather than stderr.""" + """ + Configure absl-py logging to direct log messages to stdout and apply a custom log message format. + + This function ensures that log messages generated by the absl-py library are written to stdout + rather than stderr. It also applies a custom log message format that includes module, function, + line number, log level, and the log message content. + + Note: + This function should be called once at the beginning of your script or application to + configure absl-py logging. + + Example: + To use this function, simply call it at the start of your script: + ``` + setup_absl_logging() + ``` + + """ logging.get_absl_handler().python_handler.stream = sys.stdout formatter = py_logging.Formatter( fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s" diff --git a/ml_logging/test_torch_logging.py b/ml_logging/test_torch_logging.py index e3a2e90..36c141d 100644 --- a/ml_logging/test_torch_logging.py +++ b/ml_logging/test_torch_logging.py @@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging class Testtlogging(unittest.TestCase): def test_warn_once(self): + """ + Test that warning messages are logged only once when using the assertLogs context manager. + + This unit test checks the behavior of the logging system when warning messages are issued + multiple times within the same context. It uses the assertLogs context manager to capture + log messages at the INFO level and verifies that warning messages are logged only once. + + Example: + To use this test case, call it using a test runner like unittest: + ``` + python -m unittest your_test_module.TestLogging.test_warn_once + ``` + + """ + with self.assertLogs(level="INFO") as captured_logs: logging.info("first info") logging.warning("first warning") diff --git a/ml_logging/torch_logging.py b/ml_logging/torch_logging.py index e791c46..2f60114 100644 --- a/ml_logging/torch_logging.py +++ b/ml_logging/torch_logging.py @@ -18,7 +18,35 @@ import torch.distributed as dist def rank_specific(logger): - """Ensures that we only override a given logger once.""" + """ + Customize logger behavior based on the distributed environment and rank. + + This function allows for customizing the behavior of a logger based on the distributed environment and the rank + of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages + depending on the rank or limit the number of redundant logs. + + Args: + logger: The logger object to customize. + + Returns: + The customized logger. + + Example: + To use this function with the `logging` module: + ```python + import logging + from rank_specific_logging import rank_specific + + logger = logging.getLogger(__name__) + rank_specific(logger) + ``` + + Customization: + - Messages are only logged if the distributed environment is not initialized or if the rank matches. + - The 'warning' method is limited to logging a single redundant warning. + - Logging from rank -1 is redirected to include the rank information. + + """ if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"): return logger diff --git a/model.py b/model.py index 9df13bc..b8941dd 100644 --- a/model.py +++ b/model.py @@ -54,13 +54,21 @@ def maybe_shard_model( model, device: torch.device, ): - """Set up and apply DistributedModelParallel to a model if running in a distributed environment. + """ + Set up and apply DistributedModelParallel to a model if running in a distributed environment. If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies DistributedModelParallel. - If not in a distributed environment, returns model directly. - """ + If not in a distributed environment, returns the model directly. + + Args: + model: The PyTorch model. + device: The target device (e.g., 'cuda'). + + Returns: + The model wrapped with DistributedModelParallel if in a distributed environment, else the original model. + """ if dist.is_initialized(): logging.info("***** Wrapping in DistributedModelParallel *****") logging.info(f"Model before wrapping: {model}") @@ -74,14 +82,15 @@ def maybe_shard_model( def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None: - """Handy function to log the content of EBC embedding layer. - Only works for single GPU machines. - - Args: - weight_name: name of tensor, as defined in model - table_name: name of the EBC table the weight is taken from - weight_tensor: embedding weight tensor """ + Handy function to log the content of an EBC (Embedding Bag Concatenation) embedding layer. + Only works for single GPU machines. + + Args: + weight_name: Name of the tensor, as defined in the model. + table_name: Name of the EBC table the weight is taken from. + weight_tensor: Embedding weight tensor. + """ logging.info(f"{weight_name}, {table_name}", rank=-1) logging.info(f"{weight_tensor.metadata()}", rank=-1) output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0")) diff --git a/optimizers/config.py b/optimizers/config.py index f5011f0..dc5c94a 100644 --- a/optimizers/config.py +++ b/optimizers/config.py @@ -8,11 +8,60 @@ import pydantic class PiecewiseConstant(base_config.BaseConfig): + """ + Configuration for a piecewise constant learning rate schedule. + + This configuration class allows you to specify a piecewise constant learning rate schedule + by defining boundaries and corresponding learning rate values. + + Attributes: + learning_rate_boundaries (List[int], optional): List of step boundaries at which + the learning rate will change. If None, no boundaries are defined. + learning_rate_values (List[float], optional): List of learning rate values + corresponding to the boundaries. If None, no values are defined. + + Example: + To configure a piecewise constant learning rate schedule, create an instance of this class + and set the attributes accordingly. For example: + + ```python + piecewise_lr = PiecewiseConstant( + learning_rate_boundaries=[1000, 2000, 3000], + learning_rate_values=[0.1, 0.05, 0.01, 0.001] + ) + ``` + + Note: + The number of learning rate values should be one more than the number of boundaries. + + """ learning_rate_boundaries: typing.List[int] = pydantic.Field(None) learning_rate_values: typing.List[float] = pydantic.Field(None) class LinearRampToConstant(base_config.BaseConfig): + """ + Configuration for a linear ramp-up to constant learning rate schedule. + + This configuration class allows you to specify a learning rate schedule that ramps up linearly + from zero to a constant value over a specified number of steps. + + Attributes: + learning_rate (float): The final constant learning rate. + num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero. + + Example: + To configure a linear ramp-up to a constant learning rate, create an instance of this class + and set the attributes accordingly. For example: + + ```python + linear_ramp_lr = LinearRampToConstant( + learning_rate=0.1, + num_ramp_steps=1000 + ) + ``` + + """ learning_rate: float num_ramp_steps: pydantic.PositiveInt = pydantic.Field( description="Number of steps to ramp this up from zero." @@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig): class LinearRampToCosine(base_config.BaseConfig): + """ + Configuration for a linear ramp-up to cosine decay learning rate schedule. + + This configuration class allows you to specify a learning rate schedule that ramps up linearly + from zero, then decays following a cosine schedule to a final constant learning rate. + + Attributes: + learning_rate (float): The initial learning rate at the start of ramp-up. + final_learning_rate (float): The final constant learning rate after decay. + num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero. + final_num_steps (PositiveInt): Final number of steps where decay stops. + + Example: + To configure a linear ramp-up to cosine decay learning rate, create an instance of this + class and set the attributes accordingly. For example: + + ```python + ramp_to_cosine_lr = LinearRampToCosine( + learning_rate=0.01, + final_learning_rate=0.001, + num_ramp_steps=1000, + final_num_steps=5000 + ) + ``` + + """ learning_rate: float final_learning_rate: float num_ramp_steps: pydantic.PositiveInt = pydantic.Field( @@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig): class LearningRate(base_config.BaseConfig): + """ + Learning rate configuration for training. + + This configuration class allows you to specify different learning rate schedules + for your training process. + + Attributes: + constant (float, optional): Constant learning rate to be used throughout training. + linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly + and then decays following a cosine schedule. + linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up + linearly and then remains constant. + piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified + boundaries with corresponding values. + + Example: + To configure a learning rate schedule, create an instance of this class and set the + attributes accordingly. For example: + + ```python + learning_rate = LearningRate( + constant=0.01, + linear_ramp_to_cosine=LinearRampToCosine( + learning_rate=0.1, + final_learning_rate=0.001, + num_ramp_steps=1000, + final_num_steps=5000 + ) + ) + ``` + + Note: + Each learning rate schedule attribute can be set to `None` if not needed. + + """ constant: float = pydantic.Field(None, one_of="lr") linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr") linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr") @@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig): class OptimizerAlgorithmConfig(base_config.BaseConfig): - """Base class for optimizer configurations.""" + """ + Base class for optimizer configurations. + + This base configuration class provides a structure for specifying various optimizer-related + settings, including the learning rate and different learning rate schedules. + + Attributes: + lr (float): The base learning rate used by the optimizer. + + Subclasses should inherit from this base class and define additional attributes specific to + the optimizer algorithm they represent. + + Example: + To create a custom optimizer configuration, create a subclass of this base class and + define the necessary attributes. For example: + + ```python + class MyOptimizerConfig(OptimizerAlgorithmConfig): + momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.") + ``` + + Note: + This base class does not include specific optimizer settings. Subclasses should define + the optimizer-specific attributes as needed. + + """ lr: float ... class AdamConfig(OptimizerAlgorithmConfig): - # see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam + """ + Configuration for the Adam optimizer. + + This configuration class allows you to specify the hyperparameters for the Adam optimizer. + + Attributes: + lr (float): The learning rate for optimization. + betas (Tuple[float, float], optional): Coefficients used for computing running averages + of gradient and squared gradient. Defaults to (0.9, 0.999). + eps (float, optional): A small constant added to the denominator for numerical stability. + Defaults to 1e-7. + + Example: + To configure the Adam optimizer, create an instance of this class and set the attributes + accordingly. For example: + + ```python + adam_optimizer = AdamConfig( + lr=0.001, + betas=(0.9, 0.999), + eps=1e-8 + ) + ``` + + See Also: + [PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam) + + """ lr: float betas: typing.Tuple[float, float] = [0.9, 0.999] eps: float = 1e-7 # Numerical stability in denominator. class SgdConfig(OptimizerAlgorithmConfig): + """ + Configuration for the Stochastic Gradient Descent (SGD) optimizer. + + This configuration class allows you to specify the hyperparameters for the SGD optimizer. + + Attributes: + lr (float): The learning rate for optimization. + momentum (float, optional): The momentum factor for SGD. Defaults to 0.0. + + Example: + To configure the SGD optimizer, create an instance of this class and set the attributes + accordingly. For example: + + ```python + sgd_optimizer = SgdConfig( + lr=0.01, + momentum=0.9 + ) + ``` + + """ lr: float momentum: float = 0.0 class AdagradConfig(OptimizerAlgorithmConfig): + """ + Configuration for the optimizer used during training. + + This configuration class allows you to specify the optimizer for training, including + options for various optimizer algorithms. + + Attributes: + learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None. + adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None. + sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer. + Defaults to None. + adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None. + + Example: + To configure the optimizer for training, create an instance of this class and set the + attributes accordingly. For example: + + ```python + optimizer_config = OptimizerConfig( + learning_rate=LearningRate(constant=0.001), + adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8) + ) + ``` + + """ lr: float eps: float = 0 class OptimizerConfig(base_config.BaseConfig): + """ + Configuration for defining different optimizer algorithms and their parameters. + + This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad, + along with their respective hyperparameters. + + Args: + learning_rate (LearningRate): The learning rate configuration, which can include + constant learning rates or other learning rate schedules. + adam (AdamConfig): Configuration for the Adam optimizer. + sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer. + adagrad (AdagradConfig): Configuration for the Adagrad optimizer. + + Example: + ```python + optimizer_config = OptimizerConfig( + learning_rate=LearningRate(constant=0.001), + adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8), + ) + ``` + + Attributes: + learning_rate (LearningRate): The learning rate configuration. + adam (AdamConfig): Configuration for the Adam optimizer. + sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer. + adagrad (AdagradConfig): Configuration for the Adagrad optimizer. + + Note: + You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an + `OptimizerConfig` instance. + + See Also: + - `LearningRate`: Configuration for specifying learning rates. + - `AdamConfig`: Configuration for the Adam optimizer. + - `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer. + - `AdagradConfig`: Configuration for the Adagrad optimizer. + + """ learning_rate: LearningRate = pydantic.Field( None, description="Constant learning rates", @@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig): def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig): + """ + Get the optimizer algorithm configuration from the given `OptimizerConfig`. + + This function extracts and returns the specific optimizer algorithm configuration + (e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`. + + Args: + optimizer_config (OptimizerConfig): The optimizer configuration object containing + one of the optimizer algorithm configurations. + + Returns: + Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm + configuration extracted from `optimizer_config`. + + Raises: + ValueError: If no optimizer algorithm is selected in `optimizer_config`. + + Example: + ```python + optimizer_config = OptimizerConfig( + adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8) + ) + algorithm_config = get_optimizer_algorithm_config(optimizer_config) + # `algorithm_config` will be an instance of `AdamConfig`. + ``` + + """ if optimizer_config.adam is not None: return optimizer_config.adam elif optimizer_config.sgd is not None: diff --git a/optimizers/optimizer.py b/optimizers/optimizer.py index 4517368..e930f73 100644 --- a/optimizers/optimizer.py +++ b/optimizers/optimizer.py @@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging def compute_lr(lr_config, step): - """Compute a learning rate.""" + """ + Compute the learning rate based on the specified learning rate configuration. + + This function calculates the learning rate according to the given configuration, which can include + constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing. + + Args: + lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule. + step (int): The current training step or iteration. + + Returns: + float: The computed learning rate for the current step. + + Raises: + ValueError: If the `lr_config` is invalid or contains conflicting options. + + Example: + ```python + lr_schedule = LearningRate( + constant=0.001, + piecewise_constant=PiecewiseConstant( + learning_rate_boundaries=[1000, 2000, 3000], + learning_rate_values=[0.1, 0.05, 0.01, 0.001] + ) + ) + current_step = 2500 + learning_rate = compute_lr(lr_schedule, current_step) + ``` + """ if lr_config.constant is not None: return lr_config.constant elif lr_config.piecewise_constant is not None: @@ -46,11 +74,54 @@ def compute_lr(lr_config, step): class LRShim(_LRScheduler): - """Shim to get learning rates into a LRScheduler. - - This adheres to the torch.optim scheduler API and can be plugged anywhere that - e.g. exponential decay can be used. """ + Learning Rate Scheduler Shim to adjust learning rates during training. + + This class acts as a shim to apply different learning rates to individual parameter groups + within an optimizer. It adheres to the torch.optim scheduler API and can be used with various + optimizers, allowing fine-grained control over learning rates based on configuration. + + Args: + optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted. + lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their + corresponding learning rate configurations. + last_epoch (int, optional): The index of the last epoch. Default is -1. + verbose (bool, optional): If True, prints a warning message when accessing learning rates + using the deprecated `get_lr()` method. Default is False. + + Raises: + ValueError: If the number of parameter groups in the optimizer does not match the number + of learning rate configurations provided. + + Note: + To obtain the last computed learning rates, please use `get_last_lr()`. + + Example: + ```python + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + lr_schedule = { + 'main': LearningRate(constant=0.01), + 'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant( + learning_rate_boundaries=[1000, 2000], + learning_rate_values=[0.01, 0.001] + )) + } + lr_shim = LRShim(optimizer, lr_schedule) + + for epoch in range(num_epochs): + # Train the model + train(...) + # Update learning rates at the end of each epoch + lr_shim.step(epoch) + + final_lr_main = lr_shim.get_last_lr()['main'] + final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary'] + ``` + + See Also: + - `LearningRate`: Configuration for specifying learning rates. + - `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules. + """ def __init__( self, @@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig): def build_optimizer( model: torch.nn.Module, optimizer_config: OptimizerConfig ) -> Tuple[Optimizer, _LRScheduler]: - """Builds an optimizer and LR scheduler from an OptimizerConfig. - Note: use this when you want the same optimizer and learning rate schedule for all your parameters. """ + Build an optimizer and learning rate scheduler based on the provided optimizer configuration. + + Args: + model (torch.nn.Module): The PyTorch model for which the optimizer will be created. + optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer + algorithm and learning rate settings. + + Returns: + Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler + objects. + + Note: + This function is intended for cases where you want the same optimizer and learning rate + schedule for all model parameters. + + Example: + ```python + model = MyModel() + optimizer_config = OptimizerConfig( + learning_rate=LearningRate(constant=0.01), + sgd=SgdConfig(lr=0.01, momentum=0.9) + ) + optimizer, scheduler = build_optimizer(model, optimizer_config) + + for epoch in range(num_epochs): + # Train the model with the optimizer + train(model, optimizer, ...) + # Update learning rates at the end of each epoch + scheduler.step(epoch) + ``` + + See Also: + - `OptimizerConfig`: Configuration for specifying optimizer settings. + - `LRShim`: Learning rate scheduler shim for fine-grained learning rate control. + """ optimizer_class = get_optimizer_class(optimizer_config) optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict()) # We're passing everything in as one group here diff --git a/projects/home/recap/config.py b/projects/home/recap/config.py index 261d7e5..563bf1a 100644 --- a/projects/home/recap/config.py +++ b/projects/home/recap/config.py @@ -9,6 +9,36 @@ import pydantic class TrainingConfig(config_mod.BaseConfig): + """ + Configuration settings for the training process. + + This class defines various training-related settings, including the directory to save checkpoints, the number + of training steps, logging intervals, and other training parameters. + + Attributes: + save_dir (str): The directory where checkpoints and training artifacts will be saved. + num_train_steps (pydantic.PositiveInt): The total number of training steps to run. + initial_checkpoint_dir (str): The directory containing initial checkpoints (optional). + checkpoint_every_n (pydantic.PositiveInt): Frequency of saving checkpoints during training. + checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional). + train_log_every_n (pydantic.PositiveInt): Frequency of logging training progress. + num_eval_steps (int): Number of evaluation steps. Use a negative value to evaluate the entire dataset. + eval_log_every_n (pydantic.PositiveInt): Frequency of logging evaluation progress. + eval_timeout_in_s (pydantic.PositiveFloat): Maximum time (in seconds) allowed for evaluation. + gradient_accumulation (int): Number of replica steps to accumulate gradients (optional). + + Example: + To configure training with checkpoints saved every 1000 steps, use the following settings: + + ```python + TrainingConfig( + save_dir="/tmp/model", + num_train_steps=1000000, + checkpoint_every_n=1000, + train_log_every_n=1000, + ) + ``` + """ save_dir: str = "/tmp/model" num_train_steps: pydantic.PositiveInt = 1000000 initial_checkpoint_dir: str = pydantic.Field( @@ -32,6 +62,42 @@ class TrainingConfig(config_mod.BaseConfig): class RecapConfig(config_mod.BaseConfig): + """ + Configuration settings for the Recap model training process. + + This class defines the overall configuration for the training process of a Recap model. It includes settings for + training, model architecture, data, optimization, and evaluation. + + Attributes: + training (TrainingConfig): Configuration settings for the training process. + model (model_config.ModelConfig): Configuration settings for the Recap model architecture. + train_data (data_config.RecapDataConfig): Configuration settings for training data. + validation_data (Dict[str, data_config.RecapDataConfig]): Configuration settings for validation data. + optimizer (optimizer_config.RecapOptimizerConfig): Configuration settings for optimization. + which_metrics (Optional[str]): Optional specification of which metrics to pick. + + Note: + This class encapsulates all the necessary configurations to train a Recap model. It defines settings for + training, the model architecture, data loading, optimization, and evaluation. + + Example: + To configure a Recap model training process, use the following settings: + + ```python + RecapConfig( + training=TrainingConfig( + save_dir="/tmp/model", + num_train_steps=1000000, + checkpoint_every_n=1000, + train_log_every_n=1000, + ), + model=model_config.ModelConfig(...), + train_data=data_config.RecapDataConfig(...), + validation_data={"dev": data_config.RecapDataConfig(...)}, + optimizer=optimizer_config.RecapOptimizerConfig(...), + ) + ``` + """ training: TrainingConfig = pydantic.Field(TrainingConfig()) model: model_config.ModelConfig train_data: data_config.RecapDataConfig diff --git a/projects/home/recap/data/config.py b/projects/home/recap/data/config.py index 27ef3ed..c5ee4c0 100644 --- a/projects/home/recap/data/config.py +++ b/projects/home/recap/data/config.py @@ -8,239 +8,250 @@ import pydantic class ExplicitDateInputs(base_config.BaseConfig): - """Arguments to select train/validation data using end_date and days of data.""" + """Arguments to select train/validation data using end_date and days of data.""" - data_root: str = pydantic.Field(..., description="Data path prefix.") - end_date: str = pydantic.Field(..., description="Data end date, inclusive.") - days: int = pydantic.Field(..., description="Number of days of data for dataset.") - num_missing_days_tol: int = pydantic.Field( - 0, description="We tolerate <= num_missing_days_tol days of missing data." - ) + data_root: str = pydantic.Field(..., description="Data path prefix.") + end_date: str = pydantic.Field(..., + description="Data end date, inclusive.") + days: int = pydantic.Field(..., + description="Number of days of data for dataset.") + num_missing_days_tol: int = pydantic.Field( + 0, description="We tolerate <= num_missing_days_tol days of missing data." + ) class ExplicitDatetimeInputs(base_config.BaseConfig): - """Arguments to select train/validation data using end_datetime and hours of data.""" + """Arguments to select train/validation data using end_datetime and hours of data.""" - data_root: str = pydantic.Field(..., description="Data path prefix.") - end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.") - hours: int = pydantic.Field(..., description="Number of hours of data for dataset.") - num_missing_hours_tol: int = pydantic.Field( - 0, description="We tolerate <= num_missing_hours_tol hours of missing data." - ) + data_root: str = pydantic.Field(..., description="Data path prefix.") + end_datetime: str = pydantic.Field(..., + description="Data end datetime, inclusive.") + hours: int = pydantic.Field(..., + description="Number of hours of data for dataset.") + num_missing_hours_tol: int = pydantic.Field( + 0, description="We tolerate <= num_missing_hours_tol hours of missing data." + ) class DdsCompressionOption(str, Enum): - """The only valid compression option is 'AUTO'""" + """The only valid compression option is 'AUTO'""" - AUTO = "AUTO" + AUTO = "AUTO" class DatasetConfig(base_config.BaseConfig): - inputs: str = pydantic.Field( - None, description="A glob for selecting data.", one_of="date_inputs_format" - ) - explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field( - None, one_of="date_inputs_format" - ) - explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format") + inputs: str = pydantic.Field( + None, description="A glob for selecting data.", one_of="date_inputs_format" + ) + explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field( + None, one_of="date_inputs_format" + ) + explicit_date_inputs: ExplicitDateInputs = pydantic.Field( + None, one_of="date_inputs_format") - global_batch_size: pydantic.PositiveInt + global_batch_size: pydantic.PositiveInt - num_files_to_keep: pydantic.PositiveInt = pydantic.Field( - None, description="Number of shards to keep." - ) - repeat_files: bool = pydantic.Field( - True, description="DEPRICATED. Files are repeated no matter what this is set to." - ) - file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size") + num_files_to_keep: pydantic.PositiveInt = pydantic.Field( + None, description="Number of shards to keep." + ) + repeat_files: bool = pydantic.Field( + True, description="DEPRICATED. Files are repeated no matter what this is set to." + ) + file_batch_size: pydantic.PositiveInt = pydantic.Field( + 16, description="File batch size") - cache: bool = pydantic.Field( - False, - description="Cache dataset in memory. Careful to only use this when you" - " have enough memory to fit entire dataset.", - ) + cache: bool = pydantic.Field( + False, + description="Cache dataset in memory. Careful to only use this when you" + " have enough memory to fit entire dataset.", + ) - data_service_dispatcher: str = pydantic.Field(None) - ignore_data_errors: bool = pydantic.Field( - False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs." - ) - dataset_service_compression: DdsCompressionOption = pydantic.Field( - None, - description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'", - ) + data_service_dispatcher: str = pydantic.Field(None) + ignore_data_errors: bool = pydantic.Field( + False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs." + ) + dataset_service_compression: DdsCompressionOption = pydantic.Field( + None, + description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'", + ) - # tf.data.Dataset options - examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.") - map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( - None, description="Number of parallel calls." - ) - interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( - None, description="Number of shards to interleave." - ) + # tf.data.Dataset options + examples_shuffle_buffer_size: int = pydantic.Field( + 1024, description="Size of shuffle buffers.") + map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( + None, description="Number of parallel calls." + ) + interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field( + None, description="Number of shards to interleave." + ) class TruncateAndSlice(base_config.BaseConfig): - # Apply truncation and then slice. - continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field( - None, description="Experimental. Truncates continuous features to this amount for efficiency." - ) - binary_feature_truncation: pydantic.PositiveInt = pydantic.Field( - None, description="Experimental. Truncates binary features to this amount for efficiency." - ) + # Apply truncation and then slice. + continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field( + None, description="Experimental. Truncates continuous features to this amount for efficiency." + ) + binary_feature_truncation: pydantic.PositiveInt = pydantic.Field( + None, description="Experimental. Truncates binary features to this amount for efficiency." + ) - continuous_feature_mask_path: str = pydantic.Field( - None, description="Path of mask used to slice input continuous features." - ) - binary_feature_mask_path: str = pydantic.Field( - None, description="Path of mask used to slice input binary features." - ) + continuous_feature_mask_path: str = pydantic.Field( + None, description="Path of mask used to slice input continuous features." + ) + binary_feature_mask_path: str = pydantic.Field( + None, description="Path of mask used to slice input binary features." + ) class DataType(str, Enum): - BFLOAT16 = "bfloat16" - BOOL = "bool" + BFLOAT16 = "bfloat16" + BOOL = "bool" - FLOAT32 = "float32" - FLOAT16 = "float16" + FLOAT32 = "float32" + FLOAT16 = "float16" - UINT8 = "uint8" + UINT8 = "uint8" class DownCast(base_config.BaseConfig): - # Apply down casting to selected features. - features: typing.Dict[str, DataType] = pydantic.Field( - None, description="Map features to down cast data types." - ) + # Apply down casting to selected features. + features: typing.Dict[str, DataType] = pydantic.Field( + None, description="Map features to down cast data types." + ) class TaskData(base_config.BaseConfig): - pos_downsampling_rate: float = pydantic.Field( - 1.0, - description="Downsampling rate of positives used to generate dataset.", - ) - neg_downsampling_rate: float = pydantic.Field( - 1.0, - description="Downsampling rate of negatives used to generate dataset.", - ) + pos_downsampling_rate: float = pydantic.Field( + 1.0, + description="Downsampling rate of positives used to generate dataset.", + ) + neg_downsampling_rate: float = pydantic.Field( + 1.0, + description="Downsampling rate of negatives used to generate dataset.", + ) class SegDenseSchema(base_config.BaseConfig): - schema_path: str = pydantic.Field(..., description="Path to feature config json.") - features: typing.List[str] = pydantic.Field( - [], - description="List of features (in addition to the renamed features) to read from schema path above.", - ) - renamed_features: typing.Dict[str, str] = pydantic.Field( - {}, description="Dictionary of renamed features." - ) - mask_mantissa_features: typing.Dict[str, int] = pydantic.Field( - {}, - description="(experimental) Number of mantissa bits to mask to simulate lower precision data.", - ) + schema_path: str = pydantic.Field(..., + description="Path to feature config json.") + features: typing.List[str] = pydantic.Field( + [], + description="List of features (in addition to the renamed features) to read from schema path above.", + ) + renamed_features: typing.Dict[str, str] = pydantic.Field( + {}, description="Dictionary of renamed features." + ) + mask_mantissa_features: typing.Dict[str, int] = pydantic.Field( + {}, + description="(experimental) Number of mantissa bits to mask to simulate lower precision data.", + ) class RectifyLabels(base_config.BaseConfig): - label_rectification_window_in_hours: float = pydantic.Field( - 3.0, description="overlap time in hours for which to flip labels" - ) - served_timestamp_field: str = pydantic.Field( - ..., description="input field corresponding to served time" - ) - impressed_timestamp_field: str = pydantic.Field( - ..., description="input field corresponding to impressed time" - ) - label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field( - ..., description="label to the input field corresponding to engagement time" - ) + label_rectification_window_in_hours: float = pydantic.Field( + 3.0, description="overlap time in hours for which to flip labels" + ) + served_timestamp_field: str = pydantic.Field( + ..., description="input field corresponding to served time" + ) + impressed_timestamp_field: str = pydantic.Field( + ..., description="input field corresponding to impressed time" + ) + label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field( + ..., description="label to the input field corresponding to engagement time" + ) class ExtractFeaturesRow(base_config.BaseConfig): - name: str = pydantic.Field( - ..., - description="name of the new field name to be created", - ) - source_tensor: str = pydantic.Field( - ..., - description="name of the dense tensor to look for the feature", - ) - index: int = pydantic.Field( - ..., - description="index of the feature in the dense tensor", - ) + name: str = pydantic.Field( + ..., + description="name of the new field name to be created", + ) + source_tensor: str = pydantic.Field( + ..., + description="name of the dense tensor to look for the feature", + ) + index: int = pydantic.Field( + ..., + description="index of the feature in the dense tensor", + ) class ExtractFeatures(base_config.BaseConfig): - extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field( - [], - description="list of features to be extracted with their name, source tensor and index", - ) + extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field( + [], + description="list of features to be extracted with their name, source tensor and index", + ) class DownsampleNegatives(base_config.BaseConfig): - batch_multiplier: int = pydantic.Field( - None, - description="batch multiplier", - ) - engagements_list: typing.List[str] = pydantic.Field( - [], - description="engagements with kept positives", - ) - num_engagements: int = pydantic.Field( - ..., - description="number engagements used in the model, including ones excluded in engagements_list", - ) + batch_multiplier: int = pydantic.Field( + None, + description="batch multiplier", + ) + engagements_list: typing.List[str] = pydantic.Field( + [], + description="engagements with kept positives", + ) + num_engagements: int = pydantic.Field( + ..., + description="number engagements used in the model, including ones excluded in engagements_list", + ) class Preprocess(base_config.BaseConfig): - truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.") - downcast: DownCast = pydantic.Field(None, description="Down cast to features.") - rectify_labels: RectifyLabels = pydantic.Field( - None, description="Rectify labels for a given overlap window" - ) - extract_features: ExtractFeatures = pydantic.Field( - None, description="Extract features from dense tensors." - ) - downsample_negatives: DownsampleNegatives = pydantic.Field( - None, description="Downsample negatives." - ) + truncate_and_slice: TruncateAndSlice = pydantic.Field( + None, description="Truncation and slicing.") + downcast: DownCast = pydantic.Field( + None, description="Down cast to features.") + rectify_labels: RectifyLabels = pydantic.Field( + None, description="Rectify labels for a given overlap window" + ) + extract_features: ExtractFeatures = pydantic.Field( + None, description="Extract features from dense tensors." + ) + downsample_negatives: DownsampleNegatives = pydantic.Field( + None, description="Downsample negatives." + ) class Sampler(base_config.BaseConfig): - """Assumes function is defined in data/samplers.py. + """Assumes function is defined in data/samplers.py. - Only use this for quick experimentation. - If samplers are useful, we should sample from upstream data generation. + Only use this for quick experimentation. + If samplers are useful, we should sample from upstream data generation. - DEPRICATED, DO NOT USE. - """ + DEPRICATED, DO NOT USE. + """ - name: str - kwargs: typing.Dict + name: str + kwargs: typing.Dict class RecapDataConfig(DatasetConfig): - seg_dense_schema: SegDenseSchema + seg_dense_schema: SegDenseSchema - tasks: typing.Dict[str, TaskData] = pydantic.Field( - description="Description of individual tasks in this dataset." - ) - evaluation_tasks: typing.List[str] = pydantic.Field( - [], description="If specified, lists the tasks we're generating metrics for." - ) + tasks: typing.Dict[str, TaskData] = pydantic.Field( + description="Description of individual tasks in this dataset." + ) + evaluation_tasks: typing.List[str] = pydantic.Field( + [], description="If specified, lists the tasks we're generating metrics for." + ) - preprocess: Preprocess = pydantic.Field( - None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference." - ) + preprocess: Preprocess = pydantic.Field( + None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference." + ) - sampler: Sampler = pydantic.Field( - None, - description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""", - ) + sampler: Sampler = pydantic.Field( + None, + description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""", + ) - @pydantic.root_validator() - def _validate_evaluation_tasks(cls, values): - if values.get("evaluation_tasks") is not None: - for task in values["evaluation_tasks"]: - if task not in values["tasks"]: - raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}") - return values + @pydantic.root_validator() + def _validate_evaluation_tasks(cls, values): + if values.get("evaluation_tasks") is not None: + for task in values["evaluation_tasks"]: + if task not in values["tasks"]: + raise KeyError( + f"Evaluation task {task} must be in tasks. Received {values['tasks']}") + return values diff --git a/projects/home/recap/data/dataset.py b/projects/home/recap/data/dataset.py index 3478c68..4d13ac5 100644 --- a/projects/home/recap/data/dataset.py +++ b/projects/home/recap/data/dataset.py @@ -11,8 +11,8 @@ from tml.projects.home.recap.data import preprocessors from tml.projects.home.recap.config import JobMode from tml.projects.home.recap.data.tfe_parsing import get_seg_dense_parse_fn from tml.projects.home.recap.data.util import ( - keyed_jagged_tensor_from_tensors_dict, - sparse_or_dense_tf_to_torch, + keyed_jagged_tensor_from_tensors_dict, + sparse_or_dense_tf_to_torch, ) from absl import logging import torch.distributed as dist @@ -20,458 +20,546 @@ import torch.distributed as dist @dataclass class RecapBatch(DataclassBatch): - """Holds features and labels from the Recap dataset.""" + """Holds features and labels from the Recap dataset.""" - continuous_features: torch.Tensor - binary_features: torch.Tensor - discrete_features: torch.Tensor - sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821 - labels: torch.Tensor - user_embedding: torch.Tensor = None - user_eng_embedding: torch.Tensor = None - author_embedding: torch.Tensor = None - weights: torch.Tensor = None + continuous_features: torch.Tensor + binary_features: torch.Tensor + discrete_features: torch.Tensor + sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821 + labels: torch.Tensor + user_embedding: torch.Tensor = None + user_eng_embedding: torch.Tensor = None + author_embedding: torch.Tensor = None + weights: torch.Tensor = None - def __post_init__(self): - if self.weights is None: - self.weights = torch.ones_like(self.labels) - for feature_name, feature_value in self.as_dict().items(): - if ("embedding" in feature_name) and (feature_value is None): - setattr(self, feature_name, torch.empty([0, 0])) + def __post_init__(self): + if self.weights is None: + self.weights = torch.ones_like(self.labels) + for feature_name, feature_value in self.as_dict().items(): + if ("embedding" in feature_name) and (feature_value is None): + setattr(self, feature_name, torch.empty([0, 0])) def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch: - """Converts a torch data loader output into `RecapBatch`.""" + """Converts a torch data loader output into `RecapBatch`.""" - x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x) - try: - features_in, labels = x - except ValueError: - # For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple - features_in, labels = x, None + x = tf.nest.map_structure(functools.partial( + sparse_or_dense_tf_to_torch, pin_memory=False), x) + try: + features_in, labels = x + except ValueError: + # For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple + features_in, labels = x, None - sparse_features = keyed_jagged_tensor_from_tensors_dict({}) - if sparse_feature_names: - sparse_features = keyed_jagged_tensor_from_tensors_dict( - {embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names} + sparse_features = keyed_jagged_tensor_from_tensors_dict({}) + if sparse_feature_names: + sparse_features = keyed_jagged_tensor_from_tensors_dict( + {embedding_name: features_in[embedding_name] + for embedding_name in sparse_feature_names} + ) + + user_embedding, user_eng_embedding, author_embedding = None, None, None + if "user_embedding" in features_in: + if sparse_feature_names and "meta__user_id" in sparse_feature_names: + raise ValueError( + "Only one source of embedding for user is supported") + else: + user_embedding = features_in["user_embedding"] + + if "user_eng_embedding" in features_in: + if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names: + raise ValueError( + "Only one source of embedding for user is supported") + else: + user_eng_embedding = features_in["user_eng_embedding"] + + if "author_embedding" in features_in: + if sparse_feature_names and "meta__author_id" in sparse_feature_names: + raise ValueError( + "Only one source of embedding for user is supported") + else: + author_embedding = features_in["author_embedding"] + + return RecapBatch( + continuous_features=features_in["continuous"], + binary_features=features_in["binary"], + discrete_features=features_in["discrete"], + sparse_features=sparse_features, + user_embedding=user_embedding, + user_eng_embedding=user_eng_embedding, + author_embedding=author_embedding, + labels=labels, + # Defaults to torch.ones_like(labels) + weights=features_in.get("weights", None), ) - user_embedding, user_eng_embedding, author_embedding = None, None, None - if "user_embedding" in features_in: - if sparse_feature_names and "meta__user_id" in sparse_feature_names: - raise ValueError("Only one source of embedding for user is supported") - else: - user_embedding = features_in["user_embedding"] - - if "user_eng_embedding" in features_in: - if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names: - raise ValueError("Only one source of embedding for user is supported") - else: - user_eng_embedding = features_in["user_eng_embedding"] - - if "author_embedding" in features_in: - if sparse_feature_names and "meta__author_id" in sparse_feature_names: - raise ValueError("Only one source of embedding for user is supported") - else: - author_embedding = features_in["author_embedding"] - - return RecapBatch( - continuous_features=features_in["continuous"], - binary_features=features_in["binary"], - discrete_features=features_in["discrete"], - sparse_features=sparse_features, - user_embedding=user_embedding, - user_eng_embedding=user_eng_embedding, - author_embedding=author_embedding, - labels=labels, - weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels) - ) - def _chain(param, f1, f2): - """ - Reduce multiple functions into one chained function - _chain(x, f1, f2) -> f2(f1(x)) - """ - output = param - fns = [f1, f2] - for f in fns: - output = f(output) - return output + """ + Reduce multiple functions into one chained function + _chain(x, f1, f2) -> f2(f1(x)) + """ + output = param + fns = [f1, f2] + for f in fns: + output = f(output) + return output def _add_weights(inputs, tasks: Dict[str, TaskData]): - """Adds weights based on label sampling for positive and negatives. + """Adds weights based on label sampling for positive and negatives. - This is useful for numeric calibration etc. This mutates inputs. + This is useful for numeric calibration etc. This mutates inputs. - Args: - inputs: A dictionary of strings to tensor-like structures. - tasks: A dict of string (label) to `TaskData` specifying inputs. + Args: + inputs: A dictionary of strings to tensor-like structures. + tasks: A dict of string (label) to `TaskData` specifying inputs. - Returns: - A tuple of features and labels; weights are added to features. - """ + Returns: + A tuple of features and labels; weights are added to features. + """ - weights = [] - for key, task in tasks.items(): - label = inputs[key] - float_label = tf.cast(label, tf.float32) + weights = [] + for key, task in tasks.items(): + label = inputs[key] + float_label = tf.cast(label, tf.float32) - weights.append( - float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate - ) + weights.append( + float_label / task.pos_downsampling_rate + + (1.0 - float_label) / task.neg_downsampling_rate + ) - # Ensure we are batch-major (assumes we batch before this call). - inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0) - return inputs + # Ensure we are batch-major (assumes we batch before this call). + inputs["weights"] = tf.squeeze(tf.transpose( + tf.convert_to_tensor(weights)), axis=0) + return inputs def get_datetimes(explicit_datetime_inputs): - """Compute list datetime strings for train/validation data.""" - datetime_format = "%Y/%m/%d/%H" - end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format) - dates = sorted( - [ - (end - timedelta(hours=i + 1)).strftime(datetime_format) - for i in range(int(explicit_datetime_inputs.hours)) - ] - ) - return dates + """Compute list datetime strings for train/validation data.""" + datetime_format = "%Y/%m/%d/%H" + end = datetime.strptime( + explicit_datetime_inputs.end_datetime, datetime_format) + dates = sorted( + [ + (end - timedelta(hours=i + 1)).strftime(datetime_format) + for i in range(int(explicit_datetime_inputs.hours)) + ] + ) + return dates def get_explicit_datetime_inputs_files(explicit_datetime_inputs): - """ - Compile list of files for training/validation. + """ + Compile list of files for training/validation. - Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data. - For each hour of data, if the directory is missing or empty, we increment a counter to keep - track of the number of missing data hours. - Returns only files with a `.gz` extension. + Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data. + For each hour of data, if the directory is missing or empty, we increment a counter to keep + track of the number of missing data hours. + Returns only files with a `.gz` extension. - Args: - explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object + Args: + explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object - Returns: - data_files: Sorted list of files to read corresponding to data at the desired datetimes - num_hours_missing: Number of hours that we are missing data + Returns: + data_files: Sorted list of files to read corresponding to data at the desired datetimes + num_hours_missing: Number of hours that we are missing data - """ - datetimes = get_datetimes(explicit_datetime_inputs) - folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes] - data_files = [] - num_hours_missing = 0 - for folder in folders: - try: - files = tf.io.gfile.listdir(folder) - if not files: - logging.warning(f"{folder} contained no data files") - num_hours_missing += 1 - data_files.extend( - [ - os.path.join(folder, filename) - for filename in files - if filename.rsplit(".", 1)[-1].lower() == "gz" - ] - ) - except tf.errors.NotFoundError as e: - num_hours_missing += 1 - logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}") - return sorted(data_files), num_hours_missing + """ + datetimes = get_datetimes(explicit_datetime_inputs) + folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) + for datetime in datetimes] + data_files = [] + num_hours_missing = 0 + for folder in folders: + try: + files = tf.io.gfile.listdir(folder) + if not files: + logging.warning(f"{folder} contained no data files") + num_hours_missing += 1 + data_files.extend( + [ + os.path.join(folder, filename) + for filename in files + if filename.rsplit(".", 1)[-1].lower() == "gz" + ] + ) + except tf.errors.NotFoundError as e: + num_hours_missing += 1 + logging.warning( + f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}") + return sorted(data_files), num_hours_missing def _map_output_for_inference( - inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False + inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False ): - if preprocessor: - raise ValueError("No preprocessor should be used at inference time.") - if add_weights: - raise NotImplementedError() + """ + Map the input data for inference. - # Add zero weights. - inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1)) - for label in tasks: - del inputs[label] - return inputs + Args: + inputs (dict): Input data dictionary. + tasks (Dict[str, TaskData]): Dictionary of task data configurations. + preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None. + add_weights (bool, optional): Whether to add weights. Defaults to False. + + Returns: + dict: Modified input data dictionary for inference. + """ + if preprocessor: + raise ValueError("No preprocessor should be used at inference time.") + if add_weights: + raise NotImplementedError() + + # Add zero weights. + inputs["weights"] = tf.zeros_like( + tf.expand_dims(inputs["continuous"][:, 0], -1)) + for label in tasks: + del inputs[label] + return inputs def _map_output_for_train_eval( - inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False + inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False ): - if add_weights: - inputs = _add_weights_based_on_sampling_rates(inputs, tasks) + """ + Map the input data for training and evaluation. - # Warning this has to happen first as it changes the input - if preprocessor: - inputs = preprocessor(inputs) + Args: + inputs (dict): Input data dictionary. + tasks (Dict[str, TaskData]): Dictionary of task data configurations. + preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None. + add_weights (bool, optional): Whether to add weights. Defaults to False. - label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1]) + Returns: + Tuple[dict, tf.Tensor]: Modified input data dictionary and label values for training and evaluation. + """ + if add_weights: + inputs = _add_weights_based_on_sampling_rates(inputs, tasks) - for label in tasks: - del inputs[label] + # Warning this has to happen first as it changes the input + if preprocessor: + inputs = preprocessor(inputs) - return inputs, label_values + label_values = tf.squeeze( + tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1]) + + for label in tasks: + del inputs[label] + + return inputs, label_values def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]): - """Adds weights based on label sampling for positive and negatives. + """Adds weights based on label sampling for positive and negatives. - This is useful for numeric calibration etc. This mutates inputs. + This is useful for numeric calibration etc. This mutates inputs. - Args: - inputs: A dictionary of strings to tensor-like structures. - tasks: A dict of string (label) to `TaskData` specifying inputs. + Args: + inputs: A dictionary of strings to tensor-like structures. + tasks: A dict of string (label) to `TaskData` specifying inputs. - Returns: - A tuple of features and labels; weights are added to features. - """ - weights = [] - for key, task in tasks.items(): - label = inputs[key] - float_label = tf.cast(label, tf.float32) + Returns: + A tuple of features and labels; weights are added to features. + """ + weights = [] + for key, task in tasks.items(): + label = inputs[key] + float_label = tf.cast(label, tf.float32) - weights.append( - float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate - ) + weights.append( + float_label / task.pos_downsampling_rate + + (1.0 - float_label) / task.neg_downsampling_rate + ) - # Ensure we are batch-major (assumes we batch before this call). - inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0) - return inputs + # Ensure we are batch-major (assumes we batch before this call). + inputs["weights"] = tf.squeeze(tf.transpose( + tf.convert_to_tensor(weights)), axis=0) + return inputs class RecapDataset(torch.utils.data.IterableDataset): - def __init__( - self, - data_config: RecapDataConfig, - dataset_service: Optional[str] = None, - mode: JobMode = JobMode.TRAIN, - compression: Optional[str] = "AUTO", - repeat: bool = False, - vocab_mapper: tf.keras.Model = None, - ): - logging.info("***** Labels *****") - logging.info(list(data_config.tasks.keys())) + def __init__( + self, + data_config: RecapDataConfig, + dataset_service: Optional[str] = None, + mode: JobMode = JobMode.TRAIN, + compression: Optional[str] = "AUTO", + repeat: bool = False, + vocab_mapper: tf.keras.Model = None, + ): + """ + Create a RecapDataset for training or inference. - self._data_config = data_config - self._parse_fn = get_seg_dense_parse_fn(data_config) - self._mode = mode - self._repeat = repeat - self._num_concurrent_iterators = 1 - self._vocab_mapper = vocab_mapper - self.dataset_service = dataset_service + Args: + data_config (RecapDataConfig): Data configuration. + dataset_service (str, optional): Dataset service identifier. Defaults to None. + mode (JobMode, optional): Job mode (TRAIN or INFERENCE). Defaults to JobMode.TRAIN. + compression (str, optional): Compression type. Defaults to "AUTO". + repeat (bool, optional): Whether to repeat the dataset. Defaults to False. + vocab_mapper (tf.keras.Model, optional): Vocabulary mapper. Defaults to None. + """ + logging.info("***** Labels *****") + logging.info(list(data_config.tasks.keys())) - preprocessor = None - self._batch_size_multiplier = 1 - if data_config.preprocess: - preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode) - if data_config.preprocess.downsample_negatives: - self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier + self._data_config = data_config + self._parse_fn = get_seg_dense_parse_fn(data_config) + self._mode = mode + self._repeat = repeat + self._num_concurrent_iterators = 1 + self._vocab_mapper = vocab_mapper + self.dataset_service = dataset_service - self._preprocessor = preprocessor + preprocessor = None + self._batch_size_multiplier = 1 + if data_config.preprocess: + preprocessor = preprocessors.build_preprocess( + data_config.preprocess, mode=mode) + if data_config.preprocess.downsample_negatives: + self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier - if mode == JobMode.INFERENCE: - if preprocessor is not None: - raise ValueError("Expect no preprocessor at inference time.") - should_add_weights = False - output_map_fn = _map_output_for_inference # (features,) - else: - # Only add weights if there is a reason to! If all weights will - # be equal to 1.0, save bandwidth between DDS and Chief by simply - # relying on the fact that weights default to 1.0 in `RecapBatch` - # WARNING: Weights may still be added as a side effect of a preprocessor - # such as `DownsampleNegatives`. - should_add_weights = any( - [ - task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0 - for task_cfg in data_config.tasks.values() - ] - ) - output_map_fn = _map_output_for_train_eval # (features, labels) + self._preprocessor = preprocessor - self._output_map_fn = functools.partial( - output_map_fn, - tasks=data_config.tasks, - preprocessor=preprocessor, - add_weights=should_add_weights, - ) + if mode == JobMode.INFERENCE: + if preprocessor is not None: + raise ValueError("Expect no preprocessor at inference time.") + should_add_weights = False + output_map_fn = _map_output_for_inference # (features,) + else: + # Only add weights if there is a reason to! If all weights will + # be equal to 1.0, save bandwidth between DDS and Chief by simply + # relying on the fact that weights default to 1.0 in `RecapBatch` + # WARNING: Weights may still be added as a side effect of a preprocessor + # such as `DownsampleNegatives`. + should_add_weights = any( + [ + task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0 + for task_cfg in data_config.tasks.values() + ] + ) + output_map_fn = _map_output_for_train_eval # (features, labels) - sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None - - self._tf_dataset = self._create_tf_dataset() - - self._init_tensor_spec() - - def _init_tensor_spec(self): - def _tensor_spec_to_torch_shape(spec): - if spec.shape is None: - return None - shape = [x if x is not None else -1 for x in spec.shape] - return torch.Size(shape) - - self.torch_element_spec = tf.nest.map_structure( - _tensor_spec_to_torch_shape, self._tf_dataset.element_spec - ) - - def _create_tf_dataset(self): - if hasattr(self, "_tf_dataset"): - raise ValueError("Do not call `_create_tf_dataset` more than once.") - - world_size = dist.get_world_size() if dist.is_initialized() else 1 - per_replica_bsz = ( - self._batch_size_multiplier * self._data_config.global_batch_size // world_size - ) - - dataset: tf.data.Dataset = self._create_base_tf_dataset( - batch_size=per_replica_bsz, - ) - - if self._repeat: - logging.info("Repeating dataset") - dataset = dataset.repeat() - - if self.dataset_service: - if self._num_concurrent_iterators > 1: - if not self.machines_config: - raise ValueError( - "Must supply a machine_config for autotuning in order to use >1 concurrent iterators" - ) - dataset = dataset_lib.with_auto_tune_budget( - dataset, - machine_config=self.machines_config.chief, - num_concurrent_iterators=self.num_concurrent_iterators, - on_chief=False, + self._output_map_fn = functools.partial( + output_map_fn, + tasks=data_config.tasks, + preprocessor=preprocessor, + add_weights=should_add_weights, ) - self.dataset_id, self.job_name = register_dataset( - dataset=dataset, dataset_service=self.dataset_service, compression=self.compression - ) - dataset = distribute_from_dataset_id( - dataset_id=self.dataset_id, # type: ignore[arg-type] - job_name=self.job_name, - dataset_service=self.dataset_service, - compression=self.compression, - ) + sparse_feature_names = list( + vocab_mapper.vocabs.keys()) if vocab_mapper else None - elif self._num_concurrent_iterators > 1: - if not self.machines_config: - raise ValueError( - "Must supply a machine_config for autotuning in order to use >1 concurrent iterators" + self._tf_dataset = self._create_tf_dataset() + + self._init_tensor_spec() + + def _init_tensor_spec(self): + """ + Initialize the tensor specification for the dataset. + """ + def _tensor_spec_to_torch_shape(spec): + if spec.shape is None: + return None + shape = [x if x is not None else -1 for x in spec.shape] + return torch.Size(shape) + + self.torch_element_spec = tf.nest.map_structure( + _tensor_spec_to_torch_shape, self._tf_dataset.element_spec ) - dataset = dataset_lib.with_auto_tune_budget( - dataset, - machine_config=self.machines_config.chief, - num_concurrent_iterators=self._num_concurrent_iterators, - on_chief=True, - ) - # Vocabulary mapping happens on the training node, not in dds because of size. - if self._vocab_mapper: - dataset = dataset.map(self._vocab_mapper) + def _create_tf_dataset(self): + """ + Create a TensorFlow dataset from the data files. - return dataset.prefetch(world_size * 2) + Returns: + tf.data.Dataset: TensorFlow dataset. + """ + if hasattr(self, "_tf_dataset"): + raise ValueError( + "Do not call `_create_tf_dataset` more than once.") - def _create_base_tf_dataset(self, batch_size: int): - if self._data_config.inputs: - glob = self._data_config.inputs - filenames = sorted(tf.io.gfile.glob(glob)) - elif self._data_config.explicit_datetime_inputs: - num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol - filenames, num_hours_missing = get_explicit_datetime_inputs_files( - self._data_config.explicit_datetime_inputs, - increment="hourly", - ) - if num_hours_missing > num_missing_hours_tol: - raise ValueError( - f"We are missing {num_hours_missing} hours of data" - f"more than tolerance {num_missing_hours_tol}." + world_size = dist.get_world_size() if dist.is_initialized() else 1 + per_replica_bsz = ( + self._batch_size_multiplier * self._data_config.global_batch_size // world_size ) - elif self._data_config.explicit_date_inputs: - num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol - filenames, num_days_missing = get_explicit_datetime_inputs_files( - self._data_config.explicit_date_inputs, - increment="daily", - ) - if num_days_missing > num_missing_days_tol: - raise ValueError( - f"We are missing {num_days_missing} days of data" - f"more than tolerance {num_missing_days_tol}." + + dataset: tf.data.Dataset = self._create_base_tf_dataset( + batch_size=per_replica_bsz, ) - else: - raise ValueError( - "Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config" - ) - num_files = len(filenames) - logging.info(f"Found {num_files} data files") - if num_files < 1: - raise ValueError("No data files found") + if self._repeat: + logging.info("Repeating dataset") + dataset = dataset.repeat() - if self._data_config.num_files_to_keep is not None: - filenames = filenames[: self._data_config.num_files_to_keep] - logging.info(f"Retaining only {len(filenames)} files.") + if self.dataset_service: + if self._num_concurrent_iterators > 1: + if not self.machines_config: + raise ValueError( + "Must supply a machine_config for autotuning in order to use >1 concurrent iterators" + ) + dataset = dataset_lib.with_auto_tune_budget( + dataset, + machine_config=self.machines_config.chief, + num_concurrent_iterators=self.num_concurrent_iterators, + on_chief=False, + ) - filenames_ds = ( - tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames)) - # Because of drop_remainder, if our dataset does not fill - # up a batch, it will emit nothing without this repeat. - .repeat(-1) - ) + self.dataset_id, self.job_name = register_dataset( + dataset=dataset, dataset_service=self.dataset_service, compression=self.compression + ) + dataset = distribute_from_dataset_id( + dataset_id=self.dataset_id, # type: ignore[arg-type] + job_name=self.job_name, + dataset_service=self.dataset_service, + compression=self.compression, + ) - if self._data_config.file_batch_size: - filenames_ds = filenames_ds.batch(self._data_config.file_batch_size) + elif self._num_concurrent_iterators > 1: + if not self.machines_config: + raise ValueError( + "Must supply a machine_config for autotuning in order to use >1 concurrent iterators" + ) + dataset = dataset_lib.with_auto_tune_budget( + dataset, + machine_config=self.machines_config.chief, + num_concurrent_iterators=self._num_concurrent_iterators, + on_chief=True, + ) - def per_shard_dataset(filename): - ds = tf.data.TFRecordDataset([filename], compression_type="GZIP") - return ds.prefetch(4) + # Vocabulary mapping happens on the training node, not in dds because of size. + if self._vocab_mapper: + dataset = dataset.map(self._vocab_mapper) - ds = filenames_ds.interleave( - per_shard_dataset, - block_length=4, - deterministic=False, - num_parallel_calls=self._data_config.interleave_num_parallel_calls - or tf.data.experimental.AUTOTUNE, - ) + return dataset.prefetch(world_size * 2) - # Combine functions into one map call to reduce overhead. - map_fn = functools.partial( - _chain, - f1=self._parse_fn, - f2=self._output_map_fn, - ) + def _create_base_tf_dataset(self, batch_size: int): + """ + Create the base TensorFlow dataset. - # Shuffle -> Batch -> Parse is the correct ordering - # Shuffling needs to be performed before batching otherwise there is not much point - # Batching happens before parsing because tf.Example parsing is actually vectorized - # and works much faster overall on batches of data. - ds = ( - # DANGER DANGER: there is a default shuffle size here. - ds.shuffle(self._data_config.examples_shuffle_buffer_size) - .batch(batch_size=batch_size, drop_remainder=True) - .map( - map_fn, - num_parallel_calls=self._data_config.map_num_parallel_calls - or tf.data.experimental.AUTOTUNE, - ) - ) + Args: + batch_size (int): Batch size. - if self._data_config.cache: - ds = ds.cache() + Returns: + tf.data.Dataset: Base TensorFlow dataset. + """ + if self._data_config.inputs: + glob = self._data_config.inputs + filenames = sorted(tf.io.gfile.glob(glob)) + elif self._data_config.explicit_datetime_inputs: + num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol + filenames, num_hours_missing = get_explicit_datetime_inputs_files( + self._data_config.explicit_datetime_inputs, + increment="hourly", + ) + if num_hours_missing > num_missing_hours_tol: + raise ValueError( + f"We are missing {num_hours_missing} hours of data" + f"more than tolerance {num_missing_hours_tol}." + ) + elif self._data_config.explicit_date_inputs: + num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol + filenames, num_days_missing = get_explicit_datetime_inputs_files( + self._data_config.explicit_date_inputs, + increment="daily", + ) + if num_days_missing > num_missing_days_tol: + raise ValueError( + f"We are missing {num_days_missing} days of data" + f"more than tolerance {num_missing_days_tol}." + ) + else: + raise ValueError( + "Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config" + ) - if self._data_config.ignore_data_errors: - ds = ds.apply(tf.data.experimental.ignore_errors()) + num_files = len(filenames) + logging.info(f"Found {num_files} data files") + if num_files < 1: + raise ValueError("No data files found") - options = tf.data.Options() - options.experimental_deterministic = False - ds = ds.with_options(options) + if self._data_config.num_files_to_keep is not None: + filenames = filenames[: self._data_config.num_files_to_keep] + logging.info(f"Retaining only {len(filenames)} files.") - return ds + filenames_ds = ( + tf.data.Dataset.from_tensor_slices( + filenames).shuffle(len(filenames)) + # Because of drop_remainder, if our dataset does not fill + # up a batch, it will emit nothing without this repeat. + .repeat(-1) + ) - def _gen(self): - for x in self._tf_dataset: - yield to_batch(x) + if self._data_config.file_batch_size: + filenames_ds = filenames_ds.batch( + self._data_config.file_batch_size) - def to_dataloader(self) -> Dict[str, torch.Tensor]: - return torch.utils.data.DataLoader(self, batch_size=None) + def per_shard_dataset(filename): + """ + Create a TensorFlow dataset for a single shard file. - def __iter__(self): - return iter(self._gen()) + Args: + filename (str): Path to the shard file. + + Returns: + tf.data.Dataset: TensorFlow dataset for the shard file. + """ + ds = tf.data.TFRecordDataset([filename], compression_type="GZIP") + return ds.prefetch(4) + + ds = filenames_ds.interleave( + per_shard_dataset, + block_length=4, + deterministic=False, + num_parallel_calls=self._data_config.interleave_num_parallel_calls + or tf.data.experimental.AUTOTUNE, + ) + + # Combine functions into one map call to reduce overhead. + map_fn = functools.partial( + _chain, + f1=self._parse_fn, + f2=self._output_map_fn, + ) + + # Shuffle -> Batch -> Parse is the correct ordering + # Shuffling needs to be performed before batching otherwise there is not much point + # Batching happens before parsing because tf.Example parsing is actually vectorized + # and works much faster overall on batches of data. + ds = ( + # DANGER DANGER: there is a default shuffle size here. + ds.shuffle(self._data_config.examples_shuffle_buffer_size) + .batch(batch_size=batch_size, drop_remainder=True) + .map( + map_fn, + num_parallel_calls=self._data_config.map_num_parallel_calls + or tf.data.experimental.AUTOTUNE, + ) + ) + + if self._data_config.cache: + ds = ds.cache() + + if self._data_config.ignore_data_errors: + ds = ds.apply(tf.data.experimental.ignore_errors()) + + options = tf.data.Options() + options.experimental_deterministic = False + ds = ds.with_options(options) + + return ds + + def _gen(self): + for x in self._tf_dataset: + yield to_batch(x) + + def to_dataloader(self) -> Dict[str, torch.Tensor]: + """ + Convert the RecapDataset to a PyTorch DataLoader. + + Returns: + torch.utils.data.DataLoader: PyTorch DataLoader for the dataset. + """ + return torch.utils.data.DataLoader(self, batch_size=None) + + def __iter__(self): + return iter(self._gen()) diff --git a/projects/home/recap/data/generate_random_data.py b/projects/home/recap/data/generate_random_data.py index 049385a..61bf010 100644 --- a/projects/home/recap/data/generate_random_data.py +++ b/projects/home/recap/data/generate_random_data.py @@ -17,6 +17,16 @@ FLAGS = flags.FLAGS def _generate_random_example( tf_example_schema: Dict[str, tf.io.FixedLenFeature] ) -> Dict[str, tf.Tensor]: + """ + Generate a random example based on the provided TensorFlow example schema. + + Args: + tf_example_schema (Dict[str, tf.io.FixedLenFeature]): A dictionary defining the schema of the TensorFlow example. + + Returns: + Dict[str, tf.Tensor]: A dictionary containing random data for each feature defined in the schema. + """ + example = {} for feature_name, feature_spec in tf_example_schema.items(): dtype = feature_spec.dtype @@ -33,14 +43,43 @@ def _generate_random_example( def _float_feature(value): + """ + Create a TensorFlow float feature. + + Args: + value: A float or list of floats. + + Returns: + tf.train.Feature: A TensorFlow feature containing the float value(s). + """ + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) def _int64_feature(value): + """ + Create a TensorFlow int64 feature. + + Args: + value: An integer or list of integers. + + Returns: + tf.train.Feature: A TensorFlow feature containing the int64 value(s). + """ + return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes: + """ + Serialize a dictionary of TensorFlow tensors into a binary string. + + Args: + x (Dict[str, tf.Tensor]): A dictionary of TensorFlow tensors. + + Returns: + bytes: The serialized binary string. + """ feature = {} serializers = {tf.float32: _float_feature, tf.int64: _int64_feature} for feature_name, tensor in x.items(): @@ -51,6 +90,15 @@ def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes: def generate_data(data_path: str, config: recap_config_mod.RecapConfig): + """ + Generate random data based on the provided configuration and save it as a TFRecord file. + + Args: + data_path (str): The path where the TFRecord file will be saved. + config (recap_config_mod.RecapConfig): The configuration for generating the random data. + """ + + with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f: seg_dense_schema = json.load(f)["schema"] @@ -68,6 +116,12 @@ def generate_data(data_path: str, config: recap_config_mod.RecapConfig): def _generate_data_main(unused_argv): + """ + Main function to generate random data according to the provided configuration. + + Args: + unused_argv: Unused command-line arguments. + """ config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path) # Find the path where to put the data diff --git a/projects/home/recap/data/preprocessors.py b/projects/home/recap/data/preprocessors.py index d5720e2..efa3042 100644 --- a/projects/home/recap/data/preprocessors.py +++ b/projects/home/recap/data/preprocessors.py @@ -9,9 +9,20 @@ import numpy as np class TruncateAndSlice(tf.keras.Model): - """Class for truncating and slicing.""" + """ + A class for truncating and slicing input features based on the provided configuration. + + Args: + truncate_and_slice_config: A configuration object specifying how to truncate and slice features. + """ def __init__(self, truncate_and_slice_config): + """ + Initializes the TruncateAndSlice model. + + Args: + truncate_and_slice_config: A configuration object specifying how to truncate and slice features. + """ super().__init__() self._truncate_and_slice_config = truncate_and_slice_config @@ -32,6 +43,17 @@ class TruncateAndSlice(tf.keras.Model): self._binary_mask = None def call(self, inputs, training=None, mask=None): + """ + Applies truncation and slicing to the input features based on the configuration. + + Args: + inputs: A dictionary of input features. + training: A boolean indicating whether the model is in training mode. + mask: A mask tensor. + + Returns: + A dictionary of truncated and sliced input features. + """ outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) if self._truncate_and_slice_config.continuous_feature_truncation: logging.info("Truncating continuous") @@ -51,12 +73,23 @@ class TruncateAndSlice(tf.keras.Model): class DownCast(tf.keras.Model): - """Class for Down casting dataset before serialization and transferring to training host. - Depends on the data type and the actual data range, the down casting can be lossless or not. - It is strongly recommended to compare the metrics before and after down casting. """ + A class for downcasting dataset before serialization and transferring to the training host. + + Depending on the data type and the actual data range, the downcasting can be lossless or not. + It is strongly recommended to compare the metrics before and after downcasting. + + Args: + downcast_config: A configuration object specifying the features and their target data types. + """ def __init__(self, downcast_config): + """ + Initializes the DownCast model. + + Args: + downcast_config: A configuration object specifying the features and their target data types. + """ super().__init__() self.config = downcast_config self._type_map = { @@ -65,6 +98,17 @@ class DownCast(tf.keras.Model): } def call(self, inputs, training=None, mask=None): + """ + Applies downcasting to the input features based on the configuration. + + Args: + inputs: A dictionary of input features. + training: A boolean indicating whether the model is in training mode. + mask: A mask tensor. + + Returns: + A dictionary of downcasted input features. + """ outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs)) for feature, type_str in self.config.features.items(): assert type_str in self._type_map @@ -78,14 +122,39 @@ class DownCast(tf.keras.Model): class RectifyLabels(tf.keras.Model): - """Class for rectifying labels""" + """ + A class for downcasting dataset before serialization and transferring to the training host. + + Depending on the data type and the actual data range, the downcasting can be lossless or not. + It is strongly recommended to compare the metrics before and after downcasting. + + Args: + downcast_config: A configuration object specifying the features and their target data types. + """ def __init__(self, rectify_label_config): + """ + Initializes the DownCast model. + + Args: + downcast_config: A configuration object specifying the features and their target data types. + """ super().__init__() self._config = rectify_label_config self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000) def call(self, inputs, training=None, mask=None): + """ + Applies downcasting to the input features based on the configuration. + + Args: + inputs: A dictionary of input features. + training: A boolean indicating whether the model is in training mode. + mask: A mask tensor. + + Returns: + A dictionary of downcasted input features. + """ served_ts_field = self._config.served_timestamp_field impressed_ts_field = self._config.impressed_timestamp_field @@ -102,13 +171,37 @@ class RectifyLabels(tf.keras.Model): class ExtractFeatures(tf.keras.Model): - """Class for extracting individual features from dense tensors by their index.""" + """ + A class for rectifying labels based on specified conditions. + + This class is used to adjust label values in a dataset based on configured conditions involving timestamps. + + Args: + rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings. + """ def __init__(self, extract_features_config): + """ + Initializes the RectifyLabels model. + + Args: + rectify_label_config: A configuration object specifying the timestamp fields and label-to-engaged timestamp field mappings. + """ super().__init__() self._config = extract_features_config def call(self, inputs, training=None, mask=None): + """ + Rectifies label values based on the specified conditions. + + Args: + inputs: A dictionary of input features including timestamp fields and labels. + training: A boolean indicating whether the model is in training mode. + mask: A mask tensor. + + Returns: + A dictionary of input features with rectified label values. + """ for row in self._config.extract_feature_table: inputs[row.name] = inputs[row.source_tensor][:, row.index] @@ -168,7 +261,16 @@ class DownsampleNegatives(tf.keras.Model): def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN): - """Builds a preprocess model to apply all preprocessing stages.""" + """ + Builds a preprocess model to apply all preprocessing stages. + + Args: + preprocess_config: A configuration object specifying the preprocessing parameters. + mode: A mode indicating the current job mode (TRAIN or INFERENCE). + + Returns: + A preprocess model that applies all specified preprocessing stages. + """ if mode == config_mod.JobMode.INFERENCE: logging.info("Not building preprocessors for dataloading since we are in Inference mode.") return None diff --git a/projects/home/recap/data/tfe_parsing.py b/projects/home/recap/data/tfe_parsing.py index f597746..aff73d2 100644 --- a/projects/home/recap/data/tfe_parsing.py +++ b/projects/home/recap/data/tfe_parsing.py @@ -8,122 +8,129 @@ import tensorflow as tf DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""} -DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string} +DTYPE_MAP = {"int64_list": tf.int64, + "float_list": tf.float32, "bytes_list": tf.string} def create_tf_example_schema( - data_config: recap_data_config.SegDenseSchema, - segdense_schema, + data_config: recap_data_config.SegDenseSchema, + segdense_schema, ): - """Generate schema for deseralizing tf.Example. + """Generate schema for deseralizing tf.Example. - Args: - segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length). - labels: List of strings denoting labels. + Args: + segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length). + labels: List of strings denoting labels. - Returns: - A dictionary schema suitable for deserializing tf.Example. - """ - segdense_config = data_config.seg_dense_schema - labels = list(data_config.tasks.keys()) - used_features = ( - segdense_config.features + list(segdense_config.renamed_features.values()) + labels - ) - logging.info(used_features) + Returns: + A dictionary schema suitable for deserializing tf.Example. + """ + segdense_config = data_config.seg_dense_schema + labels = list(data_config.tasks.keys()) + used_features = ( + segdense_config.features + + list(segdense_config.renamed_features.values()) + labels + ) + logging.info(used_features) - tfe_schema = {} - for entry in segdense_schema: - feature_name = entry["feature_name"] + tfe_schema = {} + for entry in segdense_schema: + feature_name = entry["feature_name"] - if feature_name in used_features: - length = entry["length"] - dtype = entry["dtype"] + if feature_name in used_features: + length = entry["length"] + dtype = entry["dtype"] - if feature_name in labels: - logging.info(f"Label: feature name is {feature_name} type is {dtype}") - tfe_schema[feature_name] = tf.io.FixedLenFeature( - length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype] - ) - elif length == -1: - tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype]) - else: - tfe_schema[feature_name] = tf.io.FixedLenFeature( - length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length - ) - for feature_name in used_features: - if feature_name not in tfe_schema: - raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.") - return tfe_schema + if feature_name in labels: + logging.info( + f"Label: feature name is {feature_name} type is {dtype}") + tfe_schema[feature_name] = tf.io.FixedLenFeature( + length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype] + ) + elif length == -1: + tfe_schema[feature_name] = tf.io.VarLenFeature( + DTYPE_MAP[dtype]) + else: + tfe_schema[feature_name] = tf.io.FixedLenFeature( + length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length + ) + for feature_name in used_features: + if feature_name not in tfe_schema: + raise ValueError( + f"{feature_name} missing from schema: {segdense_config.schema_path}.") + return tfe_schema @functools.lru_cache(1) def make_mantissa_mask(mask_length: int) -> tf.Tensor: - """For experimentating with emulating bfloat16 or less precise types.""" - return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32) + """For experimentating with emulating bfloat16 or less precise types.""" + return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32) def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor: - """For experimentating with emulating bfloat16 or less precise types.""" - mask: tf.Tensor = make_mantissa_mask(mask_length) - return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype) + """For experimentating with emulating bfloat16 or less precise types.""" + mask: tf.Tensor = make_mantissa_mask(mask_length) + return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype) def parse_tf_example( - serialized_example, - tfe_schema, - seg_dense_schema_config, + serialized_example, + tfe_schema, + seg_dense_schema_config, ): - """Parse serialized tf.Example into dict of tensors. + """Parse serialized tf.Example into dict of tensors. - Args: - serialized_example: Serialized tf.Example to be parsed. - tfe_schema: Dictionary schema suitable for deserializing tf.Example. + Args: + serialized_example: Serialized tf.Example to be parsed. + tfe_schema: Dictionary schema suitable for deserializing tf.Example. - Returns: - Dictionary of tensors to be used as model input. - """ - inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema) + Returns: + Dictionary of tensors to be used as model input. + """ + inputs = tf.io.parse_example( + serialized=serialized_example, features=tfe_schema) - for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items(): - inputs[new_feature_name] = inputs.pop(old_feature_name) + for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items(): + inputs[new_feature_name] = inputs.pop(old_feature_name) - # This should not actually be used except for experimentation with low precision floats. - if "mask_mantissa_features" in seg_dense_schema_config: - for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items(): - inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length) + # This should not actually be used except for experimentation with low precision floats. + if "mask_mantissa_features" in seg_dense_schema_config: + for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items(): + inputs[feature_name] = mask_mantissa( + inputs[feature_name], mask_length) - # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible - # at TF level. - # We should not return empty tensors if we dont use embeddings. - # Otherwise, it breaks numpy->pt conversion - renamed_keys = list(seg_dense_schema_config.renamed_features.keys()) - for renamed_key in renamed_keys: - if "embedding" in renamed_key and (renamed_key not in inputs): - inputs[renamed_key] = tf.zeros([], tf.float32) + # DANGER DANGER: This default seems really scary, and it's only here because it has to be visible + # at TF level. + # We should not return empty tensors if we dont use embeddings. + # Otherwise, it breaks numpy->pt conversion + renamed_keys = list(seg_dense_schema_config.renamed_features.keys()) + for renamed_key in renamed_keys: + if "embedding" in renamed_key and (renamed_key not in inputs): + inputs[renamed_key] = tf.zeros([], tf.float32) - logging.info(f"parsed example and inputs are {inputs}") - return inputs + logging.info(f"parsed example and inputs are {inputs}") + return inputs def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig): - """Placeholder for seg dense. + """Placeholder for seg dense. - In the future, when we use more seg dense variations, we can change this. - """ - with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f: - seg_dense_schema = json.load(f)["schema"] + In the future, when we use more seg dense variations, we can change this. + """ + with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f: + seg_dense_schema = json.load(f)["schema"] - tf_example_schema = create_tf_example_schema( - data_config, - seg_dense_schema, - ) + tf_example_schema = create_tf_example_schema( + data_config, + seg_dense_schema, + ) - logging.info("***** TF Example Schema *****") - logging.info(tf_example_schema) + logging.info("***** TF Example Schema *****") + logging.info(tf_example_schema) - parse = functools.partial( - parse_tf_example, - tfe_schema=tf_example_schema, - seg_dense_schema_config=data_config.seg_dense_schema, - ) - return parse + parse = functools.partial( + parse_tf_example, + tfe_schema=tf_example_schema, + seg_dense_schema_config=data_config.seg_dense_schema, + ) + return parse diff --git a/projects/home/recap/data/util.py b/projects/home/recap/data/util.py index a9fd51e..c5616c4 100644 --- a/projects/home/recap/data/util.py +++ b/projects/home/recap/data/util.py @@ -6,115 +6,160 @@ import tensorflow as tf def keyed_tensor_from_tensors_dict( - tensor_map: Mapping[str, torch.Tensor] + tensor_map: Mapping[str, torch.Tensor] ) -> "torchrec.KeyedTensor": - """ - Convert a dictionary of torch tensor to torchrec keyed tensor - Args: - tensor_map: + """ + Convert a dictionary of torch tensors to a torchrec KeyedTensor. - Returns: + Args: + tensor_map: A mapping of tensor names to torch tensors. - """ - keys = list(tensor_map.keys()) - # We expect batch size to be first dim. However, if we get a shape [Batch_size], - # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is - # [Batch_size x 1]. - values = [ - tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1) - for key in keys - ] - return torchrec.KeyedTensor.from_tensor_list(keys, values) + Returns: + A torchrec KeyedTensor. + """ + keys = list(tensor_map.keys()) + # We expect batch size to be first dim. However, if we get a shape [Batch_size], + # KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is + # [Batch_size x 1]. + values = [ + tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze( + tensor_map[key], -1) + for key in keys + ] + return torchrec.KeyedTensor.from_tensor_list(keys, values) def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if tensor.is_sparse: - x = tensor.coalesce() # Ensure that the indices are ordered. - lengths = torch.bincount(x.indices()[0]) - values = x.values() - else: - values = tensor - lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device) - return values, lengths + """ + Compute a jagged tensor from a torch tensor. + + Args: + tensor: Input torch tensor. + + Returns: + A tuple containing the values and lengths of the jagged tensor. + """ + if tensor.is_sparse: + x = tensor.coalesce() # Ensure that the indices are ordered. + lengths = torch.bincount(x.indices()[0]) + values = x.values() + else: + values = tensor + lengths = torch.ones( + tensor.shape[0], dtype=torch.int32, device=tensor.device) + return values, lengths def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor": - """ - Convert a torch tensor to torchrec jagged tensor. - Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors. - For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the - dense_shape of the sparse tensor can be arbitrary. - Args: - tensor: a torch (sparse) tensor. - Returns: - """ - values, lengths = _compute_jagged_tensor_from_tensor(tensor) - return torchrec.JaggedTensor(values=values, lengths=lengths) + """ + Convert a torch tensor to a torchrec jagged tensor. + + Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x N] for dense tensors. + For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x N], and the dense_shape of the sparse tensor can be arbitrary. + + Args: + tensor: A torch (sparse) tensor. + + Returns: + A torchrec JaggedTensor. + """ + values, lengths = _compute_jagged_tensor_from_tensor(tensor) + return torchrec.JaggedTensor(values=values, lengths=lengths) def keyed_jagged_tensor_from_tensors_dict( - tensor_map: Mapping[str, torch.Tensor] + tensor_map: Mapping[str, torch.Tensor] ) -> "torchrec.KeyedJaggedTensor": - """ - Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor. - Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors. - For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the - dense_shape of the sparse tensor can be arbitrary. - Args: - tensor_map: + """ + Convert a dictionary of (sparse) torch tensors to a torchrec keyed jagged tensor. - Returns: + Note: Currently, this function only supports input tensors with shapes of [Batch_size] or [Batch_size x 1] for dense tensors. + For sparse tensors, the shape of .values() should be [Batch_size] or [Batch_size x 1], and the dense_shape of the sparse tensor can be arbitrary. - """ + Args: + tensor_map: A mapping of tensor names to torch tensors. + + Returns: + A torchrec KeyedJaggedTensor. + """ + + if not tensor_map: + return torchrec.KeyedJaggedTensor( + keys=[], + values=torch.zeros(0, dtype=torch.int), + lengths=torch.zeros(0, dtype=torch.int), + ) + values = [] + lengths = [] + for tensor in tensor_map.values(): + tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor) + values.append(torch.squeeze(tensor_val)) + lengths.append(tensor_len) + + values = torch.cat(values, axis=0) + lengths = torch.cat(lengths, axis=0) - if not tensor_map: return torchrec.KeyedJaggedTensor( - keys=[], - values=torch.zeros(0, dtype=torch.int), - lengths=torch.zeros(0, dtype=torch.int), + keys=list(tensor_map.keys()), + values=values, + lengths=lengths, ) - values = [] - lengths = [] - for tensor in tensor_map.values(): - tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor) - values.append(torch.squeeze(tensor_val)) - lengths.append(tensor_len) - - values = torch.cat(values, axis=0) - lengths = torch.cat(lengths, axis=0) - - return torchrec.KeyedJaggedTensor( - keys=list(tensor_map.keys()), - values=values, - lengths=lengths, - ) def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray: - return tf_tensor._numpy() # noqa + """ + Convert a TensorFlow tensor to a NumPy array. + + Args: + tf_tensor: TensorFlow tensor. + + Returns: + NumPy array. + """ + return tf_tensor._numpy() # noqa def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor: - tensor = _tf_to_numpy(tensor) - # Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent - if tensor.dtype.name == "bfloat16": - tensor = tensor.astype(np.float32) + """ + Convert a dense TensorFlow tensor to a PyTorch tensor. - tensor = torch.from_numpy(tensor) - if pin_memory: - tensor = tensor.pin_memory() - return tensor + Args: + tensor: Dense TensorFlow tensor. + pin_memory: Whether to pin the tensor in memory (for CUDA). + + Returns: + PyTorch tensor. + """ + tensor = _tf_to_numpy(tensor) + # Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent + if tensor.dtype.name == "bfloat16": + tensor = tensor.astype(np.float32) + + tensor = torch.from_numpy(tensor) + if pin_memory: + tensor = tensor.pin_memory() + return tensor def sparse_or_dense_tf_to_torch( - tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool + tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool ) -> torch.Tensor: - if isinstance(tensor, tf.SparseTensor): - tensor = torch.sparse_coo_tensor( - _dense_tf_to_torch(tensor.indices, pin_memory).t(), - _dense_tf_to_torch(tensor.values, pin_memory), - torch.Size(_tf_to_numpy(tensor.dense_shape)), - ) - else: - tensor = _dense_tf_to_torch(tensor, pin_memory) - return tensor + """ + Convert a TensorFlow tensor (sparse or dense) to a PyTorch tensor. + + Args: + tensor: TensorFlow tensor (sparse or dense). + pin_memory: Whether to pin the tensor in memory (for CUDA). + + Returns: + PyTorch tensor. + """ + if isinstance(tensor, tf.SparseTensor): + tensor = torch.sparse_coo_tensor( + _dense_tf_to_torch(tensor.indices, pin_memory).t(), + _dense_tf_to_torch(tensor.values, pin_memory), + torch.Size(_tf_to_numpy(tensor.dense_shape)), + ) + else: + tensor = _dense_tf_to_torch(tensor, pin_memory) + return tensor diff --git a/projects/home/recap/embedding/config.py b/projects/home/recap/embedding/config.py index a05ec21..70231ee 100644 --- a/projects/home/recap/embedding/config.py +++ b/projects/home/recap/embedding/config.py @@ -30,6 +30,7 @@ class EmbeddingBagConfig(base_config.BaseConfig): class EmbeddingOptimizerConfig(base_config.BaseConfig): + """Configuration for the optimizer used for embedding tables.""" learning_rate: optimizer_config.LearningRate = pydantic.Field( None, description="learning rate scheduler for the EBC" ) @@ -52,6 +53,7 @@ class LargeEmbeddingsConfig(base_config.BaseConfig): class StratifierConfig(base_config.BaseConfig): + """Configuration for Stratifier.""" name: str index: int value: int @@ -87,7 +89,8 @@ class SmallEmbeddingsConfig(base_config.BaseConfig): the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at serving time due to size (>>1 GB). - This small embeddings table uses the same optimizer as the rest of the model.""" + This small embeddings table uses the same optimizer as the rest of the model. + """ tables: List[SmallEmbeddingBagConfig] = pydantic.Field( ..., description="list of embedding tables" diff --git a/projects/home/recap/main.py b/projects/home/recap/main.py index 3416164..5d5936d 100644 --- a/projects/home/recap/main.py +++ b/projects/home/recap/main.py @@ -34,6 +34,33 @@ FLAGS = flags.FLAGS def run(unused_argv: str, data_service_dispatcher: Optional[str] = None): + """ + Main function to run the training of a ranking model. + + This function initializes and runs the training process for a ranking model based on the provided configuration. + + Args: + unused_argv (str): Unused argument. + data_service_dispatcher (Optional[str]): The data service dispatcher for accessing training data (optional). + + Returns: + None + + Raises: + AssertionError: If the configuration or input data is not valid. + + Note: + This function serves as the main entry point for training a ranking model. It loads the configuration, sets up + the training environment, defines the loss function, creates the model, optimizer, and scheduler, and runs the + training loop. + + Example: + To run the training process, use the following command: + + ``` + python run_training.py --config_path=config.yaml + ``` + """ print("#" * 100) config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path) diff --git a/projects/home/recap/model/config.py b/projects/home/recap/model/config.py index 47d0640..357dfa7 100644 --- a/projects/home/recap/model/config.py +++ b/projects/home/recap/model/config.py @@ -40,6 +40,7 @@ class BatchNormConfig(base_config.BaseConfig): class DenseLayerConfig(base_config.BaseConfig): + """Configuration for the dense layer.""" layer_size: pydantic.PositiveInt dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.") @@ -61,6 +62,7 @@ class BatchNormConfig(base_config.BaseConfig): class DoubleNormLogConfig(base_config.BaseConfig): + """Configuration for the double norm log transform.""" batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None) clip_magnitude: float = pydantic.Field( 5.0, description="Threshold to clip the normalized input values." @@ -73,12 +75,14 @@ class Log1pAbsConfig(base_config.BaseConfig): class ClipLog1pAbsConfig(base_config.BaseConfig): + """Configuration for the clip log transform.""" clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field( 3e38, description="Threshold to clip the input values." ) class ZScoreLogConfig(base_config.BaseConfig): + """Configuration for the z-score log transform.""" analysis_path: str schema_path: str = pydantic.Field( None, @@ -148,6 +152,7 @@ class DcnConfig(base_config.BaseConfig): class MaskBlockConfig(base_config.BaseConfig): + """Config for MaskNet block.""" output_size: int reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field( None, one_of="aggregation_size" @@ -159,6 +164,7 @@ class MaskBlockConfig(base_config.BaseConfig): class MaskNetConfig(base_config.BaseConfig): + """Config for MaskNet model.""" mask_blocks: List[MaskBlockConfig] mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel") use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.") @@ -190,6 +196,7 @@ class AffineMap(base_config.BaseConfig): class DLRMConfig(base_config.BaseConfig): + """Config for DLRM model.""" bottom_mlp: MlpConfig = pydantic.Field( ..., description="Bottom mlp, the output to be combined with sparse features and feed to interaction", @@ -198,6 +205,7 @@ class DLRMConfig(base_config.BaseConfig): class TaskModel(base_config.BaseConfig): + """Configuration for a single task.""" mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture") dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture") dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture") @@ -213,6 +221,7 @@ class TaskModel(base_config.BaseConfig): class MultiTaskType(str, enum.Enum): + """Type of multi task architecture.""" SHARE_NONE = "share_none" # Tasks are separate. SHARE_ALL = "share_all" # Tasks share same backbone. SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions. @@ -247,6 +256,7 @@ class ModelConfig(base_config.BaseConfig): @pydantic.root_validator() def _validate_mtl(cls, values): + """Validate the multi task architecture.""" if values.get("multi_task_type", None) is None: return values elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]: diff --git a/projects/home/recap/model/entrypoint.py b/projects/home/recap/model/entrypoint.py index 8f4d534..0cbb4af 100644 --- a/projects/home/recap/model/entrypoint.py +++ b/projects/home/recap/model/entrypoint.py @@ -26,7 +26,19 @@ def unsanitize(sanitized_task_name): def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int): - """ "Builds a model for a single task""" + """ + Build a model for a single task based on the provided configuration. + + Args: + task (model_config_mod.TaskModel): The task model configuration. + input_shape (int): The input shape for the model. + + Returns: + torch.nn.Module: The constructed model for the single task. + + Raises: + ValueError: If the task configuration is not recognized. + """ if task.mlp_config: return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config) elif task.dcn_config: @@ -38,7 +50,12 @@ def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int) class MultiTaskRankingModel(torch.nn.Module): - """Multi-task ranking model.""" + """ + Multi-task ranking model that handles multiple ranking tasks simultaneously. + + This model takes various input features and predicts rankings for multiple + tasks using shared or separate towers. + """ def __init__( self, @@ -47,12 +64,18 @@ class MultiTaskRankingModel(torch.nn.Module): data_config: RecapDataConfig, return_backbone: bool = False, ): - """Constructor for Multi task learning. + """ + Constructor for Multi-task ranking model. + + Args: + input_shapes (Mapping[str, torch.Size]): A mapping of input feature names to their shapes. + config (ModelConfig): The model configuration. + data_config (RecapDataConfig): The data configuration. + return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False. Assumptions made: - 1. Tasks specified in data config match model architecture. - - These are all validated in config. + 1. Tasks specified in data config match model architecture. + These are all validated in config. """ super().__init__() @@ -168,6 +191,23 @@ class MultiTaskRankingModel(torch.nn.Module): labels: Optional[torch.Tensor] = None, weights: Optional[torch.Tensor] = None, ): + """ + Forward pass of the Multi-task ranking model. + + Args: + continuous_features (torch.Tensor): Continuous input features. + binary_features (torch.Tensor): Binary input features. + discrete_features (Optional[torch.Tensor], optional): Discrete input features. Defaults to None. + sparse_features ([type], optional): Sparse input features. Defaults to None. + user_embedding (Optional[torch.Tensor], optional): User embeddings. Defaults to None. + user_eng_embedding (Optional[torch.Tensor], optional): User engagement embeddings. Defaults to None. + author_embedding (Optional[torch.Tensor], optional): Author embeddings. Defaults to None. + labels (Optional[torch.Tensor], optional): Target labels. Defaults to None. + weights (Optional[torch.Tensor], optional): Weights for the loss function. Defaults to None. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the model's outputs. + """ concat_dense_features = [ self._preprocessor(continuous_features=continuous_features, binary_features=binary_features) ] @@ -270,6 +310,20 @@ def create_ranking_model( data_config=None, return_backbone=False, ): + """ + Creates a ranking model based on the provided specifications and configuration. + + Args: + data_spec: The input data specifications. + config (config_mod.RecapConfig): The model configuration. + device (torch.device): The device where the model should be placed. + loss_fn (Optional[Callable], optional): A custom loss function. Defaults to None. + data_config: The data configuration. Defaults to None. + return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False. + + Returns: + torch.nn.Module: The created ranking model. + """ if list(config.model.tasks.values())[0].dlrm_config: raise NotImplementedError() diff --git a/projects/home/recap/model/feature_transform.py b/projects/home/recap/model/feature_transform.py index 32fc3bf..1eb57f6 100644 --- a/projects/home/recap/model/feature_transform.py +++ b/projects/home/recap/model/feature_transform.py @@ -11,21 +11,52 @@ import torch def log_transform(x: torch.Tensor) -> torch.Tensor: - """Safe log transform that works across both negative, zero, and positive floats.""" + """ + Safe log transform that works across both negative, zero, and positive floats. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with log1p applied to absolute values. + """ return torch.sign(x) * torch.log1p(torch.abs(x)) class BatchNorm(torch.nn.Module): def __init__(self, num_features: int, config: BatchNormConfig): + """ + Batch normalization layer. + + Args: + num_features (int): Number of input features. + config (BatchNormConfig): Configuration for batch normalization. + """ super().__init__() self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the batch normalization layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after batch normalization. + """ return self.layer(x) class LayerNorm(torch.nn.Module): def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig): + """ + Layer normalization layer. + + Args: + normalized_shape (Union[int, Sequence[int]]): Size or shape of the input tensor. + config (LayerNormConfig): Configuration for layer normalization. + """ super().__init__() if config.axis != -1: raise NotImplementedError @@ -38,6 +69,16 @@ class LayerNorm(torch.nn.Module): ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the layer normalization layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after layer normalization. + """ + return self.layer(x) @@ -46,11 +87,27 @@ class Log1pAbs(torch.nn.Module): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass that applies a log transformation to the input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with log applied to absolute values. + """ + return log_transform(x) class InputNonFinite(torch.nn.Module): def __init__(self, fill_value: float = 0): + """ + Replaces non-finite (NaN and Inf) values in the input tensor with a specified fill value. + + Args: + fill_value (float): The value to fill non-finite elements with. Default is 0. + """ super().__init__() self.register_buffer( @@ -58,11 +115,27 @@ class InputNonFinite(torch.nn.Module): ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass that replaces non-finite values in the input tensor with the specified fill value. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with non-finite values replaced. + """ return torch.where(torch.isfinite(x), x, self.fill_value) class Clamp(torch.nn.Module): def __init__(self, min_value: float, max_value: float): + """ + Applies element-wise clamping to a tensor, ensuring that values are within a specified range. + + Args: + min_value (float): The minimum value to clamp elements to. + max_value (float): The maximum value to clamp elements to. + """ super().__init__() # Using buffer to make sure they are on correct device (and not moved every time). # Will also be part of state_dict. @@ -74,12 +147,31 @@ class Clamp(torch.nn.Module): ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass that clamps the input tensor element-wise within the specified range. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with elements clamped within the specified range. + """ return torch.clamp(x, min=self.min_value, max=self.max_value) class DoubleNormLog(torch.nn.Module): - """Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.""" + """ + Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features. + Args: + input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes. + config (DoubleNormLogConfig): Configuration for the DoubleNormLog module. + + Attributes: + _before_concat_layers (torch.nn.Sequential): Sequential layers for batch normalization, log transformation, + batch normalization (optional), and clamping. + layer_norm (LayerNorm or None): Layer normalization layer for binary and continuous features (optional). + """ def __init__( self, input_shapes: Mapping[str, Sequence[int]], @@ -108,6 +200,17 @@ class DoubleNormLog(torch.nn.Module): def forward( self, continuous_features: torch.Tensor, binary_features: torch.Tensor ) -> torch.Tensor: + """ + Forward pass that processes continuous and binary features using batch normalization, log transformation, + optional batch normalization (if configured), clamping, and layer normalization (if configured). + + Args: + continuous_features (torch.Tensor): Input tensor of continuous features. + binary_features (torch.Tensor): Input tensor of binary features. + + Returns: + torch.Tensor: Transformed tensor containing both continuous and binary features. + """ x = self._before_concat_layers(continuous_features) x = torch.cat([x, binary_features], dim=1) if self.layer_norm: @@ -118,5 +221,15 @@ class DoubleNormLog(torch.nn.Module): def build_features_preprocessor( config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]] ): - """Trivial right now, but we will change in the future.""" + """ + Build a feature preprocessor module based on the provided configuration. + Trivial right now, but we will change in the future. + + Args: + config (FeaturizationConfig): Configuration for feature preprocessing. + input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes. + + Returns: + DoubleNormLog: An instance of the DoubleNormLog feature preprocessor. + """ return DoubleNormLog(input_shapes, config.double_norm_log_config) diff --git a/projects/home/recap/model/mask_net.py b/projects/home/recap/model/mask_net.py index 43ac89c..951007e 100644 --- a/projects/home/recap/model/mask_net.py +++ b/projects/home/recap/model/mask_net.py @@ -6,93 +6,243 @@ import torch def _init_weights(module): - if isinstance(module, torch.nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - torch.nn.init.constant_(module.bias, 0) + """Initializes weights + + Example + + ```python + import torch + import torch.nn as nn + + # Define a simple linear layer + linear_layer = nn.Linear(64, 32) + + # Initialize the weights and biases using _init_weights + _init_weights(linear_layer) + ``` + + """ + if isinstance(module, torch.nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + torch.nn.init.constant_(module.bias, 0) class MaskBlock(torch.nn.Module): - def __init__( - self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int - ) -> None: - super(MaskBlock, self).__init__() - self.mask_block_config = mask_block_config - output_size = mask_block_config.output_size + """ + MaskBlock module in a mask-based neural network. - if mask_block_config.input_layer_norm: - self._input_layer_norm = torch.nn.LayerNorm(input_dim) - else: - self._input_layer_norm = None + This module represents a MaskBlock, which applies a masking operation to the input data and then + passes it through a hidden layer. It is typically used as a building block within a MaskNet. - if mask_block_config.reduction_factor: - aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor) - elif mask_block_config.aggregation_size is not None: - aggregation_size = mask_block_config.aggregation_size - else: - raise ValueError("Need one of reduction factor or aggregation size.") + Args: + mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock. + input_dim (int): Dimensionality of the input data. + mask_input_dim (int): Dimensionality of the mask input. - self._mask_layer = torch.nn.Sequential( - torch.nn.Linear(mask_input_dim, aggregation_size), - torch.nn.ReLU(), - torch.nn.Linear(aggregation_size, input_dim), - ) - self._mask_layer.apply(_init_weights) - self._hidden_layer = torch.nn.Linear(input_dim, output_size) - self._hidden_layer.apply(_init_weights) - self._layer_norm = torch.nn.LayerNorm(output_size) + Example: + To create and use a MaskBlock within a MaskNet, follow these steps: - def forward(self, net: torch.Tensor, mask_input: torch.Tensor): - if self._input_layer_norm: - net = self._input_layer_norm(net) - hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input)) - return self._layer_norm(hidden_layer_output) + ```python + # Define the configuration for the MaskBlock + mask_block_config = MaskBlockConfig( + input_layer_norm=True, # Apply input layer normalization + reduction_factor=0.5 # Reduce input dimensionality by 50% + ) + + # Create an instance of the MaskBlock + mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32) + + # Generate input tensors + input_data = torch.randn(batch_size, 64) + mask_input = torch.randn(batch_size, 32) + + # Perform a forward pass through the MaskBlock + output = mask_block(input_data, mask_input) + ``` + + Note: + The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking + operation that combines the input and mask input. Then, it passes the result through a hidden layer + with optional dimensionality reduction. + + Warning: + This class is intended for internal use within neural network architectures and should not be + directly accessed or modified by external code. + """ + + def __init__( + self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int + ) -> None: + """ + Initializes the MaskBlock module. + + Args: + mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock. + input_dim (int): Dimensionality of the input data. + mask_input_dim (int): Dimensionality of the mask input. + + Returns: + None + """ + + super(MaskBlock, self).__init__() + self.mask_block_config = mask_block_config + output_size = mask_block_config.output_size + + if mask_block_config.input_layer_norm: + self._input_layer_norm = torch.nn.LayerNorm(input_dim) + else: + self._input_layer_norm = None + + if mask_block_config.reduction_factor: + aggregation_size = int( + mask_input_dim * mask_block_config.reduction_factor) + elif mask_block_config.aggregation_size is not None: + aggregation_size = mask_block_config.aggregation_size + else: + raise ValueError( + "Need one of reduction factor or aggregation size.") + + self._mask_layer = torch.nn.Sequential( + torch.nn.Linear(mask_input_dim, aggregation_size), + torch.nn.ReLU(), + torch.nn.Linear(aggregation_size, input_dim), + ) + self._mask_layer.apply(_init_weights) + self._hidden_layer = torch.nn.Linear(input_dim, output_size) + self._hidden_layer.apply(_init_weights) + self._layer_norm = torch.nn.LayerNorm(output_size) + + def forward(self, net: torch.Tensor, mask_input: torch.Tensor): + """ + Performs a forward pass through the MaskBlock. + + Args: + net (torch.Tensor): Input data tensor. + mask_input (torch.Tensor): Mask input tensor. + + Returns: + torch.Tensor: Output tensor of the MaskBlock. + """ + if self._input_layer_norm: + net = self._input_layer_norm(net) + hidden_layer_output = self._hidden_layer( + net * self._mask_layer(mask_input)) + return self._layer_norm(hidden_layer_output) class MaskNet(torch.nn.Module): - def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int): - super().__init__() - self.mask_net_config = mask_net_config - mask_blocks = [] + """ + MaskNet module in a mask-based neural network. - if mask_net_config.use_parallel: - total_output_mask_blocks = 0 - for mask_block_config in mask_net_config.mask_blocks: - mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features)) - total_output_mask_blocks += mask_block_config.output_size - self._mask_blocks = torch.nn.ModuleList(mask_blocks) - else: - input_size = in_features - for mask_block_config in mask_net_config.mask_blocks: - mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features)) - input_size = mask_block_config.output_size + This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to + create mask-based neural networks with parallel or stacked MaskBlocks. - self._mask_blocks = torch.nn.ModuleList(mask_blocks) - total_output_mask_blocks = mask_block_config.output_size + Args: + mask_net_config (config.MaskNetConfig): Configuration for the MaskNet. + in_features (int): Dimensionality of the input data. - if mask_net_config.mlp: - self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp) - self.out_features = mask_net_config.mlp.layer_sizes[-1] - else: - self.out_features = total_output_mask_blocks - self.shared_size = total_output_mask_blocks + Example: + To create and use a MaskNet, you can follow these steps: - def forward(self, inputs: torch.Tensor): - if self.mask_net_config.use_parallel: - mask_outputs = [] - for mask_layer in self._mask_blocks: - mask_outputs.append(mask_layer(mask_input=inputs, net=inputs)) - # Share the outputs of the MaskBlocks. - all_mask_outputs = torch.cat(mask_outputs, dim=1) - output = ( - all_mask_outputs - if self.mask_net_config.mlp is None - else self._dense_layers(all_mask_outputs)["output"] - ) - return {"output": output, "shared_layer": all_mask_outputs} - else: - net = inputs - for mask_layer in self._mask_blocks: - net = mask_layer(net=net, mask_input=inputs) - # Share the output of the stacked MaskBlocks. - output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"] - return {"output": output, "shared_layer": net} + ```python + # Define the configuration for the MaskNet + mask_net_config = MaskNetConfig( + use_parallel=True, # Use parallel MaskBlocks + mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs + ) + + # Create an instance of the MaskNet + mask_net = MaskNet(mask_net_config, in_features=64) + + # Generate input tensors + input_data = torch.randn(batch_size, 64) + + # Perform a forward pass through the MaskNet + outputs = mask_net(input_data) + + # Access the output and shared layer + output = outputs["output"] + shared_layer = outputs["shared_layer"] + ``` + + Note: + The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked + MaskBlocks. You can also optionally apply an MLP to the outputs for further processing. + + Warning: + This class is intended for internal use within neural network architectures and should not be + directly accessed or modified by external code. + """ + + def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int): + """ + Initializes the MaskNet module. + + Args: + mask_net_config (config.MaskNetConfig): Configuration for the MaskNet. + in_features (int): Dimensionality of the input data. + + Returns: + None + """ + + super().__init__() + self.mask_net_config = mask_net_config + mask_blocks = [] + + if mask_net_config.use_parallel: + total_output_mask_blocks = 0 + for mask_block_config in mask_net_config.mask_blocks: + mask_blocks.append( + MaskBlock(mask_block_config, in_features, in_features)) + total_output_mask_blocks += mask_block_config.output_size + self._mask_blocks = torch.nn.ModuleList(mask_blocks) + else: + input_size = in_features + for mask_block_config in mask_net_config.mask_blocks: + mask_blocks.append( + MaskBlock(mask_block_config, input_size, in_features)) + input_size = mask_block_config.output_size + + self._mask_blocks = torch.nn.ModuleList(mask_blocks) + total_output_mask_blocks = mask_block_config.output_size + + if mask_net_config.mlp: + self._dense_layers = mlp.Mlp( + total_output_mask_blocks, mask_net_config.mlp) + self.out_features = mask_net_config.mlp.layer_sizes[-1] + else: + self.out_features = total_output_mask_blocks + self.shared_size = total_output_mask_blocks + + def forward(self, inputs: torch.Tensor): + """ + Performs a forward pass through the MaskNet. + + Args: + inputs (torch.Tensor): Input data tensor. + + Returns: + torch.Tensor: Output tensor of the MaskNet. + """ + if self.mask_net_config.use_parallel: + mask_outputs = [] + for mask_layer in self._mask_blocks: + mask_outputs.append(mask_layer(mask_input=inputs, net=inputs)) + # Share the outputs of the MaskBlocks. + all_mask_outputs = torch.cat(mask_outputs, dim=1) + output = ( + all_mask_outputs + if self.mask_net_config.mlp is None + else self._dense_layers(all_mask_outputs)["output"] + ) + return {"output": output, "shared_layer": all_mask_outputs} + else: + net = inputs + for mask_layer in self._mask_blocks: + net = mask_layer(net=net, mask_input=inputs) + # Share the output of the stacked MaskBlocks. + output = net if self.mask_net_config.mlp is None else self._dense_layers[ + net]["output"] + return {"output": output, "shared_layer": net} diff --git a/projects/home/recap/model/mlp.py b/projects/home/recap/model/mlp.py index 0301cac..22002ae 100644 --- a/projects/home/recap/model/mlp.py +++ b/projects/home/recap/model/mlp.py @@ -7,13 +7,83 @@ from absl import logging def _init_weights(module): + """Initializes weights + + Example + ------- + ```python + import torch + import torch.nn as nn + + # Define a simple linear layer + linear_layer = nn.Linear(64, 32) + + # Initialize the weights and biases using _init_weights + _init_weights(linear_layer) + ``` + + """ if isinstance(module, torch.nn.Linear): torch.nn.init.xavier_uniform_(module.weight) torch.nn.init.constant_(module.bias, 0) class Mlp(torch.nn.Module): + """ + Multi-Layer Perceptron (MLP) feedforward neural network module in PyTorch. + + This module defines an MLP with customizable layers and activation functions. It is suitable for various + applications such as deep learning for tabular data, feature extraction, and more. + + Args: + in_features (int): The number of input features or input dimensions. + mlp_config (MlpConfig): Configuration object specifying the MLP's architecture. + + Example: + To create an instance of the `Mlp` module and use it for forward passes, you can follow these steps: + + ```python + # Define the configuration for the MLP + mlp_config = MlpConfig( + layer_sizes=[128, 64], # Specify the sizes of hidden layers + batch_norm=True, # Enable batch normalization + dropout=0.2, # Apply dropout with a rate of 0.2 + final_layer_activation=True # Apply ReLU activation to the final layer + ) + + # Create an instance of the MLP module + mlp_model = Mlp(in_features=input_dim, mlp_config=mlp_config) + + # Generate an input tensor + input_tensor = torch.randn(batch_size, input_dim) + + # Perform a forward pass through the MLP + outputs = mlp_model(input_tensor) + + # Access the output and shared layer + output = outputs["output"] + shared_layer = outputs["shared_layer"] + ``` + + Note: + The `Mlp` class allows you to create customizable MLP architectures by specifying the layer sizes, + enabling batch normalization and dropout, and choosing the activation function for the final layer. + + Warning: + This class is intended for internal use within neural network architectures and should not be + directly accessed or modified by external code. + """ def __init__(self, in_features: int, mlp_config: MlpConfig): + """ + Initializes the Mlp module. + + Args: + in_features (int): The number of input features or input dimensions. + mlp_config (MlpConfig): Configuration object specifying the MLP's architecture. + + Returns: + None + """ super().__init__() self._mlp_config = mlp_config input_size = in_features @@ -42,6 +112,15 @@ class Mlp(torch.nn.Module): self.layers.apply(_init_weights) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs a forward pass through the MLP. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, in_features). + + Returns: + torch.Tensor: Output tensor of the MLP. + """ net = x for i, layer in enumerate(self.layers): net = layer(net) @@ -51,8 +130,21 @@ class Mlp(torch.nn.Module): @property def shared_size(self): + """ + Returns the size of the shared layer in the MLP. + + Returns: + int: Size of the shared layer. + """ return self._mlp_config.layer_sizes[-1] @property def out_features(self): + """ + Returns the number of output features from the MLP. + + Returns: + int: Number of output features. + """ + return self._mlp_config.layer_sizes[-1] diff --git a/projects/home/recap/model/model_and_loss.py b/projects/home/recap/model/model_and_loss.py index b42da9b..e16ad09 100644 --- a/projects/home/recap/model/model_and_loss.py +++ b/projects/home/recap/model/model_and_loss.py @@ -5,63 +5,117 @@ from absl import logging class ModelAndLoss(torch.nn.Module): - def __init__( - self, - model, - loss_fn: Callable, - stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None, - ) -> None: """ - Args: - model: torch module to wrap. - loss_fn: Function for calculating loss, should accept logits and labels. - straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification. - """ - super().__init__() - self.model = model - self.loss_fn = loss_fn - self.stratifiers = stratifiers + PyTorch module that combines a neural network model and loss function. - def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] - """Runs model forward and calculates loss according to given loss_fn. + This module wraps a neural network model and facilitates the forward pass through the model + while also calculating the loss based on the model's predictions and provided labels. - NOTE: The input signature here needs to be a Pipelineable object for - prefetching purposes during training using torchrec's pipeline. However - the underlying model signature needs to be exportable to onnx, requiring - generic python types. see https://pytorch.org/docs/stable/onnx.html#types. + Args: + model: The torch module to wrap. + loss_fn (Callable): Function for calculating the loss, which should accept logits and labels. + stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations + for metrics stratification. Each stratifier config includes the name and index of discrete features + to emit for stratification. - """ - outputs = self.model( - continuous_features=batch.continuous_features, - binary_features=batch.binary_features, - discrete_features=batch.discrete_features, - sparse_features=batch.sparse_features, - user_embedding=batch.user_embedding, - user_eng_embedding=batch.user_eng_embedding, - author_embedding=batch.author_embedding, - labels=batch.labels, - weights=batch.weights, - ) - losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float()) + Example: + To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model + and loss function as arguments: - if self.stratifiers: - logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}") - outputs["stratifiers"] = {} - for stratifier in self.stratifiers: - outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index] + ```python + # Create a neural network model + model = YourNeuralNetworkModel() - # In general, we can have a large number of losses returned by our loss function. - if isinstance(losses, dict): - return losses["loss"], { - **outputs, - **losses, - "labels": batch.labels, - "weights": batch.weights, - } - else: # Assume that this is a float. - return losses, { - **outputs, - "loss": losses, - "labels": batch.labels, - "weights": batch.weights, - } + # Define a loss function + loss_fn = torch.nn.CrossEntropyLoss() + + # Create an instance of ModelAndLoss + model_and_loss = ModelAndLoss(model, loss_fn) + + # Generate a batch of training data (e.g., RecapBatch) + batch = generate_training_batch() + + # Perform a forward pass through the model and calculate the loss + loss, outputs = model_and_loss(batch) + + # You can now backpropagate and optimize using the computed loss + loss.backward() + optimizer.step() + ``` + + Note: + The `ModelAndLoss` class simplifies the process of running forward passes through a model and + calculating loss, making it easier to integrate the model into your training loop. Additionally, + it supports the addition of stratifiers for metrics stratification, if needed. + + Warning: + This class is intended for internal use within neural network architectures and should not be + directly accessed or modified by external code. + """ + + def __init__( + self, + model, + loss_fn: Callable, + stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None, + ) -> None: + """ + Initializes the ModelAndLoss module. + + Args: + model: The torch module to wrap. + loss_fn (Callable): Function for calculating the loss, which should accept logits and labels. + stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations + for metrics stratification. + """ + super().__init__() + self.model = model + self.loss_fn = loss_fn + self.stratifiers = stratifiers + + def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] + """Runs model forward and calculates loss according to given loss_fn. + + NOTE: The input signature here needs to be a Pipelineable object for + prefetching purposes during training using torchrec's pipeline. However + the underlying model signature needs to be exportable to onnx, requiring + generic python types. see https://pytorch.org/docs/stable/onnx.html#types. + + """ + outputs = self.model( + continuous_features=batch.continuous_features, + binary_features=batch.binary_features, + discrete_features=batch.discrete_features, + sparse_features=batch.sparse_features, + user_embedding=batch.user_embedding, + user_eng_embedding=batch.user_eng_embedding, + author_embedding=batch.author_embedding, + labels=batch.labels, + weights=batch.weights, + ) + losses = self.loss_fn( + outputs["logits"], batch.labels.float(), batch.weights.float()) + + if self.stratifiers: + logging.info( + f"***** Adding stratifiers *****\n {self.stratifiers}") + outputs["stratifiers"] = {} + for stratifier in self.stratifiers: + outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, + stratifier.index] + + # In general, we can have a large number of losses returned by our loss function. + if isinstance(losses, dict): + return losses["loss"], { + **outputs, + **losses, + "labels": batch.labels, + "weights": batch.weights, + } + else: # Assume that this is a float. + return losses, { + **outputs, + "loss": losses, + "labels": batch.labels, + "weights": batch.weights, + } diff --git a/projects/home/recap/model/numeric_calibration.py b/projects/home/recap/model/numeric_calibration.py index 25f516a..a66869f 100644 --- a/projects/home/recap/model/numeric_calibration.py +++ b/projects/home/recap/model/numeric_calibration.py @@ -2,18 +2,65 @@ import torch class NumericCalibration(torch.nn.Module): - def __init__( - self, - pos_downsampling_rate: float, - neg_downsampling_rate: float, - ): - super().__init__() + """ + Numeric calibration module for adjusting probability scores. - # Using buffer to make sure they are on correct device (and not moved every time). - # Will also be part of state_dict. - self.register_buffer( - "ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True - ) + This module scales probability scores to correct for imbalanced datasets, where positive and negative samples + may be underrepresented or have different ratios. It is designed to be used as a component in a neural network + for tasks such as binary classification. - def forward(self, probs: torch.Tensor): - return probs * self.ratio / (1.0 - probs + (self.ratio * probs)) + Args: + pos_downsampling_rate (float): The downsampling rate for positive samples. + neg_downsampling_rate (float): The downsampling rate for negative samples. + + Example: + To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability + scores like this: + + ```python + # Create a NumericCalibration instance with downsampling rates + calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2) + + # Generate probability scores (e.g., from a neural network) + raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9]) + + # Apply numeric calibration to adjust the probabilities + calibrated_probs = calibration(raw_probs) + + # The `calibrated_probs` now contains the adjusted probability scores + ``` + + Note: + The `NumericCalibration` module is used to adjust probability scores to account for differences in + the number of positive and negative samples in a dataset. It can help improve the calibration of + probability estimates in imbalanced classification problems. + + Warning: + This class is intended for internal use within neural network architectures and should not be + directly accessed or modified by external code. + """ + + def __init__( + self, + pos_downsampling_rate: float, + neg_downsampling_rate: float, + ): + """ + Apply numeric calibration to probability scores. + + Args: + probs (torch.Tensor): Probability scores to be calibrated. + + Returns: + torch.Tensor: Calibrated probability scores. + """ + super().__init__() + + # Using buffer to make sure they are on correct device (and not moved every time). + # Will also be part of state_dict. + self.register_buffer( + "ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True + ) + + def forward(self, probs: torch.Tensor): + return probs * self.ratio / (1.0 - probs + (self.ratio * probs)) diff --git a/projects/home/recap/optimizer/config.py b/projects/home/recap/optimizer/config.py index f87c591..08e0359 100644 --- a/projects/home/recap/optimizer/config.py +++ b/projects/home/recap/optimizer/config.py @@ -9,12 +9,60 @@ import pydantic class RecapAdamConfig(base_config.BaseConfig): + """ + Configuration settings for the Adam optimizer used in Recap. + + Args: + beta_1 (float): Momentum term (default: 0.9). + beta_2 (float): Exponential weighted decay factor (default: 0.999). + epsilon (float): Numerical stability in the denominator (default: 1e-7). + + Example: + To define an Adam optimizer configuration for Recap, use: + + ```python + adam_config = RecapAdamConfig(beta_1=0.9, beta_2=0.999, epsilon=1e-7) + ``` + + Note: + This class configures the parameters of the Adam optimizer, which is commonly used for optimizing neural networks. + + Warning: + This class is intended for internal use within Recap and should not be directly accessed or modified by external code. + """ + beta_1: float = 0.9 # Momentum term. beta_2: float = 0.999 # Exponential weighted decay factor. epsilon: float = 1e-7 # Numerical stability in denominator. class MultiTaskLearningRates(base_config.BaseConfig): + """ + Configuration settings for multiple learning rates in Recap. + + Args: + tower_learning_rates (Dict[str, optimizers_config_mod.LearningRate]): Learning rates for different towers of the model. + backbone_learning_rate (optimizers_config_mod.LearningRate): Learning rate for the model's backbone (default: None). + + Example: + To define multiple learning rates for different towers in Recap, use: + + ```python + multi_task_lr = MultiTaskLearningRates( + tower_learning_rates={ + 'task1': learning_rate1, + 'task2': learning_rate2, + }, + backbone_learning_rate=backbone_lr, + ) + ``` + + Note: + This class allows specifying different learning rates for different parts of the model, including task-specific towers and the backbone. + + Warning: + This class is intended for internal use within Recap and should not be directly accessed or modified by external code. + """ tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field( description="Learning rates for different towers of the model." ) @@ -25,6 +73,30 @@ class MultiTaskLearningRates(base_config.BaseConfig): class RecapOptimizerConfig(base_config.BaseConfig): + """ + Configuration settings for the Recap optimizer. + + Args: + multi_task_learning_rates (MultiTaskLearningRates): Multiple learning rates for different tasks (optional). + single_task_learning_rate (optimizers_config_mod.LearningRate): Learning rate for a single task (optional). + adam (RecapAdamConfig): Configuration settings for the Adam optimizer. + + Example: + To define an optimizer configuration for training with Recap, use: + + ```python + optimizer_config = RecapOptimizerConfig( + multi_task_learning_rates=multi_task_lr, + single_task_learning_rate=single_task_lr, + adam=adam_config, + ) + ``` + + Warning: + This class is intended for internal use to configure the optimizer settings within Recap and should not be + directly accessed by external code. + """ + multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field( None, description="Multiple learning rates for different tasks.", one_of="lr" ) diff --git a/projects/home/recap/optimizer/optimizer.py b/projects/home/recap/optimizer/optimizer.py index c5b0cf1..dbffecd 100644 --- a/projects/home/recap/optimizer/optimizer.py +++ b/projects/home/recap/optimizer/optimizer.py @@ -23,12 +23,30 @@ _DENSE_EMBEDDINGS = "dense_ebc" class RecapLRShim(torch.optim.lr_scheduler._LRScheduler): - """Shim to get learning rates into a LRScheduler. - - This adheres to the torch.optim scheduler API and can be plugged anywhere that - e.g. exponential decay can be used. - """ + A shim to get learning rates into a LRScheduler. + + This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports + learning rate schedules, such as exponential decay. + + Args: + optimizer: The optimizer to which this scheduler is applied. + lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations. + emb_learning_rate: The learning rate for embeddings (optional). + last_epoch (int): The index of the last epoch (default: -1). + verbose (bool): If True, print warnings for deprecated functions (default: False). + + Example: + To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use: + + ```python + scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate) + ``` + + Warning: + This class is intended for internal use to handle learning rate scheduling within Recap training and should not + be directly accessed by external code. + """ def __init__( self, @@ -80,15 +98,25 @@ def build_optimizer( optimizer_config: config.OptimizerConfig, emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None, ): - """Builds an optimizer and scheduler. - - Args: - model: A torch model, probably with DDP/DMP. - optimizer_config: An OptimizerConfig object that specifies learning rates per tower. - - Returns: - A torch.optim instance, and a scheduler instance. """ + Build an optimizer and scheduler for training. + + Args: + model: The torch model, possibly with DDP/DMP. + optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer. + emb_optimizer_config: Configuration settings for embedding optimization (optional). + + Returns: + torch.optim.Optimizer: The optimizer for training. + RecapLRShim: The learning rate scheduler for the optimizer. + + Example: + To build an optimizer and scheduler for training, use: + + ```python + optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config) + ``` + """ optimizer_fn = functools.partial( torch.optim.Adam, lr=_DEFAULT_LR, diff --git a/projects/twhin/data/config.py b/projects/twhin/data/config.py index 55e3cec..05c70c4 100644 --- a/projects/twhin/data/config.py +++ b/projects/twhin/data/config.py @@ -4,6 +4,17 @@ import pydantic class TwhinDataConfig(base_config.BaseConfig): + """ + Configuration for Twhin model training data. + + Args: + data_root (str): The root directory for the training data. + per_replica_batch_size (pydantic.PositiveInt): Batch size per replica. + global_negatives (int): The number of global negatives. + in_batch_negatives (int): The number of in-batch negatives. + limit (pydantic.PositiveInt): The limit on the number of data points to use. + offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None. + """ data_root: str per_replica_batch_size: pydantic.PositiveInt global_negatives: int diff --git a/projects/twhin/data/data.py b/projects/twhin/data/data.py index 37bf8f5..b0afe7f 100644 --- a/projects/twhin/data/data.py +++ b/projects/twhin/data/data.py @@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig): + """ + Create a dataset for Twhin model training. + + Args: + data_config (TwhinDataConfig): The data configuration for the dataset. + model_config (TwhinModelConfig): The model configuration containing embeddings and relations. + + Returns: + EdgesDataset: The dataset for Twhin model training. + """ tables = model_config.embeddings.tables table_sizes = {table.name: table.num_embeddings for table in tables} relations = model_config.relations diff --git a/projects/twhin/data/edges.py b/projects/twhin/data/edges.py index f7864b1..778a188 100644 --- a/projects/twhin/data/edges.py +++ b/projects/twhin/data/edges.py @@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @dataclass class EdgeBatch(DataclassBatch): + """ + Batch data structure for edge-based models. + + Args: + nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings. + labels (torch.Tensor): Tensor containing labels. + rels (torch.Tensor): Tensor containing relation information. + weights (torch.Tensor): Tensor containing weights. + """ nodes: KeyedJaggedTensor labels: torch.Tensor rels: torch.Tensor @@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch): class EdgesDataset(Dataset): + """ + Dataset for edge-based models. + + Args: + file_pattern (str): The file pattern for the dataset. + table_sizes (Dict[str, int]): A dictionary of table names and their sizes. + relations (List[Relation]): A list of relations between tables. + lhs_column_name (str): The name of the left-hand-side column. + rhs_column_name (str): The name of the right-hand-side column. + rel_column_name (str): The name of the relation column. + **dataset_kwargs: Additional keyword arguments for the parent Dataset class. + """ rng = np.random.default_rng() def __init__( @@ -56,6 +77,15 @@ class EdgesDataset(Dataset): super().__init__(file_pattern=file_pattern, **dataset_kwargs) def pa_to_batch(self, batch: pa.RecordBatch): + """ + Converts a pyarrow RecordBatch to an EdgeBatch. + + Args: + batch (pa.RecordBatch): A pyarrow RecordBatch containing data. + + Returns: + EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights. + """ lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy()) rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy()) rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy()) @@ -74,6 +104,14 @@ class EdgesDataset(Dataset): ) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: """Process edges that contain lhs index, rhs index, relation index. + + Args: + lhs (torch.Tensor): Tensor containing left-hand-side indices. + rhs (torch.Tensor): Tensor containing right-hand-side indices. + rel (torch.Tensor): Tensor containing relation indices. + + Returns: + Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs. Example: ``` @@ -147,6 +185,12 @@ class EdgesDataset(Dataset): return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths) def to_batches(self): + """ + Converts data to batches. + + Yields: + pa.RecordBatch: A pyarrow RecordBatch containing data. + """ ds = super().to_batches() batch_size = self._dataset_kwargs["batch_size"] diff --git a/projects/twhin/data/test_edges.py b/projects/twhin/data/test_edges.py index 89805be..a2533c6 100644 --- a/projects/twhin/data/test_edges.py +++ b/projects/twhin/data/test_edges.py @@ -23,6 +23,25 @@ RELATIONS = [ def test_gen(): + """Test function for generating edge-based datasets and dataloaders. + + This function generates a synthetic dataset and tests the creation of an `EdgesDataset` + instance and a dataloader for it. + + The test includes the following steps: + 1. Create synthetic data with left-hand-side (lhs), right-hand-side (rhs), and relation (rel) columns. + 2. Write the synthetic data to a Parquet file. + 3. Create an `EdgesDataset` instance with the Parquet file pattern, table sizes, relations, and batch size. + 4. Initialize the local file system for the dataset. + 5. Create a dataloader for the dataset and retrieve the first batch. + 6. Assert that the labels in the batch are positive. + 7. Verify that the positive examples in the batch match the expected values. + + This function serves as a test case for the data generation and dataset creation process. + + Raises: + AssertionError: If any of the test assertions fail. + """ import os import tempfile diff --git a/projects/twhin/models/config.py b/projects/twhin/models/config.py index aafc9e6..c7b2e1c 100644 --- a/projects/twhin/models/config.py +++ b/projects/twhin/models/config.py @@ -10,8 +10,29 @@ from pydantic import validator class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): + """ + Configuration class for Twhin model embeddings. + + This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types + for all tables in the Twhin model embeddings configuration match. + + Attributes: + tables (List[TableConfig]): A list of table configurations for the model's embeddings. + """ @validator("tables") def embedding_dims_match(cls, tables): + """ + Validate that embedding dimensions and data types match for all tables. + + Args: + tables (List[TableConfig]): List of table configurations. + + Returns: + List[TableConfig]: The list of validated table configurations. + + Raises: + AssertionError: If embedding dimensions or data types do not match. + """ embedding_dim = tables[0].embedding_dim data_type = tables[0].data_type for table in tables: @@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): class Operator(str, enum.Enum): + """ + Enumeration of operator types. + + This enumeration defines different types of operators that can be applied to Twhin model relations. + """ TRANSLATION = "translation" class Relation(pydantic.BaseModel): - """graph relationship properties and operator""" + """ + Configuration class for graph relationships in the Twhin model. + + This class defines properties and operators for graph relationships in the Twhin model. + + Attributes: + name (str): The name of the relationship. + lhs (str): The name of the entity on the left-hand side of the relation. + rhs (str): The name of the entity on the right-hand side of the relation. + operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product. + """ name: str = pydantic.Field(..., description="Relationship name.") lhs: str = pydantic.Field( @@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel): class TwhinModelConfig(base_config.BaseConfig): + """ + Configuration class for the Twhin model. + + This class defines configuration options specific to the Twhin model. + + Attributes: + embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings. + relations (List[Relation]): List of graph relationship configurations. + translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation. + """ embeddings: TwhinEmbeddingsConfig relations: typing.List[Relation] translation_optimizer: OptimizerConfig @validator("relations", each_item=True) def valid_node_types(cls, relation, values, **kwargs): + """ + Validate that the specified node types in relations are valid table names in embeddings. + + Args: + relation (Relation): A single relation configuration. + values (dict): The values dictionary containing the "embeddings" configuration. + + Returns: + Relation: The validated relation configuration. + + Raises: + AssertionError: If the specified node types are not valid table names in embeddings. + """ table_names = [table.name for table in values["embeddings"].tables] assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}" assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}" diff --git a/projects/twhin/models/models.py b/projects/twhin/models/models.py index b1dfaa7..0b5f1cb 100644 --- a/projects/twhin/models/models.py +++ b/projects/twhin/models/models.py @@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa class TwhinModel(nn.Module): + """ + Twhin model for graph-based entity embeddings and translation. + + This class defines the Twhin model, which is used for learning embeddings of entities in a graph + and applying translations to these embeddings based on graph relationships. + + Args: + model_config (TwhinModelConfig): Configuration for the Twhin model. + data_config (TwhinDataConfig): Configuration for the data used by the model. + + Attributes: + batch_size (int): The batch size used for training. + table_names (List[str]): Names of tables in the model's embeddings. + large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings. + embedding_dim (int): Dimensionality of entity embeddings. + num_tables (int): Number of tables in the model's embeddings. + in_batch_negatives (int): Number of in-batch negative samples to use during training. + global_negatives (int): Number of global negative samples to use during training. + num_relations (int): Number of graph relationships in the model. + all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings. + + """ def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig): super().__init__() self.batch_size = data_config.per_replica_batch_size @@ -31,7 +53,17 @@ class TwhinModel(nn.Module): ) def forward(self, batch: EdgeBatch): + """ + Forward pass of the Twhin model. + Args: + batch (EdgeBatch): Input batch containing graph edge information. + + Returns: + dict: A dictionary containing model output with "logits" and "probabilities". + - "logits" (torch.Tensor): Logit scores. + - "probabilities" (torch.Tensor): Sigmoid probabilities. + """ # B x D trans_embs = self.all_trans_embs.data[batch.rels] @@ -98,6 +130,18 @@ class TwhinModel(nn.Module): def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig): + """ + Apply optimizers to the Twhin model's embeddings. + + This function applies optimizers to the embeddings of the Twhin model based on the provided configuration. + + Args: + model (TwhinModel): The Twhin model to apply optimizers to. + model_config (TwhinModelConfig): Configuration for the Twhin model. + + Returns: + TwhinModel: The Twhin model with optimizers applied to its embeddings. + """ for table in model_config.embeddings.tables: optimizer_class = get_optimizer_class(table.optimizer) optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict() @@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module): device: torch.device, ) -> None: """ - Args: - model: torch module to wrap. - loss_fn: Function for calculating loss, should accept logits and labels. - """ + Initialize a TwhinModelAndLoss module. + + Args: + model: The torch module to wrap. + loss_fn: A function for calculating loss, should accept logits and labels. + data_config: Configuration for Twhin data. + device: The torch device to use for calculations. + """ super().__init__() self.model = model self.loss_fn = loss_fn @@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module): self.device = device def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] - """Runs model forward and calculates loss according to given loss_fn. - - NOTE: The input signature here needs to be a Pipelineable object for - prefetching purposes during training using torchrec's pipeline. However - the underlying model signature needs to be exportable to onnx, requiring - generic python types. see https://pytorch.org/docs/stable/onnx.html#types. - """ + Run the model forward and calculate the loss according to the given loss_fn. + + NOTE: The input signature here needs to be a Pipelineable object for + prefetching purposes during training using torchrec's pipeline. However + the underlying model signature needs to be exportable to onnx, requiring + generic python types. see https://pytorch.org/docs/stable/onnx.html#types + + Args: + batch ("RecapBatch"): The input batch for model inference. + + Returns: + Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of + additional outputs including logits, labels, and weights. + """ outputs = self.model(batch) logits = outputs["logits"] diff --git a/projects/twhin/models/test_models.py b/projects/twhin/models/test_models.py index f8203ef..f76918c 100644 --- a/projects/twhin/models/test_models.py +++ b/projects/twhin/models/test_models.py @@ -18,6 +18,12 @@ EMB_DIM = 128 def twhin_model_config() -> TwhinModelConfig: + """ + Create a configuration for the Twhin model. + + Returns: + TwhinModelConfig: The Twhin model configuration. + """ sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01)) sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) @@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig: def twhin_data_config() -> TwhinDataConfig: + """ + Create a configuration for the Twhin data. + + Returns: + TwhinDataConfig: The Twhin data configuration. + """ data_config = TwhinDataConfig( data_root="/", per_replica_batch_size=10, @@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig: def test_twhin_model(): + """ + Test the Twhin model creation and optimization. + + This function creates a Twhin model using the specified configuration and tests its optimization. It also checks + the device placement of model parameters. + + Returns: + None + """ model_config = twhin_model_config() loss_fn = F.binary_cross_entropy_with_logits @@ -84,6 +105,23 @@ def test_twhin_model(): def test_unequal_dims(): + """ + Test function for validating unequal embedding dimensions in TwhinEmbeddingsConfig. + + This function tests whether the validation logic correctly raises a `ValidationError` when + embedding dimensions in the `TwhinEmbeddingsConfig` are not equal for all tables. + + The test includes the following steps: + 1. Create two embedding configurations with different embedding dimensions. + 2. Attempt to create a `TwhinEmbeddingsConfig` instance with the unequal embedding dimensions. + 3. Assert that a `ValidationError` is raised, indicating that embedding dimensions must match. + + This function serves as a test case to ensure that the validation logic enforces equal embedding dimensions + in the `TwhinEmbeddingsConfig` for all tables. + + Raises: + AssertionError: If the expected `ValidationError` is not raised. + """ sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02)) sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05)) table0 = EmbeddingBagConfig( diff --git a/projects/twhin/optimizer.py b/projects/twhin/optimizer.py index 7c43b56..e39e0d0 100644 --- a/projects/twhin/optimizer.py +++ b/projects/twhin/optimizer.py @@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt" def _lr_from_config(optimizer_config): + """Get the learning rate from an optimizer configuration. + + Args: + optimizer_config: Optimizer configuration. + + Returns: + Learning rate from the optimizer configuration. + """ if optimizer_config.learning_rate is not None: return optimizer_config.learning_rate else: @@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config): def build_optimizer(model: TwhinModel, config: TwhinModelConfig): """Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations. - Args: - model: TwhinModel to build optimizer for. - config: TwhinConfig for model. + Args: + model: TwhinModel to build optimizer for. + config: TwhinModelConfig for model. - Returns: - Optimizer for model. - """ + Returns: + Optimizer for model. + """ translation_optimizer_fn = functools.partial( get_optimizer_class(config.translation_optimizer), **get_optimizer_algorithm_config(config.translation_optimizer).dict(), diff --git a/projects/twhin/run.py b/projects/twhin/run.py index 12081e0..6124bc1 100644 --- a/projects/twhin/run.py +++ b/projects/twhin/run.py @@ -37,6 +37,12 @@ def run( all_config: TwhinConfig, save_dir: Optional[str] = None, ): + """Run the training process for TwhinModel. + + Args: + all_config (TwhinConfig): The configuration for the entire Twhin model. + save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None. + """ train_dataset = create_dataset(all_config.train_data, all_config.model) if env.is_reader(): @@ -80,6 +86,11 @@ def run( def main(argv): + """Main entry point for the Twhin training script. + + Args: + argv: Command-line arguments. + """ logging.info("Starting") logging.info(f"parsing config from {FLAGS.config_yaml_path}...") diff --git a/reader/dataset.py b/reader/dataset.py index 6e811cc..5d873cd 100644 --- a/reader/dataset.py +++ b/reader/dataset.py @@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging class _Reader(pa.flight.FlightServerBase): - """Distributed reader flight server wrapping a dataset.""" + """ + Distributed reader flight server wrapping a dataset. + + This class implements a Flight server that wraps a dataset, allowing clients to retrieve data + from the dataset over the Flight protocol. It is designed to be used in a distributed environment + for efficient data access. + + Args: + location (str): The location of the Flight server. + ds (Dataset): The dataset to be wrapped by the Flight server. + + Attributes: + _location (str): The location of the Flight server. + _ds (Dataset): The dataset wrapped by the Flight server. + + Methods: + do_get(_, __): Handles Flight requests for data retrieval. + + Note: + Flight is an Apache Arrow project that provides a framework for efficient data transfer. + This class allows clients to retrieve data from the dataset using Flight. + + """ def __init__(self, location: str, ds: "Dataset"): + """ + Initialize a new _Reader instance. + + Args: + location (str): The location of the Flight server. + ds (Dataset): The dataset to be wrapped by the Flight server. + """ super().__init__(location=location) self._location = location self._ds = ds def do_get(self, _, __): + """ + Handle Flight requests for data retrieval. + + This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol. + + Args: + _: Unused argument. + __: Unused argument. + + Returns: + pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset. + + Note: + An updated schema (to account for column selection) must be given to the stream. + """ # NB: An updated schema (to account for column selection) has to be given the stream. schema = next(iter(self._ds.to_batches())).schema batches = self._ds.to_batches() @@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase): class Dataset(torch.utils.data.IterableDataset): + """ + A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading. + + This class enables efficient loading of data from Parquet files using PyArrow. + It is designed to be used as an IterableDataset in PyTorch for training and inference. + + Args: + file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset. + **dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method. + Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset + for more details. + + Attributes: + LOCATION (str): The default location for the Flight server used for data distribution. + _file_pattern (str): The glob pattern specifying Parquet files in the dataset. + _fs: The filesystem object used for file operations. + _dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method. + _files (list): A list of file paths matching the glob pattern. + _schema (pa.Schema): The schema of the Parquet dataset. + + Methods: + serve(): Start serving the dataset using a Flight server. + to_batches(): Generate batches of data from the Parquet dataset. + pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch. + dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset. + + Note: + This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch + to create DataLoader instances for training or inference. + """ LOCATION = "grpc://0.0.0.0:2222" def __init__(self, file_pattern: str, **dataset_kwargs) -> None: - """Specify batch size and column to select for. + """ + Initialize a new Dataset instance. Specify batch size and column to select for. Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset. - """ + + + Args: + file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset. + **dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method. + """ self._file_pattern = file_pattern self._fs = infer_fs(self._file_pattern) self._dataset_kwargs = dataset_kwargs @@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset): self._validate_columns() def _validate_columns(self): + """ + Validate the specified columns against the dataset schema. + + Raises: + Exception: If any specified columns are not found in the dataset schema. + """ columns = set(self._dataset_kwargs.get("columns", [])) wrong_columns = set(columns) - set(self._schema.names) if wrong_columns: raise Exception(f"Specified columns {list(wrong_columns)} not in schema.") def serve(self): + """Start serving the dataset using a Flight server.""" self.reader = _Reader(location=self.LOCATION, ds=self) self.reader.serve() def _create_dataset(self): + """Create a PyArrow dataset for data retrieval.""" + return pads.dataset( source=random.sample(self._files, len(self._files))[0], format="parquet", @@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset): @abc.abstractmethod def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch: + """ + Convert a Parquet RecordBatch to a custom data batch. + + Args: + batch (pa.RecordBatch): A batch of data from the Parquet dataset. + + Returns: + DataclassBatch: A custom data batch used in PyTorch training. + + Raises: + NotImplementedError: This method must be implemented in derived classes. + """ raise NotImplementedError def dataloader(self, remote: bool = False): + """ + Create a PyTorch DataLoader for iterating through the dataset. + + Args: + remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training. + + Returns: + DataLoader: A PyTorch DataLoader for iterating through the dataset. + + Note: + If `remote` is True, a remote DataLoader is created for distributed training using Flight. + """ if not remote: return map(self.pa_to_batch, self.to_batches()) readers = get_readers(2) @@ -117,6 +230,25 @@ GRPC_OPTIONS = [ def get_readers(num_readers_per_worker: int): + """ + Get Flight readers for distributed data loading. + + This function retrieves Flight readers for distributed data loading in a PyTorch environment. + + Args: + num_readers_per_worker (int): The number of Flight readers to retrieve per worker. + + Returns: + List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading. + + Note: + Flight readers are used to fetch data in a distributed manner for efficient data loading. + + Example: + To obtain Flight readers, use the following code: + + >>> readers = get_readers(num_readers_per_worker=2) + """ addresses = env.get_flight_server_addresses() readers = [] diff --git a/reader/dds.py b/reader/dds.py index 7b49893..54b2c02 100644 --- a/reader/dds.py +++ b/reader/dds.py @@ -21,6 +21,16 @@ import torch.distributed as dist def maybe_start_dataset_service(): + """ + Start the dataset service if readers are available and required dependencies are met. + + This function checks if readers are available and if the required TensorFlow version is >= 2.5. + If both conditions are met and the current environment is the dispatcher or reader, it starts + the TensorFlow dataset service. + + Raises: + Exception: If the required TensorFlow version is not met (>= 2.5). + """ if not env.has_readers(): return @@ -59,6 +69,24 @@ def maybe_start_dataset_service(): def register_dataset( dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO" ): + """ + Register a dataset with the distributed dataset service. + + This function registers a dataset with the distributed dataset service and broadcasts the dataset ID + and job name to all processes in the distributed environment. + + Args: + dataset (tf.data.Dataset): The dataset to be registered. + dataset_service (str): The name of the dataset service. + compression (Optional[str]): The compression type for the dataset (default is "AUTO"). + + Returns: + Tuple[int, str]: A tuple containing the dataset ID and job name. + + Note: + This function should be called on the rank 0 process. + + """ if dist.get_rank() == 0: dataset_id = _register_dataset( service=dataset_service, @@ -82,6 +110,23 @@ def distribute_from_dataset_id( compression: Optional[str] = "AUTO", prefetch: Optional[int] = tf.data.experimental.AUTOTUNE, ) -> tf.data.Dataset: + """ + Distribute a dataset from a registered dataset ID. + + This function consumes a dataset from the distributed dataset service using the provided dataset ID + and job name. It also supports prefetching for improved performance. + + Args: + dataset_service (str): The name of the dataset service. + dataset_id (int): The ID of the dataset to be consumed. + job_name (Optional[str]): The name of the job associated with the dataset (optional). + compression (Optional[str]): The compression type for the dataset (default is "AUTO"). + prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE). + + Returns: + tf.data.Dataset: The distributed dataset. + + """ logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}") dataset = _from_dataset_id( processing_mode="parallel_epochs", @@ -97,15 +142,28 @@ def distribute_from_dataset_id( def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset: - """Torch-compatible and distributed-training-aware dataset service distributor. - - - rank 0 process will register the given dataset. - - rank 0 process will broadcast job name and dataset id. - - all rank processes will consume from the same job/dataset. - - Without this, dataset workers will try to serve 1 job per rank process and OOM. - """ + Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption. + + This function is used to distribute a dataset in a distributed training environment. It performs the + following steps: + - On the rank 0 process, it registers the given dataset with the distributed dataset service. + - It broadcasts the job name and dataset ID to all rank processes. + - All rank processes then consume the same dataset from the distributed dataset service. + + Args: + dataset (tf.data.Dataset): The TensorFlow dataset to be distributed. + + Returns: + tf.data.Dataset: The distributed TensorFlow dataset. + + Note: + - If there are no reader processes in the distributed environment, the original dataset is returned + without any distribution. + - This function is intended for use in distributed training environments to prevent out-of-memory (OOM) + issues caused by each rank process trying to serve one job. + + """ if not env.has_readers(): return dataset dataset_service = env.get_dds() diff --git a/reader/test_dataset.py b/reader/test_dataset.py index 85e756f..d9400e4 100644 --- a/reader/test_dataset.py +++ b/reader/test_dataset.py @@ -12,6 +12,17 @@ import torch def create_dataset(tmpdir): + """ + Create a mock dataset for testing. + + This function creates a mock dataset using PyArrow and Parquet for testing purposes. + + Args: + tmpdir: A temporary directory where the dataset will be created. + + Returns: + MockDataset: A mock dataset for testing. + """ table = pa.table( { @@ -34,6 +45,14 @@ def create_dataset(tmpdir): def test_dataset(tmpdir): + """ + Test the created dataset. + + This function tests the created mock dataset and checks if it behaves as expected. + + Args: + tmpdir: A temporary directory used for testing. + """ ds = create_dataset(tmpdir) batch = next(iter(ds.dataloader(remote=False))) assert batch.batch_size == 2 @@ -46,6 +65,14 @@ def test_dataset(tmpdir): reason="Multiprocessing doesn't work on github yet.", ) def test_distributed_dataset(tmpdir): + """ + Test the distributed dataset. + + This function tests the distributed version of the mock dataset using multiprocessing. + + Args: + tmpdir: A temporary directory used for testing. + """ MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"} def _client(): diff --git a/reader/utils.py b/reader/utils.py index fc0e34c..c5c9721 100644 --- a/reader/utils.py +++ b/reader/utils.py @@ -11,11 +11,55 @@ import torch def roundrobin(*iterables): - """Round robin through provided iterables, useful for simple load balancing. - - Adapted from https://docs.python.org/3/library/itertools.html. - """ + Iterate through provided iterables in a round-robin fashion. + + This function takes multiple iterables and returns an iterator that yields elements from + each iterable in a round-robin manner. It continues cycling through the iterables until + all of them are exhausted. + + Adapted from https://docs.python.org/3/library/itertools.html. + + Args: + *iterables: One or more iterable objects to iterate through. + + Yields: + Elements from the provided iterables in a round-robin fashion. + + Raises: + StopIteration: If all provided iterables are exhausted. + + Example: + ```python + iterable1 = [1, 2, 3] + iterable2 = ['a', 'b', 'c'] + iterable3 = [0.1, 0.2, 0.3] + + for item in roundrobin(iterable1, iterable2, iterable3): + print(item) + + # Output: + # 1 + # 'a' + # 0.1 + # 2 + # 'b' + # 0.2 + # 3 + # 'c' + # 0.3 + ``` + + Note: + - If one of the provided iterables is shorter than the others, the function will + continue iterating through the remaining iterables until all are exhausted. + - If an iterable raises an exception during iteration, a warning message is logged, + and the function continues with the next iterable. + + See Also: + - `itertools.cycle`: A function that repeatedly cycles through elements of an iterable. + - `itertools.islice`: A function to slice an iterable to limit the number of iterations. + """ num_active = len(iterables) nexts = itertools.cycle(iter(it).__next__ for it in iterables) while num_active: @@ -35,6 +79,48 @@ def roundrobin(*iterables): def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]): + """ + Monitor the speed and progress of data loading using a data loader. + + This function iterates through a data loader for a specified number of steps or until + the end of the data loader is reached, periodically logging progress information. + + Args: + data_loader: The data loader to monitor. + max_steps: The maximum number of steps to iterate through the data loader. + frequency: The frequency (in steps) at which to log progress. + peek (optional): If specified, it indicates the frequency (in steps) at which to log + batch contents for inspection. + + Example: + ```python + import torch + from torch.utils.data import DataLoader + + # Create a data loader (replace with your own DataLoader configuration) + data_loader = DataLoader(dataset, batch_size=32, shuffle=True) + + # Monitor data loading speed and progress + speed_check(data_loader, max_steps=1000, frequency=50, peek=500) + ``` + + Args: + data_loader: The data loader to monitor. + max_steps: The maximum number of steps to iterate through the data loader. + frequency: The frequency (in steps) at which to log progress. + peek (optional): If specified, it indicates the frequency (in steps) at which to log + batch contents for inspection. + + Note: + - The function logs information about elapsed time, the number of examples processed, + and the processing speed in examples per second. + - If `peek` is provided, batch contents will be logged for inspection at the specified + frequency. + + See Also: + - `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and + iterating through datasets. + """ num_examples = 0 prev = time.perf_counter() for idx, batch in enumerate(data_loader): @@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int] def pa_to_torch(array: pa.array) -> torch.Tensor: + """ + Convert a PyArrow Array to a PyTorch Tensor. + + Args: + array (pa.array): The PyArrow Array to convert. + + Returns: + torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array. + + Example: + ```python + import pyarrow as pa + import torch + + # Create a PyArrow Array + arrow_array = pa.array([1, 2, 3]) + + # Convert it to a PyTorch Tensor + torch_tensor = pa_to_torch(arrow_array) + ``` + """ return torch.from_numpy(array.to_numpy()) def create_default_pa_to_batch(schema) -> DataclassBatch: - """ """ + """ + Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data. + + Args: + schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch. + + Returns: + callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch. + + Example: + ```python + import pyarrow as pa + from dataclass_batch import DataclassBatch + + # Define a PyArrow schema + schema = pa.schema([ + ("feature1", pa.float64()), + ("feature2", pa.int64()), + ("label", pa.int64()), + ]) + + # Create the conversion function + pa_to_batch = create_default_pa_to_batch(schema) + + # Create a PyArrow RecordBatch + record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({ + "feature1": [1.0, 2.0, None], + "feature2": [10, 20, 30], + "label": [0, 1, None], + })) + + # Convert the RecordBatch to a custom DataclassBatch + custom_batch = pa_to_batch(record_batch) + ``` + """ _CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema) def get_imputation_value(pa_type): diff --git a/tools/pq.py b/tools/pq.py index 24c6345..b18f68a 100644 --- a/tools/pq.py +++ b/tools/pq.py @@ -38,6 +38,15 @@ import pyarrow.parquet as pq def _create_dataset(path: str): + """ + Create a PyArrow dataset from Parquet files located at the specified path. + + Args: + path (str): The path to the Parquet files. + + Returns: + pyarrow.dataset.Dataset: The PyArrow dataset. + """ fs = infer_fs(path) files = fs.glob(path) return pads.dataset(files, format="parquet", filesystem=fs) @@ -47,12 +56,27 @@ class PqReader: def __init__( self, path: str, num: int = 10, batch_size: int = 1024, columns: Optional[List[str]] = None ): + """ + Initialize a Parquet Reader. + + Args: + path (str): The path to the Parquet files. + num (int): The maximum number of rows to read. + batch_size (int): The batch size for reading data. + columns (Optional[List[str]]): A list of column names to read (default is None, which reads all columns). + """ self._ds = _create_dataset(path) self._batch_size = batch_size self._num = num self._columns = columns def __iter__(self): + """ + Iterate through the Parquet data and yield batches of rows. + + Yields: + pyarrow.RecordBatch: A batch of rows. + """ batches = self._ds.to_batches(batch_size=self._batch_size, columns=self._columns) rows_seen = 0 for count, record in enumerate(batches): @@ -62,6 +86,12 @@ class PqReader: rows_seen += record.data.num_rows def _head(self): + """ + Get the first `num` rows of the Parquet data. + + Returns: + pyarrow.RecordBatch: A batch of rows. + """ total_read = self._num * self.bytes_per_row if total_read >= int(500e6): raise Exception( @@ -71,6 +101,12 @@ class PqReader: @property def bytes_per_row(self) -> int: + """ + Calculate the estimated bytes per row in the dataset. + + Returns: + int: The estimated bytes per row. + """ nbits = 0 for t in self._ds.schema.types: try: @@ -81,18 +117,23 @@ class PqReader: return nbits // 8 def schema(self): + """ + Display the schema of the Parquet dataset. + """ print(f"\n# Schema\n{self._ds.schema}") def head(self): - """Displays first --num rows.""" + """ + Display the first `num` rows of the Parquet data as a pandas DataFrame. + """ print(self._head().to_pandas()) def distinct(self): - """Displays unique values seen in specified columns in the first `--num` rows. - - Useful for getting an approximate vocabulary for certain columns. - """ + Display unique values seen in specified columns in the first `num` rows. + + Useful for getting an approximate vocabulary for certain columns. + """ for col_name, column in zip(self._head().column_names, self._head().columns): print(col_name) print("unique:", column.unique().to_pylist())