add & config mypy on common package

Signed-off-by: wiseaidev <business@wiseai.dev>
This commit is contained in:
wiseaidev 2023-04-01 13:30:34 +03:00
parent 78c3235eee
commit 2cc1abedd7
17 changed files with 219 additions and 85 deletions

View File

@ -1,22 +1,40 @@
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
"""
# flake8: noqa
from __future__ import annotations
from typing import Dict
from __future__ import (
annotations,
)
import abc
from dataclasses import dataclass
import dataclasses
from collections import (
UserDict,
)
from dataclasses import (
dataclass,
)
from typing import (
Any,
Dict,
List,
TypeVar,
)
import torch
from torchrec.streamable import Pipelineable
from torchrec.streamable import (
Pipelineable,
)
_KT = TypeVar("_KT") # key type
_VT = TypeVar("_VT") # value type
class BatchBase(Pipelineable, abc.ABC):
@abc.abstractmethod
def as_dict(self) -> Dict:
def as_dict(self) -> Dict[str, Any]:
raise NotImplementedError
def to(self, device: torch.device, non_blocking: bool = False):
def to(self, device: torch.device, non_blocking: bool = False) -> BatchBase:
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
@ -26,14 +44,14 @@ class BatchBase(Pipelineable, abc.ABC):
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)
def pin_memory(self):
def pin_memory(self) -> BatchBase:
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)
def __repr__(self) -> str:
def obj2str(v):
def obj2str(v: Any) -> str:
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
@ -52,10 +70,10 @@ class BatchBase(Pipelineable, abc.ABC):
@dataclass
class DataclassBatch(BatchBase):
@classmethod
def feature_names(cls):
def feature_names(cls) -> List[str]:
return list(cls.__dataclass_fields__.keys())
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
@ -63,7 +81,7 @@ class DataclassBatch(BatchBase):
}
@staticmethod
def from_schema(name: str, schema):
def from_schema(name: str, schema: Any) -> type:
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
return dataclasses.make_dataclass(
cls_name=name,
@ -72,7 +90,7 @@ class DataclassBatch(BatchBase):
)
@staticmethod
def from_fields(name: str, fields: dict):
def from_fields(name: str, fields: Dict[str, Any]) -> type:
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
@ -80,6 +98,6 @@ class DataclassBatch(BatchBase):
)
class DictionaryBatch(BatchBase, dict):
def as_dict(self) -> Dict:
class DictionaryBatch(BatchBase, UserDict[_KT, _VT]):
def as_dict(self) -> Dict[str, Any]:
return self

View File

@ -1 +1,4 @@
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot
from tml.common.checkpointing.snapshot import (
Snapshot,
get_checkpoint,
)

View File

@ -1,12 +1,24 @@
import os
import time
from typing import Any, Dict, List, Optional
from tml.ml_logging.torch_logging import logging
from tml.common.filesystem import infer_fs, is_gcs_fs
from typing import (
Any,
Dict,
Generator,
List,
Optional,
)
import torchsnapshot
from tml.common.filesystem import (
infer_fs,
is_gcs_fs,
)
from tml.ml_logging.torch_logging import (
logging,
)
from torch import (
FloatTensor,
)
DONE_EVAL_SUBDIR = "evaled_by"
GCS_PREFIX = "gs://"
@ -25,7 +37,7 @@ class Snapshot:
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property
def step(self):
def step(self) -> int:
return self.state["extra_state"]["step"]
@step.setter
@ -33,14 +45,14 @@ class Snapshot:
self.state["extra_state"]["step"] = step
@property
def walltime(self):
def walltime(self) -> float:
return self.state["extra_state"]["walltime"]
@walltime.setter
def walltime(self, walltime: float) -> None:
self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot":
def save(self, global_step: int) -> "PendingSnapshot": # type: ignore
"""Saves checkpoint with given global_step."""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
@ -98,7 +110,7 @@ class Snapshot:
cls,
embedding_snapshot: torchsnapshot.Snapshot,
snapshot_emb_name: str,
weight_tensor,
weight_tensor: FloatTensor,
) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
@ -128,11 +140,11 @@ def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
def is_done_eval(checkpoint_path: str, eval_partition: str):
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
def is_done_eval(checkpoint_path: str, eval_partition: str) -> bool:
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) # type: ignore[attr-defined]
def mark_done_eval(checkpoint_path: str, eval_partition: str):
def mark_done_eval(checkpoint_path: str, eval_partition: str) -> Any:
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
@ -140,7 +152,9 @@ def step_from_checkpoint(checkpoint: str) -> int:
return int(os.path.basename(checkpoint))
def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
def checkpoints_iterator(
save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800
) -> Generator[str, None, None]:
"""Simplified equivalent of tf.train.checkpoints_iterator.
Args:
@ -149,7 +163,7 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
"""
def _poll(last_checkpoint: Optional[str] = None):
def _poll(last_checkpoint: Optional[str] = None) -> Optional[str]:
stop_time = time.time() + timeout
while True:
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)

View File

@ -4,7 +4,7 @@ import torch
import torch.distributed as dist
def maybe_setup_tensorflow():
def maybe_setup_tensorflow() -> None:
try:
import tensorflow as tf
except ImportError:

View File

@ -1 +1,5 @@
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs
from tml.common.filesystem.util import (
infer_fs,
is_gcs_fs,
is_local_fs,
)

View File

@ -2,7 +2,9 @@
Mostly a test that it returns an object
"""
from tml.common.filesystem import infer_fs
from tml.common.filesystem import (
infer_fs,
)
def test_infer_fs():

View File

@ -1,13 +1,18 @@
"""Utilities for interacting with the file systems."""
from fsspec.implementations.local import LocalFileSystem
import gcsfs
from typing import (
Union,
)
import gcsfs
from fsspec.implementations.local import (
LocalFileSystem,
)
GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
LOCAL_FS = LocalFileSystem()
def infer_fs(path: str):
def infer_fs(path: str) -> Union[LocalFileSystem, gcsfs.core.GCSFileSystem, NotImplementedError]:
if path.startswith("gs://"):
return GCS_FS
elif path.startswith("hdfs://"):
@ -17,9 +22,9 @@ def infer_fs(path: str):
return LOCAL_FS
def is_local_fs(fs):
def is_local_fs(fs: LocalFileSystem) -> bool:
return fs == LOCAL_FS
def is_gcs_fs(fs):
def is_gcs_fs(fs: gcsfs.core.GCSFileSystem) -> bool:
return fs == GCS_FS

View File

@ -1,17 +1,28 @@
"""For logging model weights."""
import itertools
from typing import Callable, Dict, List, Optional, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Union,
)
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from tml.ml_logging.torch_logging import (
logging,
)
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
)
def weights_to_log(
model: torch.nn.Module,
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
):
how_to_log: Optional[Union[Callable[[Any], Any], Dict[str, Callable[[Any], Any]]]] = None,
) -> Optional[Dict[str, Any]]:
"""Creates dict of reduced weights to log to give sense of training.
Args:
@ -21,7 +32,7 @@ def weights_to_log(
"""
if not how_to_log:
return
return None
to_log = dict()
named_parameters = model.named_parameters()
@ -38,14 +49,14 @@ def weights_to_log(
how = how_to_log
else:
how = how_to_log.get(param_name) # type: ignore[assignment]
if not how:
continue # type: ignore
if how is None:
continue
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
return to_log
def log_ebc_norms(
model_state_dict,
model_state_dict: Dict[str, Any],
ebc_keys: List[str],
sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:

View File

@ -1,10 +1,13 @@
from typing import List
from enum import Enum
import tml.core.config as base_config
from tml.optimizers.config import OptimizerConfig
from typing import (
List,
)
import pydantic
import tml.core.config as base_config
from tml.optimizers.config import (
OptimizerConfig,
)
class DataType(str, Enum):

View File

@ -1,13 +1,25 @@
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
from tml.ml_logging.torch_logging import logging
import torch
from torch import nn
import torchrec
from torchrec.modules import embedding_configs
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
import numpy as np
import torch
import torchrec
from tml.common.modules.embedding.config import (
DataType,
LargeEmbeddingsConfig,
)
from tml.ml_logging.torch_logging import (
logging,
)
from torch import nn
from torchrec import (
EmbeddingBagCollection,
EmbeddingBagConfig,
)
from torchrec.modules import (
embedding_configs,
)
from torchrec.sparse.jagged_tensor import (
KeyedJaggedTensor,
KeyedTensor,
)
class LargeEmbeddings(nn.Module):

View File

@ -1,30 +1,49 @@
import os
import subprocess
import sys
from typing import Optional
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
from twitter.ml.tensorflow.experimental.distributed import utils
from typing import (
Any,
Callable,
Dict,
Optional,
Protocol,
Tuple,
TypedDict,
)
import torch
import torch.distributed.run
from tml.ml_logging.torch_logging import (
logging,
)
from twitter.ml.tensorflow.experimental.distributed import (
utils,
)
from typing_extensions import (
Unpack,
)
def is_distributed_worker():
class Fn(Protocol):
def __call__(self, *args: Tuple[str], **kwargs: Dict[str, Any]) -> None:
...
def is_distributed_worker() -> bool:
world_size = os.environ.get("WORLD_SIZE", None)
rank = os.environ.get("RANK", None)
return world_size is not None and rank is not None
def maybe_run_training(
train_fn,
module_name,
train_fn: Fn,
module_name: Any,
nproc_per_node: Optional[int] = None,
num_nodes: Optional[int] = None,
num_nodes: Optional[int] = 0,
set_python_path_in_subprocess: bool = False,
is_chief: Optional[bool] = False,
**training_kwargs,
):
**training_kwargs: Any,
) -> None:
"""Wrapper function for single node, multi-GPU Pytorch training.
If the necessary distributed Pytorch environment variables
@ -73,7 +92,7 @@ def maybe_run_training(
]
if nproc_per_node:
cmd.extend(["--nproc_per_node", str(nproc_per_node)])
if num_nodes > 1:
if num_nodes and num_nodes > 1:
cluster_resolver = utils.cluster_resolver()
backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
cmd.extend(

View File

@ -2,7 +2,9 @@
Mostly a test that this can be imported properly even tho moved.
"""
from unittest.mock import patch
from unittest.mock import (
patch,
)
import tml.common.device as device_utils

View File

@ -1,11 +1,16 @@
from contextlib import contextmanager
import datetime
import os
from unittest.mock import patch
from contextlib import (
contextmanager,
)
from unittest.mock import (
patch,
)
import torch.distributed as dist
from tml.ml_logging.torch_logging import logging
from tml.ml_logging.torch_logging import (
logging,
)
MOCK_ENV = {
"LOCAL_RANK": "0",

View File

@ -1,18 +1,24 @@
import yaml
import getpass
import os
import string
from typing import Tuple, Type, TypeVar
from tml.core.config import base_config
from typing import (
Tuple,
Type,
TypeVar,
)
import fsspec
import yaml
from tml.core.config import (
base_config,
)
C = TypeVar("C", bound=base_config.BaseConfig)
T = TypeVar("T", str, bytes)
def _read_file(f):
with fsspec.open(f) as f:
def _read_file(file: T) -> T:
with fsspec.open(file) as f:
return f.read()
@ -34,7 +40,7 @@ def setup_configuration(
The pydantic config object.
"""
def _substitute(s):
def _substitute(s: str) -> str:
if substitute_env_variable:
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
return s

View File

@ -1,8 +1,11 @@
from typing import Any, Dict, List
import tml.core.config as base_config
from typing import (
Any,
Dict,
List,
)
import pydantic
import tml.core.config as base_config
class WandbConfig(base_config.BaseConfig):

22
mypy.ini Normal file
View File

@ -0,0 +1,22 @@
[mypy]
exclude = dist
ignore_errors = False
ignore_missing_imports = True
warn_unused_configs = True
warn_redundant_casts = True
warn_unused_ignores = True
no_implicit_optional = True
strict_equality = True
strict_concatenate = True
check_untyped_defs = True
disallow_subclassing_any = False
disallow_untyped_decorators = True
disallow_any_generics = True
disallow_untyped_calls = True
disallow_incomplete_defs = True
disallow_untyped_defs = True
no_implicit_reexport = True
strict = True
warn_return_any = False
explicit_package_bases = True
namespace_packages = True

View File

@ -15,3 +15,8 @@ exclude = '''
| dist
)/
'''
[tool.isort]
multi_line_output = 3
line_length = 21
include_trailing_comma = true