the-algorithm-ml/common/testing_utils.py
wiseaidev 2cc1abedd7 add & config mypy on common package
Signed-off-by: wiseaidev <business@wiseai.dev>
2023-04-01 14:38:59 +03:00

39 lines
660 B
Python

import datetime
import os
from contextlib import (
contextmanager,
)
from unittest.mock import (
patch,
)
import torch.distributed as dist
from tml.ml_logging.torch_logging import (
logging,
)
MOCK_ENV = {
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"LOCAL_WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "29501",
"RANK": "0",
}
@contextmanager
def mock_pg():
with patch.dict(os.environ, MOCK_ENV):
try:
dist.init_process_group(
backend="gloo",
timeout=datetime.timedelta(1),
)
yield
except:
dist.destroy_process_group()
raise
finally:
dist.destroy_process_group()