mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-25 05:11:10 +01:00
new
This commit is contained in:
parent
9bb0986079
commit
deec9a820e
121
common/batch.py
121
common/batch.py
@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable
|
||||
|
||||
|
||||
class BatchBase(Pipelineable, abc.ABC):
|
||||
"""
|
||||
A base class for batches used in pipelines.
|
||||
|
||||
Attributes:
|
||||
None
|
||||
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def as_dict(self) -> Dict:
|
||||
"""
|
||||
Convert the batch into a dictionary representation.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary representation of the batch.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented in a subclass.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device: torch.device, non_blocking: bool = False):
|
||||
"""
|
||||
Move the batch to the specified device.
|
||||
|
||||
Args:
|
||||
device (torch.device): The target device.
|
||||
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.
|
||||
|
||||
Returns:
|
||||
BatchBase: A new batch on the target device.
|
||||
|
||||
"""
|
||||
args = {}
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
||||
return self.__class__(**args)
|
||||
|
||||
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
|
||||
"""
|
||||
Record a CUDA stream for all tensors in the batch.
|
||||
|
||||
Args:
|
||||
stream (torch.cuda.streams.Stream): The CUDA stream to record.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
for feature_value in self.as_dict().values():
|
||||
feature_value.record_stream(stream)
|
||||
|
||||
def pin_memory(self):
|
||||
"""
|
||||
Pin memory for all tensors in the batch.
|
||||
|
||||
Returns:
|
||||
BatchBase: A new batch with pinned memory.
|
||||
|
||||
"""
|
||||
args = {}
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
args[feature_name] = feature_value.pin_memory()
|
||||
return self.__class__(**args)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Generate a string representation of the batch.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the batch.
|
||||
|
||||
"""
|
||||
def obj2str(v):
|
||||
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
||||
|
||||
@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC):
|
||||
|
||||
@property
|
||||
def batch_size(self) -> int:
|
||||
"""
|
||||
Get the batch size from the tensors in the batch.
|
||||
|
||||
Returns:
|
||||
int: The batch size.
|
||||
|
||||
Raises:
|
||||
Exception: If the batch size cannot be determined from the tensors.
|
||||
|
||||
"""
|
||||
for tensor in self.as_dict().values():
|
||||
if tensor is None:
|
||||
continue
|
||||
@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC):
|
||||
|
||||
@dataclass
|
||||
class DataclassBatch(BatchBase):
|
||||
"""
|
||||
A batch class that uses dataclasses to define its fields.
|
||||
|
||||
Attributes:
|
||||
None
|
||||
|
||||
"""
|
||||
@classmethod
|
||||
def feature_names(cls):
|
||||
"""
|
||||
Get the feature names of the dataclass.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of feature names.
|
||||
|
||||
"""
|
||||
return list(cls.__dataclass_fields__.keys())
|
||||
|
||||
def as_dict(self):
|
||||
"""
|
||||
Convert the dataclass batch into a dictionary representation.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary representation of the batch.
|
||||
|
||||
"""
|
||||
return {
|
||||
feature_name: getattr(self, feature_name)
|
||||
for feature_name in self.feature_names()
|
||||
@ -64,7 +147,18 @@ class DataclassBatch(BatchBase):
|
||||
|
||||
@staticmethod
|
||||
def from_schema(name: str, schema):
|
||||
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
|
||||
"""
|
||||
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.
|
||||
|
||||
Args:
|
||||
name (str): The name of the custom batch class.
|
||||
schema: The schema or structure of the batch.
|
||||
|
||||
Returns:
|
||||
Type[DataclassBatch]: A custom batch class.
|
||||
|
||||
"""
|
||||
|
||||
return dataclasses.make_dataclass(
|
||||
cls_name=name,
|
||||
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
|
||||
@ -73,6 +167,17 @@ class DataclassBatch(BatchBase):
|
||||
|
||||
@staticmethod
|
||||
def from_fields(name: str, fields: dict):
|
||||
"""
|
||||
Create a custom batch subclass from a set of fields.
|
||||
|
||||
Args:
|
||||
name (str): The name of the custom batch class.
|
||||
fields (dict): A dictionary specifying the fields and their types.
|
||||
|
||||
Returns:
|
||||
Type[DataclassBatch]: A custom batch class.
|
||||
|
||||
"""
|
||||
return dataclasses.make_dataclass(
|
||||
cls_name=name,
|
||||
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
||||
@ -81,5 +186,19 @@ class DataclassBatch(BatchBase):
|
||||
|
||||
|
||||
class DictionaryBatch(BatchBase, dict):
|
||||
"""
|
||||
A batch class that represents data as a dictionary.
|
||||
|
||||
Attributes:
|
||||
None
|
||||
|
||||
"""
|
||||
def as_dict(self) -> Dict:
|
||||
"""
|
||||
Convert the dictionary batch into a dictionary representation.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary representation of the batch.
|
||||
|
||||
"""
|
||||
return self
|
||||
|
@ -5,6 +5,15 @@ import torch.distributed as dist
|
||||
|
||||
|
||||
def maybe_setup_tensorflow():
|
||||
"""
|
||||
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
|
||||
|
||||
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
@ -14,6 +23,18 @@ def maybe_setup_tensorflow():
|
||||
|
||||
|
||||
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
|
||||
"""
|
||||
Set up the distributed environment and get the appropriate torch device.
|
||||
|
||||
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
|
||||
|
||||
Args:
|
||||
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
|
||||
|
||||
Returns:
|
||||
torch.device: The torch device for the current process.
|
||||
|
||||
"""
|
||||
if tf_ok:
|
||||
maybe_setup_tensorflow()
|
||||
|
||||
|
@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem()
|
||||
|
||||
|
||||
def infer_fs(path: str):
|
||||
"""
|
||||
Infer the file system (fs) type based on the given path.
|
||||
|
||||
Args:
|
||||
path (str): The file path.
|
||||
|
||||
Returns:
|
||||
str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system).
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported.
|
||||
"""
|
||||
if path.startswith("gs://"):
|
||||
return GCS_FS
|
||||
elif path.startswith("hdfs://"):
|
||||
@ -18,8 +30,26 @@ def infer_fs(path: str):
|
||||
|
||||
|
||||
def is_local_fs(fs):
|
||||
"""
|
||||
Check if the given file system is the local file system.
|
||||
|
||||
Args:
|
||||
fs (str): The file system type to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the file system is the local file system, False otherwise.
|
||||
"""
|
||||
return fs == LOCAL_FS
|
||||
|
||||
|
||||
def is_gcs_fs(fs):
|
||||
"""
|
||||
Check if the given file system is Google Cloud Storage (GCS).
|
||||
|
||||
Args:
|
||||
fs (str): The file system type to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the file system is GCS, False otherwise.
|
||||
"""
|
||||
return fs == GCS_FS
|
||||
|
@ -49,15 +49,25 @@ def log_ebc_norms(
|
||||
ebc_keys: List[str],
|
||||
sample_size: int = 4_000_000,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Logs the norms of the embedding tables as specified by ebc_keys.
|
||||
As of now, log average norm per rank.
|
||||
|
||||
Args:
|
||||
model_state_dict: model.state_dict()
|
||||
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
|
||||
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
|
||||
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
|
||||
"""
|
||||
Logs the norms of the embedding tables as specified by ebc_keys.
|
||||
As of now, log average norm per rank.
|
||||
|
||||
Args:
|
||||
model_state_dict (dict): The state dictionary of the model.
|
||||
ebc_keys (List[str]): List of embedding keys from state_dict to log.
|
||||
Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight".
|
||||
sample_size (int, optional): Limits the number of rows per rank to compute average on
|
||||
to avoid out-of-memory (OOM) errors. Defaults to 4,000,000.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables.
|
||||
The keys are in the format "{emb_key}-norm-{idx}".
|
||||
|
||||
Note:
|
||||
This function computes and logs the average norm of embedding tables across ranks.
|
||||
It gathers the norms from all ranks and returns them as a dictionary.
|
||||
"""
|
||||
norm_logs = dict()
|
||||
for emb_key in ebc_keys:
|
||||
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
|
||||
|
@ -11,6 +11,20 @@ import numpy as np
|
||||
|
||||
|
||||
class LargeEmbeddings(nn.Module):
|
||||
"""
|
||||
A module for handling large embeddings.
|
||||
|
||||
Args:
|
||||
large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings.
|
||||
|
||||
Attributes:
|
||||
ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings.
|
||||
surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models.
|
||||
|
||||
Note:
|
||||
The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
large_embeddings_config: LargeEmbeddingsConfig,
|
||||
@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module):
|
||||
self,
|
||||
sparse_features: KeyedJaggedTensor,
|
||||
) -> KeyedTensor:
|
||||
"""
|
||||
Forward pass of the LargeEmbeddings module.
|
||||
|
||||
Args:
|
||||
sparse_features (KeyedJaggedTensor): Sparse input features.
|
||||
|
||||
Returns:
|
||||
KeyedTensor: The output of the forward pass, a KeyedTensor.
|
||||
|
||||
"""
|
||||
pooled_embs = self.ebc(sparse_features)
|
||||
|
||||
# a KeyedTensor
|
||||
|
@ -11,6 +11,12 @@ import torch.distributed.run
|
||||
|
||||
|
||||
def is_distributed_worker():
|
||||
"""
|
||||
Checks if the current process is a distributed worker.
|
||||
|
||||
Returns:
|
||||
bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False.
|
||||
"""
|
||||
world_size = os.environ.get("WORLD_SIZE", None)
|
||||
rank = os.environ.get("RANK", None)
|
||||
return world_size is not None and rank is not None
|
||||
@ -25,25 +31,36 @@ def maybe_run_training(
|
||||
is_chief: Optional[bool] = False,
|
||||
**training_kwargs,
|
||||
):
|
||||
"""Wrapper function for single node, multi-GPU Pytorch training.
|
||||
|
||||
If the necessary distributed Pytorch environment variables
|
||||
(WORLD_SIZE, RANK) have been set, then this function executes
|
||||
`train_fn(**training_kwargs)`.
|
||||
|
||||
Otherwise, this function calls torchrun and points at the calling module
|
||||
`module_name`. After this call, the necessary environment variables are set
|
||||
and training will commence.
|
||||
|
||||
Args:
|
||||
train_fn: The function that is responsible for training
|
||||
module_name: The name of the module that this function was called from;
|
||||
used to indicate torchrun entrypoint.
|
||||
nproc_per_node: Number of workers per node; supported values.
|
||||
num_nodes: Number of nodes, otherwise inferred from environment.
|
||||
is_chief: If process is running on chief.
|
||||
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
|
||||
"""
|
||||
Wrapper function for single node, multi-GPU PyTorch training.
|
||||
|
||||
If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes
|
||||
`train_fn(**training_kwargs)`.
|
||||
|
||||
Otherwise, this function calls torchrun and points at the calling module
|
||||
`module_name`. After this call, the necessary environment variables are set
|
||||
and training will commence.
|
||||
|
||||
Args:
|
||||
train_fn (callable): The function responsible for training.
|
||||
module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint.
|
||||
nproc_per_node (int, optional): Number of workers per node. Defaults to None.
|
||||
num_nodes (int, optional): Number of nodes. Defaults to None.
|
||||
is_chief (bool, optional): If the process is running on the chief node. Defaults to False.
|
||||
set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False.
|
||||
**training_kwargs: Additional keyword arguments to pass to the `train_fn`.
|
||||
|
||||
Note:
|
||||
This function checks if the current process is a distributed worker by examining the environment variables.
|
||||
If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary
|
||||
environment variables and launches the training process using torchrun.
|
||||
|
||||
Example:
|
||||
To run training on a single node with 4 GPUs, you can use:
|
||||
```
|
||||
maybe_run_training(train_function, __name__, nproc_per_node=4)
|
||||
```
|
||||
"""
|
||||
|
||||
machines = utils.machine_from_env()
|
||||
if num_nodes is None:
|
||||
|
@ -19,6 +19,23 @@ MOCK_ENV = {
|
||||
|
||||
@contextmanager
|
||||
def mock_pg():
|
||||
"""
|
||||
A context manager for mocking the distributed process group for testing purposes.
|
||||
|
||||
This context manager temporarily sets environment variables to mock the distributed process group
|
||||
and initializes it using the Gloo backend. It is useful for testing distributed training without
|
||||
actually launching multiple processes.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with mock_pg():
|
||||
# Your distributed training code here
|
||||
```
|
||||
|
||||
Note:
|
||||
This context manager should be used within a testing environment to simulate distributed training
|
||||
without actually creating multiple processes.
|
||||
"""
|
||||
with patch.dict(os.environ, MOCK_ENV):
|
||||
try:
|
||||
dist.init_process_group(
|
||||
|
@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig)
|
||||
|
||||
|
||||
def _read_file(f):
|
||||
"""
|
||||
Read the contents of a file using fsspec.
|
||||
|
||||
Args:
|
||||
f: File path or URL.
|
||||
|
||||
Returns:
|
||||
The contents of the file.
|
||||
"""
|
||||
with fsspec.open(f) as f:
|
||||
return f.read()
|
||||
|
||||
@ -21,18 +30,24 @@ def setup_configuration(
|
||||
yaml_path: str,
|
||||
substitute_env_variable: bool = False,
|
||||
) -> Tuple[C, str]:
|
||||
"""Resolves a config at a yaml path.
|
||||
|
||||
Args:
|
||||
config_type: Pydantic config class to load.
|
||||
yaml_path: yaml path of the config file.
|
||||
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
|
||||
environment variable value whenever possible. If an environment variable doesn't exist,
|
||||
the string is left unchanged.
|
||||
|
||||
Returns:
|
||||
The pydantic config object.
|
||||
"""
|
||||
Load a Pydantic config object from a YAML file and optionally substitute environment variables.
|
||||
|
||||
Args:
|
||||
config_type: Pydantic config class to load.
|
||||
yaml_path: YAML path of the config file.
|
||||
substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR}
|
||||
with their environment variable values whenever possible.
|
||||
If an environment variable doesn't exist, the string is left unchanged.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Pydantic config object and the resolved YAML content.
|
||||
|
||||
Example:
|
||||
```python
|
||||
config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True)
|
||||
```
|
||||
"""
|
||||
|
||||
def _substitute(s):
|
||||
if substitute_env_variable:
|
||||
|
@ -6,6 +6,33 @@ import pydantic
|
||||
|
||||
|
||||
class WandbConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for integrating with Weights and Biases (WandB).
|
||||
|
||||
Attributes:
|
||||
host (str): Host of the Weights and Biases instance, passed to login.
|
||||
key_path (str): Path to the key file.
|
||||
name (str): Name of the experiment, passed to init.
|
||||
entity (str): Name of the user/service account, passed to init.
|
||||
project (str): Name of the WandB project, passed to init.
|
||||
tags (List[str]): List of tags, passed to init.
|
||||
notes (str): Notes, passed to init.
|
||||
metadata (Dict[str, Any]): Additional metadata to log.
|
||||
|
||||
Example:
|
||||
```python
|
||||
wandb_config = WandbConfig(
|
||||
host="https://wandb.example.com",
|
||||
key_path="/path/to/key",
|
||||
name="experiment_1",
|
||||
entity="user123",
|
||||
project="my_project",
|
||||
tags=["experiment", "ml"],
|
||||
notes="This is a test experiment.",
|
||||
metadata={"version": "1.0"}
|
||||
)
|
||||
```
|
||||
"""
|
||||
host: str = pydantic.Field(
|
||||
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
|
||||
description="Host of Weights and Biases instance, passed to login.",
|
||||
|
@ -6,33 +6,46 @@ import pydantic
|
||||
|
||||
|
||||
class BaseConfigTest(TestCase):
|
||||
def test_extra_forbidden(self):
|
||||
class Config(BaseConfig):
|
||||
x: int
|
||||
"""
|
||||
Unit tests for the BaseConfig class.
|
||||
"""
|
||||
|
||||
Config(x=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=2)
|
||||
def test_extra_forbidden(self):
|
||||
"""
|
||||
Test that extra fields are forbidden when creating a Config instance.
|
||||
"""
|
||||
class Config(BaseConfig):
|
||||
x: int
|
||||
|
||||
def test_one_of(self):
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, one_of="f")
|
||||
y: int = pydantic.Field(None, one_of="f")
|
||||
Config(x=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=2)
|
||||
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=3)
|
||||
def test_one_of(self):
|
||||
"""
|
||||
Test the use of the `one_of` attribute for fields in a Config instance.
|
||||
"""
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, one_of="f")
|
||||
y: int = pydantic.Field(None, one_of="f")
|
||||
|
||||
def test_at_most_one_of(self):
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, at_most_one_of="f")
|
||||
y: str = pydantic.Field(None, at_most_one_of="f")
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y=1)
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y=3)
|
||||
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y="a")
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y="a")
|
||||
def test_at_most_one_of(self):
|
||||
"""
|
||||
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
|
||||
"""
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, at_most_one_of="f")
|
||||
y: str = pydantic.Field(None, at_most_one_of="f")
|
||||
|
||||
Config()
|
||||
Config(x=1)
|
||||
Config(y="a")
|
||||
with self.assertRaises(pydantic.ValidationError):
|
||||
Config(x=1, y="a")
|
||||
|
@ -9,6 +9,14 @@ import pydantic
|
||||
|
||||
|
||||
class RuntimeConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for runtime settings.
|
||||
|
||||
Attributes:
|
||||
- wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration.
|
||||
- enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices.
|
||||
- enable_amp (bool): Enable automatic mixed precision.
|
||||
"""
|
||||
wandb: WandbConfig = pydantic.Field(None)
|
||||
enable_tensorfloat32: bool = pydantic.Field(
|
||||
False, description="Use tensorfloat32 if on Ampere devices."
|
||||
@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig):
|
||||
|
||||
|
||||
class TrainingConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for training settings.
|
||||
|
||||
Attributes:
|
||||
- save_dir (str): Directory to save checkpoints.
|
||||
- num_train_steps (pydantic.PositiveInt): Number of training steps.
|
||||
- initial_checkpoint_dir (str): Directory of initial checkpoints (optional).
|
||||
- checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps.
|
||||
- checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
|
||||
- train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps.
|
||||
- num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used.
|
||||
- eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps.
|
||||
- eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds.
|
||||
- gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
|
||||
- num_epochs (pydantic.PositiveInt): Number of training epochs.
|
||||
"""
|
||||
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
|
||||
num_train_steps: pydantic.PositiveInt = 10000
|
||||
initial_checkpoint_dir: str = pydantic.Field(
|
||||
|
Loading…
x
Reference in New Issue
Block a user