mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-04 15:55:07 +01:00
Merge 0813989fd9
into b85210863f
This commit is contained in:
commit
63573c6f22
121
common/batch.py
121
common/batch.py
@ -12,27 +12,79 @@ from torchrec.streamable import Pipelineable
|
|||||||
|
|
||||||
|
|
||||||
class BatchBase(Pipelineable, abc.ABC):
|
class BatchBase(Pipelineable, abc.ABC):
|
||||||
|
"""
|
||||||
|
A base class for batches used in pipelines.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Convert the batch into a dictionary representation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary representation of the batch.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If the method is not implemented in a subclass.
|
||||||
|
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to(self, device: torch.device, non_blocking: bool = False):
|
def to(self, device: torch.device, non_blocking: bool = False):
|
||||||
|
"""
|
||||||
|
Move the batch to the specified device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): The target device.
|
||||||
|
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchBase: A new batch on the target device.
|
||||||
|
|
||||||
|
"""
|
||||||
args = {}
|
args = {}
|
||||||
for feature_name, feature_value in self.as_dict().items():
|
for feature_name, feature_value in self.as_dict().items():
|
||||||
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
||||||
return self.__class__(**args)
|
return self.__class__(**args)
|
||||||
|
|
||||||
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
|
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
|
||||||
|
"""
|
||||||
|
Record a CUDA stream for all tensors in the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream (torch.cuda.streams.Stream): The CUDA stream to record.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
for feature_value in self.as_dict().values():
|
for feature_value in self.as_dict().values():
|
||||||
feature_value.record_stream(stream)
|
feature_value.record_stream(stream)
|
||||||
|
|
||||||
def pin_memory(self):
|
def pin_memory(self):
|
||||||
|
"""
|
||||||
|
Pin memory for all tensors in the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchBase: A new batch with pinned memory.
|
||||||
|
|
||||||
|
"""
|
||||||
args = {}
|
args = {}
|
||||||
for feature_name, feature_value in self.as_dict().items():
|
for feature_name, feature_value in self.as_dict().items():
|
||||||
args[feature_name] = feature_value.pin_memory()
|
args[feature_name] = feature_value.pin_memory()
|
||||||
return self.__class__(**args)
|
return self.__class__(**args)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
"""
|
||||||
|
Generate a string representation of the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string representation of the batch.
|
||||||
|
|
||||||
|
"""
|
||||||
def obj2str(v):
|
def obj2str(v):
|
||||||
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
||||||
|
|
||||||
@ -40,6 +92,16 @@ class BatchBase(Pipelineable, abc.ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self) -> int:
|
def batch_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the batch size from the tensors in the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The batch size.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the batch size cannot be determined from the tensors.
|
||||||
|
|
||||||
|
"""
|
||||||
for tensor in self.as_dict().values():
|
for tensor in self.as_dict().values():
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
continue
|
continue
|
||||||
@ -51,11 +113,32 @@ class BatchBase(Pipelineable, abc.ABC):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataclassBatch(BatchBase):
|
class DataclassBatch(BatchBase):
|
||||||
|
"""
|
||||||
|
A batch class that uses dataclasses to define its fields.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
@classmethod
|
@classmethod
|
||||||
def feature_names(cls):
|
def feature_names(cls):
|
||||||
|
"""
|
||||||
|
Get the feature names of the dataclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of feature names.
|
||||||
|
|
||||||
|
"""
|
||||||
return list(cls.__dataclass_fields__.keys())
|
return list(cls.__dataclass_fields__.keys())
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
|
"""
|
||||||
|
Convert the dataclass batch into a dictionary representation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary representation of the batch.
|
||||||
|
|
||||||
|
"""
|
||||||
return {
|
return {
|
||||||
feature_name: getattr(self, feature_name)
|
feature_name: getattr(self, feature_name)
|
||||||
for feature_name in self.feature_names()
|
for feature_name in self.feature_names()
|
||||||
@ -64,7 +147,18 @@ class DataclassBatch(BatchBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_schema(name: str, schema):
|
def from_schema(name: str, schema):
|
||||||
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
|
"""
|
||||||
|
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the custom batch class.
|
||||||
|
schema: The schema or structure of the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type[DataclassBatch]: A custom batch class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
return dataclasses.make_dataclass(
|
return dataclasses.make_dataclass(
|
||||||
cls_name=name,
|
cls_name=name,
|
||||||
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
|
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
|
||||||
@ -73,6 +167,17 @@ class DataclassBatch(BatchBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_fields(name: str, fields: dict):
|
def from_fields(name: str, fields: dict):
|
||||||
|
"""
|
||||||
|
Create a custom batch subclass from a set of fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the custom batch class.
|
||||||
|
fields (dict): A dictionary specifying the fields and their types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type[DataclassBatch]: A custom batch class.
|
||||||
|
|
||||||
|
"""
|
||||||
return dataclasses.make_dataclass(
|
return dataclasses.make_dataclass(
|
||||||
cls_name=name,
|
cls_name=name,
|
||||||
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
||||||
@ -81,5 +186,19 @@ class DataclassBatch(BatchBase):
|
|||||||
|
|
||||||
|
|
||||||
class DictionaryBatch(BatchBase, dict):
|
class DictionaryBatch(BatchBase, dict):
|
||||||
|
"""
|
||||||
|
A batch class that represents data as a dictionary.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Convert the dictionary batch into a dictionary representation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary representation of the batch.
|
||||||
|
|
||||||
|
"""
|
||||||
return self
|
return self
|
||||||
|
@ -13,23 +13,30 @@ GCS_PREFIX = "gs://"
|
|||||||
|
|
||||||
|
|
||||||
class Snapshot:
|
class Snapshot:
|
||||||
"""Checkpoints using torchsnapshot.
|
|
||||||
|
|
||||||
Also saves step to be updated by the training loop.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Checkpoints using torchsnapshot. Also saves step to be updated by the training loop.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
|
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Initializes a Snapshot object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir (str): Directory where checkpoints will be saved.
|
||||||
|
state (Dict[str, Any]): State dictionary containing checkpoint information.
|
||||||
|
"""
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.state = state
|
self.state = state
|
||||||
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step(self):
|
def step(self):
|
||||||
|
"""Get the current training step."""
|
||||||
return self.state["extra_state"]["step"]
|
return self.state["extra_state"]["step"]
|
||||||
|
|
||||||
@step.setter
|
@step.setter
|
||||||
def step(self, step: int) -> None:
|
def step(self, step: int) -> None:
|
||||||
|
"""Set the current training step."""
|
||||||
self.state["extra_state"]["step"] = step
|
self.state["extra_state"]["step"] = step
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -41,7 +48,15 @@ class Snapshot:
|
|||||||
self.state["extra_state"]["walltime"] = walltime
|
self.state["extra_state"]["walltime"] = walltime
|
||||||
|
|
||||||
def save(self, global_step: int) -> "PendingSnapshot":
|
def save(self, global_step: int) -> "PendingSnapshot":
|
||||||
"""Saves checkpoint with given global_step."""
|
"""
|
||||||
|
Saves a checkpoint with a given global step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_step (int): The global step to associate with the checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PendingSnapshot: A pending snapshot object.
|
||||||
|
"""
|
||||||
path = os.path.join(self.save_dir, str(global_step))
|
path = os.path.join(self.save_dir, str(global_step))
|
||||||
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -58,7 +73,12 @@ class Snapshot:
|
|||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
def restore(self, checkpoint: str) -> None:
|
def restore(self, checkpoint: str) -> None:
|
||||||
"""Restores a given checkpoint."""
|
"""
|
||||||
|
Restores a given checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint (str): Path to the checkpoint to restore.
|
||||||
|
"""
|
||||||
snapshot = torchsnapshot.Snapshot(path=checkpoint)
|
snapshot = torchsnapshot.Snapshot(path=checkpoint)
|
||||||
logging.info(f"Restoring snapshot from {snapshot.path}.")
|
logging.info(f"Restoring snapshot from {snapshot.path}.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -83,12 +103,17 @@ class Snapshot:
|
|||||||
global_step: Optional[int] = None,
|
global_step: Optional[int] = None,
|
||||||
missing_ok: bool = False,
|
missing_ok: bool = False,
|
||||||
) -> torchsnapshot.Snapshot:
|
) -> torchsnapshot.Snapshot:
|
||||||
"""Get torch stateless snapshot, without actually loading it.
|
|
||||||
Args:
|
|
||||||
snapshot_path: path to the model snapshot
|
|
||||||
global_step: restores from this checkpoint if specified.
|
|
||||||
missing_ok: if True and checkpoints do not exist, returns without restoration.
|
|
||||||
"""
|
"""
|
||||||
|
Get a torch stateless snapshot, without actually loading it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot_path (str): Path to the model snapshot.
|
||||||
|
global_step (int, optional): Restores from this checkpoint if specified.
|
||||||
|
missing_ok (bool): If True and checkpoints do not exist, returns without restoration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torchsnapshot.Snapshot: A torch snapshot object.
|
||||||
|
"""
|
||||||
path = get_checkpoint(snapshot_path, global_step, missing_ok)
|
path = get_checkpoint(snapshot_path, global_step, missing_ok)
|
||||||
logging.info(f"Loading snapshot from {path}.")
|
logging.info(f"Loading snapshot from {path}.")
|
||||||
return torchsnapshot.Snapshot(path=path)
|
return torchsnapshot.Snapshot(path=path)
|
||||||
@ -100,13 +125,14 @@ class Snapshot:
|
|||||||
snapshot_emb_name: str,
|
snapshot_emb_name: str,
|
||||||
weight_tensor,
|
weight_tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Loads pretrained embedding from the snapshot to the model.
|
|
||||||
Utilise partial lodaing meachanism from torchsnapshot.
|
|
||||||
Args:
|
|
||||||
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
|
|
||||||
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
|
|
||||||
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
|
|
||||||
"""
|
"""
|
||||||
|
Loads pretrained embedding from the snapshot to the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_snapshot (torchsnapshot.Snapshot): Path to the snapshot containing pretrained embeddings (EBC).
|
||||||
|
snapshot_emb_name (str): Name of the layer in the snapshot model containing the EBC.
|
||||||
|
weight_tensor: Embeddings tensor of the current model where the embeddings will be loaded.
|
||||||
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
manifest = embedding_snapshot.get_manifest()
|
manifest = embedding_snapshot.get_manifest()
|
||||||
for path in manifest.keys():
|
for path in manifest.keys():
|
||||||
@ -209,7 +235,22 @@ def get_checkpoint(
|
|||||||
|
|
||||||
|
|
||||||
def get_checkpoints(save_dir: str) -> List[str]:
|
def get_checkpoints(save_dir: str) -> List[str]:
|
||||||
"""Gets all checkpoints that have been fully written."""
|
"""
|
||||||
|
Get a list of fully written checkpoints in the specified directory.
|
||||||
|
|
||||||
|
This function retrieves a list of fully written checkpoints in the given directory.
|
||||||
|
Checkpoints that are considered fully written include those that have a
|
||||||
|
corresponding snapshot metadata file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir (str): The directory where checkpoints are stored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of fully written checkpoint paths.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Checkpoints are sorted by their numeric filenames in ascending order.
|
||||||
|
"""
|
||||||
checkpoints = []
|
checkpoints = []
|
||||||
fs = infer_fs(save_dir)
|
fs = infer_fs(save_dir)
|
||||||
if fs.exists(save_dir):
|
if fs.exists(save_dir):
|
||||||
@ -232,6 +273,18 @@ def wait_for_evaluators(
|
|||||||
global_step: int,
|
global_step: int,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Waits for all evaluators to finish and checks for their completion status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir (str): Directory where checkpoints are saved.
|
||||||
|
partition_names (List[str]): List of partition names to check for completion.
|
||||||
|
global_step (int): The global step for which to wait for evaluators.
|
||||||
|
timeout (int): Maximum time in seconds to wait for evaluators to finish.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: This function returns nothing but logs the progress and results.
|
||||||
|
"""
|
||||||
logging.info("Waiting for all evaluators to finish.")
|
logging.info("Waiting for all evaluators to finish.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -5,6 +5,15 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
|
|
||||||
def maybe_setup_tensorflow():
|
def maybe_setup_tensorflow():
|
||||||
|
"""
|
||||||
|
Try to import TensorFlow and disable GPU devices if TensorFlow is available.
|
||||||
|
|
||||||
|
This function checks if TensorFlow is installed and, if so, disables GPU devices used by TensorFlow to avoid conflicts with PyTorch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -14,6 +23,18 @@ def maybe_setup_tensorflow():
|
|||||||
|
|
||||||
|
|
||||||
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
|
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
|
||||||
|
"""
|
||||||
|
Set up the distributed environment and get the appropriate torch device.
|
||||||
|
|
||||||
|
This function sets up the distributed environment using PyTorch's `dist.init_process_group` and retrieves the appropriate torch device based on GPU availability and local rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tf_ok (bool, optional): Whether to run `maybe_setup_tensorflow` to disable TensorFlow GPU devices. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.device: The torch device for the current process.
|
||||||
|
|
||||||
|
"""
|
||||||
if tf_ok:
|
if tf_ok:
|
||||||
maybe_setup_tensorflow()
|
maybe_setup_tensorflow()
|
||||||
|
|
||||||
|
@ -8,6 +8,18 @@ LOCAL_FS = LocalFileSystem()
|
|||||||
|
|
||||||
|
|
||||||
def infer_fs(path: str):
|
def infer_fs(path: str):
|
||||||
|
"""
|
||||||
|
Infer the file system (fs) type based on the given path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The file path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The inferred file system type ("gs://" for Google Cloud Storage, "hdfs://" for Hadoop Distributed File System, or "local" for local file system).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If the path indicates Hadoop Distributed File System (HDFS) which is not yet supported.
|
||||||
|
"""
|
||||||
if path.startswith("gs://"):
|
if path.startswith("gs://"):
|
||||||
return GCS_FS
|
return GCS_FS
|
||||||
elif path.startswith("hdfs://"):
|
elif path.startswith("hdfs://"):
|
||||||
@ -18,8 +30,26 @@ def infer_fs(path: str):
|
|||||||
|
|
||||||
|
|
||||||
def is_local_fs(fs):
|
def is_local_fs(fs):
|
||||||
|
"""
|
||||||
|
Check if the given file system is the local file system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fs (str): The file system type to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the file system is the local file system, False otherwise.
|
||||||
|
"""
|
||||||
return fs == LOCAL_FS
|
return fs == LOCAL_FS
|
||||||
|
|
||||||
|
|
||||||
def is_gcs_fs(fs):
|
def is_gcs_fs(fs):
|
||||||
|
"""
|
||||||
|
Check if the given file system is Google Cloud Storage (GCS).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fs (str): The file system type to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the file system is GCS, False otherwise.
|
||||||
|
"""
|
||||||
return fs == GCS_FS
|
return fs == GCS_FS
|
||||||
|
@ -49,15 +49,25 @@ def log_ebc_norms(
|
|||||||
ebc_keys: List[str],
|
ebc_keys: List[str],
|
||||||
sample_size: int = 4_000_000,
|
sample_size: int = 4_000_000,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Logs the norms of the embedding tables as specified by ebc_keys.
|
|
||||||
As of now, log average norm per rank.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_state_dict: model.state_dict()
|
|
||||||
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
|
|
||||||
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
|
|
||||||
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
|
|
||||||
"""
|
"""
|
||||||
|
Logs the norms of the embedding tables as specified by ebc_keys.
|
||||||
|
As of now, log average norm per rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_state_dict (dict): The state dictionary of the model.
|
||||||
|
ebc_keys (List[str]): List of embedding keys from state_dict to log.
|
||||||
|
Each key must contain the full name, e.g., "model.embeddings.ebc.embedding_bags.meta__user_id.weight".
|
||||||
|
sample_size (int, optional): Limits the number of rows per rank to compute average on
|
||||||
|
to avoid out-of-memory (OOM) errors. Defaults to 4,000,000.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: A dictionary containing the computed norms of the embedding tables.
|
||||||
|
The keys are in the format "{emb_key}-norm-{idx}".
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function computes and logs the average norm of embedding tables across ranks.
|
||||||
|
It gathers the norms from all ranks and returns them as a dictionary.
|
||||||
|
"""
|
||||||
norm_logs = dict()
|
norm_logs = dict()
|
||||||
for emb_key in ebc_keys:
|
for emb_key in ebc_keys:
|
||||||
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
|
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
|
||||||
|
@ -11,6 +11,20 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class LargeEmbeddings(nn.Module):
|
class LargeEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
A module for handling large embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
large_embeddings_config (LargeEmbeddingsConfig): The configuration for large embeddings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
ebc (EmbeddingBagCollection): An instance of EmbeddingBagCollection for managing embeddings.
|
||||||
|
surgery_cut_point (torch.nn.Identity): A hook for performing post-processing surgery on large embedding models.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `surgery_cut_point` attribute is used for post-processing surgery on large embedding models to prepare them for serving.
|
||||||
|
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
large_embeddings_config: LargeEmbeddingsConfig,
|
large_embeddings_config: LargeEmbeddingsConfig,
|
||||||
@ -52,6 +66,16 @@ class LargeEmbeddings(nn.Module):
|
|||||||
self,
|
self,
|
||||||
sparse_features: KeyedJaggedTensor,
|
sparse_features: KeyedJaggedTensor,
|
||||||
) -> KeyedTensor:
|
) -> KeyedTensor:
|
||||||
|
"""
|
||||||
|
Forward pass of the LargeEmbeddings module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_features (KeyedJaggedTensor): Sparse input features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KeyedTensor: The output of the forward pass, a KeyedTensor.
|
||||||
|
|
||||||
|
"""
|
||||||
pooled_embs = self.ebc(sparse_features)
|
pooled_embs = self.ebc(sparse_features)
|
||||||
|
|
||||||
# a KeyedTensor
|
# a KeyedTensor
|
||||||
|
@ -11,6 +11,12 @@ import torch.distributed.run
|
|||||||
|
|
||||||
|
|
||||||
def is_distributed_worker():
|
def is_distributed_worker():
|
||||||
|
"""
|
||||||
|
Checks if the current process is a distributed worker.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) are set, else False.
|
||||||
|
"""
|
||||||
world_size = os.environ.get("WORLD_SIZE", None)
|
world_size = os.environ.get("WORLD_SIZE", None)
|
||||||
rank = os.environ.get("RANK", None)
|
rank = os.environ.get("RANK", None)
|
||||||
return world_size is not None and rank is not None
|
return world_size is not None and rank is not None
|
||||||
@ -25,25 +31,36 @@ def maybe_run_training(
|
|||||||
is_chief: Optional[bool] = False,
|
is_chief: Optional[bool] = False,
|
||||||
**training_kwargs,
|
**training_kwargs,
|
||||||
):
|
):
|
||||||
"""Wrapper function for single node, multi-GPU Pytorch training.
|
|
||||||
|
|
||||||
If the necessary distributed Pytorch environment variables
|
|
||||||
(WORLD_SIZE, RANK) have been set, then this function executes
|
|
||||||
`train_fn(**training_kwargs)`.
|
|
||||||
|
|
||||||
Otherwise, this function calls torchrun and points at the calling module
|
|
||||||
`module_name`. After this call, the necessary environment variables are set
|
|
||||||
and training will commence.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_fn: The function that is responsible for training
|
|
||||||
module_name: The name of the module that this function was called from;
|
|
||||||
used to indicate torchrun entrypoint.
|
|
||||||
nproc_per_node: Number of workers per node; supported values.
|
|
||||||
num_nodes: Number of nodes, otherwise inferred from environment.
|
|
||||||
is_chief: If process is running on chief.
|
|
||||||
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
|
|
||||||
"""
|
"""
|
||||||
|
Wrapper function for single node, multi-GPU PyTorch training.
|
||||||
|
|
||||||
|
If the necessary distributed PyTorch environment variables (WORLD_SIZE, RANK) have been set, then this function executes
|
||||||
|
`train_fn(**training_kwargs)`.
|
||||||
|
|
||||||
|
Otherwise, this function calls torchrun and points at the calling module
|
||||||
|
`module_name`. After this call, the necessary environment variables are set
|
||||||
|
and training will commence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_fn (callable): The function responsible for training.
|
||||||
|
module_name (str): The name of the module that this function was called from; used to indicate torchrun entrypoint.
|
||||||
|
nproc_per_node (int, optional): Number of workers per node. Defaults to None.
|
||||||
|
num_nodes (int, optional): Number of nodes. Defaults to None.
|
||||||
|
is_chief (bool, optional): If the process is running on the chief node. Defaults to False.
|
||||||
|
set_python_path_in_subprocess (bool, optional): Whether to set PYTHONPATH in the subprocess. Defaults to False.
|
||||||
|
**training_kwargs: Additional keyword arguments to pass to the `train_fn`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function checks if the current process is a distributed worker by examining the environment variables.
|
||||||
|
If it is a worker, it directly calls `train_fn(**training_kwargs)`. Otherwise, it sets up the necessary
|
||||||
|
environment variables and launches the training process using torchrun.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To run training on a single node with 4 GPUs, you can use:
|
||||||
|
```
|
||||||
|
maybe_run_training(train_function, __name__, nproc_per_node=4)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
machines = utils.machine_from_env()
|
machines = utils.machine_from_env()
|
||||||
if num_nodes is None:
|
if num_nodes is None:
|
||||||
|
@ -19,6 +19,23 @@ MOCK_ENV = {
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def mock_pg():
|
def mock_pg():
|
||||||
|
"""
|
||||||
|
A context manager for mocking the distributed process group for testing purposes.
|
||||||
|
|
||||||
|
This context manager temporarily sets environment variables to mock the distributed process group
|
||||||
|
and initializes it using the Gloo backend. It is useful for testing distributed training without
|
||||||
|
actually launching multiple processes.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
with mock_pg():
|
||||||
|
# Your distributed training code here
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This context manager should be used within a testing environment to simulate distributed training
|
||||||
|
without actually creating multiple processes.
|
||||||
|
"""
|
||||||
with patch.dict(os.environ, MOCK_ENV):
|
with patch.dict(os.environ, MOCK_ENV):
|
||||||
try:
|
try:
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
|
@ -12,6 +12,15 @@ C = TypeVar("C", bound=base_config.BaseConfig)
|
|||||||
|
|
||||||
|
|
||||||
def _read_file(f):
|
def _read_file(f):
|
||||||
|
"""
|
||||||
|
Read the contents of a file using fsspec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: File path or URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The contents of the file.
|
||||||
|
"""
|
||||||
with fsspec.open(f) as f:
|
with fsspec.open(f) as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
@ -21,18 +30,24 @@ def setup_configuration(
|
|||||||
yaml_path: str,
|
yaml_path: str,
|
||||||
substitute_env_variable: bool = False,
|
substitute_env_variable: bool = False,
|
||||||
) -> Tuple[C, str]:
|
) -> Tuple[C, str]:
|
||||||
"""Resolves a config at a yaml path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_type: Pydantic config class to load.
|
|
||||||
yaml_path: yaml path of the config file.
|
|
||||||
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
|
|
||||||
environment variable value whenever possible. If an environment variable doesn't exist,
|
|
||||||
the string is left unchanged.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The pydantic config object.
|
|
||||||
"""
|
"""
|
||||||
|
Load a Pydantic config object from a YAML file and optionally substitute environment variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_type: Pydantic config class to load.
|
||||||
|
yaml_path: YAML path of the config file.
|
||||||
|
substitute_env_variable: If True, substitute strings in the format $VAR or ${VAR}
|
||||||
|
with their environment variable values whenever possible.
|
||||||
|
If an environment variable doesn't exist, the string is left unchanged.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing the Pydantic config object and the resolved YAML content.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
config, resolved_yaml = setup_configuration(MyConfig, "config.yaml", substitute_env_variable=True)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def _substitute(s):
|
def _substitute(s):
|
||||||
if substitute_env_variable:
|
if substitute_env_variable:
|
||||||
|
@ -6,6 +6,33 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class WandbConfig(base_config.BaseConfig):
|
class WandbConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for integrating with Weights and Biases (WandB).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
host (str): Host of the Weights and Biases instance, passed to login.
|
||||||
|
key_path (str): Path to the key file.
|
||||||
|
name (str): Name of the experiment, passed to init.
|
||||||
|
entity (str): Name of the user/service account, passed to init.
|
||||||
|
project (str): Name of the WandB project, passed to init.
|
||||||
|
tags (List[str]): List of tags, passed to init.
|
||||||
|
notes (str): Notes, passed to init.
|
||||||
|
metadata (Dict[str, Any]): Additional metadata to log.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
wandb_config = WandbConfig(
|
||||||
|
host="https://wandb.example.com",
|
||||||
|
key_path="/path/to/key",
|
||||||
|
name="experiment_1",
|
||||||
|
entity="user123",
|
||||||
|
project="my_project",
|
||||||
|
tags=["experiment", "ml"],
|
||||||
|
notes="This is a test experiment.",
|
||||||
|
metadata={"version": "1.0"}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
host: str = pydantic.Field(
|
host: str = pydantic.Field(
|
||||||
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
|
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
|
||||||
description="Host of Weights and Biases instance, passed to login.",
|
description="Host of Weights and Biases instance, passed to login.",
|
||||||
|
@ -8,59 +8,69 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class BaseConfig(pydantic.BaseModel):
|
class BaseConfig(pydantic.BaseModel):
|
||||||
"""Base class for all derived config classes.
|
"""Base class for all derived config classes.
|
||||||
|
|
||||||
This class provides some convenient functionality:
|
This class provides convenient functionality and constraints for derived config classes:
|
||||||
- Disallows extra fields when constructing an object. User error
|
|
||||||
should be reduced by exact arguments.
|
|
||||||
- "one_of" fields. A subclass can group optional fields and enforce
|
|
||||||
that only one of the fields be set. For example:
|
|
||||||
|
|
||||||
```
|
- Disallows extra fields when constructing an object. User errors due to extraneous arguments
|
||||||
|
are minimized.
|
||||||
|
- "one_of" fields: Subclasses can group optional fields and enforce that only one of the fields
|
||||||
|
be set. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
class ExampleConfig(BaseConfig):
|
class ExampleConfig(BaseConfig):
|
||||||
x: int = Field(None, one_of="group_1")
|
x: int = Field(None, one_of="group_1")
|
||||||
y: int = Field(None, one_of="group_1")
|
y: int = Field(None, one_of="group_1")
|
||||||
|
|
||||||
ExampleConfig(x=1) # ok
|
ExampleConfig(x=1) # OK
|
||||||
ExampleConfig(y=1) # ok
|
ExampleConfig(y=1) # OK
|
||||||
ExampleConfig(x=1, y=1) # throws error
|
ExampleConfig(x=1, y=1) # Raises an error
|
||||||
```
|
```
|
||||||
"""
|
|
||||||
|
|
||||||
class Config:
|
Attributes:
|
||||||
"""Forbids extras."""
|
Config (class): Configuration options for this class, forbidding extra fields.
|
||||||
|
|
||||||
extra = pydantic.Extra.forbid # noqa
|
Methods:
|
||||||
|
_field_data_map(cls, field_data_name): Create a map of fields with the provided field data.
|
||||||
|
_one_of_check(cls, values): Validate that all 'one of' fields appear exactly once.
|
||||||
|
_at_most_one_of_check(cls, values): Validate that all 'at_most_one_of' fields appear at most once.
|
||||||
|
pretty_print(self): Return a human-readable (YAML) representation of the config useful for logging.
|
||||||
|
|
||||||
@classmethod
|
"""
|
||||||
@functools.lru_cache()
|
class Config:
|
||||||
def _field_data_map(cls, field_data_name):
|
"""Configuration options that forbid extra fields."""
|
||||||
"""Create a map of fields with provided the field data."""
|
extra = pydantic.Extra.forbid # noqa
|
||||||
schema = cls.schema()
|
|
||||||
one_of = collections.defaultdict(list)
|
|
||||||
for field, fdata in schema["properties"].items():
|
|
||||||
if field_data_name in fdata:
|
|
||||||
one_of[fdata[field_data_name]].append(field)
|
|
||||||
return one_of
|
|
||||||
|
|
||||||
@pydantic.root_validator
|
@classmethod
|
||||||
def _one_of_check(cls, values):
|
@functools.lru_cache()
|
||||||
"""Validate that all 'one of' fields are appear exactly once."""
|
def _field_data_map(cls, field_data_name):
|
||||||
one_of_map = cls._field_data_map("one_of")
|
"""Create a map of fields with the provided field data."""
|
||||||
for one_of, field_names in one_of_map.items():
|
schema = cls.schema()
|
||||||
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
one_of = collections.defaultdict(list)
|
||||||
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
for field, fdata in schema["properties"].items():
|
||||||
return values
|
if field_data_name in fdata:
|
||||||
|
one_of[fdata[field_data_name]].append(field)
|
||||||
|
return one_of
|
||||||
|
|
||||||
@pydantic.root_validator
|
@pydantic.root_validator
|
||||||
def _at_most_one_of_check(cls, values):
|
def _one_of_check(cls, values):
|
||||||
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
"""Validate that all 'one of' fields appear exactly once."""
|
||||||
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
one_of_map = cls._field_data_map("one_of")
|
||||||
for one_of, field_names in at_most_one_of_map.items():
|
for one_of, field_names in one_of_map.items():
|
||||||
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
||||||
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def pretty_print(self) -> str:
|
@pydantic.root_validator
|
||||||
"""Return a human legible (yaml) representation of the config useful for logging."""
|
def _at_most_one_of_check(cls, values):
|
||||||
return yaml.dump(self.dict())
|
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
||||||
|
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
||||||
|
for one_of, field_names in at_most_one_of_map.items():
|
||||||
|
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
||||||
|
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
def pretty_print(self) -> str:
|
||||||
|
"""Return a human-readable (YAML) representation of the config useful for logging."""
|
||||||
|
return yaml.dump(self.dict())
|
||||||
|
@ -6,33 +6,67 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class BaseConfigTest(TestCase):
|
class BaseConfigTest(TestCase):
|
||||||
def test_extra_forbidden(self):
|
"""
|
||||||
class Config(BaseConfig):
|
Unit tests for the BaseConfig class.
|
||||||
x: int
|
"""
|
||||||
|
|
||||||
Config(x=1)
|
def test_extra_forbidden(self):
|
||||||
with self.assertRaises(pydantic.ValidationError):
|
"""
|
||||||
Config(x=1, y=2)
|
Test that extra fields are forbidden when creating a Config instance.
|
||||||
|
|
||||||
def test_one_of(self):
|
This test case checks whether the `BaseConfig` class correctly raises a
|
||||||
class Config(BaseConfig):
|
`pydantic.ValidationError` when extra fields are provided when creating a
|
||||||
x: int = pydantic.Field(None, one_of="f")
|
`Config` instance.
|
||||||
y: int = pydantic.Field(None, one_of="f")
|
|
||||||
|
|
||||||
with self.assertRaises(pydantic.ValidationError):
|
Raises:
|
||||||
Config()
|
AssertionError: If the test fails.
|
||||||
Config(x=1)
|
"""
|
||||||
Config(y=1)
|
class Config(BaseConfig):
|
||||||
with self.assertRaises(pydantic.ValidationError):
|
x: int
|
||||||
Config(x=1, y=3)
|
|
||||||
|
|
||||||
def test_at_most_one_of(self):
|
Config(x=1)
|
||||||
class Config(BaseConfig):
|
with self.assertRaises(pydantic.ValidationError):
|
||||||
x: int = pydantic.Field(None, at_most_one_of="f")
|
Config(x=1, y=2)
|
||||||
y: str = pydantic.Field(None, at_most_one_of="f")
|
|
||||||
|
|
||||||
Config()
|
def test_one_of(self):
|
||||||
Config(x=1)
|
"""
|
||||||
Config(y="a")
|
Test the use of the `one_of` attribute for fields in a Config instance.
|
||||||
with self.assertRaises(pydantic.ValidationError):
|
|
||||||
Config(x=1, y="a")
|
This test case checks the behavior of the `one_of` attribute in a `Config`
|
||||||
|
instance. It verifies that the `pydantic.Field` correctly enforces the
|
||||||
|
specified constraint.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the test fails.
|
||||||
|
"""
|
||||||
|
class Config(BaseConfig):
|
||||||
|
x: int = pydantic.Field(None, one_of="f")
|
||||||
|
y: int = pydantic.Field(None, one_of="f")
|
||||||
|
|
||||||
|
with self.assertRaises(pydantic.ValidationError):
|
||||||
|
Config()
|
||||||
|
Config(x=1)
|
||||||
|
Config(y=1)
|
||||||
|
with self.assertRaises(pydantic.ValidationError):
|
||||||
|
Config(x=1, y=3)
|
||||||
|
|
||||||
|
def test_at_most_one_of(self):
|
||||||
|
"""
|
||||||
|
Test the use of the `at_most_one_of` attribute for fields in a Config instance.
|
||||||
|
|
||||||
|
This test case checks the behavior of the `at_most_one_of` attribute in a
|
||||||
|
`Config` instance. It verifies that the `pydantic.Field` enforces the
|
||||||
|
constraint where at most one of the specified fields can be provided.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the test fails.
|
||||||
|
"""
|
||||||
|
class Config(BaseConfig):
|
||||||
|
x: int = pydantic.Field(None, at_most_one_of="f")
|
||||||
|
y: str = pydantic.Field(None, at_most_one_of="f")
|
||||||
|
|
||||||
|
Config()
|
||||||
|
Config(x=1)
|
||||||
|
Config(y="a")
|
||||||
|
with self.assertRaises(pydantic.ValidationError):
|
||||||
|
Config(x=1, y="a")
|
||||||
|
@ -8,11 +8,42 @@ from tml.core.config.base_config import BaseConfig
|
|||||||
|
|
||||||
|
|
||||||
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
|
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
|
||||||
"""Recommend method to load a config file (a yaml file) and parse it.
|
|
||||||
|
|
||||||
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
|
||||||
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
|
||||||
"""
|
"""
|
||||||
|
Recommend method to Load and parse a configuration from a YAML file.
|
||||||
|
|
||||||
|
This function loads a configuration from a YAML file, parses it, and returns an instance of the
|
||||||
|
specified config type.
|
||||||
|
|
||||||
|
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
||||||
|
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_type (Type[BaseConfig]): The Pydantic config class to load.
|
||||||
|
yaml_path (str): The path to the YAML configuration file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseConfig: An instance of the specified config type populated with values from the YAML file.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Suppose you have a YAML file 'my_config.yaml' containing the following:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
x: 42
|
||||||
|
y: "hello"
|
||||||
|
```
|
||||||
|
|
||||||
|
You can load and parse it using this function as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
my_config = load_config_from_yaml(MyConfigClass, 'my_config.yaml')
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function performs environment variable substitution in the YAML file. It replaces
|
||||||
|
occurrences of the format '$VAR' or '${VAR}' with their corresponding environment variable
|
||||||
|
values. If an environment variable does not exist, the string is left unchanged.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def _substitute(s):
|
def _substitute(s):
|
||||||
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
||||||
|
@ -8,11 +8,27 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class _PointlessConfig(BaseConfig):
|
class _PointlessConfig(BaseConfig):
|
||||||
|
|
||||||
a: int
|
a: int
|
||||||
user: str
|
user: str
|
||||||
|
|
||||||
|
|
||||||
def test_load_config_from_yaml(tmp_path):
|
def test_load_config_from_yaml(tmp_path):
|
||||||
|
"""Test loading a configuration from a YAML file and verifying its values.
|
||||||
|
|
||||||
|
This test function checks the functionality of the `load_config_from_yaml` function by creating
|
||||||
|
a temporary YAML configuration file, loading it, and asserting that the loaded config object
|
||||||
|
has the expected values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmp_path: A temporary directory provided by the `pytest` framework.
|
||||||
|
|
||||||
|
Test Steps:
|
||||||
|
1. Create a temporary YAML file containing configuration data.
|
||||||
|
2. Use the `load_config_from_yaml` function to load the configuration from the YAML file.
|
||||||
|
3. Assert that the loaded configuration object has the expected values.
|
||||||
|
|
||||||
|
"""
|
||||||
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
|
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
|
||||||
with open(yaml_path, "w") as yaml_file:
|
with open(yaml_path, "w") as yaml_file:
|
||||||
yaml_file.write("""a: 3\nuser: ${USER}\n""")
|
yaml_file.write("""a: 3\nuser: ${USER}\n""")
|
||||||
|
@ -9,6 +9,14 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class RuntimeConfig(base_config.BaseConfig):
|
class RuntimeConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for runtime settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- wandb (Optional[WandbConfig]): Configuration for Wandb (Weights and Biases) integration.
|
||||||
|
- enable_tensorfloat32 (bool): Enable tensorfloat32 if on Ampere devices.
|
||||||
|
- enable_amp (bool): Enable automatic mixed precision.
|
||||||
|
"""
|
||||||
wandb: WandbConfig = pydantic.Field(None)
|
wandb: WandbConfig = pydantic.Field(None)
|
||||||
enable_tensorfloat32: bool = pydantic.Field(
|
enable_tensorfloat32: bool = pydantic.Field(
|
||||||
False, description="Use tensorfloat32 if on Ampere devices."
|
False, description="Use tensorfloat32 if on Ampere devices."
|
||||||
@ -17,6 +25,22 @@ class RuntimeConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class TrainingConfig(base_config.BaseConfig):
|
class TrainingConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for training settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- save_dir (str): Directory to save checkpoints.
|
||||||
|
- num_train_steps (pydantic.PositiveInt): Number of training steps.
|
||||||
|
- initial_checkpoint_dir (str): Directory of initial checkpoints (optional).
|
||||||
|
- checkpoint_every_n (pydantic.PositiveInt): Save checkpoints every 'n' steps.
|
||||||
|
- checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
|
||||||
|
- train_log_every_n (pydantic.PositiveInt): Log training information every 'n' steps.
|
||||||
|
- num_eval_steps (int): Number of evaluation steps. If < 0, the entire dataset will be used.
|
||||||
|
- eval_log_every_n (pydantic.PositiveInt): Log evaluation information every 'n' steps.
|
||||||
|
- eval_timeout_in_s (pydantic.PositiveFloat): Evaluation timeout in seconds.
|
||||||
|
- gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
|
||||||
|
- num_epochs (pydantic.PositiveInt): Number of training epochs.
|
||||||
|
"""
|
||||||
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
|
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
|
||||||
num_train_steps: pydantic.PositiveInt = 10000
|
num_train_steps: pydantic.PositiveInt = 10000
|
||||||
initial_checkpoint_dir: str = pydantic.Field(
|
initial_checkpoint_dir: str = pydantic.Field(
|
||||||
|
@ -46,7 +46,24 @@ def get_new_iterator(iterable: Iterable):
|
|||||||
|
|
||||||
|
|
||||||
def _get_step_fn(pipeline, data_iterator, training: bool):
|
def _get_step_fn(pipeline, data_iterator, training: bool):
|
||||||
|
"""
|
||||||
|
Returns a function to perform a single evaluation step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline (Pipeline): The pipeline object containing the model.
|
||||||
|
data_iterator (Iterator): The data iterator for evaluation.
|
||||||
|
training (bool): Flag indicating if the model should be in training mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: A function that performs a single evaluation step.
|
||||||
|
"""
|
||||||
def step_fn():
|
def step_fn():
|
||||||
|
"""
|
||||||
|
Perform a single evaluation step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The evaluation results after a single step.
|
||||||
|
"""
|
||||||
# It turns out that model.train() and model.eval() simply switch a single field inside the model
|
# It turns out that model.train() and model.eval() simply switch a single field inside the model
|
||||||
# class,so it's somewhat safer to wrap in here.
|
# class,so it's somewhat safer to wrap in here.
|
||||||
if training:
|
if training:
|
||||||
@ -69,7 +86,21 @@ def _run_evaluation(
|
|||||||
eval_batch_size: int,
|
eval_batch_size: int,
|
||||||
logger=None,
|
logger=None,
|
||||||
):
|
):
|
||||||
"""Runs the evaluation loop over all evaluation iterators."""
|
"""
|
||||||
|
Run the evaluation loop over all evaluation iterators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipeline (Pipeline): The pipeline object containing the model.
|
||||||
|
dataset (Dataset): The dataset to evaluate.
|
||||||
|
eval_steps (int): The number of evaluation steps to perform.
|
||||||
|
metrics (tm.MetricCollection): A collection of evaluation metrics.
|
||||||
|
eval_batch_size (int): Batch size for evaluation.
|
||||||
|
logger (Optional[Logger]): A logger for recording evaluation progress (default: None).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the computed evaluation metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
dataset = get_new_iterator(dataset)
|
dataset = get_new_iterator(dataset)
|
||||||
step_fn = _get_step_fn(pipeline, dataset, training=False)
|
step_fn = _get_step_fn(pipeline, dataset, training=False)
|
||||||
last_time = datetime.datetime.now()
|
last_time = datetime.datetime.now()
|
||||||
@ -109,15 +140,29 @@ def train(
|
|||||||
parameters_to_log: Optional[Dict[str, Callable]] = None,
|
parameters_to_log: Optional[Dict[str, Callable]] = None,
|
||||||
tables_to_log: Optional[List[str]] = None,
|
tables_to_log: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Runs training and eval on the given TrainPipeline
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: data iterator for the training set
|
|
||||||
evaluation_iterators: data iterators for the different evaluation sets
|
|
||||||
scheduler: optional learning rate scheduler
|
|
||||||
output_transform_for_metrics: optional transformation functions to transorm the model
|
|
||||||
output and labels into a format the metrics can understand
|
|
||||||
"""
|
"""
|
||||||
|
Runs training and evaluation on the given TrainPipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The neural network model to train.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
|
||||||
|
device (str): The target device for model training (e.g., 'cuda' or 'cpu').
|
||||||
|
save_dir (str): The directory to save model checkpoints and logs.
|
||||||
|
logging_interval (int): Interval for logging training progress.
|
||||||
|
train_steps (int): The number of training steps to perform.
|
||||||
|
checkpoint_frequency (int): Frequency of saving model checkpoints.
|
||||||
|
dataset (Iterable): Data iterator for the training set.
|
||||||
|
worker_batch_size (int): Batch size for data loading workers.
|
||||||
|
num_workers (Optional[int]): Number of data loading workers (default: 0).
|
||||||
|
enable_amp (bool): Flag to enable Automatic Mixed Precision (AMP) training (default: False).
|
||||||
|
initial_checkpoint_dir (Optional[str]): Directory to initialize training from (default: None).
|
||||||
|
gradient_accumulation (Optional[int]): Number of gradient accumulation steps (default: None).
|
||||||
|
logger_initializer (Optional[Callable]): A logger initializer function (default: None).
|
||||||
|
scheduler (_LRScheduler): Optional learning rate scheduler (default: None).
|
||||||
|
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
|
||||||
|
parameters_to_log (Optional[Dict[str, Callable]]): Dictionary of parameters to log (default: None).
|
||||||
|
tables_to_log (Optional[List[str]]): List of tables to log (default: None).
|
||||||
|
"""
|
||||||
|
|
||||||
train_pipeline = TrainPipelineSparseDist(
|
train_pipeline = TrainPipelineSparseDist(
|
||||||
model=model,
|
model=model,
|
||||||
@ -262,6 +307,15 @@ def log_eval_results(
|
|||||||
partition_name: str,
|
partition_name: str,
|
||||||
step: int,
|
step: int,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Logs evaluation results and optionally records them using a provided logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (Any): The evaluation results to log.
|
||||||
|
eval_logger (Callable): A logger for recording evaluation results.
|
||||||
|
partition_name (str): The name of the evaluation partition.
|
||||||
|
step (int): The current step in the evaluation.
|
||||||
|
"""
|
||||||
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
|
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
|
||||||
logging.info(f"Step: {step}, evaluation ({partition_name}).")
|
logging.info(f"Step: {step}, evaluation ({partition_name}).")
|
||||||
for metric_name, metric_value in results.items():
|
for metric_name, metric_value in results.items():
|
||||||
@ -285,6 +339,23 @@ def only_evaluate(
|
|||||||
partition_name: str,
|
partition_name: str,
|
||||||
metrics: Optional[tm.MetricCollection] = None,
|
metrics: Optional[tm.MetricCollection] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Performs evaluation on a given dataset partition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The neural network model for evaluation.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer used during evaluation.
|
||||||
|
device (str): The target device for evaluation (e.g., 'cuda' or 'cpu').
|
||||||
|
save_dir (str): The directory containing model checkpoints.
|
||||||
|
num_train_steps (int): The total number of training steps.
|
||||||
|
dataset (Iterable): Data iterator for evaluation.
|
||||||
|
eval_batch_size (int): Batch size for evaluation.
|
||||||
|
num_eval_steps (int): The number of evaluation steps to perform.
|
||||||
|
eval_timeout_in_s (int): Timeout for evaluating checkpoints in seconds.
|
||||||
|
eval_logger (Callable): A logger for recording evaluation results.
|
||||||
|
partition_name (str): The name of the evaluation partition.
|
||||||
|
metrics (Optional[tm.MetricCollection]): A collection of evaluation metrics (default: None).
|
||||||
|
"""
|
||||||
logging.info(f"Evaluating on partition {partition_name}.")
|
logging.info(f"Evaluating on partition {partition_name}.")
|
||||||
logging.info("Computing metrics:")
|
logging.info("Computing metrics:")
|
||||||
logging.info(metrics)
|
logging.info(metrics)
|
||||||
|
@ -28,6 +28,18 @@ def train(
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Debugging training loop. Do not use for actual model training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The neural network model.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for model optimization.
|
||||||
|
train_steps (int): The number of training steps to perform.
|
||||||
|
dataset (Iterable): Data iterator for training data.
|
||||||
|
scheduler (_LRScheduler, optional): Learning rate scheduler (default: None).
|
||||||
|
*args: Additional arguments (ignored).
|
||||||
|
**kwargs: Additional keyword arguments (ignored).
|
||||||
|
"""
|
||||||
|
|
||||||
logging.warning("Running debug training loop, don't use for model training.")
|
logging.warning("Running debug training loop, don't use for model training.")
|
||||||
|
|
||||||
|
@ -10,8 +10,11 @@ import torch
|
|||||||
|
|
||||||
def _maybe_warn(reduction: str):
|
def _maybe_warn(reduction: str):
|
||||||
"""
|
"""
|
||||||
Warning for reduction different than mean.
|
Emit a warning if the reduction method is different from 'mean'.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
reduction (str): The reduction method being used.
|
||||||
|
"""
|
||||||
if reduction != "mean":
|
if reduction != "mean":
|
||||||
logging.warn(
|
logging.warn(
|
||||||
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
|
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
|
||||||
@ -24,6 +27,16 @@ def build_loss(
|
|||||||
loss_type: LossType,
|
loss_type: LossType,
|
||||||
reduction="mean",
|
reduction="mean",
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Build a loss function based on the specified loss type and reduction method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_type (LossType): The type of loss to build.
|
||||||
|
reduction (str): The reduction method for the loss (default: 'mean').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: A loss function that takes logits and labels as input.
|
||||||
|
"""
|
||||||
_maybe_warn(reduction)
|
_maybe_warn(reduction)
|
||||||
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||||||
|
|
||||||
@ -35,11 +48,15 @@ def build_loss(
|
|||||||
|
|
||||||
def get_global_loss_detached(local_loss, reduction="mean"):
|
def get_global_loss_detached(local_loss, reduction="mean"):
|
||||||
"""
|
"""
|
||||||
Perform all_reduce to obtain the global loss function using the provided reduction.
|
Perform all_reduce to obtain the global loss function using the provided reduction.
|
||||||
:param local_loss: The local loss of the current rank.
|
|
||||||
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
|
Args:
|
||||||
:return: The reduced & detached global loss.
|
local_loss (torch.Tensor): The local loss of the current rank.
|
||||||
"""
|
reduction (str): The reduction to use for all_reduce. Should match the reduction used by DDP.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The reduced and detached global loss.
|
||||||
|
"""
|
||||||
if reduction != "mean":
|
if reduction != "mean":
|
||||||
logging.warn(
|
logging.warn(
|
||||||
f"The reduction used in this function should be the same as the one used by "
|
f"The reduction used in this function should be the same as the one used by "
|
||||||
@ -66,6 +83,19 @@ def build_multi_task_loss(
|
|||||||
global_reduction="mean",
|
global_reduction="mean",
|
||||||
pos_weights=None,
|
pos_weights=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Build a multi-task loss function based on the specified loss type and configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_type (LossType): The type of loss to build.
|
||||||
|
tasks (typing.List[str]): List of task names.
|
||||||
|
task_loss_reduction (str): Reduction method for task-specific losses (default: 'mean').
|
||||||
|
global_reduction (str): Reduction method for the global loss (default: 'mean').
|
||||||
|
pos_weights (Optional): Positive class weights for tasks (default: None).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: A multi-task loss function that takes logits, labels, and weights as input.
|
||||||
|
"""
|
||||||
_maybe_warn(global_reduction)
|
_maybe_warn(global_reduction)
|
||||||
_maybe_warn(task_loss_reduction)
|
_maybe_warn(task_loss_reduction)
|
||||||
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||||||
|
@ -36,9 +36,24 @@ import torchmetrics
|
|||||||
class MetricMixin:
|
class MetricMixin:
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
|
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
|
||||||
|
"""
|
||||||
|
Abstract method to transform model outputs into a dictionary of metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary of computed metrics.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def update(self, outputs: Dict[str, torch.Tensor]):
|
def update(self, outputs: Dict[str, torch.Tensor]):
|
||||||
|
"""
|
||||||
|
Update the metrics based on model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||||
|
"""
|
||||||
results = self.transform(outputs)
|
results = self.transform(outputs)
|
||||||
# Do not try to update if any tensor is empty as a result of stratification.
|
# Do not try to update if any tensor is empty as a result of stratification.
|
||||||
for value in results.values():
|
for value in results.values():
|
||||||
@ -49,6 +64,13 @@ class MetricMixin:
|
|||||||
|
|
||||||
class TaskMixin:
|
class TaskMixin:
|
||||||
def __init__(self, task_idx: int = -1, **kwargs):
|
def __init__(self, task_idx: int = -1, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize a TaskMixin instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_idx (int): Index of the task associated with this mixin (default: -1).
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._task_idx = task_idx
|
self._task_idx = task_idx
|
||||||
|
|
||||||
@ -59,13 +81,31 @@ class StratifyMixin:
|
|||||||
stratifier=None,
|
stratifier=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize a StratifyMixin instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stratifier: A stratifier for filtering outputs (default: None).
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._stratifier = stratifier
|
self._stratifier = stratifier
|
||||||
|
|
||||||
def maybe_apply_stratification(
|
def maybe_apply_stratification(
|
||||||
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
|
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
|
"""
|
||||||
|
Apply stratification to filter examples in the outputs.
|
||||||
|
|
||||||
|
Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||||
|
value_names (List[str]): Names of values to filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: Filtered outputs.
|
||||||
|
"""
|
||||||
outputs = outputs.copy()
|
outputs = outputs.copy()
|
||||||
if not self._stratifier:
|
if not self._stratifier:
|
||||||
return outputs
|
return outputs
|
||||||
@ -84,12 +124,20 @@ class StratifyMixin:
|
|||||||
|
|
||||||
|
|
||||||
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
|
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
|
||||||
"""Returns new class using MetricMixin and given base_metric.
|
|
||||||
|
|
||||||
Functionally the same using inheritance, just saves some lines of code
|
|
||||||
if no need for class attributes.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Returns a new class using MetricMixin and the given base_metric.
|
||||||
|
|
||||||
|
Functionally the same as using inheritance, but it saves some lines of code
|
||||||
|
if there's no need for class attributes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_metric (torchmetrics.Metric): The base metric class to prepend the transform to.
|
||||||
|
transform (Callable): The transformation function to prepend to the metric.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: A new class that includes MetricMixin and the provided base_metric
|
||||||
|
with the specified transformation method.
|
||||||
|
"""
|
||||||
|
|
||||||
def transform_method(_self, *args, **kwargs):
|
def transform_method(_self, *args, **kwargs):
|
||||||
return transform(*args, **kwargs)
|
return transform(*args, **kwargs)
|
||||||
|
@ -15,6 +15,16 @@ def probs_and_labels(
|
|||||||
outputs: Dict[str, torch.Tensor],
|
outputs: Dict[str, torch.Tensor],
|
||||||
task_idx: int,
|
task_idx: int,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Extract probabilities and labels from model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Dict[str, torch.Tensor]): Model outputs.
|
||||||
|
task_idx (int): Index of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: Dictionary containing 'preds' and 'target' tensors.
|
||||||
|
"""
|
||||||
preds = outputs["probabilities"]
|
preds = outputs["probabilities"]
|
||||||
target = outputs["labels"]
|
target = outputs["labels"]
|
||||||
if task_idx >= 0:
|
if task_idx >= 0:
|
||||||
@ -28,6 +38,11 @@ def probs_and_labels(
|
|||||||
|
|
||||||
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
|
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
|
"""
|
||||||
|
Count metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and SumMetric.
|
||||||
|
|
||||||
|
This metric counts values after potential stratification and task selection.
|
||||||
|
"""
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
||||||
value = outputs["labels"]
|
value = outputs["labels"]
|
||||||
if self._task_idx >= 0:
|
if self._task_idx >= 0:
|
||||||
@ -36,6 +51,12 @@ class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
|
|||||||
|
|
||||||
|
|
||||||
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
|
"""
|
||||||
|
Ctr (Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
|
||||||
|
|
||||||
|
This metric calculates the mean metric value after potential stratification and task selection.
|
||||||
|
"""
|
||||||
|
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
outputs = self.maybe_apply_stratification(outputs, ["labels"])
|
||||||
value = outputs["labels"]
|
value = outputs["labels"]
|
||||||
@ -45,6 +66,11 @@ class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
|
|
||||||
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
|
"""
|
||||||
|
Pctr (Predicted Click-Through Rate) metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and MeanMetric.
|
||||||
|
|
||||||
|
This metric calculates the mean metric value using probabilities after potential stratification and task selection.
|
||||||
|
"""
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
|
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
|
||||||
value = outputs["probabilities"]
|
value = outputs["probabilities"]
|
||||||
@ -54,12 +80,22 @@ class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
|
|
||||||
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
|
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
|
||||||
|
"""
|
||||||
|
Precision metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Precision.
|
||||||
|
|
||||||
|
This metric computes precision after potential stratification and task selection.
|
||||||
|
"""
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
||||||
return probs_and_labels(outputs, self._task_idx)
|
return probs_and_labels(outputs, self._task_idx)
|
||||||
|
|
||||||
|
|
||||||
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
|
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
|
||||||
|
"""
|
||||||
|
Recall metric class that inherits from StratifyMixin, TaskMixin, MetricMixin, and Recall.
|
||||||
|
|
||||||
|
This metric computes recall after potential stratification and task selection.
|
||||||
|
"""
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
|
||||||
return probs_and_labels(outputs, self._task_idx)
|
return probs_and_labels(outputs, self._task_idx)
|
||||||
@ -73,6 +109,14 @@ class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
|
|||||||
|
|
||||||
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
"""
|
"""
|
||||||
|
AUC (Area Under the ROC Curve) metric class.
|
||||||
|
|
||||||
|
This metric computes the AUC metric based on the logits and labels in the model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_samples (int): The number of samples used to compute AUC.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
Based on:
|
Based on:
|
||||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
|
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
|
||||||
"""
|
"""
|
||||||
@ -94,8 +138,14 @@ class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
"""
|
"""
|
||||||
The ranks of all positives
|
PosRanks metric class.
|
||||||
Based on:
|
|
||||||
|
This metric computes the ranks of all positive examples based on the logits and labels
|
||||||
|
in the model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
|
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -112,8 +162,13 @@ class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
"""
|
"""
|
||||||
The reciprocal of the ranks of all
|
ReciprocalRank metric class.
|
||||||
Based on:
|
|
||||||
|
This metric computes the reciprocal of the ranks of all positive examples based on the logits and labels
|
||||||
|
in the model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
|
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -130,9 +185,14 @@ class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
|||||||
|
|
||||||
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
|
||||||
"""
|
"""
|
||||||
The fraction of positives that rank in the top K among their negatives
|
HitAtK metric class.
|
||||||
Note that this is basically precision@k
|
|
||||||
Based on:
|
This metric computes the fraction of positive examples that rank in the top K among their negatives,
|
||||||
|
which is equivalent to precision@K.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k (int): The value of K.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
|
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -9,12 +9,26 @@ from torchmetrics import MaxMetric, MetricCollection, SumMetric
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockStratifierConfig:
|
class MockStratifierConfig:
|
||||||
|
"""
|
||||||
|
Configuration dataclass for mocking a stratifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the stratifier.
|
||||||
|
index (int): The index of the stratifier.
|
||||||
|
value (int): The value of the stratifier.
|
||||||
|
"""
|
||||||
name: str
|
name: str
|
||||||
index: int
|
index: int
|
||||||
value: int
|
value: int
|
||||||
|
|
||||||
|
|
||||||
class Count(MetricMixin, SumMetric):
|
class Count(MetricMixin, SumMetric):
|
||||||
|
"""
|
||||||
|
Count metric class that inherits from MetricMixin and SumMetric.
|
||||||
|
|
||||||
|
This metric counts occurrences.
|
||||||
|
|
||||||
|
"""
|
||||||
def transform(self, outputs):
|
def transform(self, outputs):
|
||||||
return {"value": 1}
|
return {"value": 1}
|
||||||
|
|
||||||
@ -23,6 +37,12 @@ Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
|
|||||||
|
|
||||||
|
|
||||||
def test_count_metric():
|
def test_count_metric():
|
||||||
|
"""
|
||||||
|
Test function for the Count metric.
|
||||||
|
|
||||||
|
It checks if the Count metric correctly counts the number of examples.
|
||||||
|
|
||||||
|
"""
|
||||||
num_examples = 123
|
num_examples = 123
|
||||||
examples = [
|
examples = [
|
||||||
{"stuff": 0},
|
{"stuff": 0},
|
||||||
@ -36,6 +56,12 @@ def test_count_metric():
|
|||||||
|
|
||||||
|
|
||||||
def test_collections():
|
def test_collections():
|
||||||
|
"""
|
||||||
|
Test function for metric collections.
|
||||||
|
|
||||||
|
It tests if metric collections correctly aggregate metrics.
|
||||||
|
|
||||||
|
"""
|
||||||
max_metric = Max()
|
max_metric = Max()
|
||||||
count_metric = Count()
|
count_metric = Count()
|
||||||
metric = MetricCollection([max_metric, count_metric])
|
metric = MetricCollection([max_metric, count_metric])
|
||||||
@ -51,6 +77,12 @@ def test_collections():
|
|||||||
|
|
||||||
|
|
||||||
def test_task_dependent_ctr():
|
def test_task_dependent_ctr():
|
||||||
|
"""
|
||||||
|
Test function for task-dependent Ctr (Click-Through Rate) metric.
|
||||||
|
|
||||||
|
It checks if the Ctr metric computes the correct value for different tasks.
|
||||||
|
|
||||||
|
"""
|
||||||
num_examples = 144
|
num_examples = 144
|
||||||
batch_size = 1024
|
batch_size = 1024
|
||||||
outputs = [
|
outputs = [
|
||||||
@ -69,6 +101,13 @@ def test_task_dependent_ctr():
|
|||||||
|
|
||||||
|
|
||||||
def test_stratified_ctr():
|
def test_stratified_ctr():
|
||||||
|
"""
|
||||||
|
Test function for the Stratified Ctr (Click-Through Rate) metric.
|
||||||
|
|
||||||
|
It checks if the Stratified Ctr metric computes the correct value for different tasks
|
||||||
|
and stratified samples.
|
||||||
|
|
||||||
|
"""
|
||||||
outputs = [
|
outputs = [
|
||||||
{
|
{
|
||||||
"stuff": 0,
|
"stuff": 0,
|
||||||
@ -114,6 +153,12 @@ def test_stratified_ctr():
|
|||||||
|
|
||||||
|
|
||||||
def test_auc():
|
def test_auc():
|
||||||
|
"""
|
||||||
|
Test function for the AUC (Area Under the Curve) metric.
|
||||||
|
|
||||||
|
It checks if the AUC metric correctly computes the Area Under the ROC Curve.
|
||||||
|
|
||||||
|
"""
|
||||||
num_samples = 10000
|
num_samples = 10000
|
||||||
metric = core_metrics.Auc(num_samples)
|
metric = core_metrics.Auc(num_samples)
|
||||||
target = torch.tensor([0, 0, 1, 1, 1])
|
target = torch.tensor([0, 0, 1, 1, 1])
|
||||||
@ -131,6 +176,12 @@ def test_auc():
|
|||||||
|
|
||||||
|
|
||||||
def test_pos_rank():
|
def test_pos_rank():
|
||||||
|
"""
|
||||||
|
Test function for the PosRanks metric.
|
||||||
|
|
||||||
|
It checks if the PosRanks metric correctly computes the ranks of positive samples.
|
||||||
|
|
||||||
|
"""
|
||||||
metric = core_metrics.PosRanks()
|
metric = core_metrics.PosRanks()
|
||||||
target = torch.tensor([0, 0, 1, 1, 1])
|
target = torch.tensor([0, 0, 1, 1, 1])
|
||||||
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
||||||
@ -147,6 +198,12 @@ def test_pos_rank():
|
|||||||
|
|
||||||
|
|
||||||
def test_reciprocal_rank():
|
def test_reciprocal_rank():
|
||||||
|
"""
|
||||||
|
Test function for the Reciprocal Rank metric.
|
||||||
|
|
||||||
|
It checks if the Reciprocal Rank metric correctly computes the reciprocal of ranks.
|
||||||
|
|
||||||
|
"""
|
||||||
metric = core_metrics.ReciprocalRank()
|
metric = core_metrics.ReciprocalRank()
|
||||||
target = torch.tensor([0, 0, 1, 1, 1])
|
target = torch.tensor([0, 0, 1, 1, 1])
|
||||||
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
|
||||||
@ -163,6 +220,12 @@ def test_reciprocal_rank():
|
|||||||
|
|
||||||
|
|
||||||
def test_hit_k():
|
def test_hit_k():
|
||||||
|
"""
|
||||||
|
Test function for the Hit@K metric.
|
||||||
|
|
||||||
|
It checks if the Hit@K metric correctly computes the fraction of positives that rank in the top K among their negatives.
|
||||||
|
|
||||||
|
"""
|
||||||
hit1_metric = core_metrics.HitAtK(1)
|
hit1_metric = core_metrics.HitAtK(1)
|
||||||
target = torch.tensor([0, 0, 1, 1, 1])
|
target = torch.tensor([0, 0, 1, 1, 1])
|
||||||
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
|
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
|
||||||
|
@ -11,23 +11,60 @@ from torchrec.distributed import DistributedModelParallel
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockDataclassBatch(DataclassBatch):
|
class MockDataclassBatch(DataclassBatch):
|
||||||
|
"""
|
||||||
|
Mock data class batch for testing purposes.
|
||||||
|
|
||||||
|
This class represents a batch of data with continuous features and labels.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
continuous_features (torch.Tensor): Tensor containing continuous feature data.
|
||||||
|
labels (torch.Tensor): Tensor containing label data.
|
||||||
|
"""
|
||||||
continuous_features: torch.Tensor
|
continuous_features: torch.Tensor
|
||||||
labels: torch.Tensor
|
labels: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class MockModule(torch.nn.Module):
|
class MockModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Mock PyTorch module for testing purposes.
|
||||||
|
|
||||||
|
This module defines a simple neural network model with a linear layer
|
||||||
|
followed by a BCEWithLogitsLoss loss function.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model (torch.nn.Linear): The linear model layer.
|
||||||
|
loss_fn (torch.nn.BCEWithLogitsLoss): Binary cross-entropy loss function.
|
||||||
|
"""
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = torch.nn.Linear(10, 1)
|
self.model = torch.nn.Linear(10, 1)
|
||||||
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass of the mock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (MockDataclassBatch): Input data batch with continuous features and labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the loss and predictions.
|
||||||
|
"""
|
||||||
pred = self.model(batch.continuous_features)
|
pred = self.model(batch.continuous_features)
|
||||||
loss = self.loss_fn(pred, batch.labels)
|
loss = self.loss_fn(pred, batch.labels)
|
||||||
return (loss, pred)
|
return (loss, pred)
|
||||||
|
|
||||||
|
|
||||||
def create_batch(bsz: int):
|
def create_batch(bsz: int):
|
||||||
|
"""
|
||||||
|
Create a mock data batch with random continuous features and labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bsz (int): Batch size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MockDataclassBatch: A batch of data with continuous features and labels.
|
||||||
|
"""
|
||||||
return MockDataclassBatch(
|
return MockDataclassBatch(
|
||||||
continuous_features=torch.rand(bsz, 10).float(),
|
continuous_features=torch.rand(bsz, 10).float(),
|
||||||
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
|
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
|
||||||
@ -35,6 +72,13 @@ def create_batch(bsz: int):
|
|||||||
|
|
||||||
|
|
||||||
def test_sparse_pipeline():
|
def test_sparse_pipeline():
|
||||||
|
"""
|
||||||
|
Test function for the sparse pipeline with distributed model parallelism.
|
||||||
|
|
||||||
|
This function tests the behavior of the sparse training pipeline using
|
||||||
|
a mock module and data.
|
||||||
|
"""
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
model = MockModule().to(device)
|
model = MockModule().to(device)
|
||||||
|
|
||||||
@ -65,6 +109,15 @@ def test_sparse_pipeline():
|
|||||||
|
|
||||||
|
|
||||||
def test_amp():
|
def test_amp():
|
||||||
|
"""
|
||||||
|
Test automatic mixed-precision (AMP) training with the sparse pipeline.
|
||||||
|
|
||||||
|
This function tests the behavior of the sparse training pipeline with
|
||||||
|
automatic mixed-precision (AMP) enabled, using a mock module and data.
|
||||||
|
|
||||||
|
AMP allows for faster training by using lower-precision data types, such as
|
||||||
|
torch.bfloat16, while maintaining model accuracy.
|
||||||
|
"""
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
model = MockModule().to(device)
|
model = MockModule().to(device)
|
||||||
|
|
||||||
|
@ -39,12 +39,42 @@ Out = TypeVar("Out")
|
|||||||
|
|
||||||
|
|
||||||
class TrainPipeline(abc.ABC, Generic[In, Out]):
|
class TrainPipeline(abc.ABC, Generic[In, Out]):
|
||||||
|
"""
|
||||||
|
Abstract base class for training pipelines.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
In (TypeVar): Input data type.
|
||||||
|
Out (TypeVar): Output data type.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
progress(dataloader_iter: Iterator[In]) -> Out: Abstract method to make progress in the training pipeline.
|
||||||
|
"""
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||||
|
"""
|
||||||
|
Make progress in the training pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Out: The output data.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
||||||
|
"""
|
||||||
|
Move a batch of data to a specified device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (In): The input batch.
|
||||||
|
device (torch.device): The target device.
|
||||||
|
non_blocking (bool): If True, move the data asynchronously.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
In: The batch of data on the target device.
|
||||||
|
"""
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
batch, (torch.Tensor, Pipelineable)
|
batch, (torch.Tensor, Pipelineable)
|
||||||
), f"{type(batch)} must implement Pipelineable interface"
|
), f"{type(batch)} must implement Pipelineable interface"
|
||||||
@ -52,6 +82,16 @@ def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
|
|||||||
|
|
||||||
|
|
||||||
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
||||||
|
"""
|
||||||
|
Wait for a batch of data on a specified stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (In): The input batch.
|
||||||
|
stream (Optional[Stream]): The CUDA stream to wait for.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function is used for managing asynchronous CUDA operations.
|
||||||
|
"""
|
||||||
if stream is None:
|
if stream is None:
|
||||||
return
|
return
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
@ -72,11 +112,26 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N
|
|||||||
|
|
||||||
class TrainPipelineBase(TrainPipeline[In, Out]):
|
class TrainPipelineBase(TrainPipeline[In, Out]):
|
||||||
"""
|
"""
|
||||||
This class runs training iterations using a pipeline of two stages, each as a CUDA
|
This class runs training iterations using a pipeline of two stages, each as a CUDA
|
||||||
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
|
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
|
||||||
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
|
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
|
||||||
memory, and the default stream runs forward, backward, and optimization.
|
memory, and the default stream runs forward, backward, and optimization.
|
||||||
"""
|
|
||||||
|
Attributes:
|
||||||
|
In (TypeVar): Input data type.
|
||||||
|
Out (TypeVar): Output data type.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
__init__(model: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> None:
|
||||||
|
Initialize the TrainPipelineBase.
|
||||||
|
|
||||||
|
_connect(dataloader_iter: Iterator[In]) -> None:
|
||||||
|
Establish a connection to the data loader and move the input data to the GPU.
|
||||||
|
|
||||||
|
progress(dataloader_iter: Iterator[In]) -> Out:
|
||||||
|
Execute a training iteration, including forward and backward passes.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -84,6 +139,14 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the TrainPipelineBase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The PyTorch model to be trained.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer used for training.
|
||||||
|
device (torch.device): The target device for training (CPU or GPU).
|
||||||
|
"""
|
||||||
self._model = model
|
self._model = model
|
||||||
self._optimizer = optimizer
|
self._optimizer = optimizer
|
||||||
self._device = device
|
self._device = device
|
||||||
@ -94,6 +157,12 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
|||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||||
|
"""
|
||||||
|
Establish a connection to the data loader and move the input data to the GPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||||
|
"""
|
||||||
cur_batch = next(dataloader_iter)
|
cur_batch = next(dataloader_iter)
|
||||||
self._cur_batch = cur_batch
|
self._cur_batch = cur_batch
|
||||||
with torch.cuda.stream(self._memcpy_stream):
|
with torch.cuda.stream(self._memcpy_stream):
|
||||||
@ -101,6 +170,15 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
|||||||
self._connected = True
|
self._connected = True
|
||||||
|
|
||||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||||
|
"""
|
||||||
|
Execute a training iteration, including forward and backward passes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator over input data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Out: The output data.
|
||||||
|
"""
|
||||||
if not self._connected:
|
if not self._connected:
|
||||||
self._connect(dataloader_iter)
|
self._connect(dataloader_iter)
|
||||||
|
|
||||||
@ -139,6 +217,16 @@ class TrainPipelineBase(TrainPipeline[In, Out]):
|
|||||||
|
|
||||||
|
|
||||||
class Tracer(torch.fx.Tracer):
|
class Tracer(torch.fx.Tracer):
|
||||||
|
"""
|
||||||
|
Custom tracer class for PyTorch models.
|
||||||
|
|
||||||
|
This tracer is used to trace PyTorch models while also considering specific leaf modules and buffer proxying settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
proxy_buffer_attributes (bool): Flag to enable/disable proxying buffers during tracing.
|
||||||
|
_leaf_modules (List[str]): List of qualified names of leaf modules.
|
||||||
|
"""
|
||||||
|
|
||||||
# Disable proxying buffers during tracing. Ideally, proxying buffers would
|
# Disable proxying buffers during tracing. Ideally, proxying buffers would
|
||||||
# be disabled, but some models are currently mutating buffer values, which
|
# be disabled, but some models are currently mutating buffer values, which
|
||||||
# causes errors during tracing. If those models can be rewritten to not do
|
# causes errors during tracing. If those models can be rewritten to not do
|
||||||
@ -146,10 +234,26 @@ class Tracer(torch.fx.Tracer):
|
|||||||
proxy_buffer_attributes = False
|
proxy_buffer_attributes = False
|
||||||
|
|
||||||
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
|
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the Tracer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
leaf_modules (Optional[List[str]]): List of qualified names of leaf modules to consider as leaf nodes during tracing.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
|
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
|
||||||
|
|
||||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a module is a leaf module during tracing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m (torch.nn.Module): The PyTorch module.
|
||||||
|
module_qualified_name (str): The qualified name of the module.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the module is considered a leaf module, False otherwise.
|
||||||
|
"""
|
||||||
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
|
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
|
||||||
return True
|
return True
|
||||||
return super().is_leaf_module(m, module_qualified_name)
|
return super().is_leaf_module(m, module_qualified_name)
|
||||||
@ -157,6 +261,15 @@ class Tracer(torch.fx.Tracer):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineContext:
|
class TrainPipelineContext:
|
||||||
|
"""
|
||||||
|
Dataclass to store information related to the training pipeline context.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
input_dist_requests (Dict[str, Awaitable[Any]]): A dictionary of input distribution requests.
|
||||||
|
module_contexts (Dict[str, Multistreamable]): A dictionary of module contexts.
|
||||||
|
feature_processor_forwards (List[Any]): A list of feature processor forwards.
|
||||||
|
"""
|
||||||
|
|
||||||
# pyre-ignore [4]
|
# pyre-ignore [4]
|
||||||
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
|
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
|
||||||
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
|
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
|
||||||
@ -166,6 +279,14 @@ class TrainPipelineContext:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ArgInfo:
|
class ArgInfo:
|
||||||
|
"""
|
||||||
|
Dataclass to store information about arguments in the training pipeline.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
input_attrs (List[str]): List of attribute names of the input batch.
|
||||||
|
is_getitems (List[bool]): List of boolean values indicating whether the argument is accessed using getitem.
|
||||||
|
name (Optional[str]): Name for the keyword argument in the pipelined forward() call or None for positional arguments.
|
||||||
|
"""
|
||||||
# attributes of input batch, e.g. batch.attr1.attr2 call
|
# attributes of input batch, e.g. batch.attr1.attr2 call
|
||||||
# will produce ["attr1", "attr2"]
|
# will produce ["attr1", "attr2"]
|
||||||
input_attrs: List[str]
|
input_attrs: List[str]
|
||||||
@ -177,6 +298,16 @@ class ArgInfo:
|
|||||||
|
|
||||||
|
|
||||||
class PipelinedForward:
|
class PipelinedForward:
|
||||||
|
"""
|
||||||
|
Represents a pipelined forward pass operation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The name of the forward pass.
|
||||||
|
args (List[ArgInfo]): List of argument information for the forward pass.
|
||||||
|
module (ShardedModule): The sharded module associated with the forward pass.
|
||||||
|
context (TrainPipelineContext): The training pipeline context.
|
||||||
|
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@ -185,6 +316,16 @@ class PipelinedForward:
|
|||||||
context: TrainPipelineContext,
|
context: TrainPipelineContext,
|
||||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize a PipelinedForward instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the forward pass.
|
||||||
|
args (List[ArgInfo]): List of argument information for the forward pass.
|
||||||
|
module (ShardedModule): The sharded module associated with the forward pass.
|
||||||
|
context (TrainPipelineContext): The training pipeline context.
|
||||||
|
dist_stream (Optional[torch.cuda.streams.Stream]): CUDA stream for distributed processing.
|
||||||
|
"""
|
||||||
self._name = name
|
self._name = name
|
||||||
self._args = args
|
self._args = args
|
||||||
self._module = module
|
self._module = module
|
||||||
@ -193,6 +334,16 @@ class PipelinedForward:
|
|||||||
|
|
||||||
# pyre-ignore [2, 24]
|
# pyre-ignore [2, 24]
|
||||||
def __call__(self, *input, **kwargs) -> Awaitable:
|
def __call__(self, *input, **kwargs) -> Awaitable:
|
||||||
|
"""
|
||||||
|
Perform the pipelined forward pass operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*input: Variable-length positional arguments.
|
||||||
|
**kwargs: Variable-length keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Awaitable: An awaitable object representing the forward pass result.
|
||||||
|
"""
|
||||||
assert self._name in self._context.input_dist_requests
|
assert self._name in self._context.input_dist_requests
|
||||||
request = self._context.input_dist_requests[self._name]
|
request = self._context.input_dist_requests[self._name]
|
||||||
assert isinstance(request, Awaitable)
|
assert isinstance(request, Awaitable)
|
||||||
@ -230,10 +381,22 @@ class PipelinedForward:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the name of the forward pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The name of the forward pass.
|
||||||
|
"""
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args(self) -> List[ArgInfo]:
|
def args(self) -> List[ArgInfo]:
|
||||||
|
"""
|
||||||
|
Get the list of argument information for the forward pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ArgInfo]: List of argument information.
|
||||||
|
"""
|
||||||
return self._args
|
return self._args
|
||||||
|
|
||||||
|
|
||||||
@ -242,6 +405,17 @@ def _start_data_dist(
|
|||||||
batch: In,
|
batch: In,
|
||||||
context: TrainPipelineContext,
|
context: TrainPipelineContext,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Start data distribution for a list of pipelined modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipelined_modules (List[ShardedModule]): List of ShardedModule instances representing pipelined modules.
|
||||||
|
batch (In): The input batch.
|
||||||
|
context (TrainPipelineContext): The training pipeline context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: This function doesn't return a value.
|
||||||
|
"""
|
||||||
context.input_dist_requests.clear()
|
context.input_dist_requests.clear()
|
||||||
context.module_contexts.clear()
|
context.module_contexts.clear()
|
||||||
for module in pipelined_modules:
|
for module in pipelined_modules:
|
||||||
@ -286,9 +460,17 @@ def _get_node_args_helper(
|
|||||||
feature_processor_arguments: Optional[List[Node]] = None,
|
feature_processor_arguments: Optional[List[Node]] = None,
|
||||||
) -> Tuple[List[ArgInfo], int]:
|
) -> Tuple[List[ArgInfo], int]:
|
||||||
"""
|
"""
|
||||||
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
|
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
|
||||||
It also counts the number of (args + kwargs) found.
|
It also counts the number of (args + kwargs) found.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
arguments: The arguments to process.
|
||||||
|
num_found: The current count of arguments found.
|
||||||
|
feature_processor_arguments: Optional list of feature processor arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the updated count of arguments found.
|
||||||
|
"""
|
||||||
|
|
||||||
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
|
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
|
||||||
for arg, arg_info in zip(arguments, arg_info_list):
|
for arg, arg_info in zip(arguments, arg_info_list):
|
||||||
@ -332,6 +514,16 @@ def _get_node_args_helper(
|
|||||||
def _get_node_args(
|
def _get_node_args(
|
||||||
node: Node, feature_processor_nodes: Optional[List[Node]] = None
|
node: Node, feature_processor_nodes: Optional[List[Node]] = None
|
||||||
) -> Tuple[List[ArgInfo], int]:
|
) -> Tuple[List[ArgInfo], int]:
|
||||||
|
"""
|
||||||
|
Get argument information for a given node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (Node): The node to process.
|
||||||
|
feature_processor_nodes (Optional[List[Node]]): Optional list of feature processor nodes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[ArgInfo], int]: A tuple containing a list of `ArgInfo` objects and the number of arguments found.
|
||||||
|
"""
|
||||||
num_found = 0
|
num_found = 0
|
||||||
pos_arg_info_list, num_found = _get_node_args_helper(
|
pos_arg_info_list, num_found = _get_node_args_helper(
|
||||||
node.args, num_found, feature_processor_nodes
|
node.args, num_found, feature_processor_nodes
|
||||||
@ -351,6 +543,17 @@ def _get_unsharded_module_names_helper(
|
|||||||
path: str,
|
path: str,
|
||||||
unsharded_module_names: Set[str],
|
unsharded_module_names: Set[str],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Get the names of unsharded modules in a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The model to analyze.
|
||||||
|
path (str): The current path in the model hierarchy.
|
||||||
|
unsharded_module_names (Set[str]): A set to store the names of unsharded modules.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if any sharded modules were found in the hierarchy, False otherwise.
|
||||||
|
"""
|
||||||
sharded_children = set()
|
sharded_children = set()
|
||||||
for name, child in model.named_children():
|
for name, child in model.named_children():
|
||||||
curr_path = path + name
|
curr_path = path + name
|
||||||
@ -375,8 +578,14 @@ def _get_unsharded_module_names_helper(
|
|||||||
|
|
||||||
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
|
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns a list of top level modules do not contain any sharded sub modules.
|
Returns a list of top-level modules that do not contain any sharded sub-modules.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The model to analyze.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of top-level module names without sharded sub-modules.
|
||||||
|
"""
|
||||||
|
|
||||||
unsharded_module_names: Set[str] = set()
|
unsharded_module_names: Set[str] = set()
|
||||||
_get_unsharded_module_names_helper(
|
_get_unsharded_module_names_helper(
|
||||||
@ -392,6 +601,21 @@ def _rewrite_model( # noqa C901
|
|||||||
context: TrainPipelineContext,
|
context: TrainPipelineContext,
|
||||||
dist_stream: Optional[torch.cuda.streams.Stream],
|
dist_stream: Optional[torch.cuda.streams.Stream],
|
||||||
) -> List[ShardedModule]:
|
) -> List[ShardedModule]:
|
||||||
|
"""
|
||||||
|
Rewrites the model to enable pipelined execution for selected sharded modules.
|
||||||
|
|
||||||
|
This function traces the input model using a custom tracer and identifies sharded modules
|
||||||
|
that can be pipelined. It then creates PipelinedForward objects for these modules,
|
||||||
|
which enable pipelining during training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The input model to be rewritten.
|
||||||
|
context (TrainPipelineContext): The context containing information needed for pipelining.
|
||||||
|
dist_stream (Optional[torch.cuda.streams.Stream]): The CUDA stream for data distribution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[ShardedModule]: A list of sharded modules that have been rewritten for pipelined execution.
|
||||||
|
"""
|
||||||
|
|
||||||
# Get underlying nn.Module
|
# Get underlying nn.Module
|
||||||
if isinstance(model, DistributedModelParallel):
|
if isinstance(model, DistributedModelParallel):
|
||||||
@ -442,20 +666,32 @@ def _rewrite_model( # noqa C901
|
|||||||
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
||||||
"""
|
"""
|
||||||
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
|
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
|
||||||
forward and backward. This helps hide the all2all latency while preserving the
|
forward and backward. This helps hide the all2all latency while preserving the
|
||||||
training forward / backward ordering.
|
training forward / backward ordering.
|
||||||
|
|
||||||
stage 3: forward, backward - uses default CUDA stream
|
stage 3: forward, backward - uses default CUDA stream
|
||||||
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
|
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
|
||||||
stage 1: device transfer - uses memcpy CUDA stream
|
stage 1: device transfer - uses memcpy CUDA stream
|
||||||
|
|
||||||
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
|
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
|
||||||
To be considered a top-level module, a module can only depend on 'getattr' calls on
|
To be considered a top-level module, a module can only depend on 'getattr' calls on
|
||||||
input.
|
input.
|
||||||
|
|
||||||
Input model must be symbolically traceable with the exception of `ShardedModule` and
|
Input model must be symbolically traceable with the exception of `ShardedModule` and
|
||||||
`DistributedDataParallel` modules.
|
`DistributedDataParallel` modules.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The input model to be used for training.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
|
||||||
|
device (torch.device): The device where training will be performed.
|
||||||
|
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
|
||||||
|
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
|
||||||
|
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
synced_pipeline_id (Dict[int, int]): A dictionary to track synchronized pipelines.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
synced_pipeline_id: Dict[int, int] = {}
|
synced_pipeline_id: Dict[int, int] = {}
|
||||||
|
|
||||||
@ -468,6 +704,17 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
|||||||
enable_grad_scaling: bool = True,
|
enable_grad_scaling: bool = True,
|
||||||
grad_accum: Optional[int] = None,
|
grad_accum: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the training pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The input model to be used for training.
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
|
||||||
|
device (torch.device): The device where training will be performed.
|
||||||
|
enable_amp (bool, optional): Whether to enable automatic mixed precision (AMP). Defaults to False.
|
||||||
|
enable_grad_scaling (bool, optional): Whether to enable gradient scaling. Defaults to True.
|
||||||
|
grad_accum (int, optional): Number of gradient accumulation steps. Defaults to None.
|
||||||
|
"""
|
||||||
self._model = model
|
self._model = model
|
||||||
self._optimizer = optimizer
|
self._optimizer = optimizer
|
||||||
self._device = device
|
self._device = device
|
||||||
@ -504,6 +751,13 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
|||||||
self._grad_accum = grad_accum
|
self._grad_accum = grad_accum
|
||||||
|
|
||||||
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
def _connect(self, dataloader_iter: Iterator[In]) -> None:
|
||||||
|
"""
|
||||||
|
Connects the training pipeline to data and prepares for forward and backward passes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator providing input data batches.
|
||||||
|
"""
|
||||||
|
|
||||||
# batch 1
|
# batch 1
|
||||||
with torch.cuda.stream(self._memcpy_stream):
|
with torch.cuda.stream(self._memcpy_stream):
|
||||||
batch_i = next(dataloader_iter)
|
batch_i = next(dataloader_iter)
|
||||||
@ -524,13 +778,20 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
|||||||
|
|
||||||
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
||||||
"""
|
"""
|
||||||
NOTE: This method has been updated to perform gradient accumulation.
|
Progresses through the training pipeline, performing forward and backward passes.
|
||||||
If `_grad_accum` is set, then loss values are scaled by this amount and
|
|
||||||
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
|
||||||
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
|
||||||
step.
|
|
||||||
|
|
||||||
"""
|
NOTE: This method has been updated to perform gradient accumulation.
|
||||||
|
If `_grad_accum` is set, then loss values are scaled by this amount and
|
||||||
|
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
|
||||||
|
(congruent to training steps), and then update/reset on every `_grad_accum`th
|
||||||
|
step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader_iter (Iterator[In]): An iterator providing input data batches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Out: The output of the forward pass.
|
||||||
|
"""
|
||||||
should_step_optimizer = (
|
should_step_optimizer = (
|
||||||
self._grad_accum is not None
|
self._grad_accum is not None
|
||||||
and self._progress_calls > 0
|
and self._progress_calls > 0
|
||||||
@ -617,9 +878,9 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
|
|||||||
|
|
||||||
def _sync_pipeline(self) -> None:
|
def _sync_pipeline(self) -> None:
|
||||||
"""
|
"""
|
||||||
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
|
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
|
||||||
current train pipeline. Used when switching between train pipelines for the same
|
current train pipeline. Used when switching between train pipelines for the same
|
||||||
model.
|
model.
|
||||||
"""
|
"""
|
||||||
for module in self._pipelined_modules:
|
for module in self._pipelined_modules:
|
||||||
module.forward._context = self._context
|
module.forward._context = self._context
|
||||||
|
@ -9,100 +9,176 @@ FLIGHT_SERVER_PORT: int = 2222
|
|||||||
|
|
||||||
|
|
||||||
def on_kf():
|
def on_kf():
|
||||||
return "SPEC_TYPE" in os.environ
|
"""Check if the code is running on Kubernetes with Kubeflow (KF) environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if running on KF, False otherwise.
|
||||||
|
"""
|
||||||
|
return "SPEC_TYPE" in os.environ
|
||||||
|
|
||||||
|
|
||||||
def has_readers():
|
def has_readers():
|
||||||
if on_kf():
|
"""Check if the current task has dataset workers.
|
||||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
|
||||||
return machines_config_env["dataset_worker"] is not None
|
Returns:
|
||||||
return os.environ.get("HAS_READERS", "False") == "True"
|
bool: True if the task has dataset workers, False otherwise.
|
||||||
|
"""
|
||||||
|
if on_kf():
|
||||||
|
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||||
|
return machines_config_env.get("dataset_worker") is not None
|
||||||
|
return os.environ.get("HAS_READERS", "False") == "True"
|
||||||
|
|
||||||
|
|
||||||
def get_task_type():
|
def get_task_type():
|
||||||
if on_kf():
|
"""Get the type of the current task.
|
||||||
return os.environ["SPEC_TYPE"]
|
|
||||||
return os.environ["TASK_TYPE"]
|
Returns:
|
||||||
|
str: Task type, such as 'chief', 'datasetworker', or 'datasetdispatcher'.
|
||||||
|
"""
|
||||||
|
if on_kf():
|
||||||
|
return os.environ["SPEC_TYPE"]
|
||||||
|
return os.environ["TASK_TYPE"]
|
||||||
|
|
||||||
|
|
||||||
def is_chief() -> bool:
|
def is_chief() -> bool:
|
||||||
return get_task_type() == "chief"
|
"""Check if the current task is the 'chief'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current task is the 'chief', False otherwise.
|
||||||
|
"""
|
||||||
|
return get_task_type() == "chief"
|
||||||
|
|
||||||
|
|
||||||
def is_reader() -> bool:
|
def is_reader() -> bool:
|
||||||
return get_task_type() == "datasetworker"
|
"""Check if the current task is a 'datasetworker'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current task is a 'datasetworker', False otherwise.
|
||||||
|
"""
|
||||||
|
return get_task_type() == "datasetworker"
|
||||||
|
|
||||||
|
|
||||||
def is_dispatcher() -> bool:
|
def is_dispatcher() -> bool:
|
||||||
return get_task_type() == "datasetdispatcher"
|
"""Check if the current task is a 'datasetdispatcher'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current task is a 'datasetdispatcher', False otherwise.
|
||||||
|
"""
|
||||||
|
return get_task_type() == "datasetdispatcher"
|
||||||
|
|
||||||
|
|
||||||
def get_task_index():
|
def get_task_index():
|
||||||
if on_kf():
|
"""Get the index of the current task.
|
||||||
pod_name = os.environ["MY_POD_NAME"]
|
|
||||||
return int(pod_name.split("-")[-1])
|
Returns:
|
||||||
else:
|
int: Task index.
|
||||||
raise NotImplementedError
|
Raises:
|
||||||
|
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
|
||||||
|
"""
|
||||||
|
if on_kf():
|
||||||
|
pod_name = os.environ["MY_POD_NAME"]
|
||||||
|
return int(pod_name.split("-")[-1])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def get_reader_port():
|
def get_reader_port():
|
||||||
if on_kf():
|
"""Get the port used by readers.
|
||||||
return KF_DDS_PORT
|
|
||||||
return SLURM_DDS_PORT
|
Returns:
|
||||||
|
int: Reader port.
|
||||||
|
"""
|
||||||
|
if on_kf():
|
||||||
|
return KF_DDS_PORT
|
||||||
|
return SLURM_DDS_PORT
|
||||||
|
|
||||||
|
|
||||||
def get_dds():
|
def get_dds():
|
||||||
if not has_readers():
|
"""Get the Distributed Data Service (DDS) address.
|
||||||
return None
|
|
||||||
dispatcher_address = get_dds_dispatcher_address()
|
Returns:
|
||||||
if dispatcher_address:
|
str: DDS address in the format 'grpc://host:port'.
|
||||||
return f"grpc://{dispatcher_address}"
|
Raises:
|
||||||
else:
|
ValueError: If the job does not have DDS.
|
||||||
raise ValueError("Job does not have DDS.")
|
"""
|
||||||
|
if not has_readers():
|
||||||
|
return None
|
||||||
|
dispatcher_address = get_dds_dispatcher_address()
|
||||||
|
if dispatcher_address:
|
||||||
|
return f"grpc://{dispatcher_address}"
|
||||||
|
else:
|
||||||
|
raise ValueError("Job does not have DDS.")
|
||||||
|
|
||||||
|
|
||||||
def get_dds_dispatcher_address():
|
def get_dds_dispatcher_address():
|
||||||
if not has_readers():
|
"""Get the DDS dispatcher address.
|
||||||
return None
|
|
||||||
if on_kf():
|
Returns:
|
||||||
job_name = os.environ["JOB_NAME"]
|
str: DDS dispatcher address in the format 'host:port'.
|
||||||
dds_host = f"{job_name}-datasetdispatcher-0"
|
"""
|
||||||
else:
|
if not has_readers():
|
||||||
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
|
return None
|
||||||
return f"{dds_host}:{get_reader_port()}"
|
if on_kf():
|
||||||
|
job_name = os.environ["JOB_NAME"]
|
||||||
|
dds_host = f"{job_name}-datasetdispatcher-0"
|
||||||
|
else:
|
||||||
|
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
|
||||||
|
return f"{dds_host}:{get_reader_port()}"
|
||||||
|
|
||||||
|
|
||||||
def get_dds_worker_address():
|
def get_dds_worker_address():
|
||||||
if not has_readers():
|
"""Get the DDS worker address.
|
||||||
return None
|
|
||||||
if on_kf():
|
Returns:
|
||||||
job_name = os.environ["JOB_NAME"]
|
str: DDS worker address in the format 'host:port'.
|
||||||
task_index = get_task_index()
|
"""
|
||||||
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
|
if not has_readers():
|
||||||
else:
|
return None
|
||||||
node = os.environ["SLURMD_NODENAME"]
|
if on_kf():
|
||||||
return f"{node}:{get_reader_port()}"
|
job_name = os.environ["JOB_NAME"]
|
||||||
|
task_index = get_task_index()
|
||||||
|
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
|
||||||
|
else:
|
||||||
|
node = os.environ["SLURMD_NODENAME"]
|
||||||
|
return f"{node}:{get_reader_port()}"
|
||||||
|
|
||||||
|
|
||||||
def get_num_readers():
|
def get_num_readers():
|
||||||
if not has_readers():
|
"""Get the number of dataset workers.
|
||||||
return 0
|
|
||||||
if on_kf():
|
Returns:
|
||||||
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
int: Number of dataset workers.
|
||||||
return int(machines_config_env["num_dataset_workers"] or 0)
|
"""
|
||||||
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
|
if not has_readers():
|
||||||
|
return 0
|
||||||
|
if on_kf():
|
||||||
|
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
|
||||||
|
return int(machines_config_env.get("num_dataset_workers") or 0)
|
||||||
|
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
|
||||||
|
|
||||||
|
|
||||||
def get_flight_server_addresses():
|
def get_flight_server_addresses():
|
||||||
if on_kf():
|
"""Get Flight server addresses for dataset workers.
|
||||||
job_name = os.environ["JOB_NAME"]
|
|
||||||
return [
|
Returns:
|
||||||
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
|
List[str]: List of Flight server addresses in the format 'grpc://host:port'.
|
||||||
for task_index in range(get_num_readers())
|
Raises:
|
||||||
]
|
NotImplementedError: If not running on Kubernetes with Kubeflow (KF) environment.
|
||||||
else:
|
"""
|
||||||
raise NotImplementedError
|
if on_kf():
|
||||||
|
job_name = os.environ["JOB_NAME"]
|
||||||
|
return [
|
||||||
|
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
|
||||||
|
for task_index in range(get_num_readers())
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def get_dds_journaling_dir():
|
def get_dds_journaling_dir():
|
||||||
return os.environ.get("DATASET_JOURNALING_DIR", None)
|
"""Get the DDS journaling directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: DDS journaling directory.
|
||||||
|
"""
|
||||||
|
return os.environ.get("DATASET_JOURNALING_DIR", None)
|
||||||
|
@ -14,18 +14,21 @@ def update_mean(
|
|||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Update the mean according to Welford formula:
|
Update the mean according to the Welford formula.
|
||||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
|
||||||
|
This function updates the mean and the weighted sum of values using the Welford algorithm.
|
||||||
|
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||||
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
||||||
Args:
|
|
||||||
current_mean: The value of the current accumulated mean.
|
|
||||||
current_weight_sum: The current weighted sum.
|
|
||||||
value: The new value that needs to be added to get a new mean.
|
|
||||||
weight: The weights for the new value.
|
|
||||||
|
|
||||||
Returns: The updated mean and updated weighted sum.
|
Args:
|
||||||
|
current_mean (torch.Tensor): The value of the current accumulated mean.
|
||||||
|
current_weight_sum (torch.Tensor): The current weighted sum.
|
||||||
|
value (torch.Tensor): The new value that needs to be added to get a new mean.
|
||||||
|
weight (torch.Tensor): The weights for the new value.
|
||||||
|
|
||||||
"""
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: The updated mean and updated weighted sum.
|
||||||
|
"""
|
||||||
weight = torch.broadcast_to(weight, value.shape)
|
weight = torch.broadcast_to(weight, value.shape)
|
||||||
|
|
||||||
# Avoiding (on purpose) in-place operation when using += in case
|
# Avoiding (on purpose) in-place operation when using += in case
|
||||||
@ -38,12 +41,15 @@ def update_mean(
|
|||||||
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Merge the state from multiple workers.
|
Merge the state from multiple workers.
|
||||||
Args:
|
|
||||||
state: A tensor with the first dimension indicating workers.
|
|
||||||
|
|
||||||
Returns: The accumulated mean from all workers.
|
This function merges the state from multiple workers to compute the accumulated mean.
|
||||||
|
|
||||||
"""
|
Args:
|
||||||
|
state (torch.Tensor): A tensor with the first dimension indicating workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The accumulated mean from all workers.
|
||||||
|
"""
|
||||||
mean, weight_sum = update_mean(
|
mean, weight_sum = update_mean(
|
||||||
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
||||||
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
|
||||||
@ -55,12 +61,20 @@ def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
class StableMean(torchmetrics.Metric):
|
class StableMean(torchmetrics.Metric):
|
||||||
"""
|
"""
|
||||||
This implements a numerical stable mean metrics computation using Welford algorithm according to
|
A numerical stable mean metric using the Welford algorithm.
|
||||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
|
||||||
|
This class implements a numerical stable mean metrics computation using the Welford algorithm.
|
||||||
|
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
|
||||||
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
|
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
|
||||||
than the maximum float32 as far as the mean is within the limit of float32.
|
than the maximum float32 as far as the mean is within the limit of float32.
|
||||||
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
|
||||||
"""
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
mean_and_weight_sum (torch.Tensor): A tensor to store the mean and weighted sum.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -75,12 +89,12 @@ class StableMean(torchmetrics.Metric):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
|
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
|
||||||
"""
|
"""Update the current mean.
|
||||||
Update the current mean.
|
|
||||||
Args:
|
Args:
|
||||||
value: Value to update the mean with.
|
value (torch.Tensor): Value to update the mean with.
|
||||||
weight: weight to use. Shape should be broadcastable to that of value.
|
weight (Union[float, torch.Tensor]): Weight to use. Shape should be broadcastable to that of value.
|
||||||
"""
|
"""
|
||||||
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
|
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
|
||||||
|
|
||||||
if not isinstance(weight, torch.Tensor):
|
if not isinstance(weight, torch.Tensor):
|
||||||
@ -91,7 +105,9 @@ class StableMean(torchmetrics.Metric):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def compute(self) -> torch.Tensor:
|
def compute(self) -> torch.Tensor:
|
||||||
"""
|
"""Compute and return the accumulated mean.
|
||||||
Compute and return the accumulated mean.
|
|
||||||
"""
|
Returns:
|
||||||
|
torch.Tensor: The accumulated mean.
|
||||||
|
"""
|
||||||
return self.mean_and_weight_sum[0]
|
return self.mean_and_weight_sum[0]
|
||||||
|
@ -29,6 +29,9 @@ def _compute_helper(
|
|||||||
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
|
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
|
||||||
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
|
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
|
||||||
we assume that they are correct prediction (i.e weight = 0).
|
we assume that they are correct prediction (i.e weight = 0).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The computed AUROC
|
||||||
"""
|
"""
|
||||||
dim = 0
|
dim = 0
|
||||||
|
|
||||||
@ -52,24 +55,34 @@ def _compute_helper(
|
|||||||
|
|
||||||
class AUROCWithMWU(torchmetrics.Metric):
|
class AUROCWithMWU(torchmetrics.Metric):
|
||||||
"""
|
"""
|
||||||
AUROC using Mann-Whitney U-test.
|
AUROC (Area Under the Receiver Operating Characteristic) using Mann-Whitney U-test.
|
||||||
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
|
|
||||||
|
|
||||||
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return
|
This AUROC implementation is well suited for (non-zero) low-CTR (Click-Through Rate)
|
||||||
the correct AUROC even if the predicted probabilities are all close to 0.
|
scenarios. It returns the correct AUROC even when predicted probabilities are close to 0.
|
||||||
Currently only support binary classification.
|
|
||||||
"""
|
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
|
||||||
|
|
||||||
|
Note: Currently, this implementation only supports binary classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label_threshold (float): Threshold for classifying labels as positive or negative.
|
||||||
|
Labels above this threshold are considered positive, and those below are considered negative.
|
||||||
|
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
|
||||||
|
Otherwise, a warning is logged, and AUROC is computed.
|
||||||
|
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
|
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
Initializes the AUROCWithMWU metric.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label_threshold: Labels strictly above this threshold are considered positive labels,
|
label_threshold (float): Threshold for classifying labels as positive or negative.
|
||||||
otherwise, they are considered negative.
|
Labels above this threshold are considered positive, and those below are considered negative.
|
||||||
raise_missing_class: If True, an error will be raise if negative or positive class is missing.
|
raise_missing_class (bool): If True, an error is raised when the negative or positive class is missing.
|
||||||
Otherwise, we will simply log a warning.
|
Otherwise, a warning is logged, and AUROC is computed.
|
||||||
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.add_state("predictions", default=[], dist_reduce_fx="cat")
|
self.add_state("predictions", default=[], dist_reduce_fx="cat")
|
||||||
self.add_state("target", default=[], dist_reduce_fx="cat")
|
self.add_state("target", default=[], dist_reduce_fx="cat")
|
||||||
|
@ -14,7 +14,24 @@ from absl import logging as logging
|
|||||||
|
|
||||||
|
|
||||||
def setup_absl_logging():
|
def setup_absl_logging():
|
||||||
"""Make sure that absl logging pushes to stdout rather than stderr."""
|
"""
|
||||||
|
Configure absl-py logging to direct log messages to stdout and apply a custom log message format.
|
||||||
|
|
||||||
|
This function ensures that log messages generated by the absl-py library are written to stdout
|
||||||
|
rather than stderr. It also applies a custom log message format that includes module, function,
|
||||||
|
line number, log level, and the log message content.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function should be called once at the beginning of your script or application to
|
||||||
|
configure absl-py logging.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To use this function, simply call it at the start of your script:
|
||||||
|
```
|
||||||
|
setup_absl_logging()
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
logging.get_absl_handler().python_handler.stream = sys.stdout
|
logging.get_absl_handler().python_handler.stream = sys.stdout
|
||||||
formatter = py_logging.Formatter(
|
formatter = py_logging.Formatter(
|
||||||
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
|
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
|
||||||
|
@ -5,6 +5,21 @@ from tml.ml_logging.torch_logging import logging
|
|||||||
|
|
||||||
class Testtlogging(unittest.TestCase):
|
class Testtlogging(unittest.TestCase):
|
||||||
def test_warn_once(self):
|
def test_warn_once(self):
|
||||||
|
"""
|
||||||
|
Test that warning messages are logged only once when using the assertLogs context manager.
|
||||||
|
|
||||||
|
This unit test checks the behavior of the logging system when warning messages are issued
|
||||||
|
multiple times within the same context. It uses the assertLogs context manager to capture
|
||||||
|
log messages at the INFO level and verifies that warning messages are logged only once.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To use this test case, call it using a test runner like unittest:
|
||||||
|
```
|
||||||
|
python -m unittest your_test_module.TestLogging.test_warn_once
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
with self.assertLogs(level="INFO") as captured_logs:
|
with self.assertLogs(level="INFO") as captured_logs:
|
||||||
logging.info("first info")
|
logging.info("first info")
|
||||||
logging.warning("first warning")
|
logging.warning("first warning")
|
||||||
|
@ -18,7 +18,35 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
|
|
||||||
def rank_specific(logger):
|
def rank_specific(logger):
|
||||||
"""Ensures that we only override a given logger once."""
|
"""
|
||||||
|
Customize logger behavior based on the distributed environment and rank.
|
||||||
|
|
||||||
|
This function allows for customizing the behavior of a logger based on the distributed environment and the rank
|
||||||
|
of the current process. It overrides standard logging methods (e.g., error, warning) to conditionally log messages
|
||||||
|
depending on the rank or limit the number of redundant logs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger: The logger object to customize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The customized logger.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To use this function with the `logging` module:
|
||||||
|
```python
|
||||||
|
import logging
|
||||||
|
from rank_specific_logging import rank_specific
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
rank_specific(logger)
|
||||||
|
```
|
||||||
|
|
||||||
|
Customization:
|
||||||
|
- Messages are only logged if the distributed environment is not initialized or if the rank matches.
|
||||||
|
- The 'warning' method is limited to logging a single redundant warning.
|
||||||
|
- Logging from rank -1 is redirected to include the rank information.
|
||||||
|
|
||||||
|
"""
|
||||||
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
|
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
@ -8,11 +8,60 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class PiecewiseConstant(base_config.BaseConfig):
|
class PiecewiseConstant(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for a piecewise constant learning rate schedule.
|
||||||
|
|
||||||
|
This configuration class allows you to specify a piecewise constant learning rate schedule
|
||||||
|
by defining boundaries and corresponding learning rate values.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate_boundaries (List[int], optional): List of step boundaries at which
|
||||||
|
the learning rate will change. If None, no boundaries are defined.
|
||||||
|
learning_rate_values (List[float], optional): List of learning rate values
|
||||||
|
corresponding to the boundaries. If None, no values are defined.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure a piecewise constant learning rate schedule, create an instance of this class
|
||||||
|
and set the attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
piecewise_lr = PiecewiseConstant(
|
||||||
|
learning_rate_boundaries=[1000, 2000, 3000],
|
||||||
|
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The number of learning rate values should be one more than the number of boundaries.
|
||||||
|
|
||||||
|
"""
|
||||||
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
|
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
|
||||||
learning_rate_values: typing.List[float] = pydantic.Field(None)
|
learning_rate_values: typing.List[float] = pydantic.Field(None)
|
||||||
|
|
||||||
|
|
||||||
class LinearRampToConstant(base_config.BaseConfig):
|
class LinearRampToConstant(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for a linear ramp-up to constant learning rate schedule.
|
||||||
|
|
||||||
|
This configuration class allows you to specify a learning rate schedule that ramps up linearly
|
||||||
|
from zero to a constant value over a specified number of steps.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate (float): The final constant learning rate.
|
||||||
|
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure a linear ramp-up to a constant learning rate, create an instance of this class
|
||||||
|
and set the attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
linear_ramp_lr = LinearRampToConstant(
|
||||||
|
learning_rate=0.1,
|
||||||
|
num_ramp_steps=1000
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
learning_rate: float
|
learning_rate: float
|
||||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||||
description="Number of steps to ramp this up from zero."
|
description="Number of steps to ramp this up from zero."
|
||||||
@ -20,6 +69,32 @@ class LinearRampToConstant(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class LinearRampToCosine(base_config.BaseConfig):
|
class LinearRampToCosine(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for a linear ramp-up to cosine decay learning rate schedule.
|
||||||
|
|
||||||
|
This configuration class allows you to specify a learning rate schedule that ramps up linearly
|
||||||
|
from zero, then decays following a cosine schedule to a final constant learning rate.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate (float): The initial learning rate at the start of ramp-up.
|
||||||
|
final_learning_rate (float): The final constant learning rate after decay.
|
||||||
|
num_ramp_steps (PositiveInt): Number of steps to ramp up the learning rate from zero.
|
||||||
|
final_num_steps (PositiveInt): Final number of steps where decay stops.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure a linear ramp-up to cosine decay learning rate, create an instance of this
|
||||||
|
class and set the attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
ramp_to_cosine_lr = LinearRampToCosine(
|
||||||
|
learning_rate=0.01,
|
||||||
|
final_learning_rate=0.001,
|
||||||
|
num_ramp_steps=1000,
|
||||||
|
final_num_steps=5000
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
learning_rate: float
|
learning_rate: float
|
||||||
final_learning_rate: float
|
final_learning_rate: float
|
||||||
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
|
||||||
@ -31,6 +106,41 @@ class LinearRampToCosine(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class LearningRate(base_config.BaseConfig):
|
class LearningRate(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Learning rate configuration for training.
|
||||||
|
|
||||||
|
This configuration class allows you to specify different learning rate schedules
|
||||||
|
for your training process.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
constant (float, optional): Constant learning rate to be used throughout training.
|
||||||
|
linear_ramp_to_cosine (LinearRampToCosine, optional): Learning rate that ramps up linearly
|
||||||
|
and then decays following a cosine schedule.
|
||||||
|
linear_ramp_to_constant (LinearRampToConstant, optional): Learning rate that ramps up
|
||||||
|
linearly and then remains constant.
|
||||||
|
piecewise_constant (PiecewiseConstant, optional): Learning rate that changes at specified
|
||||||
|
boundaries with corresponding values.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure a learning rate schedule, create an instance of this class and set the
|
||||||
|
attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
learning_rate = LearningRate(
|
||||||
|
constant=0.01,
|
||||||
|
linear_ramp_to_cosine=LinearRampToCosine(
|
||||||
|
learning_rate=0.1,
|
||||||
|
final_learning_rate=0.001,
|
||||||
|
num_ramp_steps=1000,
|
||||||
|
final_num_steps=5000
|
||||||
|
)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Each learning rate schedule attribute can be set to `None` if not needed.
|
||||||
|
|
||||||
|
"""
|
||||||
constant: float = pydantic.Field(None, one_of="lr")
|
constant: float = pydantic.Field(None, one_of="lr")
|
||||||
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
|
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
|
||||||
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
|
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
|
||||||
@ -38,30 +148,166 @@ class LearningRate(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class OptimizerAlgorithmConfig(base_config.BaseConfig):
|
class OptimizerAlgorithmConfig(base_config.BaseConfig):
|
||||||
"""Base class for optimizer configurations."""
|
"""
|
||||||
|
Base class for optimizer configurations.
|
||||||
|
|
||||||
|
This base configuration class provides a structure for specifying various optimizer-related
|
||||||
|
settings, including the learning rate and different learning rate schedules.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lr (float): The base learning rate used by the optimizer.
|
||||||
|
|
||||||
|
Subclasses should inherit from this base class and define additional attributes specific to
|
||||||
|
the optimizer algorithm they represent.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To create a custom optimizer configuration, create a subclass of this base class and
|
||||||
|
define the necessary attributes. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyOptimizerConfig(OptimizerAlgorithmConfig):
|
||||||
|
momentum: float = pydantic.Field(0.9, description="Momentum value for SGD.")
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This base class does not include specific optimizer settings. Subclasses should define
|
||||||
|
the optimizer-specific attributes as needed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
lr: float
|
lr: float
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class AdamConfig(OptimizerAlgorithmConfig):
|
class AdamConfig(OptimizerAlgorithmConfig):
|
||||||
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
|
"""
|
||||||
|
Configuration for the Adam optimizer.
|
||||||
|
|
||||||
|
This configuration class allows you to specify the hyperparameters for the Adam optimizer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lr (float): The learning rate for optimization.
|
||||||
|
betas (Tuple[float, float], optional): Coefficients used for computing running averages
|
||||||
|
of gradient and squared gradient. Defaults to (0.9, 0.999).
|
||||||
|
eps (float, optional): A small constant added to the denominator for numerical stability.
|
||||||
|
Defaults to 1e-7.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure the Adam optimizer, create an instance of this class and set the attributes
|
||||||
|
accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
adam_optimizer = AdamConfig(
|
||||||
|
lr=0.001,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
[PyTorch Adam Documentation](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)
|
||||||
|
|
||||||
|
"""
|
||||||
lr: float
|
lr: float
|
||||||
betas: typing.Tuple[float, float] = [0.9, 0.999]
|
betas: typing.Tuple[float, float] = [0.9, 0.999]
|
||||||
eps: float = 1e-7 # Numerical stability in denominator.
|
eps: float = 1e-7 # Numerical stability in denominator.
|
||||||
|
|
||||||
|
|
||||||
class SgdConfig(OptimizerAlgorithmConfig):
|
class SgdConfig(OptimizerAlgorithmConfig):
|
||||||
|
"""
|
||||||
|
Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
|
||||||
|
This configuration class allows you to specify the hyperparameters for the SGD optimizer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lr (float): The learning rate for optimization.
|
||||||
|
momentum (float, optional): The momentum factor for SGD. Defaults to 0.0.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure the SGD optimizer, create an instance of this class and set the attributes
|
||||||
|
accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
sgd_optimizer = SgdConfig(
|
||||||
|
lr=0.01,
|
||||||
|
momentum=0.9
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
lr: float
|
lr: float
|
||||||
momentum: float = 0.0
|
momentum: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class AdagradConfig(OptimizerAlgorithmConfig):
|
class AdagradConfig(OptimizerAlgorithmConfig):
|
||||||
|
"""
|
||||||
|
Configuration for the optimizer used during training.
|
||||||
|
|
||||||
|
This configuration class allows you to specify the optimizer for training, including
|
||||||
|
options for various optimizer algorithms.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate (LearningRate, optional): Learning rate configuration. Defaults to None.
|
||||||
|
adam (AdamConfig, optional): Configuration for the Adam optimizer. Defaults to None.
|
||||||
|
sgd (SgdConfig, optional): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
Defaults to None.
|
||||||
|
adagrad (AdagradConfig, optional): Configuration for the Adagrad optimizer. Defaults to None.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure the optimizer for training, create an instance of this class and set the
|
||||||
|
attributes accordingly. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
learning_rate=LearningRate(constant=0.001),
|
||||||
|
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
lr: float
|
lr: float
|
||||||
eps: float = 0
|
eps: float = 0
|
||||||
|
|
||||||
|
|
||||||
class OptimizerConfig(base_config.BaseConfig):
|
class OptimizerConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for defining different optimizer algorithms and their parameters.
|
||||||
|
|
||||||
|
This class allows you to configure various optimizer algorithms such as Adam, SGD, and Adagrad,
|
||||||
|
along with their respective hyperparameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (LearningRate): The learning rate configuration, which can include
|
||||||
|
constant learning rates or other learning rate schedules.
|
||||||
|
adam (AdamConfig): Configuration for the Adam optimizer.
|
||||||
|
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
learning_rate=LearningRate(constant=0.001),
|
||||||
|
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate (LearningRate): The learning rate configuration.
|
||||||
|
adam (AdamConfig): Configuration for the Adam optimizer.
|
||||||
|
sgd (SgdConfig): Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
adagrad (AdagradConfig): Configuration for the Adagrad optimizer.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
You can specify only one of the optimizer configurations (adam, sgd, or adagrad) in an
|
||||||
|
`OptimizerConfig` instance.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `LearningRate`: Configuration for specifying learning rates.
|
||||||
|
- `AdamConfig`: Configuration for the Adam optimizer.
|
||||||
|
- `SgdConfig`: Configuration for the Stochastic Gradient Descent (SGD) optimizer.
|
||||||
|
- `AdagradConfig`: Configuration for the Adagrad optimizer.
|
||||||
|
|
||||||
|
"""
|
||||||
learning_rate: LearningRate = pydantic.Field(
|
learning_rate: LearningRate = pydantic.Field(
|
||||||
None,
|
None,
|
||||||
description="Constant learning rates",
|
description="Constant learning rates",
|
||||||
@ -72,6 +318,33 @@ class OptimizerConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
|
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
|
||||||
|
"""
|
||||||
|
Get the optimizer algorithm configuration from the given `OptimizerConfig`.
|
||||||
|
|
||||||
|
This function extracts and returns the specific optimizer algorithm configuration
|
||||||
|
(e.g., Adam, SGD, or Adagrad) from the provided `OptimizerConfig`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer_config (OptimizerConfig): The optimizer configuration object containing
|
||||||
|
one of the optimizer algorithm configurations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[AdamConfig, SgdConfig, AdagradConfig]: The specific optimizer algorithm
|
||||||
|
configuration extracted from `optimizer_config`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no optimizer algorithm is selected in `optimizer_config`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
adam=AdamConfig(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
|
||||||
|
)
|
||||||
|
algorithm_config = get_optimizer_algorithm_config(optimizer_config)
|
||||||
|
# `algorithm_config` will be an instance of `AdamConfig`.
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
if optimizer_config.adam is not None:
|
if optimizer_config.adam is not None:
|
||||||
return optimizer_config.adam
|
return optimizer_config.adam
|
||||||
elif optimizer_config.sgd is not None:
|
elif optimizer_config.sgd is not None:
|
||||||
|
@ -14,7 +14,35 @@ from tml.ml_logging.torch_logging import logging
|
|||||||
|
|
||||||
|
|
||||||
def compute_lr(lr_config, step):
|
def compute_lr(lr_config, step):
|
||||||
"""Compute a learning rate."""
|
"""
|
||||||
|
Compute the learning rate based on the specified learning rate configuration.
|
||||||
|
|
||||||
|
This function calculates the learning rate according to the given configuration, which can include
|
||||||
|
constant learning rates, piecewise constant schedules, linear ramps, and cosine annealing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lr_config (LearningRate): The learning rate configuration specifying the learning rate schedule.
|
||||||
|
step (int): The current training step or iteration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The computed learning rate for the current step.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the `lr_config` is invalid or contains conflicting options.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
lr_schedule = LearningRate(
|
||||||
|
constant=0.001,
|
||||||
|
piecewise_constant=PiecewiseConstant(
|
||||||
|
learning_rate_boundaries=[1000, 2000, 3000],
|
||||||
|
learning_rate_values=[0.1, 0.05, 0.01, 0.001]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_step = 2500
|
||||||
|
learning_rate = compute_lr(lr_schedule, current_step)
|
||||||
|
```
|
||||||
|
"""
|
||||||
if lr_config.constant is not None:
|
if lr_config.constant is not None:
|
||||||
return lr_config.constant
|
return lr_config.constant
|
||||||
elif lr_config.piecewise_constant is not None:
|
elif lr_config.piecewise_constant is not None:
|
||||||
@ -46,11 +74,54 @@ def compute_lr(lr_config, step):
|
|||||||
|
|
||||||
|
|
||||||
class LRShim(_LRScheduler):
|
class LRShim(_LRScheduler):
|
||||||
"""Shim to get learning rates into a LRScheduler.
|
|
||||||
|
|
||||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
|
||||||
e.g. exponential decay can be used.
|
|
||||||
"""
|
"""
|
||||||
|
Learning Rate Scheduler Shim to adjust learning rates during training.
|
||||||
|
|
||||||
|
This class acts as a shim to apply different learning rates to individual parameter groups
|
||||||
|
within an optimizer. It adheres to the torch.optim scheduler API and can be used with various
|
||||||
|
optimizers, allowing fine-grained control over learning rates based on configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer for which learning rates will be adjusted.
|
||||||
|
lr_dict (Dict[str, LearningRate]): A dictionary mapping parameter group names to their
|
||||||
|
corresponding learning rate configurations.
|
||||||
|
last_epoch (int, optional): The index of the last epoch. Default is -1.
|
||||||
|
verbose (bool, optional): If True, prints a warning message when accessing learning rates
|
||||||
|
using the deprecated `get_lr()` method. Default is False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the number of parameter groups in the optimizer does not match the number
|
||||||
|
of learning rate configurations provided.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
To obtain the last computed learning rates, please use `get_last_lr()`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
|
lr_schedule = {
|
||||||
|
'main': LearningRate(constant=0.01),
|
||||||
|
'auxiliary': LearningRate(piecewise_constant=PiecewiseConstant(
|
||||||
|
learning_rate_boundaries=[1000, 2000],
|
||||||
|
learning_rate_values=[0.01, 0.001]
|
||||||
|
))
|
||||||
|
}
|
||||||
|
lr_shim = LRShim(optimizer, lr_schedule)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
# Train the model
|
||||||
|
train(...)
|
||||||
|
# Update learning rates at the end of each epoch
|
||||||
|
lr_shim.step(epoch)
|
||||||
|
|
||||||
|
final_lr_main = lr_shim.get_last_lr()['main']
|
||||||
|
final_lr_auxiliary = lr_shim.get_last_lr()['auxiliary']
|
||||||
|
```
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `LearningRate`: Configuration for specifying learning rates.
|
||||||
|
- `PiecewiseConstant`: Configuration for piecewise constant learning rate schedules.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -95,9 +166,42 @@ def get_optimizer_class(optimizer_config: OptimizerConfig):
|
|||||||
def build_optimizer(
|
def build_optimizer(
|
||||||
model: torch.nn.Module, optimizer_config: OptimizerConfig
|
model: torch.nn.Module, optimizer_config: OptimizerConfig
|
||||||
) -> Tuple[Optimizer, _LRScheduler]:
|
) -> Tuple[Optimizer, _LRScheduler]:
|
||||||
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
|
|
||||||
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
|
|
||||||
"""
|
"""
|
||||||
|
Build an optimizer and learning rate scheduler based on the provided optimizer configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The PyTorch model for which the optimizer will be created.
|
||||||
|
optimizer_config (OptimizerConfig): The optimizer configuration specifying the optimizer
|
||||||
|
algorithm and learning rate settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Optimizer, _LRScheduler]: A tuple containing the optimizer and learning rate scheduler
|
||||||
|
objects.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function is intended for cases where you want the same optimizer and learning rate
|
||||||
|
schedule for all model parameters.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
model = MyModel()
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
learning_rate=LearningRate(constant=0.01),
|
||||||
|
sgd=SgdConfig(lr=0.01, momentum=0.9)
|
||||||
|
)
|
||||||
|
optimizer, scheduler = build_optimizer(model, optimizer_config)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
# Train the model with the optimizer
|
||||||
|
train(model, optimizer, ...)
|
||||||
|
# Update learning rates at the end of each epoch
|
||||||
|
scheduler.step(epoch)
|
||||||
|
```
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `OptimizerConfig`: Configuration for specifying optimizer settings.
|
||||||
|
- `LRShim`: Learning rate scheduler shim for fine-grained learning rate control.
|
||||||
|
"""
|
||||||
optimizer_class = get_optimizer_class(optimizer_config)
|
optimizer_class = get_optimizer_class(optimizer_config)
|
||||||
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
||||||
# We're passing everything in as one group here
|
# We're passing everything in as one group here
|
||||||
|
@ -9,6 +9,36 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class TrainingConfig(config_mod.BaseConfig):
|
class TrainingConfig(config_mod.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration settings for the training process.
|
||||||
|
|
||||||
|
This class defines various training-related settings, including the directory to save checkpoints, the number
|
||||||
|
of training steps, logging intervals, and other training parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
save_dir (str): The directory where checkpoints and training artifacts will be saved.
|
||||||
|
num_train_steps (pydantic.PositiveInt): The total number of training steps to run.
|
||||||
|
initial_checkpoint_dir (str): The directory containing initial checkpoints (optional).
|
||||||
|
checkpoint_every_n (pydantic.PositiveInt): Frequency of saving checkpoints during training.
|
||||||
|
checkpoint_max_to_keep (pydantic.PositiveInt): Maximum number of checkpoints to keep (optional).
|
||||||
|
train_log_every_n (pydantic.PositiveInt): Frequency of logging training progress.
|
||||||
|
num_eval_steps (int): Number of evaluation steps. Use a negative value to evaluate the entire dataset.
|
||||||
|
eval_log_every_n (pydantic.PositiveInt): Frequency of logging evaluation progress.
|
||||||
|
eval_timeout_in_s (pydantic.PositiveFloat): Maximum time (in seconds) allowed for evaluation.
|
||||||
|
gradient_accumulation (int): Number of replica steps to accumulate gradients (optional).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure training with checkpoints saved every 1000 steps, use the following settings:
|
||||||
|
|
||||||
|
```python
|
||||||
|
TrainingConfig(
|
||||||
|
save_dir="/tmp/model",
|
||||||
|
num_train_steps=1000000,
|
||||||
|
checkpoint_every_n=1000,
|
||||||
|
train_log_every_n=1000,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
save_dir: str = "/tmp/model"
|
save_dir: str = "/tmp/model"
|
||||||
num_train_steps: pydantic.PositiveInt = 1000000
|
num_train_steps: pydantic.PositiveInt = 1000000
|
||||||
initial_checkpoint_dir: str = pydantic.Field(
|
initial_checkpoint_dir: str = pydantic.Field(
|
||||||
@ -32,6 +62,42 @@ class TrainingConfig(config_mod.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class RecapConfig(config_mod.BaseConfig):
|
class RecapConfig(config_mod.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration settings for the Recap model training process.
|
||||||
|
|
||||||
|
This class defines the overall configuration for the training process of a Recap model. It includes settings for
|
||||||
|
training, model architecture, data, optimization, and evaluation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
training (TrainingConfig): Configuration settings for the training process.
|
||||||
|
model (model_config.ModelConfig): Configuration settings for the Recap model architecture.
|
||||||
|
train_data (data_config.RecapDataConfig): Configuration settings for training data.
|
||||||
|
validation_data (Dict[str, data_config.RecapDataConfig]): Configuration settings for validation data.
|
||||||
|
optimizer (optimizer_config.RecapOptimizerConfig): Configuration settings for optimization.
|
||||||
|
which_metrics (Optional[str]): Optional specification of which metrics to pick.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This class encapsulates all the necessary configurations to train a Recap model. It defines settings for
|
||||||
|
training, the model architecture, data loading, optimization, and evaluation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To configure a Recap model training process, use the following settings:
|
||||||
|
|
||||||
|
```python
|
||||||
|
RecapConfig(
|
||||||
|
training=TrainingConfig(
|
||||||
|
save_dir="/tmp/model",
|
||||||
|
num_train_steps=1000000,
|
||||||
|
checkpoint_every_n=1000,
|
||||||
|
train_log_every_n=1000,
|
||||||
|
),
|
||||||
|
model=model_config.ModelConfig(...),
|
||||||
|
train_data=data_config.RecapDataConfig(...),
|
||||||
|
validation_data={"dev": data_config.RecapDataConfig(...)},
|
||||||
|
optimizer=optimizer_config.RecapOptimizerConfig(...),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
training: TrainingConfig = pydantic.Field(TrainingConfig())
|
training: TrainingConfig = pydantic.Field(TrainingConfig())
|
||||||
model: model_config.ModelConfig
|
model: model_config.ModelConfig
|
||||||
train_data: data_config.RecapDataConfig
|
train_data: data_config.RecapDataConfig
|
||||||
|
@ -30,6 +30,7 @@ class EmbeddingBagConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingOptimizerConfig(base_config.BaseConfig):
|
class EmbeddingOptimizerConfig(base_config.BaseConfig):
|
||||||
|
"""Configuration for the optimizer used for embedding tables."""
|
||||||
learning_rate: optimizer_config.LearningRate = pydantic.Field(
|
learning_rate: optimizer_config.LearningRate = pydantic.Field(
|
||||||
None, description="learning rate scheduler for the EBC"
|
None, description="learning rate scheduler for the EBC"
|
||||||
)
|
)
|
||||||
@ -52,6 +53,7 @@ class LargeEmbeddingsConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class StratifierConfig(base_config.BaseConfig):
|
class StratifierConfig(base_config.BaseConfig):
|
||||||
|
"""Configuration for Stratifier."""
|
||||||
name: str
|
name: str
|
||||||
index: int
|
index: int
|
||||||
value: int
|
value: int
|
||||||
@ -87,7 +89,8 @@ class SmallEmbeddingsConfig(base_config.BaseConfig):
|
|||||||
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
|
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
|
||||||
serving time due to size (>>1 GB).
|
serving time due to size (>>1 GB).
|
||||||
|
|
||||||
This small embeddings table uses the same optimizer as the rest of the model."""
|
This small embeddings table uses the same optimizer as the rest of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
|
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
|
||||||
..., description="list of embedding tables"
|
..., description="list of embedding tables"
|
||||||
|
@ -34,6 +34,33 @@ FLAGS = flags.FLAGS
|
|||||||
|
|
||||||
|
|
||||||
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
|
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Main function to run the training of a ranking model.
|
||||||
|
|
||||||
|
This function initializes and runs the training process for a ranking model based on the provided configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unused_argv (str): Unused argument.
|
||||||
|
data_service_dispatcher (Optional[str]): The data service dispatcher for accessing training data (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the configuration or input data is not valid.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function serves as the main entry point for training a ranking model. It loads the configuration, sets up
|
||||||
|
the training environment, defines the loss function, creates the model, optimizer, and scheduler, and runs the
|
||||||
|
training loop.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To run the training process, use the following command:
|
||||||
|
|
||||||
|
```
|
||||||
|
python run_training.py --config_path=config.yaml
|
||||||
|
```
|
||||||
|
"""
|
||||||
print("#" * 100)
|
print("#" * 100)
|
||||||
|
|
||||||
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
||||||
|
@ -40,6 +40,7 @@ class BatchNormConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class DenseLayerConfig(base_config.BaseConfig):
|
class DenseLayerConfig(base_config.BaseConfig):
|
||||||
|
"""Configuration for the dense layer."""
|
||||||
layer_size: pydantic.PositiveInt
|
layer_size: pydantic.PositiveInt
|
||||||
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
|
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
|
||||||
|
|
||||||
@ -61,6 +62,7 @@ class BatchNormConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleNormLogConfig(base_config.BaseConfig):
|
class DoubleNormLogConfig(base_config.BaseConfig):
|
||||||
|
"""Configuration for the double norm log transform."""
|
||||||
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
|
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
|
||||||
clip_magnitude: float = pydantic.Field(
|
clip_magnitude: float = pydantic.Field(
|
||||||
5.0, description="Threshold to clip the normalized input values."
|
5.0, description="Threshold to clip the normalized input values."
|
||||||
@ -73,12 +75,14 @@ class Log1pAbsConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class ClipLog1pAbsConfig(base_config.BaseConfig):
|
class ClipLog1pAbsConfig(base_config.BaseConfig):
|
||||||
|
"""Configuration for the clip log transform."""
|
||||||
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
|
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
|
||||||
3e38, description="Threshold to clip the input values."
|
3e38, description="Threshold to clip the input values."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ZScoreLogConfig(base_config.BaseConfig):
|
class ZScoreLogConfig(base_config.BaseConfig):
|
||||||
|
"""Configuration for the z-score log transform."""
|
||||||
analysis_path: str
|
analysis_path: str
|
||||||
schema_path: str = pydantic.Field(
|
schema_path: str = pydantic.Field(
|
||||||
None,
|
None,
|
||||||
@ -148,6 +152,7 @@ class DcnConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class MaskBlockConfig(base_config.BaseConfig):
|
class MaskBlockConfig(base_config.BaseConfig):
|
||||||
|
"""Config for MaskNet block."""
|
||||||
output_size: int
|
output_size: int
|
||||||
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
|
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
|
||||||
None, one_of="aggregation_size"
|
None, one_of="aggregation_size"
|
||||||
@ -159,6 +164,7 @@ class MaskBlockConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class MaskNetConfig(base_config.BaseConfig):
|
class MaskNetConfig(base_config.BaseConfig):
|
||||||
|
"""Config for MaskNet model."""
|
||||||
mask_blocks: List[MaskBlockConfig]
|
mask_blocks: List[MaskBlockConfig]
|
||||||
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
|
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
|
||||||
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
|
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
|
||||||
@ -190,6 +196,7 @@ class AffineMap(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class DLRMConfig(base_config.BaseConfig):
|
class DLRMConfig(base_config.BaseConfig):
|
||||||
|
"""Config for DLRM model."""
|
||||||
bottom_mlp: MlpConfig = pydantic.Field(
|
bottom_mlp: MlpConfig = pydantic.Field(
|
||||||
...,
|
...,
|
||||||
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
|
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
|
||||||
@ -198,6 +205,7 @@ class DLRMConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class TaskModel(base_config.BaseConfig):
|
class TaskModel(base_config.BaseConfig):
|
||||||
|
"""Configuration for a single task."""
|
||||||
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
|
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
|
||||||
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
|
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
|
||||||
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
|
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
|
||||||
@ -213,6 +221,7 @@ class TaskModel(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class MultiTaskType(str, enum.Enum):
|
class MultiTaskType(str, enum.Enum):
|
||||||
|
"""Type of multi task architecture."""
|
||||||
SHARE_NONE = "share_none" # Tasks are separate.
|
SHARE_NONE = "share_none" # Tasks are separate.
|
||||||
SHARE_ALL = "share_all" # Tasks share same backbone.
|
SHARE_ALL = "share_all" # Tasks share same backbone.
|
||||||
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
|
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
|
||||||
@ -247,6 +256,7 @@ class ModelConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
@pydantic.root_validator()
|
@pydantic.root_validator()
|
||||||
def _validate_mtl(cls, values):
|
def _validate_mtl(cls, values):
|
||||||
|
"""Validate the multi task architecture."""
|
||||||
if values.get("multi_task_type", None) is None:
|
if values.get("multi_task_type", None) is None:
|
||||||
return values
|
return values
|
||||||
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:
|
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:
|
||||||
|
@ -26,7 +26,19 @@ def unsanitize(sanitized_task_name):
|
|||||||
|
|
||||||
|
|
||||||
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
|
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
|
||||||
""" "Builds a model for a single task"""
|
"""
|
||||||
|
Build a model for a single task based on the provided configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (model_config_mod.TaskModel): The task model configuration.
|
||||||
|
input_shape (int): The input shape for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.Module: The constructed model for the single task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the task configuration is not recognized.
|
||||||
|
"""
|
||||||
if task.mlp_config:
|
if task.mlp_config:
|
||||||
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
|
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
|
||||||
elif task.dcn_config:
|
elif task.dcn_config:
|
||||||
@ -38,7 +50,12 @@ def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int)
|
|||||||
|
|
||||||
|
|
||||||
class MultiTaskRankingModel(torch.nn.Module):
|
class MultiTaskRankingModel(torch.nn.Module):
|
||||||
"""Multi-task ranking model."""
|
"""
|
||||||
|
Multi-task ranking model that handles multiple ranking tasks simultaneously.
|
||||||
|
|
||||||
|
This model takes various input features and predicts rankings for multiple
|
||||||
|
tasks using shared or separate towers.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -47,12 +64,18 @@ class MultiTaskRankingModel(torch.nn.Module):
|
|||||||
data_config: RecapDataConfig,
|
data_config: RecapDataConfig,
|
||||||
return_backbone: bool = False,
|
return_backbone: bool = False,
|
||||||
):
|
):
|
||||||
"""Constructor for Multi task learning.
|
"""
|
||||||
|
Constructor for Multi-task ranking model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_shapes (Mapping[str, torch.Size]): A mapping of input feature names to their shapes.
|
||||||
|
config (ModelConfig): The model configuration.
|
||||||
|
data_config (RecapDataConfig): The data configuration.
|
||||||
|
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
|
||||||
|
|
||||||
Assumptions made:
|
Assumptions made:
|
||||||
1. Tasks specified in data config match model architecture.
|
1. Tasks specified in data config match model architecture.
|
||||||
|
These are all validated in config.
|
||||||
These are all validated in config.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -168,6 +191,23 @@ class MultiTaskRankingModel(torch.nn.Module):
|
|||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
weights: Optional[torch.Tensor] = None,
|
weights: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Forward pass of the Multi-task ranking model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
continuous_features (torch.Tensor): Continuous input features.
|
||||||
|
binary_features (torch.Tensor): Binary input features.
|
||||||
|
discrete_features (Optional[torch.Tensor], optional): Discrete input features. Defaults to None.
|
||||||
|
sparse_features ([type], optional): Sparse input features. Defaults to None.
|
||||||
|
user_embedding (Optional[torch.Tensor], optional): User embeddings. Defaults to None.
|
||||||
|
user_eng_embedding (Optional[torch.Tensor], optional): User engagement embeddings. Defaults to None.
|
||||||
|
author_embedding (Optional[torch.Tensor], optional): Author embeddings. Defaults to None.
|
||||||
|
labels (Optional[torch.Tensor], optional): Target labels. Defaults to None.
|
||||||
|
weights (Optional[torch.Tensor], optional): Weights for the loss function. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: A dictionary containing the model's outputs.
|
||||||
|
"""
|
||||||
concat_dense_features = [
|
concat_dense_features = [
|
||||||
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
|
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
|
||||||
]
|
]
|
||||||
@ -270,6 +310,20 @@ def create_ranking_model(
|
|||||||
data_config=None,
|
data_config=None,
|
||||||
return_backbone=False,
|
return_backbone=False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Creates a ranking model based on the provided specifications and configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_spec: The input data specifications.
|
||||||
|
config (config_mod.RecapConfig): The model configuration.
|
||||||
|
device (torch.device): The device where the model should be placed.
|
||||||
|
loss_fn (Optional[Callable], optional): A custom loss function. Defaults to None.
|
||||||
|
data_config: The data configuration. Defaults to None.
|
||||||
|
return_backbone (bool, optional): Whether to return the backbone network in the output. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.Module: The created ranking model.
|
||||||
|
"""
|
||||||
|
|
||||||
if list(config.model.tasks.values())[0].dlrm_config:
|
if list(config.model.tasks.values())[0].dlrm_config:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -11,21 +11,52 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def log_transform(x: torch.Tensor) -> torch.Tensor:
|
def log_transform(x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Safe log transform that works across both negative, zero, and positive floats."""
|
"""
|
||||||
|
Safe log transform that works across both negative, zero, and positive floats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Transformed tensor with log1p applied to absolute values.
|
||||||
|
"""
|
||||||
return torch.sign(x) * torch.log1p(torch.abs(x))
|
return torch.sign(x) * torch.log1p(torch.abs(x))
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm(torch.nn.Module):
|
class BatchNorm(torch.nn.Module):
|
||||||
def __init__(self, num_features: int, config: BatchNormConfig):
|
def __init__(self, num_features: int, config: BatchNormConfig):
|
||||||
|
"""
|
||||||
|
Batch normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_features (int): Number of input features.
|
||||||
|
config (BatchNormConfig): Configuration for batch normalization.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
|
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass through the batch normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor after batch normalization.
|
||||||
|
"""
|
||||||
return self.layer(x)
|
return self.layer(x)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(torch.nn.Module):
|
class LayerNorm(torch.nn.Module):
|
||||||
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
|
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
|
||||||
|
"""
|
||||||
|
Layer normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_shape (Union[int, Sequence[int]]): Size or shape of the input tensor.
|
||||||
|
config (LayerNormConfig): Configuration for layer normalization.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if config.axis != -1:
|
if config.axis != -1:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -38,6 +69,16 @@ class LayerNorm(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass through the layer normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor after layer normalization.
|
||||||
|
"""
|
||||||
|
|
||||||
return self.layer(x)
|
return self.layer(x)
|
||||||
|
|
||||||
|
|
||||||
@ -46,11 +87,27 @@ class Log1pAbs(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass that applies a log transformation to the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Transformed tensor with log applied to absolute values.
|
||||||
|
"""
|
||||||
|
|
||||||
return log_transform(x)
|
return log_transform(x)
|
||||||
|
|
||||||
|
|
||||||
class InputNonFinite(torch.nn.Module):
|
class InputNonFinite(torch.nn.Module):
|
||||||
def __init__(self, fill_value: float = 0):
|
def __init__(self, fill_value: float = 0):
|
||||||
|
"""
|
||||||
|
Replaces non-finite (NaN and Inf) values in the input tensor with a specified fill value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fill_value (float): The value to fill non-finite elements with. Default is 0.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
@ -58,11 +115,27 @@ class InputNonFinite(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass that replaces non-finite values in the input tensor with the specified fill value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Transformed tensor with non-finite values replaced.
|
||||||
|
"""
|
||||||
return torch.where(torch.isfinite(x), x, self.fill_value)
|
return torch.where(torch.isfinite(x), x, self.fill_value)
|
||||||
|
|
||||||
|
|
||||||
class Clamp(torch.nn.Module):
|
class Clamp(torch.nn.Module):
|
||||||
def __init__(self, min_value: float, max_value: float):
|
def __init__(self, min_value: float, max_value: float):
|
||||||
|
"""
|
||||||
|
Applies element-wise clamping to a tensor, ensuring that values are within a specified range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_value (float): The minimum value to clamp elements to.
|
||||||
|
max_value (float): The maximum value to clamp elements to.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Using buffer to make sure they are on correct device (and not moved every time).
|
# Using buffer to make sure they are on correct device (and not moved every time).
|
||||||
# Will also be part of state_dict.
|
# Will also be part of state_dict.
|
||||||
@ -74,12 +147,31 @@ class Clamp(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass that clamps the input tensor element-wise within the specified range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Transformed tensor with elements clamped within the specified range.
|
||||||
|
"""
|
||||||
return torch.clamp(x, min=self.min_value, max=self.max_value)
|
return torch.clamp(x, min=self.min_value, max=self.max_value)
|
||||||
|
|
||||||
|
|
||||||
class DoubleNormLog(torch.nn.Module):
|
class DoubleNormLog(torch.nn.Module):
|
||||||
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features."""
|
"""
|
||||||
|
Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
|
||||||
|
config (DoubleNormLogConfig): Configuration for the DoubleNormLog module.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_before_concat_layers (torch.nn.Sequential): Sequential layers for batch normalization, log transformation,
|
||||||
|
batch normalization (optional), and clamping.
|
||||||
|
layer_norm (LayerNorm or None): Layer normalization layer for binary and continuous features (optional).
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_shapes: Mapping[str, Sequence[int]],
|
input_shapes: Mapping[str, Sequence[int]],
|
||||||
@ -108,6 +200,17 @@ class DoubleNormLog(torch.nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
|
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass that processes continuous and binary features using batch normalization, log transformation,
|
||||||
|
optional batch normalization (if configured), clamping, and layer normalization (if configured).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
continuous_features (torch.Tensor): Input tensor of continuous features.
|
||||||
|
binary_features (torch.Tensor): Input tensor of binary features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Transformed tensor containing both continuous and binary features.
|
||||||
|
"""
|
||||||
x = self._before_concat_layers(continuous_features)
|
x = self._before_concat_layers(continuous_features)
|
||||||
x = torch.cat([x, binary_features], dim=1)
|
x = torch.cat([x, binary_features], dim=1)
|
||||||
if self.layer_norm:
|
if self.layer_norm:
|
||||||
@ -118,5 +221,15 @@ class DoubleNormLog(torch.nn.Module):
|
|||||||
def build_features_preprocessor(
|
def build_features_preprocessor(
|
||||||
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
|
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
|
||||||
):
|
):
|
||||||
"""Trivial right now, but we will change in the future."""
|
"""
|
||||||
|
Build a feature preprocessor module based on the provided configuration.
|
||||||
|
Trivial right now, but we will change in the future.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (FeaturizationConfig): Configuration for feature preprocessing.
|
||||||
|
input_shapes (Mapping[str, Sequence[int]]): A mapping of input feature names to their corresponding shapes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DoubleNormLog: An instance of the DoubleNormLog feature preprocessor.
|
||||||
|
"""
|
||||||
return DoubleNormLog(input_shapes, config.double_norm_log_config)
|
return DoubleNormLog(input_shapes, config.double_norm_log_config)
|
||||||
|
@ -6,15 +6,84 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def _init_weights(module):
|
def _init_weights(module):
|
||||||
|
"""Initializes weights
|
||||||
|
|
||||||
|
Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Define a simple linear layer
|
||||||
|
linear_layer = nn.Linear(64, 32)
|
||||||
|
|
||||||
|
# Initialize the weights and biases using _init_weights
|
||||||
|
_init_weights(linear_layer)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
if isinstance(module, torch.nn.Linear):
|
if isinstance(module, torch.nn.Linear):
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
torch.nn.init.constant_(module.bias, 0)
|
torch.nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
class MaskBlock(torch.nn.Module):
|
class MaskBlock(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
MaskBlock module in a mask-based neural network.
|
||||||
|
|
||||||
|
This module represents a MaskBlock, which applies a masking operation to the input data and then
|
||||||
|
passes it through a hidden layer. It is typically used as a building block within a MaskNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
||||||
|
input_dim (int): Dimensionality of the input data.
|
||||||
|
mask_input_dim (int): Dimensionality of the mask input.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To create and use a MaskBlock within a MaskNet, follow these steps:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define the configuration for the MaskBlock
|
||||||
|
mask_block_config = MaskBlockConfig(
|
||||||
|
input_layer_norm=True, # Apply input layer normalization
|
||||||
|
reduction_factor=0.5 # Reduce input dimensionality by 50%
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an instance of the MaskBlock
|
||||||
|
mask_block = MaskBlock(mask_block_config, input_dim=64, mask_input_dim=32)
|
||||||
|
|
||||||
|
# Generate input tensors
|
||||||
|
input_data = torch.randn(batch_size, 64)
|
||||||
|
mask_input = torch.randn(batch_size, 32)
|
||||||
|
|
||||||
|
# Perform a forward pass through the MaskBlock
|
||||||
|
output = mask_block(input_data, mask_input)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `MaskBlock` module applies layer normalization to the input if specified, followed by a masking
|
||||||
|
operation that combines the input and mask input. Then, it passes the result through a hidden layer
|
||||||
|
with optional dimensionality reduction.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
|
directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the MaskBlock module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_block_config (config.MaskBlockConfig): Configuration for the MaskBlock.
|
||||||
|
input_dim (int): Dimensionality of the input data.
|
||||||
|
mask_input_dim (int): Dimensionality of the mask input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
super(MaskBlock, self).__init__()
|
super(MaskBlock, self).__init__()
|
||||||
self.mask_block_config = mask_block_config
|
self.mask_block_config = mask_block_config
|
||||||
output_size = mask_block_config.output_size
|
output_size = mask_block_config.output_size
|
||||||
@ -42,6 +111,16 @@ class MaskBlock(torch.nn.Module):
|
|||||||
self._layer_norm = torch.nn.LayerNorm(output_size)
|
self._layer_norm = torch.nn.LayerNorm(output_size)
|
||||||
|
|
||||||
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
|
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Performs a forward pass through the MaskBlock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
net (torch.Tensor): Input data tensor.
|
||||||
|
mask_input (torch.Tensor): Mask input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor of the MaskBlock.
|
||||||
|
"""
|
||||||
if self._input_layer_norm:
|
if self._input_layer_norm:
|
||||||
net = self._input_layer_norm(net)
|
net = self._input_layer_norm(net)
|
||||||
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
|
||||||
@ -49,7 +128,60 @@ class MaskBlock(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MaskNet(torch.nn.Module):
|
class MaskNet(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
MaskNet module in a mask-based neural network.
|
||||||
|
|
||||||
|
This module represents a MaskNet, which consists of multiple MaskBlocks. It can be used to
|
||||||
|
create mask-based neural networks with parallel or stacked MaskBlocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
||||||
|
in_features (int): Dimensionality of the input data.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To create and use a MaskNet, you can follow these steps:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define the configuration for the MaskNet
|
||||||
|
mask_net_config = MaskNetConfig(
|
||||||
|
use_parallel=True, # Use parallel MaskBlocks
|
||||||
|
mlp=MlpConfig(layer_sizes=[128, 64]) # Optional MLP on the outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an instance of the MaskNet
|
||||||
|
mask_net = MaskNet(mask_net_config, in_features=64)
|
||||||
|
|
||||||
|
# Generate input tensors
|
||||||
|
input_data = torch.randn(batch_size, 64)
|
||||||
|
|
||||||
|
# Perform a forward pass through the MaskNet
|
||||||
|
outputs = mask_net(input_data)
|
||||||
|
|
||||||
|
# Access the output and shared layer
|
||||||
|
output = outputs["output"]
|
||||||
|
shared_layer = outputs["shared_layer"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `MaskNet` module allows you to create mask-based neural networks with parallel or stacked
|
||||||
|
MaskBlocks. You can also optionally apply an MLP to the outputs for further processing.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
|
directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
|
||||||
|
"""
|
||||||
|
Initializes the MaskNet module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_net_config (config.MaskNetConfig): Configuration for the MaskNet.
|
||||||
|
in_features (int): Dimensionality of the input data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mask_net_config = mask_net_config
|
self.mask_net_config = mask_net_config
|
||||||
mask_blocks = []
|
mask_blocks = []
|
||||||
@ -77,6 +209,15 @@ class MaskNet(torch.nn.Module):
|
|||||||
self.shared_size = total_output_mask_blocks
|
self.shared_size = total_output_mask_blocks
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor):
|
def forward(self, inputs: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Performs a forward pass through the MaskNet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (torch.Tensor): Input data tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor of the MaskNet.
|
||||||
|
"""
|
||||||
if self.mask_net_config.use_parallel:
|
if self.mask_net_config.use_parallel:
|
||||||
mask_outputs = []
|
mask_outputs = []
|
||||||
for mask_layer in self._mask_blocks:
|
for mask_layer in self._mask_blocks:
|
||||||
|
@ -7,13 +7,83 @@ from absl import logging
|
|||||||
|
|
||||||
|
|
||||||
def _init_weights(module):
|
def _init_weights(module):
|
||||||
|
"""Initializes weights
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# Define a simple linear layer
|
||||||
|
linear_layer = nn.Linear(64, 32)
|
||||||
|
|
||||||
|
# Initialize the weights and biases using _init_weights
|
||||||
|
_init_weights(linear_layer)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
if isinstance(module, torch.nn.Linear):
|
if isinstance(module, torch.nn.Linear):
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
torch.nn.init.constant_(module.bias, 0)
|
torch.nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
class Mlp(torch.nn.Module):
|
class Mlp(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-Layer Perceptron (MLP) feedforward neural network module in PyTorch.
|
||||||
|
|
||||||
|
This module defines an MLP with customizable layers and activation functions. It is suitable for various
|
||||||
|
applications such as deep learning for tabular data, feature extraction, and more.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): The number of input features or input dimensions.
|
||||||
|
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To create an instance of the `Mlp` module and use it for forward passes, you can follow these steps:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Define the configuration for the MLP
|
||||||
|
mlp_config = MlpConfig(
|
||||||
|
layer_sizes=[128, 64], # Specify the sizes of hidden layers
|
||||||
|
batch_norm=True, # Enable batch normalization
|
||||||
|
dropout=0.2, # Apply dropout with a rate of 0.2
|
||||||
|
final_layer_activation=True # Apply ReLU activation to the final layer
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an instance of the MLP module
|
||||||
|
mlp_model = Mlp(in_features=input_dim, mlp_config=mlp_config)
|
||||||
|
|
||||||
|
# Generate an input tensor
|
||||||
|
input_tensor = torch.randn(batch_size, input_dim)
|
||||||
|
|
||||||
|
# Perform a forward pass through the MLP
|
||||||
|
outputs = mlp_model(input_tensor)
|
||||||
|
|
||||||
|
# Access the output and shared layer
|
||||||
|
output = outputs["output"]
|
||||||
|
shared_layer = outputs["shared_layer"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `Mlp` class allows you to create customizable MLP architectures by specifying the layer sizes,
|
||||||
|
enabling batch normalization and dropout, and choosing the activation function for the final layer.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
|
directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
def __init__(self, in_features: int, mlp_config: MlpConfig):
|
def __init__(self, in_features: int, mlp_config: MlpConfig):
|
||||||
|
"""
|
||||||
|
Initializes the Mlp module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): The number of input features or input dimensions.
|
||||||
|
mlp_config (MlpConfig): Configuration object specifying the MLP's architecture.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._mlp_config = mlp_config
|
self._mlp_config = mlp_config
|
||||||
input_size = in_features
|
input_size = in_features
|
||||||
@ -42,6 +112,15 @@ class Mlp(torch.nn.Module):
|
|||||||
self.layers.apply(_init_weights)
|
self.layers.apply(_init_weights)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Performs a forward pass through the MLP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor of the MLP.
|
||||||
|
"""
|
||||||
net = x
|
net = x
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
net = layer(net)
|
net = layer(net)
|
||||||
@ -51,8 +130,21 @@ class Mlp(torch.nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def shared_size(self):
|
def shared_size(self):
|
||||||
|
"""
|
||||||
|
Returns the size of the shared layer in the MLP.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Size of the shared layer.
|
||||||
|
"""
|
||||||
return self._mlp_config.layer_sizes[-1]
|
return self._mlp_config.layer_sizes[-1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def out_features(self):
|
def out_features(self):
|
||||||
|
"""
|
||||||
|
Returns the number of output features from the MLP.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of output features.
|
||||||
|
"""
|
||||||
|
|
||||||
return self._mlp_config.layer_sizes[-1]
|
return self._mlp_config.layer_sizes[-1]
|
||||||
|
@ -5,6 +5,53 @@ from absl import logging
|
|||||||
|
|
||||||
|
|
||||||
class ModelAndLoss(torch.nn.Module):
|
class ModelAndLoss(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
PyTorch module that combines a neural network model and loss function.
|
||||||
|
|
||||||
|
This module wraps a neural network model and facilitates the forward pass through the model
|
||||||
|
while also calculating the loss based on the model's predictions and provided labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The torch module to wrap.
|
||||||
|
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
||||||
|
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
||||||
|
for metrics stratification. Each stratifier config includes the name and index of discrete features
|
||||||
|
to emit for stratification.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To use `ModelAndLoss` in a PyTorch training loop, you can create an instance of it and pass your model
|
||||||
|
and loss function as arguments:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create a neural network model
|
||||||
|
model = YourNeuralNetworkModel()
|
||||||
|
|
||||||
|
# Define a loss function
|
||||||
|
loss_fn = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
# Create an instance of ModelAndLoss
|
||||||
|
model_and_loss = ModelAndLoss(model, loss_fn)
|
||||||
|
|
||||||
|
# Generate a batch of training data (e.g., RecapBatch)
|
||||||
|
batch = generate_training_batch()
|
||||||
|
|
||||||
|
# Perform a forward pass through the model and calculate the loss
|
||||||
|
loss, outputs = model_and_loss(batch)
|
||||||
|
|
||||||
|
# You can now backpropagate and optimize using the computed loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `ModelAndLoss` class simplifies the process of running forward passes through a model and
|
||||||
|
calculating loss, making it easier to integrate the model into your training loop. Additionally,
|
||||||
|
it supports the addition of stratifiers for metrics stratification, if needed.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
|
directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@ -12,10 +59,13 @@ class ModelAndLoss(torch.nn.Module):
|
|||||||
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
|
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Initializes the ModelAndLoss module.
|
||||||
model: torch module to wrap.
|
|
||||||
loss_fn: Function for calculating loss, should accept logits and labels.
|
Args:
|
||||||
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.
|
model: The torch module to wrap.
|
||||||
|
loss_fn (Callable): Function for calculating the loss, which should accept logits and labels.
|
||||||
|
stratifiers (Optional[List[embedding_config_mod.StratifierConfig]]): A list of stratifier configurations
|
||||||
|
for metrics stratification.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -2,11 +2,57 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class NumericCalibration(torch.nn.Module):
|
class NumericCalibration(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Numeric calibration module for adjusting probability scores.
|
||||||
|
|
||||||
|
This module scales probability scores to correct for imbalanced datasets, where positive and negative samples
|
||||||
|
may be underrepresented or have different ratios. It is designed to be used as a component in a neural network
|
||||||
|
for tasks such as binary classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pos_downsampling_rate (float): The downsampling rate for positive samples.
|
||||||
|
neg_downsampling_rate (float): The downsampling rate for negative samples.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To use `NumericCalibration` in a PyTorch model, you can create an instance of it and apply it to probability
|
||||||
|
scores like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create a NumericCalibration instance with downsampling rates
|
||||||
|
calibration = NumericCalibration(pos_downsampling_rate=0.1, neg_downsampling_rate=0.2)
|
||||||
|
|
||||||
|
# Generate probability scores (e.g., from a neural network)
|
||||||
|
raw_probs = torch.tensor([0.8, 0.6, 0.2, 0.9])
|
||||||
|
|
||||||
|
# Apply numeric calibration to adjust the probabilities
|
||||||
|
calibrated_probs = calibration(raw_probs)
|
||||||
|
|
||||||
|
# The `calibrated_probs` now contains the adjusted probability scores
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `NumericCalibration` module is used to adjust probability scores to account for differences in
|
||||||
|
the number of positive and negative samples in a dataset. It can help improve the calibration of
|
||||||
|
probability estimates in imbalanced classification problems.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within neural network architectures and should not be
|
||||||
|
directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pos_downsampling_rate: float,
|
pos_downsampling_rate: float,
|
||||||
neg_downsampling_rate: float,
|
neg_downsampling_rate: float,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Apply numeric calibration to probability scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor): Probability scores to be calibrated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Calibrated probability scores.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Using buffer to make sure they are on correct device (and not moved every time).
|
# Using buffer to make sure they are on correct device (and not moved every time).
|
||||||
|
@ -9,12 +9,60 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class RecapAdamConfig(base_config.BaseConfig):
|
class RecapAdamConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration settings for the Adam optimizer used in Recap.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
beta_1 (float): Momentum term (default: 0.9).
|
||||||
|
beta_2 (float): Exponential weighted decay factor (default: 0.999).
|
||||||
|
epsilon (float): Numerical stability in the denominator (default: 1e-7).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To define an Adam optimizer configuration for Recap, use:
|
||||||
|
|
||||||
|
```python
|
||||||
|
adam_config = RecapAdamConfig(beta_1=0.9, beta_2=0.999, epsilon=1e-7)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This class configures the parameters of the Adam optimizer, which is commonly used for optimizing neural networks.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
|
|
||||||
beta_1: float = 0.9 # Momentum term.
|
beta_1: float = 0.9 # Momentum term.
|
||||||
beta_2: float = 0.999 # Exponential weighted decay factor.
|
beta_2: float = 0.999 # Exponential weighted decay factor.
|
||||||
epsilon: float = 1e-7 # Numerical stability in denominator.
|
epsilon: float = 1e-7 # Numerical stability in denominator.
|
||||||
|
|
||||||
|
|
||||||
class MultiTaskLearningRates(base_config.BaseConfig):
|
class MultiTaskLearningRates(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration settings for multiple learning rates in Recap.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tower_learning_rates (Dict[str, optimizers_config_mod.LearningRate]): Learning rates for different towers of the model.
|
||||||
|
backbone_learning_rate (optimizers_config_mod.LearningRate): Learning rate for the model's backbone (default: None).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To define multiple learning rates for different towers in Recap, use:
|
||||||
|
|
||||||
|
```python
|
||||||
|
multi_task_lr = MultiTaskLearningRates(
|
||||||
|
tower_learning_rates={
|
||||||
|
'task1': learning_rate1,
|
||||||
|
'task2': learning_rate2,
|
||||||
|
},
|
||||||
|
backbone_learning_rate=backbone_lr,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This class allows specifying different learning rates for different parts of the model, including task-specific towers and the backbone.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use within Recap and should not be directly accessed or modified by external code.
|
||||||
|
"""
|
||||||
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
|
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
|
||||||
description="Learning rates for different towers of the model."
|
description="Learning rates for different towers of the model."
|
||||||
)
|
)
|
||||||
@ -25,6 +73,30 @@ class MultiTaskLearningRates(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class RecapOptimizerConfig(base_config.BaseConfig):
|
class RecapOptimizerConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration settings for the Recap optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
multi_task_learning_rates (MultiTaskLearningRates): Multiple learning rates for different tasks (optional).
|
||||||
|
single_task_learning_rate (optimizers_config_mod.LearningRate): Learning rate for a single task (optional).
|
||||||
|
adam (RecapAdamConfig): Configuration settings for the Adam optimizer.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To define an optimizer configuration for training with Recap, use:
|
||||||
|
|
||||||
|
```python
|
||||||
|
optimizer_config = RecapOptimizerConfig(
|
||||||
|
multi_task_learning_rates=multi_task_lr,
|
||||||
|
single_task_learning_rate=single_task_lr,
|
||||||
|
adam=adam_config,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use to configure the optimizer settings within Recap and should not be
|
||||||
|
directly accessed by external code.
|
||||||
|
"""
|
||||||
|
|
||||||
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
|
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
|
||||||
None, description="Multiple learning rates for different tasks.", one_of="lr"
|
None, description="Multiple learning rates for different tasks.", one_of="lr"
|
||||||
)
|
)
|
||||||
|
@ -23,12 +23,30 @@ _DENSE_EMBEDDINGS = "dense_ebc"
|
|||||||
|
|
||||||
|
|
||||||
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
|
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
|
||||||
"""Shim to get learning rates into a LRScheduler.
|
|
||||||
|
|
||||||
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
|
||||||
e.g. exponential decay can be used.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
A shim to get learning rates into a LRScheduler.
|
||||||
|
|
||||||
|
This class adheres to the torch.optim scheduler API and can be plugged into any scheduler that supports
|
||||||
|
learning rate schedules, such as exponential decay.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: The optimizer to which this scheduler is applied.
|
||||||
|
lr_dict (Dict[str, config.LearningRate]): A dictionary mapping group names to learning rate configurations.
|
||||||
|
emb_learning_rate: The learning rate for embeddings (optional).
|
||||||
|
last_epoch (int): The index of the last epoch (default: -1).
|
||||||
|
verbose (bool): If True, print warnings for deprecated functions (default: False).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To create a RecapLRShim scheduler for an optimizer and a dictionary of learning rates, use:
|
||||||
|
|
||||||
|
```python
|
||||||
|
scheduler = RecapLRShim(optimizer, lr_dict, emb_learning_rate)
|
||||||
|
```
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
This class is intended for internal use to handle learning rate scheduling within Recap training and should not
|
||||||
|
be directly accessed by external code.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -80,15 +98,25 @@ def build_optimizer(
|
|||||||
optimizer_config: config.OptimizerConfig,
|
optimizer_config: config.OptimizerConfig,
|
||||||
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
|
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
|
||||||
):
|
):
|
||||||
"""Builds an optimizer and scheduler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: A torch model, probably with DDP/DMP.
|
|
||||||
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A torch.optim instance, and a scheduler instance.
|
|
||||||
"""
|
"""
|
||||||
|
Build an optimizer and scheduler for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The torch model, possibly with DDP/DMP.
|
||||||
|
optimizer_config (config.OptimizerConfig): Configuration settings for the optimizer.
|
||||||
|
emb_optimizer_config: Configuration settings for embedding optimization (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.optim.Optimizer: The optimizer for training.
|
||||||
|
RecapLRShim: The learning rate scheduler for the optimizer.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To build an optimizer and scheduler for training, use:
|
||||||
|
|
||||||
|
```python
|
||||||
|
optimizer, scheduler = build_optimizer(model, optimizer_config, emb_optimizer_config)
|
||||||
|
```
|
||||||
|
"""
|
||||||
optimizer_fn = functools.partial(
|
optimizer_fn = functools.partial(
|
||||||
torch.optim.Adam,
|
torch.optim.Adam,
|
||||||
lr=_DEFAULT_LR,
|
lr=_DEFAULT_LR,
|
||||||
|
@ -4,6 +4,17 @@ import pydantic
|
|||||||
|
|
||||||
|
|
||||||
class TwhinDataConfig(base_config.BaseConfig):
|
class TwhinDataConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration for Twhin model training data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root (str): The root directory for the training data.
|
||||||
|
per_replica_batch_size (pydantic.PositiveInt): Batch size per replica.
|
||||||
|
global_negatives (int): The number of global negatives.
|
||||||
|
in_batch_negatives (int): The number of in-batch negatives.
|
||||||
|
limit (pydantic.PositiveInt): The limit on the number of data points to use.
|
||||||
|
offset (pydantic.PositiveInt, optional): The offset to start reading from. Default is None.
|
||||||
|
"""
|
||||||
data_root: str
|
data_root: str
|
||||||
per_replica_batch_size: pydantic.PositiveInt
|
per_replica_batch_size: pydantic.PositiveInt
|
||||||
global_negatives: int
|
global_negatives: int
|
||||||
|
@ -4,6 +4,16 @@ from tml.projects.twhin.data.edges import EdgesDataset
|
|||||||
|
|
||||||
|
|
||||||
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
|
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
|
||||||
|
"""
|
||||||
|
Create a dataset for Twhin model training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_config (TwhinDataConfig): The data configuration for the dataset.
|
||||||
|
model_config (TwhinModelConfig): The model configuration containing embeddings and relations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EdgesDataset: The dataset for Twhin model training.
|
||||||
|
"""
|
||||||
tables = model_config.embeddings.tables
|
tables = model_config.embeddings.tables
|
||||||
table_sizes = {table.name: table.num_embeddings for table in tables}
|
table_sizes = {table.name: table.num_embeddings for table in tables}
|
||||||
relations = model_config.relations
|
relations = model_config.relations
|
||||||
|
@ -15,6 +15,15 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EdgeBatch(DataclassBatch):
|
class EdgeBatch(DataclassBatch):
|
||||||
|
"""
|
||||||
|
Batch data structure for edge-based models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes (KeyedJaggedTensor): A KeyedJaggedTensor containing node embeddings.
|
||||||
|
labels (torch.Tensor): Tensor containing labels.
|
||||||
|
rels (torch.Tensor): Tensor containing relation information.
|
||||||
|
weights (torch.Tensor): Tensor containing weights.
|
||||||
|
"""
|
||||||
nodes: KeyedJaggedTensor
|
nodes: KeyedJaggedTensor
|
||||||
labels: torch.Tensor
|
labels: torch.Tensor
|
||||||
rels: torch.Tensor
|
rels: torch.Tensor
|
||||||
@ -22,6 +31,18 @@ class EdgeBatch(DataclassBatch):
|
|||||||
|
|
||||||
|
|
||||||
class EdgesDataset(Dataset):
|
class EdgesDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for edge-based models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_pattern (str): The file pattern for the dataset.
|
||||||
|
table_sizes (Dict[str, int]): A dictionary of table names and their sizes.
|
||||||
|
relations (List[Relation]): A list of relations between tables.
|
||||||
|
lhs_column_name (str): The name of the left-hand-side column.
|
||||||
|
rhs_column_name (str): The name of the right-hand-side column.
|
||||||
|
rel_column_name (str): The name of the relation column.
|
||||||
|
**dataset_kwargs: Additional keyword arguments for the parent Dataset class.
|
||||||
|
"""
|
||||||
rng = np.random.default_rng()
|
rng = np.random.default_rng()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -56,6 +77,15 @@ class EdgesDataset(Dataset):
|
|||||||
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
|
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
|
||||||
|
|
||||||
def pa_to_batch(self, batch: pa.RecordBatch):
|
def pa_to_batch(self, batch: pa.RecordBatch):
|
||||||
|
"""
|
||||||
|
Converts a pyarrow RecordBatch to an EdgeBatch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (pa.RecordBatch): A pyarrow RecordBatch containing data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EdgeBatch: An EdgeBatch containing node embeddings, labels, relations, and weights.
|
||||||
|
"""
|
||||||
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
|
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
|
||||||
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
|
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
|
||||||
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
|
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
|
||||||
@ -74,6 +104,14 @@ class EdgesDataset(Dataset):
|
|||||||
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
|
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
|
||||||
|
|
||||||
"""Process edges that contain lhs index, rhs index, relation index.
|
"""Process edges that contain lhs index, rhs index, relation index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lhs (torch.Tensor): Tensor containing left-hand-side indices.
|
||||||
|
rhs (torch.Tensor): Tensor containing right-hand-side indices.
|
||||||
|
rel (torch.Tensor): Tensor containing relation indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]: A KeyedJaggedTensor and relation index pairs.
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -147,6 +185,12 @@ class EdgesDataset(Dataset):
|
|||||||
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
|
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
|
||||||
|
|
||||||
def to_batches(self):
|
def to_batches(self):
|
||||||
|
"""
|
||||||
|
Converts data to batches.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
pa.RecordBatch: A pyarrow RecordBatch containing data.
|
||||||
|
"""
|
||||||
ds = super().to_batches()
|
ds = super().to_batches()
|
||||||
batch_size = self._dataset_kwargs["batch_size"]
|
batch_size = self._dataset_kwargs["batch_size"]
|
||||||
|
|
||||||
|
@ -23,6 +23,25 @@ RELATIONS = [
|
|||||||
|
|
||||||
|
|
||||||
def test_gen():
|
def test_gen():
|
||||||
|
"""Test function for generating edge-based datasets and dataloaders.
|
||||||
|
|
||||||
|
This function generates a synthetic dataset and tests the creation of an `EdgesDataset`
|
||||||
|
instance and a dataloader for it.
|
||||||
|
|
||||||
|
The test includes the following steps:
|
||||||
|
1. Create synthetic data with left-hand-side (lhs), right-hand-side (rhs), and relation (rel) columns.
|
||||||
|
2. Write the synthetic data to a Parquet file.
|
||||||
|
3. Create an `EdgesDataset` instance with the Parquet file pattern, table sizes, relations, and batch size.
|
||||||
|
4. Initialize the local file system for the dataset.
|
||||||
|
5. Create a dataloader for the dataset and retrieve the first batch.
|
||||||
|
6. Assert that the labels in the batch are positive.
|
||||||
|
7. Verify that the positive examples in the batch match the expected values.
|
||||||
|
|
||||||
|
This function serves as a test case for the data generation and dataset creation process.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If any of the test assertions fail.
|
||||||
|
"""
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
@ -10,8 +10,29 @@ from pydantic import validator
|
|||||||
|
|
||||||
|
|
||||||
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
||||||
|
"""
|
||||||
|
Configuration class for Twhin model embeddings.
|
||||||
|
|
||||||
|
This class inherits from LargeEmbeddingsConfig and ensures that the embedding dimensions and data types
|
||||||
|
for all tables in the Twhin model embeddings configuration match.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tables (List[TableConfig]): A list of table configurations for the model's embeddings.
|
||||||
|
"""
|
||||||
@validator("tables")
|
@validator("tables")
|
||||||
def embedding_dims_match(cls, tables):
|
def embedding_dims_match(cls, tables):
|
||||||
|
"""
|
||||||
|
Validate that embedding dimensions and data types match for all tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tables (List[TableConfig]): List of table configurations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[TableConfig]: The list of validated table configurations.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If embedding dimensions or data types do not match.
|
||||||
|
"""
|
||||||
embedding_dim = tables[0].embedding_dim
|
embedding_dim = tables[0].embedding_dim
|
||||||
data_type = tables[0].data_type
|
data_type = tables[0].data_type
|
||||||
for table in tables:
|
for table in tables:
|
||||||
@ -21,11 +42,26 @@ class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
|
|||||||
|
|
||||||
|
|
||||||
class Operator(str, enum.Enum):
|
class Operator(str, enum.Enum):
|
||||||
|
"""
|
||||||
|
Enumeration of operator types.
|
||||||
|
|
||||||
|
This enumeration defines different types of operators that can be applied to Twhin model relations.
|
||||||
|
"""
|
||||||
TRANSLATION = "translation"
|
TRANSLATION = "translation"
|
||||||
|
|
||||||
|
|
||||||
class Relation(pydantic.BaseModel):
|
class Relation(pydantic.BaseModel):
|
||||||
"""graph relationship properties and operator"""
|
"""
|
||||||
|
Configuration class for graph relationships in the Twhin model.
|
||||||
|
|
||||||
|
This class defines properties and operators for graph relationships in the Twhin model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The name of the relationship.
|
||||||
|
lhs (str): The name of the entity on the left-hand side of the relation.
|
||||||
|
rhs (str): The name of the entity on the right-hand side of the relation.
|
||||||
|
operator (Operator): The transformation operator to apply to the left-hand side embedding before dot product.
|
||||||
|
"""
|
||||||
|
|
||||||
name: str = pydantic.Field(..., description="Relationship name.")
|
name: str = pydantic.Field(..., description="Relationship name.")
|
||||||
lhs: str = pydantic.Field(
|
lhs: str = pydantic.Field(
|
||||||
@ -42,12 +78,35 @@ class Relation(pydantic.BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TwhinModelConfig(base_config.BaseConfig):
|
class TwhinModelConfig(base_config.BaseConfig):
|
||||||
|
"""
|
||||||
|
Configuration class for the Twhin model.
|
||||||
|
|
||||||
|
This class defines configuration options specific to the Twhin model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
embeddings (TwhinEmbeddingsConfig): Configuration for the model's embeddings.
|
||||||
|
relations (List[Relation]): List of graph relationship configurations.
|
||||||
|
translation_optimizer (OptimizerConfig): Configuration for the optimizer used for translation.
|
||||||
|
"""
|
||||||
embeddings: TwhinEmbeddingsConfig
|
embeddings: TwhinEmbeddingsConfig
|
||||||
relations: typing.List[Relation]
|
relations: typing.List[Relation]
|
||||||
translation_optimizer: OptimizerConfig
|
translation_optimizer: OptimizerConfig
|
||||||
|
|
||||||
@validator("relations", each_item=True)
|
@validator("relations", each_item=True)
|
||||||
def valid_node_types(cls, relation, values, **kwargs):
|
def valid_node_types(cls, relation, values, **kwargs):
|
||||||
|
"""
|
||||||
|
Validate that the specified node types in relations are valid table names in embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relation (Relation): A single relation configuration.
|
||||||
|
values (dict): The values dictionary containing the "embeddings" configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relation: The validated relation configuration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the specified node types are not valid table names in embeddings.
|
||||||
|
"""
|
||||||
table_names = [table.name for table in values["embeddings"].tables]
|
table_names = [table.name for table in values["embeddings"].tables]
|
||||||
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
|
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
|
||||||
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
|
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
|
||||||
|
@ -14,6 +14,28 @@ from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backwa
|
|||||||
|
|
||||||
|
|
||||||
class TwhinModel(nn.Module):
|
class TwhinModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Twhin model for graph-based entity embeddings and translation.
|
||||||
|
|
||||||
|
This class defines the Twhin model, which is used for learning embeddings of entities in a graph
|
||||||
|
and applying translations to these embeddings based on graph relationships.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config (TwhinModelConfig): Configuration for the Twhin model.
|
||||||
|
data_config (TwhinDataConfig): Configuration for the data used by the model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
batch_size (int): The batch size used for training.
|
||||||
|
table_names (List[str]): Names of tables in the model's embeddings.
|
||||||
|
large_embeddings (LargeEmbeddings): LargeEmbeddings instance for entity embeddings.
|
||||||
|
embedding_dim (int): Dimensionality of entity embeddings.
|
||||||
|
num_tables (int): Number of tables in the model's embeddings.
|
||||||
|
in_batch_negatives (int): Number of in-batch negative samples to use during training.
|
||||||
|
global_negatives (int): Number of global negative samples to use during training.
|
||||||
|
num_relations (int): Number of graph relationships in the model.
|
||||||
|
all_trans_embs (torch.nn.Parameter): Parameter tensor for translation embeddings.
|
||||||
|
|
||||||
|
"""
|
||||||
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
|
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_size = data_config.per_replica_batch_size
|
self.batch_size = data_config.per_replica_batch_size
|
||||||
@ -31,7 +53,17 @@ class TwhinModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, batch: EdgeBatch):
|
def forward(self, batch: EdgeBatch):
|
||||||
|
"""
|
||||||
|
Forward pass of the Twhin model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (EdgeBatch): Input batch containing graph edge information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing model output with "logits" and "probabilities".
|
||||||
|
- "logits" (torch.Tensor): Logit scores.
|
||||||
|
- "probabilities" (torch.Tensor): Sigmoid probabilities.
|
||||||
|
"""
|
||||||
# B x D
|
# B x D
|
||||||
trans_embs = self.all_trans_embs.data[batch.rels]
|
trans_embs = self.all_trans_embs.data[batch.rels]
|
||||||
|
|
||||||
@ -98,6 +130,18 @@ class TwhinModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
|
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
|
||||||
|
"""
|
||||||
|
Apply optimizers to the Twhin model's embeddings.
|
||||||
|
|
||||||
|
This function applies optimizers to the embeddings of the Twhin model based on the provided configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TwhinModel): The Twhin model to apply optimizers to.
|
||||||
|
model_config (TwhinModelConfig): Configuration for the Twhin model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TwhinModel: The Twhin model with optimizers applied to its embeddings.
|
||||||
|
"""
|
||||||
for table in model_config.embeddings.tables:
|
for table in model_config.embeddings.tables:
|
||||||
optimizer_class = get_optimizer_class(table.optimizer)
|
optimizer_class = get_optimizer_class(table.optimizer)
|
||||||
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
|
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
|
||||||
@ -124,10 +168,14 @@ class TwhinModelAndLoss(torch.nn.Module):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Initialize a TwhinModelAndLoss module.
|
||||||
model: torch module to wrap.
|
|
||||||
loss_fn: Function for calculating loss, should accept logits and labels.
|
Args:
|
||||||
"""
|
model: The torch module to wrap.
|
||||||
|
loss_fn: A function for calculating loss, should accept logits and labels.
|
||||||
|
data_config: Configuration for Twhin data.
|
||||||
|
device: The torch device to use for calculations.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
@ -136,14 +184,21 @@ class TwhinModelAndLoss(torch.nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
|
||||||
"""Runs model forward and calculates loss according to given loss_fn.
|
|
||||||
|
|
||||||
NOTE: The input signature here needs to be a Pipelineable object for
|
|
||||||
prefetching purposes during training using torchrec's pipeline. However
|
|
||||||
the underlying model signature needs to be exportable to onnx, requiring
|
|
||||||
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Run the model forward and calculate the loss according to the given loss_fn.
|
||||||
|
|
||||||
|
NOTE: The input signature here needs to be a Pipelineable object for
|
||||||
|
prefetching purposes during training using torchrec's pipeline. However
|
||||||
|
the underlying model signature needs to be exportable to onnx, requiring
|
||||||
|
generic python types. see https://pytorch.org/docs/stable/onnx.html#types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch ("RecapBatch"): The input batch for model inference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple containing the loss tensor and a dictionary of
|
||||||
|
additional outputs including logits, labels, and weights.
|
||||||
|
"""
|
||||||
outputs = self.model(batch)
|
outputs = self.model(batch)
|
||||||
logits = outputs["logits"]
|
logits = outputs["logits"]
|
||||||
|
|
||||||
|
@ -18,6 +18,12 @@ EMB_DIM = 128
|
|||||||
|
|
||||||
|
|
||||||
def twhin_model_config() -> TwhinModelConfig:
|
def twhin_model_config() -> TwhinModelConfig:
|
||||||
|
"""
|
||||||
|
Create a configuration for the Twhin model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TwhinModelConfig: The Twhin model configuration.
|
||||||
|
"""
|
||||||
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
|
sgd_config_0 = OptimizerConfig(sgd=SgdConfig(lr=0.01))
|
||||||
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
||||||
|
|
||||||
@ -52,6 +58,12 @@ def twhin_model_config() -> TwhinModelConfig:
|
|||||||
|
|
||||||
|
|
||||||
def twhin_data_config() -> TwhinDataConfig:
|
def twhin_data_config() -> TwhinDataConfig:
|
||||||
|
"""
|
||||||
|
Create a configuration for the Twhin data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TwhinDataConfig: The Twhin data configuration.
|
||||||
|
"""
|
||||||
data_config = TwhinDataConfig(
|
data_config = TwhinDataConfig(
|
||||||
data_root="/",
|
data_root="/",
|
||||||
per_replica_batch_size=10,
|
per_replica_batch_size=10,
|
||||||
@ -65,6 +77,15 @@ def twhin_data_config() -> TwhinDataConfig:
|
|||||||
|
|
||||||
|
|
||||||
def test_twhin_model():
|
def test_twhin_model():
|
||||||
|
"""
|
||||||
|
Test the Twhin model creation and optimization.
|
||||||
|
|
||||||
|
This function creates a Twhin model using the specified configuration and tests its optimization. It also checks
|
||||||
|
the device placement of model parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
model_config = twhin_model_config()
|
model_config = twhin_model_config()
|
||||||
loss_fn = F.binary_cross_entropy_with_logits
|
loss_fn = F.binary_cross_entropy_with_logits
|
||||||
|
|
||||||
@ -84,6 +105,23 @@ def test_twhin_model():
|
|||||||
|
|
||||||
|
|
||||||
def test_unequal_dims():
|
def test_unequal_dims():
|
||||||
|
"""
|
||||||
|
Test function for validating unequal embedding dimensions in TwhinEmbeddingsConfig.
|
||||||
|
|
||||||
|
This function tests whether the validation logic correctly raises a `ValidationError` when
|
||||||
|
embedding dimensions in the `TwhinEmbeddingsConfig` are not equal for all tables.
|
||||||
|
|
||||||
|
The test includes the following steps:
|
||||||
|
1. Create two embedding configurations with different embedding dimensions.
|
||||||
|
2. Attempt to create a `TwhinEmbeddingsConfig` instance with the unequal embedding dimensions.
|
||||||
|
3. Assert that a `ValidationError` is raised, indicating that embedding dimensions must match.
|
||||||
|
|
||||||
|
This function serves as a test case to ensure that the validation logic enforces equal embedding dimensions
|
||||||
|
in the `TwhinEmbeddingsConfig` for all tables.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the expected `ValidationError` is not raised.
|
||||||
|
"""
|
||||||
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
sgd_config_1 = OptimizerConfig(sgd=SgdConfig(lr=0.02))
|
||||||
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
|
sgd_config_2 = OptimizerConfig(sgd=SgdConfig(lr=0.05))
|
||||||
table0 = EmbeddingBagConfig(
|
table0 = EmbeddingBagConfig(
|
||||||
|
@ -15,6 +15,14 @@ TRANSLATION_OPT_KEY = "operator_opt"
|
|||||||
|
|
||||||
|
|
||||||
def _lr_from_config(optimizer_config):
|
def _lr_from_config(optimizer_config):
|
||||||
|
"""Get the learning rate from an optimizer configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer_config: Optimizer configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Learning rate from the optimizer configuration.
|
||||||
|
"""
|
||||||
if optimizer_config.learning_rate is not None:
|
if optimizer_config.learning_rate is not None:
|
||||||
return optimizer_config.learning_rate
|
return optimizer_config.learning_rate
|
||||||
else:
|
else:
|
||||||
@ -26,13 +34,13 @@ def _lr_from_config(optimizer_config):
|
|||||||
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
|
def build_optimizer(model: TwhinModel, config: TwhinModelConfig):
|
||||||
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
|
"""Builds an optimizer for a Twhin model combining the embeddings optimizer with an optimizer for per-relation translations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: TwhinModel to build optimizer for.
|
model: TwhinModel to build optimizer for.
|
||||||
config: TwhinConfig for model.
|
config: TwhinModelConfig for model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optimizer for model.
|
Optimizer for model.
|
||||||
"""
|
"""
|
||||||
translation_optimizer_fn = functools.partial(
|
translation_optimizer_fn = functools.partial(
|
||||||
get_optimizer_class(config.translation_optimizer),
|
get_optimizer_class(config.translation_optimizer),
|
||||||
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),
|
**get_optimizer_algorithm_config(config.translation_optimizer).dict(),
|
||||||
|
@ -37,6 +37,12 @@ def run(
|
|||||||
all_config: TwhinConfig,
|
all_config: TwhinConfig,
|
||||||
save_dir: Optional[str] = None,
|
save_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
"""Run the training process for TwhinModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
all_config (TwhinConfig): The configuration for the entire Twhin model.
|
||||||
|
save_dir (str, optional): The directory where model checkpoints will be saved. Defaults to None.
|
||||||
|
"""
|
||||||
train_dataset = create_dataset(all_config.train_data, all_config.model)
|
train_dataset = create_dataset(all_config.train_data, all_config.model)
|
||||||
|
|
||||||
if env.is_reader():
|
if env.is_reader():
|
||||||
@ -80,6 +86,11 @@ def run(
|
|||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
|
"""Main entry point for the Twhin training script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
argv: Command-line arguments.
|
||||||
|
"""
|
||||||
logging.info("Starting")
|
logging.info("Starting")
|
||||||
|
|
||||||
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")
|
logging.info(f"parsing config from {FLAGS.config_yaml_path}...")
|
||||||
|
@ -25,14 +25,58 @@ from tml.ml_logging.torch_logging import logging
|
|||||||
|
|
||||||
|
|
||||||
class _Reader(pa.flight.FlightServerBase):
|
class _Reader(pa.flight.FlightServerBase):
|
||||||
"""Distributed reader flight server wrapping a dataset."""
|
"""
|
||||||
|
Distributed reader flight server wrapping a dataset.
|
||||||
|
|
||||||
|
This class implements a Flight server that wraps a dataset, allowing clients to retrieve data
|
||||||
|
from the dataset over the Flight protocol. It is designed to be used in a distributed environment
|
||||||
|
for efficient data access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (str): The location of the Flight server.
|
||||||
|
ds (Dataset): The dataset to be wrapped by the Flight server.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_location (str): The location of the Flight server.
|
||||||
|
_ds (Dataset): The dataset wrapped by the Flight server.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
do_get(_, __): Handles Flight requests for data retrieval.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Flight is an Apache Arrow project that provides a framework for efficient data transfer.
|
||||||
|
This class allows clients to retrieve data from the dataset using Flight.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, location: str, ds: "Dataset"):
|
def __init__(self, location: str, ds: "Dataset"):
|
||||||
|
"""
|
||||||
|
Initialize a new _Reader instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location (str): The location of the Flight server.
|
||||||
|
ds (Dataset): The dataset to be wrapped by the Flight server.
|
||||||
|
"""
|
||||||
super().__init__(location=location)
|
super().__init__(location=location)
|
||||||
self._location = location
|
self._location = location
|
||||||
self._ds = ds
|
self._ds = ds
|
||||||
|
|
||||||
def do_get(self, _, __):
|
def do_get(self, _, __):
|
||||||
|
"""
|
||||||
|
Handle Flight requests for data retrieval.
|
||||||
|
|
||||||
|
This method retrieves data from the wrapped dataset and provides it to clients over the Flight protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_: Unused argument.
|
||||||
|
__: Unused argument.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pa.flight.RecordBatchStream: A stream of record batches containing data from the dataset.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
An updated schema (to account for column selection) must be given to the stream.
|
||||||
|
"""
|
||||||
# NB: An updated schema (to account for column selection) has to be given the stream.
|
# NB: An updated schema (to account for column selection) has to be given the stream.
|
||||||
schema = next(iter(self._ds.to_batches())).schema
|
schema = next(iter(self._ds.to_batches())).schema
|
||||||
batches = self._ds.to_batches()
|
batches = self._ds.to_batches()
|
||||||
@ -46,13 +90,49 @@ class _Reader(pa.flight.FlightServerBase):
|
|||||||
|
|
||||||
|
|
||||||
class Dataset(torch.utils.data.IterableDataset):
|
class Dataset(torch.utils.data.IterableDataset):
|
||||||
|
"""
|
||||||
|
A PyTorch IterableDataset wrapping a Parquet dataset for efficient data loading.
|
||||||
|
|
||||||
|
This class enables efficient loading of data from Parquet files using PyArrow.
|
||||||
|
It is designed to be used as an IterableDataset in PyTorch for training and inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
|
||||||
|
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||||
|
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
LOCATION (str): The default location for the Flight server used for data distribution.
|
||||||
|
_file_pattern (str): The glob pattern specifying Parquet files in the dataset.
|
||||||
|
_fs: The filesystem object used for file operations.
|
||||||
|
_dataset_kwargs (dict): Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||||
|
_files (list): A list of file paths matching the glob pattern.
|
||||||
|
_schema (pa.Schema): The schema of the Parquet dataset.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
serve(): Start serving the dataset using a Flight server.
|
||||||
|
to_batches(): Generate batches of data from the Parquet dataset.
|
||||||
|
pa_to_batch(batch: pa.RecordBatch) -> DataclassBatch: Convert a Parquet RecordBatch to a custom data batch.
|
||||||
|
dataloader(remote: bool = False): Create a PyTorch DataLoader for iterating through the dataset.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This class efficiently loads data from Parquet files using PyArrow, and it can be used with PyTorch
|
||||||
|
to create DataLoader instances for training or inference.
|
||||||
|
"""
|
||||||
LOCATION = "grpc://0.0.0.0:2222"
|
LOCATION = "grpc://0.0.0.0:2222"
|
||||||
|
|
||||||
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
|
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
|
||||||
"""Specify batch size and column to select for.
|
"""
|
||||||
|
Initialize a new Dataset instance. Specify batch size and column to select for.
|
||||||
|
|
||||||
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
|
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
|
||||||
"""
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_pattern (str): A glob pattern specifying the Parquet files to include in the dataset.
|
||||||
|
**dataset_kwargs: Additional keyword arguments passed to PyArrow's `to_batches` method.
|
||||||
|
"""
|
||||||
self._file_pattern = file_pattern
|
self._file_pattern = file_pattern
|
||||||
self._fs = infer_fs(self._file_pattern)
|
self._fs = infer_fs(self._file_pattern)
|
||||||
self._dataset_kwargs = dataset_kwargs
|
self._dataset_kwargs = dataset_kwargs
|
||||||
@ -64,16 +144,25 @@ class Dataset(torch.utils.data.IterableDataset):
|
|||||||
self._validate_columns()
|
self._validate_columns()
|
||||||
|
|
||||||
def _validate_columns(self):
|
def _validate_columns(self):
|
||||||
|
"""
|
||||||
|
Validate the specified columns against the dataset schema.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If any specified columns are not found in the dataset schema.
|
||||||
|
"""
|
||||||
columns = set(self._dataset_kwargs.get("columns", []))
|
columns = set(self._dataset_kwargs.get("columns", []))
|
||||||
wrong_columns = set(columns) - set(self._schema.names)
|
wrong_columns = set(columns) - set(self._schema.names)
|
||||||
if wrong_columns:
|
if wrong_columns:
|
||||||
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
|
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
|
||||||
|
|
||||||
def serve(self):
|
def serve(self):
|
||||||
|
"""Start serving the dataset using a Flight server."""
|
||||||
self.reader = _Reader(location=self.LOCATION, ds=self)
|
self.reader = _Reader(location=self.LOCATION, ds=self)
|
||||||
self.reader.serve()
|
self.reader.serve()
|
||||||
|
|
||||||
def _create_dataset(self):
|
def _create_dataset(self):
|
||||||
|
"""Create a PyArrow dataset for data retrieval."""
|
||||||
|
|
||||||
return pads.dataset(
|
return pads.dataset(
|
||||||
source=random.sample(self._files, len(self._files))[0],
|
source=random.sample(self._files, len(self._files))[0],
|
||||||
format="parquet",
|
format="parquet",
|
||||||
@ -100,9 +189,33 @@ class Dataset(torch.utils.data.IterableDataset):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
|
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
|
||||||
|
"""
|
||||||
|
Convert a Parquet RecordBatch to a custom data batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (pa.RecordBatch): A batch of data from the Parquet dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataclassBatch: A custom data batch used in PyTorch training.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: This method must be implemented in derived classes.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def dataloader(self, remote: bool = False):
|
def dataloader(self, remote: bool = False):
|
||||||
|
"""
|
||||||
|
Create a PyTorch DataLoader for iterating through the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote (bool, optional): If True, create a remote DataLoader using Flight for distributed training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataLoader: A PyTorch DataLoader for iterating through the dataset.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If `remote` is True, a remote DataLoader is created for distributed training using Flight.
|
||||||
|
"""
|
||||||
if not remote:
|
if not remote:
|
||||||
return map(self.pa_to_batch, self.to_batches())
|
return map(self.pa_to_batch, self.to_batches())
|
||||||
readers = get_readers(2)
|
readers = get_readers(2)
|
||||||
@ -117,6 +230,25 @@ GRPC_OPTIONS = [
|
|||||||
|
|
||||||
|
|
||||||
def get_readers(num_readers_per_worker: int):
|
def get_readers(num_readers_per_worker: int):
|
||||||
|
"""
|
||||||
|
Get Flight readers for distributed data loading.
|
||||||
|
|
||||||
|
This function retrieves Flight readers for distributed data loading in a PyTorch environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_readers_per_worker (int): The number of Flight readers to retrieve per worker.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[pa.RecordBatchFileReader]: A list of Flight readers for distributed data loading.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Flight readers are used to fetch data in a distributed manner for efficient data loading.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To obtain Flight readers, use the following code:
|
||||||
|
|
||||||
|
>>> readers = get_readers(num_readers_per_worker=2)
|
||||||
|
"""
|
||||||
addresses = env.get_flight_server_addresses()
|
addresses = env.get_flight_server_addresses()
|
||||||
|
|
||||||
readers = []
|
readers = []
|
||||||
|
@ -21,6 +21,16 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
|
|
||||||
def maybe_start_dataset_service():
|
def maybe_start_dataset_service():
|
||||||
|
"""
|
||||||
|
Start the dataset service if readers are available and required dependencies are met.
|
||||||
|
|
||||||
|
This function checks if readers are available and if the required TensorFlow version is >= 2.5.
|
||||||
|
If both conditions are met and the current environment is the dispatcher or reader, it starts
|
||||||
|
the TensorFlow dataset service.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the required TensorFlow version is not met (>= 2.5).
|
||||||
|
"""
|
||||||
if not env.has_readers():
|
if not env.has_readers():
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -59,6 +69,24 @@ def maybe_start_dataset_service():
|
|||||||
def register_dataset(
|
def register_dataset(
|
||||||
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
|
dataset: tf.data.Dataset, dataset_service: str, compression: Optional[str] = "AUTO"
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Register a dataset with the distributed dataset service.
|
||||||
|
|
||||||
|
This function registers a dataset with the distributed dataset service and broadcasts the dataset ID
|
||||||
|
and job name to all processes in the distributed environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (tf.data.Dataset): The dataset to be registered.
|
||||||
|
dataset_service (str): The name of the dataset service.
|
||||||
|
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, str]: A tuple containing the dataset ID and job name.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function should be called on the rank 0 process.
|
||||||
|
|
||||||
|
"""
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
dataset_id = _register_dataset(
|
dataset_id = _register_dataset(
|
||||||
service=dataset_service,
|
service=dataset_service,
|
||||||
@ -82,6 +110,23 @@ def distribute_from_dataset_id(
|
|||||||
compression: Optional[str] = "AUTO",
|
compression: Optional[str] = "AUTO",
|
||||||
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
|
prefetch: Optional[int] = tf.data.experimental.AUTOTUNE,
|
||||||
) -> tf.data.Dataset:
|
) -> tf.data.Dataset:
|
||||||
|
"""
|
||||||
|
Distribute a dataset from a registered dataset ID.
|
||||||
|
|
||||||
|
This function consumes a dataset from the distributed dataset service using the provided dataset ID
|
||||||
|
and job name. It also supports prefetching for improved performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_service (str): The name of the dataset service.
|
||||||
|
dataset_id (int): The ID of the dataset to be consumed.
|
||||||
|
job_name (Optional[str]): The name of the job associated with the dataset (optional).
|
||||||
|
compression (Optional[str]): The compression type for the dataset (default is "AUTO").
|
||||||
|
prefetch (Optional[int]): The number of elements to prefetch (default is tf.data.experimental.AUTOTUNE).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.data.Dataset: The distributed dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
|
logging.info(f"rank{dist.get_rank()}: Consuming dds job with {dataset_id}, {job_name}")
|
||||||
dataset = _from_dataset_id(
|
dataset = _from_dataset_id(
|
||||||
processing_mode="parallel_epochs",
|
processing_mode="parallel_epochs",
|
||||||
@ -97,15 +142,28 @@ def distribute_from_dataset_id(
|
|||||||
|
|
||||||
|
|
||||||
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
|
def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset:
|
||||||
"""Torch-compatible and distributed-training-aware dataset service distributor.
|
|
||||||
|
|
||||||
- rank 0 process will register the given dataset.
|
|
||||||
- rank 0 process will broadcast job name and dataset id.
|
|
||||||
- all rank processes will consume from the same job/dataset.
|
|
||||||
|
|
||||||
Without this, dataset workers will try to serve 1 job per rank process and OOM.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Distribute a TensorFlow dataset for Torch-compatible and distributed training-aware consumption.
|
||||||
|
|
||||||
|
This function is used to distribute a dataset in a distributed training environment. It performs the
|
||||||
|
following steps:
|
||||||
|
- On the rank 0 process, it registers the given dataset with the distributed dataset service.
|
||||||
|
- It broadcasts the job name and dataset ID to all rank processes.
|
||||||
|
- All rank processes then consume the same dataset from the distributed dataset service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (tf.data.Dataset): The TensorFlow dataset to be distributed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.data.Dataset: The distributed TensorFlow dataset.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- If there are no reader processes in the distributed environment, the original dataset is returned
|
||||||
|
without any distribution.
|
||||||
|
- This function is intended for use in distributed training environments to prevent out-of-memory (OOM)
|
||||||
|
issues caused by each rank process trying to serve one job.
|
||||||
|
|
||||||
|
"""
|
||||||
if not env.has_readers():
|
if not env.has_readers():
|
||||||
return dataset
|
return dataset
|
||||||
dataset_service = env.get_dds()
|
dataset_service = env.get_dds()
|
||||||
|
@ -12,6 +12,17 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def create_dataset(tmpdir):
|
def create_dataset(tmpdir):
|
||||||
|
"""
|
||||||
|
Create a mock dataset for testing.
|
||||||
|
|
||||||
|
This function creates a mock dataset using PyArrow and Parquet for testing purposes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmpdir: A temporary directory where the dataset will be created.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MockDataset: A mock dataset for testing.
|
||||||
|
"""
|
||||||
|
|
||||||
table = pa.table(
|
table = pa.table(
|
||||||
{
|
{
|
||||||
@ -34,6 +45,14 @@ def create_dataset(tmpdir):
|
|||||||
|
|
||||||
|
|
||||||
def test_dataset(tmpdir):
|
def test_dataset(tmpdir):
|
||||||
|
"""
|
||||||
|
Test the created dataset.
|
||||||
|
|
||||||
|
This function tests the created mock dataset and checks if it behaves as expected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmpdir: A temporary directory used for testing.
|
||||||
|
"""
|
||||||
ds = create_dataset(tmpdir)
|
ds = create_dataset(tmpdir)
|
||||||
batch = next(iter(ds.dataloader(remote=False)))
|
batch = next(iter(ds.dataloader(remote=False)))
|
||||||
assert batch.batch_size == 2
|
assert batch.batch_size == 2
|
||||||
@ -46,6 +65,14 @@ def test_dataset(tmpdir):
|
|||||||
reason="Multiprocessing doesn't work on github yet.",
|
reason="Multiprocessing doesn't work on github yet.",
|
||||||
)
|
)
|
||||||
def test_distributed_dataset(tmpdir):
|
def test_distributed_dataset(tmpdir):
|
||||||
|
"""
|
||||||
|
Test the distributed dataset.
|
||||||
|
|
||||||
|
This function tests the distributed version of the mock dataset using multiprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmpdir: A temporary directory used for testing.
|
||||||
|
"""
|
||||||
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
|
MOCK_ENV = {"TEMP_SLURM_NUM_READERS": "1"}
|
||||||
|
|
||||||
def _client():
|
def _client():
|
||||||
|
151
reader/utils.py
151
reader/utils.py
@ -11,11 +11,55 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def roundrobin(*iterables):
|
def roundrobin(*iterables):
|
||||||
"""Round robin through provided iterables, useful for simple load balancing.
|
|
||||||
|
|
||||||
Adapted from https://docs.python.org/3/library/itertools.html.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Iterate through provided iterables in a round-robin fashion.
|
||||||
|
|
||||||
|
This function takes multiple iterables and returns an iterator that yields elements from
|
||||||
|
each iterable in a round-robin manner. It continues cycling through the iterables until
|
||||||
|
all of them are exhausted.
|
||||||
|
|
||||||
|
Adapted from https://docs.python.org/3/library/itertools.html.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*iterables: One or more iterable objects to iterate through.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Elements from the provided iterables in a round-robin fashion.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StopIteration: If all provided iterables are exhausted.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
iterable1 = [1, 2, 3]
|
||||||
|
iterable2 = ['a', 'b', 'c']
|
||||||
|
iterable3 = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
for item in roundrobin(iterable1, iterable2, iterable3):
|
||||||
|
print(item)
|
||||||
|
|
||||||
|
# Output:
|
||||||
|
# 1
|
||||||
|
# 'a'
|
||||||
|
# 0.1
|
||||||
|
# 2
|
||||||
|
# 'b'
|
||||||
|
# 0.2
|
||||||
|
# 3
|
||||||
|
# 'c'
|
||||||
|
# 0.3
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- If one of the provided iterables is shorter than the others, the function will
|
||||||
|
continue iterating through the remaining iterables until all are exhausted.
|
||||||
|
- If an iterable raises an exception during iteration, a warning message is logged,
|
||||||
|
and the function continues with the next iterable.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `itertools.cycle`: A function that repeatedly cycles through elements of an iterable.
|
||||||
|
- `itertools.islice`: A function to slice an iterable to limit the number of iterations.
|
||||||
|
"""
|
||||||
num_active = len(iterables)
|
num_active = len(iterables)
|
||||||
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
nexts = itertools.cycle(iter(it).__next__ for it in iterables)
|
||||||
while num_active:
|
while num_active:
|
||||||
@ -35,6 +79,48 @@ def roundrobin(*iterables):
|
|||||||
|
|
||||||
|
|
||||||
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
|
def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]):
|
||||||
|
"""
|
||||||
|
Monitor the speed and progress of data loading using a data loader.
|
||||||
|
|
||||||
|
This function iterates through a data loader for a specified number of steps or until
|
||||||
|
the end of the data loader is reached, periodically logging progress information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_loader: The data loader to monitor.
|
||||||
|
max_steps: The maximum number of steps to iterate through the data loader.
|
||||||
|
frequency: The frequency (in steps) at which to log progress.
|
||||||
|
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
||||||
|
batch contents for inspection.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Create a data loader (replace with your own DataLoader configuration)
|
||||||
|
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||||
|
|
||||||
|
# Monitor data loading speed and progress
|
||||||
|
speed_check(data_loader, max_steps=1000, frequency=50, peek=500)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_loader: The data loader to monitor.
|
||||||
|
max_steps: The maximum number of steps to iterate through the data loader.
|
||||||
|
frequency: The frequency (in steps) at which to log progress.
|
||||||
|
peek (optional): If specified, it indicates the frequency (in steps) at which to log
|
||||||
|
batch contents for inspection.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The function logs information about elapsed time, the number of examples processed,
|
||||||
|
and the processing speed in examples per second.
|
||||||
|
- If `peek` is provided, batch contents will be logged for inspection at the specified
|
||||||
|
frequency.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
- `torch.utils.data.DataLoader`: PyTorch's data loading utility for batching and
|
||||||
|
iterating through datasets.
|
||||||
|
"""
|
||||||
num_examples = 0
|
num_examples = 0
|
||||||
prev = time.perf_counter()
|
prev = time.perf_counter()
|
||||||
for idx, batch in enumerate(data_loader):
|
for idx, batch in enumerate(data_loader):
|
||||||
@ -57,11 +143,66 @@ def speed_check(data_loader, max_steps: int, frequency: int, peek: Optional[int]
|
|||||||
|
|
||||||
|
|
||||||
def pa_to_torch(array: pa.array) -> torch.Tensor:
|
def pa_to_torch(array: pa.array) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a PyArrow Array to a PyTorch Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
array (pa.array): The PyArrow Array to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A PyTorch Tensor containing the data from the input PyArrow Array.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
import pyarrow as pa
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Create a PyArrow Array
|
||||||
|
arrow_array = pa.array([1, 2, 3])
|
||||||
|
|
||||||
|
# Convert it to a PyTorch Tensor
|
||||||
|
torch_tensor = pa_to_torch(arrow_array)
|
||||||
|
```
|
||||||
|
"""
|
||||||
return torch.from_numpy(array.to_numpy())
|
return torch.from_numpy(array.to_numpy())
|
||||||
|
|
||||||
|
|
||||||
def create_default_pa_to_batch(schema) -> DataclassBatch:
|
def create_default_pa_to_batch(schema) -> DataclassBatch:
|
||||||
""" """
|
"""
|
||||||
|
Create a function that converts a PyArrow RecordBatch to a custom DataclassBatch with imputed values for missing data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema (pa.Schema): The PyArrow schema describing the data structure of the RecordBatch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
callable: A function that takes a PyArrow RecordBatch as input and returns a custom DataclassBatch.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
import pyarrow as pa
|
||||||
|
from dataclass_batch import DataclassBatch
|
||||||
|
|
||||||
|
# Define a PyArrow schema
|
||||||
|
schema = pa.schema([
|
||||||
|
("feature1", pa.float64()),
|
||||||
|
("feature2", pa.int64()),
|
||||||
|
("label", pa.int64()),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Create the conversion function
|
||||||
|
pa_to_batch = create_default_pa_to_batch(schema)
|
||||||
|
|
||||||
|
# Create a PyArrow RecordBatch
|
||||||
|
record_batch = pa.RecordBatch.from_pandas(pd.DataFrame({
|
||||||
|
"feature1": [1.0, 2.0, None],
|
||||||
|
"feature2": [10, 20, 30],
|
||||||
|
"label": [0, 1, None],
|
||||||
|
}))
|
||||||
|
|
||||||
|
# Convert the RecordBatch to a custom DataclassBatch
|
||||||
|
custom_batch = pa_to_batch(record_batch)
|
||||||
|
```
|
||||||
|
"""
|
||||||
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
|
_CustomBatch = DataclassBatch.from_schema("DefaultBatch", schema=schema)
|
||||||
|
|
||||||
def get_imputation_value(pa_type):
|
def get_imputation_value(pa_type):
|
||||||
|
Loading…
Reference in New Issue
Block a user