store returned metrics in variables prior to return statement.

This commit is contained in:
Marcus Eagan (marcussorealheis) 2023-04-01 15:17:16 -07:00
parent 78c3235eee
commit 84b232bf7b
No known key found for this signature in database
GPG Key ID: DBC835F366A6E6EC
3 changed files with 10 additions and 4 deletions

View File

@ -50,7 +50,8 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
value=state[:, 0], value=state[:, 0],
weight=state[:, 1], weight=state[:, 1],
) )
return torch.stack([mean, weight_sum]) merged_accumulated_mean = torch.stack([mean, weight_sum])
return merged_accumulated_mean
class StableMean(torchmetrics.Metric): class StableMean(torchmetrics.Metric):
@ -94,4 +95,5 @@ class StableMean(torchmetrics.Metric):
""" """
Compute and return the accumulated mean. Compute and return the accumulated mean.
""" """
return self.mean_and_weight_sum[0] accumulated_mean = self.mean_and_weight_sum[0]
return accumulated_mean

View File

@ -159,4 +159,6 @@ class AUROCWithMWU(torchmetrics.Metric):
) )
# Compute auroc with the weight set to 1/2 when positive & negative have identical scores. # Compute auroc with the weight set to 1/2 when positive & negative have identical scores.
return auroc_le - (auroc_le - auroc_lt) / 2.0 auroc = auroc_le - (auroc_le - auroc_lt) / 2.0
return auroc

View File

@ -178,7 +178,9 @@ class RCE(torchmetrics.Metric):
pred_ce = self.binary_cross_entropy.compute() pred_ce = self.binary_cross_entropy.compute()
return (1.0 - (pred_ce / baseline_ce)) * 100 rce = (1.0 - (pred_ce / baseline_ce)) * 100
return rce
def reset(self): def reset(self):
""" """