mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 14:51:49 +01:00
Refactored rank-specific logger
#Refactored rank_specific functions. #Implemented the logic without changing the behavior. #The readability and maintainability of the code improved. ------------------------- Tested against: test_torch_logging.py
This commit is contained in:
parent
78c3235eee
commit
ca8db6ddf4
@ -17,45 +17,46 @@ from absl import logging as absl_logging
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from absl import logging as absl_logging
|
||||||
|
from tml.ml_logging.absl_logging import logging as logging
|
||||||
|
|
||||||
|
|
||||||
def rank_specific(logger):
|
def rank_specific(logger):
|
||||||
"""Ensures that we only override a given logger once."""
|
"""Ensures that we only override a given logger once."""
|
||||||
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
|
|
||||||
return logger
|
|
||||||
|
|
||||||
def _if_rank(logger_method, limit: Optional[int] = None):
|
def _if_rank(logger_method, limit: Optional[int] = None):
|
||||||
if limit:
|
"""Decorator to wrap logger_method and execute only if rank matches."""
|
||||||
# If we are limiting redundant logs, wrap logging call with a cache
|
if limit:
|
||||||
# to not execute if already cached.
|
@functools.lru_cache(limit)
|
||||||
def _wrap(_call):
|
def _logger_method(*args, **kwargs):
|
||||||
@functools.lru_cache(limit)
|
logger_method(*args, **kwargs)
|
||||||
def _logger_method(*args, **kwargs):
|
return _logger_method
|
||||||
_call(*args, **kwargs)
|
|
||||||
|
|
||||||
return _logger_method
|
def _inner(msg, *args, rank: int = 0, **kwargs):
|
||||||
|
"""Inner function to execute logger_method only if rank matches."""
|
||||||
|
if not dist.is_initialized() or dist.get_rank() == rank or rank < 0:
|
||||||
|
logger_method(msg, *args, **kwargs)
|
||||||
|
|
||||||
logger_method = _wrap(logger_method)
|
# Register this stack frame with absl logging so that it doesn't trample logging lines.
|
||||||
|
absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__)
|
||||||
|
|
||||||
def _inner(msg, *args, rank: int = 0, **kwargs):
|
return _inner
|
||||||
if not dist.is_initialized():
|
|
||||||
logger_method(msg, *args, **kwargs)
|
|
||||||
elif dist.get_rank() == rank:
|
|
||||||
logger_method(msg, *args, **kwargs)
|
|
||||||
elif rank < 0:
|
|
||||||
logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs)
|
|
||||||
|
|
||||||
# Register this stack frame with absl logging so that it doesn't trample logging lines.
|
logger.fatal = _if_rank(logger.fatal)
|
||||||
absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__)
|
logger.error = _if_rank(logger.error)
|
||||||
|
logger.warning = _if_rank(logger.warning, limit=1)
|
||||||
|
logger.info = _if_rank(logger.info)
|
||||||
|
logger.debug = _if_rank(logger.debug)
|
||||||
|
logger.exception = _if_rank(logger.exception)
|
||||||
|
|
||||||
return _inner
|
logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True
|
||||||
|
|
||||||
logger.fatal = _if_rank(logger.fatal)
|
return logger if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC") else None
|
||||||
logger.error = _if_rank(logger.error)
|
|
||||||
logger.warning = _if_rank(logger.warning, limit=1)
|
|
||||||
logger.info = _if_rank(logger.info)
|
|
||||||
logger.debug = _if_rank(logger.debug)
|
|
||||||
logger.exception = _if_rank(logger.exception)
|
|
||||||
|
|
||||||
logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True
|
|
||||||
|
|
||||||
|
|
||||||
rank_specific(logging)
|
rank_specific(logging)
|
||||||
|
Loading…
Reference in New Issue
Block a user