mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-27 05:55:27 +01:00
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import torch
|
|
import unittest
|
|
from aggregation import StableMean
|
|
|
|
|
|
class TestStableMean(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.metric = StableMean()
|
|
|
|
def test_compute_empty(self):
|
|
result = self.metric.compute()
|
|
self.assertEqual(result, torch.tensor(0.0))
|
|
|
|
def test_compute_single_value(self):
|
|
self.metric.update(torch.tensor(1.0))
|
|
result = self.metric.compute()
|
|
self.assertEqual(result, torch.tensor(1.0))
|
|
|
|
def test_compute_weighted_single_value(self):
|
|
self.metric.update(torch.tensor(1.0), weight=torch.tensor(2.0))
|
|
result = self.metric.compute()
|
|
self.assertEqual(result, torch.tensor(1.0))
|
|
|
|
def test_compute_multiple_values(self):
|
|
self.metric.update(torch.tensor(1.0))
|
|
self.metric.update(torch.tensor(2.0))
|
|
self.metric.update(torch.tensor(3.0))
|
|
result = self.metric.compute()
|
|
self.assertEqual(result, torch.tensor(2.0))
|
|
|
|
def test_compute_weighted_multiple_values(self):
|
|
self.metric.update(torch.tensor(1.0), weight=torch.tensor(1.0))
|
|
self.metric.update(torch.tensor(2.0), weight=torch.tensor(2.0))
|
|
self.metric.update(torch.tensor(3.0), weight=torch.tensor(3.0))
|
|
result = self.metric.compute()
|
|
print(f"get= {result.item()} but expected= 2.1666666667")
|
|
self.assertAlmostEqual(result.item(), 2.1666666667, places=0)
|
|
|
|
|
|
if '__name__' == '__main__':
|
|
unittest.main()
|