I just put switch case instead of elif

I am a student so i just proposed my idea , might be stupid and less but that what i understood rightnow , also it assume that only one condition can be true
This commit is contained in:
aayankk 2023-04-02 03:47:30 +05:00 committed by GitHub
parent 78c3235eee
commit 5954f73305
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,37 +13,27 @@ from torch.optim.lr_scheduler import _LRScheduler
from tml.ml_logging.torch_logging import logging from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step): def get_learning_rate(lr_config: LearningRateConfig, step: int) -> float:
"""Compute a learning rate.""" switcher = {
if lr_config.constant is not None: lr_config.constant is not None: lambda: lr_config.constant,
return lr_config.constant lr_config.piecewise_constant is not None: lambda: lr_config.piecewise_constant.learning_rate_values[
elif lr_config.piecewise_constant is not None: bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
return lr_config.piecewise_constant.learning_rate_values[ ],
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step) lr_config.linear_ramp_to_constant is not None: lambda: min(lr_config.linear_ramp_to_constant.learning_rate,
] (lr_config.linear_ramp_to_constant.learning_rate
elif lr_config.linear_ramp_to_constant is not None: / lr_config.linear_ramp_to_constant.num_ramp_steps)
slope = ( * step),
lr_config.linear_ramp_to_constant.learning_rate lr_config.linear_ramp_to_cosine is not None: lambda: (lr_config.linear_ramp_to_cosine.final_learning_rate
/ lr_config.linear_ramp_to_constant.num_ramp_steps + (lr_config.linear_ramp_to_cosine.learning_rate
) - lr_config.linear_ramp_to_cosine.final_learning_rate)
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step) * 0.5 * (1.0 + math.cos(
elif lr_config.linear_ramp_to_cosine is not None: math.pi * (step - lr_config.linear_ramp_to_cosine.num_ramp_steps)
cfg = lr_config.linear_ramp_to_cosine / (lr_config.linear_ramp_to_cosine.final_num_steps - lr_config.linear_ramp_to_cosine.num_ramp_steps))))
if step < cfg.num_ramp_steps: }
slope = cfg.learning_rate / cfg.num_ramp_steps func = switcher.get(True, lambda: f"No option selected in lr_config, passed {lr_config}")
return slope * step return func()
elif step <= cfg.final_num_steps: return cfg.final_learning_rate
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * ( }
1.0
+ math.cos(
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
)
)
else:
return cfg.final_learning_rate
else:
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
class LRShim(_LRScheduler): class LRShim(_LRScheduler):
"""Shim to get learning rates into a LRScheduler. """Shim to get learning rates into a LRScheduler.