mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-24 21:01:11 +01:00
I just put switch case instead of elif
I am a student so i just proposed my idea , might be stupid and less but that what i understood rightnow , also it assume that only one condition can be true
This commit is contained in:
parent
78c3235eee
commit
5954f73305
@ -13,37 +13,27 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||||||
from tml.ml_logging.torch_logging import logging
|
from tml.ml_logging.torch_logging import logging
|
||||||
|
|
||||||
|
|
||||||
def compute_lr(lr_config, step):
|
def get_learning_rate(lr_config: LearningRateConfig, step: int) -> float:
|
||||||
"""Compute a learning rate."""
|
switcher = {
|
||||||
if lr_config.constant is not None:
|
lr_config.constant is not None: lambda: lr_config.constant,
|
||||||
return lr_config.constant
|
lr_config.piecewise_constant is not None: lambda: lr_config.piecewise_constant.learning_rate_values[
|
||||||
elif lr_config.piecewise_constant is not None:
|
|
||||||
return lr_config.piecewise_constant.learning_rate_values[
|
|
||||||
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
|
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
|
||||||
]
|
],
|
||||||
elif lr_config.linear_ramp_to_constant is not None:
|
lr_config.linear_ramp_to_constant is not None: lambda: min(lr_config.linear_ramp_to_constant.learning_rate,
|
||||||
slope = (
|
(lr_config.linear_ramp_to_constant.learning_rate
|
||||||
lr_config.linear_ramp_to_constant.learning_rate
|
/ lr_config.linear_ramp_to_constant.num_ramp_steps)
|
||||||
/ lr_config.linear_ramp_to_constant.num_ramp_steps
|
* step),
|
||||||
)
|
lr_config.linear_ramp_to_cosine is not None: lambda: (lr_config.linear_ramp_to_cosine.final_learning_rate
|
||||||
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
|
+ (lr_config.linear_ramp_to_cosine.learning_rate
|
||||||
elif lr_config.linear_ramp_to_cosine is not None:
|
- lr_config.linear_ramp_to_cosine.final_learning_rate)
|
||||||
cfg = lr_config.linear_ramp_to_cosine
|
* 0.5 * (1.0 + math.cos(
|
||||||
if step < cfg.num_ramp_steps:
|
math.pi * (step - lr_config.linear_ramp_to_cosine.num_ramp_steps)
|
||||||
slope = cfg.learning_rate / cfg.num_ramp_steps
|
/ (lr_config.linear_ramp_to_cosine.final_num_steps - lr_config.linear_ramp_to_cosine.num_ramp_steps))))
|
||||||
return slope * step
|
}
|
||||||
elif step <= cfg.final_num_steps:
|
func = switcher.get(True, lambda: f"No option selected in lr_config, passed {lr_config}")
|
||||||
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
|
return func()
|
||||||
1.0
|
|
||||||
+ math.cos(
|
|
||||||
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return cfg.final_learning_rate
|
return cfg.final_learning_rate
|
||||||
else:
|
}
|
||||||
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
|
|
||||||
|
|
||||||
|
|
||||||
class LRShim(_LRScheduler):
|
class LRShim(_LRScheduler):
|
||||||
"""Shim to get learning rates into a LRScheduler.
|
"""Shim to get learning rates into a LRScheduler.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user