mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-06-16 14:48:44 +02:00
push
This commit is contained in:
parent
0813989fd9
commit
cc73f5fcb7
|
@ -43,7 +43,8 @@ class RecapBatch(DataclassBatch):
|
||||||
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
||||||
"""Converts a torch data loader output into `RecapBatch`."""
|
"""Converts a torch data loader output into `RecapBatch`."""
|
||||||
|
|
||||||
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
|
x = tf.nest.map_structure(functools.partial(
|
||||||
|
sparse_or_dense_tf_to_torch, pin_memory=False), x)
|
||||||
try:
|
try:
|
||||||
features_in, labels = x
|
features_in, labels = x
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -53,25 +54,29 @@ def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
||||||
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
|
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
|
||||||
if sparse_feature_names:
|
if sparse_feature_names:
|
||||||
sparse_features = keyed_jagged_tensor_from_tensors_dict(
|
sparse_features = keyed_jagged_tensor_from_tensors_dict(
|
||||||
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
|
{embedding_name: features_in[embedding_name]
|
||||||
|
for embedding_name in sparse_feature_names}
|
||||||
)
|
)
|
||||||
|
|
||||||
user_embedding, user_eng_embedding, author_embedding = None, None, None
|
user_embedding, user_eng_embedding, author_embedding = None, None, None
|
||||||
if "user_embedding" in features_in:
|
if "user_embedding" in features_in:
|
||||||
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
|
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
|
||||||
raise ValueError("Only one source of embedding for user is supported")
|
raise ValueError(
|
||||||
|
"Only one source of embedding for user is supported")
|
||||||
else:
|
else:
|
||||||
user_embedding = features_in["user_embedding"]
|
user_embedding = features_in["user_embedding"]
|
||||||
|
|
||||||
if "user_eng_embedding" in features_in:
|
if "user_eng_embedding" in features_in:
|
||||||
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
|
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
|
||||||
raise ValueError("Only one source of embedding for user is supported")
|
raise ValueError(
|
||||||
|
"Only one source of embedding for user is supported")
|
||||||
else:
|
else:
|
||||||
user_eng_embedding = features_in["user_eng_embedding"]
|
user_eng_embedding = features_in["user_eng_embedding"]
|
||||||
|
|
||||||
if "author_embedding" in features_in:
|
if "author_embedding" in features_in:
|
||||||
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
|
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
|
||||||
raise ValueError("Only one source of embedding for user is supported")
|
raise ValueError(
|
||||||
|
"Only one source of embedding for user is supported")
|
||||||
else:
|
else:
|
||||||
author_embedding = features_in["author_embedding"]
|
author_embedding = features_in["author_embedding"]
|
||||||
|
|
||||||
|
@ -84,7 +89,8 @@ def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
||||||
user_eng_embedding=user_eng_embedding,
|
user_eng_embedding=user_eng_embedding,
|
||||||
author_embedding=author_embedding,
|
author_embedding=author_embedding,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels)
|
# Defaults to torch.ones_like(labels)
|
||||||
|
weights=features_in.get("weights", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,18 +125,21 @@ def _add_weights(inputs, tasks: Dict[str, TaskData]):
|
||||||
float_label = tf.cast(label, tf.float32)
|
float_label = tf.cast(label, tf.float32)
|
||||||
|
|
||||||
weights.append(
|
weights.append(
|
||||||
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
float_label / task.pos_downsampling_rate +
|
||||||
|
(1.0 - float_label) / task.neg_downsampling_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure we are batch-major (assumes we batch before this call).
|
# Ensure we are batch-major (assumes we batch before this call).
|
||||||
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
inputs["weights"] = tf.squeeze(tf.transpose(
|
||||||
|
tf.convert_to_tensor(weights)), axis=0)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
def get_datetimes(explicit_datetime_inputs):
|
def get_datetimes(explicit_datetime_inputs):
|
||||||
"""Compute list datetime strings for train/validation data."""
|
"""Compute list datetime strings for train/validation data."""
|
||||||
datetime_format = "%Y/%m/%d/%H"
|
datetime_format = "%Y/%m/%d/%H"
|
||||||
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)
|
end = datetime.strptime(
|
||||||
|
explicit_datetime_inputs.end_datetime, datetime_format)
|
||||||
dates = sorted(
|
dates = sorted(
|
||||||
[
|
[
|
||||||
(end - timedelta(hours=i + 1)).strftime(datetime_format)
|
(end - timedelta(hours=i + 1)).strftime(datetime_format)
|
||||||
|
@ -158,7 +167,8 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
datetimes = get_datetimes(explicit_datetime_inputs)
|
datetimes = get_datetimes(explicit_datetime_inputs)
|
||||||
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
|
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime)
|
||||||
|
for datetime in datetimes]
|
||||||
data_files = []
|
data_files = []
|
||||||
num_hours_missing = 0
|
num_hours_missing = 0
|
||||||
for folder in folders:
|
for folder in folders:
|
||||||
|
@ -176,20 +186,34 @@ def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
|
||||||
)
|
)
|
||||||
except tf.errors.NotFoundError as e:
|
except tf.errors.NotFoundError as e:
|
||||||
num_hours_missing += 1
|
num_hours_missing += 1
|
||||||
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
|
logging.warning(
|
||||||
|
f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
|
||||||
return sorted(data_files), num_hours_missing
|
return sorted(data_files), num_hours_missing
|
||||||
|
|
||||||
|
|
||||||
def _map_output_for_inference(
|
def _map_output_for_inference(
|
||||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Map the input data for inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (dict): Input data dictionary.
|
||||||
|
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
|
||||||
|
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
|
||||||
|
add_weights (bool, optional): Whether to add weights. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Modified input data dictionary for inference.
|
||||||
|
"""
|
||||||
if preprocessor:
|
if preprocessor:
|
||||||
raise ValueError("No preprocessor should be used at inference time.")
|
raise ValueError("No preprocessor should be used at inference time.")
|
||||||
if add_weights:
|
if add_weights:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Add zero weights.
|
# Add zero weights.
|
||||||
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1))
|
inputs["weights"] = tf.zeros_like(
|
||||||
|
tf.expand_dims(inputs["continuous"][:, 0], -1))
|
||||||
for label in tasks:
|
for label in tasks:
|
||||||
del inputs[label]
|
del inputs[label]
|
||||||
return inputs
|
return inputs
|
||||||
|
@ -198,6 +222,18 @@ def _map_output_for_inference(
|
||||||
def _map_output_for_train_eval(
|
def _map_output_for_train_eval(
|
||||||
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Map the input data for training and evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (dict): Input data dictionary.
|
||||||
|
tasks (Dict[str, TaskData]): Dictionary of task data configurations.
|
||||||
|
preprocessor (tf.keras.Model, optional): Preprocessor model for input data. Defaults to None.
|
||||||
|
add_weights (bool, optional): Whether to add weights. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[dict, tf.Tensor]: Modified input data dictionary and label values for training and evaluation.
|
||||||
|
"""
|
||||||
if add_weights:
|
if add_weights:
|
||||||
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
|
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
|
||||||
|
|
||||||
|
@ -205,7 +241,8 @@ def _map_output_for_train_eval(
|
||||||
if preprocessor:
|
if preprocessor:
|
||||||
inputs = preprocessor(inputs)
|
inputs = preprocessor(inputs)
|
||||||
|
|
||||||
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
|
label_values = tf.squeeze(
|
||||||
|
tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
|
||||||
|
|
||||||
for label in tasks:
|
for label in tasks:
|
||||||
del inputs[label]
|
del inputs[label]
|
||||||
|
@ -231,11 +268,13 @@ def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
|
||||||
float_label = tf.cast(label, tf.float32)
|
float_label = tf.cast(label, tf.float32)
|
||||||
|
|
||||||
weights.append(
|
weights.append(
|
||||||
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
float_label / task.pos_downsampling_rate +
|
||||||
|
(1.0 - float_label) / task.neg_downsampling_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure we are batch-major (assumes we batch before this call).
|
# Ensure we are batch-major (assumes we batch before this call).
|
||||||
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
inputs["weights"] = tf.squeeze(tf.transpose(
|
||||||
|
tf.convert_to_tensor(weights)), axis=0)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
@ -249,6 +288,17 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||||
repeat: bool = False,
|
repeat: bool = False,
|
||||||
vocab_mapper: tf.keras.Model = None,
|
vocab_mapper: tf.keras.Model = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Create a RecapDataset for training or inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_config (RecapDataConfig): Data configuration.
|
||||||
|
dataset_service (str, optional): Dataset service identifier. Defaults to None.
|
||||||
|
mode (JobMode, optional): Job mode (TRAIN or INFERENCE). Defaults to JobMode.TRAIN.
|
||||||
|
compression (str, optional): Compression type. Defaults to "AUTO".
|
||||||
|
repeat (bool, optional): Whether to repeat the dataset. Defaults to False.
|
||||||
|
vocab_mapper (tf.keras.Model, optional): Vocabulary mapper. Defaults to None.
|
||||||
|
"""
|
||||||
logging.info("***** Labels *****")
|
logging.info("***** Labels *****")
|
||||||
logging.info(list(data_config.tasks.keys()))
|
logging.info(list(data_config.tasks.keys()))
|
||||||
|
|
||||||
|
@ -263,7 +313,8 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||||
preprocessor = None
|
preprocessor = None
|
||||||
self._batch_size_multiplier = 1
|
self._batch_size_multiplier = 1
|
||||||
if data_config.preprocess:
|
if data_config.preprocess:
|
||||||
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)
|
preprocessor = preprocessors.build_preprocess(
|
||||||
|
data_config.preprocess, mode=mode)
|
||||||
if data_config.preprocess.downsample_negatives:
|
if data_config.preprocess.downsample_negatives:
|
||||||
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
|
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
|
||||||
|
|
||||||
|
@ -295,13 +346,17 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||||
add_weights=should_add_weights,
|
add_weights=should_add_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
|
sparse_feature_names = list(
|
||||||
|
vocab_mapper.vocabs.keys()) if vocab_mapper else None
|
||||||
|
|
||||||
self._tf_dataset = self._create_tf_dataset()
|
self._tf_dataset = self._create_tf_dataset()
|
||||||
|
|
||||||
self._init_tensor_spec()
|
self._init_tensor_spec()
|
||||||
|
|
||||||
def _init_tensor_spec(self):
|
def _init_tensor_spec(self):
|
||||||
|
"""
|
||||||
|
Initialize the tensor specification for the dataset.
|
||||||
|
"""
|
||||||
def _tensor_spec_to_torch_shape(spec):
|
def _tensor_spec_to_torch_shape(spec):
|
||||||
if spec.shape is None:
|
if spec.shape is None:
|
||||||
return None
|
return None
|
||||||
|
@ -313,8 +368,15 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_tf_dataset(self):
|
def _create_tf_dataset(self):
|
||||||
|
"""
|
||||||
|
Create a TensorFlow dataset from the data files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.data.Dataset: TensorFlow dataset.
|
||||||
|
"""
|
||||||
if hasattr(self, "_tf_dataset"):
|
if hasattr(self, "_tf_dataset"):
|
||||||
raise ValueError("Do not call `_create_tf_dataset` more than once.")
|
raise ValueError(
|
||||||
|
"Do not call `_create_tf_dataset` more than once.")
|
||||||
|
|
||||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
per_replica_bsz = (
|
per_replica_bsz = (
|
||||||
|
@ -371,6 +433,15 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||||
return dataset.prefetch(world_size * 2)
|
return dataset.prefetch(world_size * 2)
|
||||||
|
|
||||||
def _create_base_tf_dataset(self, batch_size: int):
|
def _create_base_tf_dataset(self, batch_size: int):
|
||||||
|
"""
|
||||||
|
Create the base TensorFlow dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.data.Dataset: Base TensorFlow dataset.
|
||||||
|
"""
|
||||||
if self._data_config.inputs:
|
if self._data_config.inputs:
|
||||||
glob = self._data_config.inputs
|
glob = self._data_config.inputs
|
||||||
filenames = sorted(tf.io.gfile.glob(glob))
|
filenames = sorted(tf.io.gfile.glob(glob))
|
||||||
|
@ -411,16 +482,27 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||||
logging.info(f"Retaining only {len(filenames)} files.")
|
logging.info(f"Retaining only {len(filenames)} files.")
|
||||||
|
|
||||||
filenames_ds = (
|
filenames_ds = (
|
||||||
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
|
tf.data.Dataset.from_tensor_slices(
|
||||||
|
filenames).shuffle(len(filenames))
|
||||||
# Because of drop_remainder, if our dataset does not fill
|
# Because of drop_remainder, if our dataset does not fill
|
||||||
# up a batch, it will emit nothing without this repeat.
|
# up a batch, it will emit nothing without this repeat.
|
||||||
.repeat(-1)
|
.repeat(-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._data_config.file_batch_size:
|
if self._data_config.file_batch_size:
|
||||||
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)
|
filenames_ds = filenames_ds.batch(
|
||||||
|
self._data_config.file_batch_size)
|
||||||
|
|
||||||
def per_shard_dataset(filename):
|
def per_shard_dataset(filename):
|
||||||
|
"""
|
||||||
|
Create a TensorFlow dataset for a single shard file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): Path to the shard file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.data.Dataset: TensorFlow dataset for the shard file.
|
||||||
|
"""
|
||||||
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
|
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
|
||||||
return ds.prefetch(4)
|
return ds.prefetch(4)
|
||||||
|
|
||||||
|
@ -471,6 +553,12 @@ class RecapDataset(torch.utils.data.IterableDataset):
|
||||||
yield to_batch(x)
|
yield to_batch(x)
|
||||||
|
|
||||||
def to_dataloader(self) -> Dict[str, torch.Tensor]:
|
def to_dataloader(self) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Convert the RecapDataset to a PyTorch DataLoader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.utils.data.DataLoader: PyTorch DataLoader for the dataset.
|
||||||
|
"""
|
||||||
return torch.utils.data.DataLoader(self, batch_size=None)
|
return torch.utils.data.DataLoader(self, batch_size=None)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
|
@ -17,6 +17,16 @@ FLAGS = flags.FLAGS
|
||||||
def _generate_random_example(
|
def _generate_random_example(
|
||||||
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
|
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
|
||||||
) -> Dict[str, tf.Tensor]:
|
) -> Dict[str, tf.Tensor]:
|
||||||
|
"""
|
||||||
|
Generate a random example based on the provided TensorFlow example schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tf_example_schema (Dict[str, tf.io.FixedLenFeature]): A dictionary defining the schema of the TensorFlow example.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, tf.Tensor]: A dictionary containing random data for each feature defined in the schema.
|
||||||
|
"""
|
||||||
|
|
||||||
example = {}
|
example = {}
|
||||||
for feature_name, feature_spec in tf_example_schema.items():
|
for feature_name, feature_spec in tf_example_schema.items():
|
||||||
dtype = feature_spec.dtype
|
dtype = feature_spec.dtype
|
||||||
|
@ -33,14 +43,43 @@ def _generate_random_example(
|
||||||
|
|
||||||
|
|
||||||
def _float_feature(value):
|
def _float_feature(value):
|
||||||
|
"""
|
||||||
|
Create a TensorFlow float feature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: A float or list of floats.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.train.Feature: A TensorFlow feature containing the float value(s).
|
||||||
|
"""
|
||||||
|
|
||||||
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
|
||||||
|
|
||||||
|
|
||||||
def _int64_feature(value):
|
def _int64_feature(value):
|
||||||
|
"""
|
||||||
|
Create a TensorFlow int64 feature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: An integer or list of integers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.train.Feature: A TensorFlow feature containing the int64 value(s).
|
||||||
|
"""
|
||||||
|
|
||||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
|
||||||
|
|
||||||
|
|
||||||
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
|
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
|
||||||
|
"""
|
||||||
|
Serialize a dictionary of TensorFlow tensors into a binary string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Dict[str, tf.Tensor]): A dictionary of TensorFlow tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The serialized binary string.
|
||||||
|
"""
|
||||||
feature = {}
|
feature = {}
|
||||||
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
|
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
|
||||||
for feature_name, tensor in x.items():
|
for feature_name, tensor in x.items():
|
||||||
|
@ -51,6 +90,15 @@ def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
|
||||||
|
|
||||||
|
|
||||||
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
|
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
|
||||||
|
"""
|
||||||
|
Generate random data based on the provided configuration and save it as a TFRecord file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path (str): The path where the TFRecord file will be saved.
|
||||||
|
config (recap_config_mod.RecapConfig): The configuration for generating the random data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
|
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
|
||||||
seg_dense_schema = json.load(f)["schema"]
|
seg_dense_schema = json.load(f)["schema"]
|
||||||
|
|
||||||
|
@ -68,6 +116,12 @@ def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
|
||||||
|
|
||||||
|
|
||||||
def _generate_data_main(unused_argv):
|
def _generate_data_main(unused_argv):
|
||||||
|
"""
|
||||||
|
Main function to generate random data according to the provided configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unused_argv: Unused command-line arguments.
|
||||||
|
"""
|
||||||
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
|
||||||
|
|
||||||
# Find the path where to put the data
|
# Find the path where to put the data
|
||||||
|
|
Loading…
Reference in New Issue
Block a user