129 lines
3.5 KiB
Python
129 lines
3.5 KiB
Python
"""
|
|
Different type of convolution layers to be used in the ClemNet.
|
|
"""
|
|
from typing import Any
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
class KerasConv1D(tf.keras.layers.Layer):
|
|
"""
|
|
Basic Conv1D layer in a wrapper to be compatible with ClemNet.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_size: int,
|
|
filters: int,
|
|
strides: int,
|
|
padding: str,
|
|
use_bias: bool = True,
|
|
kernel_initializer: str = "glorot_uniform",
|
|
bias_initializer: str = "zeros",
|
|
**kwargs: Any,
|
|
):
|
|
super(KerasConv1D, self).__init__(**kwargs)
|
|
self.kernel_size = kernel_size
|
|
self.filters = filters
|
|
self.use_bias = use_bias
|
|
self.kernel_initializer = kernel_initializer
|
|
self.bias_initializer = bias_initializer
|
|
self.strides = strides
|
|
self.padding = padding
|
|
|
|
def build(self, input_shape: tf.TensorShape) -> None:
|
|
assert (
|
|
len(input_shape) == 3
|
|
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
|
|
|
self.features = input_shape[1]
|
|
|
|
self.w = tf.keras.layers.Conv1D(
|
|
kernel_size=self.kernel_size,
|
|
filters=self.filters,
|
|
strides=self.strides,
|
|
padding=self.padding,
|
|
use_bias=self.use_bias,
|
|
kernel_initializer=self.kernel_initializer,
|
|
bias_initializer=self.bias_initializer,
|
|
name=self.name,
|
|
)
|
|
|
|
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
return self.w(inputs)
|
|
|
|
|
|
class ChannelWiseDense(tf.keras.layers.Layer):
|
|
"""
|
|
Dense layer is applied to each channel separately. This is more memory and computationally
|
|
efficient than flattening the channels and performing single dense layers over it which is the
|
|
default behavior in tf1.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
output_size: int,
|
|
use_bias: bool,
|
|
kernel_initializer: str = "uniform_glorot",
|
|
bias_initializer: str = "zeros",
|
|
**kwargs: Any,
|
|
):
|
|
super(ChannelWiseDense, self).__init__(**kwargs)
|
|
self.output_size = output_size
|
|
self.use_bias = use_bias
|
|
self.kernel_initializer = kernel_initializer
|
|
self.bias_initializer = bias_initializer
|
|
|
|
def build(self, input_shape: tf.TensorShape) -> None:
|
|
assert (
|
|
len(input_shape) == 3
|
|
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
|
|
|
input_size = input_shape[1]
|
|
channels = input_shape[2]
|
|
|
|
self.kernel = self.add_weight(
|
|
name="kernel",
|
|
shape=(channels, input_size, self.output_size),
|
|
initializer=self.kernel_initializer,
|
|
trainable=True,
|
|
)
|
|
|
|
self.bias = self.add_weight(
|
|
name="bias",
|
|
shape=(channels, self.output_size),
|
|
initializer=self.bias_initializer,
|
|
trainable=self.use_bias,
|
|
)
|
|
|
|
def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
x = inputs
|
|
|
|
transposed_x = tf.transpose(x, perm=[2, 0, 1])
|
|
transposed_residual = (
|
|
tf.transpose(tf.matmul(transposed_x, self.kernel), perm=[1, 0, 2]) + self.bias
|
|
)
|
|
output = tf.transpose(transposed_residual, perm=[0, 2, 1])
|
|
|
|
return output
|
|
|
|
|
|
class ResidualLayer(tf.keras.layers.Layer):
|
|
"""
|
|
Layer implementing a 3D-residual connection.
|
|
"""
|
|
|
|
def build(self, input_shape: tf.TensorShape) -> None:
|
|
assert (
|
|
len(input_shape) == 3
|
|
), f"Tensor shape must be of length 3. Passed tensor of shape {input_shape}."
|
|
|
|
def call(self, inputs: tf.Tensor, residual: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
shortcut = tf.keras.layers.Conv1D(
|
|
filters=int(residual.shape[2]), strides=1, kernel_size=1, padding="SAME", use_bias=False
|
|
)(inputs)
|
|
|
|
output = tf.add(shortcut, residual)
|
|
|
|
return output
|