twitter-team ef4c5eb65e Twitter Recommendation Algorithm
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
2023-03-31 17:36:31 -05:00

30 lines
676 B
Python

from twml.contrib.pruning import apply_mask
from twml.layers import Layer
class MaskLayer(Layer):
"""
This layer corresponds to `twml.contrib.pruning.apply_mask`.
It applies a binary mask to mask out channels of a given tensor. The masks can be
optimized using `twml.contrib.trainers.PruningDataRecordTrainer`.
"""
def call(self, inputs, **kwargs):
"""
Applies a binary mask to the channels of the input.
Arguments:
inputs:
input tensor
**kwargs:
additional keyword arguments
Returns:
Masked tensor
"""
return apply_mask(inputs)
def compute_output_shape(self, input_shape):
return input_shape