the-algorithm/pushservice/src/main/python/models/libs/customized_full_sparse.py
twitter-team b389c3d302 Open-sourcing pushservice
Pushservice is the main recommendation service we use to surface recommendations to our users via notifications. It fetches candidates from various sources, ranks them in order of relevance, and applies filters to determine the best one to send.
2023-05-19 16:27:07 -05:00

57 lines
1.8 KiB
Python

# pylint: disable=no-member, arguments-differ, attribute-defined-outside-init, unused-argument
"""
Implementing Full Sparse Layer, allow specify use_binary_value in call() to
overide default action.
"""
from twml.layers import FullSparse as defaultFullSparse
from twml.layers.full_sparse import sparse_dense_matmul
import tensorflow.compat.v1 as tf
class FullSparse(defaultFullSparse):
def call(self, inputs, use_binary_values=None, **kwargs): # pylint: disable=unused-argument
"""The logic of the layer lives here.
Arguments:
inputs:
A SparseTensor or a list of SparseTensors.
If `inputs` is a list, all tensors must have same `dense_shape`.
Returns:
- If `inputs` is `SparseTensor`, then returns `bias + inputs * dense_b`.
- If `inputs` is a `list[SparseTensor`, then returns
`bias + add_n([sp_a * dense_b for sp_a in inputs])`.
"""
if use_binary_values is not None:
default_use_binary_values = use_binary_values
else:
default_use_binary_values = self.use_binary_values
if isinstance(default_use_binary_values, (list, tuple)):
raise ValueError(
"use_binary_values can not be %s when inputs is %s"
% (type(default_use_binary_values), type(inputs))
)
outputs = sparse_dense_matmul(
inputs,
self.weight,
self.use_sparse_grads,
default_use_binary_values,
name="sparse_mm",
partition_axis=self.partition_axis,
num_partitions=self.num_partitions,
compress_ids=self._use_compression,
cast_indices_dtype=self._cast_indices_dtype,
)
if self.bias is not None:
outputs = tf.nn.bias_add(outputs, self.bias)
if self.activation is not None:
return self.activation(outputs) # pylint: disable=not-callable
return outputs