wiseaidev 2cc1abedd7 add & config mypy on common package
Signed-off-by: wiseaidev <business@wiseai.dev>
2023-04-01 14:38:59 +03:00

104 lines
2.7 KiB
Python

"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
"""
# flake8: noqa
from __future__ import (
annotations,
)
import abc
import dataclasses
from collections import (
UserDict,
)
from dataclasses import (
dataclass,
)
from typing import (
Any,
Dict,
List,
TypeVar,
)
import torch
from torchrec.streamable import (
Pipelineable,
)
_KT = TypeVar("_KT") # key type
_VT = TypeVar("_VT") # value type
class BatchBase(Pipelineable, abc.ABC):
@abc.abstractmethod
def as_dict(self) -> Dict[str, Any]:
raise NotImplementedError
def to(self, device: torch.device, non_blocking: bool = False) -> BatchBase:
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
return self.__class__(**args)
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)
def pin_memory(self) -> BatchBase:
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)
def __repr__(self) -> str:
def obj2str(v: Any) -> str:
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
@property
def batch_size(self) -> int:
for tensor in self.as_dict().values():
if tensor is None:
continue
if not isinstance(tensor, torch.Tensor):
continue
return tensor.shape[0]
raise Exception("Could not determine batch size from tensors.")
@dataclass
class DataclassBatch(BatchBase):
@classmethod
def feature_names(cls) -> List[str]:
return list(cls.__dataclass_fields__.keys())
def as_dict(self) -> Dict[str, Any]:
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
if hasattr(self, feature_name)
}
@staticmethod
def from_schema(name: str, schema: Any) -> type:
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
bases=(DataclassBatch,),
)
@staticmethod
def from_fields(name: str, fields: Dict[str, Any]) -> type:
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
bases=(DataclassBatch,),
)
class DictionaryBatch(BatchBase, UserDict[_KT, _VT]):
def as_dict(self) -> Dict[str, Any]:
return self