mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 06:41:49 +01:00
Early exit
used the any function to short-circuit the loop as soon as an empty tensor is found.
This commit is contained in:
parent
c4c5072402
commit
1c0d0499ff
@ -34,17 +34,17 @@ import torchmetrics
|
||||
|
||||
|
||||
class MetricMixin:
|
||||
@abstractmethod
|
||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
|
||||
...
|
||||
@abstractmethod
|
||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
...
|
||||
|
||||
def update(self, outputs: Dict[str, torch.Tensor]):
|
||||
results = self.transform(outputs)
|
||||
# Do not try to update if any tensor is empty as a result of stratification.
|
||||
if any((torch.is_tensor(value) and not value.nelement()) for value in results.values()):
|
||||
return
|
||||
super().update(**results)
|
||||
|
||||
def update(self, outputs: Dict[str, torch.Tensor]):
|
||||
results = self.transform(outputs)
|
||||
# Do not try to update if any tensor is empty as a result of stratification.
|
||||
for value in results.values():
|
||||
if torch.is_tensor(value) and not value.nelement():
|
||||
return
|
||||
super().update(**results)
|
||||
|
||||
|
||||
class TaskMixin:
|
||||
|
Loading…
Reference in New Issue
Block a user