""" Mixin that requires a transform to munge output dictionary of tensors a model produces to a form that the torchmetrics.Metric.update expects. By unifying on our signature for `update`, we can also now use torchmetrics.MetricCollection which requires all metrics have the same call signature. To use, override this with a transform that munges `outputs` into a kwargs dict that the inherited metric.update accepts. Here are two examples of how to extend torchmetrics.SumMetric so that it accepts an output dictionary of tensors and munges it to what SumMetric expects (single `value`) for its update method. 1. Using as a mixin to inherit from or define a new metric class. class Count(MetricMixin, SumMetric): def transform(self, outputs): return {'value': 1} 2. Redefine an existing metric class. SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1}) """ from abc import abstractmethod from typing import Callable, Dict, List from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] import torch import torchmetrics class MetricMixin: @abstractmethod def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: ... def update(self, outputs: Dict[str, torch.Tensor]): results = self.transform(outputs) # Do not try to update if any tensor is empty as a result of stratification. for value in results.values(): if torch.is_tensor(value) and not value.nelement(): return super().update(**results) class TaskMixin: def __init__(self, task_idx: int = -1, **kwargs): super().__init__(**kwargs) self._task_idx = task_idx class StratifyMixin: def __init__( self, stratifier=None, **kwargs, ): super().__init__(**kwargs) self._stratifier = stratifier def maybe_apply_stratification( self, outputs: Dict[str, torch.Tensor], value_names: List[str] ) -> Dict[str, torch.Tensor]: """Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.""" outputs = outputs.copy() if not self._stratifier: return outputs stratifiers = outputs.get("stratifiers") if not stratifiers: return outputs if stratifiers.get(self._stratifier.name) is None: return outputs mask = torch.flatten(outputs["stratifiers"][self._stratifier.name] == self._stratifier.value) target_slice = torch.squeeze(mask.nonzero(), -1) for value_name in value_names: target = outputs[value_name] outputs[value_name] = torch.index_select(target, 0, target_slice) return outputs def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable): """Returns new class using MetricMixin and given base_metric. Functionally the same using inheritance, just saves some lines of code if no need for class attributes. """ def transform_method(_self, *args, **kwargs): return transform(*args, **kwargs) return type( base_metric.__name__, ( MetricMixin, base_metric, ), {"transform": transform_method}, )