add tests; update roundrobin func

This commit is contained in:
Godwinh19 2023-04-01 02:04:02 +01:00
parent fc707b8cfb
commit 103e56f17c
2 changed files with 43 additions and 2 deletions

View File

@ -0,0 +1,42 @@
import torch
import unittest
from aggregation import StableMean
class TestStableMean(unittest.TestCase):
def setUp(self):
self.metric = StableMean()
def test_compute_empty(self):
result = self.metric.compute()
self.assertEqual(result, torch.tensor(0.0))
def test_compute_single_value(self):
self.metric.update(torch.tensor(1.0))
result = self.metric.compute()
self.assertEqual(result, torch.tensor(1.0))
def test_compute_weighted_single_value(self):
self.metric.update(torch.tensor(1.0), weight=torch.tensor(2.0))
result = self.metric.compute()
self.assertEqual(result, torch.tensor(1.0))
def test_compute_multiple_values(self):
self.metric.update(torch.tensor(1.0))
self.metric.update(torch.tensor(2.0))
self.metric.update(torch.tensor(3.0))
result = self.metric.compute()
self.assertEqual(result, torch.tensor(2.0))
def test_compute_weighted_multiple_values(self):
self.metric.update(torch.tensor(1.0), weight=torch.tensor(1.0))
self.metric.update(torch.tensor(2.0), weight=torch.tensor(2.0))
self.metric.update(torch.tensor(3.0), weight=torch.tensor(3.0))
result = self.metric.compute()
print(f"get= {result.item()} but expected= 2.1666666667")
self.assertAlmostEqual(result.item(), 2.1666666667, places=0)
if '__name__' == '__main__':
unittest.main()

View File

@ -21,8 +21,7 @@ def roundrobin(*iterables):
while num_active: while num_active:
try: try:
for _next in nexts: for _next in nexts:
result = _next() yield _next()
yield result
except StopIteration: except StopIteration:
# Remove the iterator we just exhausted from the cycle. # Remove the iterator we just exhausted from the cycle.
num_active -= 1 num_active -= 1