the-algorithm/twml/twml/layers/partition.py
twitter-team ef4c5eb65e Twitter Recommendation Algorithm
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
2023-03-31 17:36:31 -05:00

75 lines
2.1 KiB
Python

"""
Implementing partition Layer
"""
from .layer import Layer
import tensorflow.compat.v1 as tf
class Partition(Layer):
"""
This layer implements:
.. code-block:: python
tf.dynamic_partition(input_vals, partition_ids, self.partitions)
Input:
partitions:
the number of partitions which we will divide the hashmap keys/bvalues
Output:
A layer that performs partitioning
"""
def __init__(self, partitions=2, **kwargs):
self.partitions = partitions
super(Partition, self).__init__(**kwargs)
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer given the input shape.
Args:
input_shape: A (possibly nested tuple of) `TensorShape`. It need not
be fully defined (e.g. the batch size may be unknown).
Raises NotImplementedError.
"""
raise NotImplementedError
def call(self, partition_ids, input_vals, input_keys, **kwargs):
"""This layer is responsible for partitioning the values/keys of a hashmap
Arguments:
partition_ids:
Tensor that is equivalent to boolean (int32).
input_vals:
Tensor that represents the values of the hashmap(float).
input_keys:
Tensor that represents the keys of the hashmap(float)
Returns:
The output of the partition layer, which is a list of lists which looks
something like:
.. code-block:: python
[[vals_0, vals_1], [keys_0, keys_1], [indices_0, indices_1]]
where:
vals_x:
values of the hashmap for partition x
keys_x:
keys of the hashmap for partition x
indices_x:
indices of the hashmap for partition x
"""
partioned_val = tf.dynamic_partition(input_vals, partition_ids, self.partitions)
partioned_keys = tf.dynamic_partition(input_keys, partition_ids, self.partitions)
partioned_indices = tf.dynamic_partition(tf.range(tf.shape(partition_ids)[0]),
tf.cast(partition_ids, tf.int32), self.partitions)
return [partioned_val, partioned_keys, partioned_indices]