the-algorithm-ml/projects/home/recap/data/generate_random_data.py

136 lines
3.9 KiB
Python
Raw Normal View History

import os
import json
from absl import app, flags, logging
import tensorflow as tf
from typing import Dict
from tml.projects.home.recap.data import tfe_parsing
from tml.core import config as tml_config_mod
import tml.projects.home.recap.config as recap_config_mod
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
flags.DEFINE_integer("n_examples", 100, "Numer of examples to generate.")
FLAGS = flags.FLAGS
def _generate_random_example(
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
) -> Dict[str, tf.Tensor]:
2023-09-21 14:48:15 +02:00
"""
Generate a random example based on the provided TensorFlow example schema.
Args:
tf_example_schema (Dict[str, tf.io.FixedLenFeature]): A dictionary defining the schema of the TensorFlow example.
Returns:
Dict[str, tf.Tensor]: A dictionary containing random data for each feature defined in the schema.
"""
example = {}
for feature_name, feature_spec in tf_example_schema.items():
dtype = feature_spec.dtype
if (dtype == tf.int64) or (dtype == tf.int32):
x = tf.experimental.numpy.random.randint(0, high=10, size=feature_spec.shape, dtype=dtype)
elif (dtype == tf.float32) or (dtype == tf.float64):
x = tf.random.uniform(shape=[feature_spec.shape], dtype=dtype)
else:
raise NotImplementedError(f"Unknown type {dtype}")
example[feature_name] = x
return example
def _float_feature(value):
2023-09-21 14:48:15 +02:00
"""
Create a TensorFlow float feature.
Args:
value: A float or list of floats.
Returns:
tf.train.Feature: A TensorFlow feature containing the float value(s).
"""
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
2023-09-21 14:48:15 +02:00
"""
Create a TensorFlow int64 feature.
Args:
value: An integer or list of integers.
Returns:
tf.train.Feature: A TensorFlow feature containing the int64 value(s).
"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
2023-09-21 14:48:15 +02:00
"""
Serialize a dictionary of TensorFlow tensors into a binary string.
Args:
x (Dict[str, tf.Tensor]): A dictionary of TensorFlow tensors.
Returns:
bytes: The serialized binary string.
"""
feature = {}
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
for feature_name, tensor in x.items():
feature[feature_name] = serializers[tensor.dtype](tensor)
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
2023-09-21 14:48:15 +02:00
"""
Generate random data based on the provided configuration and save it as a TFRecord file.
Args:
data_path (str): The path where the TFRecord file will be saved.
config (recap_config_mod.RecapConfig): The configuration for generating the random data.
"""
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
tf_example_schema = tfe_parsing.create_tf_example_schema(
config.train_data,
seg_dense_schema,
)
record_filename = os.path.join(data_path, "random.tfrecord.gz")
with tf.io.TFRecordWriter(record_filename, "GZIP") as writer:
random_example = _generate_random_example(tf_example_schema)
serialized_example = _serialize_example(random_example)
writer.write(serialized_example)
def _generate_data_main(unused_argv):
2023-09-21 14:48:15 +02:00
"""
Main function to generate random data according to the provided configuration.
Args:
unused_argv: Unused command-line arguments.
"""
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
# Find the path where to put the data
data_path = os.path.dirname(config.train_data.inputs)
logging.info("Putting random data in %s", data_path)
generate_data(data_path, config)
if __name__ == "__main__":
app.run(_generate_data_main)