This commit is contained in:
babuloseo 2023-07-17 21:42:36 -05:00 committed by GitHub
commit 315aeace47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,5 @@
# checkstyle: noqa # checkstyle: noqa
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from collections import OrderedDict
from .constants import EB_SCORE_IDX from .constants import EB_SCORE_IDX
from .lolly.data_helpers import get_lolly_scores from .lolly.data_helpers import get_lolly_scores
@ -35,7 +34,7 @@ def get_multi_binary_class_metric_fn(metrics, classes=None, class_dim=1):
# Added to support per engagement metrics for both TF and Lolly scores. # Added to support per engagement metrics for both TF and Lolly scores.
labels = tf.tile(labels, [1, 2]) labels = tf.tile(labels, [1, 2])
eval_metric_ops = OrderedDict() eval_metric_ops = dict()
preds = graph_output['output'] preds = graph_output['output']