Skip to content

Commit 6b63b13

Browse files
authored
Adds new ESAcceptDocs class and usage, allowing for future use in knn searching (#137750)
Here we have a new ESAcceptDocs implementation that doesn't fully construct the underlying bit set unless absolutely necessary, allowing for approximate cost to be calculated. Also, this adds access to the underlying BitSet, as for some usages, being able to calculate `nextBitSet` can be supremely helpful.
1 parent 15f139b commit 6b63b13

File tree

6 files changed

+419
-9
lines changed

6 files changed

+419
-9
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.lucene.index.SegmentReadState;
1414
import org.apache.lucene.index.VectorEncoding;
1515
import org.apache.lucene.index.VectorSimilarityFunction;
16+
import org.apache.lucene.search.AcceptDocs;
1617
import org.apache.lucene.search.KnnCollector;
1718
import org.apache.lucene.store.IndexInput;
1819
import org.apache.lucene.util.Bits;
@@ -88,6 +89,7 @@ public CentroidIterator getCentroidIterator(
8889
IndexInput centroids,
8990
float[] targetQuery,
9091
IndexInput postingListSlice,
92+
AcceptDocs acceptDocs,
9193
float visitRatio
9294
) throws IOException {
9395
final FieldEntry fieldEntry = fields.get(fieldInfo.number);

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IVFVectorsReader.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.apache.lucene.util.Bits;
3232
import org.elasticsearch.core.IOUtils;
3333
import org.elasticsearch.index.codec.vectors.GenericFlatVectorReaders;
34+
import org.elasticsearch.search.vectors.ESAcceptDocs;
3435
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3536

3637
import java.io.Closeable;
@@ -114,6 +115,7 @@ public abstract CentroidIterator getCentroidIterator(
114115
IndexInput centroids,
115116
float[] target,
116117
IndexInput postingListSlice,
118+
AcceptDocs acceptDocs,
117119
float visitRatio
118120
) throws IOException;
119121

@@ -283,8 +285,17 @@ public final void search(String field, float[] target, KnnCollector knnCollector
283285
"vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension()
284286
);
285287
}
288+
final ESAcceptDocs esAcceptDocs;
289+
if (acceptDocs instanceof ESAcceptDocs) {
290+
esAcceptDocs = (ESAcceptDocs) acceptDocs;
291+
} else {
292+
esAcceptDocs = null;
293+
}
286294
int numVectors = getReaderForField(field).getFloatVectorValues(field).size();
287-
float percentFiltered = Math.max(0f, Math.min(1f, (float) acceptDocs.cost() / numVectors));
295+
float percentFiltered = Math.max(
296+
0f,
297+
Math.min(1f, (float) (esAcceptDocs == null ? acceptDocs.cost() : esAcceptDocs.approximateCost()) / numVectors)
298+
);
288299
float visitRatio = DYNAMIC_VISIT_RATIO;
289300
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
290301
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
@@ -311,6 +322,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
311322
entry.centroidSlice(ivfCentroids),
312323
target,
313324
postListSlice,
325+
esAcceptDocs == null ? acceptDocs : esAcceptDocs,
314326
visitRatio
315327
);
316328
Bits acceptDocsBits = acceptDocs.bits();

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.lucene.index.SegmentReadState;
1414
import org.apache.lucene.index.VectorEncoding;
1515
import org.apache.lucene.index.VectorSimilarityFunction;
16+
import org.apache.lucene.search.AcceptDocs;
1617
import org.apache.lucene.search.KnnCollector;
1718
import org.apache.lucene.store.IndexInput;
1819
import org.apache.lucene.util.Bits;
@@ -87,6 +88,7 @@ public CentroidIterator getCentroidIterator(
8788
IndexInput centroids,
8889
float[] targetQuery,
8990
IndexInput postingListSlice,
91+
AcceptDocs acceptDocs,
9092
float visitRatio
9193
) throws IOException {
9294
final FieldEntry fieldEntry = fields.get(fieldInfo.number);

server/src/main/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQuery.java

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
import org.apache.lucene.search.QueryVisitor;
2626
import org.apache.lucene.search.ScoreDoc;
2727
import org.apache.lucene.search.ScoreMode;
28-
import org.apache.lucene.search.Scorer;
28+
import org.apache.lucene.search.ScorerSupplier;
2929
import org.apache.lucene.search.TaskExecutor;
3030
import org.apache.lucene.search.TopDocs;
3131
import org.apache.lucene.search.TopDocsCollector;
3232
import org.apache.lucene.search.Weight;
3333
import org.apache.lucene.search.knn.KnnCollectorManager;
3434
import org.apache.lucene.search.knn.KnnSearchStrategy;
35+
import org.apache.lucene.util.Bits;
3536
import org.elasticsearch.search.profile.query.QueryProfiler;
3637

3738
import java.io.IOException;
@@ -182,20 +183,31 @@ private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, IVFCollec
182183
TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, IVFCollectorManager knnCollectorManager, float visitRatio)
183184
throws IOException {
184185
final LeafReader reader = ctx.reader();
186+
final Bits liveDocs = reader.getLiveDocs();
187+
final int maxDoc = reader.maxDoc();
185188

186189
if (filterWeight == null) {
187-
AcceptDocs acceptDocs = AcceptDocs.fromLiveDocs(reader.getLiveDocs(), reader.maxDoc());
188-
return approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE, knnCollectorManager, visitRatio);
190+
return approximateSearch(
191+
ctx,
192+
liveDocs == null ? ESAcceptDocs.ESAcceptDocsAll.INSTANCE : new ESAcceptDocs.BitsAcceptDocs(liveDocs, maxDoc),
193+
Integer.MAX_VALUE,
194+
knnCollectorManager,
195+
visitRatio
196+
);
189197
}
190198

191-
Scorer scorer = filterWeight.scorer(ctx);
192-
if (scorer == null) {
199+
ScorerSupplier supplier = filterWeight.scorerSupplier(ctx);
200+
if (supplier == null) {
193201
return TopDocsCollector.EMPTY_TOPDOCS;
194202
}
195203

196-
AcceptDocs acceptDocs = AcceptDocs.fromIteratorSupplier(scorer::iterator, reader.getLiveDocs(), reader.maxDoc());
197-
final int cost = acceptDocs.cost();
198-
return approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager, visitRatio);
204+
return approximateSearch(
205+
ctx,
206+
new ESAcceptDocs.ScorerSupplierAcceptDocs(supplier, liveDocs, maxDoc),
207+
Integer.MAX_VALUE,
208+
knnCollectorManager,
209+
visitRatio
210+
);
199211
}
200212

201213
abstract TopDocs approximateSearch(
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* @notice
3+
* Licensed to the Apache Software Foundation (ASF) under one or more
4+
* contributor license agreements. See the NOTICE file distributed with
5+
* this work for additional information regarding copyright ownership.
6+
* The ASF licenses this file to You under the Apache License, Version 2.0
7+
* (the "License"); you may not use this file except in compliance with
8+
* the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*
18+
* Modifications copyright (C) 2025 Elasticsearch B.V.
19+
*/
20+
package org.elasticsearch.search.vectors;
21+
22+
import org.apache.lucene.search.AcceptDocs;
23+
import org.apache.lucene.search.DocIdSetIterator;
24+
import org.apache.lucene.search.FilteredDocIdSetIterator;
25+
import org.apache.lucene.search.ScorerSupplier;
26+
import org.apache.lucene.util.BitSet;
27+
import org.apache.lucene.util.BitSetIterator;
28+
import org.apache.lucene.util.Bits;
29+
import org.apache.lucene.util.FixedBitSet;
30+
31+
import java.io.IOException;
32+
import java.util.Objects;
33+
import java.util.Optional;
34+
35+
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
36+
37+
/**
38+
* An extension of {@link AcceptDocs} that provides additional methods to get an approximate cost
39+
* and a BitSet representation of the accepted documents.
40+
*/
41+
public abstract sealed class ESAcceptDocs extends AcceptDocs {
42+
43+
/** Returns an approximate cost of the accepted documents.
44+
* This is generally much cheaper than {@link #cost()}, as implementations may
45+
* not fully evaluate filters to provide this estimate and may ignore deletions
46+
* @return the approximate cost
47+
* @throws IOException if an I/O error occurs
48+
*/
49+
public abstract int approximateCost() throws IOException;
50+
51+
/**
52+
* Returns an optional BitSet representing the accepted documents.
53+
* If a BitSet representation is not available, returns an empty Optional. An empty optional indicates that
54+
* there are some accepted documents, but they cannot be represented as a BitSet efficiently.
55+
* Null implies that all documents are accepted.
56+
* @return an Optional containing the BitSet of accepted documents, or empty if not available, or null if all documents are accepted
57+
* @throws IOException if an I/O error occurs
58+
*/
59+
public abstract Optional<BitSet> getBitSet() throws IOException;
60+
61+
private static BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException {
62+
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
63+
// If we already have a BitSet and no deletions, reuse the BitSet
64+
return bitSetIterator.getBitSet();
65+
} else {
66+
int threshold = maxDoc >> 7; // same as BitSet#of
67+
if (iterator.cost() >= threshold) {
68+
FixedBitSet bitSet = new FixedBitSet(maxDoc);
69+
bitSet.or(iterator);
70+
if (liveDocs != null) {
71+
liveDocs.applyMask(bitSet, 0);
72+
}
73+
return bitSet;
74+
} else {
75+
return BitSet.of(liveDocs == null ? iterator : new FilteredDocIdSetIterator(iterator) {
76+
@Override
77+
protected boolean match(int doc) {
78+
return liveDocs.get(doc);
79+
}
80+
}, maxDoc); // create a sparse bitset
81+
}
82+
}
83+
}
84+
85+
/** An AcceptDocs that accepts all documents. */
86+
public static final class ESAcceptDocsAll extends ESAcceptDocs {
87+
public static final ESAcceptDocsAll INSTANCE = new ESAcceptDocsAll();
88+
89+
private ESAcceptDocsAll() {}
90+
91+
@Override
92+
public int approximateCost() throws IOException {
93+
return 0;
94+
}
95+
96+
@Override
97+
public Optional<BitSet> getBitSet() throws IOException {
98+
return null;
99+
}
100+
101+
@Override
102+
public Bits bits() throws IOException {
103+
return null;
104+
}
105+
106+
@Override
107+
public DocIdSetIterator iterator() throws IOException {
108+
return null;
109+
}
110+
111+
@Override
112+
public int cost() throws IOException {
113+
return 0;
114+
}
115+
}
116+
117+
/** An AcceptDocs that wraps a Bits instance. Generally indicates that no filter was provided, but there are deleted docs */
118+
public static final class BitsAcceptDocs extends ESAcceptDocs {
119+
private final Bits bits;
120+
private final BitSet bitSetRef;
121+
private final int maxDoc;
122+
private final int approximateCost;
123+
124+
BitsAcceptDocs(Bits bits, int maxDoc) {
125+
if (bits != null && bits.length() != maxDoc) {
126+
throw new IllegalArgumentException("Bits length = " + bits.length() + " != maxDoc = " + maxDoc);
127+
}
128+
this.bits = Objects.requireNonNull(bits);
129+
if (bits instanceof BitSet bitSet) {
130+
this.maxDoc = Objects.requireNonNull(bitSet).cardinality();
131+
this.approximateCost = Objects.requireNonNull(bitSet).approximateCardinality();
132+
this.bitSetRef = bitSet;
133+
} else {
134+
this.maxDoc = maxDoc;
135+
this.approximateCost = maxDoc;
136+
this.bitSetRef = null;
137+
}
138+
}
139+
140+
@Override
141+
public Bits bits() {
142+
return bits;
143+
}
144+
145+
@Override
146+
public DocIdSetIterator iterator() {
147+
if (bitSetRef != null) {
148+
return new BitSetIterator(bitSetRef, maxDoc);
149+
}
150+
return new FilteredDocIdSetIterator(DocIdSetIterator.all(maxDoc)) {
151+
@Override
152+
protected boolean match(int doc) {
153+
return bits.get(doc);
154+
}
155+
};
156+
}
157+
158+
@Override
159+
public int cost() {
160+
// We have no better estimate. This should be ok in practice since background merges should
161+
// keep the number of deletes under control (< 20% by default).
162+
return maxDoc;
163+
}
164+
165+
@Override
166+
public int approximateCost() {
167+
return approximateCost;
168+
}
169+
170+
@Override
171+
public Optional<BitSet> getBitSet() {
172+
if (bits == null) {
173+
return null;
174+
}
175+
return Optional.ofNullable(bitSetRef);
176+
}
177+
}
178+
179+
/** An AcceptDocs that wraps a ScorerSupplier. Indicates that a filter was provided. */
180+
public static final class ScorerSupplierAcceptDocs extends ESAcceptDocs {
181+
private final ScorerSupplier scorerSupplier;
182+
private BitSet acceptBitSet;
183+
private final Bits liveDocs;
184+
private final int maxDoc;
185+
private int cardinality = -1;
186+
187+
ScorerSupplierAcceptDocs(ScorerSupplier scorerSupplier, Bits liveDocs, int maxDoc) {
188+
this.scorerSupplier = scorerSupplier;
189+
this.liveDocs = liveDocs;
190+
this.maxDoc = maxDoc;
191+
}
192+
193+
private void createBitSetIfNecessary() throws IOException {
194+
if (acceptBitSet == null) {
195+
acceptBitSet = createBitSet(scorerSupplier.get(NO_MORE_DOCS).iterator(), liveDocs, maxDoc);
196+
}
197+
}
198+
199+
@Override
200+
public Bits bits() throws IOException {
201+
createBitSetIfNecessary();
202+
return acceptBitSet;
203+
}
204+
205+
@Override
206+
public DocIdSetIterator iterator() throws IOException {
207+
if (acceptBitSet != null) {
208+
return new BitSetIterator(acceptBitSet, cardinality);
209+
}
210+
return liveDocs == null
211+
? scorerSupplier.get(NO_MORE_DOCS).iterator()
212+
: new FilteredDocIdSetIterator(scorerSupplier.get(NO_MORE_DOCS).iterator()) {
213+
@Override
214+
protected boolean match(int doc) {
215+
return liveDocs.get(doc);
216+
}
217+
};
218+
}
219+
220+
@Override
221+
public int cost() throws IOException {
222+
createBitSetIfNecessary();
223+
if (cardinality == -1) {
224+
cardinality = acceptBitSet.cardinality();
225+
}
226+
return cardinality;
227+
}
228+
229+
@Override
230+
public int approximateCost() throws IOException {
231+
if (acceptBitSet != null) {
232+
return cardinality != -1 ? cardinality : acceptBitSet.approximateCardinality();
233+
}
234+
return Math.toIntExact(scorerSupplier.cost());
235+
}
236+
237+
@Override
238+
public Optional<BitSet> getBitSet() throws IOException {
239+
createBitSetIfNecessary();
240+
return Optional.of(acceptBitSet);
241+
}
242+
}
243+
}

0 commit comments

Comments
 (0)