the-algorithm-ml/common/device.py
wiseaidev 2cc1abedd7 add & config mypy on common package
Signed-off-by: wiseaidev <business@wiseai.dev>
2023-04-01 14:38:59 +03:00

31 lines
654 B
Python

import os
import torch
import torch.distributed as dist
def maybe_setup_tensorflow() -> None:
try:
import tensorflow as tf
except ImportError:
pass
else:
tf.config.set_visible_devices([], "GPU") # disable tf gpu
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
if tf_ok:
maybe_setup_tensorflow()
device = torch.device("cpu")
backend = "gloo"
if torch.cuda.is_available():
rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
if not torch.distributed.is_initialized():
dist.init_process_group(backend)
return device