the-algorithm-ml/common/device.py
rajveer43 deec9a820e new
2023-09-11 21:31:42 +05:30

52 lines
1.4 KiB
Python

import os
import torch
import torch.distributed as dist
def maybe_setup_tensorflow():
"""
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
Returns:
None
"""
try:
import tensorflow as tf
except ImportError:
pass
else:
tf.config.set_visible_devices([], "GPU") # disable tf gpu
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
"""
Set up the distributed environment and get the appropriate torch device.
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
Args:
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
Returns:
torch.device: The torch device for the current process.
"""
if tf_ok:
maybe_setup_tensorflow()
device = torch.device("cpu")
backend = "gloo"
if torch.cuda.is_available():
rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
if not torch.distributed.is_initialized():
dist.init_process_group(backend)
return device