mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-16 13:19:23 +01:00
52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
def maybe_setup_tensorflow():
|
|
"""
|
|
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
|
|
|
|
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
|
|
|
|
Returns:
|
|
None
|
|
|
|
"""
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
tf.config.set_visible_devices([], "GPU") # disable tf gpu
|
|
|
|
|
|
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
|
|
"""
|
|
Set up the distributed environment and get the appropriate torch device.
|
|
|
|
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
|
|
|
|
Args:
|
|
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
|
|
|
|
Returns:
|
|
torch.device: The torch device for the current process.
|
|
|
|
"""
|
|
if tf_ok:
|
|
maybe_setup_tensorflow()
|
|
|
|
device = torch.device("cpu")
|
|
backend = "gloo"
|
|
if torch.cuda.is_available():
|
|
rank = os.environ["LOCAL_RANK"]
|
|
device = torch.device(f"cuda:{rank}")
|
|
backend = "nccl"
|
|
torch.cuda.set_device(device)
|
|
if not torch.distributed.is_initialized():
|
|
dist.init_process_group(backend)
|
|
|
|
return device
|