diff --git a/common/checkpointing/snapshot.py b/common/checkpointing/snapshot.py index 2703efd..ad1fac1 100644 --- a/common/checkpointing/snapshot.py +++ b/common/checkpointing/snapshot.py @@ -13,23 +13,30 @@ GCS_PREFIX = "gs://" class Snapshot: - """Checkpoints using torchsnapshot. - - Also saves step to be updated by the training loop. - """ + Checkpoints using torchsnapshot. Also saves step to be updated by the training loop. + """ def __init__(self, save_dir: str, state: Dict[str, Any]) -> None: + """ + Initializes a Snapshot object. + + Args: + save_dir (str): Directory where checkpoints will be saved. + state (Dict[str, Any]): State dictionary containing checkpoint information. + """ self.save_dir = save_dir self.state = state self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0) @property def step(self): + """Get the current training step.""" return self.state["extra_state"]["step"] @step.setter def step(self, step: int) -> None: + """Set the current training step.""" self.state["extra_state"]["step"] = step @property @@ -41,7 +48,15 @@ class Snapshot: self.state["extra_state"]["walltime"] = walltime def save(self, global_step: int) -> "PendingSnapshot": - """Saves checkpoint with given global_step.""" + """ + Saves a checkpoint with a given global step. + + Args: + global_step (int): The global step to associate with the checkpoint. + + Returns: + PendingSnapshot: A pending snapshot object. + """ path = os.path.join(self.save_dir, str(global_step)) logging.info(f"Saving snapshot global_step {global_step} to {path}.") start_time = time.time() @@ -58,7 +73,12 @@ class Snapshot: return snapshot def restore(self, checkpoint: str) -> None: - """Restores a given checkpoint.""" + """ + Restores a given checkpoint. + + Args: + checkpoint (str): Path to the checkpoint to restore. + """ snapshot = torchsnapshot.Snapshot(path=checkpoint) logging.info(f"Restoring snapshot from {snapshot.path}.") start_time = time.time() @@ -83,12 +103,17 @@ class Snapshot: global_step: Optional[int] = None, missing_ok: bool = False, ) -> torchsnapshot.Snapshot: - """Get torch stateless snapshot, without actually loading it. - Args: - snapshot_path: path to the model snapshot - global_step: restores from this checkpoint if specified. - missing_ok: if True and checkpoints do not exist, returns without restoration. """ + Get a torch stateless snapshot, without actually loading it. + + Args: + snapshot_path (str): Path to the model snapshot. + global_step (int, optional): Restores from this checkpoint if specified. + missing_ok (bool): If True and checkpoints do not exist, returns without restoration. + + Returns: + torchsnapshot.Snapshot: A torch snapshot object. + """ path = get_checkpoint(snapshot_path, global_step, missing_ok) logging.info(f"Loading snapshot from {path}.") return torchsnapshot.Snapshot(path=path) @@ -100,13 +125,14 @@ class Snapshot: snapshot_emb_name: str, weight_tensor, ) -> None: - """Loads pretrained embedding from the snapshot to the model. - Utilise partial lodaing meachanism from torchsnapshot. - Args: - embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC). - snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC. - weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded. """ + Loads pretrained embedding from the snapshot to the model. + + Args: + embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC). + snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC. + weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded. + """ start_time = time.time() manifest = embedding_snapshot.get_manifest() for path in manifest.keys(): @@ -209,7 +235,22 @@ def get_checkpoint( def get_checkpoints(save_dir: str) -> List[str]: - """Gets all checkpoints that have been fully written.""" + """ + Get a list of fully written checkpoints in the specified directory. + + This function retrieves a list of fully written checkpoints in the given directory. + Checkpoints that are considered fully written include those that have a + corresponding snapshot metadata file. + + Args: + save_dir (str): The directory where checkpoints are stored. + + Returns: + List[str]: A list of fully written checkpoint paths. + + Note: + Checkpoints are sorted by their numeric filenames in ascending order. + """ checkpoints = [] fs = infer_fs(save_dir) if fs.exists(save_dir): @@ -232,6 +273,18 @@ def wait_for_evaluators( global_step: int, timeout: int, ) -> None: + """ + Waits for all evaluators to finish and checks for their completion status. + + Args: + save_dir (str): Directory where checkpoints are saved. + partition_names (List[str]): List of partition names to check for completion. + global_step (int): The global step for which to wait for evaluators. + timeout (int): Maximum time in seconds to wait for evaluators to finish. + + Returns: + None: This function returns nothing but logs the progress and results. + """ logging.info("Waiting for all evaluators to finish.") start_time = time.time() diff --git a/core/config/base_config.py b/core/config/base_config.py index 1c15d04..76a2022 100644 --- a/core/config/base_config.py +++ b/core/config/base_config.py @@ -8,59 +8,69 @@ import pydantic class BaseConfig(pydantic.BaseModel): - """Base class for all derived config classes. + """Base class for all derived config classes. - This class provides some convenient functionality: - - Disallows extra fields when constructing an object. User error - should be reduced by exact arguments. - - "one_of" fields. A subclass can group optional fields and enforce - that only one of the fields be set. For example: + This class provides convenient functionality and constraints for derived config classes: - ``` + - Disallows extra fields when constructing an object. User errors due to extraneous arguments + are minimized. + - "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields + be set. For example: + + ```python class ExampleConfig(BaseConfig): - x: int = Field(None, one_of="group_1") - y: int = Field(None, one_of="group_1") + x: int = Field(None, one_of="group_1") + y: int = Field(None, one_of="group_1") - ExampleConfig(x=1) # ok - ExampleConfig(y=1) # ok - ExampleConfig(x=1, y=1) # throws error + ExampleConfig(x=1) # OK + ExampleConfig(y=1) # OK + ExampleConfig(x=1, y=1) # Raises an error ``` - """ - class Config: - """Forbids extras.""" + Attributes: + Config (class): Configuration options for this class, forbidding extra fields. - extra = pydantic.Extra.forbid # noqa + Methods: + _field_data_map(cls, field_data_name): Create a map of fields with the provided field data. + _one_of_check(cls, values): Validate that all 'one of' fields appear exactly once. + _at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once. + pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging. - @classmethod - @functools.lru_cache() - def _field_data_map(cls, field_data_name): - """Create a map of fields with provided the field data.""" - schema = cls.schema() - one_of = collections.defaultdict(list) - for field, fdata in schema["properties"].items(): - if field_data_name in fdata: - one_of[fdata[field_data_name]].append(field) - return one_of + """ + class Config: + """Configuration options that forbid extra fields.""" + extra = pydantic.Extra.forbid # noqa - @pydantic.root_validator - def _one_of_check(cls, values): - """Validate that all 'one of' fields are appear exactly once.""" - one_of_map = cls._field_data_map("one_of") - for one_of, field_names in one_of_map.items(): - if sum([values.get(n, None) is not None for n in field_names]) != 1: - raise ValueError(f"Exactly one of {','.join(field_names)} required.") - return values + @classmethod + @functools.lru_cache() + def _field_data_map(cls, field_data_name): + """Create a map of fields with the provided field data.""" + schema = cls.schema() + one_of = collections.defaultdict(list) + for field, fdata in schema["properties"].items(): + if field_data_name in fdata: + one_of[fdata[field_data_name]].append(field) + return one_of - @pydantic.root_validator - def _at_most_one_of_check(cls, values): - """Validate that all 'at_most_one_of' fields appear at most once.""" - at_most_one_of_map = cls._field_data_map("at_most_one_of") - for one_of, field_names in at_most_one_of_map.items(): - if sum([values.get(n, None) is not None for n in field_names]) > 1: - raise ValueError(f"At most one of {','.join(field_names)} can be set.") - return values + @pydantic.root_validator + def _one_of_check(cls, values): + """Validate that all 'one of' fields appear exactly once.""" + one_of_map = cls._field_data_map("one_of") + for one_of, field_names in one_of_map.items(): + if sum([values.get(n, None) is not None for n in field_names]) != 1: + raise ValueError(f"Exactly one of {','.join(field_names)} required.") + return values - def pretty_print(self) -> str: - """Return a human legible (yaml) representation of the config useful for logging.""" - return yaml.dump(self.dict()) + @pydantic.root_validator + def _at_most_one_of_check(cls, values): + """Validate that all 'at_most_one_of' fields appear at most once.""" + at_most_one_of_map = cls._field_data_map("at_most_one_of") + for one_of, field_names in at_most_one_of_map.items(): + if sum([values.get(n, None) is not None for n in field_names]) > 1: + raise ValueError(f"At most one of {','.join(field_names)} can be set.") + + return values + + def pretty_print(self) -> str: + """Return a human-readable (YAML) representation of the config useful for logging.""" + return yaml.dump(self.dict()) diff --git a/core/config/base_config_test.py b/core/config/base_config_test.py index e322116..c4319e5 100644 --- a/core/config/base_config_test.py +++ b/core/config/base_config_test.py @@ -13,6 +13,13 @@ class BaseConfigTest(TestCase): def test_extra_forbidden(self): """ Test that extra fields are forbidden when creating a Config instance. + + This test case checks whether the `BaseConfig` class correctly raises a + `pydantic.ValidationError` when extra fields are provided when creating a + `Config` instance. + + Raises: + AssertionError: If the test fails. """ class Config(BaseConfig): x: int @@ -24,6 +31,13 @@ class BaseConfigTest(TestCase): def test_one_of(self): """ Test the use of the `one_of` attribute for fields in a Config instance. + + This test case checks the behavior of the `one_of` attribute in a `Config` + instance. It verifies that the `pydantic.Field` correctly enforces the + specified constraint. + + Raises: + AssertionError: If the test fails. """ class Config(BaseConfig): x: int = pydantic.Field(None, one_of="f") @@ -39,6 +53,13 @@ class BaseConfigTest(TestCase): def test_at_most_one_of(self): """ Test the use of the `at_most_one_of` attribute for fields in a Config instance. + + This test case checks the behavior of the `at_most_one_of` attribute in a + `Config` instance. It verifies that the `pydantic.Field` enforces the + constraint where at most one of the specified fields can be provided. + + Raises: + AssertionError: If the test fails. """ class Config(BaseConfig): x: int = pydantic.Field(None, at_most_one_of="f") diff --git a/core/config/config_load.py b/core/config/config_load.py index 709da41..f04c783 100644 --- a/core/config/config_load.py +++ b/core/config/config_load.py @@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str): - """Recommend method to load a config file (a yaml file) and parse it. - - Because we have a shared filesystem the recommended route to running jobs it put modified config - files with the desired parameters somewhere on the filesytem and run jobs pointing to them. """ + Recommend method to Load and parse a configuration from a YAML file. + + This function loads a configuration from a YAML file, parses it, and returns an instance of the + specified config type. + + Because we have a shared filesystem the recommended route to running jobs it put modified config + files with the desired parameters somewhere on the filesytem and run jobs pointing to them. + + Args: + config_type (Type[BaseConfig]): The Pydantic config class to load. + yaml_path (str): The path to the YAML configuration file. + + Returns: + BaseConfig: An instance of the specified config type populated with values from the YAML file. + + Example: + Suppose you have a YAML file 'my_config.yaml' containing the following: + + ```yaml + x: 42 + y: "hello" + ``` + + You can load and parse it using this function as follows: + + ```python + my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml') + ``` + + Note: + This function performs environment variable substitution in the YAML file. It replaces + occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable + values. If an environment variable does not exist, the string is left unchanged. + + """ def _substitute(s): return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) diff --git a/core/config/test_config_load.py b/core/config/test_config_load.py index 0fc88e2..1ad3942 100644 --- a/core/config/test_config_load.py +++ b/core/config/test_config_load.py @@ -8,11 +8,27 @@ import pydantic class _PointlessConfig(BaseConfig): + a: int user: str def test_load_config_from_yaml(tmp_path): + """Test loading a configuration from a YAML file and verifying its values. + + This test function checks the functionality of the `load_config_from_yaml` function by creating + a temporary YAML configuration file, loading it, and asserting that the loaded config object + has the expected values. + + Args: + tmp_path: A temporary directory provided by the `pytest` framework. + + Test Steps: + 1. Create a temporary YAML file containing configuration data. + 2. Use the `load_config_from_yaml` function to load the configuration from the YAML file. + 3. Assert that the loaded configuration object has the expected values. + + """ yaml_path = tmp_path.joinpath("test.yaml").as_posix() with open(yaml_path, "w") as yaml_file: yaml_file.write("""a: 3\nuser: ${USER}\n""")