From 84b232bf7b911e863cad419e42ce7190c92ab67e Mon Sep 17 00:00:00 2001 From: "Marcus Eagan (marcussorealheis)" Date: Sat, 1 Apr 2023 15:17:16 -0700 Subject: [PATCH] store returned metrics in variables prior to return statement. --- metrics/aggregation.py | 6 ++++-- metrics/auroc.py | 4 +++- metrics/rce.py | 4 +++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/metrics/aggregation.py b/metrics/aggregation.py index 57f75d5..a3c4e79 100644 --- a/metrics/aggregation.py +++ b/metrics/aggregation.py @@ -50,7 +50,8 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: value=state[:, 0], weight=state[:, 1], ) - return torch.stack([mean, weight_sum]) + merged_accumulated_mean = torch.stack([mean, weight_sum]) + return merged_accumulated_mean class StableMean(torchmetrics.Metric): @@ -94,4 +95,5 @@ class StableMean(torchmetrics.Metric): """ Compute and return the accumulated mean. """ - return self.mean_and_weight_sum[0] + accumulated_mean = self.mean_and_weight_sum[0] + return accumulated_mean diff --git a/metrics/auroc.py b/metrics/auroc.py index 6979c20..6e0418f 100644 --- a/metrics/auroc.py +++ b/metrics/auroc.py @@ -159,4 +159,6 @@ class AUROCWithMWU(torchmetrics.Metric): ) # Compute auroc with the weight set to 1/2 when positive & negative have identical scores. - return auroc_le - (auroc_le - auroc_lt) / 2.0 + auroc = auroc_le - (auroc_le - auroc_lt) / 2.0 + + return auroc \ No newline at end of file diff --git a/metrics/rce.py b/metrics/rce.py index b6ada1d..5d1e092 100644 --- a/metrics/rce.py +++ b/metrics/rce.py @@ -178,7 +178,9 @@ class RCE(torchmetrics.Metric): pred_ce = self.binary_cross_entropy.compute() - return (1.0 - (pred_ce / baseline_ce)) * 100 + rce = (1.0 - (pred_ce / baseline_ce)) * 100 + + return rce def reset(self): """