mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-23 06:41:49 +01:00
fix typo
This commit is contained in:
parent
78c3235eee
commit
b0dbeabc64
@ -101,7 +101,7 @@ class Snapshot:
|
|||||||
weight_tensor,
|
weight_tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Loads pretrained embedding from the snapshot to the model.
|
"""Loads pretrained embedding from the snapshot to the model.
|
||||||
Utilise partial lodaing meachanism from torchsnapshot.
|
Utilise partial lodaing mechanism from torchsnapshot.
|
||||||
Args:
|
Args:
|
||||||
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
|
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
|
||||||
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
|
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
|
||||||
|
@ -11,7 +11,7 @@ def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
|
|||||||
"""Recommend method to load a config file (a yaml file) and parse it.
|
"""Recommend method to load a config file (a yaml file) and parse it.
|
||||||
|
|
||||||
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
Because we have a shared filesystem the recommended route to running jobs it put modified config
|
||||||
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
|
files with the desired parameters somewhere on the filesystem and run jobs pointing to them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _substitute(s):
|
def _substitute(s):
|
||||||
|
@ -115,7 +115,7 @@ def train(
|
|||||||
dataset: data iterator for the training set
|
dataset: data iterator for the training set
|
||||||
evaluation_iterators: data iterators for the different evaluation sets
|
evaluation_iterators: data iterators for the different evaluation sets
|
||||||
scheduler: optional learning rate scheduler
|
scheduler: optional learning rate scheduler
|
||||||
output_transform_for_metrics: optional transformation functions to transorm the model
|
output_transform_for_metrics: optional transformation functions to transform the model
|
||||||
output and labels into a format the metrics can understand
|
output and labels into a format the metrics can understand
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class DatasetConfig(base_config.BaseConfig):
|
|||||||
None, description="Number of shards to keep."
|
None, description="Number of shards to keep."
|
||||||
)
|
)
|
||||||
repeat_files: bool = pydantic.Field(
|
repeat_files: bool = pydantic.Field(
|
||||||
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
True, description="DEPRECATED. Files are repeated no matter what this is set to."
|
||||||
)
|
)
|
||||||
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
|
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ class Sampler(base_config.BaseConfig):
|
|||||||
Only use this for quick experimentation.
|
Only use this for quick experimentation.
|
||||||
If samplers are useful, we should sample from upstream data generation.
|
If samplers are useful, we should sample from upstream data generation.
|
||||||
|
|
||||||
DEPRICATED, DO NOT USE.
|
DEPRECATED, DO NOT USE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@ -234,7 +234,7 @@ class RecapDataConfig(DatasetConfig):
|
|||||||
|
|
||||||
sampler: Sampler = pydantic.Field(
|
sampler: Sampler = pydantic.Field(
|
||||||
None,
|
None,
|
||||||
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
|
description="""DEPRECATED, DO NOT USE. Sampling function for offline experiments.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pydantic.root_validator()
|
@pydantic.root_validator()
|
||||||
|
@ -398,7 +398,7 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
|
"Must specify either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
|
||||||
)
|
)
|
||||||
|
|
||||||
num_files = len(filenames)
|
num_files = len(filenames)
|
||||||
|
@ -9,7 +9,7 @@ from tml.core import config as tml_config_mod
|
|||||||
import tml.projects.home.recap.config as recap_config_mod
|
import tml.projects.home.recap.config as recap_config_mod
|
||||||
|
|
||||||
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
|
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
|
||||||
flags.DEFINE_integer("n_examples", 100, "Numer of examples to generate.")
|
flags.DEFINE_integer("n_examples", 100, "Number of examples to generate.")
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class DropoutConfig(base_config.BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class LayerNormConfig(base_config.BaseConfig):
|
class LayerNormConfig(base_config.BaseConfig):
|
||||||
"""Configruation for the layer normalization."""
|
"""Configuration for the layer normalization."""
|
||||||
|
|
||||||
epsilon: float = pydantic.Field(
|
epsilon: float = pydantic.Field(
|
||||||
1e-3, description="Small float added to variance to avoid dividing by zero."
|
1e-3, description="Small float added to variance to avoid dividing by zero."
|
||||||
@ -91,7 +91,7 @@ class ZScoreLogConfig(base_config.BaseConfig):
|
|||||||
False, description="Option to use batch normalization on the inputs."
|
False, description="Option to use batch normalization on the inputs."
|
||||||
)
|
)
|
||||||
use_renorm: bool = pydantic.Field(
|
use_renorm: bool = pydantic.Field(
|
||||||
False, description="Option to use batch renormalization for trainig and serving consistency."
|
False, description="Option to use batch renormalization for training and serving consistency."
|
||||||
)
|
)
|
||||||
use_bq_stats: bool = pydantic.Field(
|
use_bq_stats: bool = pydantic.Field(
|
||||||
False, description="Option to load the partitioned json files from BQ as statistics."
|
False, description="Option to load the partitioned json files from BQ as statistics."
|
||||||
|
@ -96,7 +96,7 @@ class EdgesDataset(Dataset):
|
|||||||
|
|
||||||
Returns a KeyedJaggedTensor used to look up all embeddings.
|
Returns a KeyedJaggedTensor used to look up all embeddings.
|
||||||
|
|
||||||
Note: We treat the lhs and rhs as though they're separate lookups: `len(lenghts) == 2 * bsz * len(tables)`.
|
Note: We treat the lhs and rhs as though they're separate lookups: `len(lengths) == 2 * bsz * len(tables)`.
|
||||||
This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`.
|
This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`.
|
||||||
|
|
||||||
For the example above:
|
For the example above:
|
||||||
|
Loading…
Reference in New Issue
Block a user