diff --git a/core/custom_training_loop.py b/core/custom_training_loop.py index 0241145..57aba03 100644 --- a/core/custom_training_loop.py +++ b/core/custom_training_loop.py @@ -8,14 +8,22 @@ Features (go/100_enablements): - go/dataset-service 0-copy integration """ -import datetime import os -from typing import Callable, Dict, Iterable, List, Mapping, Optional +import datetime +from typing import ( + Callable, + Dict, + Iterable, + List, + Mapping, + Optional +) + from tml.common import log_weights -import tml.common.checkpointing.snapshot as snapshot_lib from tml.core.losses import get_global_loss_detached +import tml.common.checkpointing.snapshot as snapshot_lib from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] from tml.core.train_pipeline import TrainPipelineSparseDist