2023-03-31 20:05:14 +02:00
import os
import torch
import torch . distributed as dist
def maybe_setup_tensorflow ( ) :
2023-09-11 18:01:42 +02:00
"""
Try to import TensorFlow and disable GPU devices if TensorFlow is available .
This function checks if TensorFlow is installed and , if so , disables GPU devices used by TensorFlow to avoid conflicts with PyTorch .
Returns :
None
"""
2023-03-31 20:05:14 +02:00
try :
import tensorflow as tf
except ImportError :
pass
else :
tf . config . set_visible_devices ( [ ] , " GPU " ) # disable tf gpu
def setup_and_get_device ( tf_ok : bool = True ) - > torch . device :
2023-09-11 18:01:42 +02:00
"""
Set up the distributed environment and get the appropriate torch device .
This function sets up the distributed environment using PyTorch ' s `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
Args :
tf_ok ( bool , optional ) : Whether to run ` maybe_setup_tensorflow ` to disable TensorFlow GPU devices . Defaults to True .
Returns :
torch . device : The torch device for the current process .
"""
2023-03-31 20:05:14 +02:00
if tf_ok :
maybe_setup_tensorflow ( )
device = torch . device ( " cpu " )
backend = " gloo "
if torch . cuda . is_available ( ) :
rank = os . environ [ " LOCAL_RANK " ]
device = torch . device ( f " cuda: { rank } " )
backend = " nccl "
torch . cuda . set_device ( device )
if not torch . distributed . is_initialized ( ) :
dist . init_process_group ( backend )
return device