added some more extra explanation and code refactor using IDE settings

This commit is contained in:
Fudeel 2023-04-04 01:00:43 +02:00
parent 78c3235eee
commit f8500d8727

View File

@ -1,3 +1,10 @@
"""
This code defines a custom learning rate scheduler for PyTorch by using different learning rate configurations.
The scheduler can compute learning rates based on various strategies,
such as constant, piecewise constant, linear ramp to constant, and linear ramp to cosine.
The code also provides a utility function to build an optimizer and a learning rate scheduler from an OptimizerConfig.
"""
from typing import Dict, Tuple from typing import Dict, Tuple
import math import math
import bisect import bisect
@ -14,92 +21,104 @@ from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step): def compute_lr(lr_config, step):
"""Compute a learning rate.""" """This function computes the learning rate based on the learning rate configuration (lr_config)
if lr_config.constant is not None: and the current training step (step). It handles multiple learning rate strategies."""
return lr_config.constant if lr_config.constant is not None:
elif lr_config.piecewise_constant is not None: return lr_config.constant
return lr_config.piecewise_constant.learning_rate_values[ elif lr_config.piecewise_constant is not None:
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step) return lr_config.piecewise_constant.learning_rate_values[
] bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
elif lr_config.linear_ramp_to_constant is not None: ]
slope = ( elif lr_config.linear_ramp_to_constant is not None:
lr_config.linear_ramp_to_constant.learning_rate slope = (
/ lr_config.linear_ramp_to_constant.num_ramp_steps lr_config.linear_ramp_to_constant.learning_rate
) / lr_config.linear_ramp_to_constant.num_ramp_steps
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
elif lr_config.linear_ramp_to_cosine is not None:
cfg = lr_config.linear_ramp_to_cosine
if step < cfg.num_ramp_steps:
slope = cfg.learning_rate / cfg.num_ramp_steps
return slope * step
elif step <= cfg.final_num_steps:
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
1.0
+ math.cos(
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
) )
) return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
elif lr_config.linear_ramp_to_cosine is not None:
cfg = lr_config.linear_ramp_to_cosine
if step < cfg.num_ramp_steps:
slope = cfg.learning_rate / cfg.num_ramp_steps
return slope * step
elif step <= cfg.final_num_steps:
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
1.0
+ math.cos(
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
)
)
else:
return cfg.final_learning_rate
else: else:
return cfg.final_learning_rate raise ValueError(f"No option selected in lr_config, passed {lr_config}")
else:
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
class LRShim(_LRScheduler): class LRShim(_LRScheduler):
"""Shim to get learning rates into a LRScheduler. """Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used. e.g. exponential decay can be used.
This class inherits from PyTorch's _LRScheduler and acts as a shim to compute learning rates according to the
specified configurations. It also checks if the number of parameter groups in the optimizer matches
the length of the learning rate dictionary.
""" """
def __init__( def __init__(
self, self,
optimizer, optimizer,
lr_dict: Dict[str, LearningRate], lr_dict: Dict[str, LearningRate],
last_epoch=-1, last_epoch=-1,
verbose=False, verbose=False,
): ):
self.optimizer = optimizer self.optimizer = optimizer
self.lr_dict = lr_dict self.lr_dict = lr_dict
self.group_names = list(self.lr_dict.keys()) self.group_names = list(self.lr_dict.keys())
num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups) num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups)
if num_param_groups != len(lr_dict): if num_param_groups != len(lr_dict):
raise ValueError( raise ValueError(
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}." f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
) )
super().__init__(optimizer, last_epoch, verbose) super().__init__(optimizer, last_epoch, verbose)
def get_lr(self): def get_lr(self):
if not self._get_lr_called_within_step: if not self._get_lr_called_within_step:
logging.warn( logging.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning, UserWarning,
) )
return self._get_closed_form_lr() return self._get_closed_form_lr()
def _get_closed_form_lr(self): def _get_closed_form_lr(self):
return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()] return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()]
def get_optimizer_class(optimizer_config: OptimizerConfig): def get_optimizer_class(optimizer_config: OptimizerConfig):
if optimizer_config.adam is not None: """
return torch.optim.Adam This function returns the appropriate PyTorch optimizer class based on the given optimizer_config. It supports Adam, SGD, and Adagrad optimizers.
elif optimizer_config.sgd is not None: """
return torch.optim.SGD if optimizer_config.adam is not None:
elif optimizer_config.adagrad is not None: return torch.optim.Adam
return torch.optim.Adagrad elif optimizer_config.sgd is not None:
return torch.optim.SGD
elif optimizer_config.adagrad is not None:
return torch.optim.Adagrad
def build_optimizer( def build_optimizer(
model: torch.nn.Module, optimizer_config: OptimizerConfig model: torch.nn.Module, optimizer_config: OptimizerConfig
) -> Tuple[Optimizer, _LRScheduler]: ) -> Tuple[Optimizer, _LRScheduler]:
"""Builds an optimizer and LR scheduler from an OptimizerConfig. """Builds an optimizer and LR scheduler from an OptimizerConfig.
Note: use this when you want the same optimizer and learning rate schedule for all your parameters. Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
This function takes a PyTorch model and an OptimizerConfig object as input and returns a tuple containing the created
optimizer and learning rate scheduler. It creates an optimizer instance of the selected optimizer class and
initializes the LRShim scheduler with the specified learning rate configurations.
""" """
optimizer_class = get_optimizer_class(optimizer_config) optimizer_class = get_optimizer_class(optimizer_config)
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict()) optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
# We're passing everything in as one group here # We're passing everything in as one group here
scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate}) scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate})
return optimizer, scheduler return optimizer, scheduler