Early exit

used the any function to short-circuit the loop as soon as an empty tensor is found.
This commit is contained in:
Darshan P 2023-04-03 16:00:13 +00:00
parent c4c5072402
commit 1c0d0499ff

View File

@ -34,17 +34,17 @@ import torchmetrics
class MetricMixin:
@abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
...
@abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
...
def update(self, outputs: Dict[str, torch.Tensor]):
results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification.
if any((torch.is_tensor(value) and not value.nelement()) for value in results.values()):
return
super().update(**results)
def update(self, outputs: Dict[str, torch.Tensor]):
results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification.
for value in results.values():
if torch.is_tensor(value) and not value.nelement():
return
super().update(**results)
class TaskMixin: