mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-02-03 06:22:36 +01:00
ef4c5eb65e
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
30 lines
676 B
Python
30 lines
676 B
Python
from twml.contrib.pruning import apply_mask
|
|
from twml.layers import Layer
|
|
|
|
|
|
class MaskLayer(Layer):
|
|
"""
|
|
This layer corresponds to `twml.contrib.pruning.apply_mask`.
|
|
|
|
It applies a binary mask to mask out channels of a given tensor. The masks can be
|
|
optimized using `twml.contrib.trainers.PruningDataRecordTrainer`.
|
|
"""
|
|
|
|
def call(self, inputs, **kwargs):
|
|
"""
|
|
Applies a binary mask to the channels of the input.
|
|
|
|
Arguments:
|
|
inputs:
|
|
input tensor
|
|
**kwargs:
|
|
additional keyword arguments
|
|
|
|
Returns:
|
|
Masked tensor
|
|
"""
|
|
return apply_mask(inputs)
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
return input_shape
|