mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-25 13:21:10 +01:00
add & config mypy on common package
Signed-off-by: wiseaidev <business@wiseai.dev>
This commit is contained in:
parent
78c3235eee
commit
2cc1abedd7
@ -1,22 +1,40 @@
|
||||
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
|
||||
"""
|
||||
# flake8: noqa
|
||||
from __future__ import annotations
|
||||
from typing import Dict
|
||||
from __future__ import (
|
||||
annotations,
|
||||
)
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
import dataclasses
|
||||
from collections import (
|
||||
UserDict,
|
||||
)
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torchrec.streamable import Pipelineable
|
||||
from torchrec.streamable import (
|
||||
Pipelineable,
|
||||
)
|
||||
|
||||
_KT = TypeVar("_KT") # key type
|
||||
_VT = TypeVar("_VT") # value type
|
||||
|
||||
|
||||
class BatchBase(Pipelineable, abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def as_dict(self) -> Dict:
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device: torch.device, non_blocking: bool = False):
|
||||
def to(self, device: torch.device, non_blocking: bool = False) -> BatchBase:
|
||||
args = {}
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
||||
@ -26,14 +44,14 @@ class BatchBase(Pipelineable, abc.ABC):
|
||||
for feature_value in self.as_dict().values():
|
||||
feature_value.record_stream(stream)
|
||||
|
||||
def pin_memory(self):
|
||||
def pin_memory(self) -> BatchBase:
|
||||
args = {}
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
args[feature_name] = feature_value.pin_memory()
|
||||
return self.__class__(**args)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def obj2str(v):
|
||||
def obj2str(v: Any) -> str:
|
||||
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
||||
|
||||
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
|
||||
@ -52,10 +70,10 @@ class BatchBase(Pipelineable, abc.ABC):
|
||||
@dataclass
|
||||
class DataclassBatch(BatchBase):
|
||||
@classmethod
|
||||
def feature_names(cls):
|
||||
def feature_names(cls) -> List[str]:
|
||||
return list(cls.__dataclass_fields__.keys())
|
||||
|
||||
def as_dict(self):
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
feature_name: getattr(self, feature_name)
|
||||
for feature_name in self.feature_names()
|
||||
@ -63,7 +81,7 @@ class DataclassBatch(BatchBase):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_schema(name: str, schema):
|
||||
def from_schema(name: str, schema: Any) -> type:
|
||||
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
|
||||
return dataclasses.make_dataclass(
|
||||
cls_name=name,
|
||||
@ -72,7 +90,7 @@ class DataclassBatch(BatchBase):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_fields(name: str, fields: dict):
|
||||
def from_fields(name: str, fields: Dict[str, Any]) -> type:
|
||||
return dataclasses.make_dataclass(
|
||||
cls_name=name,
|
||||
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
||||
@ -80,6 +98,6 @@ class DataclassBatch(BatchBase):
|
||||
)
|
||||
|
||||
|
||||
class DictionaryBatch(BatchBase, dict):
|
||||
def as_dict(self) -> Dict:
|
||||
class DictionaryBatch(BatchBase, UserDict[_KT, _VT]):
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
return self
|
||||
|
@ -1 +1,4 @@
|
||||
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot
|
||||
from tml.common.checkpointing.snapshot import (
|
||||
Snapshot,
|
||||
get_checkpoint,
|
||||
)
|
||||
|
@ -1,12 +1,24 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
from tml.common.filesystem import infer_fs, is_gcs_fs
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import torchsnapshot
|
||||
|
||||
from tml.common.filesystem import (
|
||||
infer_fs,
|
||||
is_gcs_fs,
|
||||
)
|
||||
from tml.ml_logging.torch_logging import (
|
||||
logging,
|
||||
)
|
||||
from torch import (
|
||||
FloatTensor,
|
||||
)
|
||||
|
||||
DONE_EVAL_SUBDIR = "evaled_by"
|
||||
GCS_PREFIX = "gs://"
|
||||
@ -25,7 +37,7 @@ class Snapshot:
|
||||
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
||||
|
||||
@property
|
||||
def step(self):
|
||||
def step(self) -> int:
|
||||
return self.state["extra_state"]["step"]
|
||||
|
||||
@step.setter
|
||||
@ -33,14 +45,14 @@ class Snapshot:
|
||||
self.state["extra_state"]["step"] = step
|
||||
|
||||
@property
|
||||
def walltime(self):
|
||||
def walltime(self) -> float:
|
||||
return self.state["extra_state"]["walltime"]
|
||||
|
||||
@walltime.setter
|
||||
def walltime(self, walltime: float) -> None:
|
||||
self.state["extra_state"]["walltime"] = walltime
|
||||
|
||||
def save(self, global_step: int) -> "PendingSnapshot":
|
||||
def save(self, global_step: int) -> "PendingSnapshot": # type: ignore
|
||||
"""Saves checkpoint with given global_step."""
|
||||
path = os.path.join(self.save_dir, str(global_step))
|
||||
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
||||
@ -98,7 +110,7 @@ class Snapshot:
|
||||
cls,
|
||||
embedding_snapshot: torchsnapshot.Snapshot,
|
||||
snapshot_emb_name: str,
|
||||
weight_tensor,
|
||||
weight_tensor: FloatTensor,
|
||||
) -> None:
|
||||
"""Loads pretrained embedding from the snapshot to the model.
|
||||
Utilise partial lodaing meachanism from torchsnapshot.
|
||||
@ -128,11 +140,11 @@ def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
|
||||
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
|
||||
|
||||
|
||||
def is_done_eval(checkpoint_path: str, eval_partition: str):
|
||||
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
|
||||
def is_done_eval(checkpoint_path: str, eval_partition: str) -> bool:
|
||||
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def mark_done_eval(checkpoint_path: str, eval_partition: str):
|
||||
def mark_done_eval(checkpoint_path: str, eval_partition: str) -> Any:
|
||||
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
|
||||
|
||||
|
||||
@ -140,7 +152,9 @@ def step_from_checkpoint(checkpoint: str) -> int:
|
||||
return int(os.path.basename(checkpoint))
|
||||
|
||||
|
||||
def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
|
||||
def checkpoints_iterator(
|
||||
save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800
|
||||
) -> Generator[str, None, None]:
|
||||
"""Simplified equivalent of tf.train.checkpoints_iterator.
|
||||
|
||||
Args:
|
||||
@ -149,7 +163,7 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
|
||||
|
||||
"""
|
||||
|
||||
def _poll(last_checkpoint: Optional[str] = None):
|
||||
def _poll(last_checkpoint: Optional[str] = None) -> Optional[str]:
|
||||
stop_time = time.time() + timeout
|
||||
while True:
|
||||
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def maybe_setup_tensorflow():
|
||||
def maybe_setup_tensorflow() -> None:
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
|
@ -1 +1,5 @@
|
||||
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs
|
||||
from tml.common.filesystem.util import (
|
||||
infer_fs,
|
||||
is_gcs_fs,
|
||||
is_local_fs,
|
||||
)
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
Mostly a test that it returns an object
|
||||
"""
|
||||
from tml.common.filesystem import infer_fs
|
||||
from tml.common.filesystem import (
|
||||
infer_fs,
|
||||
)
|
||||
|
||||
|
||||
def test_infer_fs():
|
||||
|
@ -1,13 +1,18 @@
|
||||
"""Utilities for interacting with the file systems."""
|
||||
from fsspec.implementations.local import LocalFileSystem
|
||||
import gcsfs
|
||||
from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import gcsfs
|
||||
from fsspec.implementations.local import (
|
||||
LocalFileSystem,
|
||||
)
|
||||
|
||||
GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
|
||||
LOCAL_FS = LocalFileSystem()
|
||||
|
||||
|
||||
def infer_fs(path: str):
|
||||
def infer_fs(path: str) -> Union[LocalFileSystem, gcsfs.core.GCSFileSystem, NotImplementedError]:
|
||||
if path.startswith("gs://"):
|
||||
return GCS_FS
|
||||
elif path.startswith("hdfs://"):
|
||||
@ -17,9 +22,9 @@ def infer_fs(path: str):
|
||||
return LOCAL_FS
|
||||
|
||||
|
||||
def is_local_fs(fs):
|
||||
def is_local_fs(fs: LocalFileSystem) -> bool:
|
||||
return fs == LOCAL_FS
|
||||
|
||||
|
||||
def is_gcs_fs(fs):
|
||||
def is_gcs_fs(fs: gcsfs.core.GCSFileSystem) -> bool:
|
||||
return fs == GCS_FS
|
||||
|
@ -1,17 +1,28 @@
|
||||
"""For logging model weights."""
|
||||
import itertools
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torchrec.distributed.model_parallel import DistributedModelParallel
|
||||
from tml.ml_logging.torch_logging import (
|
||||
logging,
|
||||
)
|
||||
from torchrec.distributed.model_parallel import (
|
||||
DistributedModelParallel,
|
||||
)
|
||||
|
||||
|
||||
def weights_to_log(
|
||||
model: torch.nn.Module,
|
||||
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
|
||||
):
|
||||
how_to_log: Optional[Union[Callable[[Any], Any], Dict[str, Callable[[Any], Any]]]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Creates dict of reduced weights to log to give sense of training.
|
||||
|
||||
Args:
|
||||
@ -21,7 +32,7 @@ def weights_to_log(
|
||||
|
||||
"""
|
||||
if not how_to_log:
|
||||
return
|
||||
return None
|
||||
|
||||
to_log = dict()
|
||||
named_parameters = model.named_parameters()
|
||||
@ -38,14 +49,14 @@ def weights_to_log(
|
||||
how = how_to_log
|
||||
else:
|
||||
how = how_to_log.get(param_name) # type: ignore[assignment]
|
||||
if not how:
|
||||
continue # type: ignore
|
||||
if how is None:
|
||||
continue
|
||||
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
|
||||
return to_log
|
||||
|
||||
|
||||
def log_ebc_norms(
|
||||
model_state_dict,
|
||||
model_state_dict: Dict[str, Any],
|
||||
ebc_keys: List[str],
|
||||
sample_size: int = 4_000_000,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
@ -1,10 +1,13 @@
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
|
||||
import tml.core.config as base_config
|
||||
from tml.optimizers.config import OptimizerConfig
|
||||
from typing import (
|
||||
List,
|
||||
)
|
||||
|
||||
import pydantic
|
||||
import tml.core.config as base_config
|
||||
from tml.optimizers.config import (
|
||||
OptimizerConfig,
|
||||
)
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
|
@ -1,13 +1,25 @@
|
||||
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchrec
|
||||
from torchrec.modules import embedding_configs
|
||||
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchrec
|
||||
from tml.common.modules.embedding.config import (
|
||||
DataType,
|
||||
LargeEmbeddingsConfig,
|
||||
)
|
||||
from tml.ml_logging.torch_logging import (
|
||||
logging,
|
||||
)
|
||||
from torch import nn
|
||||
from torchrec import (
|
||||
EmbeddingBagCollection,
|
||||
EmbeddingBagConfig,
|
||||
)
|
||||
from torchrec.modules import (
|
||||
embedding_configs,
|
||||
)
|
||||
from torchrec.sparse.jagged_tensor import (
|
||||
KeyedJaggedTensor,
|
||||
KeyedTensor,
|
||||
)
|
||||
|
||||
|
||||
class LargeEmbeddings(nn.Module):
|
||||
|
@ -1,30 +1,49 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
||||
from twitter.ml.tensorflow.experimental.distributed import utils
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.distributed.run
|
||||
from tml.ml_logging.torch_logging import (
|
||||
logging,
|
||||
)
|
||||
from twitter.ml.tensorflow.experimental.distributed import (
|
||||
utils,
|
||||
)
|
||||
from typing_extensions import (
|
||||
Unpack,
|
||||
)
|
||||
|
||||
|
||||
def is_distributed_worker():
|
||||
class Fn(Protocol):
|
||||
def __call__(self, *args: Tuple[str], **kwargs: Dict[str, Any]) -> None:
|
||||
...
|
||||
|
||||
|
||||
def is_distributed_worker() -> bool:
|
||||
world_size = os.environ.get("WORLD_SIZE", None)
|
||||
rank = os.environ.get("RANK", None)
|
||||
return world_size is not None and rank is not None
|
||||
|
||||
|
||||
def maybe_run_training(
|
||||
train_fn,
|
||||
module_name,
|
||||
train_fn: Fn,
|
||||
module_name: Any,
|
||||
nproc_per_node: Optional[int] = None,
|
||||
num_nodes: Optional[int] = None,
|
||||
num_nodes: Optional[int] = 0,
|
||||
set_python_path_in_subprocess: bool = False,
|
||||
is_chief: Optional[bool] = False,
|
||||
**training_kwargs,
|
||||
):
|
||||
**training_kwargs: Any,
|
||||
) -> None:
|
||||
"""Wrapper function for single node, multi-GPU Pytorch training.
|
||||
|
||||
If the necessary distributed Pytorch environment variables
|
||||
@ -73,7 +92,7 @@ def maybe_run_training(
|
||||
]
|
||||
if nproc_per_node:
|
||||
cmd.extend(["--nproc_per_node", str(nproc_per_node)])
|
||||
if num_nodes > 1:
|
||||
if num_nodes and num_nodes > 1:
|
||||
cluster_resolver = utils.cluster_resolver()
|
||||
backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
|
||||
cmd.extend(
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
Mostly a test that this can be imported properly even tho moved.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import (
|
||||
patch,
|
||||
)
|
||||
|
||||
import tml.common.device as device_utils
|
||||
|
||||
|
@ -1,11 +1,16 @@
|
||||
from contextlib import contextmanager
|
||||
import datetime
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from contextlib import (
|
||||
contextmanager,
|
||||
)
|
||||
from unittest.mock import (
|
||||
patch,
|
||||
)
|
||||
|
||||
import torch.distributed as dist
|
||||
from tml.ml_logging.torch_logging import logging
|
||||
|
||||
from tml.ml_logging.torch_logging import (
|
||||
logging,
|
||||
)
|
||||
|
||||
MOCK_ENV = {
|
||||
"LOCAL_RANK": "0",
|
||||
|
@ -1,18 +1,24 @@
|
||||
import yaml
|
||||
import getpass
|
||||
import os
|
||||
import string
|
||||
from typing import Tuple, Type, TypeVar
|
||||
|
||||
from tml.core.config import base_config
|
||||
from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import fsspec
|
||||
import yaml
|
||||
from tml.core.config import (
|
||||
base_config,
|
||||
)
|
||||
|
||||
C = TypeVar("C", bound=base_config.BaseConfig)
|
||||
T = TypeVar("T", str, bytes)
|
||||
|
||||
|
||||
def _read_file(f):
|
||||
with fsspec.open(f) as f:
|
||||
def _read_file(file: T) -> T:
|
||||
with fsspec.open(file) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@ -34,7 +40,7 @@ def setup_configuration(
|
||||
The pydantic config object.
|
||||
"""
|
||||
|
||||
def _substitute(s):
|
||||
def _substitute(s: str) -> str:
|
||||
if substitute_env_variable:
|
||||
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
||||
return s
|
||||
|
@ -1,8 +1,11 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tml.core.config as base_config
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
|
||||
import pydantic
|
||||
import tml.core.config as base_config
|
||||
|
||||
|
||||
class WandbConfig(base_config.BaseConfig):
|
||||
|
22
mypy.ini
Normal file
22
mypy.ini
Normal file
@ -0,0 +1,22 @@
|
||||
[mypy]
|
||||
exclude = dist
|
||||
ignore_errors = False
|
||||
ignore_missing_imports = True
|
||||
warn_unused_configs = True
|
||||
warn_redundant_casts = True
|
||||
warn_unused_ignores = True
|
||||
no_implicit_optional = True
|
||||
strict_equality = True
|
||||
strict_concatenate = True
|
||||
check_untyped_defs = True
|
||||
disallow_subclassing_any = False
|
||||
disallow_untyped_decorators = True
|
||||
disallow_any_generics = True
|
||||
disallow_untyped_calls = True
|
||||
disallow_incomplete_defs = True
|
||||
disallow_untyped_defs = True
|
||||
no_implicit_reexport = True
|
||||
strict = True
|
||||
warn_return_any = False
|
||||
explicit_package_bases = True
|
||||
namespace_packages = True
|
@ -15,3 +15,8 @@ exclude = '''
|
||||
| dist
|
||||
)/
|
||||
'''
|
||||
|
||||
[tool.isort]
|
||||
multi_line_output = 3
|
||||
line_length = 21
|
||||
include_trailing_comma = true
|
Loading…
x
Reference in New Issue
Block a user