mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-16 21:29:24 +01:00
31 lines
646 B
Python
31 lines
646 B
Python
|
import os
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
|
||
|
|
||
|
def maybe_setup_tensorflow():
|
||
|
try:
|
||
|
import tensorflow as tf
|
||
|
except ImportError:
|
||
|
pass
|
||
|
else:
|
||
|
tf.config.set_visible_devices([], "GPU") # disable tf gpu
|
||
|
|
||
|
|
||
|
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
|
||
|
if tf_ok:
|
||
|
maybe_setup_tensorflow()
|
||
|
|
||
|
device = torch.device("cpu")
|
||
|
backend = "gloo"
|
||
|
if torch.cuda.is_available():
|
||
|
rank = os.environ["LOCAL_RANK"]
|
||
|
device = torch.device(f"cuda:{rank}")
|
||
|
backend = "nccl"
|
||
|
torch.cuda.set_device(device)
|
||
|
if not torch.distributed.is_initialized():
|
||
|
dist.init_process_group(backend)
|
||
|
|
||
|
return device
|