This commit is contained in:
rajveer43 2023-09-12 18:12:05 +05:30
parent deec9a820e
commit 92eaaad3ed
5 changed files with 198 additions and 67 deletions

View File

@ -13,23 +13,30 @@ GCS_PREFIX = "gs://"
class Snapshot: class Snapshot:
"""Checkpoints using torchsnapshot.
Also saves step to be updated by the training loop.
""" """
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
"""
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None: def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
"""
Initializes a Snapshot object.
Args:
save_dir (str): Directory where checkpoints will be saved.
state (Dict[str, Any]): State dictionary containing checkpoint information.
"""
self.save_dir = save_dir self.save_dir = save_dir
self.state = state self.state = state
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0) self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property @property
def step(self): def step(self):
"""Get the current training step."""
return self.state["extra_state"]["step"] return self.state["extra_state"]["step"]
@step.setter @step.setter
def step(self, step: int) -> None: def step(self, step: int) -> None:
"""Set the current training step."""
self.state["extra_state"]["step"] = step self.state["extra_state"]["step"] = step
@property @property
@ -41,7 +48,15 @@ class Snapshot:
self.state["extra_state"]["walltime"] = walltime self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot": def save(self, global_step: int) -> "PendingSnapshot":
"""Saves checkpoint with given global_step.""" """
Saves a checkpoint with a given global step.
Args:
global_step (int): The global step to associate with the checkpoint.
Returns:
PendingSnapshot: A pending snapshot object.
"""
path = os.path.join(self.save_dir, str(global_step)) path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.") logging.info(f"Saving snapshot global_step {global_step} to {path}.")
start_time = time.time() start_time = time.time()
@ -58,7 +73,12 @@ class Snapshot:
return snapshot return snapshot
def restore(self, checkpoint: str) -> None: def restore(self, checkpoint: str) -> None:
"""Restores a given checkpoint.""" """
Restores a given checkpoint.
Args:
checkpoint (str): Path to the checkpoint to restore.
"""
snapshot = torchsnapshot.Snapshot(path=checkpoint) snapshot = torchsnapshot.Snapshot(path=checkpoint)
logging.info(f"Restoring snapshot from {snapshot.path}.") logging.info(f"Restoring snapshot from {snapshot.path}.")
start_time = time.time() start_time = time.time()
@ -83,12 +103,17 @@ class Snapshot:
global_step: Optional[int] = None, global_step: Optional[int] = None,
missing_ok: bool = False, missing_ok: bool = False,
) -> torchsnapshot.Snapshot: ) -> torchsnapshot.Snapshot:
"""Get torch stateless snapshot, without actually loading it.
Args:
snapshot_path: path to the model snapshot
global_step: restores from this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
""" """
Get a torch stateless snapshot, without actually loading it.
Args:
snapshot_path (str): Path to the model snapshot.
global_step (int, optional): Restores from this checkpoint if specified.
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
Returns:
torchsnapshot.Snapshot: A torch snapshot object.
"""
path = get_checkpoint(snapshot_path, global_step, missing_ok) path = get_checkpoint(snapshot_path, global_step, missing_ok)
logging.info(f"Loading snapshot from {path}.") logging.info(f"Loading snapshot from {path}.")
return torchsnapshot.Snapshot(path=path) return torchsnapshot.Snapshot(path=path)
@ -100,13 +125,14 @@ class Snapshot:
snapshot_emb_name: str, snapshot_emb_name: str,
weight_tensor, weight_tensor,
) -> None: ) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
Args:
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
""" """
Loads pretrained embedding from the snapshot to the model.
Args:
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
"""
start_time = time.time() start_time = time.time()
manifest = embedding_snapshot.get_manifest() manifest = embedding_snapshot.get_manifest()
for path in manifest.keys(): for path in manifest.keys():
@ -209,7 +235,22 @@ def get_checkpoint(
def get_checkpoints(save_dir: str) -> List[str]: def get_checkpoints(save_dir: str) -> List[str]:
"""Gets all checkpoints that have been fully written.""" """
Get a list of fully written checkpoints in the specified directory.
This function retrieves a list of fully written checkpoints in the given directory.
Checkpoints that are considered fully written include those that have a
corresponding snapshot metadata file.
Args:
save_dir (str): The directory where checkpoints are stored.
Returns:
List[str]: A list of fully written checkpoint paths.
Note:
Checkpoints are sorted by their numeric filenames in ascending order.
"""
checkpoints = [] checkpoints = []
fs = infer_fs(save_dir) fs = infer_fs(save_dir)
if fs.exists(save_dir): if fs.exists(save_dir):
@ -232,6 +273,18 @@ def wait_for_evaluators(
global_step: int, global_step: int,
timeout: int, timeout: int,
) -> None: ) -> None:
"""
Waits for all evaluators to finish and checks for their completion status.
Args:
save_dir (str): Directory where checkpoints are saved.
partition_names (List[str]): List of partition names to check for completion.
global_step (int): The global step for which to wait for evaluators.
timeout (int): Maximum time in seconds to wait for evaluators to finish.
Returns:
None: This function returns nothing but logs the progress and results.
"""
logging.info("Waiting for all evaluators to finish.") logging.info("Waiting for all evaluators to finish.")
start_time = time.time() start_time = time.time()

View File

@ -8,59 +8,69 @@ import pydantic
class BaseConfig(pydantic.BaseModel): class BaseConfig(pydantic.BaseModel):
"""Base class for all derived config classes. """Base class for all derived config classes.
This class provides some convenient functionality: This class provides convenient functionality and constraints for derived config classes:
- Disallows extra fields when constructing an object. User error
should be reduced by exact arguments.
- "one_of" fields. A subclass can group optional fields and enforce
that only one of the fields be set. For example:
``` - Disallows extra fields when constructing an object. User errors due to extraneous arguments
are minimized.
- "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields
be set. For example:
```python
class ExampleConfig(BaseConfig): class ExampleConfig(BaseConfig):
x: int = Field(None, one_of="group_1") x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1") y: int = Field(None, one_of="group_1")
ExampleConfig(x=1) # ok ExampleConfig(x=1) # OK
ExampleConfig(y=1) # ok ExampleConfig(y=1) # OK
ExampleConfig(x=1, y=1) # throws error ExampleConfig(x=1, y=1) # Raises an error
``` ```
"""
class Config: Attributes:
"""Forbids extras.""" Config (class): Configuration options for this class, forbidding extra fields.
extra = pydantic.Extra.forbid # noqa Methods:
_field_data_map(cls, field_data_name): Create a map of fields with the provided field data.
_one_of_check(cls, values): Validate that all 'one of' fields appear exactly once.
_at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once.
pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging.
@classmethod """
@functools.lru_cache() class Config:
def _field_data_map(cls, field_data_name): """Configuration options that forbid extra fields."""
"""Create a map of fields with provided the field data.""" extra = pydantic.Extra.forbid # noqa
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
@pydantic.root_validator @classmethod
def _one_of_check(cls, values): @functools.lru_cache()
"""Validate that all 'one of' fields are appear exactly once.""" def _field_data_map(cls, field_data_name):
one_of_map = cls._field_data_map("one_of") """Create a map of fields with the provided field data."""
for one_of, field_names in one_of_map.items(): schema = cls.schema()
if sum([values.get(n, None) is not None for n in field_names]) != 1: one_of = collections.defaultdict(list)
raise ValueError(f"Exactly one of {','.join(field_names)} required.") for field, fdata in schema["properties"].items():
return values if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
@pydantic.root_validator @pydantic.root_validator
def _at_most_one_of_check(cls, values): def _one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once.""" """Validate that all 'one of' fields appear exactly once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of") one_of_map = cls._field_data_map("one_of")
for one_of, field_names in at_most_one_of_map.items(): for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1: if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.") raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values return values
def pretty_print(self) -> str: @pydantic.root_validator
"""Return a human legible (yaml) representation of the config useful for logging.""" def _at_most_one_of_check(cls, values):
return yaml.dump(self.dict()) """Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
def pretty_print(self) -> str:
"""Return a human-readable (YAML) representation of the config useful for logging."""
return yaml.dump(self.dict())

View File

@ -13,6 +13,13 @@ class BaseConfigTest(TestCase):
def test_extra_forbidden(self): def test_extra_forbidden(self):
""" """
Test that extra fields are forbidden when creating a Config instance. Test that extra fields are forbidden when creating a Config instance.
This test case checks whether the `BaseConfig` class correctly raises a
`pydantic.ValidationError` when extra fields are provided when creating a
`Config` instance.
Raises:
AssertionError: If the test fails.
""" """
class Config(BaseConfig): class Config(BaseConfig):
x: int x: int
@ -24,6 +31,13 @@ class BaseConfigTest(TestCase):
def test_one_of(self): def test_one_of(self):
""" """
Test the use of the `one_of` attribute for fields in a Config instance. Test the use of the `one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `one_of` attribute in a `Config`
instance. It verifies that the `pydantic.Field` correctly enforces the
specified constraint.
Raises:
AssertionError: If the test fails.
""" """
class Config(BaseConfig): class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f") x: int = pydantic.Field(None, one_of="f")
@ -39,6 +53,13 @@ class BaseConfigTest(TestCase):
def test_at_most_one_of(self): def test_at_most_one_of(self):
""" """
Test the use of the `at_most_one_of` attribute for fields in a Config instance. Test the use of the `at_most_one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `at_most_one_of` attribute in a
`Config` instance. It verifies that the `pydantic.Field` enforces the
constraint where at most one of the specified fields can be provided.
Raises:
AssertionError: If the test fails.
""" """
class Config(BaseConfig): class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f") x: int = pydantic.Field(None, at_most_one_of="f")

View File

@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str): def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
"""Recommend method to load a config file (a yaml file) and parse it.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
""" """
Recommend method to Load and parse a configuration from a YAML file.
This function loads a configuration from a YAML file, parses it, and returns an instance of the
specified config type.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
Args:
config_type (Type[BaseConfig]): The Pydantic config class to load.
yaml_path (str): The path to the YAML configuration file.
Returns:
BaseConfig: An instance of the specified config type populated with values from the YAML file.
Example:
Suppose you have a YAML file 'my_config.yaml' containing the following:
```yaml
x: 42
y: "hello"
```
You can load and parse it using this function as follows:
```python
my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml')
```
Note:
This function performs environment variable substitution in the YAML file. It replaces
occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable
values. If an environment variable does not exist, the string is left unchanged.
"""
def _substitute(s): def _substitute(s):
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())

View File

@ -8,11 +8,27 @@ import pydantic
class _PointlessConfig(BaseConfig): class _PointlessConfig(BaseConfig):
a: int a: int
user: str user: str
def test_load_config_from_yaml(tmp_path): def test_load_config_from_yaml(tmp_path):
"""Test loading a configuration from a YAML file and verifying its values.
This test function checks the functionality of the `load_config_from_yaml` function by creating
a temporary YAML configuration file, loading it, and asserting that the loaded config object
has the expected values.
Args:
tmp_path: A temporary directory provided by the `pytest` framework.
Test Steps:
1. Create a temporary YAML file containing configuration data.
2. Use the `load_config_from_yaml` function to load the configuration from the YAML file.
3. Assert that the loaded configuration object has the expected values.
"""
yaml_path = tmp_path.joinpath("test.yaml").as_posix() yaml_path = tmp_path.joinpath("test.yaml").as_posix()
with open(yaml_path, "w") as yaml_file: with open(yaml_path, "w") as yaml_file:
yaml_file.write("""a: 3\nuser: ${USER}\n""") yaml_file.write("""a: 3\nuser: ${USER}\n""")