mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-09-27 17:38:43 +02:00
Added type hinting and code formatting.
This commit is contained in:
parent
90d7ea370e
commit
68351724e9
@ -1,9 +1,12 @@
|
|||||||
# checkstyle: noqa
|
# checkstyle: noqa
|
||||||
|
from typing import Dict, Final
|
||||||
|
from tensorflow import Tensor
|
||||||
|
from twml import DefaultSubcommandArgParse
|
||||||
import tensorflow.compat.v1 as tf
|
import tensorflow.compat.v1 as tf
|
||||||
from .constants import INDEX_BY_LABEL, LABEL_NAMES
|
from .constants import INDEX_BY_LABEL, LABEL_NAMES
|
||||||
|
|
||||||
# TODO: Read these from command line arguments, since they specify the existing example weights in the input data.
|
# TODO: Read these from command line arguments, since they specify the existing example weights in the input data.
|
||||||
DEFAULT_WEIGHT_BY_LABEL = {
|
DEFAULT_WEIGHT_BY_LABEL: Final[Dict[str, float]] = {
|
||||||
"is_clicked": 0.3,
|
"is_clicked": 0.3,
|
||||||
"is_favorited": 1.0,
|
"is_favorited": 1.0,
|
||||||
"is_open_linked": 0.1,
|
"is_open_linked": 0.1,
|
||||||
@ -11,19 +14,19 @@ DEFAULT_WEIGHT_BY_LABEL = {
|
|||||||
"is_profile_clicked": 1.0,
|
"is_profile_clicked": 1.0,
|
||||||
"is_replied": 9.0,
|
"is_replied": 9.0,
|
||||||
"is_retweeted": 1.0,
|
"is_retweeted": 1.0,
|
||||||
"is_video_playback_50": 0.01
|
"is_video_playback_50": 0.01,
|
||||||
}
|
}
|
||||||
|
|
||||||
def add_weight_arguments(parser):
|
def add_weight_arguments(parser: DefaultSubcommandArgParse) -> DefaultSubcommandArgParse:
|
||||||
for label_name in LABEL_NAMES:
|
for label_name in LABEL_NAMES:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
_make_weight_cli_argument_name(label_name),
|
_make_weight_cli_argument_name(label_name),
|
||||||
type=float,
|
type=float,
|
||||||
default=DEFAULT_WEIGHT_BY_LABEL[label_name],
|
default=DEFAULT_WEIGHT_BY_LABEL[label_name],
|
||||||
dest=_make_weight_param_name(label_name)
|
dest=_make_weight_param_name(label_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_weights_tensor(input_weights, label, params):
|
def make_weights_tensor(input_weights, label, params) -> Tensor:
|
||||||
'''
|
'''
|
||||||
Replaces the weights for each positive engagement and keeps the input weights for negative examples.
|
Replaces the weights for each positive engagement and keeps the input weights for negative examples.
|
||||||
'''
|
'''
|
||||||
@ -32,12 +35,15 @@ def make_weights_tensor(input_weights, label, params):
|
|||||||
index, default_weight = INDEX_BY_LABEL[label_name], DEFAULT_WEIGHT_BY_LABEL[label_name]
|
index, default_weight = INDEX_BY_LABEL[label_name], DEFAULT_WEIGHT_BY_LABEL[label_name]
|
||||||
weight_param_name =_make_weight_param_name(label_name)
|
weight_param_name =_make_weight_param_name(label_name)
|
||||||
weight_tensors.append(
|
weight_tensors.append(
|
||||||
tf.reshape(tf.math.scalar_mul(getattr(params, weight_param_name) - default_weight, label[:, index]), [-1, 1])
|
tf.reshape(
|
||||||
|
tf.math.scalar_mul(
|
||||||
|
getattr(params, weight_param_name) - default_weight, label[:, index]
|
||||||
|
), [-1, 1])
|
||||||
)
|
)
|
||||||
return tf.math.accumulate_n(weight_tensors)
|
return tf.math.accumulate_n(weight_tensors)
|
||||||
|
|
||||||
def _make_weight_cli_argument_name(label_name):
|
def _make_weight_cli_argument_name(label_name: str) -> str:
|
||||||
return f"--weight.{label_name}"
|
return f"--weight.{label_name}"
|
||||||
|
|
||||||
def _make_weight_param_name(label_name):
|
def _make_weight_param_name(label_name: str) -> str:
|
||||||
return f"weight_{label_name}"
|
return f"weight_{label_name}"
|
||||||
|
Loading…
Reference in New Issue
Block a user