the-algorithm/pushservice/src/main/python/models/libs/graph_utils.py
twitter-team b389c3d302 Open-sourcing pushservice
Pushservice is the main recommendation service we use to surface recommendations to our users via notifications. It fetches candidates from various sources, ranks them in order of relevance, and applies filters to determine the best one to send.
2023-05-19 16:27:07 -05:00

43 lines
1.5 KiB
Python

"""
Utilties that aid in building the magic recs graph.
"""
import re
import tensorflow.compat.v1 as tf
def get_trainable_variables(all_trainable_variables, trainable_regexes):
"""Returns a subset of trainable variables for training.
Given a collection of trainable variables, this will return all those that match the given regexes.
Will also log those variables.
Args:
all_trainable_variables (a collection of trainable tf.Variable): The variables to search through.
trainable_regexes (a collection of regexes): Variables that match any regex will be included.
Returns a list of tf.Variable
"""
if trainable_regexes is None or len(trainable_regexes) == 0:
tf.logging.info("No trainable regexes found. Not using get_trainable_variables behavior.")
return None
assert any(
tf.is_tensor(var) for var in all_trainable_variables
), f"Non TF variable found: {all_trainable_variables}"
trainable_variables = list(
filter(
lambda var: any(re.match(regex, var.name, re.IGNORECASE) for regex in trainable_regexes),
all_trainable_variables,
)
)
tf.logging.info(f"Using filtered trainable variables: {trainable_variables}")
assert (
trainable_variables
), "Did not find trainable variables after filtering after filtering from {} number of vars originaly. All vars: {} and train regexes: {}".format(
len(all_trainable_variables), all_trainable_variables, trainable_regexes
)
return trainable_variables