mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-09 22:39:22 +01:00
udpate
This commit is contained in:
parent
deec9a820e
commit
92eaaad3ed
@ -13,23 +13,30 @@ GCS_PREFIX = "gs://"
|
||||
|
||||
|
||||
class Snapshot:
|
||||
"""Checkpoints using torchsnapshot.
|
||||
|
||||
Also saves step to be updated by the training loop.
|
||||
|
||||
"""
|
||||
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Initializes a Snapshot object.
|
||||
|
||||
Args:
|
||||
save_dir (str): Directory where checkpoints will be saved.
|
||||
state (Dict[str, Any]): State dictionary containing checkpoint information.
|
||||
"""
|
||||
self.save_dir = save_dir
|
||||
self.state = state
|
||||
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
||||
|
||||
@property
|
||||
def step(self):
|
||||
"""Get the current training step."""
|
||||
return self.state["extra_state"]["step"]
|
||||
|
||||
@step.setter
|
||||
def step(self, step: int) -> None:
|
||||
"""Set the current training step."""
|
||||
self.state["extra_state"]["step"] = step
|
||||
|
||||
@property
|
||||
@ -41,7 +48,15 @@ class Snapshot:
|
||||
self.state["extra_state"]["walltime"] = walltime
|
||||
|
||||
def save(self, global_step: int) -> "PendingSnapshot":
|
||||
"""Saves checkpoint with given global_step."""
|
||||
"""
|
||||
Saves a checkpoint with a given global step.
|
||||
|
||||
Args:
|
||||
global_step (int): The global step to associate with the checkpoint.
|
||||
|
||||
Returns:
|
||||
PendingSnapshot: A pending snapshot object.
|
||||
"""
|
||||
path = os.path.join(self.save_dir, str(global_step))
|
||||
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
||||
start_time = time.time()
|
||||
@ -58,7 +73,12 @@ class Snapshot:
|
||||
return snapshot
|
||||
|
||||
def restore(self, checkpoint: str) -> None:
|
||||
"""Restores a given checkpoint."""
|
||||
"""
|
||||
Restores a given checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Path to the checkpoint to restore.
|
||||
"""
|
||||
snapshot = torchsnapshot.Snapshot(path=checkpoint)
|
||||
logging.info(f"Restoring snapshot from {snapshot.path}.")
|
||||
start_time = time.time()
|
||||
@ -83,12 +103,17 @@ class Snapshot:
|
||||
global_step: Optional[int] = None,
|
||||
missing_ok: bool = False,
|
||||
) -> torchsnapshot.Snapshot:
|
||||
"""Get torch stateless snapshot, without actually loading it.
|
||||
Args:
|
||||
snapshot_path: path to the model snapshot
|
||||
global_step: restores from this checkpoint if specified.
|
||||
missing_ok: if True and checkpoints do not exist, returns without restoration.
|
||||
"""
|
||||
Get a torch stateless snapshot, without actually loading it.
|
||||
|
||||
Args:
|
||||
snapshot_path (str): Path to the model snapshot.
|
||||
global_step (int, optional): Restores from this checkpoint if specified.
|
||||
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
|
||||
|
||||
Returns:
|
||||
torchsnapshot.Snapshot: A torch snapshot object.
|
||||
"""
|
||||
path = get_checkpoint(snapshot_path, global_step, missing_ok)
|
||||
logging.info(f"Loading snapshot from {path}.")
|
||||
return torchsnapshot.Snapshot(path=path)
|
||||
@ -100,13 +125,14 @@ class Snapshot:
|
||||
snapshot_emb_name: str,
|
||||
weight_tensor,
|
||||
) -> None:
|
||||
"""Loads pretrained embedding from the snapshot to the model.
|
||||
Utilise partial lodaing meachanism from torchsnapshot.
|
||||
Args:
|
||||
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
|
||||
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
|
||||
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
|
||||
"""
|
||||
Loads pretrained embedding from the snapshot to the model.
|
||||
|
||||
Args:
|
||||
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
|
||||
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
|
||||
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
|
||||
"""
|
||||
start_time = time.time()
|
||||
manifest = embedding_snapshot.get_manifest()
|
||||
for path in manifest.keys():
|
||||
@ -209,7 +235,22 @@ def get_checkpoint(
|
||||
|
||||
|
||||
def get_checkpoints(save_dir: str) -> List[str]:
|
||||
"""Gets all checkpoints that have been fully written."""
|
||||
"""
|
||||
Get a list of fully written checkpoints in the specified directory.
|
||||
|
||||
This function retrieves a list of fully written checkpoints in the given directory.
|
||||
Checkpoints that are considered fully written include those that have a
|
||||
corresponding snapshot metadata file.
|
||||
|
||||
Args:
|
||||
save_dir (str): The directory where checkpoints are stored.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of fully written checkpoint paths.
|
||||
|
||||
Note:
|
||||
Checkpoints are sorted by their numeric filenames in ascending order.
|
||||
"""
|
||||
checkpoints = []
|
||||
fs = infer_fs(save_dir)
|
||||
if fs.exists(save_dir):
|
||||
@ -232,6 +273,18 @@ def wait_for_evaluators(
|
||||
global_step: int,
|
||||
timeout: int,
|
||||
) -> None:
|
||||
"""
|
||||
Waits for all evaluators to finish and checks for their completion status.
|
||||
|
||||
Args:
|
||||
save_dir (str): Directory where checkpoints are saved.
|
||||
partition_names (List[str]): List of partition names to check for completion.
|
||||
global_step (int): The global step for which to wait for evaluators.
|
||||
timeout (int): Maximum time in seconds to wait for evaluators to finish.
|
||||
|
||||
Returns:
|
||||
None: This function returns nothing but logs the progress and results.
|
||||
"""
|
||||
logging.info("Waiting for all evaluators to finish.")
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -8,59 +8,69 @@ import pydantic
|
||||
|
||||
|
||||
class BaseConfig(pydantic.BaseModel):
|
||||
"""Base class for all derived config classes.
|
||||
"""Base class for all derived config classes.
|
||||
|
||||
This class provides some convenient functionality:
|
||||
- Disallows extra fields when constructing an object. User error
|
||||
should be reduced by exact arguments.
|
||||
- "one_of" fields. A subclass can group optional fields and enforce
|
||||
that only one of the fields be set. For example:
|
||||
This class provides convenient functionality and constraints for derived config classes:
|
||||
|
||||
```
|
||||
- Disallows extra fields when constructing an object. User errors due to extraneous arguments
|
||||
are minimized.
|
||||
- "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields
|
||||
be set. For example:
|
||||
|
||||
```python
|
||||
class ExampleConfig(BaseConfig):
|
||||
x: int = Field(None, one_of="group_1")
|
||||
y: int = Field(None, one_of="group_1")
|
||||
x: int = Field(None, one_of="group_1")
|
||||
y: int = Field(None, one_of="group_1")
|
||||
|
||||
ExampleConfig(x=1) # ok
|
||||
ExampleConfig(y=1) # ok
|
||||
ExampleConfig(x=1, y=1) # throws error
|
||||
ExampleConfig(x=1) # OK
|
||||
ExampleConfig(y=1) # OK
|
||||
ExampleConfig(x=1, y=1) # Raises an error
|
||||
```
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Forbids extras."""
|
||||
Attributes:
|
||||
Config (class): Configuration options for this class, forbidding extra fields.
|
||||
|
||||
extra = pydantic.Extra.forbid # noqa
|
||||
Methods:
|
||||
_field_data_map(cls, field_data_name): Create a map of fields with the provided field data.
|
||||
_one_of_check(cls, values): Validate that all 'one of' fields appear exactly once.
|
||||
_at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once.
|
||||
pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging.
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def _field_data_map(cls, field_data_name):
|
||||
"""Create a map of fields with provided the field data."""
|
||||
schema = cls.schema()
|
||||
one_of = collections.defaultdict(list)
|
||||
for field, fdata in schema["properties"].items():
|
||||
if field_data_name in fdata:
|
||||
one_of[fdata[field_data_name]].append(field)
|
||||
return one_of
|
||||
"""
|
||||
class Config:
|
||||
"""Configuration options that forbid extra fields."""
|
||||
extra = pydantic.Extra.forbid # noqa
|
||||
|
||||
@pydantic.root_validator
|
||||
def _one_of_check(cls, values):
|
||||
"""Validate that all 'one of' fields are appear exactly once."""
|
||||
one_of_map = cls._field_data_map("one_of")
|
||||
for one_of, field_names in one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
||||
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
||||
return values
|
||||
@classmethod
|
||||
@functools.lru_cache()
|
||||
def _field_data_map(cls, field_data_name):
|
||||
"""Create a map of fields with the provided field data."""
|
||||
schema = cls.schema()
|
||||
one_of = collections.defaultdict(list)
|
||||
for field, fdata in schema["properties"].items():
|
||||
if field_data_name in fdata:
|
||||
one_of[fdata[field_data_name]].append(field)
|
||||
return one_of
|
||||
|
||||
@pydantic.root_validator
|
||||
def _at_most_one_of_check(cls, values):
|
||||
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
||||
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
||||
for one_of, field_names in at_most_one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
||||
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
||||
return values
|
||||
@pydantic.root_validator
|
||||
def _one_of_check(cls, values):
|
||||
"""Validate that all 'one of' fields appear exactly once."""
|
||||
one_of_map = cls._field_data_map("one_of")
|
||||
for one_of, field_names in one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
||||
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
||||
return values
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
"""Return a human legible (yaml) representation of the config useful for logging."""
|
||||
return yaml.dump(self.dict())
|
||||
@pydantic.root_validator
|
||||
def _at_most_one_of_check(cls, values):
|
||||
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
||||
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
||||
for one_of, field_names in at_most_one_of_map.items():
|
||||
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
||||
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
||||
|
||||
return values
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
"""Return a human-readable (YAML) representation of the config useful for logging."""
|
||||
return yaml.dump(self.dict())
|
||||
|
@ -13,6 +13,13 @@ class BaseConfigTest(TestCase):
|
||||
def test_extra_forbidden(self):
|
||||
"""
|
||||
Test that extra fields are forbidden when creating a Config instance.
|
||||
|
||||
This test case checks whether the `BaseConfig` class correctly raises a
|
||||
`pydantic.ValidationError` when extra fields are provided when creating a
|
||||
`Config` instance.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the test fails.
|
||||
"""
|
||||
class Config(BaseConfig):
|
||||
x: int
|
||||
@ -24,6 +31,13 @@ class BaseConfigTest(TestCase):
|
||||
def test_one_of(self):
|
||||
"""
|
||||
Test the use of the `one_of` attribute for fields in a Config instance.
|
||||
|
||||
This test case checks the behavior of the `one_of` attribute in a `Config`
|
||||
instance. It verifies that the `pydantic.Field` correctly enforces the
|
||||
specified constraint.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the test fails.
|
||||
"""
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, one_of="f")
|
||||
@ -39,6 +53,13 @@ class BaseConfigTest(TestCase):
|
||||
def test_at_most_one_of(self):
|
||||
"""
|
||||
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
|
||||
|
||||
This test case checks the behavior of the `at_most_one_of` attribute in a
|
||||
`Config` instance. It verifies that the `pydantic.Field` enforces the
|
||||
constraint where at most one of the specified fields can be provided.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the test fails.
|
||||
"""
|
||||
class Config(BaseConfig):
|
||||
x: int = pydantic.Field(None, at_most_one_of="f")
|
||||
|
@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig
|
||||
|
||||
|
||||
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
|
||||
"""Recommend method to load a config file (a yaml file) and parse it.
|
||||
|
||||
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
||||
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
||||
"""
|
||||
Recommend method to Load and parse a configuration from a YAML file.
|
||||
|
||||
This function loads a configuration from a YAML file, parses it, and returns an instance of the
|
||||
specified config type.
|
||||
|
||||
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
||||
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
||||
|
||||
Args:
|
||||
config_type (Type[BaseConfig]): The Pydantic config class to load.
|
||||
yaml_path (str): The path to the YAML configuration file.
|
||||
|
||||
Returns:
|
||||
BaseConfig: An instance of the specified config type populated with values from the YAML file.
|
||||
|
||||
Example:
|
||||
Suppose you have a YAML file 'my_config.yaml' containing the following:
|
||||
|
||||
```yaml
|
||||
x: 42
|
||||
y: "hello"
|
||||
```
|
||||
|
||||
You can load and parse it using this function as follows:
|
||||
|
||||
```python
|
||||
my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml')
|
||||
```
|
||||
|
||||
Note:
|
||||
This function performs environment variable substitution in the YAML file. It replaces
|
||||
occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable
|
||||
values. If an environment variable does not exist, the string is left unchanged.
|
||||
|
||||
"""
|
||||
|
||||
def _substitute(s):
|
||||
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
||||
|
@ -8,11 +8,27 @@ import pydantic
|
||||
|
||||
|
||||
class _PointlessConfig(BaseConfig):
|
||||
|
||||
a: int
|
||||
user: str
|
||||
|
||||
|
||||
def test_load_config_from_yaml(tmp_path):
|
||||
"""Test loading a configuration from a YAML file and verifying its values.
|
||||
|
||||
This test function checks the functionality of the `load_config_from_yaml` function by creating
|
||||
a temporary YAML configuration file, loading it, and asserting that the loaded config object
|
||||
has the expected values.
|
||||
|
||||
Args:
|
||||
tmp_path: A temporary directory provided by the `pytest` framework.
|
||||
|
||||
Test Steps:
|
||||
1. Create a temporary YAML file containing configuration data.
|
||||
2. Use the `load_config_from_yaml` function to load the configuration from the YAML file.
|
||||
3. Assert that the loaded configuration object has the expected values.
|
||||
|
||||
"""
|
||||
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
|
||||
with open(yaml_path, "w") as yaml_file:
|
||||
yaml_file.write("""a: 3\nuser: ${USER}\n""")
|
||||
|
Loading…
x
Reference in New Issue
Block a user