mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-25 21:31:14 +01:00
99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
from absl import app, flags
|
|
import json
|
|
from typing import Optional
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
|
|
# isort: on
|
|
from tml.common.device import setup_and_get_device
|
|
from tml.common.utils import setup_configuration
|
|
import tml.core.custom_training_loop as ctl
|
|
import tml.machines.environment as env
|
|
from tml.projects.twhin.models.models import apply_optimizers, TwhinModel, TwhinModelAndLoss
|
|
from tml.model import maybe_shard_model
|
|
from tml.projects.twhin.metrics import create_metrics
|
|
from tml.projects.twhin.config import TwhinConfig
|
|
from tml.projects.twhin.data.data import create_dataset
|
|
from tml.projects.twhin.optimizer import build_optimizer
|
|
|
|
from tml.ml_logging.torch_logging import logging
|
|
|
|
import torch.distributed as dist
|
|
from torch.nn import functional as F
|
|
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
|
|
from torchrec.distributed.model_parallel import get_module
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
flags.DEFINE_bool("overwrite_save_dir", False, "Whether to clear preexisting save directories.")
|
|
flags.DEFINE_string("save_dir", None, "If provided, overwrites the save directory.")
|
|
flags.DEFINE_string("config_yaml_path", None, "Path to hyperparameters for model.")
|
|
flags.DEFINE_string("task", None, "Task to run if this is local. Overrides TF_CONFIG etc.")
|
|
|
|
|
|
def run(
|
|
all_config: TwhinConfig,
|
|
save_dir: Optional[str] = None,
|
|
):
|
|
train_dataset = create_dataset(all_config.train_data, all_config.model)
|
|
|
|
if env.is_reader():
|
|
train_dataset.serve()
|
|
if env.is_chief():
|
|
device = setup_and_get_device(tf_ok=False)
|
|
logging.info(f"device: {device}")
|
|
logging.info(f"WORLD_SIZE: {dist.get_world_size()}")
|
|
|
|
# validation_dataset = create_dataset(all_config.validation_data, all_config.model)
|
|
|
|
global_batch_size = all_config.train_data.per_replica_batch_size * dist.get_world_size()
|
|
|
|
metrics = create_metrics(device)
|
|
|
|
model = TwhinModel(all_config.model, all_config.train_data)
|
|
apply_optimizers(model, all_config.model)
|
|
model = maybe_shard_model(model, device=device)
|
|
optimizer, scheduler = build_optimizer(model=model, config=all_config.model)
|
|
|
|
loss_fn = F.binary_cross_entropy_with_logits
|
|
model_and_loss = TwhinModelAndLoss(
|
|
model, loss_fn, data_config=all_config.train_data, device=device
|
|
)
|
|
|
|
ctl.train(
|
|
model=model_and_loss,
|
|
optimizer=optimizer,
|
|
device=device,
|
|
save_dir=save_dir,
|
|
logging_interval=all_config.training.train_log_every_n,
|
|
train_steps=all_config.training.num_train_steps,
|
|
checkpoint_frequency=all_config.training.checkpoint_every_n,
|
|
dataset=train_dataset.dataloader(remote=False),
|
|
worker_batch_size=global_batch_size,
|
|
num_workers=0,
|
|
scheduler=scheduler,
|
|
initial_checkpoint_dir=all_config.training.initial_checkpoint_dir,
|
|
gradient_accumulation=all_config.training.gradient_accumulation,
|
|
)
|
|
|
|
|
|
def main(argv):
|
|
logging.info("Starting")
|
|
|
|
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")
|
|
all_config = setup_configuration( # type: ignore[var-annotated]
|
|
TwhinConfig,
|
|
yaml_path=FLAGS.config_yaml_path,
|
|
)
|
|
|
|
run(
|
|
all_config,
|
|
save_dir=FLAGS.save_dir,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|