109 lines
3.2 KiB
Python

import datetime
import os
from typing import Callable, List, Optional, Tuple
import tensorflow as tf
import tml.common.checkpointing.snapshot as snapshot_lib
from tml.common.device import setup_and_get_device
from tml.core import config as tml_config_mod
import tml.core.custom_training_loop as ctl
from tml.core import debug_training_loop
from tml.core import losses
from tml.core.loss_type import LossType
from tml.model import maybe_shard_model
import tml.projects.home.recap.data.dataset as ds
import tml.projects.home.recap.config as recap_config_mod
import tml.projects.home.recap.optimizer as optimizer_mod
# from tml.projects.home.recap import feature
import tml.projects.home.recap.model as model_mod
import torchmetrics as tm
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from absl import app, flags, logging
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
flags.DEFINE_bool("debug_loop", False, "Run with debug loop (slow)")
FLAGS = flags.FLAGS
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
print("#" * 100)
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
logging.info("Config: %s", config.pretty_print())
device = setup_and_get_device()
# Always enable tensorfloat on supported devices.
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
loss_fn = losses.build_multi_task_loss(
loss_type=LossType.BCE_WITH_LOGITS,
tasks=list(config.model.tasks.keys()),
pos_weights=[task.pos_weight for task in config.model.tasks.values()],
)
# Since the prod model doesn't use large embeddings, for now we won't support them.
assert config.model.large_embeddings is None
train_dataset = ds.RecapDataset(
data_config=config.train_data,
dataset_service=data_service_dispatcher,
mode=recap_config_mod.JobMode.TRAIN,
compression=config.train_data.dataset_service_compression,
vocab_mapper=None,
repeat=True,
)
train_iterator = iter(train_dataset.to_dataloader())
torch_element_spec = train_dataset.torch_element_spec
model = model_mod.create_ranking_model(
data_spec=torch_element_spec[0],
config=config,
loss_fn=loss_fn,
device=device,
)
optimizer, scheduler = optimizer_mod.build_optimizer(model, config.optimizer, None)
model = maybe_shard_model(model, device)
datetime_str = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
print(f"{datetime_str}\n", end="")
if FLAGS.debug_loop:
logging.warning("Running debug mode, slow!")
train_mod = debug_training_loop
else:
train_mod = ctl
train_mod.train(
model=model,
optimizer=optimizer,
device=device,
save_dir=config.training.save_dir,
logging_interval=config.training.train_log_every_n,
train_steps=config.training.num_train_steps,
checkpoint_frequency=config.training.checkpoint_every_n,
dataset=train_iterator,
worker_batch_size=config.train_data.global_batch_size,
enable_amp=False,
initial_checkpoint_dir=config.training.initial_checkpoint_dir,
gradient_accumulation=config.training.gradient_accumulation,
scheduler=scheduler,
)
if __name__ == "__main__":
app.run(run)