I just put switch case instead of elif

I am a student so i just proposed my idea , might be stupid and less but that what i understood rightnow , also it assume that only one condition can be true
This commit is contained in:
aayankk 2023-04-02 03:47:30 +05:00 committed by GitHub
parent 78c3235eee
commit 5954f73305
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,37 +13,27 @@ from torch.optim.lr_scheduler import _LRScheduler
from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step):
"""Compute a learning rate."""
if lr_config.constant is not None:
return lr_config.constant
elif lr_config.piecewise_constant is not None:
return lr_config.piecewise_constant.learning_rate_values[
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
]
elif lr_config.linear_ramp_to_constant is not None:
slope = (
lr_config.linear_ramp_to_constant.learning_rate
/ lr_config.linear_ramp_to_constant.num_ramp_steps
)
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
elif lr_config.linear_ramp_to_cosine is not None:
cfg = lr_config.linear_ramp_to_cosine
if step < cfg.num_ramp_steps:
slope = cfg.learning_rate / cfg.num_ramp_steps
return slope * step
elif step <= cfg.final_num_steps:
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
1.0
+ math.cos(
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
)
)
else:
return cfg.final_learning_rate
else:
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
def get_learning_rate(lr_config: LearningRateConfig, step: int) -> float:
switcher = {
lr_config.constant is not None: lambda: lr_config.constant,
lr_config.piecewise_constant is not None: lambda: lr_config.piecewise_constant.learning_rate_values[
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
],
lr_config.linear_ramp_to_constant is not None: lambda: min(lr_config.linear_ramp_to_constant.learning_rate,
(lr_config.linear_ramp_to_constant.learning_rate
/ lr_config.linear_ramp_to_constant.num_ramp_steps)
* step),
lr_config.linear_ramp_to_cosine is not None: lambda: (lr_config.linear_ramp_to_cosine.final_learning_rate
+ (lr_config.linear_ramp_to_cosine.learning_rate
- lr_config.linear_ramp_to_cosine.final_learning_rate)
* 0.5 * (1.0 + math.cos(
math.pi * (step - lr_config.linear_ramp_to_cosine.num_ramp_steps)
/ (lr_config.linear_ramp_to_cosine.final_num_steps - lr_config.linear_ramp_to_cosine.num_ramp_steps))))
}
func = switcher.get(True, lambda: f"No option selected in lr_config, passed {lr_config}")
return func()
return cfg.final_learning_rate
}
class LRShim(_LRScheduler):
"""Shim to get learning rates into a LRScheduler.