Early exit

used the any function to short-circuit the loop as soon as an empty tensor is found.
This commit is contained in:
Darshan P 2023-04-03 16:00:13 +00:00
parent c4c5072402
commit 1c0d0499ff

View File

@ -35,18 +35,18 @@ import torchmetrics
class MetricMixin: class MetricMixin:
@abstractmethod @abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
... ...
def update(self, outputs: Dict[str, torch.Tensor]): def update(self, outputs: Dict[str, torch.Tensor]):
results = self.transform(outputs) results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification. # Do not try to update if any tensor is empty as a result of stratification.
for value in results.values(): if any((torch.is_tensor(value) and not value.nelement()) for value in results.values()):
if torch.is_tensor(value) and not value.nelement():
return return
super().update(**results) super().update(**results)
class TaskMixin: class TaskMixin:
def __init__(self, task_idx: int = -1, **kwargs): def __init__(self, task_idx: int = -1, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)