the-algorithm/src/java/com/twitter/search/common/query/FilteredQuery.java

226 lines
7.5 KiB
Java

package com.twitter.search.common.query;
import java.io.IOException;
import java.util.Set;
import com.google.common.base.Preconditions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
/**
* A pairing of a query and a filter. The hits traversal is driven by the query's DocIdSetIterator,
* and the filter is used only to do post-filtering. In other words, the filter is never used to
* find the next doc ID: it's only used to filter out the doc IDs returned by the query's
* DocIdSetIterator. This is useful when we need to have a conjunction between a query that can
* quickly iterate through doc IDs (eg. a posting list), and an expensive filter (eg. a filter based
* on the values stored in a CSF).
*
* For example, let say we want to build a query that returns all docs that have at least 100 faves.
* 1. One option is to go with the [min_faves 100] query. This would be very expensive though,
* because this query would have to walk through every doc in the segment and for each one of
* them it would have to extract the number of faves from the forward index.
* 2. Another option is to go with a conjunction between this query and the HAS_ENGAGEMENT filter:
* (+[min_faves 100] +[cached_filter has_engagements]). The HAS_ENGAGEMENT filter could
* traverse the doc ID space faster (if it's backed by a posting list). But this approach would
* still be slow, because as soon as the HAS_ENGAGEMENT filter finds a doc ID, the conjunction
* scorer would trigger an advance(docID) call on the min_faves part of the query, which has
* the same problem as the first option.
* 3. Finally, a better option for this particular case would be to drive by the HAS_ENGAGEMENT
* filter (because it can quickly jump over all docs that do not have any engagement), and use
* the min_faves filter as a post-processing step, on a much smaller set of docs.
*/
public class FilteredQuery extends Query {
/**
* A doc ID predicate that determines if the given doc ID should be accepted.
*/
@FunctionalInterface
public static interface DocIdFilter {
/**
* Determines if the given doc ID should be accepted.
*/
boolean accept(int docId) throws IOException;
}
/**
* A factory for creating DocIdFilter instances based on a given LeafReaderContext instance.
*/
@FunctionalInterface
public static interface DocIdFilterFactory {
/**
* Returns a DocIdFilter instance for the given LeafReaderContext instance.
*/
DocIdFilter getDocIdFilter(LeafReaderContext context) throws IOException;
}
private static class FilteredQueryDocIdSetIterator extends DocIdSetIterator {
private final DocIdSetIterator queryScorerIterator;
private final DocIdFilter docIdFilter;
public FilteredQueryDocIdSetIterator(
DocIdSetIterator queryScorerIterator, DocIdFilter docIdFilter) {
this.queryScorerIterator = Preconditions.checkNotNull(queryScorerIterator);
this.docIdFilter = Preconditions.checkNotNull(docIdFilter);
}
@Override
public int docID() {
return queryScorerIterator.docID();
}
@Override
public int nextDoc() throws IOException {
int docId;
do {
docId = queryScorerIterator.nextDoc();
} while (docId != NO_MORE_DOCS && !docIdFilter.accept(docId));
return docId;
}
@Override
public int advance(int target) throws IOException {
int docId = queryScorerIterator.advance(target);
if (docId == NO_MORE_DOCS || docIdFilter.accept(docId)) {
return docId;
}
return nextDoc();
}
@Override
public long cost() {
return queryScorerIterator.cost();
}
}
private static class FilteredQueryScorer extends Scorer {
private final Scorer queryScorer;
private final DocIdFilter docIdFilter;
public FilteredQueryScorer(Weight weight, Scorer queryScorer, DocIdFilter docIdFilter) {
super(weight);
this.queryScorer = Preconditions.checkNotNull(queryScorer);
this.docIdFilter = Preconditions.checkNotNull(docIdFilter);
}
@Override
public int docID() {
return queryScorer.docID();
}
@Override
public float score() throws IOException {
return queryScorer.score();
}
@Override
public DocIdSetIterator iterator() {
return new FilteredQueryDocIdSetIterator(queryScorer.iterator(), docIdFilter);
}
@Override
public float getMaxScore(int upTo) throws IOException {
return queryScorer.getMaxScore(upTo);
}
}
private static class FilteredQueryWeight extends Weight {
private final Weight queryWeight;
private final DocIdFilterFactory docIdFilterFactory;
public FilteredQueryWeight(
FilteredQuery query, Weight queryWeight, DocIdFilterFactory docIdFilterFactory) {
super(query);
this.queryWeight = Preconditions.checkNotNull(queryWeight);
this.docIdFilterFactory = Preconditions.checkNotNull(docIdFilterFactory);
}
@Override
public void extractTerms(Set<Term> terms) {
queryWeight.extractTerms(terms);
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return queryWeight.explain(context, doc);
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
Scorer queryScorer = queryWeight.scorer(context);
if (queryScorer == null) {
return null;
}
return new FilteredQueryScorer(this, queryScorer, docIdFilterFactory.getDocIdFilter(context));
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return queryWeight.isCacheable(ctx);
}
}
private final Query query;
private final DocIdFilterFactory docIdFilterFactory;
public FilteredQuery(Query query, DocIdFilterFactory docIdFilterFactory) {
this.query = Preconditions.checkNotNull(query);
this.docIdFilterFactory = Preconditions.checkNotNull(docIdFilterFactory);
}
public Query getQuery() {
return query;
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
Query rewrittenQuery = query.rewrite(reader);
if (rewrittenQuery != query) {
return new FilteredQuery(rewrittenQuery, docIdFilterFactory);
}
return this;
}
@Override
public int hashCode() {
return query.hashCode() * 13 + docIdFilterFactory.hashCode();
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof FilteredQuery)) {
return false;
}
FilteredQuery filteredQuery = FilteredQuery.class.cast(obj);
return query.equals(filteredQuery.query)
&& docIdFilterFactory.equals(filteredQuery.docIdFilterFactory);
}
@Override
public String toString(String field) {
StringBuilder sb = new StringBuilder();
sb.append("FilteredQuery(")
.append(query)
.append(" -> ")
.append(docIdFilterFactory)
.append(")");
return sb.toString();
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
Weight queryWeight = Preconditions.checkNotNull(query.createWeight(searcher, scoreMode, boost));
return new FilteredQueryWeight(this, queryWeight, docIdFilterFactory);
}
}