mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-11-16 00:25:11 +01:00
Added type hinting and code formatting.
This commit is contained in:
parent
90d7ea370e
commit
68351724e9
@ -1,9 +1,12 @@
|
||||
# checkstyle: noqa
|
||||
from typing import Dict, Final
|
||||
from tensorflow import Tensor
|
||||
from twml import DefaultSubcommandArgParse
|
||||
import tensorflow.compat.v1 as tf
|
||||
from .constants import INDEX_BY_LABEL, LABEL_NAMES
|
||||
|
||||
# TODO: Read these from command line arguments, since they specify the existing example weights in the input data.
|
||||
DEFAULT_WEIGHT_BY_LABEL = {
|
||||
DEFAULT_WEIGHT_BY_LABEL: Final[Dict[str, float]] = {
|
||||
"is_clicked": 0.3,
|
||||
"is_favorited": 1.0,
|
||||
"is_open_linked": 0.1,
|
||||
@ -11,19 +14,19 @@ DEFAULT_WEIGHT_BY_LABEL = {
|
||||
"is_profile_clicked": 1.0,
|
||||
"is_replied": 9.0,
|
||||
"is_retweeted": 1.0,
|
||||
"is_video_playback_50": 0.01
|
||||
"is_video_playback_50": 0.01,
|
||||
}
|
||||
|
||||
def add_weight_arguments(parser):
|
||||
def add_weight_arguments(parser: DefaultSubcommandArgParse) -> DefaultSubcommandArgParse:
|
||||
for label_name in LABEL_NAMES:
|
||||
parser.add_argument(
|
||||
_make_weight_cli_argument_name(label_name),
|
||||
type=float,
|
||||
default=DEFAULT_WEIGHT_BY_LABEL[label_name],
|
||||
dest=_make_weight_param_name(label_name)
|
||||
dest=_make_weight_param_name(label_name),
|
||||
)
|
||||
|
||||
def make_weights_tensor(input_weights, label, params):
|
||||
def make_weights_tensor(input_weights, label, params) -> Tensor:
|
||||
'''
|
||||
Replaces the weights for each positive engagement and keeps the input weights for negative examples.
|
||||
'''
|
||||
@ -32,12 +35,15 @@ def make_weights_tensor(input_weights, label, params):
|
||||
index, default_weight = INDEX_BY_LABEL[label_name], DEFAULT_WEIGHT_BY_LABEL[label_name]
|
||||
weight_param_name =_make_weight_param_name(label_name)
|
||||
weight_tensors.append(
|
||||
tf.reshape(tf.math.scalar_mul(getattr(params, weight_param_name) - default_weight, label[:, index]), [-1, 1])
|
||||
tf.reshape(
|
||||
tf.math.scalar_mul(
|
||||
getattr(params, weight_param_name) - default_weight, label[:, index]
|
||||
), [-1, 1])
|
||||
)
|
||||
return tf.math.accumulate_n(weight_tensors)
|
||||
|
||||
def _make_weight_cli_argument_name(label_name):
|
||||
def _make_weight_cli_argument_name(label_name: str) -> str:
|
||||
return f"--weight.{label_name}"
|
||||
|
||||
def _make_weight_param_name(label_name):
|
||||
def _make_weight_param_name(label_name: str) -> str:
|
||||
return f"weight_{label_name}"
|
||||
|
Loading…
Reference in New Issue
Block a user