diff --git a/ml_logging/torch_logging.py b/ml_logging/torch_logging.py index e791c46..fedabcb 100644 --- a/ml_logging/torch_logging.py +++ b/ml_logging/torch_logging.py @@ -17,45 +17,46 @@ from absl import logging as absl_logging import torch.distributed as dist +import functools +from typing import Optional + +import torch.distributed as dist + +from absl import logging as absl_logging +from tml.ml_logging.absl_logging import logging as logging + + def rank_specific(logger): - """Ensures that we only override a given logger once.""" - if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"): - return logger + """Ensures that we only override a given logger once.""" - def _if_rank(logger_method, limit: Optional[int] = None): - if limit: - # If we are limiting redundant logs, wrap logging call with a cache - # to not execute if already cached. - def _wrap(_call): - @functools.lru_cache(limit) - def _logger_method(*args, **kwargs): - _call(*args, **kwargs) + def _if_rank(logger_method, limit: Optional[int] = None): + """Decorator to wrap logger_method and execute only if rank matches.""" + if limit: + @functools.lru_cache(limit) + def _logger_method(*args, **kwargs): + logger_method(*args, **kwargs) + return _logger_method - return _logger_method + def _inner(msg, *args, rank: int = 0, **kwargs): + """Inner function to execute logger_method only if rank matches.""" + if not dist.is_initialized() or dist.get_rank() == rank or rank < 0: + logger_method(msg, *args, **kwargs) - logger_method = _wrap(logger_method) + # Register this stack frame with absl logging so that it doesn't trample logging lines. + absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__) - def _inner(msg, *args, rank: int = 0, **kwargs): - if not dist.is_initialized(): - logger_method(msg, *args, **kwargs) - elif dist.get_rank() == rank: - logger_method(msg, *args, **kwargs) - elif rank < 0: - logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs) + return _inner - # Register this stack frame with absl logging so that it doesn't trample logging lines. - absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__) + logger.fatal = _if_rank(logger.fatal) + logger.error = _if_rank(logger.error) + logger.warning = _if_rank(logger.warning, limit=1) + logger.info = _if_rank(logger.info) + logger.debug = _if_rank(logger.debug) + logger.exception = _if_rank(logger.exception) - return _inner + logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True - logger.fatal = _if_rank(logger.fatal) - logger.error = _if_rank(logger.error) - logger.warning = _if_rank(logger.warning, limit=1) - logger.info = _if_rank(logger.info) - logger.debug = _if_rank(logger.debug) - logger.exception = _if_rank(logger.exception) - - logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True + return logger if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC") else None rank_specific(logging)