130 lines
4.3 KiB
Python
Raw Normal View History

import functools
import json
from tml.projects.home.recap.data import config as recap_data_config
from absl import logging
import tensorflow as tf
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
def create_tf_example_schema(
data_config: recap_data_config.SegDenseSchema,
segdense_schema,
):
"""Generate schema for deseralizing tf.Example.
Args:
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
labels: List of strings denoting labels.
Returns:
A dictionary schema suitable for deserializing tf.Example.
"""
segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys())
used_features = (
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
)
logging.info(used_features)
tfe_schema = {}
for entry in segdense_schema:
feature_name = entry["feature_name"]
if feature_name in used_features:
length = entry["length"]
dtype = entry["dtype"]
if feature_name in labels:
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
)
elif length == -1:
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
else:
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
)
for feature_name in used_features:
if feature_name not in tfe_schema:
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
return tfe_schema
@functools.lru_cache(1)
def make_mantissa_mask(mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types."""
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types."""
mask: tf.Tensor = make_mantissa_mask(mask_length)
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
def parse_tf_example(
serialized_example,
tfe_schema,
seg_dense_schema_config,
):
"""Parse serialized tf.Example into dict of tensors.
Args:
serialized_example: Serialized tf.Example to be parsed.
tfe_schema: Dictionary schema suitable for deserializing tf.Example.
Returns:
Dictionary of tensors to be used as model input.
"""
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
inputs[new_feature_name] = inputs.pop(old_feature_name)
# This should not actually be used except for experimentation with low precision floats.
if "mask_mantissa_features" in seg_dense_schema_config:
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
# at TF level.
# We should not return empty tensors if we dont use embeddings.
# Otherwise, it breaks numpy->pt conversion
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
for renamed_key in renamed_keys:
if "embedding" in renamed_key and (renamed_key not in inputs):
inputs[renamed_key] = tf.zeros([], tf.float32)
logging.info(f"parsed example and inputs are {inputs}")
return inputs
def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):
"""Placeholder for seg dense.
In the future, when we use more seg dense variations, we can change this.
"""
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
tf_example_schema = create_tf_example_schema(
data_config,
seg_dense_schema,
)
logging.info("***** TF Example Schema *****")
logging.info(tf_example_schema)
parse = functools.partial(
parse_tf_example,
tfe_schema=tf_example_schema,
seg_dense_schema_config=data_config.seg_dense_schema,
)
return parse