mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-02 23:51:53 +01:00
Refactor twml metrics module for better performance and complexity
This commit is contained in:
parent
ec83d01dca
commit
85c472e57d
@ -20,29 +20,29 @@ from twml.metrics import get_multi_binary_class_metric_fn
|
|||||||
|
|
||||||
|
|
||||||
# checkstyle: noqa
|
# checkstyle: noqa
|
||||||
def get_partial_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1, predcols=None):
|
def get_partial_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1, pred_cols=None):
|
||||||
|
|
||||||
def get_eval_metric_ops(graph_output, labels, weights):
|
def get_eval_metric_ops(graph_output, labels, weights):
|
||||||
if predcols is None:
|
if pred_cols is None:
|
||||||
preds = graph_output['output']
|
preds = graph_output['output']
|
||||||
else:
|
else:
|
||||||
if isinstance(predcols, int):
|
if isinstance(pred_cols, int):
|
||||||
predcol_list=[predcols]
|
pred_col_list=[pred_cols]
|
||||||
else:
|
else:
|
||||||
predcol_list=list(predcols)
|
pred_col_list=list(pred_cols)
|
||||||
for col in predcol_list:
|
for col in pred_col_list:
|
||||||
assert 0 <= col < graph_output['output'].shape[class_dim], 'Invalid Prediction Column Index !'
|
assert 0 <= col < graph_output['output'].shape[class_dim], 'Invalid Prediction Column Index !'
|
||||||
preds = tf.gather(graph_output['output'], indices=predcol_list, axis=class_dim) # [batchSz, num_col]
|
preds = tf.gather(graph_output['output'], indices=pred_col_list, axis=class_dim) # [batchSz, num_col]
|
||||||
labels = tf.gather(labels, indices=predcol_list, axis=class_dim) # [batchSz, num_col]
|
labels = tf.gather(labels, indices=pred_col_list, axis=class_dim) # [batchSz, num_col]
|
||||||
|
|
||||||
predInfo = {'output': preds}
|
pred_info = {'output': preds}
|
||||||
if 'threshold' in graph_output:
|
if 'threshold' in graph_output:
|
||||||
predInfo['threshold'] = graph_output['threshold']
|
pred_info['threshold'] = graph_output['threshold']
|
||||||
if 'hard_output' in graph_output:
|
if 'hard_output' in graph_output:
|
||||||
predInfo['hard_output'] = graph_output['hard_output']
|
pred_info['hard_output'] = graph_output['hard_output']
|
||||||
|
|
||||||
metrics_op = get_multi_binary_class_metric_fn(metrics, classes, class_dim)
|
metrics_op = get_multi_binary_class_metric_fn(metrics, classes, class_dim)
|
||||||
metrics_op_res = metrics_op(predInfo, labels, weights)
|
metrics_op_res = metrics_op(pred_info, labels, weights)
|
||||||
return metrics_op_res
|
return metrics_op_res
|
||||||
|
|
||||||
return get_eval_metric_ops
|
return get_eval_metric_ops
|
||||||
@ -68,31 +68,29 @@ DEFAULT_NUMERIC_METRICS = ['mean_numeric_label_topk', 'mean_gated_numeric_label_
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None, topK=(5,5,5), predcol=None, labelcol=None):
|
def get_metric_topK_fn_helper(target_metrics, supported_metrics_op, metrics=None, top_k=(5, 5, 5), pred_col=None, label_col=None):
|
||||||
"""
|
"""
|
||||||
:param targetMetrics: Target Metric List
|
:param target_metrics: Target metric list
|
||||||
:param supportedMetrics_op: Supported Metric Operators Dict
|
:param supported_metrics_op: Supported metric operators as a dictionary
|
||||||
:param metrics: Metric Set to evaluate
|
:param metrics: Metric set to evaluate
|
||||||
:param topK: (topK_min, topK_max, topK_delta) Tuple
|
:param top_k: A tuple of (minimum top_k, maximum top_k, top_k delta)
|
||||||
:param predcol: Prediction Column Index
|
:param pred_col: Prediction column index
|
||||||
:param labelcol: Label Column Index
|
:param label_col: Label column index
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# pylint: disable=dict-keys-not-iterating
|
if not target_metrics or not supported_metrics_op:
|
||||||
if targetMetrics is None or supportedMetrics_op is None:
|
raise ValueError("Invalid target metric list/op!")
|
||||||
raise ValueError("Invalid Target Metric List/op !")
|
|
||||||
|
|
||||||
targetMetrics = set([m.lower() for m in targetMetrics])
|
target_metrics = {m.lower() for m in target_metrics}
|
||||||
if metrics is None:
|
if not metrics:
|
||||||
metrics = list(targetMetrics)
|
metrics = list(target_metrics)
|
||||||
else:
|
else:
|
||||||
metrics = [m.lower() for m in metrics if m.lower() in targetMetrics]
|
metrics = [m.lower() for m in metrics if m.lower() in target_metrics]
|
||||||
|
|
||||||
num_k = int((topK[1]-topK[0])/topK[2]+1)
|
num_k = (top_k[1] - top_k[0]) // top_k[2] + 1
|
||||||
topK_list = [topK[0]+d*topK[2] for d in range(num_k)]
|
topK_list = [top_k[0] + d * top_k[2] for d in range(num_k)]
|
||||||
if 1 not in topK_list:
|
if 1 not in topK_list:
|
||||||
topK_list = [1] + topK_list
|
topK_list.insert(0, 1)
|
||||||
|
|
||||||
|
|
||||||
def get_eval_metric_ops(graph_output, labels, weights):
|
def get_eval_metric_ops(graph_output, labels, weights):
|
||||||
"""
|
"""
|
||||||
@ -105,13 +103,13 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None,
|
|||||||
"""
|
"""
|
||||||
eval_metric_ops = OrderedDict()
|
eval_metric_ops = OrderedDict()
|
||||||
|
|
||||||
if predcol is None:
|
if pred_col is None:
|
||||||
pred = graph_output['output']
|
pred = graph_output['output']
|
||||||
else:
|
else:
|
||||||
assert 0 <= predcol < graph_output['output'].shape[1], 'Invalid Prediction Column Index !'
|
assert 0 <= pred_col < graph_output['output'].shape[1], 'Invalid Prediction Column Index !'
|
||||||
assert labelcol is not None
|
assert label_col is not None
|
||||||
pred = tf.reshape(graph_output['output'][:, predcol], shape=[-1, 1])
|
pred = tf.reshape(graph_output['output'][:, pred_col], shape=[-1, 1])
|
||||||
labels = tf.reshape(labels[:, labelcol], shape=[-1, 1])
|
labels = tf.reshape(labels[:, label_col], shape=[-1, 1])
|
||||||
numOut = graph_output['output'].shape[1]
|
numOut = graph_output['output'].shape[1]
|
||||||
pred_score = tf.reshape(graph_output['output'][:, numOut-1], shape=[-1, 1])
|
pred_score = tf.reshape(graph_output['output'][:, numOut-1], shape=[-1, 1])
|
||||||
|
|
||||||
@ -119,8 +117,8 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None,
|
|||||||
for metric_name in metrics:
|
for metric_name in metrics:
|
||||||
metric_name = metric_name.lower() # metric name are case insensitive.
|
metric_name = metric_name.lower() # metric name are case insensitive.
|
||||||
|
|
||||||
if metric_name in supportedMetrics_op:
|
if metric_name in supported_metrics_op:
|
||||||
metric_factory = supportedMetrics_op.get(metric_name)
|
metric_factory = supported_metrics_op.get(metric_name)
|
||||||
|
|
||||||
if 'topk' not in metric_name:
|
if 'topk' not in metric_name:
|
||||||
value_op, update_op = metric_factory(
|
value_op, update_op = metric_factory(
|
||||||
@ -150,14 +148,14 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None,
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_numeric_metric_fn(metrics=None, topK=(5,5,5), predcol=None, labelcol=None):
|
def get_numeric_metric_fn(metrics=None, topK=(5,5,5), pred_col=None, label_col=None):
|
||||||
if metrics is None:
|
if metrics is None:
|
||||||
metrics = list(DEFAULT_NUMERIC_METRICS)
|
metrics = list(DEFAULT_NUMERIC_METRICS)
|
||||||
metrics = list(set(metrics))
|
metrics = list(set(metrics))
|
||||||
|
|
||||||
metric_op = get_metric_topK_fn_helper(targetMetrics=list(DEFAULT_NUMERIC_METRICS),
|
metric_op = get_metric_topK_fn_helper(targetMetrics=list(DEFAULT_NUMERIC_METRICS),
|
||||||
supportedMetrics_op=SUPPORTED_NUMERIC_METRICS,
|
supported_metrics_op=SUPPORTED_NUMERIC_METRICS,
|
||||||
metrics=metrics, topK=topK, predcol=predcol, labelcol=labelcol)
|
metrics=metrics, topK=topK, pred_col=pred_col, label_col=label_col)
|
||||||
return metric_op
|
return metric_op
|
||||||
|
|
||||||
|
|
||||||
@ -168,16 +166,16 @@ def get_single_binary_task_metric_fn(metrics, classnames, topK=(5,5,5), use_topK
|
|||||||
labels: [BatchSz, 2] [Task1, NumericLabel]
|
labels: [BatchSz, 2] [Task1, NumericLabel]
|
||||||
"""
|
"""
|
||||||
def get_eval_metric_ops(graph_output, labels, weights):
|
def get_eval_metric_ops(graph_output, labels, weights):
|
||||||
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames)
|
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=0, classes=classnames)
|
||||||
classnames_unw = ['unweighted_'+cs for cs in classnames]
|
classnames_unw = ['unweighted_'+cs for cs in classnames]
|
||||||
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames_unw)
|
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=0, classes=classnames_unw)
|
||||||
|
|
||||||
metrics_base_res = metric_op_base(graph_output, labels, weights)
|
metrics_base_res = metric_op_base(graph_output, labels, weights)
|
||||||
metrics_unw_res = metric_op_unw(graph_output, labels, None)
|
metrics_unw_res = metric_op_unw(graph_output, labels, None)
|
||||||
metrics_base_res.update(metrics_unw_res)
|
metrics_base_res.update(metrics_unw_res)
|
||||||
|
|
||||||
if use_topK:
|
if use_topK:
|
||||||
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=0, labelcol=1)
|
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, pred_col=0, label_col=1)
|
||||||
metrics_numeric_res = metric_op_numeric(graph_output, labels, weights)
|
metrics_numeric_res = metric_op_numeric(graph_output, labels, weights)
|
||||||
metrics_base_res.update(metrics_numeric_res)
|
metrics_base_res.update(metrics_numeric_res)
|
||||||
return metrics_base_res
|
return metrics_base_res
|
||||||
@ -192,16 +190,16 @@ def get_dual_binary_tasks_metric_fn(metrics, classnames, topK=(5,5,5), use_topK=
|
|||||||
"""
|
"""
|
||||||
def get_eval_metric_ops(graph_output, labels, weights):
|
def get_eval_metric_ops(graph_output, labels, weights):
|
||||||
|
|
||||||
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames)
|
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=[0, 1], classes=classnames)
|
||||||
classnames_unw = ['unweighted_'+cs for cs in classnames]
|
classnames_unw = ['unweighted_'+cs for cs in classnames]
|
||||||
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames_unw)
|
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=[0, 1], classes=classnames_unw)
|
||||||
|
|
||||||
metrics_base_res = metric_op_base(graph_output, labels, weights)
|
metrics_base_res = metric_op_base(graph_output, labels, weights)
|
||||||
metrics_unw_res = metric_op_unw(graph_output, labels, None)
|
metrics_unw_res = metric_op_unw(graph_output, labels, None)
|
||||||
metrics_base_res.update(metrics_unw_res)
|
metrics_base_res.update(metrics_unw_res)
|
||||||
|
|
||||||
if use_topK:
|
if use_topK:
|
||||||
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=2, labelcol=2)
|
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, pred_col=2, label_col=2)
|
||||||
metrics_numeric_res = metric_op_numeric(graph_output, labels, weights)
|
metrics_numeric_res = metric_op_numeric(graph_output, labels, weights)
|
||||||
metrics_base_res.update(metrics_numeric_res)
|
metrics_base_res.update(metrics_numeric_res)
|
||||||
return metrics_base_res
|
return metrics_base_res
|
||||||
|
Loading…
Reference in New Issue
Block a user