mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 06:41:49 +01:00
added some more extra explanation and code refactor using IDE settings
This commit is contained in:
parent
78c3235eee
commit
f8500d8727
@ -1,3 +1,10 @@
|
||||
"""
|
||||
This code defines a custom learning rate scheduler for PyTorch by using different learning rate configurations.
|
||||
The scheduler can compute learning rates based on various strategies,
|
||||
such as constant, piecewise constant, linear ramp to constant, and linear ramp to cosine.
|
||||
The code also provides a utility function to build an optimizer and a learning rate scheduler from an OptimizerConfig.
|
||||
"""
|
||||
|
||||
from typing import Dict, Tuple
|
||||
import math
|
||||
import bisect
|
||||
@ -14,7 +21,8 @@ from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
def compute_lr(lr_config, step):
|
||||
"""Compute a learning rate."""
|
||||
"""This function computes the learning rate based on the learning rate configuration (lr_config)
|
||||
and the current training step (step). It handles multiple learning rate strategies."""
|
||||
if lr_config.constant is not None:
|
||||
return lr_config.constant
|
||||
elif lr_config.piecewise_constant is not None:
|
||||
@ -50,6 +58,10 @@ class LRShim(_LRScheduler):
|
||||
|
||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
||||
e.g. exponential decay can be used.
|
||||
|
||||
This class inherits from PyTorch's _LRScheduler and acts as a shim to compute learning rates according to the
|
||||
specified configurations. It also checks if the number of parameter groups in the optimizer matches
|
||||
the length of the learning rate dictionary.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -84,6 +96,9 @@ class LRShim(_LRScheduler):
|
||||
|
||||
|
||||
def get_optimizer_class(optimizer_config: OptimizerConfig):
|
||||
"""
|
||||
This function returns the appropriate PyTorch optimizer class based on the given optimizer_config. It supports Adam, SGD, and Adagrad optimizers.
|
||||
"""
|
||||
if optimizer_config.adam is not None:
|
||||
return torch.optim.Adam
|
||||
elif optimizer_config.sgd is not None:
|
||||
@ -97,6 +112,10 @@ def build_optimizer(
|
||||
) -> Tuple[Optimizer, _LRScheduler]:
|
||||
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
|
||||
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
|
||||
|
||||
This function takes a PyTorch model and an OptimizerConfig object as input and returns a tuple containing the created
|
||||
optimizer and learning rate scheduler. It creates an optimizer instance of the selected optimizer class and
|
||||
initializes the LRShim scheduler with the specified learning rate configurations.
|
||||
"""
|
||||
optimizer_class = get_optimizer_class(optimizer_config)
|
||||
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
||||
|
Loading…
Reference in New Issue
Block a user