Early exit

used the any function to short-circuit the loop as soon as an empty tensor is found.
This commit is contained in:
Darshan P 2023-04-03 16:00:13 +00:00
parent c4c5072402
commit 1c0d0499ff

View File

@ -35,18 +35,18 @@ import torchmetrics
class MetricMixin:
@abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
...
def update(self, outputs: Dict[str, torch.Tensor]):
results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification.
for value in results.values():
if torch.is_tensor(value) and not value.nelement():
if any((torch.is_tensor(value) and not value.nelement()) for value in results.values()):
return
super().update(**results)
class TaskMixin:
def __init__(self, task_idx: int = -1, **kwargs):
super().__init__(**kwargs)