mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-06-19 00:59:26 +02:00
Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings
This commit is contained in:
0
common/__init__.py
Normal file
0
common/__init__.py
Normal file
85
common/batch.py
Normal file
85
common/batch.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
|
||||
"""
|
||||
# flake8: noqa
|
||||
from __future__ import annotations
|
||||
from typing import Dict
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
import dataclasses
|
||||
|
||||
import torch
|
||||
from torchrec.streamable import Pipelineable
|
||||
|
||||
|
||||
class BatchBase(Pipelineable, abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def as_dict(self) -> Dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device: torch.device, non_blocking: bool = False):
|
||||
args = {}
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
||||
return self.__class__(**args)
|
||||
|
||||
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
|
||||
for feature_value in self.as_dict().values():
|
||||
feature_value.record_stream(stream)
|
||||
|
||||
def pin_memory(self):
|
||||
args = {}
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
args[feature_name] = feature_value.pin_memory()
|
||||
return self.__class__(**args)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def obj2str(v):
|
||||
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
||||
|
||||
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
|
||||
|
||||
@property
|
||||
def batch_size(self) -> int:
|
||||
for tensor in self.as_dict().values():
|
||||
if tensor is None:
|
||||
continue
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
continue
|
||||
return tensor.shape[0]
|
||||
raise Exception("Could not determine batch size from tensors.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataclassBatch(BatchBase):
|
||||
@classmethod
|
||||
def feature_names(cls):
|
||||
return list(cls.__dataclass_fields__.keys())
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
feature_name: getattr(self, feature_name)
|
||||
for feature_name in self.feature_names()
|
||||
if hasattr(self, feature_name)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_schema(name: str, schema):
|
||||
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
|
||||
return dataclasses.make_dataclass(
|
||||
cls_name=name,
|
||||
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
|
||||
bases=(DataclassBatch,),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_fields(name: str, fields: dict):
|
||||
return dataclasses.make_dataclass(
|
||||
cls_name=name,
|
||||
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
||||
bases=(DataclassBatch,),
|
||||
)
|
||||
|
||||
|
||||
class DictionaryBatch(BatchBase, dict):
|
||||
def as_dict(self) -> Dict:
|
||||
return self
|
1
common/checkpointing/__init__.py
Normal file
1
common/checkpointing/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot
|
260
common/checkpointing/snapshot.py
Normal file
260
common/checkpointing/snapshot.py
Normal file
@ -0,0 +1,260 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
from tml.common.filesystem import infer_fs, is_gcs_fs
|
||||
|
||||
import torchsnapshot
|
||||
|
||||
|
||||
DONE_EVAL_SUBDIR = "evaled_by"
|
||||
GCS_PREFIX = "gs://"
|
||||
|
||||
|
||||
class Snapshot:
|
||||
"""Checkpoints using torchsnapshot.
|
||||
|
||||
Also saves step to be updated by the training loop.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
|
||||
self.save_dir = save_dir
|
||||
self.state = state
|
||||
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
||||
|
||||
@property
|
||||
def step(self):
|
||||
return self.state["extra_state"]["step"]
|
||||
|
||||
@step.setter
|
||||
def step(self, step: int) -> None:
|
||||
self.state["extra_state"]["step"] = step
|
||||
|
||||
@property
|
||||
def walltime(self):
|
||||
return self.state["extra_state"]["walltime"]
|
||||
|
||||
@walltime.setter
|
||||
def walltime(self, walltime: float) -> None:
|
||||
self.state["extra_state"]["walltime"] = walltime
|
||||
|
||||
def save(self, global_step: int) -> "PendingSnapshot":
|
||||
"""Saves checkpoint with given global_step."""
|
||||
path = os.path.join(self.save_dir, str(global_step))
|
||||
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
||||
start_time = time.time()
|
||||
# Take a snapshot in async manner, the snapshot is consistent that state changes after this method returns have no effect on the snapshot. It performs storage I/O in the background.
|
||||
snapshot = torchsnapshot.Snapshot.async_take(
|
||||
app_state=self.state,
|
||||
path=path,
|
||||
# commented out because DistributedModelParallel model saving
|
||||
# errors with this on multi-GPU. With it removed, CPU, single
|
||||
# GPU, and multi-GPU training all successfully checkpoint.
|
||||
# replicated=["**"],
|
||||
)
|
||||
logging.info(f"Snapshot saved to {snapshot.path} ({time.time() - start_time:.05}s")
|
||||
return snapshot
|
||||
|
||||
def restore(self, checkpoint: str) -> None:
|
||||
"""Restores a given checkpoint."""
|
||||
snapshot = torchsnapshot.Snapshot(path=checkpoint)
|
||||
logging.info(f"Restoring snapshot from {snapshot.path}.")
|
||||
start_time = time.time()
|
||||
# We can remove the try-except when we are confident that we no longer need to restore from
|
||||
# checkpoints from before walltime was added
|
||||
try:
|
||||
# checkpoints that do not have extra_state[walltime] will fail here
|
||||
snapshot.restore(self.state)
|
||||
except RuntimeError:
|
||||
# extra_state[walltime] does not exist in the checkpoint, but step should be there so restore it
|
||||
self.state["extra_state"] = torchsnapshot.StateDict(step=0)
|
||||
snapshot.restore(self.state)
|
||||
# we still need to ensure that extra_state has walltime in it
|
||||
self.state["extra_state"] = torchsnapshot.StateDict(step=self.step, walltime=0.0)
|
||||
|
||||
logging.info(f"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s")
|
||||
|
||||
@classmethod
|
||||
def get_torch_snapshot(
|
||||
cls,
|
||||
snapshot_path: str,
|
||||
global_step: Optional[int] = None,
|
||||
missing_ok: bool = False,
|
||||
) -> torchsnapshot.Snapshot:
|
||||
"""Get torch stateless snapshot, without actually loading it.
|
||||
Args:
|
||||
snapshot_path: path to the model snapshot
|
||||
global_step: restores from this checkpoint if specified.
|
||||
missing_ok: if True and checkpoints do not exist, returns without restoration.
|
||||
"""
|
||||
path = get_checkpoint(snapshot_path, global_step, missing_ok)
|
||||
logging.info(f"Loading snapshot from {path}.")
|
||||
return torchsnapshot.Snapshot(path=path)
|
||||
|
||||
@classmethod
|
||||
def load_snapshot_to_weight(
|
||||
cls,
|
||||
embedding_snapshot: torchsnapshot.Snapshot,
|
||||
snapshot_emb_name: str,
|
||||
weight_tensor,
|
||||
) -> None:
|
||||
"""Loads pretrained embedding from the snapshot to the model.
|
||||
Utilise partial lodaing meachanism from torchsnapshot.
|
||||
Args:
|
||||
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
|
||||
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
|
||||
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
|
||||
"""
|
||||
start_time = time.time()
|
||||
manifest = embedding_snapshot.get_manifest()
|
||||
for path in manifest.keys():
|
||||
if path.startswith("0") and snapshot_emb_name in path:
|
||||
snapshot_path_to_load = path
|
||||
embedding_snapshot.read_object(snapshot_path_to_load, weight_tensor)
|
||||
logging.info(
|
||||
f"Loaded embedding snapshot from {snapshot_path_to_load}: {time.time() - start_time:.05}s",
|
||||
rank=-1,
|
||||
)
|
||||
logging.info(f"Snapshot loaded to {weight_tensor.metadata()}", rank=-1)
|
||||
|
||||
|
||||
def _eval_subdir(checkpoint_path: str) -> str:
|
||||
return os.path.join(checkpoint_path, DONE_EVAL_SUBDIR)
|
||||
|
||||
|
||||
def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
|
||||
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
|
||||
|
||||
|
||||
def is_done_eval(checkpoint_path: str, eval_partition: str):
|
||||
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
|
||||
|
||||
|
||||
def mark_done_eval(checkpoint_path: str, eval_partition: str):
|
||||
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
|
||||
|
||||
|
||||
def step_from_checkpoint(checkpoint: str) -> int:
|
||||
return int(os.path.basename(checkpoint))
|
||||
|
||||
|
||||
def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
|
||||
"""Simplified equivalent of tf.train.checkpoints_iterator.
|
||||
|
||||
Args:
|
||||
seconds_to_sleep: time between polling calls.
|
||||
timeout: how long to wait for a new checkpoint.
|
||||
|
||||
"""
|
||||
|
||||
def _poll(last_checkpoint: Optional[str] = None):
|
||||
stop_time = time.time() + timeout
|
||||
while True:
|
||||
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
|
||||
if not _checkpoint_path or _checkpoint_path == last_checkpoint:
|
||||
if time.time() + seconds_to_sleep > stop_time:
|
||||
logging.info(
|
||||
f"Timed out waiting for next available checkpoint from {save_dir} for {timeout}s."
|
||||
)
|
||||
return None
|
||||
logging.info(f"Waiting for next available checkpoint from {save_dir}.")
|
||||
time.sleep(seconds_to_sleep)
|
||||
else:
|
||||
logging.info(f"Found latest checkpoint {_checkpoint_path}.")
|
||||
return _checkpoint_path
|
||||
|
||||
checkpoint_path = None
|
||||
while True:
|
||||
new_checkpoint = _poll(checkpoint_path)
|
||||
if not new_checkpoint:
|
||||
return
|
||||
checkpoint_path = new_checkpoint
|
||||
yield checkpoint_path
|
||||
|
||||
|
||||
def get_checkpoint(
|
||||
save_dir: str,
|
||||
global_step: Optional[int] = None,
|
||||
missing_ok: bool = False,
|
||||
) -> str:
|
||||
"""Gets latest checkpoint or checkpoint at specified global_step.
|
||||
|
||||
Args:
|
||||
global_step: Finds this checkpoint if specified.
|
||||
missing_ok: if True and checkpoints do not exist, returns without restoration.
|
||||
|
||||
"""
|
||||
checkpoints = get_checkpoints(save_dir)
|
||||
if not checkpoints:
|
||||
if not missing_ok:
|
||||
raise Exception(f"No checkpoints found at {save_dir}")
|
||||
else:
|
||||
logging.info(f"No checkpoints found for restoration at {save_dir}.")
|
||||
return ""
|
||||
|
||||
if global_step is None:
|
||||
return checkpoints[-1]
|
||||
|
||||
logging.info(f"Found checkpoints: {checkpoints}")
|
||||
for checkpoint in checkpoints:
|
||||
step = step_from_checkpoint(checkpoint)
|
||||
if global_step == step:
|
||||
chosen_checkpoint = checkpoint
|
||||
break
|
||||
else:
|
||||
raise Exception(f"Desired checkpoint at {global_step} not found in {save_dir}")
|
||||
return chosen_checkpoint
|
||||
|
||||
|
||||
def get_checkpoints(save_dir: str) -> List[str]:
|
||||
"""Gets all checkpoints that have been fully written."""
|
||||
checkpoints = []
|
||||
fs = infer_fs(save_dir)
|
||||
if fs.exists(save_dir):
|
||||
prefix = GCS_PREFIX if is_gcs_fs(fs) else ""
|
||||
checkpoints = list(f"{prefix}{elem}" for elem in fs.ls(save_dir, detail=False))
|
||||
# Only take checkpoints that were fully written.
|
||||
checkpoints = list(
|
||||
filter(
|
||||
lambda path: fs.exists(f"{path}/{torchsnapshot.snapshot.SNAPSHOT_METADATA_FNAME}"),
|
||||
checkpoints,
|
||||
)
|
||||
)
|
||||
checkpoints = sorted(checkpoints, key=lambda path: int(os.path.basename(path)))
|
||||
return checkpoints
|
||||
|
||||
|
||||
def wait_for_evaluators(
|
||||
save_dir: str,
|
||||
partition_names: List[str],
|
||||
global_step: int,
|
||||
timeout: int,
|
||||
) -> None:
|
||||
logging.info("Waiting for all evaluators to finish.")
|
||||
start_time = time.time()
|
||||
|
||||
for checkpoint in checkpoints_iterator(save_dir):
|
||||
step = step_from_checkpoint(checkpoint)
|
||||
logging.info(f"Considering checkpoint {checkpoint} for global step {global_step}.")
|
||||
if step == global_step:
|
||||
while partition_names:
|
||||
if is_done_eval(checkpoint, partition_names[-1]):
|
||||
logging.info(
|
||||
f"Checkpoint {checkpoint} marked as finished eval for partition {partition_names[-1]} at step {step}, still waiting for {partition_names}."
|
||||
)
|
||||
partition_names.pop()
|
||||
|
||||
if time.time() - start_time >= timeout:
|
||||
logging.warning(
|
||||
f"Not all evaluators finished after waiting for {time.time() - start_time}"
|
||||
)
|
||||
return
|
||||
time.sleep(10)
|
||||
logging.info("All evaluators finished.")
|
||||
return
|
||||
|
||||
if time.time() - start_time >= timeout:
|
||||
logging.warning(f"Not all evaluators finished after waiting for {time.time() - start_time}")
|
||||
return
|
30
common/device.py
Normal file
30
common/device.py
Normal file
@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def maybe_setup_tensorflow():
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
tf.config.set_visible_devices([], "GPU") # disable tf gpu
|
||||
|
||||
|
||||
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
|
||||
if tf_ok:
|
||||
maybe_setup_tensorflow()
|
||||
|
||||
device = torch.device("cpu")
|
||||
backend = "gloo"
|
||||
if torch.cuda.is_available():
|
||||
rank = os.environ["LOCAL_RANK"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
backend = "nccl"
|
||||
torch.cuda.set_device(device)
|
||||
if not torch.distributed.is_initialized():
|
||||
dist.init_process_group(backend)
|
||||
|
||||
return device
|
1
common/filesystem/__init__.py
Normal file
1
common/filesystem/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs
|
16
common/filesystem/test_infer_fs.py
Normal file
16
common/filesystem/test_infer_fs.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""Minimal test for infer_fs.
|
||||
|
||||
Mostly a test that it returns an object
|
||||
"""
|
||||
from tml.common.filesystem import infer_fs
|
||||
|
||||
|
||||
def test_infer_fs():
|
||||
local_path = "/tmp/local_path"
|
||||
gcs_path = "gs://somebucket/somepath"
|
||||
|
||||
local_fs = infer_fs(local_path)
|
||||
gcs_fs = infer_fs(gcs_path)
|
||||
|
||||
# This should return two different objects
|
||||
assert local_fs != gcs_fs
|
25
common/filesystem/util.py
Normal file
25
common/filesystem/util.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""Utilities for interacting with the file systems."""
|
||||
from fsspec.implementations.local import LocalFileSystem
|
||||
import gcsfs
|
||||
|
||||
|
||||
GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
|
||||
LOCAL_FS = LocalFileSystem()
|
||||
|
||||
|
||||
def infer_fs(path: str):
|
||||
if path.startswith("gs://"):
|
||||
return GCS_FS
|
||||
elif path.startswith("hdfs://"):
|
||||
# We can probably use pyarrow HDFS to support this.
|
||||
raise NotImplementedError("HDFS not yet supported")
|
||||
else:
|
||||
return LOCAL_FS
|
||||
|
||||
|
||||
def is_local_fs(fs):
|
||||
return fs == LOCAL_FS
|
||||
|
||||
|
||||
def is_gcs_fs(fs):
|
||||
return fs == GCS_FS
|
90
common/log_weights.py
Normal file
90
common/log_weights.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""For logging model weights."""
|
||||
import itertools
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torchrec.distributed.model_parallel import DistributedModelParallel
|
||||
|
||||
|
||||
def weights_to_log(
|
||||
model: torch.nn.Module,
|
||||
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
|
||||
):
|
||||
"""Creates dict of reduced weights to log to give sense of training.
|
||||
|
||||
Args:
|
||||
model: model to traverse.
|
||||
how_to_log: if a function, then applies this to every parameter, if a dict
|
||||
then only applies and logs specified parameters.
|
||||
|
||||
"""
|
||||
if not how_to_log:
|
||||
return
|
||||
|
||||
to_log = dict()
|
||||
named_parameters = model.named_parameters()
|
||||
logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}")
|
||||
if isinstance(model, DistributedModelParallel):
|
||||
named_parameters = itertools.chain(
|
||||
named_parameters, model._dmp_wrapped_module.named_parameters()
|
||||
)
|
||||
logging.info(
|
||||
f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}"
|
||||
)
|
||||
for param_name, params in named_parameters:
|
||||
if callable(how_to_log):
|
||||
how = how_to_log
|
||||
else:
|
||||
how = how_to_log.get(param_name) # type: ignore[assignment]
|
||||
if not how:
|
||||
continue # type: ignore
|
||||
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
|
||||
return to_log
|
||||
|
||||
|
||||
def log_ebc_norms(
|
||||
model_state_dict,
|
||||
ebc_keys: List[str],
|
||||
sample_size: int = 4_000_000,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Logs the norms of the embedding tables as specified by ebc_keys.
|
||||
As of now, log average norm per rank.
|
||||
|
||||
Args:
|
||||
model_state_dict: model.state_dict()
|
||||
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
|
||||
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
|
||||
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
|
||||
"""
|
||||
norm_logs = dict()
|
||||
for emb_key in ebc_keys:
|
||||
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
|
||||
if emb_key in model_state_dict:
|
||||
emb_weight = model_state_dict[emb_key]
|
||||
try:
|
||||
emb_weight_tensor = emb_weight.local_tensor()
|
||||
except AttributeError as e:
|
||||
logging.info(e)
|
||||
emb_weight_tensor = emb_weight
|
||||
logging.info("Running Tensor.detach()")
|
||||
emb_weight_tensor = emb_weight_tensor.detach()
|
||||
sample_mask = torch.randperm(emb_weight_tensor.shape[0])[
|
||||
: min(sample_size, emb_weight_tensor.shape[0])
|
||||
]
|
||||
# WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks
|
||||
# Change sample_size if the you observe frequent OOM errors or remove weight logging.
|
||||
norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32)
|
||||
logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1)
|
||||
norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}"))
|
||||
|
||||
all_norms = [
|
||||
torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size())
|
||||
]
|
||||
dist.all_gather(all_norms, norms)
|
||||
for idx, norm in enumerate(all_norms):
|
||||
if norm != -1.0:
|
||||
norm_logs[f"{emb_key}-norm-{idx}"] = norm
|
||||
logging.info(f"Norm Logs are {norm_logs}")
|
||||
return norm_logs
|
59
common/modules/embedding/config.py
Normal file
59
common/modules/embedding/config.py
Normal file
@ -0,0 +1,59 @@
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
|
||||
import tml.core.config as base_config
|
||||
from tml.optimizers.config import OptimizerConfig
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
FP32 = "fp32"
|
||||
FP16 = "fp16"
|
||||
|
||||
|
||||
class EmbeddingSnapshot(base_config.BaseConfig):
|
||||
"""Configuration for Embedding snapshot"""
|
||||
|
||||
emb_name: str = pydantic.Field(
|
||||
..., description="Name of the embedding table from the loaded snapshot"
|
||||
)
|
||||
embedding_snapshot_uri: str = pydantic.Field(
|
||||
..., description="Path to torchsnapshot of the embedding"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingBagConfig(base_config.BaseConfig):
|
||||
"""Configuration for EmbeddingBag."""
|
||||
|
||||
name: str = pydantic.Field(..., description="name of embedding bag")
|
||||
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
|
||||
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
|
||||
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
|
||||
vocab: str = pydantic.Field(
|
||||
None, description="Directory to parquet files of mapping from entity ID to table index."
|
||||
)
|
||||
# make sure to use an optimizer that matches:
|
||||
# https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15
|
||||
optimizer: OptimizerConfig
|
||||
data_type: DataType
|
||||
|
||||
|
||||
class LargeEmbeddingsConfig(base_config.BaseConfig):
|
||||
"""Configuration for EmbeddingBagCollection.
|
||||
|
||||
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
|
||||
"""
|
||||
|
||||
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
|
||||
tables_to_log: List[str] = pydantic.Field(
|
||||
None, description="list of embedding table names that we want to log during training"
|
||||
)
|
||||
|
||||
|
||||
class Mode(str, Enum):
|
||||
"""Job modes."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVALUATE = "evaluate"
|
||||
INFERENCE = "inference"
|
58
common/modules/embedding/embedding.py
Normal file
58
common/modules/embedding/embedding.py
Normal file
@ -0,0 +1,58 @@
|
||||
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchrec
|
||||
from torchrec.modules import embedding_configs
|
||||
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LargeEmbeddings(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
large_embeddings_config: LargeEmbeddingsConfig,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
tables = []
|
||||
for table in large_embeddings_config.tables:
|
||||
data_type = (
|
||||
embedding_configs.DataType.FP32
|
||||
if (table.data_type == DataType.FP32)
|
||||
else embedding_configs.DataType.FP16
|
||||
)
|
||||
|
||||
tables.append(
|
||||
EmbeddingBagConfig(
|
||||
embedding_dim=table.embedding_dim,
|
||||
feature_names=[table.name], # restricted to 1 feature per table for now
|
||||
name=table.name,
|
||||
num_embeddings=table.num_embeddings,
|
||||
pooling=torchrec.PoolingType.SUM,
|
||||
data_type=data_type,
|
||||
)
|
||||
)
|
||||
|
||||
self.ebc = EmbeddingBagCollection(
|
||||
device="meta",
|
||||
tables=tables,
|
||||
)
|
||||
|
||||
logging.info("********************** EBC named params are **********")
|
||||
logging.info(list(self.ebc.named_parameters()))
|
||||
|
||||
# This hook is used to perform post-processing surgery
|
||||
# on large_embedding models to prep them for serving
|
||||
self.surgery_cut_point = torch.nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sparse_features: KeyedJaggedTensor,
|
||||
) -> KeyedTensor:
|
||||
pooled_embs = self.ebc(sparse_features)
|
||||
|
||||
# a KeyedTensor
|
||||
return self.surgery_cut_point(pooled_embs)
|
109
common/run_training.py
Normal file
109
common/run_training.py
Normal file
@ -0,0 +1,109 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
||||
from twitter.ml.tensorflow.experimental.distributed import utils
|
||||
|
||||
import torch
|
||||
import torch.distributed.run
|
||||
|
||||
|
||||
def is_distributed_worker():
|
||||
world_size = os.environ.get("WORLD_SIZE", None)
|
||||
rank = os.environ.get("RANK", None)
|
||||
return world_size is not None and rank is not None
|
||||
|
||||
|
||||
def maybe_run_training(
|
||||
train_fn,
|
||||
module_name,
|
||||
nproc_per_node: Optional[int] = None,
|
||||
num_nodes: Optional[int] = None,
|
||||
set_python_path_in_subprocess: bool = False,
|
||||
is_chief: Optional[bool] = False,
|
||||
**training_kwargs,
|
||||
):
|
||||
"""Wrapper function for single node, multi-GPU Pytorch training.
|
||||
|
||||
If the necessary distributed Pytorch environment variables
|
||||
(WORLD_SIZE, RANK) have been set, then this function executes
|
||||
`train_fn(**training_kwargs)`.
|
||||
|
||||
Otherwise, this function calls torchrun and points at the calling module
|
||||
`module_name`. After this call, the necessary environment variables are set
|
||||
and training will commence.
|
||||
|
||||
Args:
|
||||
train_fn: The function that is responsible for training
|
||||
module_name: The name of the module that this function was called from;
|
||||
used to indicate torchrun entrypoint.
|
||||
nproc_per_node: Number of workers per node; supported values.
|
||||
num_nodes: Number of nodes, otherwise inferred from environment.
|
||||
is_chief: If process is running on chief.
|
||||
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
|
||||
"""
|
||||
|
||||
machines = utils.machine_from_env()
|
||||
if num_nodes is None:
|
||||
num_nodes = 1
|
||||
if machines.num_workers:
|
||||
num_nodes += machines.num_workers
|
||||
|
||||
if is_distributed_worker():
|
||||
# world_size, rank, etc are set; assuming any other env vars are set (checks to come)
|
||||
# start the actual training!
|
||||
train_fn(**training_kwargs)
|
||||
else:
|
||||
if nproc_per_node is None:
|
||||
if torch.cuda.is_available():
|
||||
nproc_per_node = torch.cuda.device_count()
|
||||
else:
|
||||
nproc_per_node = machines.chief.num_accelerators
|
||||
|
||||
# Rejoin all arguments to send back through torchrec
|
||||
# this is a temporary measure, will replace the os.system call
|
||||
# with torchrun API calls
|
||||
args = list(f"--{key}={val}" for key, val in training_kwargs.items())
|
||||
|
||||
cmd = [
|
||||
"--nnodes",
|
||||
str(num_nodes),
|
||||
]
|
||||
if nproc_per_node:
|
||||
cmd.extend(["--nproc_per_node", str(nproc_per_node)])
|
||||
if num_nodes > 1:
|
||||
cluster_resolver = utils.cluster_resolver()
|
||||
backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
|
||||
cmd.extend(
|
||||
[
|
||||
"--rdzv_backend",
|
||||
"c10d",
|
||||
"--rdzv_id",
|
||||
backend_address,
|
||||
]
|
||||
)
|
||||
# Set localhost on chief because of https://github.com/pytorch/pytorch/issues/79388
|
||||
if is_chief:
|
||||
cmd.extend(["--rdzv_endpoint", "localhost:2222"])
|
||||
else:
|
||||
cmd.extend(["--rdzv_endpoint", backend_address])
|
||||
else:
|
||||
cmd.append("--standalone")
|
||||
|
||||
cmd.extend(
|
||||
[
|
||||
str(module_name),
|
||||
*args,
|
||||
]
|
||||
)
|
||||
logging.info(f"""Distributed running with cmd: '{" ".join(cmd)}'""")
|
||||
|
||||
# Call torchrun on this module; will spawn new processes and re-run this
|
||||
# function, eventually calling "train_fn". The following line sets the PYTHONPATH to accommodate
|
||||
# bazel stubbing for the main binary.
|
||||
if set_python_path_in_subprocess:
|
||||
subprocess.run(["torchrun"] + cmd, env={**os.environ, "PYTHONPATH": ":".join(sys.path)})
|
||||
else:
|
||||
torch.distributed.run.main(cmd)
|
13
common/test_device.py
Normal file
13
common/test_device.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""Minimal test for device.
|
||||
|
||||
Mostly a test that this can be imported properly even tho moved.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
import tml.common.device as device_utils
|
||||
|
||||
|
||||
def test_device():
|
||||
with patch("tml.common.device.dist.init_process_group"):
|
||||
device = device_utils.setup_and_get_device(tf_ok=False)
|
||||
assert device.type == "cpu"
|
33
common/testing_utils.py
Normal file
33
common/testing_utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
from contextlib import contextmanager
|
||||
import datetime
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch.distributed as dist
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
MOCK_ENV = {
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"LOCAL_WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "29501",
|
||||
"RANK": "0",
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mock_pg():
|
||||
with patch.dict(os.environ, MOCK_ENV):
|
||||
try:
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
timeout=datetime.timedelta(1),
|
||||
)
|
||||
yield
|
||||
except:
|
||||
dist.destroy_process_group()
|
||||
raise
|
||||
finally:
|
||||
dist.destroy_process_group()
|
44
common/utils.py
Normal file
44
common/utils.py
Normal file
@ -0,0 +1,44 @@
|
||||
import yaml
|
||||
import getpass
|
||||
import os
|
||||
import string
|
||||
from typing import Tuple, Type, TypeVar
|
||||
|
||||
from tml.core.config import base_config
|
||||
|
||||
import fsspec
|
||||
|
||||
C = TypeVar("C", bound=base_config.BaseConfig)
|
||||
|
||||
|
||||
def _read_file(f):
|
||||
with fsspec.open(f) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def setup_configuration(
|
||||
config_type: Type[C],
|
||||
yaml_path: str,
|
||||
substitute_env_variable: bool = False,
|
||||
) -> Tuple[C, str]:
|
||||
"""Resolves a config at a yaml path.
|
||||
|
||||
Args:
|
||||
config_type: Pydantic config class to load.
|
||||
yaml_path: yaml path of the config file.
|
||||
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
|
||||
environment variable value whenever possible. If an environment variable doesn't exist,
|
||||
the string is left unchanged.
|
||||
|
||||
Returns:
|
||||
The pydantic config object.
|
||||
"""
|
||||
|
||||
def _substitute(s):
|
||||
if substitute_env_variable:
|
||||
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
||||
return s
|
||||
|
||||
assert config_type is not None, "can't use all_config without config_type"
|
||||
content = _substitute(yaml.safe_load(_read_file(yaml_path)))
|
||||
return config_type.parse_obj(content)
|
20
common/wandb.py
Normal file
20
common/wandb.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tml.core.config as base_config
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class WandbConfig(base_config.BaseConfig):
|
||||
host: str = pydantic.Field(
|
||||
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
|
||||
description="Host of Weights and Biases instance, passed to login.",
|
||||
)
|
||||
key_path: str = pydantic.Field(description="Path to key file.")
|
||||
|
||||
name: str = pydantic.Field(None, description="Name of the experiment, passed to init.")
|
||||
entity: str = pydantic.Field(None, description="Name of user/service account, passed to init.")
|
||||
project: str = pydantic.Field(None, description="Name of wandb project, passed to init.")
|
||||
tags: List[str] = pydantic.Field([], description="List of tags, passed to init.")
|
||||
notes: str = pydantic.Field(None, description="Notes, passed to init.")
|
||||
metadata: Dict[str, Any] = pydantic.Field(None, description="Additional metadata to log.")
|
Reference in New Issue
Block a user