mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-26 21:45:26 +01:00
store returned metrics in variables prior to return statement.
This commit is contained in:
parent
78c3235eee
commit
84b232bf7b
@ -50,7 +50,8 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
||||
value=state[:, 0],
|
||||
weight=state[:, 1],
|
||||
)
|
||||
return torch.stack([mean, weight_sum])
|
||||
merged_accumulated_mean = torch.stack([mean, weight_sum])
|
||||
return merged_accumulated_mean
|
||||
|
||||
|
||||
class StableMean(torchmetrics.Metric):
|
||||
@ -94,4 +95,5 @@ class StableMean(torchmetrics.Metric):
|
||||
"""
|
||||
Compute and return the accumulated mean.
|
||||
"""
|
||||
return self.mean_and_weight_sum[0]
|
||||
accumulated_mean = self.mean_and_weight_sum[0]
|
||||
return accumulated_mean
|
||||
|
@ -159,4 +159,6 @@ class AUROCWithMWU(torchmetrics.Metric):
|
||||
)
|
||||
|
||||
# Compute auroc with the weight set to 1/2 when positive & negative have identical scores.
|
||||
return auroc_le - (auroc_le - auroc_lt) / 2.0
|
||||
auroc = auroc_le - (auroc_le - auroc_lt) / 2.0
|
||||
|
||||
return auroc
|
@ -178,7 +178,9 @@ class RCE(torchmetrics.Metric):
|
||||
|
||||
pred_ce = self.binary_cross_entropy.compute()
|
||||
|
||||
return (1.0 - (pred_ce / baseline_ce)) * 100
|
||||
rce = (1.0 - (pred_ce / baseline_ce)) * 100
|
||||
|
||||
return rce
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user