2023-03-31 20:05:14 +02:00
|
|
|
"""For logging model weights."""
|
|
|
|
import itertools
|
|
|
|
from typing import Callable, Dict, List, Optional, Union
|
|
|
|
|
|
|
|
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torchrec.distributed.model_parallel import DistributedModelParallel
|
|
|
|
|
|
|
|
|
|
|
|
def weights_to_log(
|
|
|
|
model: torch.nn.Module,
|
|
|
|
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
|
|
|
|
):
|
|
|
|
"""Creates dict of reduced weights to log to give sense of training.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: model to traverse.
|
|
|
|
how_to_log: if a function, then applies this to every parameter, if a dict
|
|
|
|
then only applies and logs specified parameters.
|
|
|
|
|
|
|
|
"""
|
|
|
|
if not how_to_log:
|
|
|
|
return
|
|
|
|
|
|
|
|
to_log = dict()
|
|
|
|
named_parameters = model.named_parameters()
|
|
|
|
logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}")
|
|
|
|
if isinstance(model, DistributedModelParallel):
|
|
|
|
named_parameters = itertools.chain(
|
|
|
|
named_parameters, model._dmp_wrapped_module.named_parameters()
|
|
|
|
)
|
|
|
|
logging.info(
|
|
|
|
f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}"
|
|
|
|
)
|
|
|
|
for param_name, params in named_parameters:
|
|
|
|
if callable(how_to_log):
|
|
|
|
how = how_to_log
|
|
|
|
else:
|
|
|
|
how = how_to_log.get(param_name) # type: ignore[assignment]
|
|
|
|
if not how:
|
|
|
|
continue # type: ignore
|
|
|
|
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
|
|
|
|
return to_log
|
|
|
|
|
|
|
|
|
|
|
|
def log_ebc_norms(
|
|
|
|
model_state_dict,
|
|
|
|
ebc_keys: List[str],
|
|
|
|
sample_size: int = 4_000_000,
|
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
|
"""
|
2023-09-11 18:01:42 +02:00
|
|
|
Logs the norms of the embedding tables as specified by ebc_keys.
|
|
|
|
As of now, log average norm per rank.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model_state_dict (dict): The state dictionary of the model.
|
|
|
|
ebc_keys (List[str]): List of embedding keys from state_dict to log.
|
|
|
|
Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight".
|
|
|
|
sample_size (int, optional): Limits the number of rows per rank to compute average on
|
|
|
|
to avoid out-of-memory (OOM) errors. Defaults to 4,000,000.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables.
|
|
|
|
The keys are in the format "{emb_key}-norm-{idx}".
|
|
|
|
|
|
|
|
Note:
|
|
|
|
This function computes and logs the average norm of embedding tables across ranks.
|
|
|
|
It gathers the norms from all ranks and returns them as a dictionary.
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
norm_logs = dict()
|
|
|
|
for emb_key in ebc_keys:
|
|
|
|
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
|
|
|
|
if emb_key in model_state_dict:
|
|
|
|
emb_weight = model_state_dict[emb_key]
|
|
|
|
try:
|
|
|
|
emb_weight_tensor = emb_weight.local_tensor()
|
|
|
|
except AttributeError as e:
|
|
|
|
logging.info(e)
|
|
|
|
emb_weight_tensor = emb_weight
|
|
|
|
logging.info("Running Tensor.detach()")
|
|
|
|
emb_weight_tensor = emb_weight_tensor.detach()
|
|
|
|
sample_mask = torch.randperm(emb_weight_tensor.shape[0])[
|
|
|
|
: min(sample_size, emb_weight_tensor.shape[0])
|
|
|
|
]
|
|
|
|
# WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks
|
|
|
|
# Change sample_size if the you observe frequent OOM errors or remove weight logging.
|
|
|
|
norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32)
|
|
|
|
logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1)
|
|
|
|
norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}"))
|
|
|
|
|
|
|
|
all_norms = [
|
|
|
|
torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size())
|
|
|
|
]
|
|
|
|
dist.all_gather(all_norms, norms)
|
|
|
|
for idx, norm in enumerate(all_norms):
|
|
|
|
if norm != -1.0:
|
|
|
|
norm_logs[f"{emb_key}-norm-{idx}"] = norm
|
|
|
|
logging.info(f"Norm Logs are {norm_logs}")
|
|
|
|
return norm_logs
|