added some more extra explanation and code refactor using IDE settings

This commit is contained in:
Fudeel 2023-04-04 01:00:43 +02:00
parent 78c3235eee
commit f8500d8727

View File

@ -1,3 +1,10 @@
"""
This code defines a custom learning rate scheduler for PyTorch by using different learning rate configurations.
The scheduler can compute learning rates based on various strategies,
such as constant, piecewise constant, linear ramp to constant, and linear ramp to cosine.
The code also provides a utility function to build an optimizer and a learning rate scheduler from an OptimizerConfig.
"""
from typing import Dict, Tuple from typing import Dict, Tuple
import math import math
import bisect import bisect
@ -14,7 +21,8 @@ from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step): def compute_lr(lr_config, step):
"""Compute a learning rate.""" """This function computes the learning rate based on the learning rate configuration (lr_config)
and the current training step (step). It handles multiple learning rate strategies."""
if lr_config.constant is not None: if lr_config.constant is not None:
return lr_config.constant return lr_config.constant
elif lr_config.piecewise_constant is not None: elif lr_config.piecewise_constant is not None:
@ -50,6 +58,10 @@ class LRShim(_LRScheduler):
This adheres to the torch.optim scheduler API and can be plugged anywhere that This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used. e.g. exponential decay can be used.
This class inherits from PyTorch's _LRScheduler and acts as a shim to compute learning rates according to the
specified configurations. It also checks if the number of parameter groups in the optimizer matches
the length of the learning rate dictionary.
""" """
def __init__( def __init__(
@ -84,6 +96,9 @@ class LRShim(_LRScheduler):
def get_optimizer_class(optimizer_config: OptimizerConfig): def get_optimizer_class(optimizer_config: OptimizerConfig):
"""
This function returns the appropriate PyTorch optimizer class based on the given optimizer_config. It supports Adam, SGD, and Adagrad optimizers.
"""
if optimizer_config.adam is not None: if optimizer_config.adam is not None:
return torch.optim.Adam return torch.optim.Adam
elif optimizer_config.sgd is not None: elif optimizer_config.sgd is not None:
@ -97,6 +112,10 @@ def build_optimizer(
) -> Tuple[Optimizer, _LRScheduler]: ) -> Tuple[Optimizer, _LRScheduler]:
"""Builds an optimizer and LR scheduler from an OptimizerConfig. """Builds an optimizer and LR scheduler from an OptimizerConfig.
Note: use this when you want the same optimizer and learning rate schedule for all your parameters. Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
This function takes a PyTorch model and an OptimizerConfig object as input and returns a tuple containing the created
optimizer and learning rate scheduler. It creates an optimizer instance of the selected optimizer class and
initializes the LRShim scheduler with the specified learning rate configurations.
""" """
optimizer_class = get_optimizer_class(optimizer_config) optimizer_class = get_optimizer_class(optimizer_config)
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict()) optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())