mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 14:51:49 +01:00
109 lines
2.4 KiB
Python
109 lines
2.4 KiB
Python
import json
|
|
import os
|
|
from typing import List
|
|
|
|
|
|
KF_DDS_PORT: int = 5050
|
|
SLURM_DDS_PORT: int = 5051
|
|
FLIGHT_SERVER_PORT: int = 2222
|
|
|
|
|
|
def on_kf():
|
|
return "SPEC_TYPE" in os.environ
|
|
|
|
|
|
def has_readers():
|
|
if on_kf():
|
|
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
|
return machines_config_env["dataset_worker"] is not None
|
|
return os.environ.get("HAS_READERS", "False") == "True"
|
|
|
|
|
|
def get_task_type():
|
|
if on_kf():
|
|
return os.environ["SPEC_TYPE"]
|
|
return os.environ["TASK_TYPE"]
|
|
|
|
|
|
def is_chief() -> bool:
|
|
return get_task_type() == "chief"
|
|
|
|
|
|
def is_reader() -> bool:
|
|
return get_task_type() == "datasetworker"
|
|
|
|
|
|
def is_dispatcher() -> bool:
|
|
return get_task_type() == "datasetdispatcher"
|
|
|
|
|
|
def get_task_index():
|
|
if on_kf():
|
|
pod_name = os.environ["MY_POD_NAME"]
|
|
return int(pod_name.split("-")[-1])
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def get_reader_port():
|
|
if on_kf():
|
|
return KF_DDS_PORT
|
|
return SLURM_DDS_PORT
|
|
|
|
|
|
def get_dds():
|
|
if not has_readers():
|
|
return None
|
|
dispatcher_address = get_dds_dispatcher_address()
|
|
if dispatcher_address:
|
|
return f"grpc://{dispatcher_address}"
|
|
else:
|
|
raise ValueError("Job does not have DDS.")
|
|
|
|
|
|
def get_dds_dispatcher_address():
|
|
if not has_readers():
|
|
return None
|
|
if on_kf():
|
|
job_name = os.environ["JOB_NAME"]
|
|
dds_host = f"{job_name}-datasetdispatcher-0"
|
|
else:
|
|
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
|
|
return f"{dds_host}:{get_reader_port()}"
|
|
|
|
|
|
def get_dds_worker_address():
|
|
if not has_readers():
|
|
return None
|
|
if on_kf():
|
|
job_name = os.environ["JOB_NAME"]
|
|
task_index = get_task_index()
|
|
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
|
|
else:
|
|
node = os.environ["SLURMD_NODENAME"]
|
|
return f"{node}:{get_reader_port()}"
|
|
|
|
|
|
def get_num_readers():
|
|
if not has_readers():
|
|
return 0
|
|
if on_kf():
|
|
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
|
return int(machines_config_env["num_dataset_workers"] or 0)
|
|
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
|
|
|
|
|
|
def get_flight_server_addresses():
|
|
if on_kf():
|
|
job_name = os.environ["JOB_NAME"]
|
|
return [
|
|
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
|
|
for task_index in range(get_num_readers())
|
|
]
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def get_dds_journaling_dir():
|
|
return os.environ.get("DATASET_JOURNALING_DIR", None)
|