the-algorithm-ml/ml_logging/torch_logging.py
2023-09-13 11:22:13 +05:30

90 lines
2.9 KiB
Python

"""Overrides absl logger to be rank-aware for distributed pytorch usage.
>>> # in-bazel import
>>> from twitter.ml.logging.torch_logging import logging
>>> # out-bazel import
>>> from ml.logging.torch_logging import logging
>>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.")
>>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1)
"""
import functools
from typing import Optional
from tml.ml_logging.absl_logging import logging as logging
from absl import logging as absl_logging
import torch.distributed as dist
def rank_specific(logger):
"""
Customize logger behavior based on the distributed environment and rank.
This function allows for customizing the behavior of a logger based on the distributed environment and the rank
of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
depending on the rank or limit the number of redundant logs.
Args:
logger: The logger object to customize.
Returns:
The customized logger.
Example:
To use this function with the `logging` module:
```python
import logging
from rank_specific_logging import rank_specific
logger = logging.getLogger(__name__)
rank_specific(logger)
```
Customization:
- Messages are only logged if the distributed environment is not initialized or if the rank matches.
- The 'warning' method is limited to logging a single redundant warning.
- Logging from rank -1 is redirected to include the rank information.
"""
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
return logger
def _if_rank(logger_method, limit: Optional[int] = None):
if limit:
# If we are limiting redundant logs, wrap logging call with a cache
# to not execute if already cached.
def _wrap(_call):
@functools.lru_cache(limit)
def _logger_method(*args, **kwargs):
_call(*args, **kwargs)
return _logger_method
logger_method = _wrap(logger_method)
def _inner(msg, *args, rank: int = 0, **kwargs):
if not dist.is_initialized():
logger_method(msg, *args, **kwargs)
elif dist.get_rank() == rank:
logger_method(msg, *args, **kwargs)
elif rank < 0:
logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs)
# Register this stack frame with absl logging so that it doesn't trample logging lines.
absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__)
return _inner
logger.fatal = _if_rank(logger.fatal)
logger.error = _if_rank(logger.error)
logger.warning = _if_rank(logger.warning, limit=1)
logger.info = _if_rank(logger.info)
logger.debug = _if_rank(logger.debug)
logger.exception = _if_rank(logger.exception)
logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True
rank_specific(logging)