the-algorithm-ml/common/checkpointing/snapshot.py

314 lines
10 KiB
Python
Raw Normal View History

import os
import time
from typing import Any, Dict, List, Optional
from tml.ml_logging.torch_logging import logging
from tml.common.filesystem import infer_fs, is_gcs_fs
import torchsnapshot
DONE_EVAL_SUBDIR = "evaled_by"
GCS_PREFIX = "gs://"
class Snapshot:
"""
2023-09-12 14:42:05 +02:00
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
"""
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
2023-09-12 14:42:05 +02:00
"""
Initializes a Snapshot object.
Args:
save_dir (str): Directory where checkpoints will be saved.
state (Dict[str, Any]): State dictionary containing checkpoint information.
"""
self.save_dir = save_dir
self.state = state
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property
def step(self):
2023-09-12 14:42:05 +02:00
"""Get the current training step."""
return self.state["extra_state"]["step"]
@step.setter
def step(self, step: int) -> None:
2023-09-12 14:42:05 +02:00
"""Set the current training step."""
self.state["extra_state"]["step"] = step
@property
def walltime(self):
return self.state["extra_state"]["walltime"]
@walltime.setter
def walltime(self, walltime: float) -> None:
self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot":
2023-09-12 14:42:05 +02:00
"""
Saves a checkpoint with a given global step.
Args:
global_step (int): The global step to associate with the checkpoint.
Returns:
PendingSnapshot: A pending snapshot object.
"""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
start_time = time.time()
# Take a snapshot in async manner, the snapshot is consistent that state changes after this method returns have no effect on the snapshot. It performs storage I/O in the background.
snapshot = torchsnapshot.Snapshot.async_take(
app_state=self.state,
path=path,
# commented out because DistributedModelParallel model saving
# errors with this on multi-GPU. With it removed, CPU, single
# GPU, and multi-GPU training all successfully checkpoint.
# replicated=["**"],
)
logging.info(f"Snapshot saved to {snapshot.path} ({time.time() - start_time:.05}s")
return snapshot
def restore(self, checkpoint: str) -> None:
2023-09-12 14:42:05 +02:00
"""
Restores a given checkpoint.
Args:
checkpoint (str): Path to the checkpoint to restore.
"""
snapshot = torchsnapshot.Snapshot(path=checkpoint)
logging.info(f"Restoring snapshot from {snapshot.path}.")
start_time = time.time()
# We can remove the try-except when we are confident that we no longer need to restore from
# checkpoints from before walltime was added
try:
# checkpoints that do not have extra_state[walltime] will fail here
snapshot.restore(self.state)
except RuntimeError:
# extra_state[walltime] does not exist in the checkpoint, but step should be there so restore it
self.state["extra_state"] = torchsnapshot.StateDict(step=0)
snapshot.restore(self.state)
# we still need to ensure that extra_state has walltime in it
self.state["extra_state"] = torchsnapshot.StateDict(step=self.step, walltime=0.0)
logging.info(f"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s")
@classmethod
def get_torch_snapshot(
cls,
snapshot_path: str,
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> torchsnapshot.Snapshot:
"""
2023-09-12 14:42:05 +02:00
Get a torch stateless snapshot, without actually loading it.
Args:
snapshot_path (str): Path to the model snapshot.
global_step (int, optional): Restores from this checkpoint if specified.
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
Returns:
torchsnapshot.Snapshot: A torch snapshot object.
"""
path = get_checkpoint(snapshot_path, global_step, missing_ok)
logging.info(f"Loading snapshot from {path}.")
return torchsnapshot.Snapshot(path=path)
@classmethod
def load_snapshot_to_weight(
cls,
embedding_snapshot: torchsnapshot.Snapshot,
snapshot_emb_name: str,
weight_tensor,
) -> None:
"""
2023-09-12 14:42:05 +02:00
Loads pretrained embedding from the snapshot to the model.
Args:
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
"""
start_time = time.time()
manifest = embedding_snapshot.get_manifest()
for path in manifest.keys():
if path.startswith("0") and snapshot_emb_name in path:
snapshot_path_to_load = path
embedding_snapshot.read_object(snapshot_path_to_load, weight_tensor)
logging.info(
f"Loaded embedding snapshot from {snapshot_path_to_load}: {time.time() - start_time:.05}s",
rank=-1,
)
logging.info(f"Snapshot loaded to {weight_tensor.metadata()}", rank=-1)
def _eval_subdir(checkpoint_path: str) -> str:
return os.path.join(checkpoint_path, DONE_EVAL_SUBDIR)
def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
def is_done_eval(checkpoint_path: str, eval_partition: str):
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
def mark_done_eval(checkpoint_path: str, eval_partition: str):
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
def step_from_checkpoint(checkpoint: str) -> int:
return int(os.path.basename(checkpoint))
def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
"""Simplified equivalent of tf.train.checkpoints_iterator.
Args:
seconds_to_sleep: time between polling calls.
timeout: how long to wait for a new checkpoint.
"""
def _poll(last_checkpoint: Optional[str] = None):
stop_time = time.time() + timeout
while True:
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
if not _checkpoint_path or _checkpoint_path == last_checkpoint:
if time.time() + seconds_to_sleep > stop_time:
logging.info(
f"Timed out waiting for next available checkpoint from {save_dir} for {timeout}s."
)
return None
logging.info(f"Waiting for next available checkpoint from {save_dir}.")
time.sleep(seconds_to_sleep)
else:
logging.info(f"Found latest checkpoint {_checkpoint_path}.")
return _checkpoint_path
checkpoint_path = None
while True:
new_checkpoint = _poll(checkpoint_path)
if not new_checkpoint:
return
checkpoint_path = new_checkpoint
yield checkpoint_path
def get_checkpoint(
save_dir: str,
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> str:
"""Gets latest checkpoint or checkpoint at specified global_step.
Args:
global_step: Finds this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
checkpoints = get_checkpoints(save_dir)
if not checkpoints:
if not missing_ok:
raise Exception(f"No checkpoints found at {save_dir}")
else:
logging.info(f"No checkpoints found for restoration at {save_dir}.")
return ""
if global_step is None:
return checkpoints[-1]
logging.info(f"Found checkpoints: {checkpoints}")
for checkpoint in checkpoints:
step = step_from_checkpoint(checkpoint)
if global_step == step:
chosen_checkpoint = checkpoint
break
else:
raise Exception(f"Desired checkpoint at {global_step} not found in {save_dir}")
return chosen_checkpoint
def get_checkpoints(save_dir: str) -> List[str]:
2023-09-12 14:42:05 +02:00
"""
Get a list of fully written checkpoints in the specified directory.
This function retrieves a list of fully written checkpoints in the given directory.
Checkpoints that are considered fully written include those that have a
corresponding snapshot metadata file.
Args:
save_dir (str): The directory where checkpoints are stored.
Returns:
List[str]: A list of fully written checkpoint paths.
Note:
Checkpoints are sorted by their numeric filenames in ascending order.
"""
checkpoints = []
fs = infer_fs(save_dir)
if fs.exists(save_dir):
prefix = GCS_PREFIX if is_gcs_fs(fs) else ""
checkpoints = list(f"{prefix}{elem}" for elem in fs.ls(save_dir, detail=False))
# Only take checkpoints that were fully written.
checkpoints = list(
filter(
lambda path: fs.exists(f"{path}/{torchsnapshot.snapshot.SNAPSHOT_METADATA_FNAME}"),
checkpoints,
)
)
checkpoints = sorted(checkpoints, key=lambda path: int(os.path.basename(path)))
return checkpoints
def wait_for_evaluators(
save_dir: str,
partition_names: List[str],
global_step: int,
timeout: int,
) -> None:
2023-09-12 14:42:05 +02:00
"""
Waits for all evaluators to finish and checks for their completion status.
Args:
save_dir (str): Directory where checkpoints are saved.
partition_names (List[str]): List of partition names to check for completion.
global_step (int): The global step for which to wait for evaluators.
timeout (int): Maximum time in seconds to wait for evaluators to finish.
Returns:
None: This function returns nothing but logs the progress and results.
"""
logging.info("Waiting for all evaluators to finish.")
start_time = time.time()
for checkpoint in checkpoints_iterator(save_dir):
step = step_from_checkpoint(checkpoint)
logging.info(f"Considering checkpoint {checkpoint} for global step {global_step}.")
if step == global_step:
while partition_names:
if is_done_eval(checkpoint, partition_names[-1]):
logging.info(
f"Checkpoint {checkpoint} marked as finished eval for partition {partition_names[-1]} at step {step}, still waiting for {partition_names}."
)
partition_names.pop()
if time.time() - start_time >= timeout:
logging.warning(
f"Not all evaluators finished after waiting for {time.time() - start_time}"
)
return
time.sleep(10)
logging.info("All evaluators finished.")
return
if time.time() - start_time >= timeout:
logging.warning(f"Not all evaluators finished after waiting for {time.time() - start_time}")
return