From 1c0d0499ff6dbbde3c4e68f584457330344ee555 Mon Sep 17 00:00:00 2001 From: Darshan P Date: Mon, 3 Apr 2023 16:00:13 +0000 Subject: [PATCH] Early exit used the any function to short-circuit the loop as soon as an empty tensor is found. --- core/metric_mixin.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/core/metric_mixin.py b/core/metric_mixin.py index a716ca7..b9a133d 100644 --- a/core/metric_mixin.py +++ b/core/metric_mixin.py @@ -34,17 +34,17 @@ import torchmetrics class MetricMixin: - @abstractmethod - def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: - ... + @abstractmethod + def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + ... + + def update(self, outputs: Dict[str, torch.Tensor]): + results = self.transform(outputs) + # Do not try to update if any tensor is empty as a result of stratification. + if any((torch.is_tensor(value) and not value.nelement()) for value in results.values()): + return + super().update(**results) - def update(self, outputs: Dict[str, torch.Tensor]): - results = self.transform(outputs) - # Do not try to update if any tensor is empty as a result of stratification. - for value in results.values(): - if torch.is_tensor(value) and not value.nelement(): - return - super().update(**results) class TaskMixin: