mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-27 05:55:27 +01:00
new
This commit is contained in:
parent
9bb0986079
commit
deec9a820e
121
common/batch.py
121
common/batch.py
@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable
|
|||||||
|
|
||||||
|
|
||||||
class BatchBase(Pipelineable, abc.ABC):
|
class BatchBase(Pipelineable, abc.ABC):
|
||||||
|
"""
|
||||||
|
A base class for batches used in pipelines.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Convert the batch into a dictionary representation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary representation of the batch.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If the method is not implemented in a subclass.
|
||||||
|
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to(self, device: torch.device, non_blocking: bool = False):
|
def to(self, device: torch.device, non_blocking: bool = False):
|
||||||
|
"""
|
||||||
|
Move the batch to the specified device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): The target device.
|
||||||
|
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchBase: A new batch on the target device.
|
||||||
|
|
||||||
|
"""
|
||||||
args = {}
|
args = {}
|
||||||
for feature_name, feature_value in self.as_dict().items():
|
for feature_name, feature_value in self.as_dict().items():
|
||||||
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
||||||
return self.__class__(**args)
|
return self.__class__(**args)
|
||||||
|
|
||||||
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
|
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
|
||||||
|
"""
|
||||||
|
Record a CUDA stream for all tensors in the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream (torch.cuda.streams.Stream): The CUDA stream to record.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
for feature_value in self.as_dict().values():
|
for feature_value in self.as_dict().values():
|
||||||
feature_value.record_stream(stream)
|
feature_value.record_stream(stream)
|
||||||
|
|
||||||
def pin_memory(self):
|
def pin_memory(self):
|
||||||
|
"""
|
||||||
|
Pin memory for all tensors in the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchBase: A new batch with pinned memory.
|
||||||
|
|
||||||
|
"""
|
||||||
args = {}
|
args = {}
|
||||||
for feature_name, feature_value in self.as_dict().items():
|
for feature_name, feature_value in self.as_dict().items():
|
||||||
args[feature_name] = feature_value.pin_memory()
|
args[feature_name] = feature_value.pin_memory()
|
||||||
return self.__class__(**args)
|
return self.__class__(**args)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
"""
|
||||||
|
Generate a string representation of the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string representation of the batch.
|
||||||
|
|
||||||
|
"""
|
||||||
def obj2str(v):
|
def obj2str(v):
|
||||||
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
||||||
|
|
||||||
@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self) -> int:
|
def batch_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the batch size from the tensors in the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The batch size.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the batch size cannot be determined from the tensors.
|
||||||
|
|
||||||
|
"""
|
||||||
for tensor in self.as_dict().values():
|
for tensor in self.as_dict().values():
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
continue
|
continue
|
||||||
@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataclassBatch(BatchBase):
|
class DataclassBatch(BatchBase):
|
||||||
|
"""
|
||||||
|
A batch class that uses dataclasses to define its fields.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
@classmethod
|
@classmethod
|
||||||
def feature_names(cls):
|
def feature_names(cls):
|
||||||
|
"""
|
||||||
|
Get the feature names of the dataclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of feature names.
|
||||||
|
|
||||||
|
"""
|
||||||
return list(cls.__dataclass_fields__.keys())
|
return list(cls.__dataclass_fields__.keys())
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
|
"""
|
||||||
|
Convert the dataclass batch into a dictionary representation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary representation of the batch.
|
||||||
|
|
||||||
|
"""
|
||||||
return {
|
return {
|
||||||
feature_name: getattr(self, feature_name)
|
feature_name: getattr(self, feature_name)
|
||||||
for feature_name in self.feature_names()
|
for feature_name in self.feature_names()
|
||||||
@ -64,7 +147,18 @@ class DataclassBatch(BatchBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_schema(name: str, schema):
|
def from_schema(name: str, schema):
|
||||||
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
|
"""
|
||||||
|
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the custom batch class.
|
||||||
|
schema: The schema or structure of the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type[DataclassBatch]: A custom batch class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
return dataclasses.make_dataclass(
|
return dataclasses.make_dataclass(
|
||||||
cls_name=name,
|
cls_name=name,
|
||||||
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
|
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
|
||||||
@ -73,6 +167,17 @@ class DataclassBatch(BatchBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_fields(name: str, fields: dict):
|
def from_fields(name: str, fields: dict):
|
||||||
|
"""
|
||||||
|
Create a custom batch subclass from a set of fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the custom batch class.
|
||||||
|
fields (dict): A dictionary specifying the fields and their types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type[DataclassBatch]: A custom batch class.
|
||||||
|
|
||||||
|
"""
|
||||||
return dataclasses.make_dataclass(
|
return dataclasses.make_dataclass(
|
||||||
cls_name=name,
|
cls_name=name,
|
||||||
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
||||||
@ -81,5 +186,19 @@ class DataclassBatch(BatchBase):
|
|||||||
|
|
||||||
|
|
||||||
class DictionaryBatch(BatchBase, dict):
|
class DictionaryBatch(BatchBase, dict):
|
||||||
|
"""
|
||||||
|
A batch class that represents data as a dictionary.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Convert the dictionary batch into a dictionary representation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary representation of the batch.
|
||||||
|
|
||||||
|
"""
|
||||||
return self
|
return self
|
||||||
|
@ -5,6 +5,15 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
|
|
||||||
def maybe_setup_tensorflow():
|
def maybe_setup_tensorflow():
|
||||||
|
"""
|
||||||
|
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
|
||||||
|
|
||||||
|
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -14,6 +23,18 @@ def maybe_setup_tensorflow():
|
|||||||
|
|
||||||
|
|
||||||
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
|
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
|
||||||
|
"""
|
||||||
|
Set up the distributed environment and get the appropriate torch device.
|
||||||
|
|
||||||
|
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.device: The torch device for the current process.
|
||||||
|
|
||||||
|
"""
|
||||||
if tf_ok:
|
if tf_ok:
|
||||||
maybe_setup_tensorflow()
|
maybe_setup_tensorflow()
|
||||||
|
|
||||||
|
@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem()
|
|||||||
|
|
||||||
|
|
||||||
def infer_fs(path: str):
|
def infer_fs(path: str):
|
||||||
|
"""
|
||||||
|
Infer the file system (fs) type based on the given path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The file path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported.
|
||||||
|
"""
|
||||||
if path.startswith("gs://"):
|
if path.startswith("gs://"):
|
||||||
return GCS_FS
|
return GCS_FS
|
||||||
elif path.startswith("hdfs://"):
|
elif path.startswith("hdfs://"):
|
||||||
@ -18,8 +30,26 @@ def infer_fs(path: str):
|
|||||||
|
|
||||||
|
|
||||||
def is_local_fs(fs):
|
def is_local_fs(fs):
|
||||||
|
"""
|
||||||
|
Check if the given file system is the local file system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fs (str): The file system type to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the file system is the local file system, False otherwise.
|
||||||
|
"""
|
||||||
return fs == LOCAL_FS
|
return fs == LOCAL_FS
|
||||||
|
|
||||||
|
|
||||||
def is_gcs_fs(fs):
|
def is_gcs_fs(fs):
|
||||||
|
"""
|
||||||
|
Check if the given file system is Google Cloud Storage (GCS).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fs (str): The file system type to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the file system is GCS, False otherwise.
|
||||||
|
"""
|
||||||
return fs == GCS_FS
|
return fs == GCS_FS
|
||||||
|
@ -49,14 +49,24 @@ def log_ebc_norms(
|
|||||||
ebc_keys: List[str],
|
ebc_keys: List[str],
|
||||||
sample_size: int = 4_000_000,
|
sample_size: int = 4_000_000,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Logs the norms of the embedding tables as specified by ebc_keys.
|
"""
|
||||||
|
Logs the norms of the embedding tables as specified by ebc_keys.
|
||||||
As of now, log average norm per rank.
|
As of now, log average norm per rank.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_state_dict: model.state_dict()
|
model_state_dict (dict): The state dictionary of the model.
|
||||||
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
|
ebc_keys (List[str]): List of embedding keys from state_dict to log.
|
||||||
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
|
Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight".
|
||||||
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
|
sample_size (int, optional): Limits the number of rows per rank to compute average on
|
||||||
|
to avoid out-of-memory (OOM) errors. Defaults to 4,000,000.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables.
|
||||||
|
The keys are in the format "{emb_key}-norm-{idx}".
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function computes and logs the average norm of embedding tables across ranks.
|
||||||
|
It gathers the norms from all ranks and returns them as a dictionary.
|
||||||
"""
|
"""
|
||||||
norm_logs = dict()
|
norm_logs = dict()
|
||||||
for emb_key in ebc_keys:
|
for emb_key in ebc_keys:
|
||||||
|
@ -11,6 +11,20 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class LargeEmbeddings(nn.Module):
|
class LargeEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
A module for handling large embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings.
|
||||||
|
surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving.
|
||||||
|
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
large_embeddings_config: LargeEmbeddingsConfig,
|
large_embeddings_config: LargeEmbeddingsConfig,
|
||||||
@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module):
|
|||||||
self,
|
self,
|
||||||
sparse_features: KeyedJaggedTensor,
|
sparse_features: KeyedJaggedTensor,
|
||||||
) -> KeyedTensor:
|
) -> KeyedTensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the LargeEmbeddings module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_features (KeyedJaggedTensor): Sparse input features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KeyedTensor: The output of the forward pass, a KeyedTensor.
|
||||||
|
|
||||||
|
"""
|
||||||
pooled_embs = self.ebc(sparse_features)
|
pooled_embs = self.ebc(sparse_features)
|
||||||
|
|
||||||
# a KeyedTensor
|
# a KeyedTensor
|
||||||
|
@ -11,6 +11,12 @@ import torch.distributed.run
|
|||||||
|
|
||||||
|
|
||||||
def is_distributed_worker():
|
def is_distributed_worker():
|
||||||
|
"""
|
||||||
|
Checks if the current process is a distributed worker.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False.
|
||||||
|
"""
|
||||||
world_size = os.environ.get("WORLD_SIZE", None)
|
world_size = os.environ.get("WORLD_SIZE", None)
|
||||||
rank = os.environ.get("RANK", None)
|
rank = os.environ.get("RANK", None)
|
||||||
return world_size is not None and rank is not None
|
return world_size is not None and rank is not None
|
||||||
@ -25,10 +31,10 @@ def maybe_run_training(
|
|||||||
is_chief: Optional[bool] = False,
|
is_chief: Optional[bool] = False,
|
||||||
**training_kwargs,
|
**training_kwargs,
|
||||||
):
|
):
|
||||||
"""Wrapper function for single node, multi-GPU Pytorch training.
|
"""
|
||||||
|
Wrapper function for single node, multi-GPU PyTorch training.
|
||||||
|
|
||||||
If the necessary distributed Pytorch environment variables
|
If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes
|
||||||
(WORLD_SIZE, RANK) have been set, then this function executes
|
|
||||||
`train_fn(**training_kwargs)`.
|
`train_fn(**training_kwargs)`.
|
||||||
|
|
||||||
Otherwise, this function calls torchrun and points at the calling module
|
Otherwise, this function calls torchrun and points at the calling module
|
||||||
@ -36,13 +42,24 @@ def maybe_run_training(
|
|||||||
and training will commence.
|
and training will commence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_fn: The function that is responsible for training
|
train_fn (callable): The function responsible for training.
|
||||||
module_name: The name of the module that this function was called from;
|
module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint.
|
||||||
used to indicate torchrun entrypoint.
|
nproc_per_node (int, optional): Number of workers per node. Defaults to None.
|
||||||
nproc_per_node: Number of workers per node; supported values.
|
num_nodes (int, optional): Number of nodes. Defaults to None.
|
||||||
num_nodes: Number of nodes, otherwise inferred from environment.
|
is_chief (bool, optional): If the process is running on the chief node. Defaults to False.
|
||||||
is_chief: If process is running on chief.
|
set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False.
|
||||||
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
|
**training_kwargs: Additional keyword arguments to pass to the `train_fn`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function checks if the current process is a distributed worker by examining the environment variables.
|
||||||
|
If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary
|
||||||
|
environment variables and launches the training process using torchrun.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To run training on a single node with 4 GPUs, you can use:
|
||||||
|
```
|
||||||
|
maybe_run_training(train_function, __name__, nproc_per_node=4)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
machines = utils.machine_from_env()
|
machines = utils.machine_from_env()
|
||||||
|
@ -19,6 +19,23 @@ MOCK_ENV = {
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def mock_pg():
|
def mock_pg():
|
||||||
|
"""
|
||||||
|
A context manager for mocking the distributed process group for testing purposes.
|
||||||
|
|
||||||
|
This context manager temporarily sets environment variables to mock the distributed process group
|
||||||
|
and initializes it using the Gloo backend. It is useful for testing distributed training without
|
||||||
|
actually launching multiple processes.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
with mock_pg():
|
||||||
|
# Your distributed training code here
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This context manager should be used within a testing environment to simulate distributed training
|
||||||
|
without actually creating multiple processes.
|
||||||
|
"""
|
||||||
with patch.dict(os.environ, MOCK_ENV):
|
with patch.dict(os.environ, MOCK_ENV):
|
||||||
try:
|
try:
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
|
@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig)
|
|||||||
|
|
||||||
|
|
||||||
def _read_file(f):
|
def _read_file(f):
|
||||||
|
"""
|
||||||
|
Read the contents of a file using fsspec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: File path or URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The contents of the file.
|
||||||
|
"""
|
||||||
with fsspec.open(f) as f:
|
with fsspec.open(f) as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
@ -21,17 +30,23 @@ def setup_configuration(
|
|||||||
yaml_path: str,
|
yaml_path: str,
|
||||||
substitute_env_variable: bool = False,
|
substitute_env_variable: bool = False,
|
||||||
) -> Tuple[C, str]:
|
) -> Tuple[C, str]:
|
||||||
"""Resolves a config at a yaml path.
|
"""
|
||||||
|
Load a Pydantic config object from a YAML file and optionally substitute environment variables.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_type: Pydantic config class to load.
|
config_type: Pydantic config class to load.
|
||||||
yaml_path: yaml path of the config file.
|
yaml_path: YAML path of the config file.
|
||||||
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
|
substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR}
|
||||||
environment variable value whenever possible. If an environment variable doesn't exist,
|
with their environment variable values whenever possible.
|
||||||
the string is left unchanged.
|
If an environment variable doesn't exist, the string is left unchanged.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The pydantic config object.
|
A tuple containing the Pydantic config object and the resolved YAML content.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _substitute(s):
|
def _substitute(s):
|
||||||
|
@ -6,6 +6,33 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class WandbConfig(base_config.BaseConfig):
|
class WandbConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for integrating with Weights and Biases (WandB).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
host (str): Host of the Weights and Biases instance, passed to login.
|
||||||
|
key_path (str): Path to the key file.
|
||||||
|
name (str): Name of the experiment, passed to init.
|
||||||
|
entity (str): Name of the user/service account, passed to init.
|
||||||
|
project (str): Name of the WandB project, passed to init.
|
||||||
|
tags (List[str]): List of tags, passed to init.
|
||||||
|
notes (str): Notes, passed to init.
|
||||||
|
metadata (Dict[str, Any]): Additional metadata to log.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
wandb_config = WandbConfig(
|
||||||
|
host="https://wandb.example.com",
|
||||||
|
key_path="/path/to/key",
|
||||||
|
name="experiment_1",
|
||||||
|
entity="user123",
|
||||||
|
project="my_project",
|
||||||
|
tags=["experiment", "ml"],
|
||||||
|
notes="This is a test experiment.",
|
||||||
|
metadata={"version": "1.0"}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
host: str = pydantic.Field(
|
host: str = pydantic.Field(
|
||||||
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
|
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
|
||||||
description="Host of Weights and Biases instance, passed to login.",
|
description="Host of Weights and Biases instance, passed to login.",
|
||||||
|
@ -6,7 +6,14 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class BaseConfigTest(TestCase):
|
class BaseConfigTest(TestCase):
|
||||||
|
"""
|
||||||
|
Unit tests for the BaseConfig class.
|
||||||
|
"""
|
||||||
|
|
||||||
def test_extra_forbidden(self):
|
def test_extra_forbidden(self):
|
||||||
|
"""
|
||||||
|
Test that extra fields are forbidden when creating a Config instance.
|
||||||
|
"""
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
@ -15,6 +22,9 @@ class BaseConfigTest(TestCase):
|
|||||||
Config(x=1, y=2)
|
Config(x=1, y=2)
|
||||||
|
|
||||||
def test_one_of(self):
|
def test_one_of(self):
|
||||||
|
"""
|
||||||
|
Test the use of the `one_of` attribute for fields in a Config instance.
|
||||||
|
"""
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
x: int = pydantic.Field(None, one_of="f")
|
x: int = pydantic.Field(None, one_of="f")
|
||||||
y: int = pydantic.Field(None, one_of="f")
|
y: int = pydantic.Field(None, one_of="f")
|
||||||
@ -27,6 +37,9 @@ class BaseConfigTest(TestCase):
|
|||||||
Config(x=1, y=3)
|
Config(x=1, y=3)
|
||||||
|
|
||||||
def test_at_most_one_of(self):
|
def test_at_most_one_of(self):
|
||||||
|
"""
|
||||||
|
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
|
||||||
|
"""
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
x: int = pydantic.Field(None, at_most_one_of="f")
|
x: int = pydantic.Field(None, at_most_one_of="f")
|
||||||
y: str = pydantic.Field(None, at_most_one_of="f")
|
y: str = pydantic.Field(None, at_most_one_of="f")
|
||||||
|
@ -9,6 +9,14 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class RuntimeConfig(base_config.BaseConfig):
|
class RuntimeConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for runtime settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration.
|
||||||
|
- enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices.
|
||||||
|
- enable_amp (bool): Enable automatic mixed precision.
|
||||||
|
"""
|
||||||
wandb: WandbConfig = pydantic.Field(None)
|
wandb: WandbConfig = pydantic.Field(None)
|
||||||
enable_tensorfloat32: bool = pydantic.Field(
|
enable_tensorfloat32: bool = pydantic.Field(
|
||||||
False, description="Use tensorfloat32 if on Ampere devices."
|
False, description="Use tensorfloat32 if on Ampere devices."
|
||||||
@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class TrainingConfig(base_config.BaseConfig):
|
class TrainingConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for training settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- save_dir (str): Directory to save checkpoints.
|
||||||
|
- num_train_steps (pydantic.PositiveInt): Number of training steps.
|
||||||
|
- initial_checkpoint_dir (str): Directory of initial checkpoints (optional).
|
||||||
|
- checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps.
|
||||||
|
- checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
|
||||||
|
- train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps.
|
||||||
|
- num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used.
|
||||||
|
- eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps.
|
||||||
|
- eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds.
|
||||||
|
- gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
|
||||||
|
- num_epochs (pydantic.PositiveInt): Number of training epochs.
|
||||||
|
"""
|
||||||
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
|
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
|
||||||
num_train_steps: pydantic.PositiveInt = 10000
|
num_train_steps: pydantic.PositiveInt = 10000
|
||||||
initial_checkpoint_dir: str = pydantic.Field(
|
initial_checkpoint_dir: str = pydantic.Field(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user