This commit is contained in:
sincekmori 2023-04-01 04:44:00 +09:00
parent 78c3235eee
commit b0dbeabc64
8 changed files with 11 additions and 11 deletions

View File

@ -101,7 +101,7 @@ class Snapshot:
weight_tensor, weight_tensor,
) -> None: ) -> None:
"""Loads pretrained embedding from the snapshot to the model. """Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot. Utilise partial lodaing mechanism from torchsnapshot.
Args: Args:
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC). embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC. snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.

View File

@ -11,7 +11,7 @@ def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
"""Recommend method to load a config file (a yaml file) and parse it. """Recommend method to load a config file (a yaml file) and parse it.
Because we have a shared filesystem the recommended route to running jobs it put modified config Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them. files with the desired parameters somewhere on the filesystem and run jobs pointing to them.
""" """
def _substitute(s): def _substitute(s):

View File

@ -115,7 +115,7 @@ def train(
dataset: data iterator for the training set dataset: data iterator for the training set
evaluation_iterators: data iterators for the different evaluation sets evaluation_iterators: data iterators for the different evaluation sets
scheduler: optional learning rate scheduler scheduler: optional learning rate scheduler
output_transform_for_metrics: optional transformation functions to transorm the model output_transform_for_metrics: optional transformation functions to transform the model
output and labels into a format the metrics can understand output and labels into a format the metrics can understand
""" """

View File

@ -50,7 +50,7 @@ class DatasetConfig(base_config.BaseConfig):
None, description="Number of shards to keep." None, description="Number of shards to keep."
) )
repeat_files: bool = pydantic.Field( repeat_files: bool = pydantic.Field(
True, description="DEPRICATED. Files are repeated no matter what this is set to." True, description="DEPRECATED. Files are repeated no matter what this is set to."
) )
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size") file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
@ -211,7 +211,7 @@ class Sampler(base_config.BaseConfig):
Only use this for quick experimentation. Only use this for quick experimentation.
If samplers are useful, we should sample from upstream data generation. If samplers are useful, we should sample from upstream data generation.
DEPRICATED, DO NOT USE. DEPRECATED, DO NOT USE.
""" """
name: str name: str
@ -234,7 +234,7 @@ class RecapDataConfig(DatasetConfig):
sampler: Sampler = pydantic.Field( sampler: Sampler = pydantic.Field(
None, None,
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""", description="""DEPRECATED, DO NOT USE. Sampling function for offline experiments.""",
) )
@pydantic.root_validator() @pydantic.root_validator()

View File

@ -398,7 +398,7 @@ class RecapDataset(torch.utils.data.IterableDataset):
) )
else: else:
raise ValueError( raise ValueError(
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config" "Must specify either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
) )
num_files = len(filenames) num_files = len(filenames)

View File

@ -9,7 +9,7 @@ from tml.core import config as tml_config_mod
import tml.projects.home.recap.config as recap_config_mod import tml.projects.home.recap.config as recap_config_mod
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.") flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
flags.DEFINE_integer("n_examples", 100, "Numer of examples to generate.") flags.DEFINE_integer("n_examples", 100, "Number of examples to generate.")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS

View File

@ -18,7 +18,7 @@ class DropoutConfig(base_config.BaseConfig):
class LayerNormConfig(base_config.BaseConfig): class LayerNormConfig(base_config.BaseConfig):
"""Configruation for the layer normalization.""" """Configuration for the layer normalization."""
epsilon: float = pydantic.Field( epsilon: float = pydantic.Field(
1e-3, description="Small float added to variance to avoid dividing by zero." 1e-3, description="Small float added to variance to avoid dividing by zero."
@ -91,7 +91,7 @@ class ZScoreLogConfig(base_config.BaseConfig):
False, description="Option to use batch normalization on the inputs." False, description="Option to use batch normalization on the inputs."
) )
use_renorm: bool = pydantic.Field( use_renorm: bool = pydantic.Field(
False, description="Option to use batch renormalization for trainig and serving consistency." False, description="Option to use batch renormalization for training and serving consistency."
) )
use_bq_stats: bool = pydantic.Field( use_bq_stats: bool = pydantic.Field(
False, description="Option to load the partitioned json files from BQ as statistics." False, description="Option to load the partitioned json files from BQ as statistics."

View File

@ -96,7 +96,7 @@ class EdgesDataset(Dataset):
Returns a KeyedJaggedTensor used to look up all embeddings. Returns a KeyedJaggedTensor used to look up all embeddings.
Note: We treat the lhs and rhs as though they're separate lookups: `len(lenghts) == 2 * bsz * len(tables)`. Note: We treat the lhs and rhs as though they're separate lookups: `len(lengths) == 2 * bsz * len(tables)`.
This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`. This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`.
For the example above: For the example above: