Compare commits

...

3 Commits

Author SHA1 Message Date
Rajveer Rathod 91e0453899
Merge 590e8b76fe into b85210863f 2023-09-12 17:19:27 +00:00
rajveer43 590e8b76fe mmm 2023-09-12 22:48:45 +05:30
rajveer43 92eaaad3ed udpate 2023-09-12 18:12:05 +05:30
8 changed files with 399 additions and 163 deletions

View File

@ -13,23 +13,30 @@ GCS_PREFIX = "gs://"
class Snapshot:
"""Checkpoints using torchsnapshot.
Also saves step to be updated by the training loop.
"""
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
"""
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
"""
Initializes a Snapshot object.
Args:
save_dir (str): Directory where checkpoints will be saved.
state (Dict[str, Any]): State dictionary containing checkpoint information.
"""
self.save_dir = save_dir
self.state = state
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property
def step(self):
"""Get the current training step."""
return self.state["extra_state"]["step"]
@step.setter
def step(self, step: int) -> None:
"""Set the current training step."""
self.state["extra_state"]["step"] = step
@property
@ -41,7 +48,15 @@ class Snapshot:
self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot":
"""Saves checkpoint with given global_step."""
"""
Saves a checkpoint with a given global step.
Args:
global_step (int): The global step to associate with the checkpoint.
Returns:
PendingSnapshot: A pending snapshot object.
"""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
start_time = time.time()
@ -58,7 +73,12 @@ class Snapshot:
return snapshot
def restore(self, checkpoint: str) -> None:
"""Restores a given checkpoint."""
"""
Restores a given checkpoint.
Args:
checkpoint (str): Path to the checkpoint to restore.
"""
snapshot = torchsnapshot.Snapshot(path=checkpoint)
logging.info(f"Restoring snapshot from {snapshot.path}.")
start_time = time.time()
@ -83,12 +103,17 @@ class Snapshot:
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> torchsnapshot.Snapshot:
"""Get torch stateless snapshot, without actually loading it.
Args:
snapshot_path: path to the model snapshot
global_step: restores from this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
Get a torch stateless snapshot, without actually loading it.
Args:
snapshot_path (str): Path to the model snapshot.
global_step (int, optional): Restores from this checkpoint if specified.
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
Returns:
torchsnapshot.Snapshot: A torch snapshot object.
"""
path = get_checkpoint(snapshot_path, global_step, missing_ok)
logging.info(f"Loading snapshot from {path}.")
return torchsnapshot.Snapshot(path=path)
@ -100,13 +125,14 @@ class Snapshot:
snapshot_emb_name: str,
weight_tensor,
) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
Args:
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
"""
Loads pretrained embedding from the snapshot to the model.
Args:
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
"""
start_time = time.time()
manifest = embedding_snapshot.get_manifest()
for path in manifest.keys():
@ -209,7 +235,22 @@ def get_checkpoint(
def get_checkpoints(save_dir: str) -> List[str]:
"""Gets all checkpoints that have been fully written."""
"""
Get a list of fully written checkpoints in the specified directory.
This function retrieves a list of fully written checkpoints in the given directory.
Checkpoints that are considered fully written include those that have a
corresponding snapshot metadata file.
Args:
save_dir (str): The directory where checkpoints are stored.
Returns:
List[str]: A list of fully written checkpoint paths.
Note:
Checkpoints are sorted by their numeric filenames in ascending order.
"""
checkpoints = []
fs = infer_fs(save_dir)
if fs.exists(save_dir):
@ -232,6 +273,18 @@ def wait_for_evaluators(
global_step: int,
timeout: int,
) -> None:
"""
Waits for all evaluators to finish and checks for their completion status.
Args:
save_dir (str): Directory where checkpoints are saved.
partition_names (List[str]): List of partition names to check for completion.
global_step (int): The global step for which to wait for evaluators.
timeout (int): Maximum time in seconds to wait for evaluators to finish.
Returns:
None: This function returns nothing but logs the progress and results.
"""
logging.info("Waiting for all evaluators to finish.")
start_time = time.time()

View File

@ -8,59 +8,69 @@ import pydantic
class BaseConfig(pydantic.BaseModel):
"""Base class for all derived config classes.
"""Base class for all derived config classes.
This class provides some convenient functionality:
- Disallows extra fields when constructing an object. User error
should be reduced by exact arguments.
- "one_of" fields. A subclass can group optional fields and enforce
that only one of the fields be set. For example:
This class provides convenient functionality and constraints for derived config classes:
```
- Disallows extra fields when constructing an object. User errors due to extraneous arguments
are minimized.
- "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields
be set. For example:
```python
class ExampleConfig(BaseConfig):
x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1")
x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1")
ExampleConfig(x=1) # ok
ExampleConfig(y=1) # ok
ExampleConfig(x=1, y=1) # throws error
ExampleConfig(x=1) # OK
ExampleConfig(y=1) # OK
ExampleConfig(x=1, y=1) # Raises an error
```
"""
class Config:
"""Forbids extras."""
Attributes:
Config (class): Configuration options for this class, forbidding extra fields.
extra = pydantic.Extra.forbid # noqa
Methods:
_field_data_map(cls, field_data_name): Create a map of fields with the provided field data.
_one_of_check(cls, values): Validate that all 'one of' fields appear exactly once.
_at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once.
pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging.
@classmethod
@functools.lru_cache()
def _field_data_map(cls, field_data_name):
"""Create a map of fields with provided the field data."""
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
"""
class Config:
"""Configuration options that forbid extra fields."""
extra = pydantic.Extra.forbid # noqa
@pydantic.root_validator
def _one_of_check(cls, values):
"""Validate that all 'one of' fields are appear exactly once."""
one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values
@classmethod
@functools.lru_cache()
def _field_data_map(cls, field_data_name):
"""Create a map of fields with the provided field data."""
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
@pydantic.root_validator
def _at_most_one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
@pydantic.root_validator
def _one_of_check(cls, values):
"""Validate that all 'one of' fields appear exactly once."""
one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values
def pretty_print(self) -> str:
"""Return a human legible (yaml) representation of the config useful for logging."""
return yaml.dump(self.dict())
@pydantic.root_validator
def _at_most_one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
def pretty_print(self) -> str:
"""Return a human-readable (YAML) representation of the config useful for logging."""
return yaml.dump(self.dict())

View File

@ -13,6 +13,13 @@ class BaseConfigTest(TestCase):
def test_extra_forbidden(self):
"""
Test that extra fields are forbidden when creating a Config instance.
This test case checks whether the `BaseConfig` class correctly raises a
`pydantic.ValidationError` when extra fields are provided when creating a
`Config` instance.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig):
x: int
@ -24,6 +31,13 @@ class BaseConfigTest(TestCase):
def test_one_of(self):
"""
Test the use of the `one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `one_of` attribute in a `Config`
instance. It verifies that the `pydantic.Field` correctly enforces the
specified constraint.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f")
@ -39,6 +53,13 @@ class BaseConfigTest(TestCase):
def test_at_most_one_of(self):
"""
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
This test case checks the behavior of the `at_most_one_of` attribute in a
`Config` instance. It verifies that the `pydantic.Field` enforces the
constraint where at most one of the specified fields can be provided.
Raises:
AssertionError: If the test fails.
"""
class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f")

View File

@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
"""Recommend method to load a config file (a yaml file) and parse it.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
"""
Recommend method to Load and parse a configuration from a YAML file.
This function loads a configuration from a YAML file, parses it, and returns an instance of the
specified config type.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
Args:
config_type (Type[BaseConfig]): The Pydantic config class to load.
yaml_path (str): The path to the YAML configuration file.
Returns:
BaseConfig: An instance of the specified config type populated with values from the YAML file.
Example:
Suppose you have a YAML file 'my_config.yaml' containing the following:
```yaml
x: 42
y: "hello"
```
You can load and parse it using this function as follows:
```python
my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml')
```
Note:
This function performs environment variable substitution in the YAML file. It replaces
occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable
values. If an environment variable does not exist, the string is left unchanged.
"""
def _substitute(s):
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())

View File

@ -8,11 +8,27 @@ import pydantic
class _PointlessConfig(BaseConfig):
a: int
user: str
def test_load_config_from_yaml(tmp_path):
"""Test loading a configuration from a YAML file and verifying its values.
This test function checks the functionality of the `load_config_from_yaml` function by creating
a temporary YAML configuration file, loading it, and asserting that the loaded config object
has the expected values.
Args:
tmp_path: A temporary directory provided by the `pytest` framework.
Test Steps:
1. Create a temporary YAML file containing configuration data.
2. Use the `load_config_from_yaml` function to load the configuration from the YAML file.
3. Assert that the loaded configuration object has the expected values.
"""
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
with open(yaml_path, "w") as yaml_file:
yaml_file.write("""a: 3\nuser: ${USER}\n""")

View File

@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222
def on_kf():
return "SPEC_TYPE" in os.environ
"""Check if the code is running on Kubernetes with Kubeflow (KF) environment.
Returns:
bool: True if running on KF, False otherwise.
"""
return "SPEC_TYPE" in os.environ
def has_readers():
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env["dataset_worker"] is not None
return os.environ.get("HAS_READERS", "False") == "True"
"""Check if the current task has dataset workers.
Returns:
bool: True if the task has dataset workers, False otherwise.
"""
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env.get("dataset_worker") is not None
return os.environ.get("HAS_READERS", "False") == "True"
def get_task_type():
if on_kf():
return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"]
"""Get the type of the current task.
Returns:
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
"""
if on_kf():
return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"]
def is_chief() -> bool:
return get_task_type() == "chief"
"""Check if the current task is the 'chief'.
Returns:
bool: True if the current task is the 'chief', False otherwise.
"""
return get_task_type() == "chief"
def is_reader() -> bool:
return get_task_type() == "datasetworker"
"""Check if the current task is a 'datasetworker'.
Returns:
bool: True if the current task is a 'datasetworker', False otherwise.
"""
return get_task_type() == "datasetworker"
def is_dispatcher() -> bool:
return get_task_type() == "datasetdispatcher"
"""Check if the current task is a 'datasetdispatcher'.
Returns:
bool: True if the current task is a 'datasetdispatcher', False otherwise.
"""
return get_task_type() == "datasetdispatcher"
def get_task_index():
if on_kf():
pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1])
else:
raise NotImplementedError
"""Get the index of the current task.
Returns:
int: Task index.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf():
pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1])
else:
raise NotImplementedError
def get_reader_port():
if on_kf():
return KF_DDS_PORT
return SLURM_DDS_PORT
"""Get the port used by readers.
Returns:
int: Reader port.
"""
if on_kf():
return KF_DDS_PORT
return SLURM_DDS_PORT
def get_dds():
if not has_readers():
return None
dispatcher_address = get_dds_dispatcher_address()
if dispatcher_address:
return f"grpc://{dispatcher_address}"
else:
raise ValueError("Job does not have DDS.")
"""Get the Distributed Data Service (DDS) address.
Returns:
str: DDS address in the format 'grpc://host:port'.
Raises:
ValueError: If the job does not have DDS.
"""
if not has_readers():
return None
dispatcher_address = get_dds_dispatcher_address()
if dispatcher_address:
return f"grpc://{dispatcher_address}"
else:
raise ValueError("Job does not have DDS.")
def get_dds_dispatcher_address():
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
dds_host = f"{job_name}-datasetdispatcher-0"
else:
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
return f"{dds_host}:{get_reader_port()}"
"""Get the DDS dispatcher address.
Returns:
str: DDS dispatcher address in the format 'host:port'.
"""
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
dds_host = f"{job_name}-datasetdispatcher-0"
else:
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
return f"{dds_host}:{get_reader_port()}"
def get_dds_worker_address():
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
task_index = get_task_index()
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
else:
node = os.environ["SLURMD_NODENAME"]
return f"{node}:{get_reader_port()}"
"""Get the DDS worker address.
Returns:
str: DDS worker address in the format 'host:port'.
"""
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
task_index = get_task_index()
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
else:
node = os.environ["SLURMD_NODENAME"]
return f"{node}:{get_reader_port()}"
def get_num_readers():
if not has_readers():
return 0
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env["num_dataset_workers"] or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
"""Get the number of dataset workers.
Returns:
int: Number of dataset workers.
"""
if not has_readers():
return 0
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env.get("num_dataset_workers") or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
def get_flight_server_addresses():
if on_kf():
job_name = os.environ["JOB_NAME"]
return [
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
for task_index in range(get_num_readers())
]
else:
raise NotImplementedError
"""Get Flight server addresses for dataset workers.
Returns:
List[str]: List of Flight server addresses in the format 'grpc://host:port'.
Raises:
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
"""
if on_kf():
job_name = os.environ["JOB_NAME"]
return [
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
for task_index in range(get_num_readers())
]
else:
raise NotImplementedError
def get_dds_journaling_dir():
return os.environ.get("DATASET_JOURNALING_DIR", None)
"""Get the DDS journaling directory.
Returns:
str: DDS journaling directory.
"""
return os.environ.get("DATASET_JOURNALING_DIR", None)

View File

@ -14,18 +14,21 @@ def update_mean(
weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update the mean according to Welford formula:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
Update the mean according to the Welford formula.
This function updates the mean and the weighted sum of values using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
Args:
current_mean: The value of the current accumulated mean.
current_weight_sum: The current weighted sum.
value: The new value that needs to be added to get a new mean.
weight: The weights for the new value.
Returns: The updated mean and updated weighted sum.
Args:
current_mean (torch.Tensor): The value of the current accumulated mean.
current_weight_sum (torch.Tensor): The current weighted sum.
value (torch.Tensor): The new value that needs to be added to get a new mean.
weight (torch.Tensor): The weights for the new value.
"""
Returns:
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
"""
weight = torch.broadcast_to(weight, value.shape)
# Avoiding (on purpose) in-place operation when using += in case
@ -38,12 +41,15 @@ def update_mean(
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
"""
Merge the state from multiple workers.
Args:
state: A tensor with the first dimension indicating workers.
Returns: The accumulated mean from all workers.
This function merges the state from multiple workers to compute the accumulated mean.
"""
Args:
state (torch.Tensor): A tensor with the first dimension indicating workers.
Returns:
torch.Tensor: The accumulated mean from all workers.
"""
mean, weight_sum = update_mean(
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
class StableMean(torchmetrics.Metric):
"""
This implements a numerical stable mean metrics computation using Welford algorithm according to
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
A numerical stable mean metric using the Welford algorithm.
This class implements a numerical stable mean metrics computation using the Welford algorithm.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
than the maximum float32 as far as the mean is within the limit of float32.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
"""
Args:
**kwargs: Additional parameters supported by all torchmetrics.Metric.
Attributes:
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
"""
def __init__(self, **kwargs):
"""
@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric):
)
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
"""
Update the current mean.
Args:
value: Value to update the mean with.
weight: weight to use. Shape should be broadcastable to that of value.
"""
"""Update the current mean.
Args:
value (torch.Tensor): Value to update the mean with.
weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
"""
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
if not isinstance(weight, torch.Tensor):
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
)
def compute(self) -> torch.Tensor:
"""
Compute and return the accumulated mean.
"""
"""Compute and return the accumulated mean.
Returns:
torch.Tensor: The accumulated mean.
"""
return self.mean_and_weight_sum[0]

View File

@ -29,6 +29,9 @@ def _compute_helper(
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
we assume that they are correct prediction (i.e weight = 0).
Returns:
torch.Tensor: The computed AUROC
"""
dim = 0
@ -52,24 +55,34 @@ def _compute_helper(
class AUROCWithMWU(torchmetrics.Metric):
"""
AUROC using Mann-Whitney U-test.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return
the correct AUROC even if the predicted probabilities are all close to 0.
Currently only support binary classification.
"""
This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
Note: Currently, this implementation only supports binary classification.
Args:
label_threshold (float): Threshold for classifying labels as positive or negative.
Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
"""
Initializes the AUROCWithMWU metric.
Args:
label_threshold: Labels strictly above this threshold are considered positive labels,
otherwise, they are considered negative.
raise_missing_class: If True, an error will be raise if negative or positive class is missing.
Otherwise, we will simply log a warning.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
Args:
label_threshold (float): Threshold for classifying labels as positive or negative.
Labels above this threshold are considered positive, and those below are considered negative.
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
Otherwise, a warning is logged, and AUROC is computed.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
super().__init__(**kwargs)
self.add_state("predictions", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")