This commit is contained in:
rajveer43 2023-09-12 22:48:45 +05:30
parent 92eaaad3ed
commit 590e8b76fe
3 changed files with 201 additions and 96 deletions

View File

@ -9,35 +9,72 @@ FLIGHT_SERVER_PORT: int = 2222
def on_kf(): def on_kf():
"""Check if the code is running on Kubernetes with Kubeflow (KF) environment.
Returns:
bool: True if running on KF, False otherwise.
"""
return "SPEC_TYPE" in os.environ return "SPEC_TYPE" in os.environ
def has_readers(): def has_readers():
"""Check if the current task has dataset workers.
Returns:
bool: True if the task has dataset workers, False otherwise.
"""
if on_kf(): if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env["dataset_worker"] is not None return machines_config_env.get("dataset_worker") is not None
return os.environ.get("HAS_READERS", "False") == "True" return os.environ.get("HAS_READERS", "False") == "True"
def get_task_type(): def get_task_type():
"""Get the type of the current task.
Returns:
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
"""
if on_kf(): if on_kf():
return os.environ["SPEC_TYPE"] return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"] return os.environ["TASK_TYPE"]
def is_chief() -> bool: def is_chief() -> bool:
"""Check if the current task is the 'chief'.
Returns:
bool: True if the current task is the 'chief', False otherwise.
"""
return get_task_type() == "chief" return get_task_type() == "chief"
def is_reader() -> bool: def is_reader() -> bool:
"""Check if the current task is a 'datasetworker'.
Returns:
bool: True if the current task is a 'datasetworker', False otherwise.
"""
return get_task_type() == "datasetworker" return get_task_type() == "datasetworker"
def is_dispatcher() -> bool: def is_dispatcher() -> bool:
"""Check if the current task is a 'datasetdispatcher'.
Returns:
bool: True if the current task is a 'datasetdispatcher', False otherwise.
"""
return get_task_type() == "datasetdispatcher" return get_task_type() == "datasetdispatcher"
def get_task_index(): def get_task_index():
"""Get the index of the current task.
Returns:
int: Task index.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf(): if on_kf():
pod_name = os.environ["MY_POD_NAME"] pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1]) return int(pod_name.split("-")[-1])
@ -46,12 +83,24 @@ def get_task_index():
def get_reader_port(): def get_reader_port():
"""Get the port used by readers.
Returns:
int: Reader port.
"""
if on_kf(): if on_kf():
return KF_DDS_PORT return KF_DDS_PORT
return SLURM_DDS_PORT return SLURM_DDS_PORT
def get_dds(): def get_dds():
"""Get the Distributed Data Service (DDS) address.
Returns:
str: DDS address in the format 'grpc://host:port'.
Raises:
ValueError: If the job does not have DDS.
"""
if not has_readers(): if not has_readers():
return None return None
dispatcher_address = get_dds_dispatcher_address() dispatcher_address = get_dds_dispatcher_address()
@ -62,6 +111,11 @@ def get_dds():
def get_dds_dispatcher_address(): def get_dds_dispatcher_address():
"""Get the DDS dispatcher address.
Returns:
str: DDS dispatcher address in the format 'host:port'.
"""
if not has_readers(): if not has_readers():
return None return None
if on_kf(): if on_kf():
@ -73,6 +127,11 @@ def get_dds_dispatcher_address():
def get_dds_worker_address(): def get_dds_worker_address():
"""Get the DDS worker address.
Returns:
str: DDS worker address in the format 'host:port'.
"""
if not has_readers(): if not has_readers():
return None return None
if on_kf(): if on_kf():
@ -85,15 +144,27 @@ def get_dds_worker_address():
def get_num_readers(): def get_num_readers():
"""Get the number of dataset workers.
Returns:
int: Number of dataset workers.
"""
if not has_readers(): if not has_readers():
return 0 return 0
if on_kf(): if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env["num_dataset_workers"] or 0) return int(machines_config_env.get("num_dataset_workers") or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(",")) return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
def get_flight_server_addresses(): def get_flight_server_addresses():
"""Get Flight server addresses for dataset workers.
Returns:
List[str]: List of Flight server addresses in the format 'grpc://host:port'.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf(): if on_kf():
job_name = os.environ["JOB_NAME"] job_name = os.environ["JOB_NAME"]
return [ return [
@ -105,4 +176,9 @@ def get_flight_server_addresses():
def get_dds_journaling_dir(): def get_dds_journaling_dir():
"""Get the DDS journaling directory.
Returns:
str: DDS journaling directory.
"""
return os.environ.get("DATASET_JOURNALING_DIR", None) return os.environ.get("DATASET_JOURNALING_DIR", None)

View File

@ -14,17 +14,20 @@ def update_mean(
weight: torch.Tensor, weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Update the mean according to Welford formula: Update the mean according to the Welford formula.
This function updates the mean and the weighted sum of values using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information. See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
Args: Args:
current_mean: The value of the current accumulated mean. current_mean (torch.Tensor): The value of the current accumulated mean.
current_weight_sum: The current weighted sum. current_weight_sum (torch.Tensor): The current weighted sum.
value: The new value that needs to be added to get a new mean. value (torch.Tensor): The new value that needs to be added to get a new mean.
weight: The weights for the new value. weight (torch.Tensor): The weights for the new value.
Returns: The updated mean and updated weighted sum.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
""" """
weight = torch.broadcast_to(weight, value.shape) weight = torch.broadcast_to(weight, value.shape)
@ -38,11 +41,14 @@ def update_mean(
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
""" """
Merge the state from multiple workers. Merge the state from multiple workers.
This function merges the state from multiple workers to compute the accumulated mean.
Args: Args:
state: A tensor with the first dimension indicating workers. state (torch.Tensor): A tensor with the first dimension indicating workers.
Returns: The accumulated mean from all workers.
Returns:
torch.Tensor: The accumulated mean from all workers.
""" """
mean, weight_sum = update_mean( mean, weight_sum = update_mean(
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
@ -55,11 +61,19 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
class StableMean(torchmetrics.Metric): class StableMean(torchmetrics.Metric):
""" """
This implements a numerical stable mean metrics computation using Welford algorithm according to A numerical stable mean metric using the Welford algorithm.
This class implements a numerical stable mean metrics computation using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version. https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
For example when using float32, the algorithm will give a valid output even if the "sum" is larger For example when using float32, the algorithm will give a valid output even if the "sum" is larger
than the maximum float32 as far as the mean is within the limit of float32. than the maximum float32 as far as the mean is within the limit of float32.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information. See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
Args:
**kwargs: Additional parameters supported by all torchmetrics.Metric.
Attributes:
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -75,11 +89,11 @@ class StableMean(torchmetrics.Metric):
) )
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None: def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
""" """Update the current mean.
Update the current mean.
Args: Args:
value: Value to update the mean with. value (torch.Tensor): Value to update the mean with.
weight: weight to use. Shape should be broadcastable to that of value. weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
""" """
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
) )
def compute(self) -> torch.Tensor: def compute(self) -> torch.Tensor:
""" """Compute and return the accumulated mean.
Compute and return the accumulated mean.
Returns:
torch.Tensor: The accumulated mean.
""" """
return self.mean_and_weight_sum[0] return self.mean_and_weight_sum[0]

View File

@ -29,6 +29,9 @@ def _compute_helper(
equal_predictions_as_incorrect: For positive & negative labels having identical scores, equal_predictions_as_incorrect: For positive & negative labels having identical scores,
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise, we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
we assume that they are correct prediction (i.e weight = 0). we assume that they are correct prediction (i.e weight = 0).
Returns:
torch.Tensor: The computed AUROC
""" """
dim = 0 dim = 0
@ -52,22 +55,32 @@ def _compute_helper(
class AUROCWithMWU(torchmetrics.Metric): class AUROCWithMWU(torchmetrics.Metric):
""" """
AUROC using Mann-Whitney U-test. AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve. See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return Note: Currently, this implementation only supports binary classification.
the correct AUROC even if the predicted probabilities are all close to 0.
Currently only support binary classification. Args:
label_threshold (float): Threshold for classifying labels as positive or negative.
Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
""" """
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs): def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
""" """
Initializes the AUROCWithMWU metric.
Args: Args:
label_threshold: Labels strictly above this threshold are considered positive labels, label_threshold (float): Threshold for classifying labels as positive or negative.
otherwise, they are considered negative. Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class: If True, an error will be raise if negative or positive class is missing. raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, we will simply log a warning. Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric. **kwargs: Additional parameters supported by all torchmetrics.Metric.
""" """
super().__init__(**kwargs) super().__init__(**kwargs)