This commit is contained in:
rajveer43 2023-09-12 22:48:45 +05:30
parent 92eaaad3ed
commit 590e8b76fe
3 changed files with 201 additions and 96 deletions

View File

@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222
def on_kf(): def on_kf():
return "SPEC_TYPE" in os.environ """Check if the code is running on Kubernetes with Kubeflow (KF) environment.
Returns:
bool: True if running on KF, False otherwise.
"""
return "SPEC_TYPE" in os.environ
def has_readers(): def has_readers():
if on_kf(): """Check if the current task has dataset workers.
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env["dataset_worker"] is not None Returns:
return os.environ.get("HAS_READERS", "False") == "True" bool: True if the task has dataset workers, False otherwise.
"""
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env.get("dataset_worker") is not None
return os.environ.get("HAS_READERS", "False") == "True"
def get_task_type(): def get_task_type():
if on_kf(): """Get the type of the current task.
return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"] Returns:
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
"""
if on_kf():
return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"]
def is_chief() -> bool: def is_chief() -> bool:
return get_task_type() == "chief" """Check if the current task is the 'chief'.
Returns:
bool: True if the current task is the 'chief', False otherwise.
"""
return get_task_type() == "chief"
def is_reader() -> bool: def is_reader() -> bool:
return get_task_type() == "datasetworker" """Check if the current task is a 'datasetworker'.
Returns:
bool: True if the current task is a 'datasetworker', False otherwise.
"""
return get_task_type() == "datasetworker"
def is_dispatcher() -> bool: def is_dispatcher() -> bool:
return get_task_type() == "datasetdispatcher" """Check if the current task is a 'datasetdispatcher'.
Returns:
bool: True if the current task is a 'datasetdispatcher', False otherwise.
"""
return get_task_type() == "datasetdispatcher"
def get_task_index(): def get_task_index():
if on_kf(): """Get the index of the current task.
pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1]) Returns:
else: int: Task index.
raise NotImplementedError Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf():
pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1])
else:
raise NotImplementedError
def get_reader_port(): def get_reader_port():
if on_kf(): """Get the port used by readers.
return KF_DDS_PORT
return SLURM_DDS_PORT Returns:
int: Reader port.
"""
if on_kf():
return KF_DDS_PORT
return SLURM_DDS_PORT
def get_dds(): def get_dds():
if not has_readers(): """Get the Distributed Data Service (DDS) address.
return None
dispatcher_address = get_dds_dispatcher_address() Returns:
if dispatcher_address: str: DDS address in the format 'grpc://host:port'.
return f"grpc://{dispatcher_address}" Raises:
else: ValueError: If the job does not have DDS.
raise ValueError("Job does not have DDS.") """
if not has_readers():
return None
dispatcher_address = get_dds_dispatcher_address()
if dispatcher_address:
return f"grpc://{dispatcher_address}"
else:
raise ValueError("Job does not have DDS.")
def get_dds_dispatcher_address(): def get_dds_dispatcher_address():
if not has_readers(): """Get the DDS dispatcher address.
return None
if on_kf(): Returns:
job_name = os.environ["JOB_NAME"] str: DDS dispatcher address in the format 'host:port'.
dds_host = f"{job_name}-datasetdispatcher-0" """
else: if not has_readers():
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"] return None
return f"{dds_host}:{get_reader_port()}" if on_kf():
job_name = os.environ["JOB_NAME"]
dds_host = f"{job_name}-datasetdispatcher-0"
else:
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
return f"{dds_host}:{get_reader_port()}"
def get_dds_worker_address(): def get_dds_worker_address():
if not has_readers(): """Get the DDS worker address.
return None
if on_kf(): Returns:
job_name = os.environ["JOB_NAME"] str: DDS worker address in the format 'host:port'.
task_index = get_task_index() """
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}" if not has_readers():
else: return None
node = os.environ["SLURMD_NODENAME"] if on_kf():
return f"{node}:{get_reader_port()}" job_name = os.environ["JOB_NAME"]
task_index = get_task_index()
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
else:
node = os.environ["SLURMD_NODENAME"]
return f"{node}:{get_reader_port()}"
def get_num_readers(): def get_num_readers():
if not has_readers(): """Get the number of dataset workers.
return 0
if on_kf(): Returns:
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"]) int: Number of dataset workers.
return int(machines_config_env["num_dataset_workers"] or 0) """
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(",")) if not has_readers():
return 0
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env.get("num_dataset_workers") or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
def get_flight_server_addresses(): def get_flight_server_addresses():
if on_kf(): """Get Flight server addresses for dataset workers.
job_name = os.environ["JOB_NAME"]
return [ Returns:
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}" List[str]: List of Flight server addresses in the format 'grpc://host:port'.
for task_index in range(get_num_readers()) Raises:
] NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
else: """
raise NotImplementedError if on_kf():
job_name = os.environ["JOB_NAME"]
return [
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
for task_index in range(get_num_readers())
]
else:
raise NotImplementedError
def get_dds_journaling_dir(): def get_dds_journaling_dir():
return os.environ.get("DATASET_JOURNALING_DIR", None) """Get the DDS journaling directory.
Returns:
str: DDS journaling directory.
"""
return os.environ.get("DATASET_JOURNALING_DIR", None)

View File

@ -14,18 +14,21 @@ def update_mean(
weight: torch.Tensor, weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Update the mean according to Welford formula: Update the mean according to the Welford formula.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
This function updates the mean and the weighted sum of values using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information. See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
Args:
current_mean: The value of the current accumulated mean.
current_weight_sum: The current weighted sum.
value: The new value that needs to be added to get a new mean.
weight: The weights for the new value.
Returns: The updated mean and updated weighted sum. Args:
current_mean (torch.Tensor): The value of the current accumulated mean.
current_weight_sum (torch.Tensor): The current weighted sum.
value (torch.Tensor): The new value that needs to be added to get a new mean.
weight (torch.Tensor): The weights for the new value.
""" Returns:
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
"""
weight = torch.broadcast_to(weight, value.shape) weight = torch.broadcast_to(weight, value.shape)
# Avoiding (on purpose) in-place operation when using += in case # Avoiding (on purpose) in-place operation when using += in case
@ -38,12 +41,15 @@ def update_mean(
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
""" """
Merge the state from multiple workers. Merge the state from multiple workers.
Args:
state: A tensor with the first dimension indicating workers.
Returns: The accumulated mean from all workers. This function merges the state from multiple workers to compute the accumulated mean.
""" Args:
state (torch.Tensor): A tensor with the first dimension indicating workers.
Returns:
torch.Tensor: The accumulated mean from all workers.
"""
mean, weight_sum = update_mean( mean, weight_sum = update_mean(
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device), current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
class StableMean(torchmetrics.Metric): class StableMean(torchmetrics.Metric):
""" """
This implements a numerical stable mean metrics computation using Welford algorithm according to A numerical stable mean metric using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
This class implements a numerical stable mean metrics computation using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
For example when using float32, the algorithm will give a valid output even if the "sum" is larger For example when using float32, the algorithm will give a valid output even if the "sum" is larger
than the maximum float32 as far as the mean is within the limit of float32. than the maximum float32 as far as the mean is within the limit of float32.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information. See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
"""
Args:
**kwargs: Additional parameters supported by all torchmetrics.Metric.
Attributes:
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric):
) )
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None: def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
""" """Update the current mean.
Update the current mean.
Args: Args:
value: Value to update the mean with. value (torch.Tensor): Value to update the mean with.
weight: weight to use. Shape should be broadcastable to that of value. weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
""" """
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
if not isinstance(weight, torch.Tensor): if not isinstance(weight, torch.Tensor):
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
) )
def compute(self) -> torch.Tensor: def compute(self) -> torch.Tensor:
""" """Compute and return the accumulated mean.
Compute and return the accumulated mean.
""" Returns:
torch.Tensor: The accumulated mean.
"""
return self.mean_and_weight_sum[0] return self.mean_and_weight_sum[0]

View File

@ -29,6 +29,9 @@ def _compute_helper(
equal_predictions_as_incorrect: For positive & negative labels having identical scores, equal_predictions_as_incorrect: For positive & negative labels having identical scores,
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise, we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
we assume that they are correct prediction (i.e weight = 0). we assume that they are correct prediction (i.e weight = 0).
Returns:
torch.Tensor: The computed AUROC
""" """
dim = 0 dim = 0
@ -52,24 +55,34 @@ def _compute_helper(
class AUROCWithMWU(torchmetrics.Metric): class AUROCWithMWU(torchmetrics.Metric):
""" """
AUROC using Mann-Whitney U-test. AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
the correct AUROC even if the predicted probabilities are all close to 0. scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
Currently only support binary classification.
""" See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
Note: Currently, this implementation only supports binary classification.
Args:
label_threshold (float): Threshold for classifying labels as positive or negative.
Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs): def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
""" """
Initializes the AUROCWithMWU metric.
Args: Args:
label_threshold: Labels strictly above this threshold are considered positive labels, label_threshold (float): Threshold for classifying labels as positive or negative.
otherwise, they are considered negative. Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class: If True, an error will be raise if negative or positive class is missing. raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, we will simply log a warning. Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric. **kwargs: Additional parameters supported by all torchmetrics.Metric.
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
self.add_state("predictions", default=[], dist_reduce_fx="cat") self.add_state("predictions", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat")