diff --git a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/example_weights.py b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/example_weights.py index cf0c38ecc..934efa279 100644 --- a/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/example_weights.py +++ b/src/python/twitter/deepbird/projects/timelines/scripts/models/earlybird/example_weights.py @@ -1,9 +1,12 @@ # checkstyle: noqa +from typing import Dict, Final +from tensorflow import Tensor +from twml import DefaultSubcommandArgParse import tensorflow.compat.v1 as tf from .constants import INDEX_BY_LABEL, LABEL_NAMES # TODO: Read these from command line arguments, since they specify the existing example weights in the input data. -DEFAULT_WEIGHT_BY_LABEL = { +DEFAULT_WEIGHT_BY_LABEL: Final[Dict[str, float]] = { "is_clicked": 0.3, "is_favorited": 1.0, "is_open_linked": 0.1, @@ -11,19 +14,19 @@ DEFAULT_WEIGHT_BY_LABEL = { "is_profile_clicked": 1.0, "is_replied": 9.0, "is_retweeted": 1.0, - "is_video_playback_50": 0.01 + "is_video_playback_50": 0.01, } -def add_weight_arguments(parser): +def add_weight_arguments(parser: DefaultSubcommandArgParse) -> DefaultSubcommandArgParse: for label_name in LABEL_NAMES: parser.add_argument( _make_weight_cli_argument_name(label_name), type=float, default=DEFAULT_WEIGHT_BY_LABEL[label_name], - dest=_make_weight_param_name(label_name) + dest=_make_weight_param_name(label_name), ) -def make_weights_tensor(input_weights, label, params): +def make_weights_tensor(input_weights, label, params) -> Tensor: ''' Replaces the weights for each positive engagement and keeps the input weights for negative examples. ''' @@ -32,12 +35,15 @@ def make_weights_tensor(input_weights, label, params): index, default_weight = INDEX_BY_LABEL[label_name], DEFAULT_WEIGHT_BY_LABEL[label_name] weight_param_name =_make_weight_param_name(label_name) weight_tensors.append( - tf.reshape(tf.math.scalar_mul(getattr(params, weight_param_name) - default_weight, label[:, index]), [-1, 1]) + tf.reshape( + tf.math.scalar_mul( + getattr(params, weight_param_name) - default_weight, label[:, index] + ), [-1, 1]) ) return tf.math.accumulate_n(weight_tensors) -def _make_weight_cli_argument_name(label_name): +def _make_weight_cli_argument_name(label_name: str) -> str: return f"--weight.{label_name}" -def _make_weight_param_name(label_name): +def _make_weight_param_name(label_name: str) -> str: return f"weight_{label_name}"