From 103e56f17c0a765d7a0a6de7cc2075e93605832c Mon Sep 17 00:00:00 2001 From: Godwinh19 Date: Sat, 1 Apr 2023 02:04:02 +0100 Subject: [PATCH] add tests; update roundrobin func --- metrics/test_aggregation.py | 42 +++++++++++++++++++++++++++++++++++++ reader/utils.py | 3 +-- 2 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 metrics/test_aggregation.py diff --git a/metrics/test_aggregation.py b/metrics/test_aggregation.py new file mode 100644 index 0000000..c35c989 --- /dev/null +++ b/metrics/test_aggregation.py @@ -0,0 +1,42 @@ +import torch +import unittest +from aggregation import StableMean + + +class TestStableMean(unittest.TestCase): + + def setUp(self): + self.metric = StableMean() + + def test_compute_empty(self): + result = self.metric.compute() + self.assertEqual(result, torch.tensor(0.0)) + + def test_compute_single_value(self): + self.metric.update(torch.tensor(1.0)) + result = self.metric.compute() + self.assertEqual(result, torch.tensor(1.0)) + + def test_compute_weighted_single_value(self): + self.metric.update(torch.tensor(1.0), weight=torch.tensor(2.0)) + result = self.metric.compute() + self.assertEqual(result, torch.tensor(1.0)) + + def test_compute_multiple_values(self): + self.metric.update(torch.tensor(1.0)) + self.metric.update(torch.tensor(2.0)) + self.metric.update(torch.tensor(3.0)) + result = self.metric.compute() + self.assertEqual(result, torch.tensor(2.0)) + + def test_compute_weighted_multiple_values(self): + self.metric.update(torch.tensor(1.0), weight=torch.tensor(1.0)) + self.metric.update(torch.tensor(2.0), weight=torch.tensor(2.0)) + self.metric.update(torch.tensor(3.0), weight=torch.tensor(3.0)) + result = self.metric.compute() + print(f"get= {result.item()} but expected= 2.1666666667") + self.assertAlmostEqual(result.item(), 2.1666666667, places=0) + + +if '__name__' == '__main__': + unittest.main() diff --git a/reader/utils.py b/reader/utils.py index fc0e34c..0de5868 100644 --- a/reader/utils.py +++ b/reader/utils.py @@ -21,8 +21,7 @@ def roundrobin(*iterables): while num_active: try: for _next in nexts: - result = _next() - yield result + yield _next() except StopIteration: # Remove the iterator we just exhausted from the cycle. num_active -= 1