the-algorithm/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/metrics.py

121 lines
4.4 KiB
Python

# checkstyle: noqa
import tensorflow.compat.v1 as tf
from collections import OrderedDict
from .constants import EB_SCORE_IDX
from .lolly.data_helpers import get_lolly_scores
import twml
def get_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1):
"""
This function was copied from twml/metrics.py with the following adjustments:
- Override example weights with the ones set in graph_output.
- Tile labels in order to support per engagement metrics for both TF and Lolly scores.
- Add lolly_tf_score_MSE metric.
Note: All custom lines have a comment that starts with 'Added'
"""
# pylint: disable=invalid-name,dict-keys-not-iterating
if metrics is None:
# remove expensive metrics by default for faster eval
metrics = list(twml.metrics.SUPPORTED_BINARY_CLASS_METRICS.keys())
metrics.remove('pr_curve')
def get_eval_metric_ops(graph_output, labels, weights):
"""
graph_output:
dict that is returned by build_graph given input features.
labels:
target labels associated to batch.
weights:
weights of the samples..
"""
# Added to support the example weights overriding.
weights = graph_output["weights"]
# Added to support per engagement metrics for both TF and Lolly scores.
labels = tf.tile(labels, [1, 2])
eval_metric_ops = OrderedDict()
preds = graph_output['output']
threshold = graph_output['threshold'] if 'threshold' in graph_output else 0.5
hard_preds = graph_output.get('hard_output')
if not hard_preds:
hard_preds = tf.greater_equal(preds, threshold)
shape = labels.get_shape()
# basic sanity check: multi_metric dimension must exist
assert len(shape) > class_dim, "Dimension specified by class_dim does not exist."
num_labels = shape[class_dim]
# If we are doing multi-class / multi-label metric, the number of classes / labels must
# be know at graph construction time. This dimension cannot have size None.
assert num_labels is not None, "The multi-metric dimension cannot be None."
assert classes is None or len(classes) == num_labels, (
"Number of classes must match the number of labels")
weights_shape = weights.get_shape() if weights is not None else None
if weights_shape is None:
num_weights = None
elif len(weights_shape) > 1:
num_weights = weights_shape[class_dim]
else:
num_weights = 1
for i in range(num_labels):
# add metrics to eval_metric_ops dict
for metric_name in metrics:
metric_name = metric_name.lower() # metric name are case insensitive.
class_metric_name = metric_name + "_" + (classes[i] if classes is not None else str(i))
if class_metric_name in eval_metric_ops:
# avoid adding duplicate metrics.
continue
class_labels = tf.gather(labels, indices=[i], axis=class_dim)
class_preds = tf.gather(preds, indices=[i], axis=class_dim)
class_hard_preds = tf.gather(hard_preds, indices=[i], axis=class_dim)
if num_weights is None:
class_weights = None
elif num_weights == num_labels:
class_weights = tf.gather(weights, indices=[i], axis=class_dim)
elif num_weights == 1:
class_weights = weights
else:
raise ValueError("num_weights (%d) and num_labels (%d) do not match"
% (num_weights, num_labels))
metric_factory, requires_threshold = twml.metrics.SUPPORTED_BINARY_CLASS_METRICS.get(metric_name)
if metric_factory:
value_op, update_op = metric_factory(
labels=class_labels,
predictions=(class_hard_preds if requires_threshold else class_preds),
weights=class_weights, name=class_metric_name)
eval_metric_ops[class_metric_name] = (value_op, update_op)
else:
raise ValueError('Cannot find the metric named ' + metric_name)
# Added to compare TF and Lolly scores.
eval_metric_ops["lolly_tf_score_MSE"] = get_mse(graph_output["output"], labels)
return eval_metric_ops
return get_eval_metric_ops
def get_mse(predictions, labels):
lolly_scores = get_lolly_scores(labels)
tf_scores = predictions[:, EB_SCORE_IDX]
squared_lolly_tf_score_diff = tf.square(tf.subtract(tf_scores, lolly_scores))
value_op = tf.reduce_mean(squared_lolly_tf_score_diff, name="value_op")
update_op = tf.reduce_mean(squared_lolly_tf_score_diff, name="update_op")
return value_op, update_op