mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-24 12:51:10 +01:00
I just put switch case instead of elif
I am a student so i just proposed my idea , might be stupid and less but that what i understood rightnow , also it assume that only one condition can be true
This commit is contained in:
parent
78c3235eee
commit
5954f73305
@ -13,37 +13,27 @@ from torch.optim.lr_scheduler import _LRScheduler
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
def compute_lr(lr_config, step):
|
||||
"""Compute a learning rate."""
|
||||
if lr_config.constant is not None:
|
||||
return lr_config.constant
|
||||
elif lr_config.piecewise_constant is not None:
|
||||
return lr_config.piecewise_constant.learning_rate_values[
|
||||
def get_learning_rate(lr_config: LearningRateConfig, step: int) -> float:
|
||||
switcher = {
|
||||
lr_config.constant is not None: lambda: lr_config.constant,
|
||||
lr_config.piecewise_constant is not None: lambda: lr_config.piecewise_constant.learning_rate_values[
|
||||
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
|
||||
]
|
||||
elif lr_config.linear_ramp_to_constant is not None:
|
||||
slope = (
|
||||
lr_config.linear_ramp_to_constant.learning_rate
|
||||
/ lr_config.linear_ramp_to_constant.num_ramp_steps
|
||||
)
|
||||
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
|
||||
elif lr_config.linear_ramp_to_cosine is not None:
|
||||
cfg = lr_config.linear_ramp_to_cosine
|
||||
if step < cfg.num_ramp_steps:
|
||||
slope = cfg.learning_rate / cfg.num_ramp_steps
|
||||
return slope * step
|
||||
elif step <= cfg.final_num_steps:
|
||||
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
|
||||
1.0
|
||||
+ math.cos(
|
||||
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
|
||||
)
|
||||
)
|
||||
else:
|
||||
return cfg.final_learning_rate
|
||||
else:
|
||||
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
|
||||
|
||||
],
|
||||
lr_config.linear_ramp_to_constant is not None: lambda: min(lr_config.linear_ramp_to_constant.learning_rate,
|
||||
(lr_config.linear_ramp_to_constant.learning_rate
|
||||
/ lr_config.linear_ramp_to_constant.num_ramp_steps)
|
||||
* step),
|
||||
lr_config.linear_ramp_to_cosine is not None: lambda: (lr_config.linear_ramp_to_cosine.final_learning_rate
|
||||
+ (lr_config.linear_ramp_to_cosine.learning_rate
|
||||
- lr_config.linear_ramp_to_cosine.final_learning_rate)
|
||||
* 0.5 * (1.0 + math.cos(
|
||||
math.pi * (step - lr_config.linear_ramp_to_cosine.num_ramp_steps)
|
||||
/ (lr_config.linear_ramp_to_cosine.final_num_steps - lr_config.linear_ramp_to_cosine.num_ramp_steps))))
|
||||
}
|
||||
func = switcher.get(True, lambda: f"No option selected in lr_config, passed {lr_config}")
|
||||
return func()
|
||||
return cfg.final_learning_rate
|
||||
}
|
||||
|
||||
class LRShim(_LRScheduler):
|
||||
"""Shim to get learning rates into a LRScheduler.
|
||||
|
Loading…
x
Reference in New Issue
Block a user