add & config mypy on common package

Signed-off-by: wiseaidev <business@wiseai.dev>
This commit is contained in:
wiseaidev 2023-04-01 13:30:34 +03:00
parent 78c3235eee
commit 2cc1abedd7
17 changed files with 219 additions and 85 deletions

View File

@ -1,22 +1,40 @@
"""Extension of torchrec.dataset.utils.Batch to cover any dataset. """Extension of torchrec.dataset.utils.Batch to cover any dataset.
""" """
# flake8: noqa # flake8: noqa
from __future__ import annotations from __future__ import (
from typing import Dict annotations,
)
import abc import abc
from dataclasses import dataclass
import dataclasses import dataclasses
from collections import (
UserDict,
)
from dataclasses import (
dataclass,
)
from typing import (
Any,
Dict,
List,
TypeVar,
)
import torch import torch
from torchrec.streamable import Pipelineable from torchrec.streamable import (
Pipelineable,
)
_KT = TypeVar("_KT") # key type
_VT = TypeVar("_VT") # value type
class BatchBase(Pipelineable, abc.ABC): class BatchBase(Pipelineable, abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def as_dict(self) -> Dict: def as_dict(self) -> Dict[str, Any]:
raise NotImplementedError raise NotImplementedError
def to(self, device: torch.device, non_blocking: bool = False): def to(self, device: torch.device, non_blocking: bool = False) -> BatchBase:
args = {} args = {}
for feature_name, feature_value in self.as_dict().items(): for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking) args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
@ -26,14 +44,14 @@ class BatchBase(Pipelineable, abc.ABC):
for feature_value in self.as_dict().values(): for feature_value in self.as_dict().values():
feature_value.record_stream(stream) feature_value.record_stream(stream)
def pin_memory(self): def pin_memory(self) -> BatchBase:
args = {} args = {}
for feature_name, feature_value in self.as_dict().items(): for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory() args[feature_name] = feature_value.pin_memory()
return self.__class__(**args) return self.__class__(**args)
def __repr__(self) -> str: def __repr__(self) -> str:
def obj2str(v): def obj2str(v: Any) -> str:
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()]) return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
@ -52,10 +70,10 @@ class BatchBase(Pipelineable, abc.ABC):
@dataclass @dataclass
class DataclassBatch(BatchBase): class DataclassBatch(BatchBase):
@classmethod @classmethod
def feature_names(cls): def feature_names(cls) -> List[str]:
return list(cls.__dataclass_fields__.keys()) return list(cls.__dataclass_fields__.keys())
def as_dict(self): def as_dict(self) -> Dict[str, Any]:
return { return {
feature_name: getattr(self, feature_name) feature_name: getattr(self, feature_name)
for feature_name in self.feature_names() for feature_name in self.feature_names()
@ -63,7 +81,7 @@ class DataclassBatch(BatchBase):
} }
@staticmethod @staticmethod
def from_schema(name: str, schema): def from_schema(name: str, schema: Any) -> type:
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
return dataclasses.make_dataclass( return dataclasses.make_dataclass(
cls_name=name, cls_name=name,
@ -72,7 +90,7 @@ class DataclassBatch(BatchBase):
) )
@staticmethod @staticmethod
def from_fields(name: str, fields: dict): def from_fields(name: str, fields: Dict[str, Any]) -> type:
return dataclasses.make_dataclass( return dataclasses.make_dataclass(
cls_name=name, cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()], fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
@ -80,6 +98,6 @@ class DataclassBatch(BatchBase):
) )
class DictionaryBatch(BatchBase, dict): class DictionaryBatch(BatchBase, UserDict[_KT, _VT]):
def as_dict(self) -> Dict: def as_dict(self) -> Dict[str, Any]:
return self return self

View File

@ -1 +1,4 @@
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot from tml.common.checkpointing.snapshot import (
Snapshot,
get_checkpoint,
)

View File

@ -1,12 +1,24 @@
import os import os
import time import time
from typing import Any, Dict, List, Optional from typing import (
Any,
from tml.ml_logging.torch_logging import logging Dict,
from tml.common.filesystem import infer_fs, is_gcs_fs Generator,
List,
Optional,
)
import torchsnapshot import torchsnapshot
from tml.common.filesystem import (
infer_fs,
is_gcs_fs,
)
from tml.ml_logging.torch_logging import (
logging,
)
from torch import (
FloatTensor,
)
DONE_EVAL_SUBDIR = "evaled_by" DONE_EVAL_SUBDIR = "evaled_by"
GCS_PREFIX = "gs://" GCS_PREFIX = "gs://"
@ -25,7 +37,7 @@ class Snapshot:
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0) self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property @property
def step(self): def step(self) -> int:
return self.state["extra_state"]["step"] return self.state["extra_state"]["step"]
@step.setter @step.setter
@ -33,14 +45,14 @@ class Snapshot:
self.state["extra_state"]["step"] = step self.state["extra_state"]["step"] = step
@property @property
def walltime(self): def walltime(self) -> float:
return self.state["extra_state"]["walltime"] return self.state["extra_state"]["walltime"]
@walltime.setter @walltime.setter
def walltime(self, walltime: float) -> None: def walltime(self, walltime: float) -> None:
self.state["extra_state"]["walltime"] = walltime self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot": def save(self, global_step: int) -> "PendingSnapshot": # type: ignore
"""Saves checkpoint with given global_step.""" """Saves checkpoint with given global_step."""
path = os.path.join(self.save_dir, str(global_step)) path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.") logging.info(f"Saving snapshot global_step {global_step} to {path}.")
@ -98,7 +110,7 @@ class Snapshot:
cls, cls,
embedding_snapshot: torchsnapshot.Snapshot, embedding_snapshot: torchsnapshot.Snapshot,
snapshot_emb_name: str, snapshot_emb_name: str,
weight_tensor, weight_tensor: FloatTensor,
) -> None: ) -> None:
"""Loads pretrained embedding from the snapshot to the model. """Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot. Utilise partial lodaing meachanism from torchsnapshot.
@ -128,11 +140,11 @@ def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE") return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
def is_done_eval(checkpoint_path: str, eval_partition: str): def is_done_eval(checkpoint_path: str, eval_partition: str) -> bool:
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) # type: ignore[attr-defined]
def mark_done_eval(checkpoint_path: str, eval_partition: str): def mark_done_eval(checkpoint_path: str, eval_partition: str) -> Any:
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition)) infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
@ -140,7 +152,9 @@ def step_from_checkpoint(checkpoint: str) -> int:
return int(os.path.basename(checkpoint)) return int(os.path.basename(checkpoint))
def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800): def checkpoints_iterator(
save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800
) -> Generator[str, None, None]:
"""Simplified equivalent of tf.train.checkpoints_iterator. """Simplified equivalent of tf.train.checkpoints_iterator.
Args: Args:
@ -149,7 +163,7 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
""" """
def _poll(last_checkpoint: Optional[str] = None): def _poll(last_checkpoint: Optional[str] = None) -> Optional[str]:
stop_time = time.time() + timeout stop_time = time.time() + timeout
while True: while True:
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True) _checkpoint_path = get_checkpoint(save_dir, missing_ok=True)

View File

@ -4,7 +4,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
def maybe_setup_tensorflow(): def maybe_setup_tensorflow() -> None:
try: try:
import tensorflow as tf import tensorflow as tf
except ImportError: except ImportError:

View File

@ -1 +1,5 @@
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs from tml.common.filesystem.util import (
infer_fs,
is_gcs_fs,
is_local_fs,
)

View File

@ -2,7 +2,9 @@
Mostly a test that it returns an object Mostly a test that it returns an object
""" """
from tml.common.filesystem import infer_fs from tml.common.filesystem import (
infer_fs,
)
def test_infer_fs(): def test_infer_fs():

View File

@ -1,13 +1,18 @@
"""Utilities for interacting with the file systems.""" """Utilities for interacting with the file systems."""
from fsspec.implementations.local import LocalFileSystem from typing import (
import gcsfs Union,
)
import gcsfs
from fsspec.implementations.local import (
LocalFileSystem,
)
GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1) GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
LOCAL_FS = LocalFileSystem() LOCAL_FS = LocalFileSystem()
def infer_fs(path: str): def infer_fs(path: str) -> Union[LocalFileSystem, gcsfs.core.GCSFileSystem, NotImplementedError]:
if path.startswith("gs://"): if path.startswith("gs://"):
return GCS_FS return GCS_FS
elif path.startswith("hdfs://"): elif path.startswith("hdfs://"):
@ -17,9 +22,9 @@ def infer_fs(path: str):
return LOCAL_FS return LOCAL_FS
def is_local_fs(fs): def is_local_fs(fs: LocalFileSystem) -> bool:
return fs == LOCAL_FS return fs == LOCAL_FS
def is_gcs_fs(fs): def is_gcs_fs(fs: gcsfs.core.GCSFileSystem) -> bool:
return fs == GCS_FS return fs == GCS_FS

View File

@ -1,17 +1,28 @@
"""For logging model weights.""" """For logging model weights."""
import itertools import itertools
from typing import Callable, Dict, List, Optional, Union from typing import (
Any,
Callable,
Dict,
List,
Optional,
Union,
)
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel from tml.ml_logging.torch_logging import (
logging,
)
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
)
def weights_to_log( def weights_to_log(
model: torch.nn.Module, model: torch.nn.Module,
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None, how_to_log: Optional[Union[Callable[[Any], Any], Dict[str, Callable[[Any], Any]]]] = None,
): ) -> Optional[Dict[str, Any]]:
"""Creates dict of reduced weights to log to give sense of training. """Creates dict of reduced weights to log to give sense of training.
Args: Args:
@ -21,7 +32,7 @@ def weights_to_log(
""" """
if not how_to_log: if not how_to_log:
return return None
to_log = dict() to_log = dict()
named_parameters = model.named_parameters() named_parameters = model.named_parameters()
@ -38,14 +49,14 @@ def weights_to_log(
how = how_to_log how = how_to_log
else: else:
how = how_to_log.get(param_name) # type: ignore[assignment] how = how_to_log.get(param_name) # type: ignore[assignment]
if not how: if how is None:
continue # type: ignore continue
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy() to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
return to_log return to_log
def log_ebc_norms( def log_ebc_norms(
model_state_dict, model_state_dict: Dict[str, Any],
ebc_keys: List[str], ebc_keys: List[str],
sample_size: int = 4_000_000, sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:

View File

@ -1,10 +1,13 @@
from typing import List
from enum import Enum from enum import Enum
from typing import (
import tml.core.config as base_config List,
from tml.optimizers.config import OptimizerConfig )
import pydantic import pydantic
import tml.core.config as base_config
from tml.optimizers.config import (
OptimizerConfig,
)
class DataType(str, Enum): class DataType(str, Enum):

View File

@ -1,13 +1,25 @@
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
from tml.ml_logging.torch_logging import logging
import torch
from torch import nn
import torchrec
from torchrec.modules import embedding_configs
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
import numpy as np import numpy as np
import torch
import torchrec
from tml.common.modules.embedding.config import (
DataType,
LargeEmbeddingsConfig,
)
from tml.ml_logging.torch_logging import (
logging,
)
from torch import nn
from torchrec import (
EmbeddingBagCollection,
EmbeddingBagConfig,
)
from torchrec.modules import (
embedding_configs,
)
from torchrec.sparse.jagged_tensor import (
KeyedJaggedTensor,
KeyedTensor,
)
class LargeEmbeddings(nn.Module): class LargeEmbeddings(nn.Module):

View File

@ -1,30 +1,49 @@
import os import os
import subprocess import subprocess
import sys import sys
from typing import Optional from typing import (
Any,
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] Callable,
from twitter.ml.tensorflow.experimental.distributed import utils Dict,
Optional,
Protocol,
Tuple,
TypedDict,
)
import torch import torch
import torch.distributed.run import torch.distributed.run
from tml.ml_logging.torch_logging import (
logging,
)
from twitter.ml.tensorflow.experimental.distributed import (
utils,
)
from typing_extensions import (
Unpack,
)
def is_distributed_worker(): class Fn(Protocol):
def __call__(self, *args: Tuple[str], **kwargs: Dict[str, Any]) -> None:
...
def is_distributed_worker() -> bool:
world_size = os.environ.get("WORLD_SIZE", None) world_size = os.environ.get("WORLD_SIZE", None)
rank = os.environ.get("RANK", None) rank = os.environ.get("RANK", None)
return world_size is not None and rank is not None return world_size is not None and rank is not None
def maybe_run_training( def maybe_run_training(
train_fn, train_fn: Fn,
module_name, module_name: Any,
nproc_per_node: Optional[int] = None, nproc_per_node: Optional[int] = None,
num_nodes: Optional[int] = None, num_nodes: Optional[int] = 0,
set_python_path_in_subprocess: bool = False, set_python_path_in_subprocess: bool = False,
is_chief: Optional[bool] = False, is_chief: Optional[bool] = False,
**training_kwargs, **training_kwargs: Any,
): ) -> None:
"""Wrapper function for single node, multi-GPU Pytorch training. """Wrapper function for single node, multi-GPU Pytorch training.
If the necessary distributed Pytorch environment variables If the necessary distributed Pytorch environment variables
@ -73,7 +92,7 @@ def maybe_run_training(
] ]
if nproc_per_node: if nproc_per_node:
cmd.extend(["--nproc_per_node", str(nproc_per_node)]) cmd.extend(["--nproc_per_node", str(nproc_per_node)])
if num_nodes > 1: if num_nodes and num_nodes > 1:
cluster_resolver = utils.cluster_resolver() cluster_resolver = utils.cluster_resolver()
backend_address = cluster_resolver.cluster_spec().task_address("chief", 0) backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
cmd.extend( cmd.extend(

View File

@ -2,7 +2,9 @@
Mostly a test that this can be imported properly even tho moved. Mostly a test that this can be imported properly even tho moved.
""" """
from unittest.mock import patch from unittest.mock import (
patch,
)
import tml.common.device as device_utils import tml.common.device as device_utils

View File

@ -1,11 +1,16 @@
from contextlib import contextmanager
import datetime import datetime
import os import os
from unittest.mock import patch from contextlib import (
contextmanager,
)
from unittest.mock import (
patch,
)
import torch.distributed as dist import torch.distributed as dist
from tml.ml_logging.torch_logging import logging from tml.ml_logging.torch_logging import (
logging,
)
MOCK_ENV = { MOCK_ENV = {
"LOCAL_RANK": "0", "LOCAL_RANK": "0",

View File

@ -1,18 +1,24 @@
import yaml
import getpass import getpass
import os import os
import string import string
from typing import Tuple, Type, TypeVar from typing import (
Tuple,
from tml.core.config import base_config Type,
TypeVar,
)
import fsspec import fsspec
import yaml
from tml.core.config import (
base_config,
)
C = TypeVar("C", bound=base_config.BaseConfig) C = TypeVar("C", bound=base_config.BaseConfig)
T = TypeVar("T", str, bytes)
def _read_file(f): def _read_file(file: T) -> T:
with fsspec.open(f) as f: with fsspec.open(file) as f:
return f.read() return f.read()
@ -34,7 +40,7 @@ def setup_configuration(
The pydantic config object. The pydantic config object.
""" """
def _substitute(s): def _substitute(s: str) -> str:
if substitute_env_variable: if substitute_env_variable:
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
return s return s

View File

@ -1,8 +1,11 @@
from typing import Any, Dict, List from typing import (
Any,
import tml.core.config as base_config Dict,
List,
)
import pydantic import pydantic
import tml.core.config as base_config
class WandbConfig(base_config.BaseConfig): class WandbConfig(base_config.BaseConfig):

22
mypy.ini Normal file
View File

@ -0,0 +1,22 @@
[mypy]
exclude = dist
ignore_errors = False
ignore_missing_imports = True
warn_unused_configs = True
warn_redundant_casts = True
warn_unused_ignores = True
no_implicit_optional = True
strict_equality = True
strict_concatenate = True
check_untyped_defs = True
disallow_subclassing_any = False
disallow_untyped_decorators = True
disallow_any_generics = True
disallow_untyped_calls = True
disallow_incomplete_defs = True
disallow_untyped_defs = True
no_implicit_reexport = True
strict = True
warn_return_any = False
explicit_package_bases = True
namespace_packages = True

View File

@ -15,3 +15,8 @@ exclude = '''
| dist | dist
)/ )/
''' '''
[tool.isort]
multi_line_output = 3
line_length = 21
include_trailing_comma = true