mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 23:01:48 +01:00
store returned metrics in variables prior to return statement.
This commit is contained in:
parent
78c3235eee
commit
84b232bf7b
@ -50,7 +50,8 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
|||||||
value=state[:, 0],
|
value=state[:, 0],
|
||||||
weight=state[:, 1],
|
weight=state[:, 1],
|
||||||
)
|
)
|
||||||
return torch.stack([mean, weight_sum])
|
merged_accumulated_mean = torch.stack([mean, weight_sum])
|
||||||
|
return merged_accumulated_mean
|
||||||
|
|
||||||
|
|
||||||
class StableMean(torchmetrics.Metric):
|
class StableMean(torchmetrics.Metric):
|
||||||
@ -94,4 +95,5 @@ class StableMean(torchmetrics.Metric):
|
|||||||
"""
|
"""
|
||||||
Compute and return the accumulated mean.
|
Compute and return the accumulated mean.
|
||||||
"""
|
"""
|
||||||
return self.mean_and_weight_sum[0]
|
accumulated_mean = self.mean_and_weight_sum[0]
|
||||||
|
return accumulated_mean
|
||||||
|
@ -159,4 +159,6 @@ class AUROCWithMWU(torchmetrics.Metric):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compute auroc with the weight set to 1/2 when positive & negative have identical scores.
|
# Compute auroc with the weight set to 1/2 when positive & negative have identical scores.
|
||||||
return auroc_le - (auroc_le - auroc_lt) / 2.0
|
auroc = auroc_le - (auroc_le - auroc_lt) / 2.0
|
||||||
|
|
||||||
|
return auroc
|
@ -178,7 +178,9 @@ class RCE(torchmetrics.Metric):
|
|||||||
|
|
||||||
pred_ce = self.binary_cross_entropy.compute()
|
pred_ce = self.binary_cross_entropy.compute()
|
||||||
|
|
||||||
return (1.0 - (pred_ce / baseline_ce)) * 100
|
rce = (1.0 - (pred_ce / baseline_ce)) * 100
|
||||||
|
|
||||||
|
return rce
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user