diff --git a/common/batch.py b/common/batch.py index cfd0ac0..7c04dcf 100644 --- a/common/batch.py +++ b/common/batch.py @@ -1,22 +1,40 @@ """Extension of torchrec.dataset.utils.Batch to cover any dataset. """ # flake8: noqa -from __future__ import annotations -from typing import Dict +from __future__ import ( + annotations, +) + import abc -from dataclasses import dataclass import dataclasses +from collections import ( + UserDict, +) +from dataclasses import ( + dataclass, +) +from typing import ( + Any, + Dict, + List, + TypeVar, +) import torch -from torchrec.streamable import Pipelineable +from torchrec.streamable import ( + Pipelineable, +) + +_KT = TypeVar("_KT") # key type +_VT = TypeVar("_VT") # value type class BatchBase(Pipelineable, abc.ABC): @abc.abstractmethod - def as_dict(self) -> Dict: + def as_dict(self) -> Dict[str, Any]: raise NotImplementedError - def to(self, device: torch.device, non_blocking: bool = False): + def to(self, device: torch.device, non_blocking: bool = False) -> BatchBase: args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking) @@ -26,14 +44,14 @@ class BatchBase(Pipelineable, abc.ABC): for feature_value in self.as_dict().values(): feature_value.record_stream(stream) - def pin_memory(self): + def pin_memory(self) -> BatchBase: args = {} for feature_name, feature_value in self.as_dict().items(): args[feature_name] = feature_value.pin_memory() return self.__class__(**args) def __repr__(self) -> str: - def obj2str(v): + def obj2str(v: Any) -> str: return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}" return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()]) @@ -52,10 +70,10 @@ class BatchBase(Pipelineable, abc.ABC): @dataclass class DataclassBatch(BatchBase): @classmethod - def feature_names(cls): + def feature_names(cls) -> List[str]: return list(cls.__dataclass_fields__.keys()) - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: return { feature_name: getattr(self, feature_name) for feature_name in self.feature_names() @@ -63,7 +81,7 @@ class DataclassBatch(BatchBase): } @staticmethod - def from_schema(name: str, schema): + def from_schema(name: str, schema: Any) -> type: """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor.""" return dataclasses.make_dataclass( cls_name=name, @@ -72,7 +90,7 @@ class DataclassBatch(BatchBase): ) @staticmethod - def from_fields(name: str, fields: dict): + def from_fields(name: str, fields: Dict[str, Any]) -> type: return dataclasses.make_dataclass( cls_name=name, fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()], @@ -80,6 +98,6 @@ class DataclassBatch(BatchBase): ) -class DictionaryBatch(BatchBase, dict): - def as_dict(self) -> Dict: +class DictionaryBatch(BatchBase, UserDict[_KT, _VT]): + def as_dict(self) -> Dict[str, Any]: return self diff --git a/common/checkpointing/__init__.py b/common/checkpointing/__init__.py index 02fe102..879a00a 100644 --- a/common/checkpointing/__init__.py +++ b/common/checkpointing/__init__.py @@ -1 +1,4 @@ -from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot +from tml.common.checkpointing.snapshot import ( + Snapshot, + get_checkpoint, +) diff --git a/common/checkpointing/snapshot.py b/common/checkpointing/snapshot.py index 2703efd..fce0e09 100644 --- a/common/checkpointing/snapshot.py +++ b/common/checkpointing/snapshot.py @@ -1,12 +1,24 @@ import os import time -from typing import Any, Dict, List, Optional - -from tml.ml_logging.torch_logging import logging -from tml.common.filesystem import infer_fs, is_gcs_fs +from typing import ( + Any, + Dict, + Generator, + List, + Optional, +) import torchsnapshot - +from tml.common.filesystem import ( + infer_fs, + is_gcs_fs, +) +from tml.ml_logging.torch_logging import ( + logging, +) +from torch import ( + FloatTensor, +) DONE_EVAL_SUBDIR = "evaled_by" GCS_PREFIX = "gs://" @@ -25,7 +37,7 @@ class Snapshot: self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0) @property - def step(self): + def step(self) -> int: return self.state["extra_state"]["step"] @step.setter @@ -33,14 +45,14 @@ class Snapshot: self.state["extra_state"]["step"] = step @property - def walltime(self): + def walltime(self) -> float: return self.state["extra_state"]["walltime"] @walltime.setter def walltime(self, walltime: float) -> None: self.state["extra_state"]["walltime"] = walltime - def save(self, global_step: int) -> "PendingSnapshot": + def save(self, global_step: int) -> "PendingSnapshot": # type: ignore """Saves checkpoint with given global_step.""" path = os.path.join(self.save_dir, str(global_step)) logging.info(f"Saving snapshot global_step {global_step} to {path}.") @@ -98,7 +110,7 @@ class Snapshot: cls, embedding_snapshot: torchsnapshot.Snapshot, snapshot_emb_name: str, - weight_tensor, + weight_tensor: FloatTensor, ) -> None: """Loads pretrained embedding from the snapshot to the model. Utilise partial lodaing meachanism from torchsnapshot. @@ -128,11 +140,11 @@ def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str: return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE") -def is_done_eval(checkpoint_path: str, eval_partition: str): - return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) +def is_done_eval(checkpoint_path: str, eval_partition: str) -> bool: + return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) # type: ignore[attr-defined] -def mark_done_eval(checkpoint_path: str, eval_partition: str): +def mark_done_eval(checkpoint_path: str, eval_partition: str) -> Any: infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition)) @@ -140,7 +152,9 @@ def step_from_checkpoint(checkpoint: str) -> int: return int(os.path.basename(checkpoint)) -def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800): +def checkpoints_iterator( + save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800 +) -> Generator[str, None, None]: """Simplified equivalent of tf.train.checkpoints_iterator. Args: @@ -149,7 +163,7 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int """ - def _poll(last_checkpoint: Optional[str] = None): + def _poll(last_checkpoint: Optional[str] = None) -> Optional[str]: stop_time = time.time() + timeout while True: _checkpoint_path = get_checkpoint(save_dir, missing_ok=True) diff --git a/common/device.py b/common/device.py index 33d1bb8..c4387ec 100644 --- a/common/device.py +++ b/common/device.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist -def maybe_setup_tensorflow(): +def maybe_setup_tensorflow() -> None: try: import tensorflow as tf except ImportError: diff --git a/common/filesystem/__init__.py b/common/filesystem/__init__.py index 1b896ba..f212637 100644 --- a/common/filesystem/__init__.py +++ b/common/filesystem/__init__.py @@ -1 +1,5 @@ -from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs +from tml.common.filesystem.util import ( + infer_fs, + is_gcs_fs, + is_local_fs, +) diff --git a/common/filesystem/test_infer_fs.py b/common/filesystem/test_infer_fs.py index 2036441..1f39f71 100644 --- a/common/filesystem/test_infer_fs.py +++ b/common/filesystem/test_infer_fs.py @@ -2,7 +2,9 @@ Mostly a test that it returns an object """ -from tml.common.filesystem import infer_fs +from tml.common.filesystem import ( + infer_fs, +) def test_infer_fs(): diff --git a/common/filesystem/util.py b/common/filesystem/util.py index 5148a0d..b89ac32 100644 --- a/common/filesystem/util.py +++ b/common/filesystem/util.py @@ -1,13 +1,18 @@ """Utilities for interacting with the file systems.""" -from fsspec.implementations.local import LocalFileSystem -import gcsfs +from typing import ( + Union, +) +import gcsfs +from fsspec.implementations.local import ( + LocalFileSystem, +) GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1) LOCAL_FS = LocalFileSystem() -def infer_fs(path: str): +def infer_fs(path: str) -> Union[LocalFileSystem, gcsfs.core.GCSFileSystem, NotImplementedError]: if path.startswith("gs://"): return GCS_FS elif path.startswith("hdfs://"): @@ -17,9 +22,9 @@ def infer_fs(path: str): return LOCAL_FS -def is_local_fs(fs): +def is_local_fs(fs: LocalFileSystem) -> bool: return fs == LOCAL_FS -def is_gcs_fs(fs): +def is_gcs_fs(fs: gcsfs.core.GCSFileSystem) -> bool: return fs == GCS_FS diff --git a/common/log_weights.py b/common/log_weights.py index d2c8e21..43e905b 100644 --- a/common/log_weights.py +++ b/common/log_weights.py @@ -1,17 +1,28 @@ """For logging model weights.""" import itertools -from typing import Callable, Dict, List, Optional, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Union, +) -from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] import torch import torch.distributed as dist -from torchrec.distributed.model_parallel import DistributedModelParallel +from tml.ml_logging.torch_logging import ( + logging, +) +from torchrec.distributed.model_parallel import ( + DistributedModelParallel, +) def weights_to_log( model: torch.nn.Module, - how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None, -): + how_to_log: Optional[Union[Callable[[Any], Any], Dict[str, Callable[[Any], Any]]]] = None, +) -> Optional[Dict[str, Any]]: """Creates dict of reduced weights to log to give sense of training. Args: @@ -21,7 +32,7 @@ def weights_to_log( """ if not how_to_log: - return + return None to_log = dict() named_parameters = model.named_parameters() @@ -38,14 +49,14 @@ def weights_to_log( how = how_to_log else: how = how_to_log.get(param_name) # type: ignore[assignment] - if not how: - continue # type: ignore + if how is None: + continue to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy() return to_log def log_ebc_norms( - model_state_dict, + model_state_dict: Dict[str, Any], ebc_keys: List[str], sample_size: int = 4_000_000, ) -> Dict[str, torch.Tensor]: diff --git a/common/modules/embedding/config.py b/common/modules/embedding/config.py index 2f5df15..30eff35 100644 --- a/common/modules/embedding/config.py +++ b/common/modules/embedding/config.py @@ -1,10 +1,13 @@ -from typing import List from enum import Enum - -import tml.core.config as base_config -from tml.optimizers.config import OptimizerConfig +from typing import ( + List, +) import pydantic +import tml.core.config as base_config +from tml.optimizers.config import ( + OptimizerConfig, +) class DataType(str, Enum): diff --git a/common/modules/embedding/embedding.py b/common/modules/embedding/embedding.py index b0a085e..5c47c85 100644 --- a/common/modules/embedding/embedding.py +++ b/common/modules/embedding/embedding.py @@ -1,13 +1,25 @@ -from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType -from tml.ml_logging.torch_logging import logging - -import torch -from torch import nn -import torchrec -from torchrec.modules import embedding_configs -from torchrec import EmbeddingBagConfig, EmbeddingBagCollection -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor import numpy as np +import torch +import torchrec +from tml.common.modules.embedding.config import ( + DataType, + LargeEmbeddingsConfig, +) +from tml.ml_logging.torch_logging import ( + logging, +) +from torch import nn +from torchrec import ( + EmbeddingBagCollection, + EmbeddingBagConfig, +) +from torchrec.modules import ( + embedding_configs, +) +from torchrec.sparse.jagged_tensor import ( + KeyedJaggedTensor, + KeyedTensor, +) class LargeEmbeddings(nn.Module): diff --git a/common/run_training.py b/common/run_training.py index fafb140..93293c1 100644 --- a/common/run_training.py +++ b/common/run_training.py @@ -1,30 +1,49 @@ import os import subprocess import sys -from typing import Optional - -from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] -from twitter.ml.tensorflow.experimental.distributed import utils +from typing import ( + Any, + Callable, + Dict, + Optional, + Protocol, + Tuple, + TypedDict, +) import torch import torch.distributed.run +from tml.ml_logging.torch_logging import ( + logging, +) +from twitter.ml.tensorflow.experimental.distributed import ( + utils, +) +from typing_extensions import ( + Unpack, +) -def is_distributed_worker(): +class Fn(Protocol): + def __call__(self, *args: Tuple[str], **kwargs: Dict[str, Any]) -> None: + ... + + +def is_distributed_worker() -> bool: world_size = os.environ.get("WORLD_SIZE", None) rank = os.environ.get("RANK", None) return world_size is not None and rank is not None def maybe_run_training( - train_fn, - module_name, + train_fn: Fn, + module_name: Any, nproc_per_node: Optional[int] = None, - num_nodes: Optional[int] = None, + num_nodes: Optional[int] = 0, set_python_path_in_subprocess: bool = False, is_chief: Optional[bool] = False, - **training_kwargs, -): + **training_kwargs: Any, +) -> None: """Wrapper function for single node, multi-GPU Pytorch training. If the necessary distributed Pytorch environment variables @@ -73,7 +92,7 @@ def maybe_run_training( ] if nproc_per_node: cmd.extend(["--nproc_per_node", str(nproc_per_node)]) - if num_nodes > 1: + if num_nodes and num_nodes > 1: cluster_resolver = utils.cluster_resolver() backend_address = cluster_resolver.cluster_spec().task_address("chief", 0) cmd.extend( diff --git a/common/test_device.py b/common/test_device.py index fb77920..f37bcc3 100644 --- a/common/test_device.py +++ b/common/test_device.py @@ -2,7 +2,9 @@ Mostly a test that this can be imported properly even tho moved. """ -from unittest.mock import patch +from unittest.mock import ( + patch, +) import tml.common.device as device_utils diff --git a/common/testing_utils.py b/common/testing_utils.py index 69e14de..7c5e6be 100644 --- a/common/testing_utils.py +++ b/common/testing_utils.py @@ -1,11 +1,16 @@ -from contextlib import contextmanager import datetime import os -from unittest.mock import patch +from contextlib import ( + contextmanager, +) +from unittest.mock import ( + patch, +) import torch.distributed as dist -from tml.ml_logging.torch_logging import logging - +from tml.ml_logging.torch_logging import ( + logging, +) MOCK_ENV = { "LOCAL_RANK": "0", diff --git a/common/utils.py b/common/utils.py index 09d4464..e5cf34d 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,18 +1,24 @@ -import yaml import getpass import os import string -from typing import Tuple, Type, TypeVar - -from tml.core.config import base_config +from typing import ( + Tuple, + Type, + TypeVar, +) import fsspec +import yaml +from tml.core.config import ( + base_config, +) C = TypeVar("C", bound=base_config.BaseConfig) +T = TypeVar("T", str, bytes) -def _read_file(f): - with fsspec.open(f) as f: +def _read_file(file: T) -> T: + with fsspec.open(file) as f: return f.read() @@ -34,7 +40,7 @@ def setup_configuration( The pydantic config object. """ - def _substitute(s): + def _substitute(s: str) -> str: if substitute_env_variable: return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser()) return s diff --git a/common/wandb.py b/common/wandb.py index eb8cceb..c5a6972 100644 --- a/common/wandb.py +++ b/common/wandb.py @@ -1,8 +1,11 @@ -from typing import Any, Dict, List - -import tml.core.config as base_config +from typing import ( + Any, + Dict, + List, +) import pydantic +import tml.core.config as base_config class WandbConfig(base_config.BaseConfig): diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..318158f --- /dev/null +++ b/mypy.ini @@ -0,0 +1,22 @@ +[mypy] +exclude = dist +ignore_errors = False +ignore_missing_imports = True +warn_unused_configs = True +warn_redundant_casts = True +warn_unused_ignores = True +no_implicit_optional = True +strict_equality = True +strict_concatenate = True +check_untyped_defs = True +disallow_subclassing_any = False +disallow_untyped_decorators = True +disallow_any_generics = True +disallow_untyped_calls = True +disallow_incomplete_defs = True +disallow_untyped_defs = True +no_implicit_reexport = True +strict = True +warn_return_any = False +explicit_package_bases = True +namespace_packages = True diff --git a/pyproject.toml b/pyproject.toml index 9a05aba..a25f55c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,3 +15,8 @@ exclude = ''' | dist )/ ''' + +[tool.isort] +multi_line_output = 3 +line_length = 21 +include_trailing_comma = true \ No newline at end of file