diff --git a/metrics/test_aggregation.py b/metrics/test_aggregation.py new file mode 100644 index 0000000..c35c989 --- /dev/null +++ b/metrics/test_aggregation.py @@ -0,0 +1,42 @@ +import torch +import unittest +from aggregation import StableMean + + +class TestStableMean(unittest.TestCase): + + def setUp(self): + self.metric = StableMean() + + def test_compute_empty(self): + result = self.metric.compute() + self.assertEqual(result, torch.tensor(0.0)) + + def test_compute_single_value(self): + self.metric.update(torch.tensor(1.0)) + result = self.metric.compute() + self.assertEqual(result, torch.tensor(1.0)) + + def test_compute_weighted_single_value(self): + self.metric.update(torch.tensor(1.0), weight=torch.tensor(2.0)) + result = self.metric.compute() + self.assertEqual(result, torch.tensor(1.0)) + + def test_compute_multiple_values(self): + self.metric.update(torch.tensor(1.0)) + self.metric.update(torch.tensor(2.0)) + self.metric.update(torch.tensor(3.0)) + result = self.metric.compute() + self.assertEqual(result, torch.tensor(2.0)) + + def test_compute_weighted_multiple_values(self): + self.metric.update(torch.tensor(1.0), weight=torch.tensor(1.0)) + self.metric.update(torch.tensor(2.0), weight=torch.tensor(2.0)) + self.metric.update(torch.tensor(3.0), weight=torch.tensor(3.0)) + result = self.metric.compute() + print(f"get= {result.item()} but expected= 2.1666666667") + self.assertAlmostEqual(result.item(), 2.1666666667, places=0) + + +if '__name__' == '__main__': + unittest.main() diff --git a/reader/utils.py b/reader/utils.py index fc0e34c..0de5868 100644 --- a/reader/utils.py +++ b/reader/utils.py @@ -21,8 +21,7 @@ def roundrobin(*iterables): while num_active: try: for _next in nexts: - result = _next() - yield result + yield _next() except StopIteration: # Remove the iterator we just exhausted from the cycle. num_active -= 1