mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 14:51:49 +01:00
Early exit
used the any function to short-circuit the loop as soon as an empty tensor is found.
This commit is contained in:
parent
c4c5072402
commit
1c0d0499ff
@ -34,17 +34,17 @@ import torchmetrics
|
|||||||
|
|
||||||
|
|
||||||
class MetricMixin:
|
class MetricMixin:
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
|
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def update(self, outputs: Dict[str, torch.Tensor]):
|
||||||
|
results = self.transform(outputs)
|
||||||
|
# Do not try to update if any tensor is empty as a result of stratification.
|
||||||
|
if any((torch.is_tensor(value) and not value.nelement()) for value in results.values()):
|
||||||
|
return
|
||||||
|
super().update(**results)
|
||||||
|
|
||||||
def update(self, outputs: Dict[str, torch.Tensor]):
|
|
||||||
results = self.transform(outputs)
|
|
||||||
# Do not try to update if any tensor is empty as a result of stratification.
|
|
||||||
for value in results.values():
|
|
||||||
if torch.is_tensor(value) and not value.nelement():
|
|
||||||
return
|
|
||||||
super().update(**results)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskMixin:
|
class TaskMixin:
|
||||||
|
Loading…
Reference in New Issue
Block a user