Twitter's Recommendation Algorithm - Heavy Ranker and TwHIN embeddings

This commit is contained in:
twitter-team
2023-03-31 13:05:14 -05:00
commit 78c3235eee
111 changed files with 11876 additions and 0 deletions

0
common/__init__.py Normal file
View File

85
common/batch.py Normal file
View File

@ -0,0 +1,85 @@
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
"""
# flake8: noqa
from __future__ import annotations
from typing import Dict
import abc
from dataclasses import dataclass
import dataclasses
import torch
from torchrec.streamable import Pipelineable
class BatchBase(Pipelineable, abc.ABC):
@abc.abstractmethod
def as_dict(self) -> Dict:
raise NotImplementedError
def to(self, device: torch.device, non_blocking: bool = False):
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
return self.__class__(**args)
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)
def pin_memory(self):
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)
def __repr__(self) -> str:
def obj2str(v):
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
@property
def batch_size(self) -> int:
for tensor in self.as_dict().values():
if tensor is None:
continue
if not isinstance(tensor, torch.Tensor):
continue
return tensor.shape[0]
raise Exception("Could not determine batch size from tensors.")
@dataclass
class DataclassBatch(BatchBase):
@classmethod
def feature_names(cls):
return list(cls.__dataclass_fields__.keys())
def as_dict(self):
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
if hasattr(self, feature_name)
}
@staticmethod
def from_schema(name: str, schema):
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
bases=(DataclassBatch,),
)
@staticmethod
def from_fields(name: str, fields: dict):
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
bases=(DataclassBatch,),
)
class DictionaryBatch(BatchBase, dict):
def as_dict(self) -> Dict:
return self

View File

@ -0,0 +1 @@
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot

View File

@ -0,0 +1,260 @@
import os
import time
from typing import Any, Dict, List, Optional
from tml.ml_logging.torch_logging import logging
from tml.common.filesystem import infer_fs, is_gcs_fs
import torchsnapshot
DONE_EVAL_SUBDIR = "evaled_by"
GCS_PREFIX = "gs://"
class Snapshot:
"""Checkpoints using torchsnapshot.
Also saves step to be updated by the training loop.
"""
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
self.save_dir = save_dir
self.state = state
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property
def step(self):
return self.state["extra_state"]["step"]
@step.setter
def step(self, step: int) -> None:
self.state["extra_state"]["step"] = step
@property
def walltime(self):
return self.state["extra_state"]["walltime"]
@walltime.setter
def walltime(self, walltime: float) -> None:
self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot":
"""Saves checkpoint with given global_step."""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
start_time = time.time()
# Take a snapshot in async manner, the snapshot is consistent that state changes after this method returns have no effect on the snapshot. It performs storage I/O in the background.
snapshot = torchsnapshot.Snapshot.async_take(
app_state=self.state,
path=path,
# commented out because DistributedModelParallel model saving
# errors with this on multi-GPU. With it removed, CPU, single
# GPU, and multi-GPU training all successfully checkpoint.
# replicated=["**"],
)
logging.info(f"Snapshot saved to {snapshot.path} ({time.time() - start_time:.05}s")
return snapshot
def restore(self, checkpoint: str) -> None:
"""Restores a given checkpoint."""
snapshot = torchsnapshot.Snapshot(path=checkpoint)
logging.info(f"Restoring snapshot from {snapshot.path}.")
start_time = time.time()
# We can remove the try-except when we are confident that we no longer need to restore from
# checkpoints from before walltime was added
try:
# checkpoints that do not have extra_state[walltime] will fail here
snapshot.restore(self.state)
except RuntimeError:
# extra_state[walltime] does not exist in the checkpoint, but step should be there so restore it
self.state["extra_state"] = torchsnapshot.StateDict(step=0)
snapshot.restore(self.state)
# we still need to ensure that extra_state has walltime in it
self.state["extra_state"] = torchsnapshot.StateDict(step=self.step, walltime=0.0)
logging.info(f"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s")
@classmethod
def get_torch_snapshot(
cls,
snapshot_path: str,
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> torchsnapshot.Snapshot:
"""Get torch stateless snapshot, without actually loading it.
Args:
snapshot_path: path to the model snapshot
global_step: restores from this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
path = get_checkpoint(snapshot_path, global_step, missing_ok)
logging.info(f"Loading snapshot from {path}.")
return torchsnapshot.Snapshot(path=path)
@classmethod
def load_snapshot_to_weight(
cls,
embedding_snapshot: torchsnapshot.Snapshot,
snapshot_emb_name: str,
weight_tensor,
) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
Args:
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
"""
start_time = time.time()
manifest = embedding_snapshot.get_manifest()
for path in manifest.keys():
if path.startswith("0") and snapshot_emb_name in path:
snapshot_path_to_load = path
embedding_snapshot.read_object(snapshot_path_to_load, weight_tensor)
logging.info(
f"Loaded embedding snapshot from {snapshot_path_to_load}: {time.time() - start_time:.05}s",
rank=-1,
)
logging.info(f"Snapshot loaded to {weight_tensor.metadata()}", rank=-1)
def _eval_subdir(checkpoint_path: str) -> str:
return os.path.join(checkpoint_path, DONE_EVAL_SUBDIR)
def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
def is_done_eval(checkpoint_path: str, eval_partition: str):
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
def mark_done_eval(checkpoint_path: str, eval_partition: str):
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
def step_from_checkpoint(checkpoint: str) -> int:
return int(os.path.basename(checkpoint))
def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
"""Simplified equivalent of tf.train.checkpoints_iterator.
Args:
seconds_to_sleep: time between polling calls.
timeout: how long to wait for a new checkpoint.
"""
def _poll(last_checkpoint: Optional[str] = None):
stop_time = time.time() + timeout
while True:
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
if not _checkpoint_path or _checkpoint_path == last_checkpoint:
if time.time() + seconds_to_sleep > stop_time:
logging.info(
f"Timed out waiting for next available checkpoint from {save_dir} for {timeout}s."
)
return None
logging.info(f"Waiting for next available checkpoint from {save_dir}.")
time.sleep(seconds_to_sleep)
else:
logging.info(f"Found latest checkpoint {_checkpoint_path}.")
return _checkpoint_path
checkpoint_path = None
while True:
new_checkpoint = _poll(checkpoint_path)
if not new_checkpoint:
return
checkpoint_path = new_checkpoint
yield checkpoint_path
def get_checkpoint(
save_dir: str,
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> str:
"""Gets latest checkpoint or checkpoint at specified global_step.
Args:
global_step: Finds this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
checkpoints = get_checkpoints(save_dir)
if not checkpoints:
if not missing_ok:
raise Exception(f"No checkpoints found at {save_dir}")
else:
logging.info(f"No checkpoints found for restoration at {save_dir}.")
return ""
if global_step is None:
return checkpoints[-1]
logging.info(f"Found checkpoints: {checkpoints}")
for checkpoint in checkpoints:
step = step_from_checkpoint(checkpoint)
if global_step == step:
chosen_checkpoint = checkpoint
break
else:
raise Exception(f"Desired checkpoint at {global_step} not found in {save_dir}")
return chosen_checkpoint
def get_checkpoints(save_dir: str) -> List[str]:
"""Gets all checkpoints that have been fully written."""
checkpoints = []
fs = infer_fs(save_dir)
if fs.exists(save_dir):
prefix = GCS_PREFIX if is_gcs_fs(fs) else ""
checkpoints = list(f"{prefix}{elem}" for elem in fs.ls(save_dir, detail=False))
# Only take checkpoints that were fully written.
checkpoints = list(
filter(
lambda path: fs.exists(f"{path}/{torchsnapshot.snapshot.SNAPSHOT_METADATA_FNAME}"),
checkpoints,
)
)
checkpoints = sorted(checkpoints, key=lambda path: int(os.path.basename(path)))
return checkpoints
def wait_for_evaluators(
save_dir: str,
partition_names: List[str],
global_step: int,
timeout: int,
) -> None:
logging.info("Waiting for all evaluators to finish.")
start_time = time.time()
for checkpoint in checkpoints_iterator(save_dir):
step = step_from_checkpoint(checkpoint)
logging.info(f"Considering checkpoint {checkpoint} for global step {global_step}.")
if step == global_step:
while partition_names:
if is_done_eval(checkpoint, partition_names[-1]):
logging.info(
f"Checkpoint {checkpoint} marked as finished eval for partition {partition_names[-1]} at step {step}, still waiting for {partition_names}."
)
partition_names.pop()
if time.time() - start_time >= timeout:
logging.warning(
f"Not all evaluators finished after waiting for {time.time() - start_time}"
)
return
time.sleep(10)
logging.info("All evaluators finished.")
return
if time.time() - start_time >= timeout:
logging.warning(f"Not all evaluators finished after waiting for {time.time() - start_time}")
return

30
common/device.py Normal file
View File

@ -0,0 +1,30 @@
import os
import torch
import torch.distributed as dist
def maybe_setup_tensorflow():
try:
import tensorflow as tf
except ImportError:
pass
else:
tf.config.set_visible_devices([], "GPU") # disable tf gpu
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
if tf_ok:
maybe_setup_tensorflow()
device = torch.device("cpu")
backend = "gloo"
if torch.cuda.is_available():
rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
if not torch.distributed.is_initialized():
dist.init_process_group(backend)
return device

View File

@ -0,0 +1 @@
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs

View File

@ -0,0 +1,16 @@
"""Minimal test for infer_fs.
Mostly a test that it returns an object
"""
from tml.common.filesystem import infer_fs
def test_infer_fs():
local_path = "/tmp/local_path"
gcs_path = "gs://somebucket/somepath"
local_fs = infer_fs(local_path)
gcs_fs = infer_fs(gcs_path)
# This should return two different objects
assert local_fs != gcs_fs

25
common/filesystem/util.py Normal file
View File

@ -0,0 +1,25 @@
"""Utilities for interacting with the file systems."""
from fsspec.implementations.local import LocalFileSystem
import gcsfs
GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
LOCAL_FS = LocalFileSystem()
def infer_fs(path: str):
if path.startswith("gs://"):
return GCS_FS
elif path.startswith("hdfs://"):
# We can probably use pyarrow HDFS to support this.
raise NotImplementedError("HDFS not yet supported")
else:
return LOCAL_FS
def is_local_fs(fs):
return fs == LOCAL_FS
def is_gcs_fs(fs):
return fs == GCS_FS

90
common/log_weights.py Normal file
View File

@ -0,0 +1,90 @@
"""For logging model weights."""
import itertools
from typing import Callable, Dict, List, Optional, Union
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
def weights_to_log(
model: torch.nn.Module,
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
):
"""Creates dict of reduced weights to log to give sense of training.
Args:
model: model to traverse.
how_to_log: if a function, then applies this to every parameter, if a dict
then only applies and logs specified parameters.
"""
if not how_to_log:
return
to_log = dict()
named_parameters = model.named_parameters()
logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}")
if isinstance(model, DistributedModelParallel):
named_parameters = itertools.chain(
named_parameters, model._dmp_wrapped_module.named_parameters()
)
logging.info(
f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}"
)
for param_name, params in named_parameters:
if callable(how_to_log):
how = how_to_log
else:
how = how_to_log.get(param_name) # type: ignore[assignment]
if not how:
continue # type: ignore
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
return to_log
def log_ebc_norms(
model_state_dict,
ebc_keys: List[str],
sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:
"""Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank.
Args:
model_state_dict: model.state_dict()
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
"""
norm_logs = dict()
for emb_key in ebc_keys:
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
if emb_key in model_state_dict:
emb_weight = model_state_dict[emb_key]
try:
emb_weight_tensor = emb_weight.local_tensor()
except AttributeError as e:
logging.info(e)
emb_weight_tensor = emb_weight
logging.info("Running Tensor.detach()")
emb_weight_tensor = emb_weight_tensor.detach()
sample_mask = torch.randperm(emb_weight_tensor.shape[0])[
: min(sample_size, emb_weight_tensor.shape[0])
]
# WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks
# Change sample_size if the you observe frequent OOM errors or remove weight logging.
norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32)
logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1)
norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}"))
all_norms = [
torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size())
]
dist.all_gather(all_norms, norms)
for idx, norm in enumerate(all_norms):
if norm != -1.0:
norm_logs[f"{emb_key}-norm-{idx}"] = norm
logging.info(f"Norm Logs are {norm_logs}")
return norm_logs

View File

@ -0,0 +1,59 @@
from typing import List
from enum import Enum
import tml.core.config as base_config
from tml.optimizers.config import OptimizerConfig
import pydantic
class DataType(str, Enum):
FP32 = "fp32"
FP16 = "fp16"
class EmbeddingSnapshot(base_config.BaseConfig):
"""Configuration for Embedding snapshot"""
emb_name: str = pydantic.Field(
..., description="Name of the embedding table from the loaded snapshot"
)
embedding_snapshot_uri: str = pydantic.Field(
..., description="Path to torchsnapshot of the embedding"
)
class EmbeddingBagConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBag."""
name: str = pydantic.Field(..., description="name of embedding bag")
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
vocab: str = pydantic.Field(
None, description="Directory to parquet files of mapping from entity ID to table index."
)
# make sure to use an optimizer that matches:
# https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15
optimizer: OptimizerConfig
data_type: DataType
class LargeEmbeddingsConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBagCollection.
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
"""
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
tables_to_log: List[str] = pydantic.Field(
None, description="list of embedding table names that we want to log during training"
)
class Mode(str, Enum):
"""Job modes."""
TRAIN = "train"
EVALUATE = "evaluate"
INFERENCE = "inference"

View File

@ -0,0 +1,58 @@
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
from tml.ml_logging.torch_logging import logging
import torch
from torch import nn
import torchrec
from torchrec.modules import embedding_configs
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
import numpy as np
class LargeEmbeddings(nn.Module):
def __init__(
self,
large_embeddings_config: LargeEmbeddingsConfig,
):
super().__init__()
tables = []
for table in large_embeddings_config.tables:
data_type = (
embedding_configs.DataType.FP32
if (table.data_type == DataType.FP32)
else embedding_configs.DataType.FP16
)
tables.append(
EmbeddingBagConfig(
embedding_dim=table.embedding_dim,
feature_names=[table.name], # restricted to 1 feature per table for now
name=table.name,
num_embeddings=table.num_embeddings,
pooling=torchrec.PoolingType.SUM,
data_type=data_type,
)
)
self.ebc = EmbeddingBagCollection(
device="meta",
tables=tables,
)
logging.info("********************** EBC named params are **********")
logging.info(list(self.ebc.named_parameters()))
# This hook is used to perform post-processing surgery
# on large_embedding models to prep them for serving
self.surgery_cut_point = torch.nn.Identity()
def forward(
self,
sparse_features: KeyedJaggedTensor,
) -> KeyedTensor:
pooled_embs = self.ebc(sparse_features)
# a KeyedTensor
return self.surgery_cut_point(pooled_embs)

109
common/run_training.py Normal file
View File

@ -0,0 +1,109 @@
import os
import subprocess
import sys
from typing import Optional
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
from twitter.ml.tensorflow.experimental.distributed import utils
import torch
import torch.distributed.run
def is_distributed_worker():
world_size = os.environ.get("WORLD_SIZE", None)
rank = os.environ.get("RANK", None)
return world_size is not None and rank is not None
def maybe_run_training(
train_fn,
module_name,
nproc_per_node: Optional[int] = None,
num_nodes: Optional[int] = None,
set_python_path_in_subprocess: bool = False,
is_chief: Optional[bool] = False,
**training_kwargs,
):
"""Wrapper function for single node, multi-GPU Pytorch training.
If the necessary distributed Pytorch environment variables
(WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module
`module_name`. After this call, the necessary environment variables are set
and training will commence.
Args:
train_fn: The function that is responsible for training
module_name: The name of the module that this function was called from;
used to indicate torchrun entrypoint.
nproc_per_node: Number of workers per node; supported values.
num_nodes: Number of nodes, otherwise inferred from environment.
is_chief: If process is running on chief.
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
"""
machines = utils.machine_from_env()
if num_nodes is None:
num_nodes = 1
if machines.num_workers:
num_nodes += machines.num_workers
if is_distributed_worker():
# world_size, rank, etc are set; assuming any other env vars are set (checks to come)
# start the actual training!
train_fn(**training_kwargs)
else:
if nproc_per_node is None:
if torch.cuda.is_available():
nproc_per_node = torch.cuda.device_count()
else:
nproc_per_node = machines.chief.num_accelerators
# Rejoin all arguments to send back through torchrec
# this is a temporary measure, will replace the os.system call
# with torchrun API calls
args = list(f"--{key}={val}" for key, val in training_kwargs.items())
cmd = [
"--nnodes",
str(num_nodes),
]
if nproc_per_node:
cmd.extend(["--nproc_per_node", str(nproc_per_node)])
if num_nodes > 1:
cluster_resolver = utils.cluster_resolver()
backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
cmd.extend(
[
"--rdzv_backend",
"c10d",
"--rdzv_id",
backend_address,
]
)
# Set localhost on chief because of https://github.com/pytorch/pytorch/issues/79388
if is_chief:
cmd.extend(["--rdzv_endpoint", "localhost:2222"])
else:
cmd.extend(["--rdzv_endpoint", backend_address])
else:
cmd.append("--standalone")
cmd.extend(
[
str(module_name),
*args,
]
)
logging.info(f"""Distributed running with cmd: '{" ".join(cmd)}'""")
# Call torchrun on this module; will spawn new processes and re-run this
# function, eventually calling "train_fn". The following line sets the PYTHONPATH to accommodate
# bazel stubbing for the main binary.
if set_python_path_in_subprocess:
subprocess.run(["torchrun"] + cmd, env={**os.environ, "PYTHONPATH": ":".join(sys.path)})
else:
torch.distributed.run.main(cmd)

13
common/test_device.py Normal file
View File

@ -0,0 +1,13 @@
"""Minimal test for device.
Mostly a test that this can be imported properly even tho moved.
"""
from unittest.mock import patch
import tml.common.device as device_utils
def test_device():
with patch("tml.common.device.dist.init_process_group"):
device = device_utils.setup_and_get_device(tf_ok=False)
assert device.type == "cpu"

33
common/testing_utils.py Normal file
View File

@ -0,0 +1,33 @@
from contextlib import contextmanager
import datetime
import os
from unittest.mock import patch
import torch.distributed as dist
from tml.ml_logging.torch_logging import logging
MOCK_ENV = {
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"LOCAL_WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "29501",
"RANK": "0",
}
@contextmanager
def mock_pg():
with patch.dict(os.environ, MOCK_ENV):
try:
dist.init_process_group(
backend="gloo",
timeout=datetime.timedelta(1),
)
yield
except:
dist.destroy_process_group()
raise
finally:
dist.destroy_process_group()

44
common/utils.py Normal file
View File

@ -0,0 +1,44 @@
import yaml
import getpass
import os
import string
from typing import Tuple, Type, TypeVar
from tml.core.config import base_config
import fsspec
C = TypeVar("C", bound=base_config.BaseConfig)
def _read_file(f):
with fsspec.open(f) as f:
return f.read()
def setup_configuration(
config_type: Type[C],
yaml_path: str,
substitute_env_variable: bool = False,
) -> Tuple[C, str]:
"""Resolves a config at a yaml path.
Args:
config_type: Pydantic config class to load.
yaml_path: yaml path of the config file.
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
environment variable value whenever possible. If an environment variable doesn't exist,
the string is left unchanged.
Returns:
The pydantic config object.
"""
def _substitute(s):
if substitute_env_variable:
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
return s
assert config_type is not None, "can't use all_config without config_type"
content = _substitute(yaml.safe_load(_read_file(yaml_path)))
return config_type.parse_obj(content)

20
common/wandb.py Normal file
View File

@ -0,0 +1,20 @@
from typing import Any, Dict, List
import tml.core.config as base_config
import pydantic
class WandbConfig(base_config.BaseConfig):
host: str = pydantic.Field(
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
description="Host of Weights and Biases instance, passed to login.",
)
key_path: str = pydantic.Field(description="Path to key file.")
name: str = pydantic.Field(None, description="Name of the experiment, passed to init.")
entity: str = pydantic.Field(None, description="Name of user/service account, passed to init.")
project: str = pydantic.Field(None, description="Name of wandb project, passed to init.")
tags: List[str] = pydantic.Field([], description="List of tags, passed to init.")
notes: str = pydantic.Field(None, description="Notes, passed to init.")
metadata: Dict[str, Any] = pydantic.Field(None, description="Additional metadata to log.")