diff --git a/core/metric_mixin.py b/core/metric_mixin.py index a716ca7..b9a133d 100644 --- a/core/metric_mixin.py +++ b/core/metric_mixin.py @@ -34,17 +34,17 @@ import torchmetrics class MetricMixin: - @abstractmethod - def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: - ... + @abstractmethod + def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + ... + + def update(self, outputs: Dict[str, torch.Tensor]): + results = self.transform(outputs) + # Do not try to update if any tensor is empty as a result of stratification. + if any((torch.is_tensor(value) and not value.nelement()) for value in results.values()): + return + super().update(**results) - def update(self, outputs: Dict[str, torch.Tensor]): - results = self.transform(outputs) - # Do not try to update if any tensor is empty as a result of stratification. - for value in results.values(): - if torch.is_tensor(value) and not value.nelement(): - return - super().update(**results) class TaskMixin: