2023-03-31 20:05:14 +02:00
|
|
|
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
|
|
|
|
"""
|
|
|
|
# flake8: noqa
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Dict
|
|
|
|
import abc
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import dataclasses
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torchrec.streamable import Pipelineable
|
|
|
|
|
|
|
|
|
|
|
|
class BatchBase(Pipelineable, abc.ABC):
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
A base class for batches used in pipelines.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
None
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
@abc.abstractmethod
|
|
|
|
def as_dict(self) -> Dict:
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Convert the batch into a dictionary representation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict: A dictionary representation of the batch.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
NotImplementedError: If the method is not implemented in a subclass.
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def to(self, device: torch.device, non_blocking: bool = False):
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Move the batch to the specified device.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
device (torch.device): The target device.
|
|
|
|
non_blocking (bool, optional): Whether to use non-blocking transfers. Defaults to False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
BatchBase: A new batch on the target device.
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
args = {}
|
|
|
|
for feature_name, feature_value in self.as_dict().items():
|
|
|
|
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
|
|
|
return self.__class__(**args)
|
|
|
|
|
|
|
|
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Record a CUDA stream for all tensors in the batch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
stream (torch.cuda.streams.Stream): The CUDA stream to record.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
for feature_value in self.as_dict().values():
|
|
|
|
feature_value.record_stream(stream)
|
|
|
|
|
|
|
|
def pin_memory(self):
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Pin memory for all tensors in the batch.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
BatchBase: A new batch with pinned memory.
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
args = {}
|
|
|
|
for feature_name, feature_value in self.as_dict().items():
|
|
|
|
args[feature_name] = feature_value.pin_memory()
|
|
|
|
return self.__class__(**args)
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Generate a string representation of the batch.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
str: A string representation of the batch.
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
def obj2str(v):
|
|
|
|
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
|
|
|
|
|
|
|
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
|
|
|
|
|
|
|
|
@property
|
|
|
|
def batch_size(self) -> int:
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Get the batch size from the tensors in the batch.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: The batch size.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
Exception: If the batch size cannot be determined from the tensors.
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
for tensor in self.as_dict().values():
|
|
|
|
if tensor is None:
|
|
|
|
continue
|
|
|
|
if not isinstance(tensor, torch.Tensor):
|
|
|
|
continue
|
|
|
|
return tensor.shape[0]
|
|
|
|
raise Exception("Could not determine batch size from tensors.")
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class DataclassBatch(BatchBase):
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
A batch class that uses dataclasses to define its fields.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
None
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
@classmethod
|
|
|
|
def feature_names(cls):
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Get the feature names of the dataclass.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[str]: A list of feature names.
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
return list(cls.__dataclass_fields__.keys())
|
|
|
|
|
|
|
|
def as_dict(self):
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Convert the dataclass batch into a dictionary representation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict: A dictionary representation of the batch.
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
return {
|
|
|
|
feature_name: getattr(self, feature_name)
|
|
|
|
for feature_name in self.feature_names()
|
|
|
|
if hasattr(self, feature_name)
|
|
|
|
}
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_schema(name: str, schema):
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Instantiate a custom batch subclass if all columns can be represented as a torch.Tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name (str): The name of the custom batch class.
|
|
|
|
schema: The schema or structure of the batch.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Type[DataclassBatch]: A custom batch class.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
2023-03-31 20:05:14 +02:00
|
|
|
return dataclasses.make_dataclass(
|
|
|
|
cls_name=name,
|
|
|
|
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
|
|
|
|
bases=(DataclassBatch,),
|
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_fields(name: str, fields: dict):
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Create a custom batch subclass from a set of fields.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name (str): The name of the custom batch class.
|
|
|
|
fields (dict): A dictionary specifying the fields and their types.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Type[DataclassBatch]: A custom batch class.
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
return dataclasses.make_dataclass(
|
|
|
|
cls_name=name,
|
|
|
|
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
|
|
|
bases=(DataclassBatch,),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class DictionaryBatch(BatchBase, dict):
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
A batch class that represents data as a dictionary.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
None
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
def as_dict(self) -> Dict:
|
2023-09-11 18:01:42 +02:00
|
|
|
"""
|
|
|
|
Convert the dictionary batch into a dictionary representation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict: A dictionary representation of the batch.
|
|
|
|
|
|
|
|
"""
|
2023-03-31 20:05:14 +02:00
|
|
|
return self
|