mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-26 13:45:25 +01:00
2cc1abedd7
Signed-off-by: wiseaidev <business@wiseai.dev>
102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
"""For logging model weights."""
|
|
import itertools
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from tml.ml_logging.torch_logging import (
|
|
logging,
|
|
)
|
|
from torchrec.distributed.model_parallel import (
|
|
DistributedModelParallel,
|
|
)
|
|
|
|
|
|
def weights_to_log(
|
|
model: torch.nn.Module,
|
|
how_to_log: Optional[Union[Callable[[Any], Any], Dict[str, Callable[[Any], Any]]]] = None,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Creates dict of reduced weights to log to give sense of training.
|
|
|
|
Args:
|
|
model: model to traverse.
|
|
how_to_log: if a function, then applies this to every parameter, if a dict
|
|
then only applies and logs specified parameters.
|
|
|
|
"""
|
|
if not how_to_log:
|
|
return None
|
|
|
|
to_log = dict()
|
|
named_parameters = model.named_parameters()
|
|
logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}")
|
|
if isinstance(model, DistributedModelParallel):
|
|
named_parameters = itertools.chain(
|
|
named_parameters, model._dmp_wrapped_module.named_parameters()
|
|
)
|
|
logging.info(
|
|
f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}"
|
|
)
|
|
for param_name, params in named_parameters:
|
|
if callable(how_to_log):
|
|
how = how_to_log
|
|
else:
|
|
how = how_to_log.get(param_name) # type: ignore[assignment]
|
|
if how is None:
|
|
continue
|
|
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
|
|
return to_log
|
|
|
|
|
|
def log_ebc_norms(
|
|
model_state_dict: Dict[str, Any],
|
|
ebc_keys: List[str],
|
|
sample_size: int = 4_000_000,
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""Logs the norms of the embedding tables as specified by ebc_keys.
|
|
As of now, log average norm per rank.
|
|
|
|
Args:
|
|
model_state_dict: model.state_dict()
|
|
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
|
|
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
|
|
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
|
|
"""
|
|
norm_logs = dict()
|
|
for emb_key in ebc_keys:
|
|
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
|
|
if emb_key in model_state_dict:
|
|
emb_weight = model_state_dict[emb_key]
|
|
try:
|
|
emb_weight_tensor = emb_weight.local_tensor()
|
|
except AttributeError as e:
|
|
logging.info(e)
|
|
emb_weight_tensor = emb_weight
|
|
logging.info("Running Tensor.detach()")
|
|
emb_weight_tensor = emb_weight_tensor.detach()
|
|
sample_mask = torch.randperm(emb_weight_tensor.shape[0])[
|
|
: min(sample_size, emb_weight_tensor.shape[0])
|
|
]
|
|
# WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks
|
|
# Change sample_size if the you observe frequent OOM errors or remove weight logging.
|
|
norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32)
|
|
logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1)
|
|
norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}"))
|
|
|
|
all_norms = [
|
|
torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size())
|
|
]
|
|
dist.all_gather(all_norms, norms)
|
|
for idx, norm in enumerate(all_norms):
|
|
if norm != -1.0:
|
|
norm_logs[f"{emb_key}-norm-{idx}"] = norm
|
|
logging.info(f"Norm Logs are {norm_logs}")
|
|
return norm_logs
|