import numpy as np import tensorflow.compat.v1 as tf class PartitionInitializer(tf.keras.initializers.Initializer): """Required to initialize partitioned weight with numpy array for tests""" def __init__(self, np_array: np.ndarray): self.np_array = np_array def __call__(self, shape, dtype=None, partition_info=None) -> np.ndarray: """Returns a numpy array for the given shape and dtype.""" offset = partition_info.var_offset ix0, ix1 = offset[0], offset[0] + shape[0] iy0, iy1 = offset[1], offset[1] + shape[1] if dtype is not None: return self.np_array[ix0:ix1, iy0:iy1].astype(dtype) return self.np_array[ix0:ix1, iy0:iy1]