mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-09 22:39:22 +01:00
mmm
This commit is contained in:
parent
92eaaad3ed
commit
590e8b76fe
@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222
|
||||
|
||||
|
||||
def on_kf():
|
||||
return "SPEC_TYPE" in os.environ
|
||||
"""Check if the code is running on Kubernetes with Kubeflow (KF) environment.
|
||||
|
||||
Returns:
|
||||
bool: True if running on KF, False otherwise.
|
||||
"""
|
||||
return "SPEC_TYPE" in os.environ
|
||||
|
||||
|
||||
def has_readers():
|
||||
if on_kf():
|
||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||
return machines_config_env["dataset_worker"] is not None
|
||||
return os.environ.get("HAS_READERS", "False") == "True"
|
||||
"""Check if the current task has dataset workers.
|
||||
|
||||
Returns:
|
||||
bool: True if the task has dataset workers, False otherwise.
|
||||
"""
|
||||
if on_kf():
|
||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||
return machines_config_env.get("dataset_worker") is not None
|
||||
return os.environ.get("HAS_READERS", "False") == "True"
|
||||
|
||||
|
||||
def get_task_type():
|
||||
if on_kf():
|
||||
return os.environ["SPEC_TYPE"]
|
||||
return os.environ["TASK_TYPE"]
|
||||
"""Get the type of the current task.
|
||||
|
||||
Returns:
|
||||
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
|
||||
"""
|
||||
if on_kf():
|
||||
return os.environ["SPEC_TYPE"]
|
||||
return os.environ["TASK_TYPE"]
|
||||
|
||||
|
||||
def is_chief() -> bool:
|
||||
return get_task_type() == "chief"
|
||||
"""Check if the current task is the 'chief'.
|
||||
|
||||
Returns:
|
||||
bool: True if the current task is the 'chief', False otherwise.
|
||||
"""
|
||||
return get_task_type() == "chief"
|
||||
|
||||
|
||||
def is_reader() -> bool:
|
||||
return get_task_type() == "datasetworker"
|
||||
"""Check if the current task is a 'datasetworker'.
|
||||
|
||||
Returns:
|
||||
bool: True if the current task is a 'datasetworker', False otherwise.
|
||||
"""
|
||||
return get_task_type() == "datasetworker"
|
||||
|
||||
|
||||
def is_dispatcher() -> bool:
|
||||
return get_task_type() == "datasetdispatcher"
|
||||
"""Check if the current task is a 'datasetdispatcher'.
|
||||
|
||||
Returns:
|
||||
bool: True if the current task is a 'datasetdispatcher', False otherwise.
|
||||
"""
|
||||
return get_task_type() == "datasetdispatcher"
|
||||
|
||||
|
||||
def get_task_index():
|
||||
if on_kf():
|
||||
pod_name = os.environ["MY_POD_NAME"]
|
||||
return int(pod_name.split("-")[-1])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
"""Get the index of the current task.
|
||||
|
||||
Returns:
|
||||
int: Task index.
|
||||
Raises:
|
||||
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
|
||||
"""
|
||||
if on_kf():
|
||||
pod_name = os.environ["MY_POD_NAME"]
|
||||
return int(pod_name.split("-")[-1])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_reader_port():
|
||||
if on_kf():
|
||||
return KF_DDS_PORT
|
||||
return SLURM_DDS_PORT
|
||||
"""Get the port used by readers.
|
||||
|
||||
Returns:
|
||||
int: Reader port.
|
||||
"""
|
||||
if on_kf():
|
||||
return KF_DDS_PORT
|
||||
return SLURM_DDS_PORT
|
||||
|
||||
|
||||
def get_dds():
|
||||
if not has_readers():
|
||||
return None
|
||||
dispatcher_address = get_dds_dispatcher_address()
|
||||
if dispatcher_address:
|
||||
return f"grpc://{dispatcher_address}"
|
||||
else:
|
||||
raise ValueError("Job does not have DDS.")
|
||||
"""Get the Distributed Data Service (DDS) address.
|
||||
|
||||
Returns:
|
||||
str: DDS address in the format 'grpc://host:port'.
|
||||
Raises:
|
||||
ValueError: If the job does not have DDS.
|
||||
"""
|
||||
if not has_readers():
|
||||
return None
|
||||
dispatcher_address = get_dds_dispatcher_address()
|
||||
if dispatcher_address:
|
||||
return f"grpc://{dispatcher_address}"
|
||||
else:
|
||||
raise ValueError("Job does not have DDS.")
|
||||
|
||||
|
||||
def get_dds_dispatcher_address():
|
||||
if not has_readers():
|
||||
return None
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
dds_host = f"{job_name}-datasetdispatcher-0"
|
||||
else:
|
||||
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
|
||||
return f"{dds_host}:{get_reader_port()}"
|
||||
"""Get the DDS dispatcher address.
|
||||
|
||||
Returns:
|
||||
str: DDS dispatcher address in the format 'host:port'.
|
||||
"""
|
||||
if not has_readers():
|
||||
return None
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
dds_host = f"{job_name}-datasetdispatcher-0"
|
||||
else:
|
||||
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
|
||||
return f"{dds_host}:{get_reader_port()}"
|
||||
|
||||
|
||||
def get_dds_worker_address():
|
||||
if not has_readers():
|
||||
return None
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
task_index = get_task_index()
|
||||
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
|
||||
else:
|
||||
node = os.environ["SLURMD_NODENAME"]
|
||||
return f"{node}:{get_reader_port()}"
|
||||
"""Get the DDS worker address.
|
||||
|
||||
Returns:
|
||||
str: DDS worker address in the format 'host:port'.
|
||||
"""
|
||||
if not has_readers():
|
||||
return None
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
task_index = get_task_index()
|
||||
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
|
||||
else:
|
||||
node = os.environ["SLURMD_NODENAME"]
|
||||
return f"{node}:{get_reader_port()}"
|
||||
|
||||
|
||||
def get_num_readers():
|
||||
if not has_readers():
|
||||
return 0
|
||||
if on_kf():
|
||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||
return int(machines_config_env["num_dataset_workers"] or 0)
|
||||
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
|
||||
"""Get the number of dataset workers.
|
||||
|
||||
Returns:
|
||||
int: Number of dataset workers.
|
||||
"""
|
||||
if not has_readers():
|
||||
return 0
|
||||
if on_kf():
|
||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||
return int(machines_config_env.get("num_dataset_workers") or 0)
|
||||
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
|
||||
|
||||
|
||||
def get_flight_server_addresses():
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
return [
|
||||
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
|
||||
for task_index in range(get_num_readers())
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
"""Get Flight server addresses for dataset workers.
|
||||
|
||||
Returns:
|
||||
List[str]: List of Flight server addresses in the format 'grpc://host:port'.
|
||||
Raises:
|
||||
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
|
||||
"""
|
||||
if on_kf():
|
||||
job_name = os.environ["JOB_NAME"]
|
||||
return [
|
||||
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
|
||||
for task_index in range(get_num_readers())
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_dds_journaling_dir():
|
||||
return os.environ.get("DATASET_JOURNALING_DIR", None)
|
||||
"""Get the DDS journaling directory.
|
||||
|
||||
Returns:
|
||||
str: DDS journaling directory.
|
||||
"""
|
||||
return os.environ.get("DATASET_JOURNALING_DIR", None)
|
||||
|
@ -14,18 +14,21 @@ def update_mean(
|
||||
weight: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update the mean according to Welford formula:
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||
Update the mean according to the Welford formula.
|
||||
|
||||
This function updates the mean and the weighted sum of values using the Welford algorithm.
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
||||
Args:
|
||||
current_mean: The value of the current accumulated mean.
|
||||
current_weight_sum: The current weighted sum.
|
||||
value: The new value that needs to be added to get a new mean.
|
||||
weight: The weights for the new value.
|
||||
|
||||
Returns: The updated mean and updated weighted sum.
|
||||
Args:
|
||||
current_mean (torch.Tensor): The value of the current accumulated mean.
|
||||
current_weight_sum (torch.Tensor): The current weighted sum.
|
||||
value (torch.Tensor): The new value that needs to be added to get a new mean.
|
||||
weight (torch.Tensor): The weights for the new value.
|
||||
|
||||
"""
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
|
||||
"""
|
||||
weight = torch.broadcast_to(weight, value.shape)
|
||||
|
||||
# Avoiding (on purpose) in-place operation when using += in case
|
||||
@ -38,12 +41,15 @@ def update_mean(
|
||||
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Merge the state from multiple workers.
|
||||
Args:
|
||||
state: A tensor with the first dimension indicating workers.
|
||||
|
||||
Returns: The accumulated mean from all workers.
|
||||
This function merges the state from multiple workers to compute the accumulated mean.
|
||||
|
||||
"""
|
||||
Args:
|
||||
state (torch.Tensor): A tensor with the first dimension indicating workers.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The accumulated mean from all workers.
|
||||
"""
|
||||
mean, weight_sum = update_mean(
|
||||
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
||||
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
||||
@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
class StableMean(torchmetrics.Metric):
|
||||
"""
|
||||
This implements a numerical stable mean metrics computation using Welford algorithm according to
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||
A numerical stable mean metric using the Welford algorithm.
|
||||
|
||||
This class implements a numerical stable mean metrics computation using the Welford algorithm.
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
|
||||
than the maximum float32 as far as the mean is within the limit of float32.
|
||||
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
||||
"""
|
||||
|
||||
Args:
|
||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||
|
||||
Attributes:
|
||||
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric):
|
||||
)
|
||||
|
||||
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
|
||||
"""
|
||||
Update the current mean.
|
||||
Args:
|
||||
value: Value to update the mean with.
|
||||
weight: weight to use. Shape should be broadcastable to that of value.
|
||||
"""
|
||||
"""Update the current mean.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Value to update the mean with.
|
||||
weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
|
||||
"""
|
||||
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
|
||||
|
||||
if not isinstance(weight, torch.Tensor):
|
||||
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
|
||||
)
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
"""
|
||||
Compute and return the accumulated mean.
|
||||
"""
|
||||
"""Compute and return the accumulated mean.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The accumulated mean.
|
||||
"""
|
||||
return self.mean_and_weight_sum[0]
|
||||
|
@ -29,6 +29,9 @@ def _compute_helper(
|
||||
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
|
||||
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
|
||||
we assume that they are correct prediction (i.e weight = 0).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The computed AUROC
|
||||
"""
|
||||
dim = 0
|
||||
|
||||
@ -52,24 +55,34 @@ def _compute_helper(
|
||||
|
||||
class AUROCWithMWU(torchmetrics.Metric):
|
||||
"""
|
||||
AUROC using Mann-Whitney U-test.
|
||||
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
|
||||
AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
|
||||
|
||||
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return
|
||||
the correct AUROC even if the predicted probabilities are all close to 0.
|
||||
Currently only support binary classification.
|
||||
"""
|
||||
This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
|
||||
scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
|
||||
|
||||
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
|
||||
|
||||
Note: Currently, this implementation only supports binary classification.
|
||||
|
||||
Args:
|
||||
label_threshold (float): Threshold for classifying labels as positive or negative.
|
||||
Labels above this threshold are considered positive, and those below are considered negative.
|
||||
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
|
||||
Otherwise, a warning is logged, and AUROC is computed.
|
||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||
"""
|
||||
|
||||
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
|
||||
"""
|
||||
Initializes the AUROCWithMWU metric.
|
||||
|
||||
Args:
|
||||
label_threshold: Labels strictly above this threshold are considered positive labels,
|
||||
otherwise, they are considered negative.
|
||||
raise_missing_class: If True, an error will be raise if negative or positive class is missing.
|
||||
Otherwise, we will simply log a warning.
|
||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||
"""
|
||||
Args:
|
||||
label_threshold (float): Threshold for classifying labels as positive or negative.
|
||||
Labels above this threshold are considered positive, and those below are considered negative.
|
||||
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
|
||||
Otherwise, a warning is logged, and AUROC is computed.
|
||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.add_state("predictions", default=[], dist_reduce_fx="cat")
|
||||
self.add_state("target", default=[], dist_reduce_fx="cat")
|
||||
|
Loading…
x
Reference in New Issue
Block a user