mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-06-10 13:18:47 +02:00
ef4c5eb65e
Please note we have force-pushed a new initial commit in order to remove some publicly-available Twitter user information. Note that this process may be required in the future.
35 lines
1.4 KiB
Python
35 lines
1.4 KiB
Python
from .hashing_utils import make_feature_id, numpy_hashing_uniform
|
|
|
|
import numpy as np
|
|
import tensorflow.compat.v1 as tf
|
|
import twml
|
|
|
|
|
|
class TFModelWeightsInitializerBuilder(object):
|
|
def __init__(self, num_bits):
|
|
self.num_bits = num_bits
|
|
|
|
def build(self, tf_model_initializer):
|
|
'''
|
|
:return: (bias_initializer, weight_initializer)
|
|
'''
|
|
initial_weights = np.zeros((2 ** self.num_bits, 1))
|
|
|
|
features = tf_model_initializer["features"]
|
|
self._set_binary_feature_weights(initial_weights, features["binary"])
|
|
self._set_discretized_feature_weights(initial_weights, features["discretized"])
|
|
|
|
return tf.constant_initializer(features["bias"]), twml.contrib.initializers.PartitionConstant(initial_weights)
|
|
|
|
def _set_binary_feature_weights(self, initial_weights, binary_features):
|
|
for feature_name, weight in binary_features.items():
|
|
feature_id = make_feature_id(feature_name, self.num_bits)
|
|
initial_weights[feature_id][0] = weight
|
|
|
|
def _set_discretized_feature_weights(self, initial_weights, discretized_features):
|
|
for feature_name, discretized_feature in discretized_features.items():
|
|
feature_id = make_feature_id(feature_name, self.num_bits)
|
|
for bin_idx, weight in enumerate(discretized_feature["weights"]):
|
|
final_bucket_id = numpy_hashing_uniform(feature_id, bin_idx, self.num_bits)
|
|
initial_weights[final_bucket_id][0] = weight
|