mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-09 14:39:20 +01:00
add tests; update roundrobin func
This commit is contained in:
parent
fc707b8cfb
commit
103e56f17c
42
metrics/test_aggregation.py
Normal file
42
metrics/test_aggregation.py
Normal file
@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import unittest
|
||||
from aggregation import StableMean
|
||||
|
||||
|
||||
class TestStableMean(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.metric = StableMean()
|
||||
|
||||
def test_compute_empty(self):
|
||||
result = self.metric.compute()
|
||||
self.assertEqual(result, torch.tensor(0.0))
|
||||
|
||||
def test_compute_single_value(self):
|
||||
self.metric.update(torch.tensor(1.0))
|
||||
result = self.metric.compute()
|
||||
self.assertEqual(result, torch.tensor(1.0))
|
||||
|
||||
def test_compute_weighted_single_value(self):
|
||||
self.metric.update(torch.tensor(1.0), weight=torch.tensor(2.0))
|
||||
result = self.metric.compute()
|
||||
self.assertEqual(result, torch.tensor(1.0))
|
||||
|
||||
def test_compute_multiple_values(self):
|
||||
self.metric.update(torch.tensor(1.0))
|
||||
self.metric.update(torch.tensor(2.0))
|
||||
self.metric.update(torch.tensor(3.0))
|
||||
result = self.metric.compute()
|
||||
self.assertEqual(result, torch.tensor(2.0))
|
||||
|
||||
def test_compute_weighted_multiple_values(self):
|
||||
self.metric.update(torch.tensor(1.0), weight=torch.tensor(1.0))
|
||||
self.metric.update(torch.tensor(2.0), weight=torch.tensor(2.0))
|
||||
self.metric.update(torch.tensor(3.0), weight=torch.tensor(3.0))
|
||||
result = self.metric.compute()
|
||||
print(f"get= {result.item()} but expected= 2.1666666667")
|
||||
self.assertAlmostEqual(result.item(), 2.1666666667, places=0)
|
||||
|
||||
|
||||
if '__name__' == '__main__':
|
||||
unittest.main()
|
@ -21,8 +21,7 @@ def roundrobin(*iterables):
|
||||
while num_active:
|
||||
try:
|
||||
for _next in nexts:
|
||||
result = _next()
|
||||
yield result
|
||||
yield _next()
|
||||
except StopIteration:
|
||||
# Remove the iterator we just exhausted from the cycle.
|
||||
num_active -= 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user