mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-25 13:21:10 +01:00
181 lines
6.1 KiB
Python
181 lines
6.1 KiB
Python
"""Dataset service orchestrated by a TFJob
|
|
"""
|
|
from typing import Optional
|
|
import uuid
|
|
|
|
from tml.ml_logging.torch_logging import logging
|
|
import tml.machines.environment as env
|
|
|
|
import packaging.version
|
|
import tensorflow as tf
|
|
|
|
try:
|
|
import tensorflow_io as tfio
|
|
except:
|
|
pass
|
|
from tensorflow.python.data.experimental.ops.data_service_ops import (
|
|
_from_dataset_id,
|
|
_register_dataset,
|
|
)
|
|
import torch.distributed as dist
|
|
|
|
|
|
def maybe_start_dataset_service():
|
|
"""
|
|
Start the dataset service if readers are available and required dependencies are met.
|
|
|
|
This function checks if readers are available and if the required TensorFlow version is >= 2.5.
|
|
If both conditions are met and the current environment is the dispatcher or reader, it starts
|
|
the TensorFlow dataset service.
|
|
|
|
Raises:
|
|
Exception: If the required TensorFlow version is not met (>= 2.5).
|
|
"""
|
|
if not env.has_readers():
|
|
return
|
|
|
|
if packaging.version.parse(tf.__version__) < packaging.version.parse("2.5"):
|
|
raise Exception(f"maybe_distribute_dataset requires TF >= 2.5; got {tf.__version__}")
|
|
|
|
if env.is_dispatcher():
|
|
logging.info(f"env.get_reader_port() = {env.get_reader_port()}")
|
|
logging.info(f"env.get_dds_journaling_dir() = {env.get_dds_journaling_dir()}")
|
|
work_dir = env.get_dds_journaling_dir()
|
|
server = tf.data.experimental.service.DispatchServer(
|
|
tf.data.experimental.service.DispatcherConfig(
|
|
port=env.get_reader_port(),
|
|
protocol="grpc",
|
|
work_dir=work_dir,
|
|
fault_tolerant_mode=bool(work_dir),
|
|
)
|
|
)
|
|
server.join()
|
|
|
|
elif env.is_reader():
|
|
logging.info(f"env.get_reader_port() = {env.get_reader_port()}")
|
|
logging.info(f"env.get_dds_dispatcher_address() = {env.get_dds_dispatcher_address()}")
|
|
logging.info(f"env.get_dds_worker_address() = {env.get_dds_worker_address()}")
|
|
server = tf.data.experimental.service.WorkerServer(
|
|
tf.data.experimental.service.WorkerConfig(
|
|
port=env.get_reader_port(),
|
|
dispatcher_address=env.get_dds_dispatcher_address(),
|
|
worker_address=env.get_dds_worker_address(),
|
|
protocol="grpc",
|
|
)
|
|
)
|
|
server.join()
|
|
|
|
|
|
def register_dataset(
|
|
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
|
|
):
|
|
"""
|
|
Register a dataset with the distributed dataset service.
|
|
|
|
This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
|
|
and job name to all processes in the distributed environment.
|
|
|
|
Args:
|
|
dataset (tf.data.Dataset): The dataset to be registered.
|
|
dataset_service (str): The name of the dataset service.
|
|
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
|
|
|
Returns:
|
|
Tuple[int, str]: A tuple containing the dataset ID and job name.
|
|
|
|
Note:
|
|
This function should be called on the rank 0 process.
|
|
|
|
"""
|
|
if dist.get_rank() == 0:
|
|
dataset_id = _register_dataset(
|
|
service=dataset_service,
|
|
dataset=dataset,
|
|
compression=compression,
|
|
)
|
|
job_name = uuid.uuid4().hex[:8]
|
|
id_and_job = [dataset_id.numpy(), job_name]
|
|
logging.info(f"rank{dist.get_rank()}: Created dds job with {dataset_id.numpy()}, {job_name}")
|
|
else:
|
|
id_and_job = [None, None]
|
|
|
|
dist.broadcast_object_list(id_and_job, src=0)
|
|
return tuple(id_and_job)
|
|
|
|
|
|
def distribute_from_dataset_id(
|
|
dataset_service: str,
|
|
dataset_id: int,
|
|
job_name: Optional[str],
|
|
compression: Optional[str] = "AUTO",
|
|
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
|
|
) -> tf.data.Dataset:
|
|
"""
|
|
Distribute a dataset from a registered dataset ID.
|
|
|
|
This function consumes a dataset from the distributed dataset service using the provided dataset ID
|
|
and job name. It also supports prefetching for improved performance.
|
|
|
|
Args:
|
|
dataset_service (str): The name of the dataset service.
|
|
dataset_id (int): The ID of the dataset to be consumed.
|
|
job_name (Optional[str]): The name of the job associated with the dataset (optional).
|
|
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
|
prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
|
|
|
|
Returns:
|
|
tf.data.Dataset: The distributed dataset.
|
|
|
|
"""
|
|
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
|
|
dataset = _from_dataset_id(
|
|
processing_mode="parallel_epochs",
|
|
service=dataset_service,
|
|
dataset_id=dataset_id,
|
|
job_name=job_name,
|
|
element_spec=None,
|
|
compression=compression,
|
|
)
|
|
if prefetch is not None:
|
|
dataset = dataset.prefetch(prefetch)
|
|
return dataset
|
|
|
|
|
|
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
|
|
"""
|
|
Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
|
|
|
|
This function is used to distribute a dataset in a distributed training environment. It performs the
|
|
following steps:
|
|
- On the rank 0 process, it registers the given dataset with the distributed dataset service.
|
|
- It broadcasts the job name and dataset ID to all rank processes.
|
|
- All rank processes then consume the same dataset from the distributed dataset service.
|
|
|
|
Args:
|
|
dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
|
|
|
|
Returns:
|
|
tf.data.Dataset: The distributed TensorFlow dataset.
|
|
|
|
Note:
|
|
- If there are no reader processes in the distributed environment, the original dataset is returned
|
|
without any distribution.
|
|
- This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
|
|
issues caused by each rank process trying to serve one job.
|
|
|
|
"""
|
|
if not env.has_readers():
|
|
return dataset
|
|
dataset_service = env.get_dds()
|
|
|
|
logging.info(f"using DDS = {dataset_service}")
|
|
dataset_id, job_name = register_dataset(dataset=dataset, dataset_service=dataset_service)
|
|
dataset = distribute_from_dataset_id(
|
|
dataset_service=dataset_service, dataset_id=dataset_id, job_name=job_name
|
|
)
|
|
return dataset
|
|
|
|
|
|
if __name__ == "__main__":
|
|
maybe_start_dataset_service()
|