Refactor twml metrics module for better performance and complexity

This commit is contained in:
Sharansrj567 2023-04-01 14:45:17 +05:30
parent ec83d01dca
commit 85c472e57d
1 changed files with 44 additions and 46 deletions

View File

@ -20,29 +20,29 @@ from twml.metrics import get_multi_binary_class_metric_fn
# checkstyle: noqa
def get_partial_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1, predcols=None):
def get_partial_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1, pred_cols=None):
def get_eval_metric_ops(graph_output, labels, weights):
if predcols is None:
if pred_cols is None:
preds = graph_output['output']
else:
if isinstance(predcols, int):
predcol_list=[predcols]
if isinstance(pred_cols, int):
pred_col_list=[pred_cols]
else:
predcol_list=list(predcols)
for col in predcol_list:
pred_col_list=list(pred_cols)
for col in pred_col_list:
assert 0 <= col < graph_output['output'].shape[class_dim], 'Invalid Prediction Column Index !'
preds = tf.gather(graph_output['output'], indices=predcol_list, axis=class_dim) # [batchSz, num_col]
labels = tf.gather(labels, indices=predcol_list, axis=class_dim) # [batchSz, num_col]
preds = tf.gather(graph_output['output'], indices=pred_col_list, axis=class_dim) # [batchSz, num_col]
labels = tf.gather(labels, indices=pred_col_list, axis=class_dim) # [batchSz, num_col]
predInfo = {'output': preds}
pred_info = {'output': preds}
if 'threshold' in graph_output:
predInfo['threshold'] = graph_output['threshold']
pred_info['threshold'] = graph_output['threshold']
if 'hard_output' in graph_output:
predInfo['hard_output'] = graph_output['hard_output']
pred_info['hard_output'] = graph_output['hard_output']
metrics_op = get_multi_binary_class_metric_fn(metrics, classes, class_dim)
metrics_op_res = metrics_op(predInfo, labels, weights)
metrics_op_res = metrics_op(pred_info, labels, weights)
return metrics_op_res
return get_eval_metric_ops
@ -68,31 +68,29 @@ DEFAULT_NUMERIC_METRICS = ['mean_numeric_label_topk', 'mean_gated_numeric_label_
def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None, topK=(5,5,5), predcol=None, labelcol=None):
def get_metric_topK_fn_helper(target_metrics, supported_metrics_op, metrics=None, top_k=(5, 5, 5), pred_col=None, label_col=None):
"""
:param targetMetrics: Target Metric List
:param supportedMetrics_op: Supported Metric Operators Dict
:param metrics: Metric Set to evaluate
:param topK: (topK_min, topK_max, topK_delta) Tuple
:param predcol: Prediction Column Index
:param labelcol: Label Column Index
:param target_metrics: Target metric list
:param supported_metrics_op: Supported metric operators as a dictionary
:param metrics: Metric set to evaluate
:param top_k: A tuple of (minimum top_k, maximum top_k, top_k delta)
:param pred_col: Prediction column index
:param label_col: Label column index
:return:
"""
# pylint: disable=dict-keys-not-iterating
if targetMetrics is None or supportedMetrics_op is None:
raise ValueError("Invalid Target Metric List/op !")
if not target_metrics or not supported_metrics_op:
raise ValueError("Invalid target metric list/op!")
targetMetrics = set([m.lower() for m in targetMetrics])
if metrics is None:
metrics = list(targetMetrics)
target_metrics = {m.lower() for m in target_metrics}
if not metrics:
metrics = list(target_metrics)
else:
metrics = [m.lower() for m in metrics if m.lower() in targetMetrics]
metrics = [m.lower() for m in metrics if m.lower() in target_metrics]
num_k = int((topK[1]-topK[0])/topK[2]+1)
topK_list = [topK[0]+d*topK[2] for d in range(num_k)]
num_k = (top_k[1] - top_k[0]) // top_k[2] + 1
topK_list = [top_k[0] + d * top_k[2] for d in range(num_k)]
if 1 not in topK_list:
topK_list = [1] + topK_list
topK_list.insert(0, 1)
def get_eval_metric_ops(graph_output, labels, weights):
"""
@ -105,13 +103,13 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None,
"""
eval_metric_ops = OrderedDict()
if predcol is None:
if pred_col is None:
pred = graph_output['output']
else:
assert 0 <= predcol < graph_output['output'].shape[1], 'Invalid Prediction Column Index !'
assert labelcol is not None
pred = tf.reshape(graph_output['output'][:, predcol], shape=[-1, 1])
labels = tf.reshape(labels[:, labelcol], shape=[-1, 1])
assert 0 <= pred_col < graph_output['output'].shape[1], 'Invalid Prediction Column Index !'
assert label_col is not None
pred = tf.reshape(graph_output['output'][:, pred_col], shape=[-1, 1])
labels = tf.reshape(labels[:, label_col], shape=[-1, 1])
numOut = graph_output['output'].shape[1]
pred_score = tf.reshape(graph_output['output'][:, numOut-1], shape=[-1, 1])
@ -119,8 +117,8 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None,
for metric_name in metrics:
metric_name = metric_name.lower() # metric name are case insensitive.
if metric_name in supportedMetrics_op:
metric_factory = supportedMetrics_op.get(metric_name)
if metric_name in supported_metrics_op:
metric_factory = supported_metrics_op.get(metric_name)
if 'topk' not in metric_name:
value_op, update_op = metric_factory(
@ -150,14 +148,14 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None,
def get_numeric_metric_fn(metrics=None, topK=(5,5,5), predcol=None, labelcol=None):
def get_numeric_metric_fn(metrics=None, topK=(5,5,5), pred_col=None, label_col=None):
if metrics is None:
metrics = list(DEFAULT_NUMERIC_METRICS)
metrics = list(set(metrics))
metric_op = get_metric_topK_fn_helper(targetMetrics=list(DEFAULT_NUMERIC_METRICS),
supportedMetrics_op=SUPPORTED_NUMERIC_METRICS,
metrics=metrics, topK=topK, predcol=predcol, labelcol=labelcol)
supported_metrics_op=SUPPORTED_NUMERIC_METRICS,
metrics=metrics, topK=topK, pred_col=pred_col, label_col=label_col)
return metric_op
@ -168,16 +166,16 @@ def get_single_binary_task_metric_fn(metrics, classnames, topK=(5,5,5), use_topK
labels: [BatchSz, 2] [Task1, NumericLabel]
"""
def get_eval_metric_ops(graph_output, labels, weights):
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames)
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=0, classes=classnames)
classnames_unw = ['unweighted_'+cs for cs in classnames]
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames_unw)
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=0, classes=classnames_unw)
metrics_base_res = metric_op_base(graph_output, labels, weights)
metrics_unw_res = metric_op_unw(graph_output, labels, None)
metrics_base_res.update(metrics_unw_res)
if use_topK:
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=0, labelcol=1)
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, pred_col=0, label_col=1)
metrics_numeric_res = metric_op_numeric(graph_output, labels, weights)
metrics_base_res.update(metrics_numeric_res)
return metrics_base_res
@ -192,16 +190,16 @@ def get_dual_binary_tasks_metric_fn(metrics, classnames, topK=(5,5,5), use_topK=
"""
def get_eval_metric_ops(graph_output, labels, weights):
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames)
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=[0, 1], classes=classnames)
classnames_unw = ['unweighted_'+cs for cs in classnames]
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames_unw)
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=[0, 1], classes=classnames_unw)
metrics_base_res = metric_op_base(graph_output, labels, weights)
metrics_unw_res = metric_op_unw(graph_output, labels, None)
metrics_base_res.update(metrics_unw_res)
if use_topK:
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=2, labelcol=2)
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, pred_col=2, label_col=2)
metrics_numeric_res = metric_op_numeric(graph_output, labels, weights)
metrics_base_res.update(metrics_numeric_res)
return metrics_base_res