the-algorithm/src/java/com/twitter/search/common/util/ml/prediction_engine/BaseModelBuilder.java
2023-03-31 22:16:43 -04:00

112 lines
3.7 KiB
Java

package com.twitter.search.common.util.ml.prediction_engine;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import com.google.common.collect.Lists;
import com.twitter.ml.api.FeatureParser;
import com.twitter.ml.api.transform.DiscretizerTransform;
import com.twitter.ml.tool.prediction.ModelInterpreter;
/**
* The base model builder for LightweightLinearModels.
*/
public abstract class BaseModelBuilder implements ModelBuilder {
// Ignore features that have an absolute weight lower than this value
protected static final double MIN_WEIGHT = 1e-9;
private static final String BIAS_FIELD_NAME = ModelInterpreter.BIAS_FIELD_NAME;
static final String DISCRETIZER_NAME_SUFFIX =
"." + DiscretizerTransform.DEFAULT_FEATURE_NAME_SUFFIX;
protected final String modelName;
protected double bias;
public BaseModelBuilder(String modelName) {
this.modelName = modelName;
this.bias = 0.0;
}
/**
* Collects all the ranges of a discretized feature and sorts them.
*/
static DiscretizedFeature buildFeature(Collection<DiscretizedFeatureRange> ranges) {
List<DiscretizedFeatureRange> sortedRanges = Lists.newArrayList(ranges);
sortedRanges.sort(Comparator.comparingDouble(a -> a.minValue));
double[] splits = new double[ranges.size()];
double[] weights = new double[ranges.size()];
for (int i = 0; i < sortedRanges.size(); i++) {
splits[i] = sortedRanges.get(i).minValue;
weights[i] = sortedRanges.get(i).weight;
}
return new DiscretizedFeature(splits, weights);
}
/**
* Parses a line from the interpreted model text file. See the javadoc of the constructor for
* more details about how to create the text file.
* <p>
* The file uses TSV format with 3 columns:
* <p>
* Model name (Generated by ML API, but ignored by this class)
* Feature definition:
* Name of the feature or definition from the MDL discretizer
* Weight:
* Weight of the feature using LOGIT scale.
* <p>
* When it parses each line, it stores the weights for all the features defined in the context,
* as well as the bias, but it ignores any other feature (e.g. label, prediction or
* meta.record_weight) and features with a small absolute weight (see MIN_WEIGHT).
* <p>
* Example lines:
* <p>
* model_name bias 0.019735312089324074
* model_name demo.binary_feature 0.06524706073105327
* model_name demo.continuous_feature 0.0
* model_name demo.continuous_feature.dz/dz_model=mdl/dz_range=-inf_3.58e-01 0.07155931927263737
* model_name demo.continuous_feature.dz/dz_model=mdl/dz_range=3.58e-01_inf -0.08979256264865387
*
* @see ModelInterpreter
* @see DiscretizerTransform
*/
@Override
public ModelBuilder parseLine(String line) {
String[] columns = line.split("\t");
if (columns.length != 3) {
return this;
}
// columns[0] has the model name, which we don't need
String featureName = columns[1];
double weight = Double.parseDouble(columns[2]);
if (BIAS_FIELD_NAME.equals(featureName)) {
bias = weight;
return this;
}
FeatureParser parser = FeatureParser.parse(featureName);
String baseName = parser.getBaseName();
if (Math.abs(weight) < MIN_WEIGHT && !baseName.endsWith(DISCRETIZER_NAME_SUFFIX)) {
// skip, unless it represents a range of a discretized feature.
// discretized features with all zeros should also be removed, but will handle that later
return this;
}
addFeature(baseName, weight, parser);
return this;
}
/**
* Adds feature to the model
*/
protected abstract void addFeature(String baseName, double weight, FeatureParser parser);
@Override
public abstract LightweightLinearModel build();
}