Added type hinting and code formatting.

This commit is contained in:
Isabelle 2023-05-04 16:48:56 +03:00 committed by GitHub
parent 90d7ea370e
commit 68351724e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,12 @@
# checkstyle: noqa
from typing import Dict, Final
from tensorflow import Tensor
from twml import DefaultSubcommandArgParse
import tensorflow.compat.v1 as tf
from .constants import INDEX_BY_LABEL, LABEL_NAMES
# TODO: Read these from command line arguments, since they specify the existing example weights in the input data.
DEFAULT_WEIGHT_BY_LABEL = {
DEFAULT_WEIGHT_BY_LABEL: Final[Dict[str, float]] = {
"is_clicked": 0.3,
"is_favorited": 1.0,
"is_open_linked": 0.1,
@ -11,19 +14,19 @@ DEFAULT_WEIGHT_BY_LABEL = {
"is_profile_clicked": 1.0,
"is_replied": 9.0,
"is_retweeted": 1.0,
"is_video_playback_50": 0.01
"is_video_playback_50": 0.01,
}
def add_weight_arguments(parser):
def add_weight_arguments(parser: DefaultSubcommandArgParse) -> DefaultSubcommandArgParse:
for label_name in LABEL_NAMES:
parser.add_argument(
_make_weight_cli_argument_name(label_name),
type=float,
default=DEFAULT_WEIGHT_BY_LABEL[label_name],
dest=_make_weight_param_name(label_name)
dest=_make_weight_param_name(label_name),
)
def make_weights_tensor(input_weights, label, params):
def make_weights_tensor(input_weights, label, params) -> Tensor:
'''
Replaces the weights for each positive engagement and keeps the input weights for negative examples.
'''
@ -32,12 +35,15 @@ def make_weights_tensor(input_weights, label, params):
index, default_weight = INDEX_BY_LABEL[label_name], DEFAULT_WEIGHT_BY_LABEL[label_name]
weight_param_name =_make_weight_param_name(label_name)
weight_tensors.append(
tf.reshape(tf.math.scalar_mul(getattr(params, weight_param_name) - default_weight, label[:, index]), [-1, 1])
tf.reshape(
tf.math.scalar_mul(
getattr(params, weight_param_name) - default_weight, label[:, index]
), [-1, 1])
)
return tf.math.accumulate_n(weight_tensors)
def _make_weight_cli_argument_name(label_name):
def _make_weight_cli_argument_name(label_name: str) -> str:
return f"--weight.{label_name}"
def _make_weight_param_name(label_name):
def _make_weight_param_name(label_name: str) -> str:
return f"weight_{label_name}"