Compare commits
2 Commits
deec9a820e
...
590e8b76fe
Author | SHA1 | Date |
---|---|---|
rajveer43 | 590e8b76fe | |
rajveer43 | 92eaaad3ed |
|
@ -13,23 +13,30 @@ GCS_PREFIX = "gs://"
|
||||||
|
|
||||||
|
|
||||||
class Snapshot:
|
class Snapshot:
|
||||||
"""Checkpoints using torchsnapshot.
|
|
||||||
|
|
||||||
Also saves step to be updated by the training loop.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
|
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Initializes a Snapshot object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir (str): Directory where checkpoints will be saved.
|
||||||
|
state (Dict[str, Any]): State dictionary containing checkpoint information.
|
||||||
|
"""
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.state = state
|
self.state = state
|
||||||
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step(self):
|
def step(self):
|
||||||
|
"""Get the current training step."""
|
||||||
return self.state["extra_state"]["step"]
|
return self.state["extra_state"]["step"]
|
||||||
|
|
||||||
@step.setter
|
@step.setter
|
||||||
def step(self, step: int) -> None:
|
def step(self, step: int) -> None:
|
||||||
|
"""Set the current training step."""
|
||||||
self.state["extra_state"]["step"] = step
|
self.state["extra_state"]["step"] = step
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -41,7 +48,15 @@ class Snapshot:
|
||||||
self.state["extra_state"]["walltime"] = walltime
|
self.state["extra_state"]["walltime"] = walltime
|
||||||
|
|
||||||
def save(self, global_step: int) -> "PendingSnapshot":
|
def save(self, global_step: int) -> "PendingSnapshot":
|
||||||
"""Saves checkpoint with given global_step."""
|
"""
|
||||||
|
Saves a checkpoint with a given global step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_step (int): The global step to associate with the checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PendingSnapshot: A pending snapshot object.
|
||||||
|
"""
|
||||||
path = os.path.join(self.save_dir, str(global_step))
|
path = os.path.join(self.save_dir, str(global_step))
|
||||||
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -58,7 +73,12 @@ class Snapshot:
|
||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
def restore(self, checkpoint: str) -> None:
|
def restore(self, checkpoint: str) -> None:
|
||||||
"""Restores a given checkpoint."""
|
"""
|
||||||
|
Restores a given checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint (str): Path to the checkpoint to restore.
|
||||||
|
"""
|
||||||
snapshot = torchsnapshot.Snapshot(path=checkpoint)
|
snapshot = torchsnapshot.Snapshot(path=checkpoint)
|
||||||
logging.info(f"Restoring snapshot from {snapshot.path}.")
|
logging.info(f"Restoring snapshot from {snapshot.path}.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -83,12 +103,17 @@ class Snapshot:
|
||||||
global_step: Optional[int] = None,
|
global_step: Optional[int] = None,
|
||||||
missing_ok: bool = False,
|
missing_ok: bool = False,
|
||||||
) -> torchsnapshot.Snapshot:
|
) -> torchsnapshot.Snapshot:
|
||||||
"""Get torch stateless snapshot, without actually loading it.
|
|
||||||
Args:
|
|
||||||
snapshot_path: path to the model snapshot
|
|
||||||
global_step: restores from this checkpoint if specified.
|
|
||||||
missing_ok: if True and checkpoints do not exist, returns without restoration.
|
|
||||||
"""
|
"""
|
||||||
|
Get a torch stateless snapshot, without actually loading it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot_path (str): Path to the model snapshot.
|
||||||
|
global_step (int, optional): Restores from this checkpoint if specified.
|
||||||
|
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torchsnapshot.Snapshot: A torch snapshot object.
|
||||||
|
"""
|
||||||
path = get_checkpoint(snapshot_path, global_step, missing_ok)
|
path = get_checkpoint(snapshot_path, global_step, missing_ok)
|
||||||
logging.info(f"Loading snapshot from {path}.")
|
logging.info(f"Loading snapshot from {path}.")
|
||||||
return torchsnapshot.Snapshot(path=path)
|
return torchsnapshot.Snapshot(path=path)
|
||||||
|
@ -100,13 +125,14 @@ class Snapshot:
|
||||||
snapshot_emb_name: str,
|
snapshot_emb_name: str,
|
||||||
weight_tensor,
|
weight_tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Loads pretrained embedding from the snapshot to the model.
|
|
||||||
Utilise partial lodaing meachanism from torchsnapshot.
|
|
||||||
Args:
|
|
||||||
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
|
|
||||||
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
|
|
||||||
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
|
|
||||||
"""
|
"""
|
||||||
|
Loads pretrained embedding from the snapshot to the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
|
||||||
|
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
|
||||||
|
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
|
||||||
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
manifest = embedding_snapshot.get_manifest()
|
manifest = embedding_snapshot.get_manifest()
|
||||||
for path in manifest.keys():
|
for path in manifest.keys():
|
||||||
|
@ -209,7 +235,22 @@ def get_checkpoint(
|
||||||
|
|
||||||
|
|
||||||
def get_checkpoints(save_dir: str) -> List[str]:
|
def get_checkpoints(save_dir: str) -> List[str]:
|
||||||
"""Gets all checkpoints that have been fully written."""
|
"""
|
||||||
|
Get a list of fully written checkpoints in the specified directory.
|
||||||
|
|
||||||
|
This function retrieves a list of fully written checkpoints in the given directory.
|
||||||
|
Checkpoints that are considered fully written include those that have a
|
||||||
|
corresponding snapshot metadata file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir (str): The directory where checkpoints are stored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of fully written checkpoint paths.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Checkpoints are sorted by their numeric filenames in ascending order.
|
||||||
|
"""
|
||||||
checkpoints = []
|
checkpoints = []
|
||||||
fs = infer_fs(save_dir)
|
fs = infer_fs(save_dir)
|
||||||
if fs.exists(save_dir):
|
if fs.exists(save_dir):
|
||||||
|
@ -232,6 +273,18 @@ def wait_for_evaluators(
|
||||||
global_step: int,
|
global_step: int,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Waits for all evaluators to finish and checks for their completion status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir (str): Directory where checkpoints are saved.
|
||||||
|
partition_names (List[str]): List of partition names to check for completion.
|
||||||
|
global_step (int): The global step for which to wait for evaluators.
|
||||||
|
timeout (int): Maximum time in seconds to wait for evaluators to finish.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: This function returns nothing but logs the progress and results.
|
||||||
|
"""
|
||||||
logging.info("Waiting for all evaluators to finish.")
|
logging.info("Waiting for all evaluators to finish.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
|
@ -8,59 +8,69 @@ import pydantic
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(pydantic.BaseModel):
|
class BaseConfig(pydantic.BaseModel):
|
||||||
"""Base class for all derived config classes.
|
"""Base class for all derived config classes.
|
||||||
|
|
||||||
This class provides some convenient functionality:
|
This class provides convenient functionality and constraints for derived config classes:
|
||||||
- Disallows extra fields when constructing an object. User error
|
|
||||||
should be reduced by exact arguments.
|
|
||||||
- "one_of" fields. A subclass can group optional fields and enforce
|
|
||||||
that only one of the fields be set. For example:
|
|
||||||
|
|
||||||
```
|
- Disallows extra fields when constructing an object. User errors due to extraneous arguments
|
||||||
|
are minimized.
|
||||||
|
- "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields
|
||||||
|
be set. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
class ExampleConfig(BaseConfig):
|
class ExampleConfig(BaseConfig):
|
||||||
x: int = Field(None, one_of="group_1")
|
x: int = Field(None, one_of="group_1")
|
||||||
y: int = Field(None, one_of="group_1")
|
y: int = Field(None, one_of="group_1")
|
||||||
|
|
||||||
ExampleConfig(x=1) # ok
|
ExampleConfig(x=1) # OK
|
||||||
ExampleConfig(y=1) # ok
|
ExampleConfig(y=1) # OK
|
||||||
ExampleConfig(x=1, y=1) # throws error
|
ExampleConfig(x=1, y=1) # Raises an error
|
||||||
```
|
```
|
||||||
"""
|
|
||||||
|
|
||||||
class Config:
|
Attributes:
|
||||||
"""Forbids extras."""
|
Config (class): Configuration options for this class, forbidding extra fields.
|
||||||
|
|
||||||
extra = pydantic.Extra.forbid # noqa
|
Methods:
|
||||||
|
_field_data_map(cls, field_data_name): Create a map of fields with the provided field data.
|
||||||
|
_one_of_check(cls, values): Validate that all 'one of' fields appear exactly once.
|
||||||
|
_at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once.
|
||||||
|
pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging.
|
||||||
|
|
||||||
@classmethod
|
"""
|
||||||
@functools.lru_cache()
|
class Config:
|
||||||
def _field_data_map(cls, field_data_name):
|
"""Configuration options that forbid extra fields."""
|
||||||
"""Create a map of fields with provided the field data."""
|
extra = pydantic.Extra.forbid # noqa
|
||||||
schema = cls.schema()
|
|
||||||
one_of = collections.defaultdict(list)
|
|
||||||
for field, fdata in schema["properties"].items():
|
|
||||||
if field_data_name in fdata:
|
|
||||||
one_of[fdata[field_data_name]].append(field)
|
|
||||||
return one_of
|
|
||||||
|
|
||||||
@pydantic.root_validator
|
@classmethod
|
||||||
def _one_of_check(cls, values):
|
@functools.lru_cache()
|
||||||
"""Validate that all 'one of' fields are appear exactly once."""
|
def _field_data_map(cls, field_data_name):
|
||||||
one_of_map = cls._field_data_map("one_of")
|
"""Create a map of fields with the provided field data."""
|
||||||
for one_of, field_names in one_of_map.items():
|
schema = cls.schema()
|
||||||
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
one_of = collections.defaultdict(list)
|
||||||
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
for field, fdata in schema["properties"].items():
|
||||||
return values
|
if field_data_name in fdata:
|
||||||
|
one_of[fdata[field_data_name]].append(field)
|
||||||
|
return one_of
|
||||||
|
|
||||||
@pydantic.root_validator
|
@pydantic.root_validator
|
||||||
def _at_most_one_of_check(cls, values):
|
def _one_of_check(cls, values):
|
||||||
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
"""Validate that all 'one of' fields appear exactly once."""
|
||||||
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
one_of_map = cls._field_data_map("one_of")
|
||||||
for one_of, field_names in at_most_one_of_map.items():
|
for one_of, field_names in one_of_map.items():
|
||||||
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
||||||
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def pretty_print(self) -> str:
|
@pydantic.root_validator
|
||||||
"""Return a human legible (yaml) representation of the config useful for logging."""
|
def _at_most_one_of_check(cls, values):
|
||||||
return yaml.dump(self.dict())
|
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
||||||
|
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
||||||
|
for one_of, field_names in at_most_one_of_map.items():
|
||||||
|
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
||||||
|
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
def pretty_print(self) -> str:
|
||||||
|
"""Return a human-readable (YAML) representation of the config useful for logging."""
|
||||||
|
return yaml.dump(self.dict())
|
||||||
|
|
|
@ -13,6 +13,13 @@ class BaseConfigTest(TestCase):
|
||||||
def test_extra_forbidden(self):
|
def test_extra_forbidden(self):
|
||||||
"""
|
"""
|
||||||
Test that extra fields are forbidden when creating a Config instance.
|
Test that extra fields are forbidden when creating a Config instance.
|
||||||
|
|
||||||
|
This test case checks whether the `BaseConfig` class correctly raises a
|
||||||
|
`pydantic.ValidationError` when extra fields are provided when creating a
|
||||||
|
`Config` instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the test fails.
|
||||||
"""
|
"""
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
x: int
|
x: int
|
||||||
|
@ -24,6 +31,13 @@ class BaseConfigTest(TestCase):
|
||||||
def test_one_of(self):
|
def test_one_of(self):
|
||||||
"""
|
"""
|
||||||
Test the use of the `one_of` attribute for fields in a Config instance.
|
Test the use of the `one_of` attribute for fields in a Config instance.
|
||||||
|
|
||||||
|
This test case checks the behavior of the `one_of` attribute in a `Config`
|
||||||
|
instance. It verifies that the `pydantic.Field` correctly enforces the
|
||||||
|
specified constraint.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the test fails.
|
||||||
"""
|
"""
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
x: int = pydantic.Field(None, one_of="f")
|
x: int = pydantic.Field(None, one_of="f")
|
||||||
|
@ -39,6 +53,13 @@ class BaseConfigTest(TestCase):
|
||||||
def test_at_most_one_of(self):
|
def test_at_most_one_of(self):
|
||||||
"""
|
"""
|
||||||
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
|
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
|
||||||
|
|
||||||
|
This test case checks the behavior of the `at_most_one_of` attribute in a
|
||||||
|
`Config` instance. It verifies that the `pydantic.Field` enforces the
|
||||||
|
constraint where at most one of the specified fields can be provided.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the test fails.
|
||||||
"""
|
"""
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
x: int = pydantic.Field(None, at_most_one_of="f")
|
x: int = pydantic.Field(None, at_most_one_of="f")
|
||||||
|
|
|
@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
|
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
|
||||||
"""Recommend method to load a config file (a yaml file) and parse it.
|
|
||||||
|
|
||||||
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
|
||||||
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
|
||||||
"""
|
"""
|
||||||
|
Recommend method to Load and parse a configuration from a YAML file.
|
||||||
|
|
||||||
|
This function loads a configuration from a YAML file, parses it, and returns an instance of the
|
||||||
|
specified config type.
|
||||||
|
|
||||||
|
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
||||||
|
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_type (Type[BaseConfig]): The Pydantic config class to load.
|
||||||
|
yaml_path (str): The path to the YAML configuration file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseConfig: An instance of the specified config type populated with values from the YAML file.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Suppose you have a YAML file 'my_config.yaml' containing the following:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
x: 42
|
||||||
|
y: "hello"
|
||||||
|
```
|
||||||
|
|
||||||
|
You can load and parse it using this function as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml')
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function performs environment variable substitution in the YAML file. It replaces
|
||||||
|
occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable
|
||||||
|
values. If an environment variable does not exist, the string is left unchanged.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def _substitute(s):
|
def _substitute(s):
|
||||||
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
||||||
|
|
|
@ -8,11 +8,27 @@ import pydantic
|
||||||
|
|
||||||
|
|
||||||
class _PointlessConfig(BaseConfig):
|
class _PointlessConfig(BaseConfig):
|
||||||
|
|
||||||
a: int
|
a: int
|
||||||
user: str
|
user: str
|
||||||
|
|
||||||
|
|
||||||
def test_load_config_from_yaml(tmp_path):
|
def test_load_config_from_yaml(tmp_path):
|
||||||
|
"""Test loading a configuration from a YAML file and verifying its values.
|
||||||
|
|
||||||
|
This test function checks the functionality of the `load_config_from_yaml` function by creating
|
||||||
|
a temporary YAML configuration file, loading it, and asserting that the loaded config object
|
||||||
|
has the expected values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmp_path: A temporary directory provided by the `pytest` framework.
|
||||||
|
|
||||||
|
Test Steps:
|
||||||
|
1. Create a temporary YAML file containing configuration data.
|
||||||
|
2. Use the `load_config_from_yaml` function to load the configuration from the YAML file.
|
||||||
|
3. Assert that the loaded configuration object has the expected values.
|
||||||
|
|
||||||
|
"""
|
||||||
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
|
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
|
||||||
with open(yaml_path, "w") as yaml_file:
|
with open(yaml_path, "w") as yaml_file:
|
||||||
yaml_file.write("""a: 3\nuser: ${USER}\n""")
|
yaml_file.write("""a: 3\nuser: ${USER}\n""")
|
||||||
|
|
|
@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222
|
||||||
|
|
||||||
|
|
||||||
def on_kf():
|
def on_kf():
|
||||||
return "SPEC_TYPE" in os.environ
|
"""Check if the code is running on Kubernetes with Kubeflow (KF) environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if running on KF, False otherwise.
|
||||||
|
"""
|
||||||
|
return "SPEC_TYPE" in os.environ
|
||||||
|
|
||||||
|
|
||||||
def has_readers():
|
def has_readers():
|
||||||
if on_kf():
|
"""Check if the current task has dataset workers.
|
||||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
|
||||||
return machines_config_env["dataset_worker"] is not None
|
Returns:
|
||||||
return os.environ.get("HAS_READERS", "False") == "True"
|
bool: True if the task has dataset workers, False otherwise.
|
||||||
|
"""
|
||||||
|
if on_kf():
|
||||||
|
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||||
|
return machines_config_env.get("dataset_worker") is not None
|
||||||
|
return os.environ.get("HAS_READERS", "False") == "True"
|
||||||
|
|
||||||
|
|
||||||
def get_task_type():
|
def get_task_type():
|
||||||
if on_kf():
|
"""Get the type of the current task.
|
||||||
return os.environ["SPEC_TYPE"]
|
|
||||||
return os.environ["TASK_TYPE"]
|
Returns:
|
||||||
|
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
|
||||||
|
"""
|
||||||
|
if on_kf():
|
||||||
|
return os.environ["SPEC_TYPE"]
|
||||||
|
return os.environ["TASK_TYPE"]
|
||||||
|
|
||||||
|
|
||||||
def is_chief() -> bool:
|
def is_chief() -> bool:
|
||||||
return get_task_type() == "chief"
|
"""Check if the current task is the 'chief'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current task is the 'chief', False otherwise.
|
||||||
|
"""
|
||||||
|
return get_task_type() == "chief"
|
||||||
|
|
||||||
|
|
||||||
def is_reader() -> bool:
|
def is_reader() -> bool:
|
||||||
return get_task_type() == "datasetworker"
|
"""Check if the current task is a 'datasetworker'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current task is a 'datasetworker', False otherwise.
|
||||||
|
"""
|
||||||
|
return get_task_type() == "datasetworker"
|
||||||
|
|
||||||
|
|
||||||
def is_dispatcher() -> bool:
|
def is_dispatcher() -> bool:
|
||||||
return get_task_type() == "datasetdispatcher"
|
"""Check if the current task is a 'datasetdispatcher'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current task is a 'datasetdispatcher', False otherwise.
|
||||||
|
"""
|
||||||
|
return get_task_type() == "datasetdispatcher"
|
||||||
|
|
||||||
|
|
||||||
def get_task_index():
|
def get_task_index():
|
||||||
if on_kf():
|
"""Get the index of the current task.
|
||||||
pod_name = os.environ["MY_POD_NAME"]
|
|
||||||
return int(pod_name.split("-")[-1])
|
Returns:
|
||||||
else:
|
int: Task index.
|
||||||
raise NotImplementedError
|
Raises:
|
||||||
|
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
|
||||||
|
"""
|
||||||
|
if on_kf():
|
||||||
|
pod_name = os.environ["MY_POD_NAME"]
|
||||||
|
return int(pod_name.split("-")[-1])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def get_reader_port():
|
def get_reader_port():
|
||||||
if on_kf():
|
"""Get the port used by readers.
|
||||||
return KF_DDS_PORT
|
|
||||||
return SLURM_DDS_PORT
|
Returns:
|
||||||
|
int: Reader port.
|
||||||
|
"""
|
||||||
|
if on_kf():
|
||||||
|
return KF_DDS_PORT
|
||||||
|
return SLURM_DDS_PORT
|
||||||
|
|
||||||
|
|
||||||
def get_dds():
|
def get_dds():
|
||||||
if not has_readers():
|
"""Get the Distributed Data Service (DDS) address.
|
||||||
return None
|
|
||||||
dispatcher_address = get_dds_dispatcher_address()
|
Returns:
|
||||||
if dispatcher_address:
|
str: DDS address in the format 'grpc://host:port'.
|
||||||
return f"grpc://{dispatcher_address}"
|
Raises:
|
||||||
else:
|
ValueError: If the job does not have DDS.
|
||||||
raise ValueError("Job does not have DDS.")
|
"""
|
||||||
|
if not has_readers():
|
||||||
|
return None
|
||||||
|
dispatcher_address = get_dds_dispatcher_address()
|
||||||
|
if dispatcher_address:
|
||||||
|
return f"grpc://{dispatcher_address}"
|
||||||
|
else:
|
||||||
|
raise ValueError("Job does not have DDS.")
|
||||||
|
|
||||||
|
|
||||||
def get_dds_dispatcher_address():
|
def get_dds_dispatcher_address():
|
||||||
if not has_readers():
|
"""Get the DDS dispatcher address.
|
||||||
return None
|
|
||||||
if on_kf():
|
Returns:
|
||||||
job_name = os.environ["JOB_NAME"]
|
str: DDS dispatcher address in the format 'host:port'.
|
||||||
dds_host = f"{job_name}-datasetdispatcher-0"
|
"""
|
||||||
else:
|
if not has_readers():
|
||||||
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
|
return None
|
||||||
return f"{dds_host}:{get_reader_port()}"
|
if on_kf():
|
||||||
|
job_name = os.environ["JOB_NAME"]
|
||||||
|
dds_host = f"{job_name}-datasetdispatcher-0"
|
||||||
|
else:
|
||||||
|
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
|
||||||
|
return f"{dds_host}:{get_reader_port()}"
|
||||||
|
|
||||||
|
|
||||||
def get_dds_worker_address():
|
def get_dds_worker_address():
|
||||||
if not has_readers():
|
"""Get the DDS worker address.
|
||||||
return None
|
|
||||||
if on_kf():
|
Returns:
|
||||||
job_name = os.environ["JOB_NAME"]
|
str: DDS worker address in the format 'host:port'.
|
||||||
task_index = get_task_index()
|
"""
|
||||||
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
|
if not has_readers():
|
||||||
else:
|
return None
|
||||||
node = os.environ["SLURMD_NODENAME"]
|
if on_kf():
|
||||||
return f"{node}:{get_reader_port()}"
|
job_name = os.environ["JOB_NAME"]
|
||||||
|
task_index = get_task_index()
|
||||||
|
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
|
||||||
|
else:
|
||||||
|
node = os.environ["SLURMD_NODENAME"]
|
||||||
|
return f"{node}:{get_reader_port()}"
|
||||||
|
|
||||||
|
|
||||||
def get_num_readers():
|
def get_num_readers():
|
||||||
if not has_readers():
|
"""Get the number of dataset workers.
|
||||||
return 0
|
|
||||||
if on_kf():
|
Returns:
|
||||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
int: Number of dataset workers.
|
||||||
return int(machines_config_env["num_dataset_workers"] or 0)
|
"""
|
||||||
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
|
if not has_readers():
|
||||||
|
return 0
|
||||||
|
if on_kf():
|
||||||
|
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||||
|
return int(machines_config_env.get("num_dataset_workers") or 0)
|
||||||
|
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
|
||||||
|
|
||||||
|
|
||||||
def get_flight_server_addresses():
|
def get_flight_server_addresses():
|
||||||
if on_kf():
|
"""Get Flight server addresses for dataset workers.
|
||||||
job_name = os.environ["JOB_NAME"]
|
|
||||||
return [
|
Returns:
|
||||||
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
|
List[str]: List of Flight server addresses in the format 'grpc://host:port'.
|
||||||
for task_index in range(get_num_readers())
|
Raises:
|
||||||
]
|
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
|
||||||
else:
|
"""
|
||||||
raise NotImplementedError
|
if on_kf():
|
||||||
|
job_name = os.environ["JOB_NAME"]
|
||||||
|
return [
|
||||||
|
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
|
||||||
|
for task_index in range(get_num_readers())
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def get_dds_journaling_dir():
|
def get_dds_journaling_dir():
|
||||||
return os.environ.get("DATASET_JOURNALING_DIR", None)
|
"""Get the DDS journaling directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: DDS journaling directory.
|
||||||
|
"""
|
||||||
|
return os.environ.get("DATASET_JOURNALING_DIR", None)
|
||||||
|
|
|
@ -14,18 +14,21 @@ def update_mean(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Update the mean according to Welford formula:
|
Update the mean according to the Welford formula.
|
||||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
|
||||||
|
This function updates the mean and the weighted sum of values using the Welford algorithm.
|
||||||
|
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||||
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
||||||
Args:
|
|
||||||
current_mean: The value of the current accumulated mean.
|
|
||||||
current_weight_sum: The current weighted sum.
|
|
||||||
value: The new value that needs to be added to get a new mean.
|
|
||||||
weight: The weights for the new value.
|
|
||||||
|
|
||||||
Returns: The updated mean and updated weighted sum.
|
Args:
|
||||||
|
current_mean (torch.Tensor): The value of the current accumulated mean.
|
||||||
|
current_weight_sum (torch.Tensor): The current weighted sum.
|
||||||
|
value (torch.Tensor): The new value that needs to be added to get a new mean.
|
||||||
|
weight (torch.Tensor): The weights for the new value.
|
||||||
|
|
||||||
"""
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
|
||||||
|
"""
|
||||||
weight = torch.broadcast_to(weight, value.shape)
|
weight = torch.broadcast_to(weight, value.shape)
|
||||||
|
|
||||||
# Avoiding (on purpose) in-place operation when using += in case
|
# Avoiding (on purpose) in-place operation when using += in case
|
||||||
|
@ -38,12 +41,15 @@ def update_mean(
|
||||||
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Merge the state from multiple workers.
|
Merge the state from multiple workers.
|
||||||
Args:
|
|
||||||
state: A tensor with the first dimension indicating workers.
|
|
||||||
|
|
||||||
Returns: The accumulated mean from all workers.
|
This function merges the state from multiple workers to compute the accumulated mean.
|
||||||
|
|
||||||
"""
|
Args:
|
||||||
|
state (torch.Tensor): A tensor with the first dimension indicating workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The accumulated mean from all workers.
|
||||||
|
"""
|
||||||
mean, weight_sum = update_mean(
|
mean, weight_sum = update_mean(
|
||||||
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
||||||
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
||||||
|
@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
class StableMean(torchmetrics.Metric):
|
class StableMean(torchmetrics.Metric):
|
||||||
"""
|
"""
|
||||||
This implements a numerical stable mean metrics computation using Welford algorithm according to
|
A numerical stable mean metric using the Welford algorithm.
|
||||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
|
||||||
|
This class implements a numerical stable mean metrics computation using the Welford algorithm.
|
||||||
|
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||||
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
|
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
|
||||||
than the maximum float32 as far as the mean is within the limit of float32.
|
than the maximum float32 as far as the mean is within the limit of float32.
|
||||||
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric):
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
|
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
|
||||||
"""
|
"""Update the current mean.
|
||||||
Update the current mean.
|
|
||||||
Args:
|
Args:
|
||||||
value: Value to update the mean with.
|
value (torch.Tensor): Value to update the mean with.
|
||||||
weight: weight to use. Shape should be broadcastable to that of value.
|
weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
|
||||||
"""
|
"""
|
||||||
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
|
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
|
||||||
|
|
||||||
if not isinstance(weight, torch.Tensor):
|
if not isinstance(weight, torch.Tensor):
|
||||||
|
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute(self) -> torch.Tensor:
|
def compute(self) -> torch.Tensor:
|
||||||
"""
|
"""Compute and return the accumulated mean.
|
||||||
Compute and return the accumulated mean.
|
|
||||||
"""
|
Returns:
|
||||||
|
torch.Tensor: The accumulated mean.
|
||||||
|
"""
|
||||||
return self.mean_and_weight_sum[0]
|
return self.mean_and_weight_sum[0]
|
||||||
|
|
|
@ -29,6 +29,9 @@ def _compute_helper(
|
||||||
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
|
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
|
||||||
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
|
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
|
||||||
we assume that they are correct prediction (i.e weight = 0).
|
we assume that they are correct prediction (i.e weight = 0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The computed AUROC
|
||||||
"""
|
"""
|
||||||
dim = 0
|
dim = 0
|
||||||
|
|
||||||
|
@ -52,24 +55,34 @@ def _compute_helper(
|
||||||
|
|
||||||
class AUROCWithMWU(torchmetrics.Metric):
|
class AUROCWithMWU(torchmetrics.Metric):
|
||||||
"""
|
"""
|
||||||
AUROC using Mann-Whitney U-test.
|
AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
|
||||||
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
|
|
||||||
|
|
||||||
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return
|
This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
|
||||||
the correct AUROC even if the predicted probabilities are all close to 0.
|
scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
|
||||||
Currently only support binary classification.
|
|
||||||
"""
|
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
|
||||||
|
|
||||||
|
Note: Currently, this implementation only supports binary classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label_threshold (float): Threshold for classifying labels as positive or negative.
|
||||||
|
Labels above this threshold are considered positive, and those below are considered negative.
|
||||||
|
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
|
||||||
|
Otherwise, a warning is logged, and AUROC is computed.
|
||||||
|
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
|
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
Initializes the AUROCWithMWU metric.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label_threshold: Labels strictly above this threshold are considered positive labels,
|
label_threshold (float): Threshold for classifying labels as positive or negative.
|
||||||
otherwise, they are considered negative.
|
Labels above this threshold are considered positive, and those below are considered negative.
|
||||||
raise_missing_class: If True, an error will be raise if negative or positive class is missing.
|
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
|
||||||
Otherwise, we will simply log a warning.
|
Otherwise, a warning is logged, and AUROC is computed.
|
||||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.add_state("predictions", default=[], dist_reduce_fx="cat")
|
self.add_state("predictions", default=[], dist_reduce_fx="cat")
|
||||||
self.add_state("target", default=[], dist_reduce_fx="cat")
|
self.add_state("target", default=[], dist_reduce_fx="cat")
|
||||||
|
|
Loading…
Reference in New Issue