2023-04-01 00:36:31 +02:00
|
|
|
import numpy as np
|
|
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
|
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
def create_sparse_tensor(
|
|
|
|
batch_size: int,
|
|
|
|
input_size: int,
|
|
|
|
num_values: int,
|
|
|
|
dtype: tf.DType = tf.float32,
|
|
|
|
) -> tf.SparseTensor:
|
|
|
|
"""
|
|
|
|
Creates a sparse tensor with the given batch size, input size, and number of values.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
batch_size (int): The batch size of the sparse tensor.
|
|
|
|
input_size (int): The input size of the sparse tensor.
|
|
|
|
num_values (int): The number of values in the sparse tensor.
|
|
|
|
dtype (tf.DType): The dtype of the sparse tensor.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A sparse tensor with the given batch size, input size, and number of values.
|
|
|
|
"""
|
|
|
|
random_indices = np.sort(
|
|
|
|
np.random.randint(batch_size * input_size, size=num_values)
|
|
|
|
)
|
|
|
|
test_indices_i = random_indices // input_size
|
|
|
|
test_indices_j = random_indices % input_size
|
|
|
|
test_indices = np.stack([test_indices_i, test_indices_j], axis=1)
|
|
|
|
test_values = np.random.random(num_values).astype(dtype.as_numpy_dtype)
|
|
|
|
|
|
|
|
return tf.SparseTensor(
|
|
|
|
indices=tf.constant(test_indices),
|
|
|
|
values=tf.constant(test_values),
|
|
|
|
dense_shape=(batch_size, input_size),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def create_reference_input(
|
|
|
|
sparse_input: tf.SparseTensor, use_binary_values: bool
|
|
|
|
) -> tf.SparseTensor:
|
|
|
|
"""
|
|
|
|
Creates a reference input for the sparse input.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
sparse_input (tf.SparseTensor): The sparse input.
|
|
|
|
use_binary_values (bool): Whether to use binary values.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A reference input for the sparse input.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if use_binary_values:
|
|
|
|
sp_a = tf.SparseTensor(
|
|
|
|
indices=sparse_input.indices,
|
|
|
|
values=tf.ones_like(sparse_input.values),
|
|
|
|
dense_shape=sparse_input.dense_shape,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
sp_a = sparse_input
|
|
|
|
return sp_a
|