from collections import OrderedDict import json import os from os.path import join from twitter.magicpony.common import file_access import twml from .model_utils import read_config import numpy as np from scipy import stats import tensorflow.compat.v1 as tf # checkstyle: noqa def get_model_type_to_tensors_to_change_axis(): model_type_to_tensors_to_change_axis = { "magic_recs/model/batch_normalization/beta": ([0], "continuous"), "magic_recs/model/batch_normalization/gamma": ([0], "continuous"), "magic_recs/model/batch_normalization/moving_mean": ([0], "continuous"), "magic_recs/model/batch_normalization/moving_stddev": ([0], "continuous"), "magic_recs/model/batch_normalization/moving_variance": ([0], "continuous"), "magic_recs/model/batch_normalization/renorm_mean": ([0], "continuous"), "magic_recs/model/batch_normalization/renorm_stddev": ([0], "continuous"), "magic_recs/model/logits/EngagementGivenOONC_logits/clem_net_1/block2_4/channel_wise_dense_4/kernel": ( [1], "all", ), "magic_recs/model/logits/OONC_logits/clem_net/block2/channel_wise_dense/kernel": ([1], "all"), } return model_type_to_tensors_to_change_axis def mkdirp(dirname): if not tf.io.gfile.exists(dirname): tf.io.gfile.makedirs(dirname) def rename_dir(dirname, dst): file_access.hdfs.mv(dirname, dst) def rmdir(dirname): if tf.io.gfile.exists(dirname): if tf.io.gfile.isdir(dirname): tf.io.gfile.rmtree(dirname) else: tf.io.gfile.remove(dirname) def get_var_dict(checkpoint_path): checkpoint = tf.train.get_checkpoint_state(checkpoint_path) var_dict = OrderedDict() with tf.Session() as sess: all_var_list = tf.train.list_variables(checkpoint_path) for var_name, _ in all_var_list: # Load the variable var = tf.train.load_variable(checkpoint_path, var_name) var_dict[var_name] = var return var_dict def get_continunous_mapping_from_feat_list(old_feature_list, new_feature_list): """ get var_ind for old_feature and corresponding var_ind for new_feature """ new_var_ind, old_var_ind = [], [] for this_new_id, this_new_name in enumerate(new_feature_list): if this_new_name in old_feature_list: this_old_id = old_feature_list.index(this_new_name) new_var_ind.append(this_new_id) old_var_ind.append(this_old_id) return np.asarray(old_var_ind), np.asarray(new_var_ind) def get_continuous_mapping_from_feat_dict(old_feature_dict, new_feature_dict): """ get var_ind for old_feature and corresponding var_ind for new_feature """ old_cont = old_feature_dict["continuous"] old_bin = old_feature_dict["binary"] new_cont = new_feature_dict["continuous"] new_bin = new_feature_dict["binary"] _dummy_sparse_feat = [f"sparse_feature_{_idx}" for _idx in range(100)] cont_old_var_ind, cont_new_var_ind = get_continunous_mapping_from_feat_list(old_cont, new_cont) all_old_var_ind, all_new_var_ind = get_continunous_mapping_from_feat_list( old_cont + old_bin + _dummy_sparse_feat, new_cont + new_bin + _dummy_sparse_feat ) _res = { "continuous": (cont_old_var_ind, cont_new_var_ind), "all": (all_old_var_ind, all_new_var_ind), } return _res def warm_start_from_var_dict( old_ckpt_path, var_ind_dict, output_dir, new_len_var, var_to_change_dict_fn=get_model_type_to_tensors_to_change_axis, ): """ Parameters: old_ckpt_path (str): path to the old checkpoint path new_var_ind (array of int): index to overlapping features in new var between old and new feature list. old_var_ind (array of int): index to overlapping features in old var between old and new feature list. output_dir (str): dir that used to write modified checkpoint new_len_var ({str:int}): number of feature in the new feature list. var_to_change_dict_fn (dict): A function to get the dictionary of format {var_name: dim_to_change} """ old_var_dict = get_var_dict(old_ckpt_path) ckpt_file_name = os.path.basename(old_ckpt_path) mkdirp(output_dir) output_path = join(output_dir, ckpt_file_name) tensors_to_change = var_to_change_dict_fn() tf.compat.v1.reset_default_graph() with tf.Session() as sess: var_name_shape_list = tf.train.list_variables(old_ckpt_path) count = 0 for var_name, var_shape in var_name_shape_list: old_var = old_var_dict[var_name] if var_name in tensors_to_change.keys(): _info_tuple = tensors_to_change[var_name] dims_to_remove_from, var_type = _info_tuple new_var_ind, old_var_ind = var_ind_dict[var_type] this_shape = list(old_var.shape) for this_dim in dims_to_remove_from: this_shape[this_dim] = new_len_var[var_type] stddev = np.std(old_var) truncated_norm_generator = stats.truncnorm(-0.5, 0.5, loc=0, scale=stddev) size = np.prod(this_shape) new_var = truncated_norm_generator.rvs(size).reshape(this_shape) new_var = new_var.astype(old_var.dtype) new_var = copy_feat_based_on_mapping( new_var, old_var, dims_to_remove_from, new_var_ind, old_var_ind ) count = count + 1 else: new_var = old_var var = tf.Variable(new_var, name=var_name) assert count == len(tensors_to_change.keys()), "not all variables are exchanged.\n" saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.save(sess, output_path) return output_path def copy_feat_based_on_mapping(new_array, old_array, dims_to_remove_from, new_var_ind, old_var_ind): if dims_to_remove_from == [0, 1]: for this_new_ind, this_old_ind in zip(new_var_ind, old_var_ind): new_array[this_new_ind, new_var_ind] = old_array[this_old_ind, old_var_ind] elif dims_to_remove_from == [0]: new_array[new_var_ind] = old_array[old_var_ind] elif dims_to_remove_from == [1]: new_array[:, new_var_ind] = old_array[:, old_var_ind] else: raise RuntimeError(f"undefined dims_to_remove_from pattern: ({dims_to_remove_from})") return new_array def read_file(filename, decode=False): """ Reads contents from a file and optionally decodes it. Arguments: filename: path to file where the contents will be loaded from. Accepts HDFS and local paths. decode: False or 'json'. When decode='json', contents is decoded with json.loads. When False, contents is returned as is. """ graph = tf.Graph() with graph.as_default(): read = tf.read_file(filename) with tf.Session(graph=graph) as sess: contents = sess.run(read) if not isinstance(contents, str): contents = contents.decode() if decode == "json": contents = json.loads(contents) return contents def read_feat_list_from_disk(file_path): return read_file(file_path, decode="json") def get_feature_list_for_light_ranking(feature_list_path, data_spec_path): feature_list = read_config(feature_list_path).items() string_feat_list = [f[0] for f in feature_list if f[1] != "S"] feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder( data_spec_path=data_spec_path ) feature_config_builder = feature_config_builder.extract_feature_group( feature_regexes=string_feat_list, group_name="continuous", default_value=-1, type_filter=["CONTINUOUS"], ) feature_config = feature_config_builder.build() feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[ "CONTINUOUS" ] return feature_list def get_feature_list_for_heavy_ranking(feature_list_path, data_spec_path): feature_list = read_config(feature_list_path).items() string_feat_list = [f[0] for f in feature_list if f[1] != "S"] feature_config_builder = twml.contrib.feature_config.FeatureConfigBuilder( data_spec_path=data_spec_path ) feature_config_builder = feature_config_builder.extract_feature_group( feature_regexes=string_feat_list, group_name="continuous", default_value=-1, type_filter=["CONTINUOUS"], ) feature_config_builder = feature_config_builder.extract_feature_group( feature_regexes=string_feat_list, group_name="binary", default_value=False, type_filter=["BINARY"], ) feature_config_builder = feature_config_builder.build() continuous_feature_list = feature_config_builder._feature_group_extraction_configs[0].feature_map[ "CONTINUOUS" ] binary_feature_list = feature_config_builder._feature_group_extraction_configs[1].feature_map[ "BINARY" ] return {"continuous": continuous_feature_list, "binary": binary_feature_list} def warm_start_checkpoint( old_best_ckpt_folder, old_feature_list_path, feature_allow_list_path, data_spec_path, output_ckpt_folder, *args, ): """ Reads old checkpoint and the old feature list, and create a new ckpt warm started from old ckpt using new features . Arguments: old_best_ckpt_folder: path to the best_checkpoint_folder for old model old_feature_list_path: path to the json file that stores the list of continuous features used in old models. feature_allow_list_path: yaml file that contain the feature allow list. data_spec_path: path to the data_spec file output_ckpt_folder: folder that contains the modified ckpt. Returns: path to the modified ckpt.""" old_ckpt_path = tf.train.latest_checkpoint(old_best_ckpt_folder, latest_filename=None) new_feature_dict = get_feature_list(feature_allow_list_path, data_spec_path) old_feature_dict = read_feat_list_from_disk(old_feature_list_path) var_ind_dict = get_continuous_mapping_from_feat_dict(new_feature_dict, old_feature_dict) new_len_var = { "continuous": len(new_feature_dict["continuous"]), "all": len(new_feature_dict["continuous"] + new_feature_dict["binary"]) + 100, } warm_started_ckpt_path = warm_start_from_var_dict( old_ckpt_path, var_ind_dict, output_dir=output_ckpt_folder, new_len_var=new_len_var, ) return warm_started_ckpt_path