Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings

This commit is contained in:
twitter-team
2023-03-31 13:05:14 -05:00
commit 78c3235eee
111 changed files with 11876 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,48 @@
This project is the "heavy ranker" used on the "For You" timeline. This is used to generate the ranking of Tweet after candidate retrieval and light ranker (note the final ordering of the Tweet is not directly the highest -> lowest scoring, because after scoring other heuristics are used).
This model captures the ranking model used for the majority of users of Twitter "For You" timeline in early March 2023. Due to the need to make sure this runs independently from other parts of Twitter codebase, there may be small differences from the production model.
The model receives various features, describing the Tweet and the user whose timeline is being constructed as input (see FEATURES.md for more details). The model outputs multiple binary predictions about how the user will respond if shown the Tweet.
Those are:
"recap.engagement.is_favorited": The probability the user will favorite the Tweet.
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": The probability the user will click into the conversation of this Tweet and reply or Like a Tweet.
"recap.engagement.is_good_clicked_convo_desc_v2": The probability the user will click into the conversation of this Tweet and stay there for at least 2 minutes.
"recap.engagement.is_negative_feedback_v2": The probability the user will react negatively (requesting "show less often" on the Tweet or author, block or mute the Tweet author)
"recap.engagement.is_profile_clicked_and_profile_engaged": The probability the user opens the Tweet author profile and Likes or replies to a Tweet.
"recap.engagement.is_replied": The probability the user replies to the Tweet.
"recap.engagement.is_replied_reply_engaged_by_author": The probability the user replies to the Tweet and this reply is engaged by the Tweet author.
"recap.engagement.is_report_tweet_clicked": The probability the user will click Report Tweet.
"recap.engagement.is_retweeted": The probability the user will ReTweet the Tweet.
"recap.engagement.is_video_playback_50": The probability (for a video Tweet) that the user will watch at least half of the video
For ranking the candidates these predictions are combined into a score by weighting them:
"recap.engagement.is_favorited": 0.5
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": 11* (the maximum prediction from these two "good click" features is used and weighted by 11, the other prediction is ignored).
"recap.engagement.is_good_clicked_convo_desc_v2": 11*
"recap.engagement.is_negative_feedback_v2": -74
"recap.engagement.is_profile_clicked_and_profile_engaged": 12
"recap.engagement.is_replied": 27
"recap.engagement.is_replied_reply_engaged_by_author": 75
"recap.engagement.is_report_tweet_clicked": -369
"recap.engagement.is_retweeted": 1
"recap.engagement.is_video_playback_50": 0.005
We cannot release the real training data due to privacy restrictions. However, we have included a script to generate random data to ensure you can run the model training code.
To try training the model (assuming you have already followed the repo setup instructions and are inside a virtualenv).
Run
$ ./projects/home/recap/scripts/create_random_data.sh
This will create some random data (in $HOME/tmp/recap_local_random_data).
$ ./projects/home/recap/scripts/run_local.sh
This will train the model (for a small number of iterations). Checkpoints and logs will be written to $HOME/tmp/runs/recap_local_debug.
The model training is configured through a yaml file (./projects/home/recap/config/local_prod.yaml).
The model architecture is a parallel masknet (https://arxiv.org/abs/2102.07619).

View File

View File

@ -0,0 +1,54 @@
from tml.core import config as config_mod
import tml.projects.home.recap.data.config as data_config
import tml.projects.home.recap.model.config as model_config
import tml.projects.home.recap.optimizer.config as optimizer_config
from enum import Enum
from typing import Dict, Optional
import pydantic
class TrainingConfig(config_mod.BaseConfig):
save_dir: str = "/tmp/model"
num_train_steps: pydantic.PositiveInt = 1000000
initial_checkpoint_dir: str = pydantic.Field(
None, description="Directory of initial checkpoints", at_most_one_of="initialization"
)
checkpoint_every_n: pydantic.PositiveInt = 1000
checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Maximum number of checkpoints to keep. Defaults to keeping all."
)
train_log_every_n: pydantic.PositiveInt = 1000
num_eval_steps: int = pydantic.Field(
16384, description="Number of evaluation steps. If < 0 the entire dataset " "will be used."
)
eval_log_every_n: pydantic.PositiveInt = 5000
eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60
gradient_accumulation: int = pydantic.Field(
None, description="Number of replica steps to accumulate gradients."
)
class RecapConfig(config_mod.BaseConfig):
training: TrainingConfig = pydantic.Field(TrainingConfig())
model: model_config.ModelConfig
train_data: data_config.RecapDataConfig
validation_data: Dict[str, data_config.RecapDataConfig]
optimizer: optimizer_config.RecapOptimizerConfig
which_metrics: Optional[str] = pydantic.Field(None, description="which metrics to pick.")
# DANGER DANGER! You might expect validators here to ensure that multi task learning setups are
# the same as the data. Unfortunately, this throws opaque errors when the model configuration is
# invalid. In our judgement, that is a more frequency and worse occurrence than tasks not matching
# the data.
class JobMode(str, Enum):
"""Job modes."""
TRAIN = "train"
EVALUATE = "evaluate"
INFERENCE = "inference"

View File

@ -0,0 +1,114 @@
{
"schema": [
{
"dtype": "int64_list",
"feature_name": "home_recap_2022_discrete__segdense_vals",
"length": 320
},
{
"dtype": "float_list",
"feature_name": "home_recap_2022_cont__segdense_vals",
"length": 6000
},
{
"dtype": "int64_list",
"feature_name": "home_recap_2022_binary__segdense_vals",
"length": 512
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_tweet_detail_dwelled_15_sec",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_profile_clicked_and_profile_engaged",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_replied_reply_engaged_by_author",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_video_playback_50",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_report_tweet_clicked",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_replied",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "meta.author_id",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_negative_feedback_v2",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_retweeted",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_favorited",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "meta.tweet_id",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_good_clicked_convo_desc_v2",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "meta.user_id",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_bookmarked",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_shared",
"length": 1
},
{
"dtype": "float_list",
"feature_name": "user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings",
"length": 200
},
{
"dtype": "float_list",
"feature_name": "original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings",
"length": 200
},
{
"dtype": "float_list",
"feature_name": "user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings",
"length": 200
}
]
}

View File

@ -0,0 +1,477 @@
training:
num_train_steps: 10
num_eval_steps: 5
checkpoint_every_n: 5
train_log_every_n: 1
eval_log_every_n: 1
save_dir: ${HOME}/tmp/runs/recap_local_debug
eval_timeout_in_s: 7200
model:
backbone:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config:
mask_blocks:
- aggregation_size: 1024
input_layer_norm: false
output_size: 1024
reduction_factor: null
- aggregation_size: 1024
input_layer_norm: false
output_size: 1024
reduction_factor: null
- aggregation_size: 1024
input_layer_norm: false
output_size: 1024
reduction_factor: null
- aggregation_size: 1024
input_layer_norm: false
output_size: 1024
reduction_factor: null
mlp:
batch_norm: null
dropout: null
final_layer_activation: true
layer_sizes:
- 2048
use_parallel: true
mlp_config: null
pos_weight: 1.0
featurization_config:
clip_log1p_abs_config: null
double_norm_log_config:
batch_norm_config:
affine: true
momentum: 0.01
clip_magnitude: 5.0
layer_norm_config:
axis: -1
center: true
epsilon: 0.0
scale: true
feature_names_to_concat:
- binary
log1p_abs_config: null
z_score_log_config: null
large_embeddings: null
multi_task_type: share_all
position_debias_config: null
small_embeddings: null
stratifiers: null
tasks:
recap.engagement.is_favorited:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_good_clicked_convo_desc_v2:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_negative_feedback_v2:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout:
rate: 0.1
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_profile_clicked_and_profile_engaged:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_replied:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_replied_reply_engaged_by_author:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_report_tweet_clicked:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout:
rate: 0.2
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_retweeted:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_video_playback_50:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
train_data:
global_batch_size: 128
dataset_service_compression: AUTO
inputs: &data_root "${HOME}/tmp/recap_local_random_data/*.gz"
seg_dense_schema: &seg_dense_schema
schema_path: "${TML_BASE}/projects/home/recap/config/home_recap_2022/segdense.json"
renamed_features:
"continuous": "home_recap_2022_cont__segdense_vals"
"binary": "home_recap_2022_binary__segdense_vals"
"discrete": "home_recap_2022_discrete__segdense_vals"
"author_embedding": "original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings"
"user_embedding": "user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings"
"user_eng_embedding": "user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings"
"meta__author_id": "meta.author_id"
"meta__user_id": "meta.user_id"
"meta__tweet_id": "meta.tweet_id"
tasks: &data_tasks
"recap.engagement.is_bookmarked": {}
"recap.engagement.is_favorited": {}
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": {}
"recap.engagement.is_good_clicked_convo_desc_v2": {}
"recap.engagement.is_negative_feedback_v2": {}
"recap.engagement.is_profile_clicked_and_profile_engaged": {}
"recap.engagement.is_replied": {}
"recap.engagement.is_replied_reply_engaged_by_author": {}
"recap.engagement.is_report_tweet_clicked": {}
"recap.engagement.is_retweeted": {}
"recap.engagement.is_shared": {}
"recap.engagement.is_tweet_detail_dwelled_15_sec": {}
"recap.engagement.is_video_playback_50": {}
preprocess: &preprocess
truncate_and_slice:
continuous_feature_truncation: 2117
binary_feature_truncation: 59
validation_data:
validation: &validation
global_batch_size: &eval_batch_size 128
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks: *data_tasks
preprocess: *preprocess
train:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks: *data_tasks
preprocess: *preprocess
recap.engagement.is_favorited:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_favorited":
pos_downsampling_rate: 0.8387
neg_downsampling_rate: 0.01
evaluation_tasks:
- "recap.engagement.is_favorited"
preprocess: *preprocess
recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied":
pos_downsampling_rate: 0.9164
neg_downsampling_rate: 0.00195
evaluation_tasks:
- "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied"
preprocess: *preprocess
recap.engagement.is_good_clicked_convo_desc_v2:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_good_clicked_convo_desc_v2":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.00174
evaluation_tasks:
- "recap.engagement.is_good_clicked_convo_desc_v2"
preprocess: *preprocess
recap.engagement.is_negative_feedback_v2:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_negative_feedback_v2":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.00280
evaluation_tasks:
- "recap.engagement.is_negative_feedback_v2"
preprocess: *preprocess
recap.engagement.is_profile_clicked_and_profile_engaged:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_profile_clicked_and_profile_engaged":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.0015
evaluation_tasks:
- "recap.engagement.is_profile_clicked_and_profile_engaged"
preprocess: *preprocess
recap.engagement.is_replied:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_replied":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.005
evaluation_tasks:
- "recap.engagement.is_replied"
preprocess: *preprocess
recap.engagement.is_replied_reply_engaged_by_author:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_replied_reply_engaged_by_author":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.001
evaluation_tasks:
- "recap.engagement.is_replied_reply_engaged_by_author"
preprocess: *preprocess
recap.engagement.is_report_tweet_clicked:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_report_tweet_clicked":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.000014
evaluation_tasks:
- "recap.engagement.is_report_tweet_clicked"
preprocess: *preprocess
recap.engagement.is_retweeted:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_retweeted":
pos_downsampling_rate: 0.9561
neg_downsampling_rate: 0.004
evaluation_tasks:
- "recap.engagement.is_retweeted"
preprocess: *preprocess
recap.engagement.is_video_playback_50:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_video_playback_50":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.00427
evaluation_tasks:
- "recap.engagement.is_video_playback_50"
preprocess: *preprocess
optimizer:
adam:
beta_1: 0.95
beta_2: 0.999
epsilon: 1.0e-07
multi_task_learning_rates:
backbone_learning_rate:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
tower_learning_rates:
recap.engagement.is_favorited:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0008
num_ramp_steps: 5000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 2000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_good_clicked_convo_desc_v2:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0002
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_negative_feedback_v2:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0005
num_ramp_steps: 5000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_profile_clicked_and_profile_engaged:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0003
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_replied:
constant: null
linear_ramp_to_constant:
learning_rate: 0.001
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_replied_reply_engaged_by_author:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_report_tweet_clicked:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 3000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_retweeted:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_video_playback_50:
constant: null
linear_ramp_to_constant:
learning_rate: 0.003
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
single_task_learning_rate: null

View File

View File

@ -0,0 +1,246 @@
import typing
from enum import Enum
from tml.core import config as base_config
import pydantic
class ExplicitDateInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_date and days of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
num_missing_days_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_days_tol days of missing data."
)
class ExplicitDatetimeInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_datetime and hours of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
num_missing_hours_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
)
class DdsCompressionOption(str, Enum):
"""The only valid compression option is 'AUTO'"""
AUTO = "AUTO"
class DatasetConfig(base_config.BaseConfig):
inputs: str = pydantic.Field(
None, description="A glob for selecting data.", one_of="date_inputs_format"
)
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
None, one_of="date_inputs_format"
)
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
global_batch_size: pydantic.PositiveInt
num_files_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Number of shards to keep."
)
repeat_files: bool = pydantic.Field(
True, description="DEPRICATED. Files are repeated no matter what this is set to."
)
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
cache: bool = pydantic.Field(
False,
description="Cache dataset in memory. Careful to only use this when you"
" have enough memory to fit entire dataset.",
)
data_service_dispatcher: str = pydantic.Field(None)
ignore_data_errors: bool = pydantic.Field(
False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs."
)
dataset_service_compression: DdsCompressionOption = pydantic.Field(
None,
description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'",
)
# tf.data.Dataset options
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of parallel calls."
)
interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of shards to interleave."
)
class TruncateAndSlice(base_config.BaseConfig):
# Apply truncation and then slice.
continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates continuous features to this amount for efficiency."
)
binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates binary features to this amount for efficiency."
)
continuous_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input continuous features."
)
binary_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input binary features."
)
class DataType(str, Enum):
BFLOAT16 = "bfloat16"
BOOL = "bool"
FLOAT32 = "float32"
FLOAT16 = "float16"
UINT8 = "uint8"
class DownCast(base_config.BaseConfig):
# Apply down casting to selected features.
features: typing.Dict[str, DataType] = pydantic.Field(
None, description="Map features to down cast data types."
)
class TaskData(base_config.BaseConfig):
pos_downsampling_rate: float = pydantic.Field(
1.0,
description="Downsampling rate of positives used to generate dataset.",
)
neg_downsampling_rate: float = pydantic.Field(
1.0,
description="Downsampling rate of negatives used to generate dataset.",
)
class SegDenseSchema(base_config.BaseConfig):
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
features: typing.List[str] = pydantic.Field(
[],
description="List of features (in addition to the renamed features) to read from schema path above.",
)
renamed_features: typing.Dict[str, str] = pydantic.Field(
{}, description="Dictionary of renamed features."
)
mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(
{},
description="(experimental) Number of mantissa bits to mask to simulate lower precision data.",
)
class RectifyLabels(base_config.BaseConfig):
label_rectification_window_in_hours: float = pydantic.Field(
3.0, description="overlap time in hours for which to flip labels"
)
served_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to served time"
)
impressed_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to impressed time"
)
label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(
..., description="label to the input field corresponding to engagement time"
)
class ExtractFeaturesRow(base_config.BaseConfig):
name: str = pydantic.Field(
...,
description="name of the new field name to be created",
)
source_tensor: str = pydantic.Field(
...,
description="name of the dense tensor to look for the feature",
)
index: int = pydantic.Field(
...,
description="index of the feature in the dense tensor",
)
class ExtractFeatures(base_config.BaseConfig):
extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(
[],
description="list of features to be extracted with their name, source tensor and index",
)
class DownsampleNegatives(base_config.BaseConfig):
batch_multiplier: int = pydantic.Field(
None,
description="batch multiplier",
)
engagements_list: typing.List[str] = pydantic.Field(
[],
description="engagements with kept positives",
)
num_engagements: int = pydantic.Field(
...,
description="number engagements used in the model, including ones excluded in engagements_list",
)
class Preprocess(base_config.BaseConfig):
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
rectify_labels: RectifyLabels = pydantic.Field(
None, description="Rectify labels for a given overlap window"
)
extract_features: ExtractFeatures = pydantic.Field(
None, description="Extract features from dense tensors."
)
downsample_negatives: DownsampleNegatives = pydantic.Field(
None, description="Downsample negatives."
)
class Sampler(base_config.BaseConfig):
"""Assumes function is defined in data/samplers.py.
Only use this for quick experimentation.
If samplers are useful, we should sample from upstream data generation.
DEPRICATED, DO NOT USE.
"""
name: str
kwargs: typing.Dict
class RecapDataConfig(DatasetConfig):
seg_dense_schema: SegDenseSchema
tasks: typing.Dict[str, TaskData] = pydantic.Field(
description="Description of individual tasks in this dataset."
)
evaluation_tasks: typing.List[str] = pydantic.Field(
[], description="If specified, lists the tasks we're generating metrics for."
)
preprocess: Preprocess = pydantic.Field(
None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference."
)
sampler: Sampler = pydantic.Field(
None,
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
)
@pydantic.root_validator()
def _validate_evaluation_tasks(cls, values):
if values.get("evaluation_tasks") is not None:
for task in values["evaluation_tasks"]:
if task not in values["tasks"]:
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
return values

View File

@ -0,0 +1,477 @@
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Dict
import functools
import torch
import tensorflow as tf
from tml.common.batch import DataclassBatch
from tml.projects.home.recap.data.config import RecapDataConfig, TaskData
from tml.projects.home.recap.data import preprocessors
from tml.projects.home.recap.config import JobMode
from tml.projects.home.recap.data.tfe_parsing import get_seg_dense_parse_fn
from tml.projects.home.recap.data.util import (
keyed_jagged_tensor_from_tensors_dict,
sparse_or_dense_tf_to_torch,
)
from absl import logging
import torch.distributed as dist
@dataclass
class RecapBatch(DataclassBatch):
"""Holds features and labels from the Recap dataset."""
continuous_features: torch.Tensor
binary_features: torch.Tensor
discrete_features: torch.Tensor
sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821
labels: torch.Tensor
user_embedding: torch.Tensor = None
user_eng_embedding: torch.Tensor = None
author_embedding: torch.Tensor = None
weights: torch.Tensor = None
def __post_init__(self):
if self.weights is None:
self.weights = torch.ones_like(self.labels)
for feature_name, feature_value in self.as_dict().items():
if ("embedding" in feature_name) and (feature_value is None):
setattr(self, feature_name, torch.empty([0, 0]))
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
"""Converts a torch data loader output into `RecapBatch`."""
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
try:
features_in, labels = x
except ValueError:
# For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple
features_in, labels = x, None
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
if sparse_feature_names:
sparse_features = keyed_jagged_tensor_from_tensors_dict(
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
)
user_embedding, user_eng_embedding, author_embedding = None, None, None
if "user_embedding" in features_in:
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
else:
user_embedding = features_in["user_embedding"]
if "user_eng_embedding" in features_in:
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
else:
user_eng_embedding = features_in["user_eng_embedding"]
if "author_embedding" in features_in:
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
else:
author_embedding = features_in["author_embedding"]
return RecapBatch(
continuous_features=features_in["continuous"],
binary_features=features_in["binary"],
discrete_features=features_in["discrete"],
sparse_features=sparse_features,
user_embedding=user_embedding,
user_eng_embedding=user_eng_embedding,
author_embedding=author_embedding,
labels=labels,
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels)
)
def _chain(param, f1, f2):
"""
Reduce multiple functions into one chained function
_chain(x, f1, f2) -> f2(f1(x))
"""
output = param
fns = [f1, f2]
for f in fns:
output = f(output)
return output
def _add_weights(inputs, tasks: Dict[str, TaskData]):
"""Adds weights based on label sampling for positive and negatives.
This is useful for numeric calibration etc. This mutates inputs.
Args:
inputs: A dictionary of strings to tensor-like structures.
tasks: A dict of string (label) to `TaskData` specifying inputs.
Returns:
A tuple of features and labels; weights are added to features.
"""
weights = []
for key, task in tasks.items():
label = inputs[key]
float_label = tf.cast(label, tf.float32)
weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
)
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
return inputs
def get_datetimes(explicit_datetime_inputs):
"""Compute list datetime strings for train/validation data."""
datetime_format = "%Y/%m/%d/%H"
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)
dates = sorted(
[
(end - timedelta(hours=i + 1)).strftime(datetime_format)
for i in range(int(explicit_datetime_inputs.hours))
]
)
return dates
def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
"""
Compile list of files for training/validation.
Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data.
For each hour of data, if the directory is missing or empty, we increment a counter to keep
track of the number of missing data hours.
Returns only files with a `.gz` extension.
Args:
explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object
Returns:
data_files: Sorted list of files to read corresponding to data at the desired datetimes
num_hours_missing: Number of hours that we are missing data
"""
datetimes = get_datetimes(explicit_datetime_inputs)
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
data_files = []
num_hours_missing = 0
for folder in folders:
try:
files = tf.io.gfile.listdir(folder)
if not files:
logging.warning(f"{folder} contained no data files")
num_hours_missing += 1
data_files.extend(
[
os.path.join(folder, filename)
for filename in files
if filename.rsplit(".", 1)[-1].lower() == "gz"
]
)
except tf.errors.NotFoundError as e:
num_hours_missing += 1
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
return sorted(data_files), num_hours_missing
def _map_output_for_inference(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
):
if preprocessor:
raise ValueError("No preprocessor should be used at inference time.")
if add_weights:
raise NotImplementedError()
# Add zero weights.
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1))
for label in tasks:
del inputs[label]
return inputs
def _map_output_for_train_eval(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
):
if add_weights:
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
# Warning this has to happen first as it changes the input
if preprocessor:
inputs = preprocessor(inputs)
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
for label in tasks:
del inputs[label]
return inputs, label_values
def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
"""Adds weights based on label sampling for positive and negatives.
This is useful for numeric calibration etc. This mutates inputs.
Args:
inputs: A dictionary of strings to tensor-like structures.
tasks: A dict of string (label) to `TaskData` specifying inputs.
Returns:
A tuple of features and labels; weights are added to features.
"""
weights = []
for key, task in tasks.items():
label = inputs[key]
float_label = tf.cast(label, tf.float32)
weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
)
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
return inputs
class RecapDataset(torch.utils.data.IterableDataset):
def __init__(
self,
data_config: RecapDataConfig,
dataset_service: Optional[str] = None,
mode: JobMode = JobMode.TRAIN,
compression: Optional[str] = "AUTO",
repeat: bool = False,
vocab_mapper: tf.keras.Model = None,
):
logging.info("***** Labels *****")
logging.info(list(data_config.tasks.keys()))
self._data_config = data_config
self._parse_fn = get_seg_dense_parse_fn(data_config)
self._mode = mode
self._repeat = repeat
self._num_concurrent_iterators = 1
self._vocab_mapper = vocab_mapper
self.dataset_service = dataset_service
preprocessor = None
self._batch_size_multiplier = 1
if data_config.preprocess:
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)
if data_config.preprocess.downsample_negatives:
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
self._preprocessor = preprocessor
if mode == JobMode.INFERENCE:
if preprocessor is not None:
raise ValueError("Expect no preprocessor at inference time.")
should_add_weights = False
output_map_fn = _map_output_for_inference # (features,)
else:
# Only add weights if there is a reason to! If all weights will
# be equal to 1.0, save bandwidth between DDS and Chief by simply
# relying on the fact that weights default to 1.0 in `RecapBatch`
# WARNING: Weights may still be added as a side effect of a preprocessor
# such as `DownsampleNegatives`.
should_add_weights = any(
[
task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0
for task_cfg in data_config.tasks.values()
]
)
output_map_fn = _map_output_for_train_eval # (features, labels)
self._output_map_fn = functools.partial(
output_map_fn,
tasks=data_config.tasks,
preprocessor=preprocessor,
add_weights=should_add_weights,
)
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
self._tf_dataset = self._create_tf_dataset()
self._init_tensor_spec()
def _init_tensor_spec(self):
def _tensor_spec_to_torch_shape(spec):
if spec.shape is None:
return None
shape = [x if x is not None else -1 for x in spec.shape]
return torch.Size(shape)
self.torch_element_spec = tf.nest.map_structure(
_tensor_spec_to_torch_shape, self._tf_dataset.element_spec
)
def _create_tf_dataset(self):
if hasattr(self, "_tf_dataset"):
raise ValueError("Do not call `_create_tf_dataset` more than once.")
world_size = dist.get_world_size() if dist.is_initialized() else 1
per_replica_bsz = (
self._batch_size_multiplier * self._data_config.global_batch_size // world_size
)
dataset: tf.data.Dataset = self._create_base_tf_dataset(
batch_size=per_replica_bsz,
)
if self._repeat:
logging.info("Repeating dataset")
dataset = dataset.repeat()
if self.dataset_service:
if self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self.num_concurrent_iterators,
on_chief=False,
)
self.dataset_id, self.job_name = register_dataset(
dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
)
dataset = distribute_from_dataset_id(
dataset_id=self.dataset_id, # type: ignore[arg-type]
job_name=self.job_name,
dataset_service=self.dataset_service,
compression=self.compression,
)
elif self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self._num_concurrent_iterators,
on_chief=True,
)
# Vocabulary mapping happens on the training node, not in dds because of size.
if self._vocab_mapper:
dataset = dataset.map(self._vocab_mapper)
return dataset.prefetch(world_size * 2)
def _create_base_tf_dataset(self, batch_size: int):
if self._data_config.inputs:
glob = self._data_config.inputs
filenames = sorted(tf.io.gfile.glob(glob))
elif self._data_config.explicit_datetime_inputs:
num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol
filenames, num_hours_missing = get_explicit_datetime_inputs_files(
self._data_config.explicit_datetime_inputs,
increment="hourly",
)
if num_hours_missing > num_missing_hours_tol:
raise ValueError(
f"We are missing {num_hours_missing} hours of data"
f"more than tolerance {num_missing_hours_tol}."
)
elif self._data_config.explicit_date_inputs:
num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol
filenames, num_days_missing = get_explicit_datetime_inputs_files(
self._data_config.explicit_date_inputs,
increment="daily",
)
if num_days_missing > num_missing_days_tol:
raise ValueError(
f"We are missing {num_days_missing} days of data"
f"more than tolerance {num_missing_days_tol}."
)
else:
raise ValueError(
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
)
num_files = len(filenames)
logging.info(f"Found {num_files} data files")
if num_files < 1:
raise ValueError("No data files found")
if self._data_config.num_files_to_keep is not None:
filenames = filenames[: self._data_config.num_files_to_keep]
logging.info(f"Retaining only {len(filenames)} files.")
filenames_ds = (
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
# Because of drop_remainder, if our dataset does not fill
# up a batch, it will emit nothing without this repeat.
.repeat(-1)
)
if self._data_config.file_batch_size:
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)
def per_shard_dataset(filename):
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
return ds.prefetch(4)
ds = filenames_ds.interleave(
per_shard_dataset,
block_length=4,
deterministic=False,
num_parallel_calls=self._data_config.interleave_num_parallel_calls
or tf.data.experimental.AUTOTUNE,
)
# Combine functions into one map call to reduce overhead.
map_fn = functools.partial(
_chain,
f1=self._parse_fn,
f2=self._output_map_fn,
)
# Shuffle -> Batch -> Parse is the correct ordering
# Shuffling needs to be performed before batching otherwise there is not much point
# Batching happens before parsing because tf.Example parsing is actually vectorized
# and works much faster overall on batches of data.
ds = (
# DANGER DANGER: there is a default shuffle size here.
ds.shuffle(self._data_config.examples_shuffle_buffer_size)
.batch(batch_size=batch_size, drop_remainder=True)
.map(
map_fn,
num_parallel_calls=self._data_config.map_num_parallel_calls
or tf.data.experimental.AUTOTUNE,
)
)
if self._data_config.cache:
ds = ds.cache()
if self._data_config.ignore_data_errors:
ds = ds.apply(tf.data.experimental.ignore_errors())
options = tf.data.Options()
options.experimental_deterministic = False
ds = ds.with_options(options)
return ds
def _gen(self):
for x in self._tf_dataset:
yield to_batch(x)
def to_dataloader(self) -> Dict[str, torch.Tensor]:
return torch.utils.data.DataLoader(self, batch_size=None)
def __iter__(self):
return iter(self._gen())

View File

@ -0,0 +1,81 @@
import os
import json
from absl import app, flags, logging
import tensorflow as tf
from typing import Dict
from tml.projects.home.recap.data import tfe_parsing
from tml.core import config as tml_config_mod
import tml.projects.home.recap.config as recap_config_mod
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
flags.DEFINE_integer("n_examples", 100, "Numer of examples to generate.")
FLAGS = flags.FLAGS
def _generate_random_example(
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
) -> Dict[str, tf.Tensor]:
example = {}
for feature_name, feature_spec in tf_example_schema.items():
dtype = feature_spec.dtype
if (dtype == tf.int64) or (dtype == tf.int32):
x = tf.experimental.numpy.random.randint(0, high=10, size=feature_spec.shape, dtype=dtype)
elif (dtype == tf.float32) or (dtype == tf.float64):
x = tf.random.uniform(shape=[feature_spec.shape], dtype=dtype)
else:
raise NotImplementedError(f"Unknown type {dtype}")
example[feature_name] = x
return example
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
feature = {}
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
for feature_name, tensor in x.items():
feature[feature_name] = serializers[tensor.dtype](tensor)
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
tf_example_schema = tfe_parsing.create_tf_example_schema(
config.train_data,
seg_dense_schema,
)
record_filename = os.path.join(data_path, "random.tfrecord.gz")
with tf.io.TFRecordWriter(record_filename, "GZIP") as writer:
random_example = _generate_random_example(tf_example_schema)
serialized_example = _serialize_example(random_example)
writer.write(serialized_example)
def _generate_data_main(unused_argv):
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
# Find the path where to put the data
data_path = os.path.dirname(config.train_data.inputs)
logging.info("Putting random data in %s", data_path)
generate_data(data_path, config)
if __name__ == "__main__":
app.run(_generate_data_main)

View File

@ -0,0 +1,206 @@
"""
Preprocessors applied on DDS workers in order to modify the dataset on the fly.
Some of these preprocessors are also applied to the model at serving time.
"""
from tml.projects.home.recap import config as config_mod
from absl import logging
import tensorflow as tf
import numpy as np
class TruncateAndSlice(tf.keras.Model):
"""Class for truncating and slicing."""
def __init__(self, truncate_and_slice_config):
super().__init__()
self._truncate_and_slice_config = truncate_and_slice_config
if self._truncate_and_slice_config.continuous_feature_mask_path:
with tf.io.gfile.GFile(
self._truncate_and_slice_config.continuous_feature_mask_path, "rb"
) as f:
self._continuous_mask = np.load(f).nonzero()[0]
logging.info(f"Slicing {np.sum(self._continuous_mask)} continuous features.")
else:
self._continuous_mask = None
if self._truncate_and_slice_config.binary_feature_mask_path:
with tf.io.gfile.GFile(self._truncate_and_slice_config.binary_feature_mask_path, "rb") as f:
self._binary_mask = np.load(f).nonzero()[0]
logging.info(f"Slicing {np.sum(self._binary_mask)} binary features.")
else:
self._binary_mask = None
def call(self, inputs, training=None, mask=None):
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
if self._truncate_and_slice_config.continuous_feature_truncation:
logging.info("Truncating continuous")
outputs["continuous"] = outputs["continuous"][
:, : self._truncate_and_slice_config.continuous_feature_truncation
]
if self._truncate_and_slice_config.binary_feature_truncation:
logging.info("Truncating binary")
outputs["binary"] = outputs["binary"][
:, : self._truncate_and_slice_config.binary_feature_truncation
]
if self._continuous_mask is not None:
outputs["continuous"] = tf.gather(outputs["continuous"], self._continuous_mask, axis=1)
if self._binary_mask is not None:
outputs["binary"] = tf.gather(outputs["binary"], self._binary_mask, axis=1)
return outputs
class DownCast(tf.keras.Model):
"""Class for Down casting dataset before serialization and transferring to training host.
Depends on the data type and the actual data range, the down casting can be lossless or not.
It is strongly recommended to compare the metrics before and after down casting.
"""
def __init__(self, downcast_config):
super().__init__()
self.config = downcast_config
self._type_map = {
"bfloat16": tf.bfloat16,
"bool": tf.bool,
}
def call(self, inputs, training=None, mask=None):
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
for feature, type_str in self.config.features.items():
assert type_str in self._type_map
if type_str == "bfloat16":
logging.warning(
"Although bfloat16 and float32 have the same number of exponent bits, this down casting is not 100% lossless. Please double check metrics."
)
down_cast_data_type = self._type_map[type_str]
outputs[feature] = tf.cast(outputs[feature], dtype=down_cast_data_type)
return outputs
class RectifyLabels(tf.keras.Model):
"""Class for rectifying labels"""
def __init__(self, rectify_label_config):
super().__init__()
self._config = rectify_label_config
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
def call(self, inputs, training=None, mask=None):
served_ts_field = self._config.served_timestamp_field
impressed_ts_field = self._config.impressed_timestamp_field
for label, engaged_ts_field in self._config.label_to_engaged_timestamp_field.items():
impressed = inputs[impressed_ts_field]
served = inputs[served_ts_field]
engaged = inputs[engaged_ts_field]
keep = tf.math.logical_and(inputs[label] > 0, impressed - served < self._window)
keep = tf.math.logical_and(keep, engaged - served < self._window)
inputs[label] = tf.where(keep, inputs[label], tf.zeros_like(inputs[label]))
return inputs
class ExtractFeatures(tf.keras.Model):
"""Class for extracting individual features from dense tensors by their index."""
def __init__(self, extract_features_config):
super().__init__()
self._config = extract_features_config
def call(self, inputs, training=None, mask=None):
for row in self._config.extract_feature_table:
inputs[row.name] = inputs[row.source_tensor][:, row.index]
return inputs
class DownsampleNegatives(tf.keras.Model):
"""Class for down-sampling/dropping negatives and updating the weights.
If inputs['fav'] = [1, 0, 0, 0] and inputs['weights'] = [1.0, 1.0, 1.0, 1.0]
inputs are transformed to inputs['fav'] = [1, 0] and inputs['weights'] = [1.0, 3.0]
when batch_multiplier=2 and engagements_list=['fav']
It supports multiple engagements (union/logical_or is used to aggregate engagements), so we don't
drop positives for any engagement.
"""
def __init__(self, downsample_negatives_config):
super().__init__()
self.config = downsample_negatives_config
def call(self, inputs, training=None, mask=None):
labels = self.config.engagements_list
# union of engagements
mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1))
n_positives = tf.reduce_sum(tf.cast(mask, tf.int32))
batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32)
negative_weights = tf.math.divide_no_nan(
tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32),
tf.cast(batch_size - n_positives, tf.float32),
)
new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights
def _split_by_label_concatenate_and_truncate(input_tensor):
# takes positive examples and concatenate with negative examples and truncate
# DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50)
return tf.concat(
[
input_tensor[mask],
input_tensor[tf.math.logical_not(mask)],
],
0,
)[:batch_size]
if "weights" not in inputs:
# add placeholder so logic below applies even if weights aren't present in inputs
inputs["weights"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements])
for tensor in inputs:
if tensor == "weights":
inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1])
inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor])
return inputs
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
"""Builds a preprocess model to apply all preprocessing stages."""
if mode == config_mod.JobMode.INFERENCE:
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
return None
preprocess_models = []
if preprocess_config.downsample_negatives:
preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives))
if preprocess_config.truncate_and_slice:
preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice))
if preprocess_config.downcast:
preprocess_models.append(DownCast(preprocess_config.downcast))
if preprocess_config.rectify_labels:
preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels))
if preprocess_config.extract_features:
preprocess_models.append(ExtractFeatures(preprocess_config.extract_features))
if len(preprocess_models) == 0:
raise ValueError("No known preprocessor.")
class PreprocessModel(tf.keras.Model):
def __init__(self, preprocess_models):
super().__init__()
self.preprocess_models = preprocess_models
def call(self, inputs, training=None, mask=None):
outputs = inputs
for model in self.preprocess_models:
outputs = model(outputs, training, mask)
return outputs
if len(preprocess_models) > 1:
logging.warning(
"With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders."
)
return PreprocessModel(preprocess_models)

View File

@ -0,0 +1,129 @@
import functools
import json
from tml.projects.home.recap.data import config as recap_data_config
from absl import logging
import tensorflow as tf
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
def create_tf_example_schema(
data_config: recap_data_config.SegDenseSchema,
segdense_schema,
):
"""Generate schema for deseralizing tf.Example.
Args:
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
labels: List of strings denoting labels.
Returns:
A dictionary schema suitable for deserializing tf.Example.
"""
segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys())
used_features = (
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
)
logging.info(used_features)
tfe_schema = {}
for entry in segdense_schema:
feature_name = entry["feature_name"]
if feature_name in used_features:
length = entry["length"]
dtype = entry["dtype"]
if feature_name in labels:
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
)
elif length == -1:
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
else:
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
)
for feature_name in used_features:
if feature_name not in tfe_schema:
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
return tfe_schema
@functools.lru_cache(1)
def make_mantissa_mask(mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types."""
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types."""
mask: tf.Tensor = make_mantissa_mask(mask_length)
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
def parse_tf_example(
serialized_example,
tfe_schema,
seg_dense_schema_config,
):
"""Parse serialized tf.Example into dict of tensors.
Args:
serialized_example: Serialized tf.Example to be parsed.
tfe_schema: Dictionary schema suitable for deserializing tf.Example.
Returns:
Dictionary of tensors to be used as model input.
"""
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
inputs[new_feature_name] = inputs.pop(old_feature_name)
# This should not actually be used except for experimentation with low precision floats.
if "mask_mantissa_features" in seg_dense_schema_config:
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
# at TF level.
# We should not return empty tensors if we dont use embeddings.
# Otherwise, it breaks numpy->pt conversion
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
for renamed_key in renamed_keys:
if "embedding" in renamed_key and (renamed_key not in inputs):
inputs[renamed_key] = tf.zeros([], tf.float32)
logging.info(f"parsed example and inputs are {inputs}")
return inputs
def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):
"""Placeholder for seg dense.
In the future, when we use more seg dense variations, we can change this.
"""
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
tf_example_schema = create_tf_example_schema(
data_config,
seg_dense_schema,
)
logging.info("***** TF Example Schema *****")
logging.info(tf_example_schema)
parse = functools.partial(
parse_tf_example,
tfe_schema=tf_example_schema,
seg_dense_schema_config=data_config.seg_dense_schema,
)
return parse

View File

@ -0,0 +1,120 @@
from typing import Mapping, Tuple, Union
import torch
import torchrec
import numpy as np
import tensorflow as tf
def keyed_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedTensor":
"""
Convert a dictionary of torch tensor to torchrec keyed tensor
Args:
tensor_map:
Returns:
"""
keys = list(tensor_map.keys())
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
# [Batch_size x 1].
values = [
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
for key in keys
]
return torchrec.KeyedTensor.from_tensor_list(keys, values)
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if tensor.is_sparse:
x = tensor.coalesce() # Ensure that the indices are ordered.
lengths = torch.bincount(x.indices()[0])
values = x.values()
else:
values = tensor
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
return values, lengths
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
"""
Convert a torch tensor to torchrec jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
dense_shape of the sparse tensor can be arbitrary.
Args:
tensor: a torch (sparse) tensor.
Returns:
"""
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
return torchrec.JaggedTensor(values=values, lengths=lengths)
def keyed_jagged_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedJaggedTensor":
"""
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
dense_shape of the sparse tensor can be arbitrary.
Args:
tensor_map:
Returns:
"""
if not tensor_map:
return torchrec.KeyedJaggedTensor(
keys=[],
values=torch.zeros(0, dtype=torch.int),
lengths=torch.zeros(0, dtype=torch.int),
)
values = []
lengths = []
for tensor in tensor_map.values():
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
values.append(torch.squeeze(tensor_val))
lengths.append(tensor_len)
values = torch.cat(values, axis=0)
lengths = torch.cat(lengths, axis=0)
return torchrec.KeyedJaggedTensor(
keys=list(tensor_map.keys()),
values=values,
lengths=lengths,
)
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
return tf_tensor._numpy() # noqa
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
tensor = _tf_to_numpy(tensor)
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
if tensor.dtype.name == "bfloat16":
tensor = tensor.astype(np.float32)
tensor = torch.from_numpy(tensor)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
def sparse_or_dense_tf_to_torch(
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
) -> torch.Tensor:
if isinstance(tensor, tf.SparseTensor):
tensor = torch.sparse_coo_tensor(
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
_dense_tf_to_torch(tensor.values, pin_memory),
torch.Size(_tf_to_numpy(tensor.dense_shape)),
)
else:
tensor = _dense_tf_to_torch(tensor, pin_memory)
return tensor

View File

@ -0,0 +1,94 @@
from typing import List, Optional
import tml.core.config as base_config
from tml.optimizers import config as optimizer_config
import pydantic
class EmbeddingSnapshot(base_config.BaseConfig):
"""Configuration for Embedding snapshot"""
emb_name: str = pydantic.Field(
..., description="Name of the embedding table from the loaded snapshot"
)
embedding_snapshot_uri: str = pydantic.Field(
..., description="Path to torchsnapshot of the embedding"
)
# https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_configs.EmbeddingBagConfig
class EmbeddingBagConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBag."""
name: str = pydantic.Field(..., description="name of embedding bag")
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
vocab: str = pydantic.Field(
None, description="Directory to parquet files of mapping from entity ID to table index."
)
class EmbeddingOptimizerConfig(base_config.BaseConfig):
learning_rate: optimizer_config.LearningRate = pydantic.Field(
None, description="learning rate scheduler for the EBC"
)
init_learning_rate: float = pydantic.Field(description="initial learning rate for the EBC")
# NB: Only sgd is supported right now and implicitly.
# FBGemm only supports simple exact_sgd which only takes LR as an argument.
class LargeEmbeddingsConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBagCollection.
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
"""
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
optimizer: EmbeddingOptimizerConfig
tables_to_log: List[str] = pydantic.Field(
None, description="list of embedding table names that we want to log during training"
)
class StratifierConfig(base_config.BaseConfig):
name: str
index: int
value: int
class SmallEmbeddingBagConfig(base_config.BaseConfig):
"""Configuration for SmallEmbeddingBag."""
name: str = pydantic.Field(..., description="name of embedding bag")
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
index: int = pydantic.Field(..., description="index in the discrete tensor to look for")
class SmallEmbeddingBagConfig(base_config.BaseConfig):
"""Configuration for SmallEmbeddingBag."""
name: str = pydantic.Field(..., description="name of embedding bag")
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
index: int = pydantic.Field(..., description="index in the discrete tensor to look for")
class SmallEmbeddingsConfig(base_config.BaseConfig):
"""Configuration for SmallEmbeddingConfig.
Here we can use discrete features that already are present in our TFRecords generated using
segdense conversion as "home_recap_2022_discrete__segdense_vals" which are available in
the model as "discrete_features", and embed a user-defined set of them with configurable
dimensions and vocabulary sizes.
Compared with LargeEmbedding, this config is for small embedding tables that can fit inside
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
serving time due to size (>>1 GB).
This small embeddings table uses the same optimizer as the rest of the model."""
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
..., description="list of embedding tables"
)

108
projects/home/recap/main.py Normal file
View File

@ -0,0 +1,108 @@
import datetime
import os
from typing import Callable, List, Optional, Tuple
import tensorflow as tf
import tml.common.checkpointing.snapshot as snapshot_lib
from tml.common.device import setup_and_get_device
from tml.core import config as tml_config_mod
import tml.core.custom_training_loop as ctl
from tml.core import debug_training_loop
from tml.core import losses
from tml.core.loss_type import LossType
from tml.model import maybe_shard_model
import tml.projects.home.recap.data.dataset as ds
import tml.projects.home.recap.config as recap_config_mod
import tml.projects.home.recap.optimizer as optimizer_mod
# from tml.projects.home.recap import feature
import tml.projects.home.recap.model as model_mod
import torchmetrics as tm
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from absl import app, flags, logging
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
flags.DEFINE_bool("debug_loop", False, "Run with debug loop (slow)")
FLAGS = flags.FLAGS
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
print("#" * 100)
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
logging.info("Config: %s", config.pretty_print())
device = setup_and_get_device()
# Always enable tensorfloat on supported devices.
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
loss_fn = losses.build_multi_task_loss(
loss_type=LossType.BCE_WITH_LOGITS,
tasks=list(config.model.tasks.keys()),
pos_weights=[task.pos_weight for task in config.model.tasks.values()],
)
# Since the prod model doesn't use large embeddings, for now we won't support them.
assert config.model.large_embeddings is None
train_dataset = ds.RecapDataset(
data_config=config.train_data,
dataset_service=data_service_dispatcher,
mode=recap_config_mod.JobMode.TRAIN,
compression=config.train_data.dataset_service_compression,
vocab_mapper=None,
repeat=True,
)
train_iterator = iter(train_dataset.to_dataloader())
torch_element_spec = train_dataset.torch_element_spec
model = model_mod.create_ranking_model(
data_spec=torch_element_spec[0],
config=config,
loss_fn=loss_fn,
device=device,
)
optimizer, scheduler = optimizer_mod.build_optimizer(model, config.optimizer, None)
model = maybe_shard_model(model, device)
datetime_str = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
print(f"{datetime_str}\n", end="")
if FLAGS.debug_loop:
logging.warning("Running debug mode, slow!")
train_mod = debug_training_loop
else:
train_mod = ctl
train_mod.train(
model=model,
optimizer=optimizer,
device=device,
save_dir=config.training.save_dir,
logging_interval=config.training.train_log_every_n,
train_steps=config.training.num_train_steps,
checkpoint_frequency=config.training.checkpoint_every_n,
dataset=train_iterator,
worker_batch_size=config.train_data.global_batch_size,
enable_amp=False,
initial_checkpoint_dir=config.training.initial_checkpoint_dir,
gradient_accumulation=config.training.gradient_accumulation,
scheduler=scheduler,
)
if __name__ == "__main__":
app.run(run)

View File

@ -0,0 +1,7 @@
from tml.projects.home.recap.model.entrypoint import (
create_ranking_model,
sanitize,
unsanitize,
MultiTaskRankingModel,
)
from tml.projects.home.recap.model.model_and_loss import ModelAndLoss

View File

@ -0,0 +1,260 @@
"""Configuration for the main Recap model."""
import enum
from typing import List, Optional, Dict
import tml.core.config as base_config
from tml.projects.home.recap.embedding import config as embedding_config
import pydantic
class DropoutConfig(base_config.BaseConfig):
"""Configuration for the dropout layer."""
rate: pydantic.PositiveFloat = pydantic.Field(
0.1, description="Fraction of inputs to be dropped."
)
class LayerNormConfig(base_config.BaseConfig):
"""Configruation for the layer normalization."""
epsilon: float = pydantic.Field(
1e-3, description="Small float added to variance to avoid dividing by zero."
)
axis: int = pydantic.Field(-1, description="Axis or axes to normalize across.")
center: bool = pydantic.Field(True, description="Whether to add learnable center.")
scale: bool = pydantic.Field(True, description="Whether to add learnable scale.")
class BatchNormConfig(base_config.BaseConfig):
"""Configuration of the batch normalization layer."""
epsilon: pydantic.PositiveFloat = 1e-5
momentum: pydantic.PositiveFloat = 0.9
training_mode_at_inference_time: bool = False
use_renorm: bool = False
center: bool = pydantic.Field(True, description="Whether to add learnable center.")
scale: bool = pydantic.Field(True, description="Whether to add learnable scale.")
class DenseLayerConfig(base_config.BaseConfig):
layer_size: pydantic.PositiveInt
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
class MlpConfig(base_config.BaseConfig):
"""Configuration for MLP model."""
layer_sizes: List[pydantic.PositiveInt] = pydantic.Field(None, one_of="mlp_layer_definition")
layers: List[DenseLayerConfig] = pydantic.Field(None, one_of="mlp_layer_definition")
class BatchNormConfig(base_config.BaseConfig):
"""Configuration for the batch norm layer."""
affine: bool = pydantic.Field(True, description="Use affine transformation.")
momentum: pydantic.PositiveFloat = pydantic.Field(
0.1, description="Forgetting parameter in moving average."
)
class DoubleNormLogConfig(base_config.BaseConfig):
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
clip_magnitude: float = pydantic.Field(
5.0, description="Threshold to clip the normalized input values."
)
layer_norm_config: Optional[LayerNormConfig] = pydantic.Field(None)
class Log1pAbsConfig(base_config.BaseConfig):
"""Simple configuration where only the log transform is performed."""
class ClipLog1pAbsConfig(base_config.BaseConfig):
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
3e38, description="Threshold to clip the input values."
)
class ZScoreLogConfig(base_config.BaseConfig):
analysis_path: str
schema_path: str = pydantic.Field(
None,
description="Schema path which feaure statistics are generated with. Can be different from scehma in data config.",
)
clip_magnitude: float = pydantic.Field(
5.0, description="Threshold to clip the normalized input values."
)
use_batch_norm: bool = pydantic.Field(
False, description="Option to use batch normalization on the inputs."
)
use_renorm: bool = pydantic.Field(
False, description="Option to use batch renormalization for trainig and serving consistency."
)
use_bq_stats: bool = pydantic.Field(
False, description="Option to load the partitioned json files from BQ as statistics."
)
class FeaturizationConfig(base_config.BaseConfig):
"""Configuration for featurization."""
log1p_abs_config: Log1pAbsConfig = pydantic.Field(None, one_of="featurization")
clip_log1p_abs_config: ClipLog1pAbsConfig = pydantic.Field(None, one_of="featurization")
z_score_log_config: ZScoreLogConfig = pydantic.Field(None, one_of="featurization")
double_norm_log_config: DoubleNormLogConfig = pydantic.Field(None, one_of="featurization")
feature_names_to_concat: List[str] = pydantic.Field(
["binary"], description="Feature names to concatenate as raw values with continuous features."
)
class DropoutConfig(base_config.BaseConfig):
"""Configuration for the dropout layer."""
rate: pydantic.PositiveFloat = pydantic.Field(
0.1, description="Fraction of inputs to be dropped."
)
class MlpConfig(base_config.BaseConfig):
"""Configuration for MLP model."""
layer_sizes: List[pydantic.PositiveInt]
batch_norm: BatchNormConfig = pydantic.Field(
None, description="Optional batch norm configuration."
)
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout configuration.")
final_layer_activation: bool = pydantic.Field(
False, description="Whether to include activation on final layer."
)
class DcnConfig(base_config.BaseConfig):
"""Config for DCN model."""
poly_degree: pydantic.PositiveInt
projection_dim: pydantic.PositiveInt = pydantic.Field(
None, description="Factorizes main DCN matmul with projection."
)
parallel_mlp: Optional[MlpConfig] = pydantic.Field(
None, description="Config for the mlp if used. If None, only the cross layers are used."
)
use_parallel: bool = pydantic.Field(True, description="Whether to use parallel DCN.")
output_mlp: Optional[MlpConfig] = pydantic.Field(None, description="Config for the output mlp.")
class MaskBlockConfig(base_config.BaseConfig):
output_size: int
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
None, one_of="aggregation_size"
)
aggregation_size: Optional[pydantic.PositiveInt] = pydantic.Field(
None, description="Specify the aggregation size directly.", one_of="aggregation_size"
)
input_layer_norm: bool
class MaskNetConfig(base_config.BaseConfig):
mask_blocks: List[MaskBlockConfig]
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
class PositionDebiasConfig(base_config.BaseConfig):
"""
Configuration for Position Debias.
"""
max_position: int = pydantic.Field(256, description="Bucket all later positions.")
num_dims: pydantic.PositiveInt = pydantic.Field(
64, description="Number of dimensions in embedding."
)
drop_probability: float = pydantic.Field(0.5, description="Probability of dropping position.")
# Currently it should be 51 based on dataset being tested at the time of writing this model
# However, no default provided here to make sure user of the model is aware of its importance.
position_feature_index: int = pydantic.Field(
description="The index of the position feature in the discrete features"
)
class AffineMap(base_config.BaseConfig):
"""An affine map that scales the logits into the appropriate range."""
scale: float = pydantic.Field(1.0)
bias: float = pydantic.Field(0.0)
class DLRMConfig(base_config.BaseConfig):
bottom_mlp: MlpConfig = pydantic.Field(
...,
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
)
top_mlp: MlpConfig = pydantic.Field(..., description="Top mlp, generate the final output")
class TaskModel(base_config.BaseConfig):
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
mask_net_config: MaskNetConfig = pydantic.Field(None, one_of="architecture")
affine_map: AffineMap = pydantic.Field(
None,
description="Affine map applied to logits so we can represent a broader range of probabilities.",
)
# DANGER DANGER: not implemented yet.
# loss_weight: float = pydantic.Field(1.0, description="Weight for task in loss.")
pos_weight: float = pydantic.Field(1.0, description="Weight of positive in loss.")
class MultiTaskType(str, enum.Enum):
SHARE_NONE = "share_none" # Tasks are separate.
SHARE_ALL = "share_all" # Tasks share same backbone.
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
class ModelConfig(base_config.BaseConfig):
"""Specify model architecture."""
tasks: Dict[str, TaskModel] = pydantic.Field(
description="Specification of architecture per task."
)
large_embeddings: embedding_config.LargeEmbeddingsConfig = pydantic.Field(None)
small_embeddings: embedding_config.SmallEmbeddingsConfig = pydantic.Field(None)
# Not implemented yet.
# multi_task_loss_reduction_fn: str = "mean"
position_debias_config: PositionDebiasConfig = pydantic.Field(
default=None, description="position debias model configuration"
)
featurization_config: FeaturizationConfig = pydantic.Field(None)
multi_task_type: MultiTaskType = pydantic.Field(
MultiTaskType.SHARE_NONE, description="Multi task architecture"
)
backbone: TaskModel = pydantic.Field(None, description="Type of architecture for the backbone.")
stratifiers: List[embedding_config.StratifierConfig] = pydantic.Field(
default=None, description="Discrete features and values to stratify metrics by."
)
@pydantic.root_validator()
def _validate_mtl(cls, values):
if values.get("multi_task_type", None) is None:
return values
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:
if values.get("backbone", None) is None:
raise ValueError("Require `backbone` for SHARE_ALL and SHARE_PARTIAL.")
elif values["multi_task_type"] in [
MultiTaskType.SHARE_NONE,
]:
if values.get("backbone", None) is not None:
raise ValueError("Can not have backbone if the share type is SHARE_NONE")
return values

View File

@ -0,0 +1,304 @@
from __future__ import annotations
from absl import logging
import torch
from typing import Optional, Callable, Mapping, Dict, Sequence, TYPE_CHECKING
from tml.projects.home.recap.model import feature_transform
from tml.projects.home.recap.model import config as model_config_mod
from tml.projects.home.recap.model import mlp
from tml.projects.home.recap.model import mask_net
from tml.projects.home.recap.model import numeric_calibration
from tml.projects.home.recap.model.model_and_loss import ModelAndLoss
import tml.projects.home.recap.model.config as model_config_mod
if TYPE_CHECKING:
from tml.projects.home.recap import config as config_mod
from tml.projects.home.recap.data.config import RecapDataConfig
from tml.projects.home.recap.model.config import ModelConfig
def sanitize(task_name):
return task_name.replace(".", "__")
def unsanitize(sanitized_task_name):
return sanitized_task_name.replace("__", ".")
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
""" "Builds a model for a single task"""
if task.mlp_config:
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
elif task.dcn_config:
return dcn.Dcn(dcn_config=task.dcn_config, in_features=input_shape)
elif task.mask_net_config:
return mask_net.MaskNet(mask_net_config=task.mask_net_config, in_features=input_shape)
else:
raise ValueError("This should never be reached.")
class MultiTaskRankingModel(torch.nn.Module):
"""Multi-task ranking model."""
def __init__(
self,
input_shapes: Mapping[str, torch.Size],
config: ModelConfig,
data_config: RecapDataConfig,
return_backbone: bool = False,
):
"""Constructor for Multi task learning.
Assumptions made:
1. Tasks specified in data config match model architecture.
These are all validated in config.
"""
super().__init__()
self._config = config
self._data_config = data_config
self._preprocessor = feature_transform.build_features_preprocessor(
config.featurization_config, input_shapes
)
self.return_backbone = return_backbone
self.embeddings = None
self.small_embeddings = None
embedding_dims = 0
if config.large_embeddings:
from large_embeddings.models.learnable_embeddings import LargeEmbeddings
self.embeddings = LargeEmbeddings(large_embeddings_config=config.large_embeddings)
embedding_dims += sum([table.embedding_dim for table in config.large_embeddings.tables])
logging.info(f"Emb dim: {embedding_dims}")
if config.small_embeddings:
self.small_embeddings = SmallEmbedding(config.small_embeddings)
embedding_dims += sum([table.embedding_dim for table in config.small_embeddings.tables])
logging.info(f"Emb dim (with small embeddings): {embedding_dims}")
if "user_embedding" in data_config.seg_dense_schema.renamed_features:
embedding_dims += input_shapes["user_embedding"][-1]
self._user_embedding_layer_norm = torch.nn.LayerNorm(input_shapes["user_embedding"][-1])
else:
self._user_embedding_layer_norm = None
if "user_eng_embedding" in data_config.seg_dense_schema.renamed_features:
embedding_dims += input_shapes["user_eng_embedding"][-1]
self._user_eng_embedding_layer_norm = torch.nn.LayerNorm(
input_shapes["user_eng_embedding"][-1]
)
else:
self._user_eng_embedding_layer_norm = None
if "author_embedding" in data_config.seg_dense_schema.renamed_features:
embedding_dims += input_shapes["author_embedding"][-1]
self._author_embedding_layer_norm = torch.nn.LayerNorm(input_shapes["author_embedding"][-1])
else:
self._author_embedding_layer_norm = None
input_dims = input_shapes["continuous"][-1] + input_shapes["binary"][-1] + embedding_dims
if config.position_debias_config:
self.position_debias_model = PositionDebias(config.position_debias_config)
input_dims += self.position_debias_model.out_features
else:
self.position_debias_model = None
logging.info(f"input dim: {input_dims}")
if config.multi_task_type in [
model_config_mod.MultiTaskType.SHARE_ALL,
model_config_mod.MultiTaskType.SHARE_PARTIAL,
]:
self._backbone = _build_single_task_model(config.backbone, input_dims)
else:
self._backbone = None
_towers: Dict[str, torch.nn.Module] = {}
_calibrators: Dict[str, torch.nn.Module] = {}
_affine_maps: Dict[str, torch.nn.Module] = {}
for task_name, task_architecture in config.tasks.items():
safe_name = sanitize(task_name)
# Complex input dimension calculation.
if config.multi_task_type == model_config_mod.MultiTaskType.SHARE_NONE:
num_inputs = input_dims
elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:
num_inputs = self._backbone.out_features
elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:
num_inputs = input_dims + self._backbone.out_features
else:
raise ValueError("Unreachable branch of enum.")
# Annoyingly, ModuleDict doesn't allow . inside key names.
_towers[safe_name] = _build_single_task_model(task_architecture, num_inputs)
if task_architecture.affine_map:
affine_map = torch.nn.Linear(1, 1)
affine_map.weight.data = torch.tensor([[task_architecture.affine_map.scale]])
affine_map.bias.data = torch.tensor([task_architecture.affine_map.bias])
_affine_maps[safe_name] = affine_map
else:
_affine_maps[safe_name] = torch.nn.Identity()
_calibrators[safe_name] = numeric_calibration.NumericCalibration(
pos_downsampling_rate=data_config.tasks[task_name].pos_downsampling_rate,
neg_downsampling_rate=data_config.tasks[task_name].neg_downsampling_rate,
)
self._task_names = list(config.tasks.keys())
self._towers = torch.nn.ModuleDict(_towers)
self._affine_maps = torch.nn.ModuleDict(_affine_maps)
self._calibrators = torch.nn.ModuleDict(_calibrators)
self._counter = torch.autograd.Variable(torch.tensor(0), requires_grad=False)
def forward(
self,
continuous_features: torch.Tensor,
binary_features: torch.Tensor,
discrete_features: Optional[torch.Tensor] = None,
sparse_features=None, # Optional[KeyedJaggedTensor]
user_embedding: Optional[torch.Tensor] = None,
user_eng_embedding: Optional[torch.Tensor] = None,
author_embedding: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
weights: Optional[torch.Tensor] = None,
):
concat_dense_features = [
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
]
if self.embeddings:
concat_dense_features.append(self.embeddings(sparse_features))
# Twhin embedding layer norms
if self.small_embeddings:
if discrete_features is None:
raise ValueError(
"Forward arg discrete_features is None, but since small_embeddings are used, a Tensor is expected."
)
concat_dense_features.append(self.small_embeddings(discrete_features))
if self._user_embedding_layer_norm:
if user_embedding is None:
raise ValueError(
"Forward arg user_embedding is None, but since Twhin user_embeddings are used by the model, a Tensor is expected."
)
concat_dense_features.append(self._user_embedding_layer_norm(user_embedding))
if self._user_eng_embedding_layer_norm:
if user_eng_embedding is None:
raise ValueError(
"Forward arg user_eng_embedding is None, but since Twhin user_eng_embeddings are used by the model, a Tensor is expected."
)
concat_dense_features.append(self._user_eng_embedding_layer_norm(user_eng_embedding))
if self._author_embedding_layer_norm:
if author_embedding is None:
raise ValueError(
"Forward arg author_embedding is None, but since Twhin author_embeddings are used by the model, a Tensor is expected."
)
concat_dense_features.append(self._author_embedding_layer_norm(author_embedding))
if self.position_debias_model:
if discrete_features is None:
raise ValueError(
"Forward arg discrete_features is None, but since position_debias_model is used, a Tensor is expected."
)
concat_dense_features.append(self.position_debias_model(discrete_features))
if discrete_features is not None and not (self.position_debias_model or self.small_embeddings):
logging.warning("Forward arg discrete_features is passed, but never used.")
concat_dense_features = torch.cat(concat_dense_features, dim=1)
if self._backbone:
if self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:
net = self._backbone(concat_dense_features)["output"]
elif self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:
net = torch.cat(
[concat_dense_features, self._backbone(concat_dense_features)["output"]], dim=1
)
else:
net = concat_dense_features
backbone_result = net
all_logits = []
all_probabilities = []
all_calibrated_probabilities = []
for task_name in self._task_names:
safe_name = sanitize(task_name)
tower_outputs = self._towers[safe_name](net)
logits = tower_outputs["output"]
scaled_logits = self._affine_maps[safe_name](logits)
probabilities = torch.sigmoid(scaled_logits)
calibrated_probabilities = self._calibrators[safe_name](probabilities)
all_logits.append(scaled_logits)
all_probabilities.append(probabilities)
all_calibrated_probabilities.append(calibrated_probabilities)
results = {
"logits": torch.squeeze(torch.stack(all_logits, dim=1), dim=-1),
"probabilities": torch.squeeze(torch.stack(all_probabilities, dim=1), dim=-1),
"calibrated_probabilities": torch.squeeze(
torch.stack(all_calibrated_probabilities, dim=1), dim=-1
),
}
# Returning the backbone is intended for stitching post-tf conversion
# Leaving this on will ~200x the size of the output
# and could slow things down
if self.return_backbone:
results["backbone"] = backbone_result
return results
def create_ranking_model(
data_spec,
# Used for planner to be batch size aware.
config: config_mod.RecapConfig,
device: torch.device,
loss_fn: Optional[Callable] = None,
data_config=None,
return_backbone=False,
):
if list(config.model.tasks.values())[0].dlrm_config:
raise NotImplementedError()
model = EmbeddingRankingModel(
input_shapes=data_spec,
config=all_config.model,
data_config=all_config.train_data,
)
else:
model = MultiTaskRankingModel(
input_shapes=data_spec,
config=config.model,
data_config=data_config if data_config is not None else config.train_data,
return_backbone=return_backbone,
)
logging.info("***** Model Architecture *****")
logging.info(model)
logging.info("***** Named Parameters *****")
for elem in model.named_parameters():
logging.info(elem[0])
if loss_fn:
logging.info("***** Wrapping in loss *****")
model = ModelAndLoss(
model=model,
loss_fn=loss_fn,
stratifiers=config.model.stratifiers,
)
return model

View File

@ -0,0 +1,122 @@
from typing import Mapping, Sequence, Union
from tml.projects.home.recap.model.config import (
BatchNormConfig,
DoubleNormLogConfig,
FeaturizationConfig,
LayerNormConfig,
)
import torch
def log_transform(x: torch.Tensor) -> torch.Tensor:
"""Safe log transform that works across both negative, zero, and positive floats."""
return torch.sign(x) * torch.log1p(torch.abs(x))
class BatchNorm(torch.nn.Module):
def __init__(self, num_features: int, config: BatchNormConfig):
super().__init__()
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer(x)
class LayerNorm(torch.nn.Module):
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
super().__init__()
if config.axis != -1:
raise NotImplementedError
if config.center != config.scale:
raise ValueError(
f"Center and scale must match in torch, received {config.center}, {config.scale}"
)
self.layer = torch.nn.LayerNorm(
normalized_shape, eps=config.epsilon, elementwise_affine=config.center
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer(x)
class Log1pAbs(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return log_transform(x)
class InputNonFinite(torch.nn.Module):
def __init__(self, fill_value: float = 0):
super().__init__()
self.register_buffer(
"fill_value", torch.as_tensor(fill_value, dtype=torch.float32), persistent=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.where(torch.isfinite(x), x, self.fill_value)
class Clamp(torch.nn.Module):
def __init__(self, min_value: float, max_value: float):
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
self.register_buffer(
"min_value", torch.as_tensor(min_value, dtype=torch.float32), persistent=True
)
self.register_buffer(
"max_value", torch.as_tensor(max_value, dtype=torch.float32), persistent=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.clamp(x, min=self.min_value, max=self.max_value)
class DoubleNormLog(torch.nn.Module):
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features."""
def __init__(
self,
input_shapes: Mapping[str, Sequence[int]],
config: DoubleNormLogConfig,
):
super().__init__()
_before_concat_layers = [
InputNonFinite(),
Log1pAbs(),
]
if config.batch_norm_config:
_before_concat_layers.append(
BatchNorm(input_shapes["continuous"][-1], config.batch_norm_config)
)
_before_concat_layers.append(
Clamp(min_value=-config.clip_magnitude, max_value=config.clip_magnitude)
)
self._before_concat_layers = torch.nn.Sequential(*_before_concat_layers)
self.layer_norm = None
if config.layer_norm_config:
last_dim = input_shapes["continuous"][-1] + input_shapes["binary"][-1]
self.layer_norm = LayerNorm(last_dim, config.layer_norm_config)
def forward(
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
) -> torch.Tensor:
x = self._before_concat_layers(continuous_features)
x = torch.cat([x, binary_features], dim=1)
if self.layer_norm:
return self.layer_norm(x)
return x
def build_features_preprocessor(
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
):
"""Trivial right now, but we will change in the future."""
return DoubleNormLog(input_shapes, config.double_norm_log_config)

View File

@ -0,0 +1,98 @@
"""MaskNet: Wang et al. (https://arxiv.org/abs/2102.07619)."""
from tml.projects.home.recap.model import config, mlp
import torch
def _init_weights(module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class MaskBlock(torch.nn.Module):
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
super(MaskBlock, self).__init__()
self.mask_block_config = mask_block_config
output_size = mask_block_config.output_size
if mask_block_config.input_layer_norm:
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
else:
self._input_layer_norm = None
if mask_block_config.reduction_factor:
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
elif mask_block_config.aggregation_size is not None:
aggregation_size = mask_block_config.aggregation_size
else:
raise ValueError("Need one of reduction factor or aggregation size.")
self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size),
torch.nn.ReLU(),
torch.nn.Linear(aggregation_size, input_dim),
)
self._mask_layer.apply(_init_weights)
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
self._hidden_layer.apply(_init_weights)
self._layer_norm = torch.nn.LayerNorm(output_size)
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
if self._input_layer_norm:
net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
return self._layer_norm(hidden_layer_output)
class MaskNet(torch.nn.Module):
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
super().__init__()
self.mask_net_config = mask_net_config
mask_blocks = []
if mask_net_config.use_parallel:
total_output_mask_blocks = 0
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else:
input_size = in_features
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
total_output_mask_blocks = mask_block_config.output_size
if mask_net_config.mlp:
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
self.out_features = mask_net_config.mlp.layer_sizes[-1]
else:
self.out_features = total_output_mask_blocks
self.shared_size = total_output_mask_blocks
def forward(self, inputs: torch.Tensor):
if self.mask_net_config.use_parallel:
mask_outputs = []
for mask_layer in self._mask_blocks:
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
# Share the outputs of the MaskBlocks.
all_mask_outputs = torch.cat(mask_outputs, dim=1)
output = (
all_mask_outputs
if self.mask_net_config.mlp is None
else self._dense_layers(all_mask_outputs)["output"]
)
return {"output": output, "shared_layer": all_mask_outputs}
else:
net = inputs
for mask_layer in self._mask_blocks:
net = mask_layer(net=net, mask_input=inputs)
# Share the output of the stacked MaskBlocks.
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
return {"output": output, "shared_layer": net}

View File

@ -0,0 +1,58 @@
"""MLP feed forward stack in torch."""
from tml.projects.home.recap.model.config import MlpConfig
import torch
from absl import logging
def _init_weights(module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class Mlp(torch.nn.Module):
def __init__(self, in_features: int, mlp_config: MlpConfig):
super().__init__()
self._mlp_config = mlp_config
input_size = in_features
layer_sizes = mlp_config.layer_sizes
modules = []
for layer_size in layer_sizes[:-1]:
modules.append(torch.nn.Linear(input_size, layer_size, bias=True))
if mlp_config.batch_norm:
modules.append(
torch.nn.BatchNorm1d(
layer_size, affine=mlp_config.batch_norm.affine, momentum=mlp_config.batch_norm.momentum
)
)
modules.append(torch.nn.ReLU())
if mlp_config.dropout:
modules.append(torch.nn.Dropout(mlp_config.dropout.rate))
input_size = layer_size
modules.append(torch.nn.Linear(input_size, layer_sizes[-1], bias=True))
if mlp_config.final_layer_activation:
modules.append(torch.nn.ReLU())
self.layers = torch.nn.ModuleList(modules)
self.layers.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
net = x
for i, layer in enumerate(self.layers):
net = layer(net)
if i == 1: # Share the first (widest?) set of activations for other applications.
shared_layer = net
return {"output": net, "shared_layer": shared_layer}
@property
def shared_size(self):
return self._mlp_config.layer_sizes[-1]
@property
def out_features(self):
return self._mlp_config.layer_sizes[-1]

View File

@ -0,0 +1,67 @@
from typing import Callable, Optional, List
from tml.projects.home.recap.embedding import config as embedding_config_mod
import torch
from absl import logging
class ModelAndLoss(torch.nn.Module):
def __init__(
self,
model,
loss_fn: Callable,
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.stratifiers = stratifiers
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(
continuous_features=batch.continuous_features,
binary_features=batch.binary_features,
discrete_features=batch.discrete_features,
sparse_features=batch.sparse_features,
user_embedding=batch.user_embedding,
user_eng_embedding=batch.user_eng_embedding,
author_embedding=batch.author_embedding,
labels=batch.labels,
weights=batch.weights,
)
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
if self.stratifiers:
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}")
outputs["stratifiers"] = {}
for stratifier in self.stratifiers:
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
# In general, we can have a large number of losses returned by our loss function.
if isinstance(losses, dict):
return losses["loss"], {
**outputs,
**losses,
"labels": batch.labels,
"weights": batch.weights,
}
else: # Assume that this is a float.
return losses, {
**outputs,
"loss": losses,
"labels": batch.labels,
"weights": batch.weights,
}

View File

@ -0,0 +1,19 @@
import torch
class NumericCalibration(torch.nn.Module):
def __init__(
self,
pos_downsampling_rate: float,
neg_downsampling_rate: float,
):
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
self.register_buffer(
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
)
def forward(self, probs: torch.Tensor):
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))

View File

@ -0,0 +1 @@
from tml.projects.home.recap.optimizer.optimizer import build_optimizer

View File

@ -0,0 +1,36 @@
"""Optimization configurations for models."""
import typing
import tml.core.config as base_config
import tml.optimizers.config as optimizers_config_mod
import pydantic
class RecapAdamConfig(base_config.BaseConfig):
beta_1: float = 0.9 # Momentum term.
beta_2: float = 0.999 # Exponential weighted decay factor.
epsilon: float = 1e-7 # Numerical stability in denominator.
class MultiTaskLearningRates(base_config.BaseConfig):
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
description="Learning rates for different towers of the model."
)
backbone_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(
None, description="Learning rate for backbone of the model."
)
class RecapOptimizerConfig(base_config.BaseConfig):
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
None, description="Multiple learning rates for different tasks.", one_of="lr"
)
single_task_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(
None, description="Single task learning rates", one_of="lr"
)
adam: RecapAdamConfig = pydantic.Field(one_of="optimizer")

View File

@ -0,0 +1,166 @@
"""Build optimizers and learning rate schedules."""
import bisect
from collections import defaultdict
import functools
import math
import typing
from typing import Optional
import warnings
# from large_embeddings.config import EmbeddingOptimizerConfig
from tml.projects.home.recap import model as model_mod
from tml.optimizers import config
from tml.optimizers import compute_lr
from absl import logging # type: ignore[attr-defined]
import torch
from torchrec.optim import keyed
_DEFAULT_LR = 24601.0 # NaN the model if we're not using the learning rate.
_BACKBONE = "backbone"
_DENSE_EMBEDDINGS = "dense_ebc"
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
def __init__(
self,
optimizer,
lr_dict: typing.Dict[str, config.LearningRate],
emb_learning_rate,
last_epoch=-1,
verbose=False,
):
self.optimizer = optimizer
self.lr_dict = lr_dict
self.group_names = list(self.lr_dict.keys())
self.emb_learning_rate = emb_learning_rate
# We handle sparse LR scheduling separately, so only validate LR groups against dense param groups
num_dense_param_groups = sum(
1
for _, _optim in optimizer._optims
for _ in _optim.param_groups
if isinstance(_optim, keyed.KeyedOptimizerWrapper)
)
if num_dense_param_groups != len(lr_dict):
raise ValueError(
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
)
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning,
)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
learning_rates = []
for lr_config in self.lr_dict.values():
learning_rates.append(compute_lr(lr_config, self.last_epoch))
# WARNING: The order of appending is important.
if self.emb_learning_rate:
learning_rates.append(compute_lr(self.emb_learning_rate, self.last_epoch))
return learning_rates
def build_optimizer(
model: torch.nn.Module,
optimizer_config: config.OptimizerConfig,
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
):
"""Builds an optimizer and scheduler.
Args:
model: A torch model, probably with DDP/DMP.
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
Returns:
A torch.optim instance, and a scheduler instance.
"""
optimizer_fn = functools.partial(
torch.optim.Adam,
lr=_DEFAULT_LR,
betas=(optimizer_config.adam.beta_1, optimizer_config.adam.beta_2),
eps=optimizer_config.adam.epsilon,
maximize=False,
)
if optimizer_config.multi_task_learning_rates:
logging.info("***** Parameter groups for optimization *****")
# Importantly, we preserve insertion order in dictionaries here.
parameter_groups: typing.Dict[str, typing.Dict] = defaultdict(dict)
added_parameters: typing.Set[str] = set()
for task in optimizer_config.multi_task_learning_rates.tower_learning_rates:
for name, parameter in model.named_parameters():
if f".{model_mod.sanitize(task)}." in name:
parameter_groups[task][name] = parameter
logging.info(f"{task}: {name}")
if name in added_parameters:
raise ValueError(f"Parameter {name} matched multiple tasks.")
added_parameters.add(name)
for name, parameter in model.named_parameters():
if name not in added_parameters and "embedding_bags" not in name:
parameter_groups[_BACKBONE][name] = parameter
added_parameters.add(name)
logging.info(f"{_BACKBONE}: {name}")
for name, parameter in model.named_parameters():
if name not in added_parameters and "embedding_bags" in name:
parameter_groups[_DENSE_EMBEDDINGS][name] = parameter
logging.info(f"{_DENSE_EMBEDDINGS}: {name}")
all_learning_rates = optimizer_config.multi_task_learning_rates.tower_learning_rates.copy()
if optimizer_config.multi_task_learning_rates.backbone_learning_rate is not None:
all_learning_rates[
_BACKBONE
] = optimizer_config.multi_task_learning_rates.backbone_learning_rate
if _DENSE_EMBEDDINGS in parameter_groups and emb_optimizer_config:
all_learning_rates[_DENSE_EMBEDDINGS] = emb_optimizer_config.learning_rate.copy()
else:
parameter_groups = dict(model.named_parameters())
all_learning_rates = {"single_task": optimizer_config.single_task_learning_rate}
optimizers = [
keyed.KeyedOptimizerWrapper(param_group, optimizer_fn)
for param_name, param_group in parameter_groups.items()
if param_name != _DENSE_EMBEDDINGS
]
# Making EBC optimizer to be SGD to match fused optimiser
if _DENSE_EMBEDDINGS in parameter_groups:
optimizers.append(
keyed.KeyedOptimizerWrapper(
parameter_groups[_DENSE_EMBEDDINGS],
functools.partial(torch.optim.SGD, lr=_DEFAULT_LR, maximize=False, momentum=False),
)
)
if not parameter_groups.keys() == all_learning_rates.keys():
raise ValueError("Learning rates do not match optimizers")
# If the optimiser is dense, model.fused_optimizer will be empty (but not None)
emb_learning_rate = None
if hasattr(model, "fused_optimizer") and model.fused_optimizer.optimizers:
logging.info(f"Model fused optimiser: {model.fused_optimizer}")
optimizers.append(model.fused_optimizer)
if emb_optimizer_config:
emb_learning_rate = emb_optimizer_config.learning_rate.copy()
else:
raise ValueError("Fused kernel exists, but LR is not set")
logging.info(f"***** Combining optimizers: {optimizers} *****")
optimizer = keyed.CombinedOptimizer(optimizers)
scheduler = RecapLRShim(optimizer, all_learning_rates, emb_learning_rate)
logging.info(f"***** Combined optimizer after init: {optimizer} *****")
return optimizer, scheduler

View File

@ -0,0 +1,10 @@
#!/usr/bin/env bash
# Runs from inside venv
rm -rf $HOME/tmp/runs/recap_local_random_data
python -m tml.machines.is_venv || exit 1
export TML_BASE="$(git rev-parse --show-toplevel)"
mkdir -p $HOME/tmp/recap_local_random_data
python projects/home/recap/data/generate_random_data.py --config_path $(pwd)/projects/home/recap/config/local_prod.yaml

View File

@ -0,0 +1,15 @@
#!/usr/bin/env bash
# Runs from inside venv
rm -rf $HOME/tmp/runs/recap_local_debug
mkdir -p $HOME/tmp/runs/recap_local_debug
python -m tml.machines.is_venv || exit 1
export TML_BASE="$(git rev-parse --show-toplevel)"
torchrun \
--standalone \
--nnodes 1 \
--nproc_per_node 1 \
projects/home/recap/main.py \
--config_path $(pwd)/projects/home/recap/config/local_prod.yaml \
$@