diff --git a/optimizers/optimizer.py b/optimizers/optimizer.py index 4517368..6475d18 100644 --- a/optimizers/optimizer.py +++ b/optimizers/optimizer.py @@ -13,37 +13,27 @@ from torch.optim.lr_scheduler import _LRScheduler from tml.ml_logging.torch_logging import logging -def compute_lr(lr_config, step): - """Compute a learning rate.""" - if lr_config.constant is not None: - return lr_config.constant - elif lr_config.piecewise_constant is not None: - return lr_config.piecewise_constant.learning_rate_values[ - bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step) - ] - elif lr_config.linear_ramp_to_constant is not None: - slope = ( - lr_config.linear_ramp_to_constant.learning_rate - / lr_config.linear_ramp_to_constant.num_ramp_steps - ) - return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step) - elif lr_config.linear_ramp_to_cosine is not None: - cfg = lr_config.linear_ramp_to_cosine - if step < cfg.num_ramp_steps: - slope = cfg.learning_rate / cfg.num_ramp_steps - return slope * step - elif step <= cfg.final_num_steps: - return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * ( - 1.0 - + math.cos( - math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps) - ) - ) - else: - return cfg.final_learning_rate - else: - raise ValueError(f"No option selected in lr_config, passed {lr_config}") - +def get_learning_rate(lr_config: LearningRateConfig, step: int) -> float: + switcher = { + lr_config.constant is not None: lambda: lr_config.constant, + lr_config.piecewise_constant is not None: lambda: lr_config.piecewise_constant.learning_rate_values[ + bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step) + ], + lr_config.linear_ramp_to_constant is not None: lambda: min(lr_config.linear_ramp_to_constant.learning_rate, + (lr_config.linear_ramp_to_constant.learning_rate + / lr_config.linear_ramp_to_constant.num_ramp_steps) + * step), + lr_config.linear_ramp_to_cosine is not None: lambda: (lr_config.linear_ramp_to_cosine.final_learning_rate + + (lr_config.linear_ramp_to_cosine.learning_rate + - lr_config.linear_ramp_to_cosine.final_learning_rate) + * 0.5 * (1.0 + math.cos( + math.pi * (step - lr_config.linear_ramp_to_cosine.num_ramp_steps) + / (lr_config.linear_ramp_to_cosine.final_num_steps - lr_config.linear_ramp_to_cosine.num_ramp_steps)))) + } + func = switcher.get(True, lambda: f"No option selected in lr_config, passed {lr_config}") + return func() +return cfg.final_learning_rate +} class LRShim(_LRScheduler): """Shim to get learning rates into a LRScheduler.