Refactored rank-specific logger

#Refactored rank_specific functions.
#Implemented the logic without changing the behavior.
#The readability and maintainability of the code improved.
-------------------------
Tested against:
test_torch_logging.py
This commit is contained in:
Darshan P 2023-04-02 18:19:39 +00:00
parent 78c3235eee
commit ca8db6ddf4

View File

@ -17,31 +17,30 @@ from absl import logging as absl_logging
import torch.distributed as dist import torch.distributed as dist
import functools
from typing import Optional
import torch.distributed as dist
from absl import logging as absl_logging
from tml.ml_logging.absl_logging import logging as logging
def rank_specific(logger): def rank_specific(logger):
"""Ensures that we only override a given logger once.""" """Ensures that we only override a given logger once."""
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
return logger
def _if_rank(logger_method, limit: Optional[int] = None): def _if_rank(logger_method, limit: Optional[int] = None):
"""Decorator to wrap logger_method and execute only if rank matches."""
if limit: if limit:
# If we are limiting redundant logs, wrap logging call with a cache
# to not execute if already cached.
def _wrap(_call):
@functools.lru_cache(limit) @functools.lru_cache(limit)
def _logger_method(*args, **kwargs): def _logger_method(*args, **kwargs):
_call(*args, **kwargs) logger_method(*args, **kwargs)
return _logger_method return _logger_method
logger_method = _wrap(logger_method)
def _inner(msg, *args, rank: int = 0, **kwargs): def _inner(msg, *args, rank: int = 0, **kwargs):
if not dist.is_initialized(): """Inner function to execute logger_method only if rank matches."""
if not dist.is_initialized() or dist.get_rank() == rank or rank < 0:
logger_method(msg, *args, **kwargs) logger_method(msg, *args, **kwargs)
elif dist.get_rank() == rank:
logger_method(msg, *args, **kwargs)
elif rank < 0:
logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs)
# Register this stack frame with absl logging so that it doesn't trample logging lines. # Register this stack frame with absl logging so that it doesn't trample logging lines.
absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__) absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__)
@ -57,5 +56,7 @@ def rank_specific(logger):
logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True
return logger if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC") else None
rank_specific(logging) rank_specific(logging)