mirror of
https://github.com/twitter/the-algorithm.git
synced 2024-06-02 17:28:45 +02:00
112 lines
3.7 KiB
Java
112 lines
3.7 KiB
Java
package com.twitter.search.common.util.ml.prediction_engine;
|
|
|
|
import java.util.Collection;
|
|
import java.util.Comparator;
|
|
import java.util.List;
|
|
|
|
import com.google.common.collect.Lists;
|
|
|
|
import com.twitter.ml.api.FeatureParser;
|
|
import com.twitter.ml.api.transform.DiscretizerTransform;
|
|
import com.twitter.ml.tool.prediction.ModelInterpreter;
|
|
|
|
/**
|
|
* The base model builder for LightweightLinearModels.
|
|
*/
|
|
public abstract class BaseModelBuilder implements ModelBuilder {
|
|
// Ignore features that have an absolute weight lower than this value
|
|
protected static final double MIN_WEIGHT = 1e-9;
|
|
private static final String BIAS_FIELD_NAME = ModelInterpreter.BIAS_FIELD_NAME;
|
|
static final String DISCRETIZER_NAME_SUFFIX =
|
|
"." + DiscretizerTransform.DEFAULT_FEATURE_NAME_SUFFIX;
|
|
|
|
protected final String modelName;
|
|
protected double bias;
|
|
|
|
public BaseModelBuilder(String modelName) {
|
|
this.modelName = modelName;
|
|
this.bias = 0.0;
|
|
}
|
|
|
|
/**
|
|
* Collects all the ranges of a discretized feature and sorts them.
|
|
*/
|
|
static DiscretizedFeature buildFeature(Collection<DiscretizedFeatureRange> ranges) {
|
|
List<DiscretizedFeatureRange> sortedRanges = Lists.newArrayList(ranges);
|
|
sortedRanges.sort(Comparator.comparingDouble(a -> a.minValue));
|
|
|
|
double[] splits = new double[ranges.size()];
|
|
double[] weights = new double[ranges.size()];
|
|
|
|
for (int i = 0; i < sortedRanges.size(); i++) {
|
|
splits[i] = sortedRanges.get(i).minValue;
|
|
weights[i] = sortedRanges.get(i).weight;
|
|
}
|
|
return new DiscretizedFeature(splits, weights);
|
|
}
|
|
|
|
/**
|
|
* Parses a line from the interpreted model text file. See the javadoc of the constructor for
|
|
* more details about how to create the text file.
|
|
* <p>
|
|
* The file uses TSV format with 3 columns:
|
|
* <p>
|
|
* Model name (Generated by ML API, but ignored by this class)
|
|
* Feature definition:
|
|
* Name of the feature or definition from the MDL discretizer
|
|
* Weight:
|
|
* Weight of the feature using LOGIT scale.
|
|
* <p>
|
|
* When it parses each line, it stores the weights for all the features defined in the context,
|
|
* as well as the bias, but it ignores any other feature (e.g. label, prediction or
|
|
* meta.record_weight) and features with a small absolute weight (see MIN_WEIGHT).
|
|
* <p>
|
|
* Example lines:
|
|
* <p>
|
|
* model_name bias 0.019735312089324074
|
|
* model_name demo.binary_feature 0.06524706073105327
|
|
* model_name demo.continuous_feature 0.0
|
|
* model_name demo.continuous_feature.dz/dz_model=mdl/dz_range=-inf_3.58e-01 0.07155931927263737
|
|
* model_name demo.continuous_feature.dz/dz_model=mdl/dz_range=3.58e-01_inf -0.08979256264865387
|
|
*
|
|
* @see ModelInterpreter
|
|
* @see DiscretizerTransform
|
|
*/
|
|
@Override
|
|
public ModelBuilder parseLine(String line) {
|
|
String[] columns = line.split("\t");
|
|
if (columns.length != 3) {
|
|
return this;
|
|
}
|
|
|
|
// columns[0] has the model name, which we don't need
|
|
String featureName = columns[1];
|
|
double weight = Double.parseDouble(columns[2]);
|
|
|
|
if (BIAS_FIELD_NAME.equals(featureName)) {
|
|
bias = weight;
|
|
return this;
|
|
}
|
|
|
|
FeatureParser parser = FeatureParser.parse(featureName);
|
|
String baseName = parser.getBaseName();
|
|
|
|
if (Math.abs(weight) < MIN_WEIGHT && !baseName.endsWith(DISCRETIZER_NAME_SUFFIX)) {
|
|
// skip, unless it represents a range of a discretized feature.
|
|
// discretized features with all zeros should also be removed, but will handle that later
|
|
return this;
|
|
}
|
|
|
|
addFeature(baseName, weight, parser);
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* Adds feature to the model
|
|
*/
|
|
protected abstract void addFeature(String baseName, double weight, FeatureParser parser);
|
|
|
|
@Override
|
|
public abstract LightweightLinearModel build();
|
|
}
|