From fc707b8cfbaa8ac007efac5ddb53d65602d88510 Mon Sep 17 00:00:00 2001 From: Godwinh19 Date: Sat, 1 Apr 2023 01:48:15 +0100 Subject: [PATCH] add tests and remove redundant imports --- optimizers/test_optimizer.py | 44 ++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 optimizers/test_optimizer.py diff --git a/optimizers/test_optimizer.py b/optimizers/test_optimizer.py new file mode 100644 index 0000000..5ef1107 --- /dev/null +++ b/optimizers/test_optimizer.py @@ -0,0 +1,44 @@ +import unittest +import torch +from tml.optimizers.config import LearningRate, OptimizerConfig +from .optimizer import compute_lr, LRShim, get_optimizer_class, build_optimizer + + +class TestComputeLR(unittest.TestCase): + def test_constant_lr(self): + lr_config = LearningRate(constant=0.1) + lr = compute_lr(lr_config, step=0) + self.assertAlmostEqual(lr, 0.1) + + def test_piecewise_constant_lr(self): + lr_config = LearningRate(piecewise_constant={"learning_rate_boundaries": [10, 20], "learning_rate_values": [0.1, 0.01, 0.001]}) + lr = compute_lr(lr_config, step=5) + self.assertAlmostEqual(lr, 0.1) + lr = compute_lr(lr_config, step=15) + self.assertAlmostEqual(lr, 0.01) + lr = compute_lr(lr_config, step=25) + self.assertAlmostEqual(lr, 0.001) + + +class TestLRShim(unittest.TestCase): + def setUp(self): + self.optimizer = torch.optim.SGD([torch.randn(10, 10)], lr=0.1) + self.lr_dict = {"ALL_PARAMS": LearningRate(constant=0.1)} + + def test_get_lr(self): + lr_scheduler = LRShim(self.optimizer, self.lr_dict) + lr = lr_scheduler.get_lr() + self.assertAlmostEqual(lr, [0.1]) + + +class TestBuildOptimizer(unittest.TestCase): + def test_build_optimizer(self): + model = torch.nn.Linear(10, 1) + optimizer_config = OptimizerConfig(sgd={"lr": 0.1}) + optimizer, scheduler = build_optimizer(model, optimizer_config) + self.assertIsInstance(optimizer, torch.optim.SGD) + self.assertIsInstance(scheduler, LRShim) + + +if __name__ == "__main__": + unittest.main()