mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-26 21:45:26 +01:00
add & config mypy on common package
Signed-off-by: wiseaidev <business@wiseai.dev>
This commit is contained in:
parent
78c3235eee
commit
2cc1abedd7
@ -1,22 +1,40 @@
|
|||||||
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
|
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
|
||||||
"""
|
"""
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
from __future__ import annotations
|
from __future__ import (
|
||||||
from typing import Dict
|
annotations,
|
||||||
|
)
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
from collections import (
|
||||||
|
UserDict,
|
||||||
|
)
|
||||||
|
from dataclasses import (
|
||||||
|
dataclass,
|
||||||
|
)
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchrec.streamable import Pipelineable
|
from torchrec.streamable import (
|
||||||
|
Pipelineable,
|
||||||
|
)
|
||||||
|
|
||||||
|
_KT = TypeVar("_KT") # key type
|
||||||
|
_VT = TypeVar("_VT") # value type
|
||||||
|
|
||||||
|
|
||||||
class BatchBase(Pipelineable, abc.ABC):
|
class BatchBase(Pipelineable, abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to(self, device: torch.device, non_blocking: bool = False):
|
def to(self, device: torch.device, non_blocking: bool = False) -> BatchBase:
|
||||||
args = {}
|
args = {}
|
||||||
for feature_name, feature_value in self.as_dict().items():
|
for feature_name, feature_value in self.as_dict().items():
|
||||||
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
|
||||||
@ -26,14 +44,14 @@ class BatchBase(Pipelineable, abc.ABC):
|
|||||||
for feature_value in self.as_dict().values():
|
for feature_value in self.as_dict().values():
|
||||||
feature_value.record_stream(stream)
|
feature_value.record_stream(stream)
|
||||||
|
|
||||||
def pin_memory(self):
|
def pin_memory(self) -> BatchBase:
|
||||||
args = {}
|
args = {}
|
||||||
for feature_name, feature_value in self.as_dict().items():
|
for feature_name, feature_value in self.as_dict().items():
|
||||||
args[feature_name] = feature_value.pin_memory()
|
args[feature_name] = feature_value.pin_memory()
|
||||||
return self.__class__(**args)
|
return self.__class__(**args)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
def obj2str(v):
|
def obj2str(v: Any) -> str:
|
||||||
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
|
||||||
|
|
||||||
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
|
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
|
||||||
@ -52,10 +70,10 @@ class BatchBase(Pipelineable, abc.ABC):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class DataclassBatch(BatchBase):
|
class DataclassBatch(BatchBase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def feature_names(cls):
|
def feature_names(cls) -> List[str]:
|
||||||
return list(cls.__dataclass_fields__.keys())
|
return list(cls.__dataclass_fields__.keys())
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
feature_name: getattr(self, feature_name)
|
feature_name: getattr(self, feature_name)
|
||||||
for feature_name in self.feature_names()
|
for feature_name in self.feature_names()
|
||||||
@ -63,7 +81,7 @@ class DataclassBatch(BatchBase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_schema(name: str, schema):
|
def from_schema(name: str, schema: Any) -> type:
|
||||||
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
|
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
|
||||||
return dataclasses.make_dataclass(
|
return dataclasses.make_dataclass(
|
||||||
cls_name=name,
|
cls_name=name,
|
||||||
@ -72,7 +90,7 @@ class DataclassBatch(BatchBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_fields(name: str, fields: dict):
|
def from_fields(name: str, fields: Dict[str, Any]) -> type:
|
||||||
return dataclasses.make_dataclass(
|
return dataclasses.make_dataclass(
|
||||||
cls_name=name,
|
cls_name=name,
|
||||||
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
|
||||||
@ -80,6 +98,6 @@ class DataclassBatch(BatchBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DictionaryBatch(BatchBase, dict):
|
class DictionaryBatch(BatchBase, UserDict[_KT, _VT]):
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
return self
|
return self
|
||||||
|
@ -1 +1,4 @@
|
|||||||
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot
|
from tml.common.checkpointing.snapshot import (
|
||||||
|
Snapshot,
|
||||||
|
get_checkpoint,
|
||||||
|
)
|
||||||
|
@ -1,12 +1,24 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import (
|
||||||
|
Any,
|
||||||
from tml.ml_logging.torch_logging import logging
|
Dict,
|
||||||
from tml.common.filesystem import infer_fs, is_gcs_fs
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
import torchsnapshot
|
import torchsnapshot
|
||||||
|
from tml.common.filesystem import (
|
||||||
|
infer_fs,
|
||||||
|
is_gcs_fs,
|
||||||
|
)
|
||||||
|
from tml.ml_logging.torch_logging import (
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from torch import (
|
||||||
|
FloatTensor,
|
||||||
|
)
|
||||||
|
|
||||||
DONE_EVAL_SUBDIR = "evaled_by"
|
DONE_EVAL_SUBDIR = "evaled_by"
|
||||||
GCS_PREFIX = "gs://"
|
GCS_PREFIX = "gs://"
|
||||||
@ -25,7 +37,7 @@ class Snapshot:
|
|||||||
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step(self):
|
def step(self) -> int:
|
||||||
return self.state["extra_state"]["step"]
|
return self.state["extra_state"]["step"]
|
||||||
|
|
||||||
@step.setter
|
@step.setter
|
||||||
@ -33,14 +45,14 @@ class Snapshot:
|
|||||||
self.state["extra_state"]["step"] = step
|
self.state["extra_state"]["step"] = step
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def walltime(self):
|
def walltime(self) -> float:
|
||||||
return self.state["extra_state"]["walltime"]
|
return self.state["extra_state"]["walltime"]
|
||||||
|
|
||||||
@walltime.setter
|
@walltime.setter
|
||||||
def walltime(self, walltime: float) -> None:
|
def walltime(self, walltime: float) -> None:
|
||||||
self.state["extra_state"]["walltime"] = walltime
|
self.state["extra_state"]["walltime"] = walltime
|
||||||
|
|
||||||
def save(self, global_step: int) -> "PendingSnapshot":
|
def save(self, global_step: int) -> "PendingSnapshot": # type: ignore
|
||||||
"""Saves checkpoint with given global_step."""
|
"""Saves checkpoint with given global_step."""
|
||||||
path = os.path.join(self.save_dir, str(global_step))
|
path = os.path.join(self.save_dir, str(global_step))
|
||||||
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
|
||||||
@ -98,7 +110,7 @@ class Snapshot:
|
|||||||
cls,
|
cls,
|
||||||
embedding_snapshot: torchsnapshot.Snapshot,
|
embedding_snapshot: torchsnapshot.Snapshot,
|
||||||
snapshot_emb_name: str,
|
snapshot_emb_name: str,
|
||||||
weight_tensor,
|
weight_tensor: FloatTensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Loads pretrained embedding from the snapshot to the model.
|
"""Loads pretrained embedding from the snapshot to the model.
|
||||||
Utilise partial lodaing meachanism from torchsnapshot.
|
Utilise partial lodaing meachanism from torchsnapshot.
|
||||||
@ -128,11 +140,11 @@ def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
|
|||||||
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
|
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
|
||||||
|
|
||||||
|
|
||||||
def is_done_eval(checkpoint_path: str, eval_partition: str):
|
def is_done_eval(checkpoint_path: str, eval_partition: str) -> bool:
|
||||||
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
|
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition)) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def mark_done_eval(checkpoint_path: str, eval_partition: str):
|
def mark_done_eval(checkpoint_path: str, eval_partition: str) -> Any:
|
||||||
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
|
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
|
||||||
|
|
||||||
|
|
||||||
@ -140,7 +152,9 @@ def step_from_checkpoint(checkpoint: str) -> int:
|
|||||||
return int(os.path.basename(checkpoint))
|
return int(os.path.basename(checkpoint))
|
||||||
|
|
||||||
|
|
||||||
def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
|
def checkpoints_iterator(
|
||||||
|
save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
"""Simplified equivalent of tf.train.checkpoints_iterator.
|
"""Simplified equivalent of tf.train.checkpoints_iterator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -149,7 +163,7 @@ def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _poll(last_checkpoint: Optional[str] = None):
|
def _poll(last_checkpoint: Optional[str] = None) -> Optional[str]:
|
||||||
stop_time = time.time() + timeout
|
stop_time = time.time() + timeout
|
||||||
while True:
|
while True:
|
||||||
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
|
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
def maybe_setup_tensorflow():
|
def maybe_setup_tensorflow() -> None:
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -1 +1,5 @@
|
|||||||
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs
|
from tml.common.filesystem.util import (
|
||||||
|
infer_fs,
|
||||||
|
is_gcs_fs,
|
||||||
|
is_local_fs,
|
||||||
|
)
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
Mostly a test that it returns an object
|
Mostly a test that it returns an object
|
||||||
"""
|
"""
|
||||||
from tml.common.filesystem import infer_fs
|
from tml.common.filesystem import (
|
||||||
|
infer_fs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_infer_fs():
|
def test_infer_fs():
|
||||||
|
@ -1,13 +1,18 @@
|
|||||||
"""Utilities for interacting with the file systems."""
|
"""Utilities for interacting with the file systems."""
|
||||||
from fsspec.implementations.local import LocalFileSystem
|
from typing import (
|
||||||
import gcsfs
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import gcsfs
|
||||||
|
from fsspec.implementations.local import (
|
||||||
|
LocalFileSystem,
|
||||||
|
)
|
||||||
|
|
||||||
GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
|
GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
|
||||||
LOCAL_FS = LocalFileSystem()
|
LOCAL_FS = LocalFileSystem()
|
||||||
|
|
||||||
|
|
||||||
def infer_fs(path: str):
|
def infer_fs(path: str) -> Union[LocalFileSystem, gcsfs.core.GCSFileSystem, NotImplementedError]:
|
||||||
if path.startswith("gs://"):
|
if path.startswith("gs://"):
|
||||||
return GCS_FS
|
return GCS_FS
|
||||||
elif path.startswith("hdfs://"):
|
elif path.startswith("hdfs://"):
|
||||||
@ -17,9 +22,9 @@ def infer_fs(path: str):
|
|||||||
return LOCAL_FS
|
return LOCAL_FS
|
||||||
|
|
||||||
|
|
||||||
def is_local_fs(fs):
|
def is_local_fs(fs: LocalFileSystem) -> bool:
|
||||||
return fs == LOCAL_FS
|
return fs == LOCAL_FS
|
||||||
|
|
||||||
|
|
||||||
def is_gcs_fs(fs):
|
def is_gcs_fs(fs: gcsfs.core.GCSFileSystem) -> bool:
|
||||||
return fs == GCS_FS
|
return fs == GCS_FS
|
||||||
|
@ -1,17 +1,28 @@
|
|||||||
"""For logging model weights."""
|
"""For logging model weights."""
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torchrec.distributed.model_parallel import DistributedModelParallel
|
from tml.ml_logging.torch_logging import (
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from torchrec.distributed.model_parallel import (
|
||||||
|
DistributedModelParallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def weights_to_log(
|
def weights_to_log(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
|
how_to_log: Optional[Union[Callable[[Any], Any], Dict[str, Callable[[Any], Any]]]] = None,
|
||||||
):
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Creates dict of reduced weights to log to give sense of training.
|
"""Creates dict of reduced weights to log to give sense of training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -21,7 +32,7 @@ def weights_to_log(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if not how_to_log:
|
if not how_to_log:
|
||||||
return
|
return None
|
||||||
|
|
||||||
to_log = dict()
|
to_log = dict()
|
||||||
named_parameters = model.named_parameters()
|
named_parameters = model.named_parameters()
|
||||||
@ -38,14 +49,14 @@ def weights_to_log(
|
|||||||
how = how_to_log
|
how = how_to_log
|
||||||
else:
|
else:
|
||||||
how = how_to_log.get(param_name) # type: ignore[assignment]
|
how = how_to_log.get(param_name) # type: ignore[assignment]
|
||||||
if not how:
|
if how is None:
|
||||||
continue # type: ignore
|
continue
|
||||||
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
|
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
|
||||||
return to_log
|
return to_log
|
||||||
|
|
||||||
|
|
||||||
def log_ebc_norms(
|
def log_ebc_norms(
|
||||||
model_state_dict,
|
model_state_dict: Dict[str, Any],
|
||||||
ebc_keys: List[str],
|
ebc_keys: List[str],
|
||||||
sample_size: int = 4_000_000,
|
sample_size: int = 4_000_000,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
from typing import List
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
import tml.core.config as base_config
|
List,
|
||||||
from tml.optimizers.config import OptimizerConfig
|
)
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
import tml.core.config as base_config
|
||||||
|
from tml.optimizers.config import (
|
||||||
|
OptimizerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DataType(str, Enum):
|
class DataType(str, Enum):
|
||||||
|
@ -1,13 +1,25 @@
|
|||||||
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
|
|
||||||
from tml.ml_logging.torch_logging import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torchrec
|
|
||||||
from torchrec.modules import embedding_configs
|
|
||||||
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
|
|
||||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchrec
|
||||||
|
from tml.common.modules.embedding.config import (
|
||||||
|
DataType,
|
||||||
|
LargeEmbeddingsConfig,
|
||||||
|
)
|
||||||
|
from tml.ml_logging.torch_logging import (
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from torchrec import (
|
||||||
|
EmbeddingBagCollection,
|
||||||
|
EmbeddingBagConfig,
|
||||||
|
)
|
||||||
|
from torchrec.modules import (
|
||||||
|
embedding_configs,
|
||||||
|
)
|
||||||
|
from torchrec.sparse.jagged_tensor import (
|
||||||
|
KeyedJaggedTensor,
|
||||||
|
KeyedTensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LargeEmbeddings(nn.Module):
|
class LargeEmbeddings(nn.Module):
|
||||||
|
@ -1,30 +1,49 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
from typing import (
|
||||||
|
Any,
|
||||||
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
Callable,
|
||||||
from twitter.ml.tensorflow.experimental.distributed import utils
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Tuple,
|
||||||
|
TypedDict,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed.run
|
import torch.distributed.run
|
||||||
|
from tml.ml_logging.torch_logging import (
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from twitter.ml.tensorflow.experimental.distributed import (
|
||||||
|
utils,
|
||||||
|
)
|
||||||
|
from typing_extensions import (
|
||||||
|
Unpack,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_distributed_worker():
|
class Fn(Protocol):
|
||||||
|
def __call__(self, *args: Tuple[str], **kwargs: Dict[str, Any]) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed_worker() -> bool:
|
||||||
world_size = os.environ.get("WORLD_SIZE", None)
|
world_size = os.environ.get("WORLD_SIZE", None)
|
||||||
rank = os.environ.get("RANK", None)
|
rank = os.environ.get("RANK", None)
|
||||||
return world_size is not None and rank is not None
|
return world_size is not None and rank is not None
|
||||||
|
|
||||||
|
|
||||||
def maybe_run_training(
|
def maybe_run_training(
|
||||||
train_fn,
|
train_fn: Fn,
|
||||||
module_name,
|
module_name: Any,
|
||||||
nproc_per_node: Optional[int] = None,
|
nproc_per_node: Optional[int] = None,
|
||||||
num_nodes: Optional[int] = None,
|
num_nodes: Optional[int] = 0,
|
||||||
set_python_path_in_subprocess: bool = False,
|
set_python_path_in_subprocess: bool = False,
|
||||||
is_chief: Optional[bool] = False,
|
is_chief: Optional[bool] = False,
|
||||||
**training_kwargs,
|
**training_kwargs: Any,
|
||||||
):
|
) -> None:
|
||||||
"""Wrapper function for single node, multi-GPU Pytorch training.
|
"""Wrapper function for single node, multi-GPU Pytorch training.
|
||||||
|
|
||||||
If the necessary distributed Pytorch environment variables
|
If the necessary distributed Pytorch environment variables
|
||||||
@ -73,7 +92,7 @@ def maybe_run_training(
|
|||||||
]
|
]
|
||||||
if nproc_per_node:
|
if nproc_per_node:
|
||||||
cmd.extend(["--nproc_per_node", str(nproc_per_node)])
|
cmd.extend(["--nproc_per_node", str(nproc_per_node)])
|
||||||
if num_nodes > 1:
|
if num_nodes and num_nodes > 1:
|
||||||
cluster_resolver = utils.cluster_resolver()
|
cluster_resolver = utils.cluster_resolver()
|
||||||
backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
|
backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
|
||||||
cmd.extend(
|
cmd.extend(
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
Mostly a test that this can be imported properly even tho moved.
|
Mostly a test that this can be imported properly even tho moved.
|
||||||
"""
|
"""
|
||||||
from unittest.mock import patch
|
from unittest.mock import (
|
||||||
|
patch,
|
||||||
|
)
|
||||||
|
|
||||||
import tml.common.device as device_utils
|
import tml.common.device as device_utils
|
||||||
|
|
||||||
|
@ -1,11 +1,16 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from contextlib import (
|
||||||
|
contextmanager,
|
||||||
|
)
|
||||||
|
from unittest.mock import (
|
||||||
|
patch,
|
||||||
|
)
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from tml.ml_logging.torch_logging import logging
|
from tml.ml_logging.torch_logging import (
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
MOCK_ENV = {
|
MOCK_ENV = {
|
||||||
"LOCAL_RANK": "0",
|
"LOCAL_RANK": "0",
|
||||||
|
@ -1,18 +1,24 @@
|
|||||||
import yaml
|
|
||||||
import getpass
|
import getpass
|
||||||
import os
|
import os
|
||||||
import string
|
import string
|
||||||
from typing import Tuple, Type, TypeVar
|
from typing import (
|
||||||
|
Tuple,
|
||||||
from tml.core.config import base_config
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
|
import yaml
|
||||||
|
from tml.core.config import (
|
||||||
|
base_config,
|
||||||
|
)
|
||||||
|
|
||||||
C = TypeVar("C", bound=base_config.BaseConfig)
|
C = TypeVar("C", bound=base_config.BaseConfig)
|
||||||
|
T = TypeVar("T", str, bytes)
|
||||||
|
|
||||||
|
|
||||||
def _read_file(f):
|
def _read_file(file: T) -> T:
|
||||||
with fsspec.open(f) as f:
|
with fsspec.open(file) as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
@ -34,7 +40,7 @@ def setup_configuration(
|
|||||||
The pydantic config object.
|
The pydantic config object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _substitute(s):
|
def _substitute(s: str) -> str:
|
||||||
if substitute_env_variable:
|
if substitute_env_variable:
|
||||||
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
|
||||||
return s
|
return s
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import (
|
||||||
|
Any,
|
||||||
import tml.core.config as base_config
|
Dict,
|
||||||
|
List,
|
||||||
|
)
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
import tml.core.config as base_config
|
||||||
|
|
||||||
|
|
||||||
class WandbConfig(base_config.BaseConfig):
|
class WandbConfig(base_config.BaseConfig):
|
||||||
|
22
mypy.ini
Normal file
22
mypy.ini
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
[mypy]
|
||||||
|
exclude = dist
|
||||||
|
ignore_errors = False
|
||||||
|
ignore_missing_imports = True
|
||||||
|
warn_unused_configs = True
|
||||||
|
warn_redundant_casts = True
|
||||||
|
warn_unused_ignores = True
|
||||||
|
no_implicit_optional = True
|
||||||
|
strict_equality = True
|
||||||
|
strict_concatenate = True
|
||||||
|
check_untyped_defs = True
|
||||||
|
disallow_subclassing_any = False
|
||||||
|
disallow_untyped_decorators = True
|
||||||
|
disallow_any_generics = True
|
||||||
|
disallow_untyped_calls = True
|
||||||
|
disallow_incomplete_defs = True
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
no_implicit_reexport = True
|
||||||
|
strict = True
|
||||||
|
warn_return_any = False
|
||||||
|
explicit_package_bases = True
|
||||||
|
namespace_packages = True
|
@ -15,3 +15,8 @@ exclude = '''
|
|||||||
| dist
|
| dist
|
||||||
)/
|
)/
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
multi_line_output = 3
|
||||||
|
line_length = 21
|
||||||
|
include_trailing_comma = true
|
Loading…
x
Reference in New Issue
Block a user