"""For logging model weights.""" import itertools from typing import Callable, Dict, List, Optional, Union from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] import torch import torch.distributed as dist from torchrec.distributed.model_parallel import DistributedModelParallel def weights_to_log( model: torch.nn.Module, how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None, ): """Creates dict of reduced weights to log to give sense of training. Args: model: model to traverse. how_to_log: if a function, then applies this to every parameter, if a dict then only applies and logs specified parameters. """ if not how_to_log: return to_log = dict() named_parameters = model.named_parameters() logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}") if isinstance(model, DistributedModelParallel): named_parameters = itertools.chain( named_parameters, model._dmp_wrapped_module.named_parameters() ) logging.info( f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}" ) for param_name, params in named_parameters: if callable(how_to_log): how = how_to_log else: how = how_to_log.get(param_name) # type: ignore[assignment] if not how: continue # type: ignore to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy() return to_log def log_ebc_norms( model_state_dict, ebc_keys: List[str], sample_size: int = 4_000_000, ) -> Dict[str, torch.Tensor]: """ Logs the norms of the embedding tables as specified by ebc_keys. As of now, log average norm per rank. Args: model_state_dict (dict): The state dictionary of the model. ebc_keys (List[str]): List of embedding keys from state_dict to log. Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight". sample_size (int, optional): Limits the number of rows per rank to compute average on to avoid out-of-memory (OOM) errors. Defaults to 4,000,000. Returns: Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables. The keys are in the format "{emb_key}-norm-{idx}". Note: This function computes and logs the average norm of embedding tables across ranks. It gathers the norms from all ranks and returns them as a dictionary. """ norm_logs = dict() for emb_key in ebc_keys: norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}")) if emb_key in model_state_dict: emb_weight = model_state_dict[emb_key] try: emb_weight_tensor = emb_weight.local_tensor() except AttributeError as e: logging.info(e) emb_weight_tensor = emb_weight logging.info("Running Tensor.detach()") emb_weight_tensor = emb_weight_tensor.detach() sample_mask = torch.randperm(emb_weight_tensor.shape[0])[ : min(sample_size, emb_weight_tensor.shape[0]) ] # WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks # Change sample_size if the you observe frequent OOM errors or remove weight logging. norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32) logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1) norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}")) all_norms = [ torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size()) ] dist.all_gather(all_norms, norms) for idx, norm in enumerate(all_norms): if norm != -1.0: norm_logs[f"{emb_key}-norm-{idx}"] = norm logging.info(f"Norm Logs are {norm_logs}") return norm_logs