This commit is contained in:
rajveer43 2023-09-21 18:18:15 +05:30
parent 0813989fd9
commit cc73f5fcb7
2 changed files with 515 additions and 373 deletions

View File

@ -43,7 +43,8 @@ class RecapBatch(DataclassBatch):
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
"""Converts a torch data loader output into `RecapBatch`."""
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
x = tf.nest.map_structure(functools.partial(
sparse_or_dense_tf_to_torch, pin_memory=False), x)
try:
features_in, labels = x
except ValueError:
@ -53,25 +54,29 @@ def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
if sparse_feature_names:
sparse_features = keyed_jagged_tensor_from_tensors_dict(
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
{embedding_name: features_in[embedding_name]
for embedding_name in sparse_feature_names}
)
user_embedding, user_eng_embedding, author_embedding = None, None, None
if "user_embedding" in features_in:
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
raise ValueError(
"Only one source of embedding for user is supported")
else:
user_embedding = features_in["user_embedding"]
if "user_eng_embedding" in features_in:
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
raise ValueError(
"Only one source of embedding for user is supported")
else:
user_eng_embedding = features_in["user_eng_embedding"]
if "author_embedding" in features_in:
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
raise ValueError(
"Only one source of embedding for user is supported")
else:
author_embedding = features_in["author_embedding"]
@ -84,7 +89,8 @@ def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
user_eng_embedding=user_eng_embedding,
author_embedding=author_embedding,
labels=labels,
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels)
# Defaults to torch.ones_like(labels)
weights=features_in.get("weights", None),
)
@ -119,18 +125,21 @@ def _add_weights(inputs, tasks: Dict[str, TaskData]):
float_label = tf.cast(label, tf.float32)
weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
float_label / task.pos_downsampling_rate +
(1.0 - float_label) / task.neg_downsampling_rate
)
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
inputs["weights"] = tf.squeeze(tf.transpose(
tf.convert_to_tensor(weights)), axis=0)
return inputs
def get_datetimes(explicit_datetime_inputs):
"""Compute list datetime strings for train/validation data."""
datetime_format = "%Y/%m/%d/%H"
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)
end = datetime.strptime(
explicit_datetime_inputs.end_datetime, datetime_format)
dates = sorted(
[
(end - timedelta(hours=i + 1)).strftime(datetime_format)
@ -158,7 +167,8 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
"""
datetimes = get_datetimes(explicit_datetime_inputs)
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime)
for datetime in datetimes]
data_files = []
num_hours_missing = 0
for folder in folders:
@ -176,20 +186,34 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
)
except tf.errors.NotFoundError as e:
num_hours_missing += 1
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
logging.warning(
f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
return sorted(data_files), num_hours_missing
def _map_output_for_inference(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
):
"""
Map the input data for inference.
Args:
inputs (dict): Input data dictionary.
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
add_weights (bool, optional): Whether to add weights. Defaults to False.
Returns:
dict: Modified input data dictionary for inference.
"""
if preprocessor:
raise ValueError("No preprocessor should be used at inference time.")
if add_weights:
raise NotImplementedError()
# Add zero weights.
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1))
inputs["weights"] = tf.zeros_like(
tf.expand_dims(inputs["continuous"][:, 0], -1))
for label in tasks:
del inputs[label]
return inputs
@ -198,6 +222,18 @@ def _map_output_for_inference(
def _map_output_for_train_eval(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
):
"""
Map the input data for training and evaluation.
Args:
inputs (dict): Input data dictionary.
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
add_weights (bool, optional): Whether to add weights. Defaults to False.
Returns:
Tuple[dict, tf.Tensor]: Modified input data dictionary and label values for training and evaluation.
"""
if add_weights:
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
@ -205,7 +241,8 @@ def _map_output_for_train_eval(
if preprocessor:
inputs = preprocessor(inputs)
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
label_values = tf.squeeze(
tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
for label in tasks:
del inputs[label]
@ -231,11 +268,13 @@ def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
float_label = tf.cast(label, tf.float32)
weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
float_label / task.pos_downsampling_rate +
(1.0 - float_label) / task.neg_downsampling_rate
)
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
inputs["weights"] = tf.squeeze(tf.transpose(
tf.convert_to_tensor(weights)), axis=0)
return inputs
@ -249,6 +288,17 @@ class RecapDataset(torch.utils.data.IterableDataset):
repeat: bool = False,
vocab_mapper: tf.keras.Model = None,
):
"""
Create a RecapDataset for training or inference.
Args:
data_config (RecapDataConfig): Data configuration.
dataset_service (str, optional): Dataset service identifier. Defaults to None.
mode (JobMode, optional): Job mode (TRAIN or INFERENCE). Defaults to JobMode.TRAIN.
compression (str, optional): Compression type. Defaults to "AUTO".
repeat (bool, optional): Whether to repeat the dataset. Defaults to False.
vocab_mapper (tf.keras.Model, optional): Vocabulary mapper. Defaults to None.
"""
logging.info("***** Labels *****")
logging.info(list(data_config.tasks.keys()))
@ -263,7 +313,8 @@ class RecapDataset(torch.utils.data.IterableDataset):
preprocessor = None
self._batch_size_multiplier = 1
if data_config.preprocess:
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)
preprocessor = preprocessors.build_preprocess(
data_config.preprocess, mode=mode)
if data_config.preprocess.downsample_negatives:
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
@ -295,13 +346,17 @@ class RecapDataset(torch.utils.data.IterableDataset):
add_weights=should_add_weights,
)
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
sparse_feature_names = list(
vocab_mapper.vocabs.keys()) if vocab_mapper else None
self._tf_dataset = self._create_tf_dataset()
self._init_tensor_spec()
def _init_tensor_spec(self):
"""
Initialize the tensor specification for the dataset.
"""
def _tensor_spec_to_torch_shape(spec):
if spec.shape is None:
return None
@ -313,8 +368,15 @@ class RecapDataset(torch.utils.data.IterableDataset):
)
def _create_tf_dataset(self):
"""
Create a TensorFlow dataset from the data files.
Returns:
tf.data.Dataset: TensorFlow dataset.
"""
if hasattr(self, "_tf_dataset"):
raise ValueError("Do not call `_create_tf_dataset` more than once.")
raise ValueError(
"Do not call `_create_tf_dataset` more than once.")
world_size = dist.get_world_size() if dist.is_initialized() else 1
per_replica_bsz = (
@ -371,6 +433,15 @@ class RecapDataset(torch.utils.data.IterableDataset):
return dataset.prefetch(world_size * 2)
def _create_base_tf_dataset(self, batch_size: int):
"""
Create the base TensorFlow dataset.
Args:
batch_size (int): Batch size.
Returns:
tf.data.Dataset: Base TensorFlow dataset.
"""
if self._data_config.inputs:
glob = self._data_config.inputs
filenames = sorted(tf.io.gfile.glob(glob))
@ -411,16 +482,27 @@ class RecapDataset(torch.utils.data.IterableDataset):
logging.info(f"Retaining only {len(filenames)} files.")
filenames_ds = (
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
tf.data.Dataset.from_tensor_slices(
filenames).shuffle(len(filenames))
# Because of drop_remainder, if our dataset does not fill
# up a batch, it will emit nothing without this repeat.
.repeat(-1)
)
if self._data_config.file_batch_size:
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)
filenames_ds = filenames_ds.batch(
self._data_config.file_batch_size)
def per_shard_dataset(filename):
"""
Create a TensorFlow dataset for a single shard file.
Args:
filename (str): Path to the shard file.
Returns:
tf.data.Dataset: TensorFlow dataset for the shard file.
"""
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
return ds.prefetch(4)
@ -471,6 +553,12 @@ class RecapDataset(torch.utils.data.IterableDataset):
yield to_batch(x)
def to_dataloader(self) -> Dict[str, torch.Tensor]:
"""
Convert the RecapDataset to a PyTorch DataLoader.
Returns:
torch.utils.data.DataLoader: PyTorch DataLoader for the dataset.
"""
return torch.utils.data.DataLoader(self, batch_size=None)
def __iter__(self):

View File

@ -17,6 +17,16 @@ FLAGS = flags.FLAGS
def _generate_random_example(
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
) -> Dict[str, tf.Tensor]:
"""
Generate a random example based on the provided TensorFlow example schema.
Args:
tf_example_schema (Dict[str, tf.io.FixedLenFeature]): A dictionary defining the schema of the TensorFlow example.
Returns:
Dict[str, tf.Tensor]: A dictionary containing random data for each feature defined in the schema.
"""
example = {}
for feature_name, feature_spec in tf_example_schema.items():
dtype = feature_spec.dtype
@ -33,14 +43,43 @@ def _generate_random_example(
def _float_feature(value):
"""
Create a TensorFlow float feature.
Args:
value: A float or list of floats.
Returns:
tf.train.Feature: A TensorFlow feature containing the float value(s).
"""
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
"""
Create a TensorFlow int64 feature.
Args:
value: An integer or list of integers.
Returns:
tf.train.Feature: A TensorFlow feature containing the int64 value(s).
"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
"""
Serialize a dictionary of TensorFlow tensors into a binary string.
Args:
x (Dict[str, tf.Tensor]): A dictionary of TensorFlow tensors.
Returns:
bytes: The serialized binary string.
"""
feature = {}
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
for feature_name, tensor in x.items():
@ -51,6 +90,15 @@ def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
"""
Generate random data based on the provided configuration and save it as a TFRecord file.
Args:
data_path (str): The path where the TFRecord file will be saved.
config (recap_config_mod.RecapConfig): The configuration for generating the random data.
"""
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
@ -68,6 +116,12 @@ def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
def _generate_data_main(unused_argv):
"""
Main function to generate random data according to the provided configuration.
Args:
unused_argv: Unused command-line arguments.
"""
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
# Find the path where to put the data