mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-06-16 18:58:30 +02:00
Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings
This commit is contained in:
0
ml_logging/__init__.py
Normal file
0
ml_logging/__init__.py
Normal file
26
ml_logging/absl_logging.py
Normal file
26
ml_logging/absl_logging.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""Sets up logging through absl for training usage.
|
||||
|
||||
- Redirects logging to sys.stdout so that severity levels in GCP Stackdriver are accurate.
|
||||
|
||||
Usage:
|
||||
>>> from twitter.ml.logging.absl_logging import logging
|
||||
>>> logging.info(f"Properly logged as INFO level in GCP Stackdriver.")
|
||||
|
||||
"""
|
||||
import logging as py_logging
|
||||
import sys
|
||||
|
||||
from absl import logging as logging
|
||||
|
||||
|
||||
def setup_absl_logging():
|
||||
"""Make sure that absl logging pushes to stdout rather than stderr."""
|
||||
logging.get_absl_handler().python_handler.stream = sys.stdout
|
||||
formatter = py_logging.Formatter(
|
||||
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
|
||||
)
|
||||
logging.get_absl_handler().setFormatter(formatter)
|
||||
logging.set_verbosity(logging.INFO)
|
||||
|
||||
|
||||
setup_absl_logging()
|
21
ml_logging/test_torch_logging.py
Normal file
21
ml_logging/test_torch_logging.py
Normal file
@ -0,0 +1,21 @@
|
||||
import unittest
|
||||
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
|
||||
class Testtlogging(unittest.TestCase):
|
||||
def test_warn_once(self):
|
||||
with self.assertLogs(level="INFO") as captured_logs:
|
||||
logging.info("first info")
|
||||
logging.warning("first warning")
|
||||
logging.warning("first warning")
|
||||
logging.info("second info")
|
||||
|
||||
self.assertEqual(
|
||||
captured_logs.output,
|
||||
[
|
||||
"INFO:absl:first info",
|
||||
"WARNING:absl:first warning",
|
||||
"INFO:absl:second info",
|
||||
],
|
||||
)
|
61
ml_logging/torch_logging.py
Normal file
61
ml_logging/torch_logging.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""Overrides absl logger to be rank-aware for distributed pytorch usage.
|
||||
|
||||
>>> # in-bazel import
|
||||
>>> from twitter.ml.logging.torch_logging import logging
|
||||
>>> # out-bazel import
|
||||
>>> from ml.logging.torch_logging import logging
|
||||
>>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.")
|
||||
>>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1)
|
||||
|
||||
"""
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
from tml.ml_logging.absl_logging import logging as logging
|
||||
from absl import logging as absl_logging
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def rank_specific(logger):
|
||||
"""Ensures that we only override a given logger once."""
|
||||
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
|
||||
return logger
|
||||
|
||||
def _if_rank(logger_method, limit: Optional[int] = None):
|
||||
if limit:
|
||||
# If we are limiting redundant logs, wrap logging call with a cache
|
||||
# to not execute if already cached.
|
||||
def _wrap(_call):
|
||||
@functools.lru_cache(limit)
|
||||
def _logger_method(*args, **kwargs):
|
||||
_call(*args, **kwargs)
|
||||
|
||||
return _logger_method
|
||||
|
||||
logger_method = _wrap(logger_method)
|
||||
|
||||
def _inner(msg, *args, rank: int = 0, **kwargs):
|
||||
if not dist.is_initialized():
|
||||
logger_method(msg, *args, **kwargs)
|
||||
elif dist.get_rank() == rank:
|
||||
logger_method(msg, *args, **kwargs)
|
||||
elif rank < 0:
|
||||
logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs)
|
||||
|
||||
# Register this stack frame with absl logging so that it doesn't trample logging lines.
|
||||
absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__)
|
||||
|
||||
return _inner
|
||||
|
||||
logger.fatal = _if_rank(logger.fatal)
|
||||
logger.error = _if_rank(logger.error)
|
||||
logger.warning = _if_rank(logger.warning, limit=1)
|
||||
logger.info = _if_rank(logger.info)
|
||||
logger.debug = _if_rank(logger.debug)
|
||||
logger.exception = _if_rank(logger.exception)
|
||||
|
||||
logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True
|
||||
|
||||
|
||||
rank_specific(logging)
|
Reference in New Issue
Block a user