Added type hinting and code formatting.

This commit is contained in:
Isabelle 2023-05-04 16:48:56 +03:00 committed by GitHub
parent 90d7ea370e
commit 68351724e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,12 @@
# checkstyle: noqa # checkstyle: noqa
from typing import Dict, Final
from tensorflow import Tensor
from twml import DefaultSubcommandArgParse
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from .constants import INDEX_BY_LABEL, LABEL_NAMES from .constants import INDEX_BY_LABEL, LABEL_NAMES
# TODO: Read these from command line arguments, since they specify the existing example weights in the input data. # TODO: Read these from command line arguments, since they specify the existing example weights in the input data.
DEFAULT_WEIGHT_BY_LABEL = { DEFAULT_WEIGHT_BY_LABEL: Final[Dict[str, float]] = {
"is_clicked": 0.3, "is_clicked": 0.3,
"is_favorited": 1.0, "is_favorited": 1.0,
"is_open_linked": 0.1, "is_open_linked": 0.1,
@ -11,19 +14,19 @@ DEFAULT_WEIGHT_BY_LABEL = {
"is_profile_clicked": 1.0, "is_profile_clicked": 1.0,
"is_replied": 9.0, "is_replied": 9.0,
"is_retweeted": 1.0, "is_retweeted": 1.0,
"is_video_playback_50": 0.01 "is_video_playback_50": 0.01,
} }
def add_weight_arguments(parser): def add_weight_arguments(parser: DefaultSubcommandArgParse) -> DefaultSubcommandArgParse:
for label_name in LABEL_NAMES: for label_name in LABEL_NAMES:
parser.add_argument( parser.add_argument(
_make_weight_cli_argument_name(label_name), _make_weight_cli_argument_name(label_name),
type=float, type=float,
default=DEFAULT_WEIGHT_BY_LABEL[label_name], default=DEFAULT_WEIGHT_BY_LABEL[label_name],
dest=_make_weight_param_name(label_name) dest=_make_weight_param_name(label_name),
) )
def make_weights_tensor(input_weights, label, params): def make_weights_tensor(input_weights, label, params) -> Tensor:
''' '''
Replaces the weights for each positive engagement and keeps the input weights for negative examples. Replaces the weights for each positive engagement and keeps the input weights for negative examples.
''' '''
@ -32,12 +35,15 @@ def make_weights_tensor(input_weights, label, params):
index, default_weight = INDEX_BY_LABEL[label_name], DEFAULT_WEIGHT_BY_LABEL[label_name] index, default_weight = INDEX_BY_LABEL[label_name], DEFAULT_WEIGHT_BY_LABEL[label_name]
weight_param_name =_make_weight_param_name(label_name) weight_param_name =_make_weight_param_name(label_name)
weight_tensors.append( weight_tensors.append(
tf.reshape(tf.math.scalar_mul(getattr(params, weight_param_name) - default_weight, label[:, index]), [-1, 1]) tf.reshape(
tf.math.scalar_mul(
getattr(params, weight_param_name) - default_weight, label[:, index]
), [-1, 1])
) )
return tf.math.accumulate_n(weight_tensors) return tf.math.accumulate_n(weight_tensors)
def _make_weight_cli_argument_name(label_name): def _make_weight_cli_argument_name(label_name: str) -> str:
return f"--weight.{label_name}" return f"--weight.{label_name}"
def _make_weight_param_name(label_name): def _make_weight_param_name(label_name: str) -> str:
return f"weight_{label_name}" return f"weight_{label_name}"