Refactor twml metrics module for better performance and complexity

This commit is contained in:
Sharansrj567 2023-04-01 14:45:17 +05:30
parent ec83d01dca
commit 85c472e57d

View File

@ -20,29 +20,29 @@ from twml.metrics import get_multi_binary_class_metric_fn
# checkstyle: noqa # checkstyle: noqa
def get_partial_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1, predcols=None): def get_partial_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1, pred_cols=None):
def get_eval_metric_ops(graph_output, labels, weights): def get_eval_metric_ops(graph_output, labels, weights):
if predcols is None: if pred_cols is None:
preds = graph_output['output'] preds = graph_output['output']
else: else:
if isinstance(predcols, int): if isinstance(pred_cols, int):
predcol_list=[predcols] pred_col_list=[pred_cols]
else: else:
predcol_list=list(predcols) pred_col_list=list(pred_cols)
for col in predcol_list: for col in pred_col_list:
assert 0 <= col < graph_output['output'].shape[class_dim], 'Invalid Prediction Column Index !' assert 0 <= col < graph_output['output'].shape[class_dim], 'Invalid Prediction Column Index !'
preds = tf.gather(graph_output['output'], indices=predcol_list, axis=class_dim) # [batchSz, num_col] preds = tf.gather(graph_output['output'], indices=pred_col_list, axis=class_dim) # [batchSz, num_col]
labels = tf.gather(labels, indices=predcol_list, axis=class_dim) # [batchSz, num_col] labels = tf.gather(labels, indices=pred_col_list, axis=class_dim) # [batchSz, num_col]
predInfo = {'output': preds} pred_info = {'output': preds}
if 'threshold' in graph_output: if 'threshold' in graph_output:
predInfo['threshold'] = graph_output['threshold'] pred_info['threshold'] = graph_output['threshold']
if 'hard_output' in graph_output: if 'hard_output' in graph_output:
predInfo['hard_output'] = graph_output['hard_output'] pred_info['hard_output'] = graph_output['hard_output']
metrics_op = get_multi_binary_class_metric_fn(metrics, classes, class_dim) metrics_op = get_multi_binary_class_metric_fn(metrics, classes, class_dim)
metrics_op_res = metrics_op(predInfo, labels, weights) metrics_op_res = metrics_op(pred_info, labels, weights)
return metrics_op_res return metrics_op_res
return get_eval_metric_ops return get_eval_metric_ops
@ -68,31 +68,29 @@ DEFAULT_NUMERIC_METRICS = ['mean_numeric_label_topk', 'mean_gated_numeric_label_
def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None, topK=(5,5,5), predcol=None, labelcol=None): def get_metric_topK_fn_helper(target_metrics, supported_metrics_op, metrics=None, top_k=(5, 5, 5), pred_col=None, label_col=None):
""" """
:param targetMetrics: Target Metric List :param target_metrics: Target metric list
:param supportedMetrics_op: Supported Metric Operators Dict :param supported_metrics_op: Supported metric operators as a dictionary
:param metrics: Metric Set to evaluate :param metrics: Metric set to evaluate
:param topK: (topK_min, topK_max, topK_delta) Tuple :param top_k: A tuple of (minimum top_k, maximum top_k, top_k delta)
:param predcol: Prediction Column Index :param pred_col: Prediction column index
:param labelcol: Label Column Index :param label_col: Label column index
:return: :return:
""" """
# pylint: disable=dict-keys-not-iterating if not target_metrics or not supported_metrics_op:
if targetMetrics is None or supportedMetrics_op is None: raise ValueError("Invalid target metric list/op!")
raise ValueError("Invalid Target Metric List/op !")
targetMetrics = set([m.lower() for m in targetMetrics]) target_metrics = {m.lower() for m in target_metrics}
if metrics is None: if not metrics:
metrics = list(targetMetrics) metrics = list(target_metrics)
else: else:
metrics = [m.lower() for m in metrics if m.lower() in targetMetrics] metrics = [m.lower() for m in metrics if m.lower() in target_metrics]
num_k = int((topK[1]-topK[0])/topK[2]+1) num_k = (top_k[1] - top_k[0]) // top_k[2] + 1
topK_list = [topK[0]+d*topK[2] for d in range(num_k)] topK_list = [top_k[0] + d * top_k[2] for d in range(num_k)]
if 1 not in topK_list: if 1 not in topK_list:
topK_list = [1] + topK_list topK_list.insert(0, 1)
def get_eval_metric_ops(graph_output, labels, weights): def get_eval_metric_ops(graph_output, labels, weights):
""" """
@ -105,13 +103,13 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None,
""" """
eval_metric_ops = OrderedDict() eval_metric_ops = OrderedDict()
if predcol is None: if pred_col is None:
pred = graph_output['output'] pred = graph_output['output']
else: else:
assert 0 <= predcol < graph_output['output'].shape[1], 'Invalid Prediction Column Index !' assert 0 <= pred_col < graph_output['output'].shape[1], 'Invalid Prediction Column Index !'
assert labelcol is not None assert label_col is not None
pred = tf.reshape(graph_output['output'][:, predcol], shape=[-1, 1]) pred = tf.reshape(graph_output['output'][:, pred_col], shape=[-1, 1])
labels = tf.reshape(labels[:, labelcol], shape=[-1, 1]) labels = tf.reshape(labels[:, label_col], shape=[-1, 1])
numOut = graph_output['output'].shape[1] numOut = graph_output['output'].shape[1]
pred_score = tf.reshape(graph_output['output'][:, numOut-1], shape=[-1, 1]) pred_score = tf.reshape(graph_output['output'][:, numOut-1], shape=[-1, 1])
@ -119,8 +117,8 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None,
for metric_name in metrics: for metric_name in metrics:
metric_name = metric_name.lower() # metric name are case insensitive. metric_name = metric_name.lower() # metric name are case insensitive.
if metric_name in supportedMetrics_op: if metric_name in supported_metrics_op:
metric_factory = supportedMetrics_op.get(metric_name) metric_factory = supported_metrics_op.get(metric_name)
if 'topk' not in metric_name: if 'topk' not in metric_name:
value_op, update_op = metric_factory( value_op, update_op = metric_factory(
@ -150,14 +148,14 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None,
def get_numeric_metric_fn(metrics=None, topK=(5,5,5), predcol=None, labelcol=None): def get_numeric_metric_fn(metrics=None, topK=(5,5,5), pred_col=None, label_col=None):
if metrics is None: if metrics is None:
metrics = list(DEFAULT_NUMERIC_METRICS) metrics = list(DEFAULT_NUMERIC_METRICS)
metrics = list(set(metrics)) metrics = list(set(metrics))
metric_op = get_metric_topK_fn_helper(targetMetrics=list(DEFAULT_NUMERIC_METRICS), metric_op = get_metric_topK_fn_helper(targetMetrics=list(DEFAULT_NUMERIC_METRICS),
supportedMetrics_op=SUPPORTED_NUMERIC_METRICS, supported_metrics_op=SUPPORTED_NUMERIC_METRICS,
metrics=metrics, topK=topK, predcol=predcol, labelcol=labelcol) metrics=metrics, topK=topK, pred_col=pred_col, label_col=label_col)
return metric_op return metric_op
@ -168,16 +166,16 @@ def get_single_binary_task_metric_fn(metrics, classnames, topK=(5,5,5), use_topK
labels: [BatchSz, 2] [Task1, NumericLabel] labels: [BatchSz, 2] [Task1, NumericLabel]
""" """
def get_eval_metric_ops(graph_output, labels, weights): def get_eval_metric_ops(graph_output, labels, weights):
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames) metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=0, classes=classnames)
classnames_unw = ['unweighted_'+cs for cs in classnames] classnames_unw = ['unweighted_'+cs for cs in classnames]
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames_unw) metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=0, classes=classnames_unw)
metrics_base_res = metric_op_base(graph_output, labels, weights) metrics_base_res = metric_op_base(graph_output, labels, weights)
metrics_unw_res = metric_op_unw(graph_output, labels, None) metrics_unw_res = metric_op_unw(graph_output, labels, None)
metrics_base_res.update(metrics_unw_res) metrics_base_res.update(metrics_unw_res)
if use_topK: if use_topK:
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=0, labelcol=1) metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, pred_col=0, label_col=1)
metrics_numeric_res = metric_op_numeric(graph_output, labels, weights) metrics_numeric_res = metric_op_numeric(graph_output, labels, weights)
metrics_base_res.update(metrics_numeric_res) metrics_base_res.update(metrics_numeric_res)
return metrics_base_res return metrics_base_res
@ -192,16 +190,16 @@ def get_dual_binary_tasks_metric_fn(metrics, classnames, topK=(5,5,5), use_topK=
""" """
def get_eval_metric_ops(graph_output, labels, weights): def get_eval_metric_ops(graph_output, labels, weights):
metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames) metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=[0, 1], classes=classnames)
classnames_unw = ['unweighted_'+cs for cs in classnames] classnames_unw = ['unweighted_'+cs for cs in classnames]
metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames_unw) metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=[0, 1], classes=classnames_unw)
metrics_base_res = metric_op_base(graph_output, labels, weights) metrics_base_res = metric_op_base(graph_output, labels, weights)
metrics_unw_res = metric_op_unw(graph_output, labels, None) metrics_unw_res = metric_op_unw(graph_output, labels, None)
metrics_base_res.update(metrics_unw_res) metrics_base_res.update(metrics_unw_res)
if use_topK: if use_topK:
metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=2, labelcol=2) metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, pred_col=2, label_col=2)
metrics_numeric_res = metric_op_numeric(graph_output, labels, weights) metrics_numeric_res = metric_op_numeric(graph_output, labels, weights)
metrics_base_res.update(metrics_numeric_res) metrics_base_res.update(metrics_numeric_res)
return metrics_base_res return metrics_base_res