mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-02-09 09:03:27 +01:00
30 lines
676 B
Python
30 lines
676 B
Python
![]() |
from twml.contrib.pruning import apply_mask
|
||
|
from twml.layers import Layer
|
||
|
|
||
|
|
||
|
class MaskLayer(Layer):
|
||
|
"""
|
||
|
This layer corresponds to `twml.contrib.pruning.apply_mask`.
|
||
|
|
||
|
It applies a binary mask to mask out channels of a given tensor. The masks can be
|
||
|
optimized using `twml.contrib.trainers.PruningDataRecordTrainer`.
|
||
|
"""
|
||
|
|
||
|
def call(self, inputs, **kwargs):
|
||
|
"""
|
||
|
Applies a binary mask to the channels of the input.
|
||
|
|
||
|
Arguments:
|
||
|
inputs:
|
||
|
input tensor
|
||
|
**kwargs:
|
||
|
additional keyword arguments
|
||
|
|
||
|
Returns:
|
||
|
Masked tensor
|
||
|
"""
|
||
|
return apply_mask(inputs)
|
||
|
|
||
|
def compute_output_shape(self, input_shape):
|
||
|
return input_shape
|