From 85c472e57df8a64ea9a00c5b9842cd91e1b34d12 Mon Sep 17 00:00:00 2001 From: Sharansrj567 Date: Sat, 1 Apr 2023 14:45:17 +0530 Subject: [PATCH] Refactor twml metrics module for better performance and complexity --- twml/twml/contrib/metrics/metrics.py | 90 ++++++++++++++-------------- 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/twml/twml/contrib/metrics/metrics.py b/twml/twml/contrib/metrics/metrics.py index dea1a5273..a49ce03cf 100644 --- a/twml/twml/contrib/metrics/metrics.py +++ b/twml/twml/contrib/metrics/metrics.py @@ -20,29 +20,29 @@ from twml.metrics import get_multi_binary_class_metric_fn # checkstyle: noqa -def get_partial_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1, predcols=None): +def get_partial_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1, pred_cols=None): def get_eval_metric_ops(graph_output, labels, weights): - if predcols is None: + if pred_cols is None: preds = graph_output['output'] else: - if isinstance(predcols, int): - predcol_list=[predcols] + if isinstance(pred_cols, int): + pred_col_list=[pred_cols] else: - predcol_list=list(predcols) - for col in predcol_list: + pred_col_list=list(pred_cols) + for col in pred_col_list: assert 0 <= col < graph_output['output'].shape[class_dim], 'Invalid Prediction Column Index !' - preds = tf.gather(graph_output['output'], indices=predcol_list, axis=class_dim) # [batchSz, num_col] - labels = tf.gather(labels, indices=predcol_list, axis=class_dim) # [batchSz, num_col] + preds = tf.gather(graph_output['output'], indices=pred_col_list, axis=class_dim) # [batchSz, num_col] + labels = tf.gather(labels, indices=pred_col_list, axis=class_dim) # [batchSz, num_col] - predInfo = {'output': preds} + pred_info = {'output': preds} if 'threshold' in graph_output: - predInfo['threshold'] = graph_output['threshold'] + pred_info['threshold'] = graph_output['threshold'] if 'hard_output' in graph_output: - predInfo['hard_output'] = graph_output['hard_output'] + pred_info['hard_output'] = graph_output['hard_output'] metrics_op = get_multi_binary_class_metric_fn(metrics, classes, class_dim) - metrics_op_res = metrics_op(predInfo, labels, weights) + metrics_op_res = metrics_op(pred_info, labels, weights) return metrics_op_res return get_eval_metric_ops @@ -68,31 +68,29 @@ DEFAULT_NUMERIC_METRICS = ['mean_numeric_label_topk', 'mean_gated_numeric_label_ -def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None, topK=(5,5,5), predcol=None, labelcol=None): +def get_metric_topK_fn_helper(target_metrics, supported_metrics_op, metrics=None, top_k=(5, 5, 5), pred_col=None, label_col=None): """ - :param targetMetrics: Target Metric List - :param supportedMetrics_op: Supported Metric Operators Dict - :param metrics: Metric Set to evaluate - :param topK: (topK_min, topK_max, topK_delta) Tuple - :param predcol: Prediction Column Index - :param labelcol: Label Column Index + :param target_metrics: Target metric list + :param supported_metrics_op: Supported metric operators as a dictionary + :param metrics: Metric set to evaluate + :param top_k: A tuple of (minimum top_k, maximum top_k, top_k delta) + :param pred_col: Prediction column index + :param label_col: Label column index :return: """ - # pylint: disable=dict-keys-not-iterating - if targetMetrics is None or supportedMetrics_op is None: - raise ValueError("Invalid Target Metric List/op !") + if not target_metrics or not supported_metrics_op: + raise ValueError("Invalid target metric list/op!") - targetMetrics = set([m.lower() for m in targetMetrics]) - if metrics is None: - metrics = list(targetMetrics) + target_metrics = {m.lower() for m in target_metrics} + if not metrics: + metrics = list(target_metrics) else: - metrics = [m.lower() for m in metrics if m.lower() in targetMetrics] + metrics = [m.lower() for m in metrics if m.lower() in target_metrics] - num_k = int((topK[1]-topK[0])/topK[2]+1) - topK_list = [topK[0]+d*topK[2] for d in range(num_k)] + num_k = (top_k[1] - top_k[0]) // top_k[2] + 1 + topK_list = [top_k[0] + d * top_k[2] for d in range(num_k)] if 1 not in topK_list: - topK_list = [1] + topK_list - + topK_list.insert(0, 1) def get_eval_metric_ops(graph_output, labels, weights): """ @@ -105,13 +103,13 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None, """ eval_metric_ops = OrderedDict() - if predcol is None: + if pred_col is None: pred = graph_output['output'] else: - assert 0 <= predcol < graph_output['output'].shape[1], 'Invalid Prediction Column Index !' - assert labelcol is not None - pred = tf.reshape(graph_output['output'][:, predcol], shape=[-1, 1]) - labels = tf.reshape(labels[:, labelcol], shape=[-1, 1]) + assert 0 <= pred_col < graph_output['output'].shape[1], 'Invalid Prediction Column Index !' + assert label_col is not None + pred = tf.reshape(graph_output['output'][:, pred_col], shape=[-1, 1]) + labels = tf.reshape(labels[:, label_col], shape=[-1, 1]) numOut = graph_output['output'].shape[1] pred_score = tf.reshape(graph_output['output'][:, numOut-1], shape=[-1, 1]) @@ -119,8 +117,8 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None, for metric_name in metrics: metric_name = metric_name.lower() # metric name are case insensitive. - if metric_name in supportedMetrics_op: - metric_factory = supportedMetrics_op.get(metric_name) + if metric_name in supported_metrics_op: + metric_factory = supported_metrics_op.get(metric_name) if 'topk' not in metric_name: value_op, update_op = metric_factory( @@ -150,14 +148,14 @@ def get_metric_topK_fn_helper(targetMetrics, supportedMetrics_op, metrics=None, -def get_numeric_metric_fn(metrics=None, topK=(5,5,5), predcol=None, labelcol=None): +def get_numeric_metric_fn(metrics=None, topK=(5,5,5), pred_col=None, label_col=None): if metrics is None: metrics = list(DEFAULT_NUMERIC_METRICS) metrics = list(set(metrics)) metric_op = get_metric_topK_fn_helper(targetMetrics=list(DEFAULT_NUMERIC_METRICS), - supportedMetrics_op=SUPPORTED_NUMERIC_METRICS, - metrics=metrics, topK=topK, predcol=predcol, labelcol=labelcol) + supported_metrics_op=SUPPORTED_NUMERIC_METRICS, + metrics=metrics, topK=topK, pred_col=pred_col, label_col=label_col) return metric_op @@ -168,16 +166,16 @@ def get_single_binary_task_metric_fn(metrics, classnames, topK=(5,5,5), use_topK labels: [BatchSz, 2] [Task1, NumericLabel] """ def get_eval_metric_ops(graph_output, labels, weights): - metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames) + metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=0, classes=classnames) classnames_unw = ['unweighted_'+cs for cs in classnames] - metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=0, classes=classnames_unw) + metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=0, classes=classnames_unw) metrics_base_res = metric_op_base(graph_output, labels, weights) metrics_unw_res = metric_op_unw(graph_output, labels, None) metrics_base_res.update(metrics_unw_res) if use_topK: - metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=0, labelcol=1) + metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, pred_col=0, label_col=1) metrics_numeric_res = metric_op_numeric(graph_output, labels, weights) metrics_base_res.update(metrics_numeric_res) return metrics_base_res @@ -192,16 +190,16 @@ def get_dual_binary_tasks_metric_fn(metrics, classnames, topK=(5,5,5), use_topK= """ def get_eval_metric_ops(graph_output, labels, weights): - metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames) + metric_op_base = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=[0, 1], classes=classnames) classnames_unw = ['unweighted_'+cs for cs in classnames] - metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, predcols=[0, 1], classes=classnames_unw) + metric_op_unw = get_partial_multi_binary_class_metric_fn(metrics, pred_cols=[0, 1], classes=classnames_unw) metrics_base_res = metric_op_base(graph_output, labels, weights) metrics_unw_res = metric_op_unw(graph_output, labels, None) metrics_base_res.update(metrics_unw_res) if use_topK: - metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, predcol=2, labelcol=2) + metric_op_numeric = get_numeric_metric_fn(metrics=None, topK=topK, pred_col=2, label_col=2) metrics_numeric_res = metric_op_numeric(graph_output, labels, weights) metrics_base_res.update(metrics_numeric_res) return metrics_base_res