mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2025-01-24 12:51:10 +01:00
push
This commit is contained in:
parent
0813989fd9
commit
cc73f5fcb7
@ -43,7 +43,8 @@ class RecapBatch(DataclassBatch):
|
||||
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
||||
"""Converts a torch data loader output into `RecapBatch`."""
|
||||
|
||||
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
|
||||
x = tf.nest.map_structure(functools.partial(
|
||||
sparse_or_dense_tf_to_torch, pin_memory=False), x)
|
||||
try:
|
||||
features_in, labels = x
|
||||
except ValueError:
|
||||
@ -53,25 +54,29 @@ def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
||||
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
|
||||
if sparse_feature_names:
|
||||
sparse_features = keyed_jagged_tensor_from_tensors_dict(
|
||||
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
|
||||
{embedding_name: features_in[embedding_name]
|
||||
for embedding_name in sparse_feature_names}
|
||||
)
|
||||
|
||||
user_embedding, user_eng_embedding, author_embedding = None, None, None
|
||||
if "user_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
|
||||
raise ValueError("Only one source of embedding for user is supported")
|
||||
raise ValueError(
|
||||
"Only one source of embedding for user is supported")
|
||||
else:
|
||||
user_embedding = features_in["user_embedding"]
|
||||
|
||||
if "user_eng_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
|
||||
raise ValueError("Only one source of embedding for user is supported")
|
||||
raise ValueError(
|
||||
"Only one source of embedding for user is supported")
|
||||
else:
|
||||
user_eng_embedding = features_in["user_eng_embedding"]
|
||||
|
||||
if "author_embedding" in features_in:
|
||||
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
|
||||
raise ValueError("Only one source of embedding for user is supported")
|
||||
raise ValueError(
|
||||
"Only one source of embedding for user is supported")
|
||||
else:
|
||||
author_embedding = features_in["author_embedding"]
|
||||
|
||||
@ -84,7 +89,8 @@ def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
||||
user_eng_embedding=user_eng_embedding,
|
||||
author_embedding=author_embedding,
|
||||
labels=labels,
|
||||
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels)
|
||||
# Defaults to torch.ones_like(labels)
|
||||
weights=features_in.get("weights", None),
|
||||
)
|
||||
|
||||
|
||||
@ -119,18 +125,21 @@ def _add_weights(inputs, tasks: Dict[str, TaskData]):
|
||||
float_label = tf.cast(label, tf.float32)
|
||||
|
||||
weights.append(
|
||||
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
||||
float_label / task.pos_downsampling_rate +
|
||||
(1.0 - float_label) / task.neg_downsampling_rate
|
||||
)
|
||||
|
||||
# Ensure we are batch-major (assumes we batch before this call).
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(
|
||||
tf.convert_to_tensor(weights)), axis=0)
|
||||
return inputs
|
||||
|
||||
|
||||
def get_datetimes(explicit_datetime_inputs):
|
||||
"""Compute list datetime strings for train/validation data."""
|
||||
datetime_format = "%Y/%m/%d/%H"
|
||||
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)
|
||||
end = datetime.strptime(
|
||||
explicit_datetime_inputs.end_datetime, datetime_format)
|
||||
dates = sorted(
|
||||
[
|
||||
(end - timedelta(hours=i + 1)).strftime(datetime_format)
|
||||
@ -158,7 +167,8 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
|
||||
|
||||
"""
|
||||
datetimes = get_datetimes(explicit_datetime_inputs)
|
||||
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
|
||||
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime)
|
||||
for datetime in datetimes]
|
||||
data_files = []
|
||||
num_hours_missing = 0
|
||||
for folder in folders:
|
||||
@ -176,20 +186,34 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
|
||||
)
|
||||
except tf.errors.NotFoundError as e:
|
||||
num_hours_missing += 1
|
||||
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
|
||||
logging.warning(
|
||||
f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
|
||||
return sorted(data_files), num_hours_missing
|
||||
|
||||
|
||||
def _map_output_for_inference(
|
||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||
):
|
||||
"""
|
||||
Map the input data for inference.
|
||||
|
||||
Args:
|
||||
inputs (dict): Input data dictionary.
|
||||
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
|
||||
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
|
||||
add_weights (bool, optional): Whether to add weights. Defaults to False.
|
||||
|
||||
Returns:
|
||||
dict: Modified input data dictionary for inference.
|
||||
"""
|
||||
if preprocessor:
|
||||
raise ValueError("No preprocessor should be used at inference time.")
|
||||
if add_weights:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Add zero weights.
|
||||
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1))
|
||||
inputs["weights"] = tf.zeros_like(
|
||||
tf.expand_dims(inputs["continuous"][:, 0], -1))
|
||||
for label in tasks:
|
||||
del inputs[label]
|
||||
return inputs
|
||||
@ -198,6 +222,18 @@ def _map_output_for_inference(
|
||||
def _map_output_for_train_eval(
|
||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||
):
|
||||
"""
|
||||
Map the input data for training and evaluation.
|
||||
|
||||
Args:
|
||||
inputs (dict): Input data dictionary.
|
||||
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
|
||||
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
|
||||
add_weights (bool, optional): Whether to add weights. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, tf.Tensor]: Modified input data dictionary and label values for training and evaluation.
|
||||
"""
|
||||
if add_weights:
|
||||
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
|
||||
|
||||
@ -205,7 +241,8 @@ def _map_output_for_train_eval(
|
||||
if preprocessor:
|
||||
inputs = preprocessor(inputs)
|
||||
|
||||
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
|
||||
label_values = tf.squeeze(
|
||||
tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
|
||||
|
||||
for label in tasks:
|
||||
del inputs[label]
|
||||
@ -231,11 +268,13 @@ def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
|
||||
float_label = tf.cast(label, tf.float32)
|
||||
|
||||
weights.append(
|
||||
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
||||
float_label / task.pos_downsampling_rate +
|
||||
(1.0 - float_label) / task.neg_downsampling_rate
|
||||
)
|
||||
|
||||
# Ensure we are batch-major (assumes we batch before this call).
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
||||
inputs["weights"] = tf.squeeze(tf.transpose(
|
||||
tf.convert_to_tensor(weights)), axis=0)
|
||||
return inputs
|
||||
|
||||
|
||||
@ -249,6 +288,17 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||
repeat: bool = False,
|
||||
vocab_mapper: tf.keras.Model = None,
|
||||
):
|
||||
"""
|
||||
Create a RecapDataset for training or inference.
|
||||
|
||||
Args:
|
||||
data_config (RecapDataConfig): Data configuration.
|
||||
dataset_service (str, optional): Dataset service identifier. Defaults to None.
|
||||
mode (JobMode, optional): Job mode (TRAIN or INFERENCE). Defaults to JobMode.TRAIN.
|
||||
compression (str, optional): Compression type. Defaults to "AUTO".
|
||||
repeat (bool, optional): Whether to repeat the dataset. Defaults to False.
|
||||
vocab_mapper (tf.keras.Model, optional): Vocabulary mapper. Defaults to None.
|
||||
"""
|
||||
logging.info("***** Labels *****")
|
||||
logging.info(list(data_config.tasks.keys()))
|
||||
|
||||
@ -263,7 +313,8 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||
preprocessor = None
|
||||
self._batch_size_multiplier = 1
|
||||
if data_config.preprocess:
|
||||
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)
|
||||
preprocessor = preprocessors.build_preprocess(
|
||||
data_config.preprocess, mode=mode)
|
||||
if data_config.preprocess.downsample_negatives:
|
||||
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
|
||||
|
||||
@ -295,13 +346,17 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||
add_weights=should_add_weights,
|
||||
)
|
||||
|
||||
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
|
||||
sparse_feature_names = list(
|
||||
vocab_mapper.vocabs.keys()) if vocab_mapper else None
|
||||
|
||||
self._tf_dataset = self._create_tf_dataset()
|
||||
|
||||
self._init_tensor_spec()
|
||||
|
||||
def _init_tensor_spec(self):
|
||||
"""
|
||||
Initialize the tensor specification for the dataset.
|
||||
"""
|
||||
def _tensor_spec_to_torch_shape(spec):
|
||||
if spec.shape is None:
|
||||
return None
|
||||
@ -313,8 +368,15 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||
)
|
||||
|
||||
def _create_tf_dataset(self):
|
||||
"""
|
||||
Create a TensorFlow dataset from the data files.
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset: TensorFlow dataset.
|
||||
"""
|
||||
if hasattr(self, "_tf_dataset"):
|
||||
raise ValueError("Do not call `_create_tf_dataset` more than once.")
|
||||
raise ValueError(
|
||||
"Do not call `_create_tf_dataset` more than once.")
|
||||
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
per_replica_bsz = (
|
||||
@ -371,6 +433,15 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||
return dataset.prefetch(world_size * 2)
|
||||
|
||||
def _create_base_tf_dataset(self, batch_size: int):
|
||||
"""
|
||||
Create the base TensorFlow dataset.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size.
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset: Base TensorFlow dataset.
|
||||
"""
|
||||
if self._data_config.inputs:
|
||||
glob = self._data_config.inputs
|
||||
filenames = sorted(tf.io.gfile.glob(glob))
|
||||
@ -411,16 +482,27 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||
logging.info(f"Retaining only {len(filenames)} files.")
|
||||
|
||||
filenames_ds = (
|
||||
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
|
||||
tf.data.Dataset.from_tensor_slices(
|
||||
filenames).shuffle(len(filenames))
|
||||
# Because of drop_remainder, if our dataset does not fill
|
||||
# up a batch, it will emit nothing without this repeat.
|
||||
.repeat(-1)
|
||||
)
|
||||
|
||||
if self._data_config.file_batch_size:
|
||||
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)
|
||||
filenames_ds = filenames_ds.batch(
|
||||
self._data_config.file_batch_size)
|
||||
|
||||
def per_shard_dataset(filename):
|
||||
"""
|
||||
Create a TensorFlow dataset for a single shard file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the shard file.
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset: TensorFlow dataset for the shard file.
|
||||
"""
|
||||
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
|
||||
return ds.prefetch(4)
|
||||
|
||||
@ -471,6 +553,12 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||
yield to_batch(x)
|
||||
|
||||
def to_dataloader(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Convert the RecapDataset to a PyTorch DataLoader.
|
||||
|
||||
Returns:
|
||||
torch.utils.data.DataLoader: PyTorch DataLoader for the dataset.
|
||||
"""
|
||||
return torch.utils.data.DataLoader(self, batch_size=None)
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -17,6 +17,16 @@ FLAGS = flags.FLAGS
|
||||
def _generate_random_example(
|
||||
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
|
||||
) -> Dict[str, tf.Tensor]:
|
||||
"""
|
||||
Generate a random example based on the provided TensorFlow example schema.
|
||||
|
||||
Args:
|
||||
tf_example_schema (Dict[str, tf.io.FixedLenFeature]): A dictionary defining the schema of the TensorFlow example.
|
||||
|
||||
Returns:
|
||||
Dict[str, tf.Tensor]: A dictionary containing random data for each feature defined in the schema.
|
||||
"""
|
||||
|
||||
example = {}
|
||||
for feature_name, feature_spec in tf_example_schema.items():
|
||||
dtype = feature_spec.dtype
|
||||
@ -33,14 +43,43 @@ def _generate_random_example(
|
||||
|
||||
|
||||
def _float_feature(value):
|
||||
"""
|
||||
Create a TensorFlow float feature.
|
||||
|
||||
Args:
|
||||
value: A float or list of floats.
|
||||
|
||||
Returns:
|
||||
tf.train.Feature: A TensorFlow feature containing the float value(s).
|
||||
"""
|
||||
|
||||
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
||||
|
||||
|
||||
def _int64_feature(value):
|
||||
"""
|
||||
Create a TensorFlow int64 feature.
|
||||
|
||||
Args:
|
||||
value: An integer or list of integers.
|
||||
|
||||
Returns:
|
||||
tf.train.Feature: A TensorFlow feature containing the int64 value(s).
|
||||
"""
|
||||
|
||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
||||
|
||||
|
||||
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
|
||||
"""
|
||||
Serialize a dictionary of TensorFlow tensors into a binary string.
|
||||
|
||||
Args:
|
||||
x (Dict[str, tf.Tensor]): A dictionary of TensorFlow tensors.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized binary string.
|
||||
"""
|
||||
feature = {}
|
||||
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
|
||||
for feature_name, tensor in x.items():
|
||||
@ -51,6 +90,15 @@ def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
|
||||
|
||||
|
||||
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
|
||||
"""
|
||||
Generate random data based on the provided configuration and save it as a TFRecord file.
|
||||
|
||||
Args:
|
||||
data_path (str): The path where the TFRecord file will be saved.
|
||||
config (recap_config_mod.RecapConfig): The configuration for generating the random data.
|
||||
"""
|
||||
|
||||
|
||||
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
|
||||
seg_dense_schema = json.load(f)["schema"]
|
||||
|
||||
@ -68,6 +116,12 @@ def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
|
||||
|
||||
|
||||
def _generate_data_main(unused_argv):
|
||||
"""
|
||||
Main function to generate random data according to the provided configuration.
|
||||
|
||||
Args:
|
||||
unused_argv: Unused command-line arguments.
|
||||
"""
|
||||
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
||||
|
||||
# Find the path where to put the data
|
||||
|
Loading…
x
Reference in New Issue
Block a user