mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 06:41:49 +01:00
added some more extra explanation and code refactor using IDE settings
This commit is contained in:
parent
78c3235eee
commit
f8500d8727
@ -1,3 +1,10 @@
|
|||||||
|
"""
|
||||||
|
This code defines a custom learning rate scheduler for PyTorch by using different learning rate configurations.
|
||||||
|
The scheduler can compute learning rates based on various strategies,
|
||||||
|
such as constant, piecewise constant, linear ramp to constant, and linear ramp to cosine.
|
||||||
|
The code also provides a utility function to build an optimizer and a learning rate scheduler from an OptimizerConfig.
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
import math
|
import math
|
||||||
import bisect
|
import bisect
|
||||||
@ -14,7 +21,8 @@ from tml.ml_logging.torch_logging import logging
|
|||||||
|
|
||||||
|
|
||||||
def compute_lr(lr_config, step):
|
def compute_lr(lr_config, step):
|
||||||
"""Compute a learning rate."""
|
"""This function computes the learning rate based on the learning rate configuration (lr_config)
|
||||||
|
and the current training step (step). It handles multiple learning rate strategies."""
|
||||||
if lr_config.constant is not None:
|
if lr_config.constant is not None:
|
||||||
return lr_config.constant
|
return lr_config.constant
|
||||||
elif lr_config.piecewise_constant is not None:
|
elif lr_config.piecewise_constant is not None:
|
||||||
@ -50,6 +58,10 @@ class LRShim(_LRScheduler):
|
|||||||
|
|
||||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
||||||
e.g. exponential decay can be used.
|
e.g. exponential decay can be used.
|
||||||
|
|
||||||
|
This class inherits from PyTorch's _LRScheduler and acts as a shim to compute learning rates according to the
|
||||||
|
specified configurations. It also checks if the number of parameter groups in the optimizer matches
|
||||||
|
the length of the learning rate dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -84,6 +96,9 @@ class LRShim(_LRScheduler):
|
|||||||
|
|
||||||
|
|
||||||
def get_optimizer_class(optimizer_config: OptimizerConfig):
|
def get_optimizer_class(optimizer_config: OptimizerConfig):
|
||||||
|
"""
|
||||||
|
This function returns the appropriate PyTorch optimizer class based on the given optimizer_config. It supports Adam, SGD, and Adagrad optimizers.
|
||||||
|
"""
|
||||||
if optimizer_config.adam is not None:
|
if optimizer_config.adam is not None:
|
||||||
return torch.optim.Adam
|
return torch.optim.Adam
|
||||||
elif optimizer_config.sgd is not None:
|
elif optimizer_config.sgd is not None:
|
||||||
@ -97,6 +112,10 @@ def build_optimizer(
|
|||||||
) -> Tuple[Optimizer, _LRScheduler]:
|
) -> Tuple[Optimizer, _LRScheduler]:
|
||||||
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
|
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
|
||||||
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
|
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
|
||||||
|
|
||||||
|
This function takes a PyTorch model and an OptimizerConfig object as input and returns a tuple containing the created
|
||||||
|
optimizer and learning rate scheduler. It creates an optimizer instance of the selected optimizer class and
|
||||||
|
initializes the LRShim scheduler with the specified learning rate configurations.
|
||||||
"""
|
"""
|
||||||
optimizer_class = get_optimizer_class(optimizer_config)
|
optimizer_class = get_optimizer_class(optimizer_config)
|
||||||
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
||||||
|
Loading…
Reference in New Issue
Block a user