mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-06-13 12:58:39 +02:00
Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings
This commit is contained in:
2267
projects/home/recap/FEATURES.md
Normal file
2267
projects/home/recap/FEATURES.md
Normal file
File diff suppressed because it is too large
Load Diff
48
projects/home/recap/README.md
Normal file
48
projects/home/recap/README.md
Normal file
@ -0,0 +1,48 @@
|
||||
This project is the "heavy ranker" used on the "For You" timeline. This is used to generate the ranking of Tweet after candidate retrieval and light ranker (note the final ordering of the Tweet is not directly the highest -> lowest scoring, because after scoring other heuristics are used).
|
||||
|
||||
This model captures the ranking model used for the majority of users of Twitter "For You" timeline in early March 2023. Due to the need to make sure this runs independently from other parts of Twitter codebase, there may be small differences from the production model.
|
||||
|
||||
The model receives various features, describing the Tweet and the user whose timeline is being constructed as input (see FEATURES.md for more details). The model outputs multiple binary predictions about how the user will respond if shown the Tweet.
|
||||
|
||||
|
||||
Those are:
|
||||
"recap.engagement.is_favorited": The probability the user will favorite the Tweet.
|
||||
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": The probability the user will click into the conversation of this Tweet and reply or Like a Tweet.
|
||||
"recap.engagement.is_good_clicked_convo_desc_v2": The probability the user will click into the conversation of this Tweet and stay there for at least 2 minutes.
|
||||
"recap.engagement.is_negative_feedback_v2": The probability the user will react negatively (requesting "show less often" on the Tweet or author, block or mute the Tweet author)
|
||||
"recap.engagement.is_profile_clicked_and_profile_engaged": The probability the user opens the Tweet author profile and Likes or replies to a Tweet.
|
||||
"recap.engagement.is_replied": The probability the user replies to the Tweet.
|
||||
"recap.engagement.is_replied_reply_engaged_by_author": The probability the user replies to the Tweet and this reply is engaged by the Tweet author.
|
||||
"recap.engagement.is_report_tweet_clicked": The probability the user will click Report Tweet.
|
||||
"recap.engagement.is_retweeted": The probability the user will ReTweet the Tweet.
|
||||
"recap.engagement.is_video_playback_50": The probability (for a video Tweet) that the user will watch at least half of the video
|
||||
|
||||
For ranking the candidates these predictions are combined into a score by weighting them:
|
||||
"recap.engagement.is_favorited": 0.5
|
||||
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": 11* (the maximum prediction from these two "good click" features is used and weighted by 11, the other prediction is ignored).
|
||||
"recap.engagement.is_good_clicked_convo_desc_v2": 11*
|
||||
"recap.engagement.is_negative_feedback_v2": -74
|
||||
"recap.engagement.is_profile_clicked_and_profile_engaged": 12
|
||||
"recap.engagement.is_replied": 27
|
||||
"recap.engagement.is_replied_reply_engaged_by_author": 75
|
||||
"recap.engagement.is_report_tweet_clicked": -369
|
||||
"recap.engagement.is_retweeted": 1
|
||||
"recap.engagement.is_video_playback_50": 0.005
|
||||
|
||||
|
||||
We cannot release the real training data due to privacy restrictions. However, we have included a script to generate random data to ensure you can run the model training code.
|
||||
|
||||
To try training the model (assuming you have already followed the repo setup instructions and are inside a virtualenv).
|
||||
|
||||
Run
|
||||
$ ./projects/home/recap/scripts/create_random_data.sh
|
||||
|
||||
This will create some random data (in $HOME/tmp/recap_local_random_data).
|
||||
|
||||
$ ./projects/home/recap/scripts/run_local.sh
|
||||
|
||||
This will train the model (for a small number of iterations). Checkpoints and logs will be written to $HOME/tmp/runs/recap_local_debug.
|
||||
|
||||
The model training is configured through a yaml file (./projects/home/recap/config/local_prod.yaml).
|
||||
|
||||
The model architecture is a parallel masknet (https://arxiv.org/abs/2102.07619).
|
0
projects/home/recap/__init__.py
Normal file
0
projects/home/recap/__init__.py
Normal file
54
projects/home/recap/config.py
Normal file
54
projects/home/recap/config.py
Normal file
@ -0,0 +1,54 @@
|
||||
from tml.core import config as config_mod
|
||||
import tml.projects.home.recap.data.config as data_config
|
||||
import tml.projects.home.recap.model.config as model_config
|
||||
import tml.projects.home.recap.optimizer.config as optimizer_config
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
import pydantic
|
||||
|
||||
|
||||
class TrainingConfig(config_mod.BaseConfig):
|
||||
save_dir: str = "/tmp/model"
|
||||
num_train_steps: pydantic.PositiveInt = 1000000
|
||||
initial_checkpoint_dir: str = pydantic.Field(
|
||||
None, description="Directory of initial checkpoints", at_most_one_of="initialization"
|
||||
)
|
||||
checkpoint_every_n: pydantic.PositiveInt = 1000
|
||||
checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(
|
||||
None, description="Maximum number of checkpoints to keep. Defaults to keeping all."
|
||||
)
|
||||
train_log_every_n: pydantic.PositiveInt = 1000
|
||||
num_eval_steps: int = pydantic.Field(
|
||||
16384, description="Number of evaluation steps. If < 0 the entire dataset " "will be used."
|
||||
)
|
||||
eval_log_every_n: pydantic.PositiveInt = 5000
|
||||
|
||||
eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60
|
||||
|
||||
gradient_accumulation: int = pydantic.Field(
|
||||
None, description="Number of replica steps to accumulate gradients."
|
||||
)
|
||||
|
||||
|
||||
class RecapConfig(config_mod.BaseConfig):
|
||||
training: TrainingConfig = pydantic.Field(TrainingConfig())
|
||||
model: model_config.ModelConfig
|
||||
train_data: data_config.RecapDataConfig
|
||||
validation_data: Dict[str, data_config.RecapDataConfig]
|
||||
optimizer: optimizer_config.RecapOptimizerConfig
|
||||
|
||||
which_metrics: Optional[str] = pydantic.Field(None, description="which metrics to pick.")
|
||||
|
||||
# DANGER DANGER! You might expect validators here to ensure that multi task learning setups are
|
||||
# the same as the data. Unfortunately, this throws opaque errors when the model configuration is
|
||||
# invalid. In our judgement, that is a more frequency and worse occurrence than tasks not matching
|
||||
# the data.
|
||||
|
||||
|
||||
class JobMode(str, Enum):
|
||||
"""Job modes."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVALUATE = "evaluate"
|
||||
INFERENCE = "inference"
|
114
projects/home/recap/config/home_recap_2022/segdense.json
Normal file
114
projects/home/recap/config/home_recap_2022/segdense.json
Normal file
@ -0,0 +1,114 @@
|
||||
{
|
||||
"schema": [
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "home_recap_2022_discrete__segdense_vals",
|
||||
"length": 320
|
||||
},
|
||||
{
|
||||
"dtype": "float_list",
|
||||
"feature_name": "home_recap_2022_cont__segdense_vals",
|
||||
"length": 6000
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "home_recap_2022_binary__segdense_vals",
|
||||
"length": 512
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_tweet_detail_dwelled_15_sec",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_profile_clicked_and_profile_engaged",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_replied_reply_engaged_by_author",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_video_playback_50",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_report_tweet_clicked",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_replied",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "meta.author_id",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_negative_feedback_v2",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_retweeted",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_favorited",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "meta.tweet_id",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_good_clicked_convo_desc_v2",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "meta.user_id",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_bookmarked",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "int64_list",
|
||||
"feature_name": "recap.engagement.is_shared",
|
||||
"length": 1
|
||||
},
|
||||
{
|
||||
"dtype": "float_list",
|
||||
"feature_name": "user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings",
|
||||
"length": 200
|
||||
},
|
||||
{
|
||||
"dtype": "float_list",
|
||||
"feature_name": "original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings",
|
||||
"length": 200
|
||||
},
|
||||
{
|
||||
"dtype": "float_list",
|
||||
"feature_name": "user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings",
|
||||
"length": 200
|
||||
}
|
||||
]
|
||||
}
|
477
projects/home/recap/config/local_prod.yaml
Normal file
477
projects/home/recap/config/local_prod.yaml
Normal file
@ -0,0 +1,477 @@
|
||||
training:
|
||||
num_train_steps: 10
|
||||
num_eval_steps: 5
|
||||
checkpoint_every_n: 5
|
||||
train_log_every_n: 1
|
||||
eval_log_every_n: 1
|
||||
save_dir: ${HOME}/tmp/runs/recap_local_debug
|
||||
eval_timeout_in_s: 7200
|
||||
model:
|
||||
backbone:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config:
|
||||
mask_blocks:
|
||||
- aggregation_size: 1024
|
||||
input_layer_norm: false
|
||||
output_size: 1024
|
||||
reduction_factor: null
|
||||
- aggregation_size: 1024
|
||||
input_layer_norm: false
|
||||
output_size: 1024
|
||||
reduction_factor: null
|
||||
- aggregation_size: 1024
|
||||
input_layer_norm: false
|
||||
output_size: 1024
|
||||
reduction_factor: null
|
||||
- aggregation_size: 1024
|
||||
input_layer_norm: false
|
||||
output_size: 1024
|
||||
reduction_factor: null
|
||||
mlp:
|
||||
batch_norm: null
|
||||
dropout: null
|
||||
final_layer_activation: true
|
||||
layer_sizes:
|
||||
- 2048
|
||||
use_parallel: true
|
||||
mlp_config: null
|
||||
pos_weight: 1.0
|
||||
featurization_config:
|
||||
clip_log1p_abs_config: null
|
||||
double_norm_log_config:
|
||||
batch_norm_config:
|
||||
affine: true
|
||||
momentum: 0.01
|
||||
clip_magnitude: 5.0
|
||||
layer_norm_config:
|
||||
axis: -1
|
||||
center: true
|
||||
epsilon: 0.0
|
||||
scale: true
|
||||
feature_names_to_concat:
|
||||
- binary
|
||||
log1p_abs_config: null
|
||||
z_score_log_config: null
|
||||
large_embeddings: null
|
||||
multi_task_type: share_all
|
||||
position_debias_config: null
|
||||
small_embeddings: null
|
||||
stratifiers: null
|
||||
tasks:
|
||||
recap.engagement.is_favorited:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout: null
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout: null
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
recap.engagement.is_good_clicked_convo_desc_v2:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout: null
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
recap.engagement.is_negative_feedback_v2:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout:
|
||||
rate: 0.1
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
recap.engagement.is_profile_clicked_and_profile_engaged:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout: null
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
recap.engagement.is_replied:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout: null
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
recap.engagement.is_replied_reply_engaged_by_author:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout: null
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
recap.engagement.is_report_tweet_clicked:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout:
|
||||
rate: 0.2
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
recap.engagement.is_retweeted:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout: null
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
recap.engagement.is_video_playback_50:
|
||||
affine_map: null
|
||||
dcn_config: null
|
||||
dlrm_config: null
|
||||
mask_net_config: null
|
||||
mlp_config:
|
||||
batch_norm:
|
||||
affine: false
|
||||
momentum: 0.1
|
||||
dropout: null
|
||||
final_layer_activation: false
|
||||
layer_sizes:
|
||||
- 256
|
||||
- 128
|
||||
- 1
|
||||
pos_weight: 1.0
|
||||
train_data:
|
||||
global_batch_size: 128
|
||||
dataset_service_compression: AUTO
|
||||
inputs: &data_root "${HOME}/tmp/recap_local_random_data/*.gz"
|
||||
seg_dense_schema: &seg_dense_schema
|
||||
schema_path: "${TML_BASE}/projects/home/recap/config/home_recap_2022/segdense.json"
|
||||
renamed_features:
|
||||
"continuous": "home_recap_2022_cont__segdense_vals"
|
||||
"binary": "home_recap_2022_binary__segdense_vals"
|
||||
"discrete": "home_recap_2022_discrete__segdense_vals"
|
||||
"author_embedding": "original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings"
|
||||
"user_embedding": "user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings"
|
||||
"user_eng_embedding": "user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings"
|
||||
"meta__author_id": "meta.author_id"
|
||||
"meta__user_id": "meta.user_id"
|
||||
"meta__tweet_id": "meta.tweet_id"
|
||||
tasks: &data_tasks
|
||||
"recap.engagement.is_bookmarked": {}
|
||||
"recap.engagement.is_favorited": {}
|
||||
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": {}
|
||||
"recap.engagement.is_good_clicked_convo_desc_v2": {}
|
||||
"recap.engagement.is_negative_feedback_v2": {}
|
||||
"recap.engagement.is_profile_clicked_and_profile_engaged": {}
|
||||
"recap.engagement.is_replied": {}
|
||||
"recap.engagement.is_replied_reply_engaged_by_author": {}
|
||||
"recap.engagement.is_report_tweet_clicked": {}
|
||||
"recap.engagement.is_retweeted": {}
|
||||
"recap.engagement.is_shared": {}
|
||||
"recap.engagement.is_tweet_detail_dwelled_15_sec": {}
|
||||
"recap.engagement.is_video_playback_50": {}
|
||||
preprocess: &preprocess
|
||||
truncate_and_slice:
|
||||
continuous_feature_truncation: 2117
|
||||
binary_feature_truncation: 59
|
||||
validation_data:
|
||||
validation: &validation
|
||||
global_batch_size: &eval_batch_size 128
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks: *data_tasks
|
||||
preprocess: *preprocess
|
||||
train:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks: *data_tasks
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_favorited:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_favorited":
|
||||
pos_downsampling_rate: 0.8387
|
||||
neg_downsampling_rate: 0.01
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_favorited"
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied":
|
||||
pos_downsampling_rate: 0.9164
|
||||
neg_downsampling_rate: 0.00195
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied"
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_good_clicked_convo_desc_v2:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_good_clicked_convo_desc_v2":
|
||||
pos_downsampling_rate: 1.0
|
||||
neg_downsampling_rate: 0.00174
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_good_clicked_convo_desc_v2"
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_negative_feedback_v2:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_negative_feedback_v2":
|
||||
pos_downsampling_rate: 1.0
|
||||
neg_downsampling_rate: 0.00280
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_negative_feedback_v2"
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_profile_clicked_and_profile_engaged:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_profile_clicked_and_profile_engaged":
|
||||
pos_downsampling_rate: 1.0
|
||||
neg_downsampling_rate: 0.0015
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_profile_clicked_and_profile_engaged"
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_replied:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_replied":
|
||||
pos_downsampling_rate: 1.0
|
||||
neg_downsampling_rate: 0.005
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_replied"
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_replied_reply_engaged_by_author:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_replied_reply_engaged_by_author":
|
||||
pos_downsampling_rate: 1.0
|
||||
neg_downsampling_rate: 0.001
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_replied_reply_engaged_by_author"
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_report_tweet_clicked:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_report_tweet_clicked":
|
||||
pos_downsampling_rate: 1.0
|
||||
neg_downsampling_rate: 0.000014
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_report_tweet_clicked"
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_retweeted:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_retweeted":
|
||||
pos_downsampling_rate: 0.9561
|
||||
neg_downsampling_rate: 0.004
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_retweeted"
|
||||
preprocess: *preprocess
|
||||
recap.engagement.is_video_playback_50:
|
||||
global_batch_size: *eval_batch_size
|
||||
inputs: *data_root
|
||||
seg_dense_schema: *seg_dense_schema
|
||||
tasks:
|
||||
<<: *data_tasks
|
||||
"recap.engagement.is_video_playback_50":
|
||||
pos_downsampling_rate: 1.0
|
||||
neg_downsampling_rate: 0.00427
|
||||
evaluation_tasks:
|
||||
- "recap.engagement.is_video_playback_50"
|
||||
preprocess: *preprocess
|
||||
|
||||
optimizer:
|
||||
adam:
|
||||
beta_1: 0.95
|
||||
beta_2: 0.999
|
||||
epsilon: 1.0e-07
|
||||
multi_task_learning_rates:
|
||||
backbone_learning_rate:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.0001
|
||||
num_ramp_steps: 1000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
tower_learning_rates:
|
||||
recap.engagement.is_favorited:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.0008
|
||||
num_ramp_steps: 5000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.0001
|
||||
num_ramp_steps: 2000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
recap.engagement.is_good_clicked_convo_desc_v2:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.0002
|
||||
num_ramp_steps: 1000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
recap.engagement.is_negative_feedback_v2:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.0005
|
||||
num_ramp_steps: 5000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
recap.engagement.is_profile_clicked_and_profile_engaged:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.0003
|
||||
num_ramp_steps: 1000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
recap.engagement.is_replied:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.001
|
||||
num_ramp_steps: 1000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
recap.engagement.is_replied_reply_engaged_by_author:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.0001
|
||||
num_ramp_steps: 1000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
recap.engagement.is_report_tweet_clicked:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.0001
|
||||
num_ramp_steps: 3000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
recap.engagement.is_retweeted:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.0001
|
||||
num_ramp_steps: 1000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
recap.engagement.is_video_playback_50:
|
||||
constant: null
|
||||
linear_ramp_to_constant:
|
||||
learning_rate: 0.003
|
||||
num_ramp_steps: 1000
|
||||
linear_ramp_to_cosine: null
|
||||
piecewise_constant: null
|
||||
single_task_learning_rate: null
|
0
projects/home/recap/data/__init__.py
Normal file
0
projects/home/recap/data/__init__.py
Normal file
246
projects/home/recap/data/config.py
Normal file
246
projects/home/recap/data/config.py
Normal file
@ -0,0 +1,246 @@
|
||||
import typing
|
||||
from enum import Enum
|
||||
|
||||
|
||||
from tml.core import config as base_config
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class ExplicitDateInputs(base_config.BaseConfig):
|
||||
"""Arguments to select train/validation data using end_date and days of data."""
|
||||
|
||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
|
||||
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
|
||||
num_missing_days_tol: int = pydantic.Field(
|
||||
0, description="We tolerate <= num_missing_days_tol days of missing data."
|
||||
)
|
||||
|
||||
|
||||
class ExplicitDatetimeInputs(base_config.BaseConfig):
|
||||
"""Arguments to select train/validation data using end_datetime and hours of data."""
|
||||
|
||||
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||||
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
|
||||
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
|
||||
num_missing_hours_tol: int = pydantic.Field(
|
||||
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
|
||||
)
|
||||
|
||||
|
||||
class DdsCompressionOption(str, Enum):
|
||||
"""The only valid compression option is 'AUTO'"""
|
||||
|
||||
AUTO = "AUTO"
|
||||
|
||||
|
||||
class DatasetConfig(base_config.BaseConfig):
|
||||
inputs: str = pydantic.Field(
|
||||
None, description="A glob for selecting data.", one_of="date_inputs_format"
|
||||
)
|
||||
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
|
||||
None, one_of="date_inputs_format"
|
||||
)
|
||||
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
|
||||
|
||||
global_batch_size: pydantic.PositiveInt
|
||||
|
||||
num_files_to_keep: pydantic.PositiveInt = pydantic.Field(
|
||||
None, description="Number of shards to keep."
|
||||
)
|
||||
repeat_files: bool = pydantic.Field(
|
||||
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
||||
)
|
||||
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
|
||||
|
||||
cache: bool = pydantic.Field(
|
||||
False,
|
||||
description="Cache dataset in memory. Careful to only use this when you"
|
||||
" have enough memory to fit entire dataset.",
|
||||
)
|
||||
|
||||
data_service_dispatcher: str = pydantic.Field(None)
|
||||
ignore_data_errors: bool = pydantic.Field(
|
||||
False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs."
|
||||
)
|
||||
dataset_service_compression: DdsCompressionOption = pydantic.Field(
|
||||
None,
|
||||
description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'",
|
||||
)
|
||||
|
||||
# tf.data.Dataset options
|
||||
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
|
||||
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
||||
None, description="Number of parallel calls."
|
||||
)
|
||||
interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
||||
None, description="Number of shards to interleave."
|
||||
)
|
||||
|
||||
|
||||
class TruncateAndSlice(base_config.BaseConfig):
|
||||
# Apply truncation and then slice.
|
||||
continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(
|
||||
None, description="Experimental. Truncates continuous features to this amount for efficiency."
|
||||
)
|
||||
binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(
|
||||
None, description="Experimental. Truncates binary features to this amount for efficiency."
|
||||
)
|
||||
|
||||
continuous_feature_mask_path: str = pydantic.Field(
|
||||
None, description="Path of mask used to slice input continuous features."
|
||||
)
|
||||
binary_feature_mask_path: str = pydantic.Field(
|
||||
None, description="Path of mask used to slice input binary features."
|
||||
)
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
BFLOAT16 = "bfloat16"
|
||||
BOOL = "bool"
|
||||
|
||||
FLOAT32 = "float32"
|
||||
FLOAT16 = "float16"
|
||||
|
||||
UINT8 = "uint8"
|
||||
|
||||
|
||||
class DownCast(base_config.BaseConfig):
|
||||
# Apply down casting to selected features.
|
||||
features: typing.Dict[str, DataType] = pydantic.Field(
|
||||
None, description="Map features to down cast data types."
|
||||
)
|
||||
|
||||
|
||||
class TaskData(base_config.BaseConfig):
|
||||
pos_downsampling_rate: float = pydantic.Field(
|
||||
1.0,
|
||||
description="Downsampling rate of positives used to generate dataset.",
|
||||
)
|
||||
neg_downsampling_rate: float = pydantic.Field(
|
||||
1.0,
|
||||
description="Downsampling rate of negatives used to generate dataset.",
|
||||
)
|
||||
|
||||
|
||||
class SegDenseSchema(base_config.BaseConfig):
|
||||
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
|
||||
features: typing.List[str] = pydantic.Field(
|
||||
[],
|
||||
description="List of features (in addition to the renamed features) to read from schema path above.",
|
||||
)
|
||||
renamed_features: typing.Dict[str, str] = pydantic.Field(
|
||||
{}, description="Dictionary of renamed features."
|
||||
)
|
||||
mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(
|
||||
{},
|
||||
description="(experimental) Number of mantissa bits to mask to simulate lower precision data.",
|
||||
)
|
||||
|
||||
|
||||
class RectifyLabels(base_config.BaseConfig):
|
||||
label_rectification_window_in_hours: float = pydantic.Field(
|
||||
3.0, description="overlap time in hours for which to flip labels"
|
||||
)
|
||||
served_timestamp_field: str = pydantic.Field(
|
||||
..., description="input field corresponding to served time"
|
||||
)
|
||||
impressed_timestamp_field: str = pydantic.Field(
|
||||
..., description="input field corresponding to impressed time"
|
||||
)
|
||||
label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(
|
||||
..., description="label to the input field corresponding to engagement time"
|
||||
)
|
||||
|
||||
|
||||
class ExtractFeaturesRow(base_config.BaseConfig):
|
||||
name: str = pydantic.Field(
|
||||
...,
|
||||
description="name of the new field name to be created",
|
||||
)
|
||||
source_tensor: str = pydantic.Field(
|
||||
...,
|
||||
description="name of the dense tensor to look for the feature",
|
||||
)
|
||||
index: int = pydantic.Field(
|
||||
...,
|
||||
description="index of the feature in the dense tensor",
|
||||
)
|
||||
|
||||
|
||||
class ExtractFeatures(base_config.BaseConfig):
|
||||
extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(
|
||||
[],
|
||||
description="list of features to be extracted with their name, source tensor and index",
|
||||
)
|
||||
|
||||
|
||||
class DownsampleNegatives(base_config.BaseConfig):
|
||||
batch_multiplier: int = pydantic.Field(
|
||||
None,
|
||||
description="batch multiplier",
|
||||
)
|
||||
engagements_list: typing.List[str] = pydantic.Field(
|
||||
[],
|
||||
description="engagements with kept positives",
|
||||
)
|
||||
num_engagements: int = pydantic.Field(
|
||||
...,
|
||||
description="number engagements used in the model, including ones excluded in engagements_list",
|
||||
)
|
||||
|
||||
|
||||
class Preprocess(base_config.BaseConfig):
|
||||
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
|
||||
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
|
||||
rectify_labels: RectifyLabels = pydantic.Field(
|
||||
None, description="Rectify labels for a given overlap window"
|
||||
)
|
||||
extract_features: ExtractFeatures = pydantic.Field(
|
||||
None, description="Extract features from dense tensors."
|
||||
)
|
||||
downsample_negatives: DownsampleNegatives = pydantic.Field(
|
||||
None, description="Downsample negatives."
|
||||
)
|
||||
|
||||
|
||||
class Sampler(base_config.BaseConfig):
|
||||
"""Assumes function is defined in data/samplers.py.
|
||||
|
||||
Only use this for quick experimentation.
|
||||
If samplers are useful, we should sample from upstream data generation.
|
||||
|
||||
DEPRICATED, DO NOT USE.
|
||||
"""
|
||||
|
||||
name: str
|
||||
kwargs: typing.Dict
|
||||
|
||||
|
||||
class RecapDataConfig(DatasetConfig):
|
||||
seg_dense_schema: SegDenseSchema
|
||||
|
||||
tasks: typing.Dict[str, TaskData] = pydantic.Field(
|
||||
description="Description of individual tasks in this dataset."
|
||||
)
|
||||
evaluation_tasks: typing.List[str] = pydantic.Field(
|
||||
[], description="If specified, lists the tasks we're generating metrics for."
|
||||
)
|
||||
|
||||
preprocess: Preprocess = pydantic.Field(
|
||||
None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference."
|
||||
)
|
||||
|
||||
sampler: Sampler = pydantic.Field(
|
||||
None,
|
||||
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
|
||||
)
|
||||
|
||||
@pydantic.root_validator()
|
||||
def _validate_evaluation_tasks(cls, values):
|
||||
if values.get("evaluation_tasks") is not None:
|
||||
for task in values["evaluation_tasks"]:
|
||||
if task not in values["tasks"]:
|
||||
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
||||
return values
|
477
projects/home/recap/data/dataset.py
Normal file
477
projects/home/recap/data/dataset.py
Normal file
@ -0,0 +1,477 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple, Dict
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import tensorflow as tf
|
||||
|
||||
from tml.common.batch import DataclassBatch
|
||||
from tml.projects.home.recap.data.config import RecapDataConfig, TaskData
|
||||
from tml.projects.home.recap.data import preprocessors
|
||||
from tml.projects.home.recap.config import JobMode
|
||||
from tml.projects.home.recap.data.tfe_parsing import get_seg_dense_parse_fn
|
||||
from tml.projects.home.recap.data.util import (
|
||||
keyed_jagged_tensor_from_tensors_dict,
|
||||
sparse_or_dense_tf_to_torch,
|
||||
)
|
||||
from absl import logging
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecapBatch(DataclassBatch):
|
||||
"""Holds features and labels from the Recap dataset."""
|
||||
|
||||
continuous_features: torch.Tensor
|
||||
binary_features: torch.Tensor
|
||||
discrete_features: torch.Tensor
|
||||
sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821
|
||||
labels: torch.Tensor
|
||||
user_embedding: torch.Tensor = None
|
||||
user_eng_embedding: torch.Tensor = None
|
||||
author_embedding: torch.Tensor = None
|
||||
weights: torch.Tensor = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.weights is None:
|
||||
self.weights = torch.ones_like(self.labels)
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
if ("embedding" in feature_name) and (feature_value is None):
|
||||
setattr(self, feature_name, torch.empty([0, 0]))
|
||||
|
||||
|
||||
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
||||
"""Converts a torch data loader output into `RecapBatch`."""
|
||||
|
||||
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
|
||||
try:
|
||||
features_in, labels = x
|
||||
except ValueError:
|
||||
# For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple
|
||||
features_in, labels = x, None
|
||||
|
||||
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
|
||||
if sparse_feature_names:
|
||||
sparse_features = keyed_jagged_tensor_from_tensors_dict(
|
||||
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
|
||||
)
|
||||
|
||||
user_embedding, user_eng_embedding, author_embedding = None, None, None
|
||||
if "user_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
|
||||
raise ValueError("Only one source of embedding for user is supported")
|
||||
else:
|
||||
user_embedding = features_in["user_embedding"]
|
||||
|
||||
if "user_eng_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
|
||||
raise ValueError("Only one source of embedding for user is supported")
|
||||
else:
|
||||
user_eng_embedding = features_in["user_eng_embedding"]
|
||||
|
||||
if "author_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
|
||||
raise ValueError("Only one source of embedding for user is supported")
|
||||
else:
|
||||
author_embedding = features_in["author_embedding"]
|
||||
|
||||
return RecapBatch(
|
||||
continuous_features=features_in["continuous"],
|
||||
binary_features=features_in["binary"],
|
||||
discrete_features=features_in["discrete"],
|
||||
sparse_features=sparse_features,
|
||||
user_embedding=user_embedding,
|
||||
user_eng_embedding=user_eng_embedding,
|
||||
author_embedding=author_embedding,
|
||||
labels=labels,
|
||||
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels)
|
||||
)
|
||||
|
||||
|
||||
def _chain(param, f1, f2):
|
||||
"""
|
||||
Reduce multiple functions into one chained function
|
||||
_chain(x, f1, f2) -> f2(f1(x))
|
||||
"""
|
||||
output = param
|
||||
fns = [f1, f2]
|
||||
for f in fns:
|
||||
output = f(output)
|
||||
return output
|
||||
|
||||
|
||||
def _add_weights(inputs, tasks: Dict[str, TaskData]):
|
||||
"""Adds weights based on label sampling for positive and negatives.
|
||||
|
||||
This is useful for numeric calibration etc. This mutates inputs.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary of strings to tensor-like structures.
|
||||
tasks: A dict of string (label) to `TaskData` specifying inputs.
|
||||
|
||||
Returns:
|
||||
A tuple of features and labels; weights are added to features.
|
||||
"""
|
||||
|
||||
weights = []
|
||||
for key, task in tasks.items():
|
||||
label = inputs[key]
|
||||
float_label = tf.cast(label, tf.float32)
|
||||
|
||||
weights.append(
|
||||
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
||||
)
|
||||
|
||||
# Ensure we are batch-major (assumes we batch before this call).
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
||||
return inputs
|
||||
|
||||
|
||||
def get_datetimes(explicit_datetime_inputs):
|
||||
"""Compute list datetime strings for train/validation data."""
|
||||
datetime_format = "%Y/%m/%d/%H"
|
||||
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)
|
||||
dates = sorted(
|
||||
[
|
||||
(end - timedelta(hours=i + 1)).strftime(datetime_format)
|
||||
for i in range(int(explicit_datetime_inputs.hours))
|
||||
]
|
||||
)
|
||||
return dates
|
||||
|
||||
|
||||
def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
|
||||
"""
|
||||
Compile list of files for training/validation.
|
||||
|
||||
Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data.
|
||||
For each hour of data, if the directory is missing or empty, we increment a counter to keep
|
||||
track of the number of missing data hours.
|
||||
Returns only files with a `.gz` extension.
|
||||
|
||||
Args:
|
||||
explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object
|
||||
|
||||
Returns:
|
||||
data_files: Sorted list of files to read corresponding to data at the desired datetimes
|
||||
num_hours_missing: Number of hours that we are missing data
|
||||
|
||||
"""
|
||||
datetimes = get_datetimes(explicit_datetime_inputs)
|
||||
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
|
||||
data_files = []
|
||||
num_hours_missing = 0
|
||||
for folder in folders:
|
||||
try:
|
||||
files = tf.io.gfile.listdir(folder)
|
||||
if not files:
|
||||
logging.warning(f"{folder} contained no data files")
|
||||
num_hours_missing += 1
|
||||
data_files.extend(
|
||||
[
|
||||
os.path.join(folder, filename)
|
||||
for filename in files
|
||||
if filename.rsplit(".", 1)[-1].lower() == "gz"
|
||||
]
|
||||
)
|
||||
except tf.errors.NotFoundError as e:
|
||||
num_hours_missing += 1
|
||||
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
|
||||
return sorted(data_files), num_hours_missing
|
||||
|
||||
|
||||
def _map_output_for_inference(
|
||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||
):
|
||||
if preprocessor:
|
||||
raise ValueError("No preprocessor should be used at inference time.")
|
||||
if add_weights:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Add zero weights.
|
||||
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1))
|
||||
for label in tasks:
|
||||
del inputs[label]
|
||||
return inputs
|
||||
|
||||
|
||||
def _map_output_for_train_eval(
|
||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||
):
|
||||
if add_weights:
|
||||
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
|
||||
|
||||
# Warning this has to happen first as it changes the input
|
||||
if preprocessor:
|
||||
inputs = preprocessor(inputs)
|
||||
|
||||
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
|
||||
|
||||
for label in tasks:
|
||||
del inputs[label]
|
||||
|
||||
return inputs, label_values
|
||||
|
||||
|
||||
def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
|
||||
"""Adds weights based on label sampling for positive and negatives.
|
||||
|
||||
This is useful for numeric calibration etc. This mutates inputs.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary of strings to tensor-like structures.
|
||||
tasks: A dict of string (label) to `TaskData` specifying inputs.
|
||||
|
||||
Returns:
|
||||
A tuple of features and labels; weights are added to features.
|
||||
"""
|
||||
weights = []
|
||||
for key, task in tasks.items():
|
||||
label = inputs[key]
|
||||
float_label = tf.cast(label, tf.float32)
|
||||
|
||||
weights.append(
|
||||
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
||||
)
|
||||
|
||||
# Ensure we are batch-major (assumes we batch before this call).
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
||||
return inputs
|
||||
|
||||
|
||||
class RecapDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_config: RecapDataConfig,
|
||||
dataset_service: Optional[str] = None,
|
||||
mode: JobMode = JobMode.TRAIN,
|
||||
compression: Optional[str] = "AUTO",
|
||||
repeat: bool = False,
|
||||
vocab_mapper: tf.keras.Model = None,
|
||||
):
|
||||
logging.info("***** Labels *****")
|
||||
logging.info(list(data_config.tasks.keys()))
|
||||
|
||||
self._data_config = data_config
|
||||
self._parse_fn = get_seg_dense_parse_fn(data_config)
|
||||
self._mode = mode
|
||||
self._repeat = repeat
|
||||
self._num_concurrent_iterators = 1
|
||||
self._vocab_mapper = vocab_mapper
|
||||
self.dataset_service = dataset_service
|
||||
|
||||
preprocessor = None
|
||||
self._batch_size_multiplier = 1
|
||||
if data_config.preprocess:
|
||||
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)
|
||||
if data_config.preprocess.downsample_negatives:
|
||||
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
|
||||
|
||||
self._preprocessor = preprocessor
|
||||
|
||||
if mode == JobMode.INFERENCE:
|
||||
if preprocessor is not None:
|
||||
raise ValueError("Expect no preprocessor at inference time.")
|
||||
should_add_weights = False
|
||||
output_map_fn = _map_output_for_inference # (features,)
|
||||
else:
|
||||
# Only add weights if there is a reason to! If all weights will
|
||||
# be equal to 1.0, save bandwidth between DDS and Chief by simply
|
||||
# relying on the fact that weights default to 1.0 in `RecapBatch`
|
||||
# WARNING: Weights may still be added as a side effect of a preprocessor
|
||||
# such as `DownsampleNegatives`.
|
||||
should_add_weights = any(
|
||||
[
|
||||
task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0
|
||||
for task_cfg in data_config.tasks.values()
|
||||
]
|
||||
)
|
||||
output_map_fn = _map_output_for_train_eval # (features, labels)
|
||||
|
||||
self._output_map_fn = functools.partial(
|
||||
output_map_fn,
|
||||
tasks=data_config.tasks,
|
||||
preprocessor=preprocessor,
|
||||
add_weights=should_add_weights,
|
||||
)
|
||||
|
||||
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
|
||||
|
||||
self._tf_dataset = self._create_tf_dataset()
|
||||
|
||||
self._init_tensor_spec()
|
||||
|
||||
def _init_tensor_spec(self):
|
||||
def _tensor_spec_to_torch_shape(spec):
|
||||
if spec.shape is None:
|
||||
return None
|
||||
shape = [x if x is not None else -1 for x in spec.shape]
|
||||
return torch.Size(shape)
|
||||
|
||||
self.torch_element_spec = tf.nest.map_structure(
|
||||
_tensor_spec_to_torch_shape, self._tf_dataset.element_spec
|
||||
)
|
||||
|
||||
def _create_tf_dataset(self):
|
||||
if hasattr(self, "_tf_dataset"):
|
||||
raise ValueError("Do not call `_create_tf_dataset` more than once.")
|
||||
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
per_replica_bsz = (
|
||||
self._batch_size_multiplier * self._data_config.global_batch_size // world_size
|
||||
)
|
||||
|
||||
dataset: tf.data.Dataset = self._create_base_tf_dataset(
|
||||
batch_size=per_replica_bsz,
|
||||
)
|
||||
|
||||
if self._repeat:
|
||||
logging.info("Repeating dataset")
|
||||
dataset = dataset.repeat()
|
||||
|
||||
if self.dataset_service:
|
||||
if self._num_concurrent_iterators > 1:
|
||||
if not self.machines_config:
|
||||
raise ValueError(
|
||||
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
|
||||
)
|
||||
dataset = dataset_lib.with_auto_tune_budget(
|
||||
dataset,
|
||||
machine_config=self.machines_config.chief,
|
||||
num_concurrent_iterators=self.num_concurrent_iterators,
|
||||
on_chief=False,
|
||||
)
|
||||
|
||||
self.dataset_id, self.job_name = register_dataset(
|
||||
dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
|
||||
)
|
||||
dataset = distribute_from_dataset_id(
|
||||
dataset_id=self.dataset_id, # type: ignore[arg-type]
|
||||
job_name=self.job_name,
|
||||
dataset_service=self.dataset_service,
|
||||
compression=self.compression,
|
||||
)
|
||||
|
||||
elif self._num_concurrent_iterators > 1:
|
||||
if not self.machines_config:
|
||||
raise ValueError(
|
||||
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
|
||||
)
|
||||
dataset = dataset_lib.with_auto_tune_budget(
|
||||
dataset,
|
||||
machine_config=self.machines_config.chief,
|
||||
num_concurrent_iterators=self._num_concurrent_iterators,
|
||||
on_chief=True,
|
||||
)
|
||||
|
||||
# Vocabulary mapping happens on the training node, not in dds because of size.
|
||||
if self._vocab_mapper:
|
||||
dataset = dataset.map(self._vocab_mapper)
|
||||
|
||||
return dataset.prefetch(world_size * 2)
|
||||
|
||||
def _create_base_tf_dataset(self, batch_size: int):
|
||||
if self._data_config.inputs:
|
||||
glob = self._data_config.inputs
|
||||
filenames = sorted(tf.io.gfile.glob(glob))
|
||||
elif self._data_config.explicit_datetime_inputs:
|
||||
num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol
|
||||
filenames, num_hours_missing = get_explicit_datetime_inputs_files(
|
||||
self._data_config.explicit_datetime_inputs,
|
||||
increment="hourly",
|
||||
)
|
||||
if num_hours_missing > num_missing_hours_tol:
|
||||
raise ValueError(
|
||||
f"We are missing {num_hours_missing} hours of data"
|
||||
f"more than tolerance {num_missing_hours_tol}."
|
||||
)
|
||||
elif self._data_config.explicit_date_inputs:
|
||||
num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol
|
||||
filenames, num_days_missing = get_explicit_datetime_inputs_files(
|
||||
self._data_config.explicit_date_inputs,
|
||||
increment="daily",
|
||||
)
|
||||
if num_days_missing > num_missing_days_tol:
|
||||
raise ValueError(
|
||||
f"We are missing {num_days_missing} days of data"
|
||||
f"more than tolerance {num_missing_days_tol}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
|
||||
)
|
||||
|
||||
num_files = len(filenames)
|
||||
logging.info(f"Found {num_files} data files")
|
||||
if num_files < 1:
|
||||
raise ValueError("No data files found")
|
||||
|
||||
if self._data_config.num_files_to_keep is not None:
|
||||
filenames = filenames[: self._data_config.num_files_to_keep]
|
||||
logging.info(f"Retaining only {len(filenames)} files.")
|
||||
|
||||
filenames_ds = (
|
||||
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
|
||||
# Because of drop_remainder, if our dataset does not fill
|
||||
# up a batch, it will emit nothing without this repeat.
|
||||
.repeat(-1)
|
||||
)
|
||||
|
||||
if self._data_config.file_batch_size:
|
||||
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)
|
||||
|
||||
def per_shard_dataset(filename):
|
||||
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
|
||||
return ds.prefetch(4)
|
||||
|
||||
ds = filenames_ds.interleave(
|
||||
per_shard_dataset,
|
||||
block_length=4,
|
||||
deterministic=False,
|
||||
num_parallel_calls=self._data_config.interleave_num_parallel_calls
|
||||
or tf.data.experimental.AUTOTUNE,
|
||||
)
|
||||
|
||||
# Combine functions into one map call to reduce overhead.
|
||||
map_fn = functools.partial(
|
||||
_chain,
|
||||
f1=self._parse_fn,
|
||||
f2=self._output_map_fn,
|
||||
)
|
||||
|
||||
# Shuffle -> Batch -> Parse is the correct ordering
|
||||
# Shuffling needs to be performed before batching otherwise there is not much point
|
||||
# Batching happens before parsing because tf.Example parsing is actually vectorized
|
||||
# and works much faster overall on batches of data.
|
||||
ds = (
|
||||
# DANGER DANGER: there is a default shuffle size here.
|
||||
ds.shuffle(self._data_config.examples_shuffle_buffer_size)
|
||||
.batch(batch_size=batch_size, drop_remainder=True)
|
||||
.map(
|
||||
map_fn,
|
||||
num_parallel_calls=self._data_config.map_num_parallel_calls
|
||||
or tf.data.experimental.AUTOTUNE,
|
||||
)
|
||||
)
|
||||
|
||||
if self._data_config.cache:
|
||||
ds = ds.cache()
|
||||
|
||||
if self._data_config.ignore_data_errors:
|
||||
ds = ds.apply(tf.data.experimental.ignore_errors())
|
||||
|
||||
options = tf.data.Options()
|
||||
options.experimental_deterministic = False
|
||||
ds = ds.with_options(options)
|
||||
|
||||
return ds
|
||||
|
||||
def _gen(self):
|
||||
for x in self._tf_dataset:
|
||||
yield to_batch(x)
|
||||
|
||||
def to_dataloader(self) -> Dict[str, torch.Tensor]:
|
||||
return torch.utils.data.DataLoader(self, batch_size=None)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._gen())
|
81
projects/home/recap/data/generate_random_data.py
Normal file
81
projects/home/recap/data/generate_random_data.py
Normal file
@ -0,0 +1,81 @@
|
||||
import os
|
||||
import json
|
||||
from absl import app, flags, logging
|
||||
import tensorflow as tf
|
||||
from typing import Dict
|
||||
|
||||
from tml.projects.home.recap.data import tfe_parsing
|
||||
from tml.core import config as tml_config_mod
|
||||
import tml.projects.home.recap.config as recap_config_mod
|
||||
|
||||
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
|
||||
flags.DEFINE_integer("n_examples", 100, "Numer of examples to generate.")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def _generate_random_example(
|
||||
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
|
||||
) -> Dict[str, tf.Tensor]:
|
||||
example = {}
|
||||
for feature_name, feature_spec in tf_example_schema.items():
|
||||
dtype = feature_spec.dtype
|
||||
if (dtype == tf.int64) or (dtype == tf.int32):
|
||||
x = tf.experimental.numpy.random.randint(0, high=10, size=feature_spec.shape, dtype=dtype)
|
||||
elif (dtype == tf.float32) or (dtype == tf.float64):
|
||||
x = tf.random.uniform(shape=[feature_spec.shape], dtype=dtype)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown type {dtype}")
|
||||
|
||||
example[feature_name] = x
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def _float_feature(value):
|
||||
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
||||
|
||||
|
||||
def _int64_feature(value):
|
||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
||||
|
||||
|
||||
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
|
||||
feature = {}
|
||||
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
|
||||
for feature_name, tensor in x.items():
|
||||
feature[feature_name] = serializers[tensor.dtype](tensor)
|
||||
|
||||
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
|
||||
return example_proto.SerializeToString()
|
||||
|
||||
|
||||
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
|
||||
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
|
||||
seg_dense_schema = json.load(f)["schema"]
|
||||
|
||||
tf_example_schema = tfe_parsing.create_tf_example_schema(
|
||||
config.train_data,
|
||||
seg_dense_schema,
|
||||
)
|
||||
|
||||
record_filename = os.path.join(data_path, "random.tfrecord.gz")
|
||||
|
||||
with tf.io.TFRecordWriter(record_filename, "GZIP") as writer:
|
||||
random_example = _generate_random_example(tf_example_schema)
|
||||
serialized_example = _serialize_example(random_example)
|
||||
writer.write(serialized_example)
|
||||
|
||||
|
||||
def _generate_data_main(unused_argv):
|
||||
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
||||
|
||||
# Find the path where to put the data
|
||||
data_path = os.path.dirname(config.train_data.inputs)
|
||||
logging.info("Putting random data in %s", data_path)
|
||||
|
||||
generate_data(data_path, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(_generate_data_main)
|
206
projects/home/recap/data/preprocessors.py
Normal file
206
projects/home/recap/data/preprocessors.py
Normal file
@ -0,0 +1,206 @@
|
||||
"""
|
||||
Preprocessors applied on DDS workers in order to modify the dataset on the fly.
|
||||
Some of these preprocessors are also applied to the model at serving time.
|
||||
"""
|
||||
from tml.projects.home.recap import config as config_mod
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TruncateAndSlice(tf.keras.Model):
|
||||
"""Class for truncating and slicing."""
|
||||
|
||||
def __init__(self, truncate_and_slice_config):
|
||||
super().__init__()
|
||||
self._truncate_and_slice_config = truncate_and_slice_config
|
||||
|
||||
if self._truncate_and_slice_config.continuous_feature_mask_path:
|
||||
with tf.io.gfile.GFile(
|
||||
self._truncate_and_slice_config.continuous_feature_mask_path, "rb"
|
||||
) as f:
|
||||
self._continuous_mask = np.load(f).nonzero()[0]
|
||||
logging.info(f"Slicing {np.sum(self._continuous_mask)} continuous features.")
|
||||
else:
|
||||
self._continuous_mask = None
|
||||
|
||||
if self._truncate_and_slice_config.binary_feature_mask_path:
|
||||
with tf.io.gfile.GFile(self._truncate_and_slice_config.binary_feature_mask_path, "rb") as f:
|
||||
self._binary_mask = np.load(f).nonzero()[0]
|
||||
logging.info(f"Slicing {np.sum(self._binary_mask)} binary features.")
|
||||
else:
|
||||
self._binary_mask = None
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||
if self._truncate_and_slice_config.continuous_feature_truncation:
|
||||
logging.info("Truncating continuous")
|
||||
outputs["continuous"] = outputs["continuous"][
|
||||
:, : self._truncate_and_slice_config.continuous_feature_truncation
|
||||
]
|
||||
if self._truncate_and_slice_config.binary_feature_truncation:
|
||||
logging.info("Truncating binary")
|
||||
outputs["binary"] = outputs["binary"][
|
||||
:, : self._truncate_and_slice_config.binary_feature_truncation
|
||||
]
|
||||
if self._continuous_mask is not None:
|
||||
outputs["continuous"] = tf.gather(outputs["continuous"], self._continuous_mask, axis=1)
|
||||
if self._binary_mask is not None:
|
||||
outputs["binary"] = tf.gather(outputs["binary"], self._binary_mask, axis=1)
|
||||
return outputs
|
||||
|
||||
|
||||
class DownCast(tf.keras.Model):
|
||||
"""Class for Down casting dataset before serialization and transferring to training host.
|
||||
Depends on the data type and the actual data range, the down casting can be lossless or not.
|
||||
It is strongly recommended to compare the metrics before and after down casting.
|
||||
"""
|
||||
|
||||
def __init__(self, downcast_config):
|
||||
super().__init__()
|
||||
self.config = downcast_config
|
||||
self._type_map = {
|
||||
"bfloat16": tf.bfloat16,
|
||||
"bool": tf.bool,
|
||||
}
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
|
||||
for feature, type_str in self.config.features.items():
|
||||
assert type_str in self._type_map
|
||||
if type_str == "bfloat16":
|
||||
logging.warning(
|
||||
"Although bfloat16 and float32 have the same number of exponent bits, this down casting is not 100% lossless. Please double check metrics."
|
||||
)
|
||||
down_cast_data_type = self._type_map[type_str]
|
||||
outputs[feature] = tf.cast(outputs[feature], dtype=down_cast_data_type)
|
||||
return outputs
|
||||
|
||||
|
||||
class RectifyLabels(tf.keras.Model):
|
||||
"""Class for rectifying labels"""
|
||||
|
||||
def __init__(self, rectify_label_config):
|
||||
super().__init__()
|
||||
self._config = rectify_label_config
|
||||
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
served_ts_field = self._config.served_timestamp_field
|
||||
impressed_ts_field = self._config.impressed_timestamp_field
|
||||
|
||||
for label, engaged_ts_field in self._config.label_to_engaged_timestamp_field.items():
|
||||
impressed = inputs[impressed_ts_field]
|
||||
served = inputs[served_ts_field]
|
||||
engaged = inputs[engaged_ts_field]
|
||||
|
||||
keep = tf.math.logical_and(inputs[label] > 0, impressed - served < self._window)
|
||||
keep = tf.math.logical_and(keep, engaged - served < self._window)
|
||||
inputs[label] = tf.where(keep, inputs[label], tf.zeros_like(inputs[label]))
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class ExtractFeatures(tf.keras.Model):
|
||||
"""Class for extracting individual features from dense tensors by their index."""
|
||||
|
||||
def __init__(self, extract_features_config):
|
||||
super().__init__()
|
||||
self._config = extract_features_config
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
|
||||
for row in self._config.extract_feature_table:
|
||||
inputs[row.name] = inputs[row.source_tensor][:, row.index]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class DownsampleNegatives(tf.keras.Model):
|
||||
"""Class for down-sampling/dropping negatives and updating the weights.
|
||||
|
||||
If inputs['fav'] = [1, 0, 0, 0] and inputs['weights'] = [1.0, 1.0, 1.0, 1.0]
|
||||
inputs are transformed to inputs['fav'] = [1, 0] and inputs['weights'] = [1.0, 3.0]
|
||||
when batch_multiplier=2 and engagements_list=['fav']
|
||||
|
||||
It supports multiple engagements (union/logical_or is used to aggregate engagements), so we don't
|
||||
drop positives for any engagement.
|
||||
"""
|
||||
|
||||
def __init__(self, downsample_negatives_config):
|
||||
super().__init__()
|
||||
self.config = downsample_negatives_config
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
labels = self.config.engagements_list
|
||||
# union of engagements
|
||||
mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1))
|
||||
n_positives = tf.reduce_sum(tf.cast(mask, tf.int32))
|
||||
batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32)
|
||||
negative_weights = tf.math.divide_no_nan(
|
||||
tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32),
|
||||
tf.cast(batch_size - n_positives, tf.float32),
|
||||
)
|
||||
new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights
|
||||
|
||||
def _split_by_label_concatenate_and_truncate(input_tensor):
|
||||
# takes positive examples and concatenate with negative examples and truncate
|
||||
# DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50)
|
||||
return tf.concat(
|
||||
[
|
||||
input_tensor[mask],
|
||||
input_tensor[tf.math.logical_not(mask)],
|
||||
],
|
||||
0,
|
||||
)[:batch_size]
|
||||
|
||||
if "weights" not in inputs:
|
||||
# add placeholder so logic below applies even if weights aren't present in inputs
|
||||
inputs["weights"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements])
|
||||
|
||||
for tensor in inputs:
|
||||
if tensor == "weights":
|
||||
inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1])
|
||||
|
||||
inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor])
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
|
||||
"""Builds a preprocess model to apply all preprocessing stages."""
|
||||
if mode == config_mod.JobMode.INFERENCE:
|
||||
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
|
||||
return None
|
||||
|
||||
preprocess_models = []
|
||||
if preprocess_config.downsample_negatives:
|
||||
preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives))
|
||||
if preprocess_config.truncate_and_slice:
|
||||
preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice))
|
||||
if preprocess_config.downcast:
|
||||
preprocess_models.append(DownCast(preprocess_config.downcast))
|
||||
if preprocess_config.rectify_labels:
|
||||
preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels))
|
||||
if preprocess_config.extract_features:
|
||||
preprocess_models.append(ExtractFeatures(preprocess_config.extract_features))
|
||||
|
||||
if len(preprocess_models) == 0:
|
||||
raise ValueError("No known preprocessor.")
|
||||
|
||||
class PreprocessModel(tf.keras.Model):
|
||||
def __init__(self, preprocess_models):
|
||||
super().__init__()
|
||||
self.preprocess_models = preprocess_models
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
outputs = inputs
|
||||
for model in self.preprocess_models:
|
||||
outputs = model(outputs, training, mask)
|
||||
return outputs
|
||||
|
||||
if len(preprocess_models) > 1:
|
||||
logging.warning(
|
||||
"With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders."
|
||||
)
|
||||
return PreprocessModel(preprocess_models)
|
129
projects/home/recap/data/tfe_parsing.py
Normal file
129
projects/home/recap/data/tfe_parsing.py
Normal file
@ -0,0 +1,129 @@
|
||||
import functools
|
||||
import json
|
||||
|
||||
from tml.projects.home.recap.data import config as recap_data_config
|
||||
|
||||
from absl import logging
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
|
||||
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
|
||||
|
||||
|
||||
def create_tf_example_schema(
|
||||
data_config: recap_data_config.SegDenseSchema,
|
||||
segdense_schema,
|
||||
):
|
||||
"""Generate schema for deseralizing tf.Example.
|
||||
|
||||
Args:
|
||||
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
|
||||
labels: List of strings denoting labels.
|
||||
|
||||
Returns:
|
||||
A dictionary schema suitable for deserializing tf.Example.
|
||||
"""
|
||||
segdense_config = data_config.seg_dense_schema
|
||||
labels = list(data_config.tasks.keys())
|
||||
used_features = (
|
||||
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
|
||||
)
|
||||
logging.info(used_features)
|
||||
|
||||
tfe_schema = {}
|
||||
for entry in segdense_schema:
|
||||
feature_name = entry["feature_name"]
|
||||
|
||||
if feature_name in used_features:
|
||||
length = entry["length"]
|
||||
dtype = entry["dtype"]
|
||||
|
||||
if feature_name in labels:
|
||||
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
|
||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
|
||||
)
|
||||
elif length == -1:
|
||||
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
|
||||
else:
|
||||
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||||
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
|
||||
)
|
||||
for feature_name in used_features:
|
||||
if feature_name not in tfe_schema:
|
||||
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
||||
return tfe_schema
|
||||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
def make_mantissa_mask(mask_length: int) -> tf.Tensor:
|
||||
"""For experimentating with emulating bfloat16 or less precise types."""
|
||||
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
|
||||
|
||||
|
||||
def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
|
||||
"""For experimentating with emulating bfloat16 or less precise types."""
|
||||
mask: tf.Tensor = make_mantissa_mask(mask_length)
|
||||
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
|
||||
|
||||
|
||||
def parse_tf_example(
|
||||
serialized_example,
|
||||
tfe_schema,
|
||||
seg_dense_schema_config,
|
||||
):
|
||||
"""Parse serialized tf.Example into dict of tensors.
|
||||
|
||||
Args:
|
||||
serialized_example: Serialized tf.Example to be parsed.
|
||||
tfe_schema: Dictionary schema suitable for deserializing tf.Example.
|
||||
|
||||
Returns:
|
||||
Dictionary of tensors to be used as model input.
|
||||
"""
|
||||
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
|
||||
|
||||
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
|
||||
inputs[new_feature_name] = inputs.pop(old_feature_name)
|
||||
|
||||
# This should not actually be used except for experimentation with low precision floats.
|
||||
if "mask_mantissa_features" in seg_dense_schema_config:
|
||||
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
|
||||
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
|
||||
|
||||
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
|
||||
# at TF level.
|
||||
# We should not return empty tensors if we dont use embeddings.
|
||||
# Otherwise, it breaks numpy->pt conversion
|
||||
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
|
||||
for renamed_key in renamed_keys:
|
||||
if "embedding" in renamed_key and (renamed_key not in inputs):
|
||||
inputs[renamed_key] = tf.zeros([], tf.float32)
|
||||
|
||||
logging.info(f"parsed example and inputs are {inputs}")
|
||||
return inputs
|
||||
|
||||
|
||||
def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):
|
||||
"""Placeholder for seg dense.
|
||||
|
||||
In the future, when we use more seg dense variations, we can change this.
|
||||
"""
|
||||
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
|
||||
seg_dense_schema = json.load(f)["schema"]
|
||||
|
||||
tf_example_schema = create_tf_example_schema(
|
||||
data_config,
|
||||
seg_dense_schema,
|
||||
)
|
||||
|
||||
logging.info("***** TF Example Schema *****")
|
||||
logging.info(tf_example_schema)
|
||||
|
||||
parse = functools.partial(
|
||||
parse_tf_example,
|
||||
tfe_schema=tf_example_schema,
|
||||
seg_dense_schema_config=data_config.seg_dense_schema,
|
||||
)
|
||||
return parse
|
120
projects/home/recap/data/util.py
Normal file
120
projects/home/recap/data/util.py
Normal file
@ -0,0 +1,120 @@
|
||||
from typing import Mapping, Tuple, Union
|
||||
import torch
|
||||
import torchrec
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def keyed_tensor_from_tensors_dict(
|
||||
tensor_map: Mapping[str, torch.Tensor]
|
||||
) -> "torchrec.KeyedTensor":
|
||||
"""
|
||||
Convert a dictionary of torch tensor to torchrec keyed tensor
|
||||
Args:
|
||||
tensor_map:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
keys = list(tensor_map.keys())
|
||||
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
||||
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
|
||||
# [Batch_size x 1].
|
||||
values = [
|
||||
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
|
||||
for key in keys
|
||||
]
|
||||
return torchrec.KeyedTensor.from_tensor_list(keys, values)
|
||||
|
||||
|
||||
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if tensor.is_sparse:
|
||||
x = tensor.coalesce() # Ensure that the indices are ordered.
|
||||
lengths = torch.bincount(x.indices()[0])
|
||||
values = x.values()
|
||||
else:
|
||||
values = tensor
|
||||
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
||||
return values, lengths
|
||||
|
||||
|
||||
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
|
||||
"""
|
||||
Convert a torch tensor to torchrec jagged tensor.
|
||||
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
|
||||
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
|
||||
dense_shape of the sparse tensor can be arbitrary.
|
||||
Args:
|
||||
tensor: a torch (sparse) tensor.
|
||||
Returns:
|
||||
"""
|
||||
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
|
||||
return torchrec.JaggedTensor(values=values, lengths=lengths)
|
||||
|
||||
|
||||
def keyed_jagged_tensor_from_tensors_dict(
|
||||
tensor_map: Mapping[str, torch.Tensor]
|
||||
) -> "torchrec.KeyedJaggedTensor":
|
||||
"""
|
||||
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
|
||||
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
|
||||
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
|
||||
dense_shape of the sparse tensor can be arbitrary.
|
||||
Args:
|
||||
tensor_map:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
if not tensor_map:
|
||||
return torchrec.KeyedJaggedTensor(
|
||||
keys=[],
|
||||
values=torch.zeros(0, dtype=torch.int),
|
||||
lengths=torch.zeros(0, dtype=torch.int),
|
||||
)
|
||||
values = []
|
||||
lengths = []
|
||||
for tensor in tensor_map.values():
|
||||
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
|
||||
values.append(torch.squeeze(tensor_val))
|
||||
lengths.append(tensor_len)
|
||||
|
||||
values = torch.cat(values, axis=0)
|
||||
lengths = torch.cat(lengths, axis=0)
|
||||
|
||||
return torchrec.KeyedJaggedTensor(
|
||||
keys=list(tensor_map.keys()),
|
||||
values=values,
|
||||
lengths=lengths,
|
||||
)
|
||||
|
||||
|
||||
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
|
||||
return tf_tensor._numpy() # noqa
|
||||
|
||||
|
||||
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
||||
tensor = _tf_to_numpy(tensor)
|
||||
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
|
||||
if tensor.dtype.name == "bfloat16":
|
||||
tensor = tensor.astype(np.float32)
|
||||
|
||||
tensor = torch.from_numpy(tensor)
|
||||
if pin_memory:
|
||||
tensor = tensor.pin_memory()
|
||||
return tensor
|
||||
|
||||
|
||||
def sparse_or_dense_tf_to_torch(
|
||||
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
|
||||
) -> torch.Tensor:
|
||||
if isinstance(tensor, tf.SparseTensor):
|
||||
tensor = torch.sparse_coo_tensor(
|
||||
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
|
||||
_dense_tf_to_torch(tensor.values, pin_memory),
|
||||
torch.Size(_tf_to_numpy(tensor.dense_shape)),
|
||||
)
|
||||
else:
|
||||
tensor = _dense_tf_to_torch(tensor, pin_memory)
|
||||
return tensor
|
94
projects/home/recap/embedding/config.py
Normal file
94
projects/home/recap/embedding/config.py
Normal file
@ -0,0 +1,94 @@
|
||||
from typing import List, Optional
|
||||
import tml.core.config as base_config
|
||||
from tml.optimizers import config as optimizer_config
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class EmbeddingSnapshot(base_config.BaseConfig):
|
||||
"""Configuration for Embedding snapshot"""
|
||||
|
||||
emb_name: str = pydantic.Field(
|
||||
..., description="Name of the embedding table from the loaded snapshot"
|
||||
)
|
||||
embedding_snapshot_uri: str = pydantic.Field(
|
||||
..., description="Path to torchsnapshot of the embedding"
|
||||
)
|
||||
|
||||
|
||||
# https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_configs.EmbeddingBagConfig
|
||||
class EmbeddingBagConfig(base_config.BaseConfig):
|
||||
"""Configuration for EmbeddingBag."""
|
||||
|
||||
name: str = pydantic.Field(..., description="name of embedding bag")
|
||||
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
|
||||
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
|
||||
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
|
||||
vocab: str = pydantic.Field(
|
||||
None, description="Directory to parquet files of mapping from entity ID to table index."
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingOptimizerConfig(base_config.BaseConfig):
|
||||
learning_rate: optimizer_config.LearningRate = pydantic.Field(
|
||||
None, description="learning rate scheduler for the EBC"
|
||||
)
|
||||
init_learning_rate: float = pydantic.Field(description="initial learning rate for the EBC")
|
||||
# NB: Only sgd is supported right now and implicitly.
|
||||
# FBGemm only supports simple exact_sgd which only takes LR as an argument.
|
||||
|
||||
|
||||
class LargeEmbeddingsConfig(base_config.BaseConfig):
|
||||
"""Configuration for EmbeddingBagCollection.
|
||||
|
||||
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
|
||||
"""
|
||||
|
||||
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
|
||||
optimizer: EmbeddingOptimizerConfig
|
||||
tables_to_log: List[str] = pydantic.Field(
|
||||
None, description="list of embedding table names that we want to log during training"
|
||||
)
|
||||
|
||||
|
||||
class StratifierConfig(base_config.BaseConfig):
|
||||
name: str
|
||||
index: int
|
||||
value: int
|
||||
|
||||
|
||||
class SmallEmbeddingBagConfig(base_config.BaseConfig):
|
||||
"""Configuration for SmallEmbeddingBag."""
|
||||
|
||||
name: str = pydantic.Field(..., description="name of embedding bag")
|
||||
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
|
||||
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
|
||||
index: int = pydantic.Field(..., description="index in the discrete tensor to look for")
|
||||
|
||||
|
||||
class SmallEmbeddingBagConfig(base_config.BaseConfig):
|
||||
"""Configuration for SmallEmbeddingBag."""
|
||||
|
||||
name: str = pydantic.Field(..., description="name of embedding bag")
|
||||
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
|
||||
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
|
||||
index: int = pydantic.Field(..., description="index in the discrete tensor to look for")
|
||||
|
||||
|
||||
class SmallEmbeddingsConfig(base_config.BaseConfig):
|
||||
"""Configuration for SmallEmbeddingConfig.
|
||||
|
||||
Here we can use discrete features that already are present in our TFRecords generated using
|
||||
segdense conversion as "home_recap_2022_discrete__segdense_vals" which are available in
|
||||
the model as "discrete_features", and embed a user-defined set of them with configurable
|
||||
dimensions and vocabulary sizes.
|
||||
|
||||
Compared with LargeEmbedding, this config is for small embedding tables that can fit inside
|
||||
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
|
||||
serving time due to size (>>1 GB).
|
||||
|
||||
This small embeddings table uses the same optimizer as the rest of the model."""
|
||||
|
||||
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
|
||||
..., description="list of embedding tables"
|
||||
)
|
108
projects/home/recap/main.py
Normal file
108
projects/home/recap/main.py
Normal file
@ -0,0 +1,108 @@
|
||||
import datetime
|
||||
import os
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
import tensorflow as tf
|
||||
|
||||
import tml.common.checkpointing.snapshot as snapshot_lib
|
||||
from tml.common.device import setup_and_get_device
|
||||
from tml.core import config as tml_config_mod
|
||||
import tml.core.custom_training_loop as ctl
|
||||
from tml.core import debug_training_loop
|
||||
from tml.core import losses
|
||||
from tml.core.loss_type import LossType
|
||||
from tml.model import maybe_shard_model
|
||||
|
||||
|
||||
import tml.projects.home.recap.data.dataset as ds
|
||||
import tml.projects.home.recap.config as recap_config_mod
|
||||
import tml.projects.home.recap.optimizer as optimizer_mod
|
||||
|
||||
|
||||
# from tml.projects.home.recap import feature
|
||||
import tml.projects.home.recap.model as model_mod
|
||||
import torchmetrics as tm
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torchrec.distributed.model_parallel import DistributedModelParallel
|
||||
|
||||
from absl import app, flags, logging
|
||||
|
||||
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
|
||||
flags.DEFINE_bool("debug_loop", False, "Run with debug loop (slow)")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
|
||||
print("#" * 100)
|
||||
|
||||
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
||||
logging.info("Config: %s", config.pretty_print())
|
||||
|
||||
device = setup_and_get_device()
|
||||
|
||||
# Always enable tensorfloat on supported devices.
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
loss_fn = losses.build_multi_task_loss(
|
||||
loss_type=LossType.BCE_WITH_LOGITS,
|
||||
tasks=list(config.model.tasks.keys()),
|
||||
pos_weights=[task.pos_weight for task in config.model.tasks.values()],
|
||||
)
|
||||
|
||||
# Since the prod model doesn't use large embeddings, for now we won't support them.
|
||||
assert config.model.large_embeddings is None
|
||||
|
||||
train_dataset = ds.RecapDataset(
|
||||
data_config=config.train_data,
|
||||
dataset_service=data_service_dispatcher,
|
||||
mode=recap_config_mod.JobMode.TRAIN,
|
||||
compression=config.train_data.dataset_service_compression,
|
||||
vocab_mapper=None,
|
||||
repeat=True,
|
||||
)
|
||||
|
||||
train_iterator = iter(train_dataset.to_dataloader())
|
||||
|
||||
torch_element_spec = train_dataset.torch_element_spec
|
||||
|
||||
model = model_mod.create_ranking_model(
|
||||
data_spec=torch_element_spec[0],
|
||||
config=config,
|
||||
loss_fn=loss_fn,
|
||||
device=device,
|
||||
)
|
||||
|
||||
optimizer, scheduler = optimizer_mod.build_optimizer(model, config.optimizer, None)
|
||||
|
||||
model = maybe_shard_model(model, device)
|
||||
|
||||
datetime_str = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
|
||||
print(f"{datetime_str}\n", end="")
|
||||
|
||||
if FLAGS.debug_loop:
|
||||
logging.warning("Running debug mode, slow!")
|
||||
train_mod = debug_training_loop
|
||||
else:
|
||||
train_mod = ctl
|
||||
|
||||
train_mod.train(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
device=device,
|
||||
save_dir=config.training.save_dir,
|
||||
logging_interval=config.training.train_log_every_n,
|
||||
train_steps=config.training.num_train_steps,
|
||||
checkpoint_frequency=config.training.checkpoint_every_n,
|
||||
dataset=train_iterator,
|
||||
worker_batch_size=config.train_data.global_batch_size,
|
||||
enable_amp=False,
|
||||
initial_checkpoint_dir=config.training.initial_checkpoint_dir,
|
||||
gradient_accumulation=config.training.gradient_accumulation,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(run)
|
7
projects/home/recap/model/__init__.py
Normal file
7
projects/home/recap/model/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from tml.projects.home.recap.model.entrypoint import (
|
||||
create_ranking_model,
|
||||
sanitize,
|
||||
unsanitize,
|
||||
MultiTaskRankingModel,
|
||||
)
|
||||
from tml.projects.home.recap.model.model_and_loss import ModelAndLoss
|
260
projects/home/recap/model/config.py
Normal file
260
projects/home/recap/model/config.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""Configuration for the main Recap model."""
|
||||
|
||||
import enum
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
import tml.core.config as base_config
|
||||
from tml.projects.home.recap.embedding import config as embedding_config
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class DropoutConfig(base_config.BaseConfig):
|
||||
"""Configuration for the dropout layer."""
|
||||
|
||||
rate: pydantic.PositiveFloat = pydantic.Field(
|
||||
0.1, description="Fraction of inputs to be dropped."
|
||||
)
|
||||
|
||||
|
||||
class LayerNormConfig(base_config.BaseConfig):
|
||||
"""Configruation for the layer normalization."""
|
||||
|
||||
epsilon: float = pydantic.Field(
|
||||
1e-3, description="Small float added to variance to avoid dividing by zero."
|
||||
)
|
||||
axis: int = pydantic.Field(-1, description="Axis or axes to normalize across.")
|
||||
center: bool = pydantic.Field(True, description="Whether to add learnable center.")
|
||||
scale: bool = pydantic.Field(True, description="Whether to add learnable scale.")
|
||||
|
||||
|
||||
class BatchNormConfig(base_config.BaseConfig):
|
||||
"""Configuration of the batch normalization layer."""
|
||||
|
||||
epsilon: pydantic.PositiveFloat = 1e-5
|
||||
momentum: pydantic.PositiveFloat = 0.9
|
||||
training_mode_at_inference_time: bool = False
|
||||
use_renorm: bool = False
|
||||
center: bool = pydantic.Field(True, description="Whether to add learnable center.")
|
||||
scale: bool = pydantic.Field(True, description="Whether to add learnable scale.")
|
||||
|
||||
|
||||
class DenseLayerConfig(base_config.BaseConfig):
|
||||
layer_size: pydantic.PositiveInt
|
||||
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
|
||||
|
||||
|
||||
class MlpConfig(base_config.BaseConfig):
|
||||
"""Configuration for MLP model."""
|
||||
|
||||
layer_sizes: List[pydantic.PositiveInt] = pydantic.Field(None, one_of="mlp_layer_definition")
|
||||
layers: List[DenseLayerConfig] = pydantic.Field(None, one_of="mlp_layer_definition")
|
||||
|
||||
|
||||
class BatchNormConfig(base_config.BaseConfig):
|
||||
"""Configuration for the batch norm layer."""
|
||||
|
||||
affine: bool = pydantic.Field(True, description="Use affine transformation.")
|
||||
momentum: pydantic.PositiveFloat = pydantic.Field(
|
||||
0.1, description="Forgetting parameter in moving average."
|
||||
)
|
||||
|
||||
|
||||
class DoubleNormLogConfig(base_config.BaseConfig):
|
||||
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
|
||||
clip_magnitude: float = pydantic.Field(
|
||||
5.0, description="Threshold to clip the normalized input values."
|
||||
)
|
||||
layer_norm_config: Optional[LayerNormConfig] = pydantic.Field(None)
|
||||
|
||||
|
||||
class Log1pAbsConfig(base_config.BaseConfig):
|
||||
"""Simple configuration where only the log transform is performed."""
|
||||
|
||||
|
||||
class ClipLog1pAbsConfig(base_config.BaseConfig):
|
||||
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
|
||||
3e38, description="Threshold to clip the input values."
|
||||
)
|
||||
|
||||
|
||||
class ZScoreLogConfig(base_config.BaseConfig):
|
||||
analysis_path: str
|
||||
schema_path: str = pydantic.Field(
|
||||
None,
|
||||
description="Schema path which feaure statistics are generated with. Can be different from scehma in data config.",
|
||||
)
|
||||
clip_magnitude: float = pydantic.Field(
|
||||
5.0, description="Threshold to clip the normalized input values."
|
||||
)
|
||||
use_batch_norm: bool = pydantic.Field(
|
||||
False, description="Option to use batch normalization on the inputs."
|
||||
)
|
||||
use_renorm: bool = pydantic.Field(
|
||||
False, description="Option to use batch renormalization for trainig and serving consistency."
|
||||
)
|
||||
use_bq_stats: bool = pydantic.Field(
|
||||
False, description="Option to load the partitioned json files from BQ as statistics."
|
||||
)
|
||||
|
||||
|
||||
class FeaturizationConfig(base_config.BaseConfig):
|
||||
"""Configuration for featurization."""
|
||||
|
||||
log1p_abs_config: Log1pAbsConfig = pydantic.Field(None, one_of="featurization")
|
||||
clip_log1p_abs_config: ClipLog1pAbsConfig = pydantic.Field(None, one_of="featurization")
|
||||
z_score_log_config: ZScoreLogConfig = pydantic.Field(None, one_of="featurization")
|
||||
double_norm_log_config: DoubleNormLogConfig = pydantic.Field(None, one_of="featurization")
|
||||
feature_names_to_concat: List[str] = pydantic.Field(
|
||||
["binary"], description="Feature names to concatenate as raw values with continuous features."
|
||||
)
|
||||
|
||||
|
||||
class DropoutConfig(base_config.BaseConfig):
|
||||
"""Configuration for the dropout layer."""
|
||||
|
||||
rate: pydantic.PositiveFloat = pydantic.Field(
|
||||
0.1, description="Fraction of inputs to be dropped."
|
||||
)
|
||||
|
||||
|
||||
class MlpConfig(base_config.BaseConfig):
|
||||
"""Configuration for MLP model."""
|
||||
|
||||
layer_sizes: List[pydantic.PositiveInt]
|
||||
batch_norm: BatchNormConfig = pydantic.Field(
|
||||
None, description="Optional batch norm configuration."
|
||||
)
|
||||
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout configuration.")
|
||||
final_layer_activation: bool = pydantic.Field(
|
||||
False, description="Whether to include activation on final layer."
|
||||
)
|
||||
|
||||
|
||||
class DcnConfig(base_config.BaseConfig):
|
||||
"""Config for DCN model."""
|
||||
|
||||
poly_degree: pydantic.PositiveInt
|
||||
projection_dim: pydantic.PositiveInt = pydantic.Field(
|
||||
None, description="Factorizes main DCN matmul with projection."
|
||||
)
|
||||
|
||||
parallel_mlp: Optional[MlpConfig] = pydantic.Field(
|
||||
None, description="Config for the mlp if used. If None, only the cross layers are used."
|
||||
)
|
||||
use_parallel: bool = pydantic.Field(True, description="Whether to use parallel DCN.")
|
||||
|
||||
output_mlp: Optional[MlpConfig] = pydantic.Field(None, description="Config for the output mlp.")
|
||||
|
||||
|
||||
class MaskBlockConfig(base_config.BaseConfig):
|
||||
output_size: int
|
||||
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
|
||||
None, one_of="aggregation_size"
|
||||
)
|
||||
aggregation_size: Optional[pydantic.PositiveInt] = pydantic.Field(
|
||||
None, description="Specify the aggregation size directly.", one_of="aggregation_size"
|
||||
)
|
||||
input_layer_norm: bool
|
||||
|
||||
|
||||
class MaskNetConfig(base_config.BaseConfig):
|
||||
mask_blocks: List[MaskBlockConfig]
|
||||
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
|
||||
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
|
||||
|
||||
|
||||
class PositionDebiasConfig(base_config.BaseConfig):
|
||||
"""
|
||||
Configuration for Position Debias.
|
||||
"""
|
||||
|
||||
max_position: int = pydantic.Field(256, description="Bucket all later positions.")
|
||||
num_dims: pydantic.PositiveInt = pydantic.Field(
|
||||
64, description="Number of dimensions in embedding."
|
||||
)
|
||||
drop_probability: float = pydantic.Field(0.5, description="Probability of dropping position.")
|
||||
|
||||
# Currently it should be 51 based on dataset being tested at the time of writing this model
|
||||
# However, no default provided here to make sure user of the model is aware of its importance.
|
||||
position_feature_index: int = pydantic.Field(
|
||||
description="The index of the position feature in the discrete features"
|
||||
)
|
||||
|
||||
|
||||
class AffineMap(base_config.BaseConfig):
|
||||
"""An affine map that scales the logits into the appropriate range."""
|
||||
|
||||
scale: float = pydantic.Field(1.0)
|
||||
bias: float = pydantic.Field(0.0)
|
||||
|
||||
|
||||
class DLRMConfig(base_config.BaseConfig):
|
||||
bottom_mlp: MlpConfig = pydantic.Field(
|
||||
...,
|
||||
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
|
||||
)
|
||||
top_mlp: MlpConfig = pydantic.Field(..., description="Top mlp, generate the final output")
|
||||
|
||||
|
||||
class TaskModel(base_config.BaseConfig):
|
||||
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
|
||||
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
|
||||
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
|
||||
mask_net_config: MaskNetConfig = pydantic.Field(None, one_of="architecture")
|
||||
|
||||
affine_map: AffineMap = pydantic.Field(
|
||||
None,
|
||||
description="Affine map applied to logits so we can represent a broader range of probabilities.",
|
||||
)
|
||||
# DANGER DANGER: not implemented yet.
|
||||
# loss_weight: float = pydantic.Field(1.0, description="Weight for task in loss.")
|
||||
pos_weight: float = pydantic.Field(1.0, description="Weight of positive in loss.")
|
||||
|
||||
|
||||
class MultiTaskType(str, enum.Enum):
|
||||
SHARE_NONE = "share_none" # Tasks are separate.
|
||||
SHARE_ALL = "share_all" # Tasks share same backbone.
|
||||
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
|
||||
|
||||
|
||||
class ModelConfig(base_config.BaseConfig):
|
||||
"""Specify model architecture."""
|
||||
|
||||
tasks: Dict[str, TaskModel] = pydantic.Field(
|
||||
description="Specification of architecture per task."
|
||||
)
|
||||
|
||||
large_embeddings: embedding_config.LargeEmbeddingsConfig = pydantic.Field(None)
|
||||
small_embeddings: embedding_config.SmallEmbeddingsConfig = pydantic.Field(None)
|
||||
# Not implemented yet.
|
||||
# multi_task_loss_reduction_fn: str = "mean"
|
||||
|
||||
position_debias_config: PositionDebiasConfig = pydantic.Field(
|
||||
default=None, description="position debias model configuration"
|
||||
)
|
||||
|
||||
featurization_config: FeaturizationConfig = pydantic.Field(None)
|
||||
|
||||
multi_task_type: MultiTaskType = pydantic.Field(
|
||||
MultiTaskType.SHARE_NONE, description="Multi task architecture"
|
||||
)
|
||||
|
||||
backbone: TaskModel = pydantic.Field(None, description="Type of architecture for the backbone.")
|
||||
stratifiers: List[embedding_config.StratifierConfig] = pydantic.Field(
|
||||
default=None, description="Discrete features and values to stratify metrics by."
|
||||
)
|
||||
|
||||
@pydantic.root_validator()
|
||||
def _validate_mtl(cls, values):
|
||||
if values.get("multi_task_type", None) is None:
|
||||
return values
|
||||
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:
|
||||
if values.get("backbone", None) is None:
|
||||
raise ValueError("Require `backbone` for SHARE_ALL and SHARE_PARTIAL.")
|
||||
elif values["multi_task_type"] in [
|
||||
MultiTaskType.SHARE_NONE,
|
||||
]:
|
||||
if values.get("backbone", None) is not None:
|
||||
raise ValueError("Can not have backbone if the share type is SHARE_NONE")
|
||||
return values
|
304
projects/home/recap/model/entrypoint.py
Normal file
304
projects/home/recap/model/entrypoint.py
Normal file
@ -0,0 +1,304 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from absl import logging
|
||||
import torch
|
||||
from typing import Optional, Callable, Mapping, Dict, Sequence, TYPE_CHECKING
|
||||
from tml.projects.home.recap.model import feature_transform
|
||||
from tml.projects.home.recap.model import config as model_config_mod
|
||||
from tml.projects.home.recap.model import mlp
|
||||
from tml.projects.home.recap.model import mask_net
|
||||
from tml.projects.home.recap.model import numeric_calibration
|
||||
from tml.projects.home.recap.model.model_and_loss import ModelAndLoss
|
||||
import tml.projects.home.recap.model.config as model_config_mod
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tml.projects.home.recap import config as config_mod
|
||||
from tml.projects.home.recap.data.config import RecapDataConfig
|
||||
from tml.projects.home.recap.model.config import ModelConfig
|
||||
|
||||
|
||||
def sanitize(task_name):
|
||||
return task_name.replace(".", "__")
|
||||
|
||||
|
||||
def unsanitize(sanitized_task_name):
|
||||
return sanitized_task_name.replace("__", ".")
|
||||
|
||||
|
||||
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
|
||||
""" "Builds a model for a single task"""
|
||||
if task.mlp_config:
|
||||
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
|
||||
elif task.dcn_config:
|
||||
return dcn.Dcn(dcn_config=task.dcn_config, in_features=input_shape)
|
||||
elif task.mask_net_config:
|
||||
return mask_net.MaskNet(mask_net_config=task.mask_net_config, in_features=input_shape)
|
||||
else:
|
||||
raise ValueError("This should never be reached.")
|
||||
|
||||
|
||||
class MultiTaskRankingModel(torch.nn.Module):
|
||||
"""Multi-task ranking model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shapes: Mapping[str, torch.Size],
|
||||
config: ModelConfig,
|
||||
data_config: RecapDataConfig,
|
||||
return_backbone: bool = False,
|
||||
):
|
||||
"""Constructor for Multi task learning.
|
||||
|
||||
Assumptions made:
|
||||
1. Tasks specified in data config match model architecture.
|
||||
|
||||
These are all validated in config.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._config = config
|
||||
self._data_config = data_config
|
||||
|
||||
self._preprocessor = feature_transform.build_features_preprocessor(
|
||||
config.featurization_config, input_shapes
|
||||
)
|
||||
|
||||
self.return_backbone = return_backbone
|
||||
|
||||
self.embeddings = None
|
||||
self.small_embeddings = None
|
||||
embedding_dims = 0
|
||||
if config.large_embeddings:
|
||||
from large_embeddings.models.learnable_embeddings import LargeEmbeddings
|
||||
|
||||
self.embeddings = LargeEmbeddings(large_embeddings_config=config.large_embeddings)
|
||||
|
||||
embedding_dims += sum([table.embedding_dim for table in config.large_embeddings.tables])
|
||||
logging.info(f"Emb dim: {embedding_dims}")
|
||||
|
||||
if config.small_embeddings:
|
||||
self.small_embeddings = SmallEmbedding(config.small_embeddings)
|
||||
embedding_dims += sum([table.embedding_dim for table in config.small_embeddings.tables])
|
||||
logging.info(f"Emb dim (with small embeddings): {embedding_dims}")
|
||||
|
||||
if "user_embedding" in data_config.seg_dense_schema.renamed_features:
|
||||
embedding_dims += input_shapes["user_embedding"][-1]
|
||||
self._user_embedding_layer_norm = torch.nn.LayerNorm(input_shapes["user_embedding"][-1])
|
||||
else:
|
||||
self._user_embedding_layer_norm = None
|
||||
if "user_eng_embedding" in data_config.seg_dense_schema.renamed_features:
|
||||
embedding_dims += input_shapes["user_eng_embedding"][-1]
|
||||
self._user_eng_embedding_layer_norm = torch.nn.LayerNorm(
|
||||
input_shapes["user_eng_embedding"][-1]
|
||||
)
|
||||
else:
|
||||
self._user_eng_embedding_layer_norm = None
|
||||
if "author_embedding" in data_config.seg_dense_schema.renamed_features:
|
||||
embedding_dims += input_shapes["author_embedding"][-1]
|
||||
self._author_embedding_layer_norm = torch.nn.LayerNorm(input_shapes["author_embedding"][-1])
|
||||
else:
|
||||
self._author_embedding_layer_norm = None
|
||||
|
||||
input_dims = input_shapes["continuous"][-1] + input_shapes["binary"][-1] + embedding_dims
|
||||
|
||||
if config.position_debias_config:
|
||||
self.position_debias_model = PositionDebias(config.position_debias_config)
|
||||
input_dims += self.position_debias_model.out_features
|
||||
else:
|
||||
self.position_debias_model = None
|
||||
logging.info(f"input dim: {input_dims}")
|
||||
|
||||
if config.multi_task_type in [
|
||||
model_config_mod.MultiTaskType.SHARE_ALL,
|
||||
model_config_mod.MultiTaskType.SHARE_PARTIAL,
|
||||
]:
|
||||
self._backbone = _build_single_task_model(config.backbone, input_dims)
|
||||
else:
|
||||
self._backbone = None
|
||||
|
||||
_towers: Dict[str, torch.nn.Module] = {}
|
||||
_calibrators: Dict[str, torch.nn.Module] = {}
|
||||
_affine_maps: Dict[str, torch.nn.Module] = {}
|
||||
|
||||
for task_name, task_architecture in config.tasks.items():
|
||||
safe_name = sanitize(task_name)
|
||||
|
||||
# Complex input dimension calculation.
|
||||
if config.multi_task_type == model_config_mod.MultiTaskType.SHARE_NONE:
|
||||
num_inputs = input_dims
|
||||
elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:
|
||||
num_inputs = self._backbone.out_features
|
||||
elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:
|
||||
num_inputs = input_dims + self._backbone.out_features
|
||||
else:
|
||||
raise ValueError("Unreachable branch of enum.")
|
||||
|
||||
# Annoyingly, ModuleDict doesn't allow . inside key names.
|
||||
_towers[safe_name] = _build_single_task_model(task_architecture, num_inputs)
|
||||
|
||||
if task_architecture.affine_map:
|
||||
affine_map = torch.nn.Linear(1, 1)
|
||||
affine_map.weight.data = torch.tensor([[task_architecture.affine_map.scale]])
|
||||
affine_map.bias.data = torch.tensor([task_architecture.affine_map.bias])
|
||||
_affine_maps[safe_name] = affine_map
|
||||
else:
|
||||
_affine_maps[safe_name] = torch.nn.Identity()
|
||||
|
||||
_calibrators[safe_name] = numeric_calibration.NumericCalibration(
|
||||
pos_downsampling_rate=data_config.tasks[task_name].pos_downsampling_rate,
|
||||
neg_downsampling_rate=data_config.tasks[task_name].neg_downsampling_rate,
|
||||
)
|
||||
|
||||
self._task_names = list(config.tasks.keys())
|
||||
self._towers = torch.nn.ModuleDict(_towers)
|
||||
self._affine_maps = torch.nn.ModuleDict(_affine_maps)
|
||||
self._calibrators = torch.nn.ModuleDict(_calibrators)
|
||||
|
||||
self._counter = torch.autograd.Variable(torch.tensor(0), requires_grad=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
continuous_features: torch.Tensor,
|
||||
binary_features: torch.Tensor,
|
||||
discrete_features: Optional[torch.Tensor] = None,
|
||||
sparse_features=None, # Optional[KeyedJaggedTensor]
|
||||
user_embedding: Optional[torch.Tensor] = None,
|
||||
user_eng_embedding: Optional[torch.Tensor] = None,
|
||||
author_embedding: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
weights: Optional[torch.Tensor] = None,
|
||||
):
|
||||
concat_dense_features = [
|
||||
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
|
||||
]
|
||||
|
||||
if self.embeddings:
|
||||
concat_dense_features.append(self.embeddings(sparse_features))
|
||||
|
||||
# Twhin embedding layer norms
|
||||
if self.small_embeddings:
|
||||
if discrete_features is None:
|
||||
raise ValueError(
|
||||
"Forward arg discrete_features is None, but since small_embeddings are used, a Tensor is expected."
|
||||
)
|
||||
concat_dense_features.append(self.small_embeddings(discrete_features))
|
||||
|
||||
if self._user_embedding_layer_norm:
|
||||
if user_embedding is None:
|
||||
raise ValueError(
|
||||
"Forward arg user_embedding is None, but since Twhin user_embeddings are used by the model, a Tensor is expected."
|
||||
)
|
||||
concat_dense_features.append(self._user_embedding_layer_norm(user_embedding))
|
||||
|
||||
if self._user_eng_embedding_layer_norm:
|
||||
if user_eng_embedding is None:
|
||||
raise ValueError(
|
||||
"Forward arg user_eng_embedding is None, but since Twhin user_eng_embeddings are used by the model, a Tensor is expected."
|
||||
)
|
||||
concat_dense_features.append(self._user_eng_embedding_layer_norm(user_eng_embedding))
|
||||
|
||||
if self._author_embedding_layer_norm:
|
||||
if author_embedding is None:
|
||||
raise ValueError(
|
||||
"Forward arg author_embedding is None, but since Twhin author_embeddings are used by the model, a Tensor is expected."
|
||||
)
|
||||
concat_dense_features.append(self._author_embedding_layer_norm(author_embedding))
|
||||
|
||||
if self.position_debias_model:
|
||||
if discrete_features is None:
|
||||
raise ValueError(
|
||||
"Forward arg discrete_features is None, but since position_debias_model is used, a Tensor is expected."
|
||||
)
|
||||
concat_dense_features.append(self.position_debias_model(discrete_features))
|
||||
|
||||
if discrete_features is not None and not (self.position_debias_model or self.small_embeddings):
|
||||
logging.warning("Forward arg discrete_features is passed, but never used.")
|
||||
|
||||
concat_dense_features = torch.cat(concat_dense_features, dim=1)
|
||||
|
||||
if self._backbone:
|
||||
if self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:
|
||||
net = self._backbone(concat_dense_features)["output"]
|
||||
elif self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:
|
||||
net = torch.cat(
|
||||
[concat_dense_features, self._backbone(concat_dense_features)["output"]], dim=1
|
||||
)
|
||||
else:
|
||||
net = concat_dense_features
|
||||
|
||||
backbone_result = net
|
||||
|
||||
all_logits = []
|
||||
all_probabilities = []
|
||||
all_calibrated_probabilities = []
|
||||
|
||||
for task_name in self._task_names:
|
||||
safe_name = sanitize(task_name)
|
||||
tower_outputs = self._towers[safe_name](net)
|
||||
logits = tower_outputs["output"]
|
||||
scaled_logits = self._affine_maps[safe_name](logits)
|
||||
probabilities = torch.sigmoid(scaled_logits)
|
||||
calibrated_probabilities = self._calibrators[safe_name](probabilities)
|
||||
|
||||
all_logits.append(scaled_logits)
|
||||
all_probabilities.append(probabilities)
|
||||
all_calibrated_probabilities.append(calibrated_probabilities)
|
||||
|
||||
results = {
|
||||
"logits": torch.squeeze(torch.stack(all_logits, dim=1), dim=-1),
|
||||
"probabilities": torch.squeeze(torch.stack(all_probabilities, dim=1), dim=-1),
|
||||
"calibrated_probabilities": torch.squeeze(
|
||||
torch.stack(all_calibrated_probabilities, dim=1), dim=-1
|
||||
),
|
||||
}
|
||||
|
||||
# Returning the backbone is intended for stitching post-tf conversion
|
||||
# Leaving this on will ~200x the size of the output
|
||||
# and could slow things down
|
||||
if self.return_backbone:
|
||||
results["backbone"] = backbone_result
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def create_ranking_model(
|
||||
data_spec,
|
||||
# Used for planner to be batch size aware.
|
||||
config: config_mod.RecapConfig,
|
||||
device: torch.device,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
data_config=None,
|
||||
return_backbone=False,
|
||||
):
|
||||
|
||||
if list(config.model.tasks.values())[0].dlrm_config:
|
||||
raise NotImplementedError()
|
||||
model = EmbeddingRankingModel(
|
||||
input_shapes=data_spec,
|
||||
config=all_config.model,
|
||||
data_config=all_config.train_data,
|
||||
)
|
||||
else:
|
||||
model = MultiTaskRankingModel(
|
||||
input_shapes=data_spec,
|
||||
config=config.model,
|
||||
data_config=data_config if data_config is not None else config.train_data,
|
||||
return_backbone=return_backbone,
|
||||
)
|
||||
|
||||
logging.info("***** Model Architecture *****")
|
||||
logging.info(model)
|
||||
|
||||
logging.info("***** Named Parameters *****")
|
||||
for elem in model.named_parameters():
|
||||
logging.info(elem[0])
|
||||
|
||||
if loss_fn:
|
||||
logging.info("***** Wrapping in loss *****")
|
||||
model = ModelAndLoss(
|
||||
model=model,
|
||||
loss_fn=loss_fn,
|
||||
stratifiers=config.model.stratifiers,
|
||||
)
|
||||
|
||||
return model
|
122
projects/home/recap/model/feature_transform.py
Normal file
122
projects/home/recap/model/feature_transform.py
Normal file
@ -0,0 +1,122 @@
|
||||
from typing import Mapping, Sequence, Union
|
||||
|
||||
from tml.projects.home.recap.model.config import (
|
||||
BatchNormConfig,
|
||||
DoubleNormLogConfig,
|
||||
FeaturizationConfig,
|
||||
LayerNormConfig,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def log_transform(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Safe log transform that works across both negative, zero, and positive floats."""
|
||||
return torch.sign(x) * torch.log1p(torch.abs(x))
|
||||
|
||||
|
||||
class BatchNorm(torch.nn.Module):
|
||||
def __init__(self, num_features: int, config: BatchNormConfig):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.Module):
|
||||
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
|
||||
super().__init__()
|
||||
if config.axis != -1:
|
||||
raise NotImplementedError
|
||||
if config.center != config.scale:
|
||||
raise ValueError(
|
||||
f"Center and scale must match in torch, received {config.center}, {config.scale}"
|
||||
)
|
||||
self.layer = torch.nn.LayerNorm(
|
||||
normalized_shape, eps=config.epsilon, elementwise_affine=config.center
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class Log1pAbs(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return log_transform(x)
|
||||
|
||||
|
||||
class InputNonFinite(torch.nn.Module):
|
||||
def __init__(self, fill_value: float = 0):
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer(
|
||||
"fill_value", torch.as_tensor(fill_value, dtype=torch.float32), persistent=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(torch.isfinite(x), x, self.fill_value)
|
||||
|
||||
|
||||
class Clamp(torch.nn.Module):
|
||||
def __init__(self, min_value: float, max_value: float):
|
||||
super().__init__()
|
||||
# Using buffer to make sure they are on correct device (and not moved every time).
|
||||
# Will also be part of state_dict.
|
||||
self.register_buffer(
|
||||
"min_value", torch.as_tensor(min_value, dtype=torch.float32), persistent=True
|
||||
)
|
||||
self.register_buffer(
|
||||
"max_value", torch.as_tensor(max_value, dtype=torch.float32), persistent=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.clamp(x, min=self.min_value, max=self.max_value)
|
||||
|
||||
|
||||
class DoubleNormLog(torch.nn.Module):
|
||||
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shapes: Mapping[str, Sequence[int]],
|
||||
config: DoubleNormLogConfig,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
_before_concat_layers = [
|
||||
InputNonFinite(),
|
||||
Log1pAbs(),
|
||||
]
|
||||
if config.batch_norm_config:
|
||||
_before_concat_layers.append(
|
||||
BatchNorm(input_shapes["continuous"][-1], config.batch_norm_config)
|
||||
)
|
||||
_before_concat_layers.append(
|
||||
Clamp(min_value=-config.clip_magnitude, max_value=config.clip_magnitude)
|
||||
)
|
||||
self._before_concat_layers = torch.nn.Sequential(*_before_concat_layers)
|
||||
|
||||
self.layer_norm = None
|
||||
if config.layer_norm_config:
|
||||
last_dim = input_shapes["continuous"][-1] + input_shapes["binary"][-1]
|
||||
self.layer_norm = LayerNorm(last_dim, config.layer_norm_config)
|
||||
|
||||
def forward(
|
||||
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
x = self._before_concat_layers(continuous_features)
|
||||
x = torch.cat([x, binary_features], dim=1)
|
||||
if self.layer_norm:
|
||||
return self.layer_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
def build_features_preprocessor(
|
||||
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
|
||||
):
|
||||
"""Trivial right now, but we will change in the future."""
|
||||
return DoubleNormLog(input_shapes, config.double_norm_log_config)
|
98
projects/home/recap/model/mask_net.py
Normal file
98
projects/home/recap/model/mask_net.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""MaskNet: Wang et al. (https://arxiv.org/abs/2102.07619)."""
|
||||
|
||||
from tml.projects.home.recap.model import config, mlp
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _init_weights(module):
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
torch.nn.init.constant_(module.bias, 0)
|
||||
|
||||
|
||||
class MaskBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
||||
) -> None:
|
||||
super(MaskBlock, self).__init__()
|
||||
self.mask_block_config = mask_block_config
|
||||
output_size = mask_block_config.output_size
|
||||
|
||||
if mask_block_config.input_layer_norm:
|
||||
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
|
||||
else:
|
||||
self._input_layer_norm = None
|
||||
|
||||
if mask_block_config.reduction_factor:
|
||||
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
|
||||
elif mask_block_config.aggregation_size is not None:
|
||||
aggregation_size = mask_block_config.aggregation_size
|
||||
else:
|
||||
raise ValueError("Need one of reduction factor or aggregation size.")
|
||||
|
||||
self._mask_layer = torch.nn.Sequential(
|
||||
torch.nn.Linear(mask_input_dim, aggregation_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(aggregation_size, input_dim),
|
||||
)
|
||||
self._mask_layer.apply(_init_weights)
|
||||
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
|
||||
self._hidden_layer.apply(_init_weights)
|
||||
self._layer_norm = torch.nn.LayerNorm(output_size)
|
||||
|
||||
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
|
||||
if self._input_layer_norm:
|
||||
net = self._input_layer_norm(net)
|
||||
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
||||
return self._layer_norm(hidden_layer_output)
|
||||
|
||||
|
||||
class MaskNet(torch.nn.Module):
|
||||
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
||||
super().__init__()
|
||||
self.mask_net_config = mask_net_config
|
||||
mask_blocks = []
|
||||
|
||||
if mask_net_config.use_parallel:
|
||||
total_output_mask_blocks = 0
|
||||
for mask_block_config in mask_net_config.mask_blocks:
|
||||
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
|
||||
total_output_mask_blocks += mask_block_config.output_size
|
||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||
else:
|
||||
input_size = in_features
|
||||
for mask_block_config in mask_net_config.mask_blocks:
|
||||
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
|
||||
input_size = mask_block_config.output_size
|
||||
|
||||
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
|
||||
total_output_mask_blocks = mask_block_config.output_size
|
||||
|
||||
if mask_net_config.mlp:
|
||||
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
|
||||
self.out_features = mask_net_config.mlp.layer_sizes[-1]
|
||||
else:
|
||||
self.out_features = total_output_mask_blocks
|
||||
self.shared_size = total_output_mask_blocks
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
if self.mask_net_config.use_parallel:
|
||||
mask_outputs = []
|
||||
for mask_layer in self._mask_blocks:
|
||||
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
|
||||
# Share the outputs of the MaskBlocks.
|
||||
all_mask_outputs = torch.cat(mask_outputs, dim=1)
|
||||
output = (
|
||||
all_mask_outputs
|
||||
if self.mask_net_config.mlp is None
|
||||
else self._dense_layers(all_mask_outputs)["output"]
|
||||
)
|
||||
return {"output": output, "shared_layer": all_mask_outputs}
|
||||
else:
|
||||
net = inputs
|
||||
for mask_layer in self._mask_blocks:
|
||||
net = mask_layer(net=net, mask_input=inputs)
|
||||
# Share the output of the stacked MaskBlocks.
|
||||
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
|
||||
return {"output": output, "shared_layer": net}
|
58
projects/home/recap/model/mlp.py
Normal file
58
projects/home/recap/model/mlp.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""MLP feed forward stack in torch."""
|
||||
|
||||
from tml.projects.home.recap.model.config import MlpConfig
|
||||
|
||||
import torch
|
||||
from absl import logging
|
||||
|
||||
|
||||
def _init_weights(module):
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
torch.nn.init.constant_(module.bias, 0)
|
||||
|
||||
|
||||
class Mlp(torch.nn.Module):
|
||||
def __init__(self, in_features: int, mlp_config: MlpConfig):
|
||||
super().__init__()
|
||||
self._mlp_config = mlp_config
|
||||
input_size = in_features
|
||||
layer_sizes = mlp_config.layer_sizes
|
||||
modules = []
|
||||
for layer_size in layer_sizes[:-1]:
|
||||
modules.append(torch.nn.Linear(input_size, layer_size, bias=True))
|
||||
|
||||
if mlp_config.batch_norm:
|
||||
modules.append(
|
||||
torch.nn.BatchNorm1d(
|
||||
layer_size, affine=mlp_config.batch_norm.affine, momentum=mlp_config.batch_norm.momentum
|
||||
)
|
||||
)
|
||||
|
||||
modules.append(torch.nn.ReLU())
|
||||
|
||||
if mlp_config.dropout:
|
||||
modules.append(torch.nn.Dropout(mlp_config.dropout.rate))
|
||||
|
||||
input_size = layer_size
|
||||
modules.append(torch.nn.Linear(input_size, layer_sizes[-1], bias=True))
|
||||
if mlp_config.final_layer_activation:
|
||||
modules.append(torch.nn.ReLU())
|
||||
self.layers = torch.nn.ModuleList(modules)
|
||||
self.layers.apply(_init_weights)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
net = x
|
||||
for i, layer in enumerate(self.layers):
|
||||
net = layer(net)
|
||||
if i == 1: # Share the first (widest?) set of activations for other applications.
|
||||
shared_layer = net
|
||||
return {"output": net, "shared_layer": shared_layer}
|
||||
|
||||
@property
|
||||
def shared_size(self):
|
||||
return self._mlp_config.layer_sizes[-1]
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
return self._mlp_config.layer_sizes[-1]
|
67
projects/home/recap/model/model_and_loss.py
Normal file
67
projects/home/recap/model/model_and_loss.py
Normal file
@ -0,0 +1,67 @@
|
||||
from typing import Callable, Optional, List
|
||||
from tml.projects.home.recap.embedding import config as embedding_config_mod
|
||||
import torch
|
||||
from absl import logging
|
||||
|
||||
|
||||
class ModelAndLoss(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
loss_fn: Callable,
|
||||
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
model: torch module to wrap.
|
||||
loss_fn: Function for calculating loss, should accept logits and labels.
|
||||
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.loss_fn = loss_fn
|
||||
self.stratifiers = stratifiers
|
||||
|
||||
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
||||
"""Runs model forward and calculates loss according to given loss_fn.
|
||||
|
||||
NOTE: The input signature here needs to be a Pipelineable object for
|
||||
prefetching purposes during training using torchrec's pipeline. However
|
||||
the underlying model signature needs to be exportable to onnx, requiring
|
||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
||||
|
||||
"""
|
||||
outputs = self.model(
|
||||
continuous_features=batch.continuous_features,
|
||||
binary_features=batch.binary_features,
|
||||
discrete_features=batch.discrete_features,
|
||||
sparse_features=batch.sparse_features,
|
||||
user_embedding=batch.user_embedding,
|
||||
user_eng_embedding=batch.user_eng_embedding,
|
||||
author_embedding=batch.author_embedding,
|
||||
labels=batch.labels,
|
||||
weights=batch.weights,
|
||||
)
|
||||
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
|
||||
|
||||
if self.stratifiers:
|
||||
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}")
|
||||
outputs["stratifiers"] = {}
|
||||
for stratifier in self.stratifiers:
|
||||
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
|
||||
|
||||
# In general, we can have a large number of losses returned by our loss function.
|
||||
if isinstance(losses, dict):
|
||||
return losses["loss"], {
|
||||
**outputs,
|
||||
**losses,
|
||||
"labels": batch.labels,
|
||||
"weights": batch.weights,
|
||||
}
|
||||
else: # Assume that this is a float.
|
||||
return losses, {
|
||||
**outputs,
|
||||
"loss": losses,
|
||||
"labels": batch.labels,
|
||||
"weights": batch.weights,
|
||||
}
|
19
projects/home/recap/model/numeric_calibration.py
Normal file
19
projects/home/recap/model/numeric_calibration.py
Normal file
@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
|
||||
class NumericCalibration(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
pos_downsampling_rate: float,
|
||||
neg_downsampling_rate: float,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Using buffer to make sure they are on correct device (and not moved every time).
|
||||
# Will also be part of state_dict.
|
||||
self.register_buffer(
|
||||
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
|
||||
)
|
||||
|
||||
def forward(self, probs: torch.Tensor):
|
||||
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))
|
1
projects/home/recap/optimizer/__init__.py
Normal file
1
projects/home/recap/optimizer/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from tml.projects.home.recap.optimizer.optimizer import build_optimizer
|
36
projects/home/recap/optimizer/config.py
Normal file
36
projects/home/recap/optimizer/config.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Optimization configurations for models."""
|
||||
|
||||
import typing
|
||||
|
||||
import tml.core.config as base_config
|
||||
import tml.optimizers.config as optimizers_config_mod
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class RecapAdamConfig(base_config.BaseConfig):
|
||||
beta_1: float = 0.9 # Momentum term.
|
||||
beta_2: float = 0.999 # Exponential weighted decay factor.
|
||||
epsilon: float = 1e-7 # Numerical stability in denominator.
|
||||
|
||||
|
||||
class MultiTaskLearningRates(base_config.BaseConfig):
|
||||
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
|
||||
description="Learning rates for different towers of the model."
|
||||
)
|
||||
|
||||
backbone_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(
|
||||
None, description="Learning rate for backbone of the model."
|
||||
)
|
||||
|
||||
|
||||
class RecapOptimizerConfig(base_config.BaseConfig):
|
||||
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
|
||||
None, description="Multiple learning rates for different tasks.", one_of="lr"
|
||||
)
|
||||
|
||||
single_task_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(
|
||||
None, description="Single task learning rates", one_of="lr"
|
||||
)
|
||||
|
||||
adam: RecapAdamConfig = pydantic.Field(one_of="optimizer")
|
166
projects/home/recap/optimizer/optimizer.py
Normal file
166
projects/home/recap/optimizer/optimizer.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""Build optimizers and learning rate schedules."""
|
||||
import bisect
|
||||
from collections import defaultdict
|
||||
import functools
|
||||
import math
|
||||
import typing
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
# from large_embeddings.config import EmbeddingOptimizerConfig
|
||||
from tml.projects.home.recap import model as model_mod
|
||||
from tml.optimizers import config
|
||||
from tml.optimizers import compute_lr
|
||||
from absl import logging # type: ignore[attr-defined]
|
||||
|
||||
import torch
|
||||
from torchrec.optim import keyed
|
||||
|
||||
|
||||
_DEFAULT_LR = 24601.0 # NaN the model if we're not using the learning rate.
|
||||
_BACKBONE = "backbone"
|
||||
_DENSE_EMBEDDINGS = "dense_ebc"
|
||||
|
||||
|
||||
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
|
||||
"""Shim to get learning rates into a LRScheduler.
|
||||
|
||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
||||
e.g. exponential decay can be used.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
lr_dict: typing.Dict[str, config.LearningRate],
|
||||
emb_learning_rate,
|
||||
last_epoch=-1,
|
||||
verbose=False,
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
self.lr_dict = lr_dict
|
||||
self.group_names = list(self.lr_dict.keys())
|
||||
self.emb_learning_rate = emb_learning_rate
|
||||
|
||||
# We handle sparse LR scheduling separately, so only validate LR groups against dense param groups
|
||||
num_dense_param_groups = sum(
|
||||
1
|
||||
for _, _optim in optimizer._optims
|
||||
for _ in _optim.param_groups
|
||||
if isinstance(_optim, keyed.KeyedOptimizerWrapper)
|
||||
)
|
||||
if num_dense_param_groups != len(lr_dict):
|
||||
raise ValueError(
|
||||
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
|
||||
)
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
)
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
learning_rates = []
|
||||
|
||||
for lr_config in self.lr_dict.values():
|
||||
learning_rates.append(compute_lr(lr_config, self.last_epoch))
|
||||
# WARNING: The order of appending is important.
|
||||
if self.emb_learning_rate:
|
||||
learning_rates.append(compute_lr(self.emb_learning_rate, self.last_epoch))
|
||||
return learning_rates
|
||||
|
||||
|
||||
def build_optimizer(
|
||||
model: torch.nn.Module,
|
||||
optimizer_config: config.OptimizerConfig,
|
||||
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
|
||||
):
|
||||
"""Builds an optimizer and scheduler.
|
||||
|
||||
Args:
|
||||
model: A torch model, probably with DDP/DMP.
|
||||
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
|
||||
|
||||
Returns:
|
||||
A torch.optim instance, and a scheduler instance.
|
||||
"""
|
||||
optimizer_fn = functools.partial(
|
||||
torch.optim.Adam,
|
||||
lr=_DEFAULT_LR,
|
||||
betas=(optimizer_config.adam.beta_1, optimizer_config.adam.beta_2),
|
||||
eps=optimizer_config.adam.epsilon,
|
||||
maximize=False,
|
||||
)
|
||||
if optimizer_config.multi_task_learning_rates:
|
||||
logging.info("***** Parameter groups for optimization *****")
|
||||
# Importantly, we preserve insertion order in dictionaries here.
|
||||
parameter_groups: typing.Dict[str, typing.Dict] = defaultdict(dict)
|
||||
added_parameters: typing.Set[str] = set()
|
||||
for task in optimizer_config.multi_task_learning_rates.tower_learning_rates:
|
||||
for name, parameter in model.named_parameters():
|
||||
if f".{model_mod.sanitize(task)}." in name:
|
||||
parameter_groups[task][name] = parameter
|
||||
logging.info(f"{task}: {name}")
|
||||
if name in added_parameters:
|
||||
raise ValueError(f"Parameter {name} matched multiple tasks.")
|
||||
added_parameters.add(name)
|
||||
|
||||
for name, parameter in model.named_parameters():
|
||||
if name not in added_parameters and "embedding_bags" not in name:
|
||||
parameter_groups[_BACKBONE][name] = parameter
|
||||
added_parameters.add(name)
|
||||
logging.info(f"{_BACKBONE}: {name}")
|
||||
|
||||
for name, parameter in model.named_parameters():
|
||||
if name not in added_parameters and "embedding_bags" in name:
|
||||
parameter_groups[_DENSE_EMBEDDINGS][name] = parameter
|
||||
logging.info(f"{_DENSE_EMBEDDINGS}: {name}")
|
||||
|
||||
all_learning_rates = optimizer_config.multi_task_learning_rates.tower_learning_rates.copy()
|
||||
if optimizer_config.multi_task_learning_rates.backbone_learning_rate is not None:
|
||||
all_learning_rates[
|
||||
_BACKBONE
|
||||
] = optimizer_config.multi_task_learning_rates.backbone_learning_rate
|
||||
if _DENSE_EMBEDDINGS in parameter_groups and emb_optimizer_config:
|
||||
all_learning_rates[_DENSE_EMBEDDINGS] = emb_optimizer_config.learning_rate.copy()
|
||||
else:
|
||||
parameter_groups = dict(model.named_parameters())
|
||||
all_learning_rates = {"single_task": optimizer_config.single_task_learning_rate}
|
||||
|
||||
optimizers = [
|
||||
keyed.KeyedOptimizerWrapper(param_group, optimizer_fn)
|
||||
for param_name, param_group in parameter_groups.items()
|
||||
if param_name != _DENSE_EMBEDDINGS
|
||||
]
|
||||
# Making EBC optimizer to be SGD to match fused optimiser
|
||||
if _DENSE_EMBEDDINGS in parameter_groups:
|
||||
optimizers.append(
|
||||
keyed.KeyedOptimizerWrapper(
|
||||
parameter_groups[_DENSE_EMBEDDINGS],
|
||||
functools.partial(torch.optim.SGD, lr=_DEFAULT_LR, maximize=False, momentum=False),
|
||||
)
|
||||
)
|
||||
|
||||
if not parameter_groups.keys() == all_learning_rates.keys():
|
||||
raise ValueError("Learning rates do not match optimizers")
|
||||
|
||||
# If the optimiser is dense, model.fused_optimizer will be empty (but not None)
|
||||
emb_learning_rate = None
|
||||
if hasattr(model, "fused_optimizer") and model.fused_optimizer.optimizers:
|
||||
logging.info(f"Model fused optimiser: {model.fused_optimizer}")
|
||||
optimizers.append(model.fused_optimizer)
|
||||
if emb_optimizer_config:
|
||||
emb_learning_rate = emb_optimizer_config.learning_rate.copy()
|
||||
else:
|
||||
raise ValueError("Fused kernel exists, but LR is not set")
|
||||
logging.info(f"***** Combining optimizers: {optimizers} *****")
|
||||
optimizer = keyed.CombinedOptimizer(optimizers)
|
||||
scheduler = RecapLRShim(optimizer, all_learning_rates, emb_learning_rate)
|
||||
logging.info(f"***** Combined optimizer after init: {optimizer} *****")
|
||||
|
||||
return optimizer, scheduler
|
10
projects/home/recap/script/create_random_data.sh
Executable file
10
projects/home/recap/script/create_random_data.sh
Executable file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Runs from inside venv
|
||||
|
||||
rm -rf $HOME/tmp/runs/recap_local_random_data
|
||||
python -m tml.machines.is_venv || exit 1
|
||||
export TML_BASE="$(git rev-parse --show-toplevel)"
|
||||
|
||||
mkdir -p $HOME/tmp/recap_local_random_data
|
||||
python projects/home/recap/data/generate_random_data.py --config_path $(pwd)/projects/home/recap/config/local_prod.yaml
|
15
projects/home/recap/script/run_local.sh
Executable file
15
projects/home/recap/script/run_local.sh
Executable file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Runs from inside venv
|
||||
rm -rf $HOME/tmp/runs/recap_local_debug
|
||||
mkdir -p $HOME/tmp/runs/recap_local_debug
|
||||
python -m tml.machines.is_venv || exit 1
|
||||
export TML_BASE="$(git rev-parse --show-toplevel)"
|
||||
|
||||
torchrun \
|
||||
--standalone \
|
||||
--nnodes 1 \
|
||||
--nproc_per_node 1 \
|
||||
projects/home/recap/main.py \
|
||||
--config_path $(pwd)/projects/home/recap/config/local_prod.yaml \
|
||||
$@
|
Reference in New Issue
Block a user