Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get better cost estimate on MultiTermQuery over few terms #13201

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,38 +154,17 @@ protected abstract WeightOrDocIdSetIterator rewriteInner(
List<TermAndState> collectedTerms)
throws IOException;

private IOSupplier<WeightOrDocIdSetIterator> rewrite(LeafReaderContext context, Terms terms)
throws IOException {
assert terms != null;

final int fieldDocCount = terms.getDocCount();
final TermsEnum termsEnum = q.getTermsEnum(terms);
assert termsEnum != null;

final List<TermAndState> collectedTerms = new ArrayList<>();
boolean collectResult = collectTerms(fieldDocCount, termsEnum, collectedTerms);
if (collectResult && collectedTerms.isEmpty()) {
return null;
private WeightOrDocIdSetIterator rewriteAsBooleanQuery(
LeafReaderContext context, List<TermAndState> collectedTerms) throws IOException {
BooleanQuery.Builder bq = new BooleanQuery.Builder();
for (TermAndState t : collectedTerms) {
final TermStates termStates = new TermStates(searcher.getTopReaderContext());
termStates.register(t.state, context.ord, t.docFreq, t.totalTermFreq);
bq.add(new TermQuery(new Term(q.field, t.term), termStates), BooleanClause.Occur.SHOULD);
}
return () -> {
if (collectResult) {
// build a boolean query
BooleanQuery.Builder bq = new BooleanQuery.Builder();
for (TermAndState t : collectedTerms) {
final TermStates termStates = new TermStates(searcher.getTopReaderContext());
termStates.register(t.state, context.ord, t.docFreq, t.totalTermFreq);
bq.add(
new TermQuery(new Term(q.field, t.term), termStates), BooleanClause.Occur.SHOULD);
}
Query q = new ConstantScoreQuery(bq.build());
final Weight weight = searcher.rewrite(q).createWeight(searcher, scoreMode, score());
return new WeightOrDocIdSetIterator(weight);
} else {
// Too many terms to rewrite as a simple bq.
// Invoke rewriteInner logic to handle rewriting:
return rewriteInner(context, fieldDocCount, terms, termsEnum, collectedTerms);
}
};
Query q = new ConstantScoreQuery(bq.build());
final Weight weight = searcher.rewrite(q).createWeight(searcher, scoreMode, score());
return new WeightOrDocIdSetIterator(weight);
}

private boolean collectTerms(int fieldDocCount, TermsEnum termsEnum, List<TermAndState> terms)
Expand Down Expand Up @@ -240,9 +219,38 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
return null;
}

final long cost = estimateCost(terms, q.getTermsCount());
IOSupplier<WeightOrDocIdSetIterator> weightOrIteratorSupplier = rewrite(context, terms);
if (weightOrIteratorSupplier == null) return null;
assert terms != null;

final int fieldDocCount = terms.getDocCount();
final TermsEnum termsEnum = q.getTermsEnum(terms);
assert termsEnum != null;

List<TermAndState> collectedTerms = new ArrayList<>();
boolean collectResult = collectTerms(fieldDocCount, termsEnum, collectedTerms);

if (collectResult && collectedTerms.isEmpty()) return null;

final long cost;
if (collectResult) {
long sumTermCost = 0;
for (TermAndState collectedTerm : collectedTerms) {
sumTermCost += collectedTerm.docFreq;
}
cost = sumTermCost;
} else {
cost = estimateCost(terms, q.getTermsCount());
}

IOSupplier<WeightOrDocIdSetIterator> weightOrIteratorSupplier =
() -> {
if (collectResult) {
return rewriteAsBooleanQuery(context, collectedTerms);
} else {
// Too many terms to rewrite as a simple bq.
// Invoke rewriteInner logic to handle rewriting:
return rewriteInner(context, fieldDocCount, terms, termsEnum, collectedTerms);
}
};

return new ScorerSupplier() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.lucene.document.Field;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
Expand Down Expand Up @@ -418,4 +419,42 @@ public void testLarge() throws IOException {
reader.close();
dir.close();
}

public void testCostEstimate() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), dir);
for (int i = 0; i < 1000; i++) {
Document doc = new Document();
doc.add(newStringField("body", "foo bar", Field.Store.NO));
writer.addDocument(doc);
doc = new Document();
doc.add(newStringField("body", "foo wuzzle", Field.Store.NO));
writer.addDocument(doc);
doc = new Document();
doc.add(newStringField("body", "bar " + i, Field.Store.NO));
writer.addDocument(doc);
}
writer.flush();
writer.forceMerge(1);
writer.close();

IndexReader reader = DirectoryReader.open(dir);
IndexSearcher searcher = newSearcher(reader);
LeafReaderContext lrc = reader.leaves().get(0);

WildcardQuery query = new WildcardQuery(new Term("body", "foo*"));
Query rewritten = searcher.rewrite(query);
Weight weight = rewritten.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
ScorerSupplier supplier = weight.scorerSupplier(lrc);
assertEquals(2000, supplier.cost()); // Sum the terms doc freqs

query = new WildcardQuery(new Term("body", "bar*"));
rewritten = searcher.rewrite(query);
weight = rewritten.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
supplier = weight.scorerSupplier(lrc);
assertEquals(3000, supplier.cost()); // Too many terms, assume worst-case all terms match

reader.close();
dir.close();
}
}