mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-11-16 21:29:24 +01:00
Compare commits
2 Commits
63573c6f22
...
d3e9477fb0
Author | SHA1 | Date | |
---|---|---|---|
|
d3e9477fb0 | ||
|
cc73f5fcb7 |
@ -11,8 +11,8 @@ from tml.projects.home.recap.data import preprocessors
|
||||
from tml.projects.home.recap.config import JobMode
|
||||
from tml.projects.home.recap.data.tfe_parsing import get_seg_dense_parse_fn
|
||||
from tml.projects.home.recap.data.util import (
|
||||
keyed_jagged_tensor_from_tensors_dict,
|
||||
sparse_or_dense_tf_to_torch,
|
||||
keyed_jagged_tensor_from_tensors_dict,
|
||||
sparse_or_dense_tf_to_torch,
|
||||
)
|
||||
from absl import logging
|
||||
import torch.distributed as dist
|
||||
@ -20,458 +20,546 @@ import torch.distributed as dist
|
||||
|
||||
@dataclass
|
||||
class RecapBatch(DataclassBatch):
|
||||
"""Holds features and labels from the Recap dataset."""
|
||||
"""Holds features and labels from the Recap dataset."""
|
||||
|
||||
continuous_features: torch.Tensor
|
||||
binary_features: torch.Tensor
|
||||
discrete_features: torch.Tensor
|
||||
sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821
|
||||
labels: torch.Tensor
|
||||
user_embedding: torch.Tensor = None
|
||||
user_eng_embedding: torch.Tensor = None
|
||||
author_embedding: torch.Tensor = None
|
||||
weights: torch.Tensor = None
|
||||
continuous_features: torch.Tensor
|
||||
binary_features: torch.Tensor
|
||||
discrete_features: torch.Tensor
|
||||
sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821
|
||||
labels: torch.Tensor
|
||||
user_embedding: torch.Tensor = None
|
||||
user_eng_embedding: torch.Tensor = None
|
||||
author_embedding: torch.Tensor = None
|
||||
weights: torch.Tensor = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.weights is None:
|
||||
self.weights = torch.ones_like(self.labels)
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
if ("embedding" in feature_name) and (feature_value is None):
|
||||
setattr(self, feature_name, torch.empty([0, 0]))
|
||||
def __post_init__(self):
|
||||
if self.weights is None:
|
||||
self.weights = torch.ones_like(self.labels)
|
||||
for feature_name, feature_value in self.as_dict().items():
|
||||
if ("embedding" in feature_name) and (feature_value is None):
|
||||
setattr(self, feature_name, torch.empty([0, 0]))
|
||||
|
||||
|
||||
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
||||
"""Converts a torch data loader output into `RecapBatch`."""
|
||||
"""Converts a torch data loader output into `RecapBatch`."""
|
||||
|
||||
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
|
||||
try:
|
||||
features_in, labels = x
|
||||
except ValueError:
|
||||
# For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple
|
||||
features_in, labels = x, None
|
||||
x = tf.nest.map_structure(functools.partial(
|
||||
sparse_or_dense_tf_to_torch, pin_memory=False), x)
|
||||
try:
|
||||
features_in, labels = x
|
||||
except ValueError:
|
||||
# For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple
|
||||
features_in, labels = x, None
|
||||
|
||||
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
|
||||
if sparse_feature_names:
|
||||
sparse_features = keyed_jagged_tensor_from_tensors_dict(
|
||||
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
|
||||
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
|
||||
if sparse_feature_names:
|
||||
sparse_features = keyed_jagged_tensor_from_tensors_dict(
|
||||
{embedding_name: features_in[embedding_name]
|
||||
for embedding_name in sparse_feature_names}
|
||||
)
|
||||
|
||||
user_embedding, user_eng_embedding, author_embedding = None, None, None
|
||||
if "user_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
|
||||
raise ValueError(
|
||||
"Only one source of embedding for user is supported")
|
||||
else:
|
||||
user_embedding = features_in["user_embedding"]
|
||||
|
||||
if "user_eng_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
|
||||
raise ValueError(
|
||||
"Only one source of embedding for user is supported")
|
||||
else:
|
||||
user_eng_embedding = features_in["user_eng_embedding"]
|
||||
|
||||
if "author_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
|
||||
raise ValueError(
|
||||
"Only one source of embedding for user is supported")
|
||||
else:
|
||||
author_embedding = features_in["author_embedding"]
|
||||
|
||||
return RecapBatch(
|
||||
continuous_features=features_in["continuous"],
|
||||
binary_features=features_in["binary"],
|
||||
discrete_features=features_in["discrete"],
|
||||
sparse_features=sparse_features,
|
||||
user_embedding=user_embedding,
|
||||
user_eng_embedding=user_eng_embedding,
|
||||
author_embedding=author_embedding,
|
||||
labels=labels,
|
||||
# Defaults to torch.ones_like(labels)
|
||||
weights=features_in.get("weights", None),
|
||||
)
|
||||
|
||||
user_embedding, user_eng_embedding, author_embedding = None, None, None
|
||||
if "user_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
|
||||
raise ValueError("Only one source of embedding for user is supported")
|
||||
else:
|
||||
user_embedding = features_in["user_embedding"]
|
||||
|
||||
if "user_eng_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
|
||||
raise ValueError("Only one source of embedding for user is supported")
|
||||
else:
|
||||
user_eng_embedding = features_in["user_eng_embedding"]
|
||||
|
||||
if "author_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
|
||||
raise ValueError("Only one source of embedding for user is supported")
|
||||
else:
|
||||
author_embedding = features_in["author_embedding"]
|
||||
|
||||
return RecapBatch(
|
||||
continuous_features=features_in["continuous"],
|
||||
binary_features=features_in["binary"],
|
||||
discrete_features=features_in["discrete"],
|
||||
sparse_features=sparse_features,
|
||||
user_embedding=user_embedding,
|
||||
user_eng_embedding=user_eng_embedding,
|
||||
author_embedding=author_embedding,
|
||||
labels=labels,
|
||||
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels)
|
||||
)
|
||||
|
||||
|
||||
def _chain(param, f1, f2):
|
||||
"""
|
||||
Reduce multiple functions into one chained function
|
||||
_chain(x, f1, f2) -> f2(f1(x))
|
||||
"""
|
||||
output = param
|
||||
fns = [f1, f2]
|
||||
for f in fns:
|
||||
output = f(output)
|
||||
return output
|
||||
"""
|
||||
Reduce multiple functions into one chained function
|
||||
_chain(x, f1, f2) -> f2(f1(x))
|
||||
"""
|
||||
output = param
|
||||
fns = [f1, f2]
|
||||
for f in fns:
|
||||
output = f(output)
|
||||
return output
|
||||
|
||||
|
||||
def _add_weights(inputs, tasks: Dict[str, TaskData]):
|
||||
"""Adds weights based on label sampling for positive and negatives.
|
||||
"""Adds weights based on label sampling for positive and negatives.
|
||||
|
||||
This is useful for numeric calibration etc. This mutates inputs.
|
||||
This is useful for numeric calibration etc. This mutates inputs.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary of strings to tensor-like structures.
|
||||
tasks: A dict of string (label) to `TaskData` specifying inputs.
|
||||
Args:
|
||||
inputs: A dictionary of strings to tensor-like structures.
|
||||
tasks: A dict of string (label) to `TaskData` specifying inputs.
|
||||
|
||||
Returns:
|
||||
A tuple of features and labels; weights are added to features.
|
||||
"""
|
||||
Returns:
|
||||
A tuple of features and labels; weights are added to features.
|
||||
"""
|
||||
|
||||
weights = []
|
||||
for key, task in tasks.items():
|
||||
label = inputs[key]
|
||||
float_label = tf.cast(label, tf.float32)
|
||||
weights = []
|
||||
for key, task in tasks.items():
|
||||
label = inputs[key]
|
||||
float_label = tf.cast(label, tf.float32)
|
||||
|
||||
weights.append(
|
||||
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
||||
)
|
||||
weights.append(
|
||||
float_label / task.pos_downsampling_rate +
|
||||
(1.0 - float_label) / task.neg_downsampling_rate
|
||||
)
|
||||
|
||||
# Ensure we are batch-major (assumes we batch before this call).
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
||||
return inputs
|
||||
# Ensure we are batch-major (assumes we batch before this call).
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(
|
||||
tf.convert_to_tensor(weights)), axis=0)
|
||||
return inputs
|
||||
|
||||
|
||||
def get_datetimes(explicit_datetime_inputs):
|
||||
"""Compute list datetime strings for train/validation data."""
|
||||
datetime_format = "%Y/%m/%d/%H"
|
||||
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)
|
||||
dates = sorted(
|
||||
[
|
||||
(end - timedelta(hours=i + 1)).strftime(datetime_format)
|
||||
for i in range(int(explicit_datetime_inputs.hours))
|
||||
]
|
||||
)
|
||||
return dates
|
||||
"""Compute list datetime strings for train/validation data."""
|
||||
datetime_format = "%Y/%m/%d/%H"
|
||||
end = datetime.strptime(
|
||||
explicit_datetime_inputs.end_datetime, datetime_format)
|
||||
dates = sorted(
|
||||
[
|
||||
(end - timedelta(hours=i + 1)).strftime(datetime_format)
|
||||
for i in range(int(explicit_datetime_inputs.hours))
|
||||
]
|
||||
)
|
||||
return dates
|
||||
|
||||
|
||||
def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
|
||||
"""
|
||||
Compile list of files for training/validation.
|
||||
"""
|
||||
Compile list of files for training/validation.
|
||||
|
||||
Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data.
|
||||
For each hour of data, if the directory is missing or empty, we increment a counter to keep
|
||||
track of the number of missing data hours.
|
||||
Returns only files with a `.gz` extension.
|
||||
Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data.
|
||||
For each hour of data, if the directory is missing or empty, we increment a counter to keep
|
||||
track of the number of missing data hours.
|
||||
Returns only files with a `.gz` extension.
|
||||
|
||||
Args:
|
||||
explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object
|
||||
Args:
|
||||
explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object
|
||||
|
||||
Returns:
|
||||
data_files: Sorted list of files to read corresponding to data at the desired datetimes
|
||||
num_hours_missing: Number of hours that we are missing data
|
||||
Returns:
|
||||
data_files: Sorted list of files to read corresponding to data at the desired datetimes
|
||||
num_hours_missing: Number of hours that we are missing data
|
||||
|
||||
"""
|
||||
datetimes = get_datetimes(explicit_datetime_inputs)
|
||||
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
|
||||
data_files = []
|
||||
num_hours_missing = 0
|
||||
for folder in folders:
|
||||
try:
|
||||
files = tf.io.gfile.listdir(folder)
|
||||
if not files:
|
||||
logging.warning(f"{folder} contained no data files")
|
||||
num_hours_missing += 1
|
||||
data_files.extend(
|
||||
[
|
||||
os.path.join(folder, filename)
|
||||
for filename in files
|
||||
if filename.rsplit(".", 1)[-1].lower() == "gz"
|
||||
]
|
||||
)
|
||||
except tf.errors.NotFoundError as e:
|
||||
num_hours_missing += 1
|
||||
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
|
||||
return sorted(data_files), num_hours_missing
|
||||
"""
|
||||
datetimes = get_datetimes(explicit_datetime_inputs)
|
||||
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime)
|
||||
for datetime in datetimes]
|
||||
data_files = []
|
||||
num_hours_missing = 0
|
||||
for folder in folders:
|
||||
try:
|
||||
files = tf.io.gfile.listdir(folder)
|
||||
if not files:
|
||||
logging.warning(f"{folder} contained no data files")
|
||||
num_hours_missing += 1
|
||||
data_files.extend(
|
||||
[
|
||||
os.path.join(folder, filename)
|
||||
for filename in files
|
||||
if filename.rsplit(".", 1)[-1].lower() == "gz"
|
||||
]
|
||||
)
|
||||
except tf.errors.NotFoundError as e:
|
||||
num_hours_missing += 1
|
||||
logging.warning(
|
||||
f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
|
||||
return sorted(data_files), num_hours_missing
|
||||
|
||||
|
||||
def _map_output_for_inference(
|
||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||
):
|
||||
if preprocessor:
|
||||
raise ValueError("No preprocessor should be used at inference time.")
|
||||
if add_weights:
|
||||
raise NotImplementedError()
|
||||
"""
|
||||
Map the input data for inference.
|
||||
|
||||
# Add zero weights.
|
||||
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1))
|
||||
for label in tasks:
|
||||
del inputs[label]
|
||||
return inputs
|
||||
Args:
|
||||
inputs (dict): Input data dictionary.
|
||||
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
|
||||
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
|
||||
add_weights (bool, optional): Whether to add weights. Defaults to False.
|
||||
|
||||
Returns:
|
||||
dict: Modified input data dictionary for inference.
|
||||
"""
|
||||
if preprocessor:
|
||||
raise ValueError("No preprocessor should be used at inference time.")
|
||||
if add_weights:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Add zero weights.
|
||||
inputs["weights"] = tf.zeros_like(
|
||||
tf.expand_dims(inputs["continuous"][:, 0], -1))
|
||||
for label in tasks:
|
||||
del inputs[label]
|
||||
return inputs
|
||||
|
||||
|
||||
def _map_output_for_train_eval(
|
||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||
):
|
||||
if add_weights:
|
||||
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
|
||||
"""
|
||||
Map the input data for training and evaluation.
|
||||
|
||||
# Warning this has to happen first as it changes the input
|
||||
if preprocessor:
|
||||
inputs = preprocessor(inputs)
|
||||
Args:
|
||||
inputs (dict): Input data dictionary.
|
||||
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
|
||||
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
|
||||
add_weights (bool, optional): Whether to add weights. Defaults to False.
|
||||
|
||||
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
|
||||
Returns:
|
||||
Tuple[dict, tf.Tensor]: Modified input data dictionary and label values for training and evaluation.
|
||||
"""
|
||||
if add_weights:
|
||||
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
|
||||
|
||||
for label in tasks:
|
||||
del inputs[label]
|
||||
# Warning this has to happen first as it changes the input
|
||||
if preprocessor:
|
||||
inputs = preprocessor(inputs)
|
||||
|
||||
return inputs, label_values
|
||||
label_values = tf.squeeze(
|
||||
tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
|
||||
|
||||
for label in tasks:
|
||||
del inputs[label]
|
||||
|
||||
return inputs, label_values
|
||||
|
||||
|
||||
def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
|
||||
"""Adds weights based on label sampling for positive and negatives.
|
||||
"""Adds weights based on label sampling for positive and negatives.
|
||||
|
||||
This is useful for numeric calibration etc. This mutates inputs.
|
||||
This is useful for numeric calibration etc. This mutates inputs.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary of strings to tensor-like structures.
|
||||
tasks: A dict of string (label) to `TaskData` specifying inputs.
|
||||
Args:
|
||||
inputs: A dictionary of strings to tensor-like structures.
|
||||
tasks: A dict of string (label) to `TaskData` specifying inputs.
|
||||
|
||||
Returns:
|
||||
A tuple of features and labels; weights are added to features.
|
||||
"""
|
||||
weights = []
|
||||
for key, task in tasks.items():
|
||||
label = inputs[key]
|
||||
float_label = tf.cast(label, tf.float32)
|
||||
Returns:
|
||||
A tuple of features and labels; weights are added to features.
|
||||
"""
|
||||
weights = []
|
||||
for key, task in tasks.items():
|
||||
label = inputs[key]
|
||||
float_label = tf.cast(label, tf.float32)
|
||||
|
||||
weights.append(
|
||||
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
||||
)
|
||||
weights.append(
|
||||
float_label / task.pos_downsampling_rate +
|
||||
(1.0 - float_label) / task.neg_downsampling_rate
|
||||
)
|
||||
|
||||
# Ensure we are batch-major (assumes we batch before this call).
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
||||
return inputs
|
||||
# Ensure we are batch-major (assumes we batch before this call).
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(
|
||||
tf.convert_to_tensor(weights)), axis=0)
|
||||
return inputs
|
||||
|
||||
|
||||
class RecapDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_config: RecapDataConfig,
|
||||
dataset_service: Optional[str] = None,
|
||||
mode: JobMode = JobMode.TRAIN,
|
||||
compression: Optional[str] = "AUTO",
|
||||
repeat: bool = False,
|
||||
vocab_mapper: tf.keras.Model = None,
|
||||
):
|
||||
logging.info("***** Labels *****")
|
||||
logging.info(list(data_config.tasks.keys()))
|
||||
def __init__(
|
||||
self,
|
||||
data_config: RecapDataConfig,
|
||||
dataset_service: Optional[str] = None,
|
||||
mode: JobMode = JobMode.TRAIN,
|
||||
compression: Optional[str] = "AUTO",
|
||||
repeat: bool = False,
|
||||
vocab_mapper: tf.keras.Model = None,
|
||||
):
|
||||
"""
|
||||
Create a RecapDataset for training or inference.
|
||||
|
||||
self._data_config = data_config
|
||||
self._parse_fn = get_seg_dense_parse_fn(data_config)
|
||||
self._mode = mode
|
||||
self._repeat = repeat
|
||||
self._num_concurrent_iterators = 1
|
||||
self._vocab_mapper = vocab_mapper
|
||||
self.dataset_service = dataset_service
|
||||
Args:
|
||||
data_config (RecapDataConfig): Data configuration.
|
||||
dataset_service (str, optional): Dataset service identifier. Defaults to None.
|
||||
mode (JobMode, optional): Job mode (TRAIN or INFERENCE). Defaults to JobMode.TRAIN.
|
||||
compression (str, optional): Compression type. Defaults to "AUTO".
|
||||
repeat (bool, optional): Whether to repeat the dataset. Defaults to False.
|
||||
vocab_mapper (tf.keras.Model, optional): Vocabulary mapper. Defaults to None.
|
||||
"""
|
||||
logging.info("***** Labels *****")
|
||||
logging.info(list(data_config.tasks.keys()))
|
||||
|
||||
preprocessor = None
|
||||
self._batch_size_multiplier = 1
|
||||
if data_config.preprocess:
|
||||
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)
|
||||
if data_config.preprocess.downsample_negatives:
|
||||
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
|
||||
self._data_config = data_config
|
||||
self._parse_fn = get_seg_dense_parse_fn(data_config)
|
||||
self._mode = mode
|
||||
self._repeat = repeat
|
||||
self._num_concurrent_iterators = 1
|
||||
self._vocab_mapper = vocab_mapper
|
||||
self.dataset_service = dataset_service
|
||||
|
||||
self._preprocessor = preprocessor
|
||||
preprocessor = None
|
||||
self._batch_size_multiplier = 1
|
||||
if data_config.preprocess:
|
||||
preprocessor = preprocessors.build_preprocess(
|
||||
data_config.preprocess, mode=mode)
|
||||
if data_config.preprocess.downsample_negatives:
|
||||
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
|
||||
|
||||
if mode == JobMode.INFERENCE:
|
||||
if preprocessor is not None:
|
||||
raise ValueError("Expect no preprocessor at inference time.")
|
||||
should_add_weights = False
|
||||
output_map_fn = _map_output_for_inference # (features,)
|
||||
else:
|
||||
# Only add weights if there is a reason to! If all weights will
|
||||
# be equal to 1.0, save bandwidth between DDS and Chief by simply
|
||||
# relying on the fact that weights default to 1.0 in `RecapBatch`
|
||||
# WARNING: Weights may still be added as a side effect of a preprocessor
|
||||
# such as `DownsampleNegatives`.
|
||||
should_add_weights = any(
|
||||
[
|
||||
task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0
|
||||
for task_cfg in data_config.tasks.values()
|
||||
]
|
||||
)
|
||||
output_map_fn = _map_output_for_train_eval # (features, labels)
|
||||
self._preprocessor = preprocessor
|
||||
|
||||
self._output_map_fn = functools.partial(
|
||||
output_map_fn,
|
||||
tasks=data_config.tasks,
|
||||
preprocessor=preprocessor,
|
||||
add_weights=should_add_weights,
|
||||
)
|
||||
if mode == JobMode.INFERENCE:
|
||||
if preprocessor is not None:
|
||||
raise ValueError("Expect no preprocessor at inference time.")
|
||||
should_add_weights = False
|
||||
output_map_fn = _map_output_for_inference # (features,)
|
||||
else:
|
||||
# Only add weights if there is a reason to! If all weights will
|
||||
# be equal to 1.0, save bandwidth between DDS and Chief by simply
|
||||
# relying on the fact that weights default to 1.0 in `RecapBatch`
|
||||
# WARNING: Weights may still be added as a side effect of a preprocessor
|
||||
# such as `DownsampleNegatives`.
|
||||
should_add_weights = any(
|
||||
[
|
||||
task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0
|
||||
for task_cfg in data_config.tasks.values()
|
||||
]
|
||||
)
|
||||
output_map_fn = _map_output_for_train_eval # (features, labels)
|
||||
|
||||
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
|
||||
|
||||
self._tf_dataset = self._create_tf_dataset()
|
||||
|
||||
self._init_tensor_spec()
|
||||
|
||||
def _init_tensor_spec(self):
|
||||
def _tensor_spec_to_torch_shape(spec):
|
||||
if spec.shape is None:
|
||||
return None
|
||||
shape = [x if x is not None else -1 for x in spec.shape]
|
||||
return torch.Size(shape)
|
||||
|
||||
self.torch_element_spec = tf.nest.map_structure(
|
||||
_tensor_spec_to_torch_shape, self._tf_dataset.element_spec
|
||||
)
|
||||
|
||||
def _create_tf_dataset(self):
|
||||
if hasattr(self, "_tf_dataset"):
|
||||
raise ValueError("Do not call `_create_tf_dataset` more than once.")
|
||||
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
per_replica_bsz = (
|
||||
self._batch_size_multiplier * self._data_config.global_batch_size // world_size
|
||||
)
|
||||
|
||||
dataset: tf.data.Dataset = self._create_base_tf_dataset(
|
||||
batch_size=per_replica_bsz,
|
||||
)
|
||||
|
||||
if self._repeat:
|
||||
logging.info("Repeating dataset")
|
||||
dataset = dataset.repeat()
|
||||
|
||||
if self.dataset_service:
|
||||
if self._num_concurrent_iterators > 1:
|
||||
if not self.machines_config:
|
||||
raise ValueError(
|
||||
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
|
||||
)
|
||||
dataset = dataset_lib.with_auto_tune_budget(
|
||||
dataset,
|
||||
machine_config=self.machines_config.chief,
|
||||
num_concurrent_iterators=self.num_concurrent_iterators,
|
||||
on_chief=False,
|
||||
self._output_map_fn = functools.partial(
|
||||
output_map_fn,
|
||||
tasks=data_config.tasks,
|
||||
preprocessor=preprocessor,
|
||||
add_weights=should_add_weights,
|
||||
)
|
||||
|
||||
self.dataset_id, self.job_name = register_dataset(
|
||||
dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
|
||||
)
|
||||
dataset = distribute_from_dataset_id(
|
||||
dataset_id=self.dataset_id, # type: ignore[arg-type]
|
||||
job_name=self.job_name,
|
||||
dataset_service=self.dataset_service,
|
||||
compression=self.compression,
|
||||
)
|
||||
sparse_feature_names = list(
|
||||
vocab_mapper.vocabs.keys()) if vocab_mapper else None
|
||||
|
||||
elif self._num_concurrent_iterators > 1:
|
||||
if not self.machines_config:
|
||||
raise ValueError(
|
||||
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
|
||||
self._tf_dataset = self._create_tf_dataset()
|
||||
|
||||
self._init_tensor_spec()
|
||||
|
||||
def _init_tensor_spec(self):
|
||||
"""
|
||||
Initialize the tensor specification for the dataset.
|
||||
"""
|
||||
def _tensor_spec_to_torch_shape(spec):
|
||||
if spec.shape is None:
|
||||
return None
|
||||
shape = [x if x is not None else -1 for x in spec.shape]
|
||||
return torch.Size(shape)
|
||||
|
||||
self.torch_element_spec = tf.nest.map_structure(
|
||||
_tensor_spec_to_torch_shape, self._tf_dataset.element_spec
|
||||
)
|
||||
dataset = dataset_lib.with_auto_tune_budget(
|
||||
dataset,
|
||||
machine_config=self.machines_config.chief,
|
||||
num_concurrent_iterators=self._num_concurrent_iterators,
|
||||
on_chief=True,
|
||||
)
|
||||
|
||||
# Vocabulary mapping happens on the training node, not in dds because of size.
|
||||
if self._vocab_mapper:
|
||||
dataset = dataset.map(self._vocab_mapper)
|
||||
def _create_tf_dataset(self):
|
||||
"""
|
||||
Create a TensorFlow dataset from the data files.
|
||||
|
||||
return dataset.prefetch(world_size * 2)
|
||||
Returns:
|
||||
tf.data.Dataset: TensorFlow dataset.
|
||||
"""
|
||||
if hasattr(self, "_tf_dataset"):
|
||||
raise ValueError(
|
||||
"Do not call `_create_tf_dataset` more than once.")
|
||||
|
||||
def _create_base_tf_dataset(self, batch_size: int):
|
||||
if self._data_config.inputs:
|
||||
glob = self._data_config.inputs
|
||||
filenames = sorted(tf.io.gfile.glob(glob))
|
||||
elif self._data_config.explicit_datetime_inputs:
|
||||
num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol
|
||||
filenames, num_hours_missing = get_explicit_datetime_inputs_files(
|
||||
self._data_config.explicit_datetime_inputs,
|
||||
increment="hourly",
|
||||
)
|
||||
if num_hours_missing > num_missing_hours_tol:
|
||||
raise ValueError(
|
||||
f"We are missing {num_hours_missing} hours of data"
|
||||
f"more than tolerance {num_missing_hours_tol}."
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
per_replica_bsz = (
|
||||
self._batch_size_multiplier * self._data_config.global_batch_size // world_size
|
||||
)
|
||||
elif self._data_config.explicit_date_inputs:
|
||||
num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol
|
||||
filenames, num_days_missing = get_explicit_datetime_inputs_files(
|
||||
self._data_config.explicit_date_inputs,
|
||||
increment="daily",
|
||||
)
|
||||
if num_days_missing > num_missing_days_tol:
|
||||
raise ValueError(
|
||||
f"We are missing {num_days_missing} days of data"
|
||||
f"more than tolerance {num_missing_days_tol}."
|
||||
|
||||
dataset: tf.data.Dataset = self._create_base_tf_dataset(
|
||||
batch_size=per_replica_bsz,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
|
||||
)
|
||||
|
||||
num_files = len(filenames)
|
||||
logging.info(f"Found {num_files} data files")
|
||||
if num_files < 1:
|
||||
raise ValueError("No data files found")
|
||||
if self._repeat:
|
||||
logging.info("Repeating dataset")
|
||||
dataset = dataset.repeat()
|
||||
|
||||
if self._data_config.num_files_to_keep is not None:
|
||||
filenames = filenames[: self._data_config.num_files_to_keep]
|
||||
logging.info(f"Retaining only {len(filenames)} files.")
|
||||
if self.dataset_service:
|
||||
if self._num_concurrent_iterators > 1:
|
||||
if not self.machines_config:
|
||||
raise ValueError(
|
||||
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
|
||||
)
|
||||
dataset = dataset_lib.with_auto_tune_budget(
|
||||
dataset,
|
||||
machine_config=self.machines_config.chief,
|
||||
num_concurrent_iterators=self.num_concurrent_iterators,
|
||||
on_chief=False,
|
||||
)
|
||||
|
||||
filenames_ds = (
|
||||
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
|
||||
# Because of drop_remainder, if our dataset does not fill
|
||||
# up a batch, it will emit nothing without this repeat.
|
||||
.repeat(-1)
|
||||
)
|
||||
self.dataset_id, self.job_name = register_dataset(
|
||||
dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
|
||||
)
|
||||
dataset = distribute_from_dataset_id(
|
||||
dataset_id=self.dataset_id, # type: ignore[arg-type]
|
||||
job_name=self.job_name,
|
||||
dataset_service=self.dataset_service,
|
||||
compression=self.compression,
|
||||
)
|
||||
|
||||
if self._data_config.file_batch_size:
|
||||
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)
|
||||
elif self._num_concurrent_iterators > 1:
|
||||
if not self.machines_config:
|
||||
raise ValueError(
|
||||
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
|
||||
)
|
||||
dataset = dataset_lib.with_auto_tune_budget(
|
||||
dataset,
|
||||
machine_config=self.machines_config.chief,
|
||||
num_concurrent_iterators=self._num_concurrent_iterators,
|
||||
on_chief=True,
|
||||
)
|
||||
|
||||
def per_shard_dataset(filename):
|
||||
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
|
||||
return ds.prefetch(4)
|
||||
# Vocabulary mapping happens on the training node, not in dds because of size.
|
||||
if self._vocab_mapper:
|
||||
dataset = dataset.map(self._vocab_mapper)
|
||||
|
||||
ds = filenames_ds.interleave(
|
||||
per_shard_dataset,
|
||||
block_length=4,
|
||||
deterministic=False,
|
||||
num_parallel_calls=self._data_config.interleave_num_parallel_calls
|
||||
or tf.data.experimental.AUTOTUNE,
|
||||
)
|
||||
return dataset.prefetch(world_size * 2)
|
||||
|
||||
# Combine functions into one map call to reduce overhead.
|
||||
map_fn = functools.partial(
|
||||
_chain,
|
||||
f1=self._parse_fn,
|
||||
f2=self._output_map_fn,
|
||||
)
|
||||
def _create_base_tf_dataset(self, batch_size: int):
|
||||
"""
|
||||
Create the base TensorFlow dataset.
|
||||
|
||||
# Shuffle -> Batch -> Parse is the correct ordering
|
||||
# Shuffling needs to be performed before batching otherwise there is not much point
|
||||
# Batching happens before parsing because tf.Example parsing is actually vectorized
|
||||
# and works much faster overall on batches of data.
|
||||
ds = (
|
||||
# DANGER DANGER: there is a default shuffle size here.
|
||||
ds.shuffle(self._data_config.examples_shuffle_buffer_size)
|
||||
.batch(batch_size=batch_size, drop_remainder=True)
|
||||
.map(
|
||||
map_fn,
|
||||
num_parallel_calls=self._data_config.map_num_parallel_calls
|
||||
or tf.data.experimental.AUTOTUNE,
|
||||
)
|
||||
)
|
||||
Args:
|
||||
batch_size (int): Batch size.
|
||||
|
||||
if self._data_config.cache:
|
||||
ds = ds.cache()
|
||||
Returns:
|
||||
tf.data.Dataset: Base TensorFlow dataset.
|
||||
"""
|
||||
if self._data_config.inputs:
|
||||
glob = self._data_config.inputs
|
||||
filenames = sorted(tf.io.gfile.glob(glob))
|
||||
elif self._data_config.explicit_datetime_inputs:
|
||||
num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol
|
||||
filenames, num_hours_missing = get_explicit_datetime_inputs_files(
|
||||
self._data_config.explicit_datetime_inputs,
|
||||
increment="hourly",
|
||||
)
|
||||
if num_hours_missing > num_missing_hours_tol:
|
||||
raise ValueError(
|
||||
f"We are missing {num_hours_missing} hours of data"
|
||||
f"more than tolerance {num_missing_hours_tol}."
|
||||
)
|
||||
elif self._data_config.explicit_date_inputs:
|
||||
num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol
|
||||
filenames, num_days_missing = get_explicit_datetime_inputs_files(
|
||||
self._data_config.explicit_date_inputs,
|
||||
increment="daily",
|
||||
)
|
||||
if num_days_missing > num_missing_days_tol:
|
||||
raise ValueError(
|
||||
f"We are missing {num_days_missing} days of data"
|
||||
f"more than tolerance {num_missing_days_tol}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
|
||||
)
|
||||
|
||||
if self._data_config.ignore_data_errors:
|
||||
ds = ds.apply(tf.data.experimental.ignore_errors())
|
||||
num_files = len(filenames)
|
||||
logging.info(f"Found {num_files} data files")
|
||||
if num_files < 1:
|
||||
raise ValueError("No data files found")
|
||||
|
||||
options = tf.data.Options()
|
||||
options.experimental_deterministic = False
|
||||
ds = ds.with_options(options)
|
||||
if self._data_config.num_files_to_keep is not None:
|
||||
filenames = filenames[: self._data_config.num_files_to_keep]
|
||||
logging.info(f"Retaining only {len(filenames)} files.")
|
||||
|
||||
return ds
|
||||
filenames_ds = (
|
||||
tf.data.Dataset.from_tensor_slices(
|
||||
filenames).shuffle(len(filenames))
|
||||
# Because of drop_remainder, if our dataset does not fill
|
||||
# up a batch, it will emit nothing without this repeat.
|
||||
.repeat(-1)
|
||||
)
|
||||
|
||||
def _gen(self):
|
||||
for x in self._tf_dataset:
|
||||
yield to_batch(x)
|
||||
if self._data_config.file_batch_size:
|
||||
filenames_ds = filenames_ds.batch(
|
||||
self._data_config.file_batch_size)
|
||||
|
||||
def to_dataloader(self) -> Dict[str, torch.Tensor]:
|
||||
return torch.utils.data.DataLoader(self, batch_size=None)
|
||||
def per_shard_dataset(filename):
|
||||
"""
|
||||
Create a TensorFlow dataset for a single shard file.
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._gen())
|
||||
Args:
|
||||
filename (str): Path to the shard file.
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset: TensorFlow dataset for the shard file.
|
||||
"""
|
||||
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
|
||||
return ds.prefetch(4)
|
||||
|
||||
ds = filenames_ds.interleave(
|
||||
per_shard_dataset,
|
||||
block_length=4,
|
||||
deterministic=False,
|
||||
num_parallel_calls=self._data_config.interleave_num_parallel_calls
|
||||
or tf.data.experimental.AUTOTUNE,
|
||||
)
|
||||
|
||||
# Combine functions into one map call to reduce overhead.
|
||||
map_fn = functools.partial(
|
||||
_chain,
|
||||
f1=self._parse_fn,
|
||||
f2=self._output_map_fn,
|
||||
)
|
||||
|
||||
# Shuffle -> Batch -> Parse is the correct ordering
|
||||
# Shuffling needs to be performed before batching otherwise there is not much point
|
||||
# Batching happens before parsing because tf.Example parsing is actually vectorized
|
||||
# and works much faster overall on batches of data.
|
||||
ds = (
|
||||
# DANGER DANGER: there is a default shuffle size here.
|
||||
ds.shuffle(self._data_config.examples_shuffle_buffer_size)
|
||||
.batch(batch_size=batch_size, drop_remainder=True)
|
||||
.map(
|
||||
map_fn,
|
||||
num_parallel_calls=self._data_config.map_num_parallel_calls
|
||||
or tf.data.experimental.AUTOTUNE,
|
||||
)
|
||||
)
|
||||
|
||||
if self._data_config.cache:
|
||||
ds = ds.cache()
|
||||
|
||||
if self._data_config.ignore_data_errors:
|
||||
ds = ds.apply(tf.data.experimental.ignore_errors())
|
||||
|
||||
options = tf.data.Options()
|
||||
options.experimental_deterministic = False
|
||||
ds = ds.with_options(options)
|
||||
|
||||
return ds
|
||||
|
||||
def _gen(self):
|
||||
for x in self._tf_dataset:
|
||||
yield to_batch(x)
|
||||
|
||||
def to_dataloader(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Convert the RecapDataset to a PyTorch DataLoader.
|
||||
|
||||
Returns:
|
||||
torch.utils.data.DataLoader: PyTorch DataLoader for the dataset.
|
||||
"""
|
||||
return torch.utils.data.DataLoader(self, batch_size=None)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._gen())
|
||||
|
@ -17,6 +17,16 @@ FLAGS = flags.FLAGS
|
||||
def _generate_random_example(
|
||||
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
|
||||
) -> Dict[str, tf.Tensor]:
|
||||
"""
|
||||
Generate a random example based on the provided TensorFlow example schema.
|
||||
|
||||
Args:
|
||||
tf_example_schema (Dict[str, tf.io.FixedLenFeature]): A dictionary defining the schema of the TensorFlow example.
|
||||
|
||||
Returns:
|
||||
Dict[str, tf.Tensor]: A dictionary containing random data for each feature defined in the schema.
|
||||
"""
|
||||
|
||||
example = {}
|
||||
for feature_name, feature_spec in tf_example_schema.items():
|
||||
dtype = feature_spec.dtype
|
||||
@ -33,14 +43,43 @@ def _generate_random_example(
|
||||
|
||||
|
||||
def _float_feature(value):
|
||||
"""
|
||||
Create a TensorFlow float feature.
|
||||
|
||||
Args:
|
||||
value: A float or list of floats.
|
||||
|
||||
Returns:
|
||||
tf.train.Feature: A TensorFlow feature containing the float value(s).
|
||||
"""
|
||||
|
||||
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
||||
|
||||
|
||||
def _int64_feature(value):
|
||||
"""
|
||||
Create a TensorFlow int64 feature.
|
||||
|
||||
Args:
|
||||
value: An integer or list of integers.
|
||||
|
||||
Returns:
|
||||
tf.train.Feature: A TensorFlow feature containing the int64 value(s).
|
||||
"""
|
||||
|
||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
||||
|
||||
|
||||
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
|
||||
"""
|
||||
Serialize a dictionary of TensorFlow tensors into a binary string.
|
||||
|
||||
Args:
|
||||
x (Dict[str, tf.Tensor]): A dictionary of TensorFlow tensors.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized binary string.
|
||||
"""
|
||||
feature = {}
|
||||
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
|
||||
for feature_name, tensor in x.items():
|
||||
@ -51,6 +90,15 @@ def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
|
||||
|
||||
|
||||
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
|
||||
"""
|
||||
Generate random data based on the provided configuration and save it as a TFRecord file.
|
||||
|
||||
Args:
|
||||
data_path (str): The path where the TFRecord file will be saved.
|
||||
config (recap_config_mod.RecapConfig): The configuration for generating the random data.
|
||||
"""
|
||||
|
||||
|
||||
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
|
||||
seg_dense_schema = json.load(f)["schema"]
|
||||
|
||||
@ -68,6 +116,12 @@ def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
|
||||
|
||||
|
||||
def _generate_data_main(unused_argv):
|
||||
"""
|
||||
Main function to generate random data according to the provided configuration.
|
||||
|
||||
Args:
|
||||
unused_argv: Unused command-line arguments.
|
||||
"""
|
||||
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
||||
|
||||
# Find the path where to put the data
|
||||
|
Loading…
Reference in New Issue
Block a user