the-algorithm-ml/core/debug_training_loop.py
rajveer43 799254345f core update
remaning train_pipline.py
2023-09-11 16:26:29 +05:30

58 lines
1.7 KiB
Python

"""This is a very limited feature training loop useful for interactive debugging.
It is not intended for actual model tranining (it is not fast, doesn't compile the model).
It does not support checkpointing.
suggested use:
from tml.core import debug_training_loop
debug_training_loop.train(...)
"""
from typing import Iterable, Optional, Dict, Callable, List
import torch
from torch.optim.lr_scheduler import _LRScheduler
import torchmetrics as tm
from tml.ml_logging.torch_logging import logging
def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_steps: int,
dataset: Iterable,
scheduler: _LRScheduler = None,
# Accept any arguments (to be compatible with the real training loop)
# but just ignore them.
*args,
**kwargs,
) -> None:
"""
Debugging training loop. Do not use for actual model training.
Args:
model (torch.nn.Module): The neural network model.
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
train_steps (int): The number of training steps to perform.
dataset (Iterable): Data iterator for training data.
scheduler (_LRScheduler, optional): Learning rate scheduler (default: None).
*args: Additional arguments (ignored).
**kwargs: Additional keyword arguments (ignored).
"""
logging.warning("Running debug training loop, don't use for model training.")
data_iter = iter(dataset)
for step in range(0, train_steps + 1):
x = next(data_iter)
optimizer.zero_grad()
loss, outputs = model.forward(x)
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
logging.info(f"Step {step} completed. Loss = {loss}")