diff --git a/optimizers/config.py b/optimizers/config.py index f5011f0..a1df58e 100644 --- a/optimizers/config.py +++ b/optimizers/config.py @@ -72,11 +72,7 @@ class OptimizerConfig(base_config.BaseConfig): def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig): - if optimizer_config.adam is not None: - return optimizer_config.adam - elif optimizer_config.sgd is not None: - return optimizer_config.sgd - elif optimizer_config.adagrad is not None: - return optimizer_config.adagrad - else: - raise ValueError(f"No optimizer selected in optimizer_config, passed {optimizer_config}") + for optz in (optimizer_config.adam, optimizer_config.sgd, optimizer_config.adagrad): + if optz is not None: + return optz + raise ValueError(f"No optimizer selected in optimizer_config, passed {optimizer_config}")