2023-04-17 06:19:03 +02:00
|
|
|
import numpy as np
|
2023-04-01 00:36:31 +02:00
|
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
|
|
|
|
|
|
|
|
class PartitionInitializer(tf.keras.initializers.Initializer):
|
2023-04-17 06:19:03 +02:00
|
|
|
"""Required to initialize partitioned weight with numpy array for tests"""
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
def __init__(self, np_array: np.ndarray):
|
|
|
|
self.np_array = np_array
|
2023-04-01 00:36:31 +02:00
|
|
|
|
2023-04-17 06:19:03 +02:00
|
|
|
def __call__(self, shape, dtype=None, partition_info=None) -> np.ndarray:
|
|
|
|
"""Returns a numpy array for the given shape and dtype."""
|
|
|
|
offset = partition_info.var_offset
|
|
|
|
ix0, ix1 = offset[0], offset[0] + shape[0]
|
|
|
|
iy0, iy1 = offset[1], offset[1] + shape[1]
|
|
|
|
if dtype is not None:
|
|
|
|
return self.np_array[ix0:ix1, iy0:iy1].astype(dtype)
|
|
|
|
return self.np_array[ix0:ix1, iy0:iy1]
|