diff --git a/server/src/main/java/org/elasticsearch/index/query/BoolQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/BoolQueryBuilder.java index 5944fc3d8df7a..635b7fa3cb056 100644 --- a/server/src/main/java/org/elasticsearch/index/query/BoolQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/BoolQueryBuilder.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.index.query.support.AutoPrefilteringScope; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -26,10 +27,13 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.common.lucene.search.Queries.fixNegativeQueryIfNeeded; @@ -299,16 +303,17 @@ public String getWriteableName() { @Override protected Query doToQuery(SearchExecutionContext context) throws IOException { BooleanQuery.Builder booleanQueryBuilder = new BooleanQuery.Builder(); - addBooleanClauses(context, booleanQueryBuilder, mustClauses, BooleanClause.Occur.MUST); + final List prefilters = collectPrefilters(); + addBooleanClauses(context, booleanQueryBuilder, mustClauses, BooleanClause.Occur.MUST, prefilters); try { // disable tracking of the @timestamp range for must_not and should clauses context.setTrackTimeRangeFilterFrom(false); - addBooleanClauses(context, booleanQueryBuilder, mustNotClauses, BooleanClause.Occur.MUST_NOT); - addBooleanClauses(context, booleanQueryBuilder, shouldClauses, BooleanClause.Occur.SHOULD); + addBooleanClauses(context, booleanQueryBuilder, mustNotClauses, BooleanClause.Occur.MUST_NOT, List.of()); + addBooleanClauses(context, booleanQueryBuilder, shouldClauses, BooleanClause.Occur.SHOULD, prefilters); } finally { context.setTrackTimeRangeFilterFrom(true); } - addBooleanClauses(context, booleanQueryBuilder, filterClauses, BooleanClause.Occur.FILTER); + addBooleanClauses(context, booleanQueryBuilder, filterClauses, BooleanClause.Occur.FILTER, List.of()); BooleanQuery booleanQuery = booleanQueryBuilder.build(); if (booleanQuery.clauses().isEmpty()) { return new MatchAllDocsQuery(); @@ -318,15 +323,25 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { return adjustPureNegative ? fixNegativeQueryIfNeeded(query) : query; } + private List collectPrefilters() { + return Stream.of(mustClauses, mustNotClauses.stream().map(c -> QueryBuilders.boolQuery().mustNot(c)).toList(), filterClauses) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + } + private static void addBooleanClauses( SearchExecutionContext context, BooleanQuery.Builder booleanQueryBuilder, List clauses, - Occur occurs + Occur occurs, + List prefilters ) throws IOException { for (QueryBuilder query : clauses) { - Query luceneQuery = query.toQuery(context); - booleanQueryBuilder.add(new BooleanClause(luceneQuery, occurs)); + try (AutoPrefilteringScope autoPrefilteringScope = context.autoPrefilteringScope()) { + autoPrefilteringScope.push(prefilters.stream().filter(c -> c != query).toList()); + Query luceneQuery = query.toQuery(context); + booleanQueryBuilder.add(new BooleanClause(luceneQuery, occurs)); + } } } diff --git a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java index d50c4f0f618f5..3150470f57472 100644 --- a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java @@ -49,6 +49,7 @@ import org.elasticsearch.index.mapper.ParsedDocument; import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceToParse; +import org.elasticsearch.index.query.support.AutoPrefilteringScope; import org.elasticsearch.index.query.support.NestedScope; import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.script.Script; @@ -103,6 +104,7 @@ public class SearchExecutionContext extends QueryRewriteContext { private final Map namedQueries = new HashMap<>(); private NestedScope nestedScope; + private AutoPrefilteringScope autoPrefilteringScope; private QueryBuilder aliasFilter; private boolean rewriteToNamedQueries = false; @@ -291,6 +293,7 @@ private SearchExecutionContext( this.bitsetFilterCache = bitsetFilterCache; this.indexFieldDataLookup = indexFieldDataLookup; this.nestedScope = new NestedScope(); + this.autoPrefilteringScope = new AutoPrefilteringScope(); this.searcher = searcher; this.requestSize = requestSize; this.mapperMetrics = mapperMetrics; @@ -301,7 +304,7 @@ private void reset() { this.lookup = null; this.namedQueries.clear(); this.nestedScope = new NestedScope(); - + this.autoPrefilteringScope = new AutoPrefilteringScope(); } // Set alias filter, so it can be applied for queries that need it (e.g. knn query) @@ -556,6 +559,10 @@ public NestedScope nestedScope() { return nestedScope; } + public AutoPrefilteringScope autoPrefilteringScope() { + return autoPrefilteringScope; + } + public IndexVersion indexVersionCreated() { return indexSettings.getIndexVersionCreated(); } diff --git a/server/src/main/java/org/elasticsearch/index/query/support/AutoPrefilteringScope.java b/server/src/main/java/org/elasticsearch/index/query/support/AutoPrefilteringScope.java new file mode 100644 index 0000000000000..0e57ed201bcbf --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/query/support/AutoPrefilteringScope.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.query.support; + +import org.elasticsearch.index.query.QueryBuilder; + +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; + +/** + * During query parsing this keeps track of the current prefiltering level. + */ +public final class AutoPrefilteringScope implements AutoCloseable { + + private final Deque> prefiltersStack = new LinkedList<>(); + + public List getPrefilters() { + return prefiltersStack.stream().flatMap(List::stream).toList(); + } + + public void push(List prefilters) { + prefiltersStack.push(prefilters); + } + + public void pop() { + prefiltersStack.pop(); + } + + @Override + public void close() { + pop(); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index c85ffcea2c46b..a97616bc82123 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -56,6 +56,9 @@ * {@link org.apache.lucene.search.KnnByteVectorQuery}. */ public class KnnVectorQueryBuilder extends AbstractQueryBuilder { + + public static final TransportVersion AUTO_PREFILTERING = TransportVersion.fromName("knn_vector_query_auto_prefiltering"); + public static final String NAME = "knn"; private static final int NUM_CANDS_LIMIT = 10_000; private static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f; @@ -121,7 +124,7 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private static final TransportVersion VISIT_PERCENTAGE = TransportVersion.fromName("visit_percentage"); + public static final TransportVersion VISIT_PERCENTAGE = TransportVersion.fromName("visit_percentage"); private final String fieldName; private final VectorData queryVector; @@ -133,6 +136,7 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) { private final QueryVectorBuilder queryVectorBuilder; private final Supplier queryVectorSupplier; private final RescoreVectorBuilder rescoreVectorBuilder; + private boolean isAutoPrefiltering = false; public KnnVectorQueryBuilder( String fieldName, @@ -302,6 +306,9 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { } else { this.rescoreVectorBuilder = null; } + if (in.getTransportVersion().supports(AUTO_PREFILTERING)) { + this.isAutoPrefiltering = in.readBoolean(); + } this.queryVectorSupplier = null; } @@ -357,6 +364,10 @@ public KnnVectorQueryBuilder addFilterQueries(List filterQueries) return this; } + public boolean isAutoPrefiltering() { + return isAutoPrefiltering; + } + @Override protected void doWriteTo(StreamOutput out) throws IOException { if (queryVectorSupplier != null) { @@ -417,6 +428,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (out.getTransportVersion().supports(TransportVersions.V_8_18_0)) { out.writeOptionalWriteable(rescoreVectorBuilder); } + if (out.getTransportVersion().supports(AUTO_PREFILTERING)) { + out.writeBoolean(isAutoPrefiltering); + } } @Override @@ -479,7 +493,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { visitPercentage, rescoreVectorBuilder, vectorSimilarity - ).boost(boost).queryName(queryName).addFilterQueries(filterQueries); + ).boost(boost).queryName(queryName).addFilterQueries(filterQueries).setAutoPrefiltering(isAutoPrefiltering); } if (queryVectorBuilder != null) { SetOnce toSet = new SetOnce<>(); @@ -509,7 +523,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { visitPercentage, rescoreVectorBuilder, vectorSimilarity - ).boost(boost).queryName(queryName).addFilterQueries(filterQueries); + ).boost(boost).queryName(queryName).addFilterQueries(filterQueries).setAutoPrefiltering(isAutoPrefiltering); } boolean changed = false; List rewrittenQueries = new ArrayList<>(filterQueries.size()); @@ -534,7 +548,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException { visitPercentage, rescoreVectorBuilder, vectorSimilarity - ).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries); + ).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries).setAutoPrefiltering(isAutoPrefiltering); } if (ctx.convertToInnerHitsRewriteContext() != null) { QueryBuilder exactKnnQuery = new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity); @@ -579,8 +593,9 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { } DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType; - List filtersInitial = new ArrayList<>(filterQueries.size()); - for (QueryBuilder query : this.filterQueries) { + List allApplicableFilters = getAllApplicableFilters(context); + List filtersInitial = new ArrayList<>(allApplicableFilters.size()); + for (QueryBuilder query : allApplicableFilters) { filtersInitial.add(query.toQuery(context)); } if (context.getAliasFilter() != null) { @@ -650,6 +665,14 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { ); } + private List getAllApplicableFilters(SearchExecutionContext context) { + List applicableFilters = new ArrayList<>(filterQueries); + if (isAutoPrefiltering) { + applicableFilters.addAll(context.autoPrefilteringScope().getPrefilters()); + } + return applicableFilters; + } + private static Query buildFilterQuery(List filters) { BooleanQuery.Builder builder = new BooleanQuery.Builder(); for (Query f : filters) { @@ -671,7 +694,8 @@ protected int doHashCode() { filterQueries, vectorSimilarity, queryVectorBuilder, - rescoreVectorBuilder + rescoreVectorBuilder, + isAutoPrefiltering ); } @@ -685,11 +709,17 @@ protected boolean doEquals(KnnVectorQueryBuilder other) { && Objects.equals(filterQueries, other.filterQueries) && Objects.equals(vectorSimilarity, other.vectorSimilarity) && Objects.equals(queryVectorBuilder, other.queryVectorBuilder) - && Objects.equals(rescoreVectorBuilder, other.rescoreVectorBuilder); + && Objects.equals(rescoreVectorBuilder, other.rescoreVectorBuilder) + && isAutoPrefiltering == other.isAutoPrefiltering; } @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_0_0; } + + public KnnVectorQueryBuilder setAutoPrefiltering(boolean isAutoPrefiltering) { + this.isAutoPrefiltering = isAutoPrefiltering; + return this; + } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java index 98cb674f46bd1..bf822c8bacbc8 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/RescoreVectorBuilder.java @@ -38,7 +38,7 @@ public class RescoreVectorBuilder implements Writeable, ToXContentObject { PARSER.declareFloat(ConstructingObjectParser.constructorArg(), OVERSAMPLE_FIELD); } - private static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO = TransportVersion.fromName("rescore_vector_allow_zero"); + public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO = TransportVersion.fromName("rescore_vector_allow_zero"); // Oversample is required as of now as it is the only field in the rescore vector private final float oversample; diff --git a/server/src/main/resources/transport/definitions/referable/knn_vector_query_auto_prefiltering.csv b/server/src/main/resources/transport/definitions/referable/knn_vector_query_auto_prefiltering.csv new file mode 100644 index 0000000000000..8d6255874d132 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/knn_vector_query_auto_prefiltering.csv @@ -0,0 +1 @@ +9216000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index 6b0edb76f268f..393a630716b30 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -inference_api_eis_authorization_persistent_task,9215000 +knn_vector_query_auto_prefiltering,9216000 diff --git a/server/src/test/java/org/elasticsearch/index/query/BoolQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/BoolQueryBuilderTests.java index e9ef3ac8ad748..62c3550ce7463 100644 --- a/server/src/test/java/org/elasticsearch/index/query/BoolQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/BoolQueryBuilderTests.java @@ -14,6 +14,8 @@ import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.test.AbstractQueryTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -25,9 +27,12 @@ import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import static org.elasticsearch.index.query.QueryBuilders.boolQuery; import static org.elasticsearch.index.query.QueryBuilders.termQuery; @@ -35,6 +40,9 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; public class BoolQueryBuilderTests extends AbstractQueryTestCase { @@ -507,4 +515,192 @@ public void testShallowCopy() { } } } + + public void testAutoPrefiltering_GivenSingleMustPrefilteringClause() throws IOException { + BoolQueryBuilder query = new BoolQueryBuilder(); + Map> prefiltersToQueryNameMap = new HashMap<>(); + query.must(new TestAutoPrefilteringQueryBuilder("test", prefiltersToQueryNameMap)); + + query.toQuery(createSearchExecutionContext()); + + assertThat(prefiltersToQueryNameMap.get("test"), is(empty())); + } + + public void testAutoPrefiltering_GivenSingleShouldPrefilteringClauseAndFilters() throws IOException { + BoolQueryBuilder query = new BoolQueryBuilder(); + Map> prefiltersToQueryNameMap = new HashMap<>(); + query.should(new TestAutoPrefilteringQueryBuilder("test", prefiltersToQueryNameMap)); + randomList(1, 5, () -> randomTermQueryMaybeWrappedInCompoundQuery()).forEach(query::filter); + + query = (BoolQueryBuilder) Rewriteable.rewrite(query, createQueryRewriteContext()); + query.toQuery(createSearchExecutionContext()); + + assertThat(prefiltersToQueryNameMap.get("test"), containsInAnyOrder(query.filter().toArray())); + } + + public void testAutoPrefiltering_GivenPrefilteringFilterClauseAndFilters_ShouldNotReceiveFilters() throws IOException { + BoolQueryBuilder query = new BoolQueryBuilder(); + Map> prefiltersToQueryNameMap = new HashMap<>(); + query.filter(new TestAutoPrefilteringQueryBuilder("test", prefiltersToQueryNameMap)); + randomList(1, 5, () -> randomTermQueryMaybeWrappedInCompoundQuery()).forEach(query::filter); + + query = (BoolQueryBuilder) Rewriteable.rewrite(query, createQueryRewriteContext()); + query.toQuery(createSearchExecutionContext()); + + assertThat(prefiltersToQueryNameMap.getOrDefault("test", Set.of()), is(empty())); + } + + public void testAutoPrefiltering_GivenPrefilteringMustNotClauseAndFilters_ShouldNotReceiveFilters() throws IOException { + BoolQueryBuilder query = new BoolQueryBuilder(); + Map> prefiltersToQueryNameMap = new HashMap<>(); + query.mustNot(new TestAutoPrefilteringQueryBuilder("test", prefiltersToQueryNameMap)); + randomList(1, 5, () -> randomTermQueryMaybeWrappedInCompoundQuery()).forEach(query::filter); + + query = (BoolQueryBuilder) Rewriteable.rewrite(query, createQueryRewriteContext()); + query.toQuery(createSearchExecutionContext()); + + assertThat(prefiltersToQueryNameMap.getOrDefault("test", Set.of()), is(empty())); + } + + public void testAutoPrefiltering_GivenSingleMustPrefilteringClauseAndFilters() throws IOException { + BoolQueryBuilder query = new BoolQueryBuilder(); + Map> prefiltersToQueryNameMap = new HashMap<>(); + query.must(new TestAutoPrefilteringQueryBuilder("test", prefiltersToQueryNameMap)); + randomList(1, 5, () -> randomTermQueryMaybeWrappedInCompoundQuery()).forEach(query::filter); + + query = (BoolQueryBuilder) Rewriteable.rewrite(query, createQueryRewriteContext()); + query.toQuery(createSearchExecutionContext()); + + assertThat(prefiltersToQueryNameMap.get("test"), containsInAnyOrder(query.filter().toArray())); + } + + public void testAutoPrefiltering_GivenRandomMultipleClauses() throws IOException { + for (int i = 0; i < 100; i++) { + BoolQueryBuilder rootQuery = new BoolQueryBuilder(); + Map> prefiltersToQueryNameMap = new HashMap<>(); + TestAutoPrefilteringQueryBuilder must_1 = new TestAutoPrefilteringQueryBuilder("must_1", prefiltersToQueryNameMap); + TestAutoPrefilteringQueryBuilder must_2 = new TestAutoPrefilteringQueryBuilder("must_2", prefiltersToQueryNameMap); + TestAutoPrefilteringQueryBuilder should_1 = new TestAutoPrefilteringQueryBuilder("should_1", prefiltersToQueryNameMap); + TestAutoPrefilteringQueryBuilder should_2 = new TestAutoPrefilteringQueryBuilder("should_2", prefiltersToQueryNameMap); + rootQuery.must(must_1); + rootQuery.must(must_2); + rootQuery.should(should_1); + rootQuery.should(should_2); + + // We add a prefiltering clause for filter and must_not to later check those do not receive any prefilters + TestAutoPrefilteringQueryBuilder filter_1 = new TestAutoPrefilteringQueryBuilder("filter_1", prefiltersToQueryNameMap); + TestAutoPrefilteringQueryBuilder must_not_1 = new TestAutoPrefilteringQueryBuilder("must_not_1", prefiltersToQueryNameMap); + rootQuery.filter(filter_1); + rootQuery.mustNot(must_not_1); + + randomList(5, () -> randomTermQueryMaybeWrappedInCompoundQuery()).forEach(rootQuery::must); + randomList(5, () -> randomTermQueryMaybeWrappedInCompoundQuery()).forEach(rootQuery::should); + randomList(5, () -> randomTermQueryMaybeWrappedInCompoundQuery()).forEach(rootQuery::filter); + randomList(5, () -> randomTermQueryMaybeWrappedInCompoundQuery()).forEach(rootQuery::mustNot); + + // We add a must clause that is another bool query containing a prefiltering clause + BoolQueryBuilder bool_1 = new BoolQueryBuilder(); + TestAutoPrefilteringQueryBuilder must_3 = new TestAutoPrefilteringQueryBuilder("must_3", prefiltersToQueryNameMap); + bool_1.must(must_3); + randomList(5, () -> randomTermQueryMaybeWrappedInCompoundQuery()).forEach(bool_1::filter); + rootQuery.must(bool_1); + + rootQuery = (BoolQueryBuilder) Rewriteable.rewrite(rootQuery, createQueryRewriteContext()); + rootQuery.toQuery(createSearchExecutionContext()); + + Set expectedPrefilters = collectExpectedPrefilters(rootQuery); + assertThat( + prefiltersToQueryNameMap.get("must_1"), + equalTo(expectedPrefilters.stream().filter(q -> q != must_1).collect(Collectors.toSet())) + ); + assertThat( + prefiltersToQueryNameMap.get("must_2"), + equalTo(expectedPrefilters.stream().filter(q -> q != must_2).collect(Collectors.toSet())) + ); + assertThat(prefiltersToQueryNameMap.get("should_1"), containsInAnyOrder(expectedPrefilters.toArray())); + assertThat(prefiltersToQueryNameMap.get("should_2"), containsInAnyOrder(expectedPrefilters.toArray())); + + expectedPrefilters = collectExpectedPrefilters(rootQuery, bool_1); + assertThat( + prefiltersToQueryNameMap.get("must_3"), + equalTo(expectedPrefilters.stream().filter(q -> q != must_3 && q != bool_1).collect(Collectors.toSet())) + ); + + assertThat(prefiltersToQueryNameMap.get("filter_1"), is(empty())); + assertThat(prefiltersToQueryNameMap.get("must_not_1"), is(empty())); + } + } + + private static QueryBuilder randomTermQueryMaybeWrappedInCompoundQuery() { + QueryBuilder termQuery = randomTermQuery(); + if (randomBoolean()) { + return termQuery; + } + return switch (randomIntBetween(0, 4)) { + case 0 -> QueryBuilders.constantScoreQuery(termQuery); + case 1 -> QueryBuilders.functionScoreQuery(termQuery); + case 2 -> QueryBuilders.boostingQuery(termQuery, randomTermQuery()); + case 3 -> QueryBuilders.disMaxQuery().add(termQuery).add(randomTermQuery()); + case 4 -> QueryBuilders.boolQuery().filter(termQuery); + default -> throw new IllegalStateException("Unexpected value: " + randomIntBetween(0, 2)); + }; + } + + private static QueryBuilder randomTermQuery() { + String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME; + return termQuery(filterFieldName, randomAlphaOfLength(10)); + } + + private static Set collectExpectedPrefilters(BoolQueryBuilder... queries) { + Set expectedPrefilters = new HashSet<>(); + for (BoolQueryBuilder query : queries) { + expectedPrefilters.addAll(query.must()); + expectedPrefilters.addAll(query.filter()); + expectedPrefilters.addAll(query.mustNot().stream().map(q -> boolQuery().mustNot(q)).collect(Collectors.toList())); + } + return expectedPrefilters; + } + + private static final class TestAutoPrefilteringQueryBuilder extends AbstractQueryBuilder { + + Map> prefiltersToQueryNameMap; + + private TestAutoPrefilteringQueryBuilder(String queryName, Map> prefiltersToQueryNameMap) { + super(); + queryName(queryName); + this.prefiltersToQueryNameMap = prefiltersToQueryNameMap; + } + + @Override + protected void doWriteTo(StreamOutput out) {} + + @Override + protected void doXContent(XContentBuilder builder, Params params) {} + + @Override + protected Query doToQuery(SearchExecutionContext context) { + prefiltersToQueryNameMap.put(queryName(), context.autoPrefilteringScope().getPrefilters().stream().collect(Collectors.toSet())); + return new MatchAllDocsQuery(); + } + + @Override + protected boolean doEquals(TestAutoPrefilteringQueryBuilder other) { + return false; + } + + @Override + protected int doHashCode() { + return 0; + } + + @Override + public String getWriteableName() { + return ""; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return null; + } + } } diff --git a/server/src/test/java/org/elasticsearch/index/query/support/AutoPrefilteringScopeTests.java b/server/src/test/java/org/elasticsearch/index/query/support/AutoPrefilteringScopeTests.java new file mode 100644 index 0000000000000..f49784c5e2e2d --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/query/support/AutoPrefilteringScopeTests.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.query.support; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.RandomQueryBuilder; +import org.elasticsearch.test.ESTestCase; + +import java.util.Collection; +import java.util.List; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class AutoPrefilteringScopeTests extends ESTestCase { + + public void testMultipleLevels() { + AutoPrefilteringScope autoPrefilteringScope = new AutoPrefilteringScope(); + assertThat(autoPrefilteringScope.getPrefilters(), is(empty())); + + List prefilters_1_1 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random())); + List prefilters_1_2 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random())); + List prefilters_2_1 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random())); + List prefilters_2_2 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random())); + List prefilters_3_1 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random())); + + // Given + increases level and - decreases level, we add scope as follows: + // + 1_1 + 2_1 + 3_1 + // - 3_1 + 2_2 + // - 2_2 + 1_2 + // - 1_2 + 1_1 + // - 1_1 + // and we check current prefilters after each operation. + + autoPrefilteringScope.push(prefilters_1_1); + assertThat(autoPrefilteringScope.getPrefilters(), equalTo(prefilters_1_1)); + autoPrefilteringScope.push(prefilters_2_1); + assertThat( + autoPrefilteringScope.getPrefilters(), + equalTo(Stream.of(prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList()) + ); + autoPrefilteringScope.push(prefilters_3_1); + assertThat( + autoPrefilteringScope.getPrefilters(), + equalTo(Stream.of(prefilters_3_1, prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList()) + ); + autoPrefilteringScope.pop(); + assertThat( + autoPrefilteringScope.getPrefilters(), + equalTo(Stream.of(prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList()) + ); + autoPrefilteringScope.push(prefilters_2_2); + assertThat( + autoPrefilteringScope.getPrefilters(), + equalTo(Stream.of(prefilters_2_2, prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList()) + ); + autoPrefilteringScope.pop(); + assertThat( + autoPrefilteringScope.getPrefilters(), + equalTo(Stream.of(prefilters_2_1, prefilters_1_1).flatMap(Collection::stream).toList()) + ); + autoPrefilteringScope.pop(); + assertThat(autoPrefilteringScope.getPrefilters(), equalTo(prefilters_1_1)); + autoPrefilteringScope.push(prefilters_1_2); + assertThat( + autoPrefilteringScope.getPrefilters(), + equalTo(Stream.of(prefilters_1_2, prefilters_1_1).flatMap(Collection::stream).toList()) + ); + autoPrefilteringScope.pop(); + assertThat(autoPrefilteringScope.getPrefilters(), equalTo(prefilters_1_1)); + autoPrefilteringScope.pop(); + assertThat(autoPrefilteringScope.getPrefilters(), empty()); + } + + public void testAutoCloseable() { + AutoPrefilteringScope autoPrefilteringScope = new AutoPrefilteringScope(); + List prefilters_1 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random())); + List prefilters_2 = randomList(0, 5, () -> RandomQueryBuilder.createQuery(random())); + + try (autoPrefilteringScope) { + autoPrefilteringScope.push(prefilters_1); + + try (autoPrefilteringScope) { + autoPrefilteringScope.push(prefilters_2); + assertThat( + autoPrefilteringScope.getPrefilters(), + equalTo(Stream.of(prefilters_2, prefilters_1).flatMap(Collection::stream).toList()) + ); + } + assertThat(autoPrefilteringScope.getPrefilters(), equalTo(prefilters_1)); + } + assertThat(autoPrefilteringScope.getPrefilters(), is(empty())); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index 04731f193fb14..4f299f9dfdeab 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -76,6 +76,8 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa protected static String indexType; protected static int vectorDimensions; + protected boolean rescoreVectorAllowZero = true; + @Before private void checkIndexTypeAndDimensions() { // Check that these are initialized - should be done as part of the createAdditionalMappings method @@ -182,7 +184,7 @@ protected RescoreVectorBuilder randomRescoreVectorBuilder() { return null; } - return new RescoreVectorBuilder(randomBoolean() ? 0f : randomFloatBetween(1.0f, 10.0f, false)); + return new RescoreVectorBuilder((rescoreVectorAllowZero && randomBoolean()) ? 0f : randomFloatBetween(1.0f, 10.0f, false)); } @Override @@ -492,6 +494,36 @@ public void testBWCVersionSerializationRescoreVector() throws IOException { assertBWCSerialization(query, queryNoRescoreVector, version); } + public void testBWCVersionSerialization_GivenAutoPrefiltering() throws IOException { + for (int i = 0; i < 100; i++) { + + TransportVersion version = TransportVersionUtils.randomVersionBetween( + random(), + TransportVersions.V_8_18_0, + TransportVersionUtils.getPreviousVersion(KnnVectorQueryBuilder.AUTO_PREFILTERING) + ); + rescoreVectorAllowZero = version.onOrAfter(RescoreVectorBuilder.RESCORE_VECTOR_ALLOW_ZERO); + KnnVectorQueryBuilder query = doCreateTestQueryBuilder().setAutoPrefiltering(true); + KnnVectorQueryBuilder queryNoAutoPrefiltering = new KnnVectorQueryBuilder( + query.getFieldName(), + query.queryVector(), + query.k(), + query.numCands(), + version.onOrAfter(KnnVectorQueryBuilder.VISIT_PERCENTAGE) ? query.visitPercentage() : null, + query.rescoreVectorBuilder(), + query.getVectorSimilarity() + ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()).setAutoPrefiltering(false); + assertBWCSerialization(query, queryNoAutoPrefiltering, version); + } + } + + public void testSerialization_GivenAutoPrefiltering() throws IOException { + KnnVectorQueryBuilder query = doCreateTestQueryBuilder().setAutoPrefiltering(true); + KnnVectorQueryBuilder serializedQuery = copyQuery(query); + assertThat(serializedQuery, equalTo(query)); + assertThat(serializedQuery.hashCode(), equalTo(query.hashCode())); + } + private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException { assertSerialization(bwcQuery, version); try (BytesStreamOutput output = new BytesStreamOutput()) { @@ -551,6 +583,7 @@ public void testRewriteWithQueryVectorBuilder() throws Exception { filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10))); } knnVectorQueryBuilder.addFilterQueries(filters); + knnVectorQueryBuilder.setAutoPrefiltering(randomBoolean()); QueryRewriteContext context = new QueryRewriteContext(null, null, null); PlainActionFuture knnFuture = new PlainActionFuture<>(); @@ -564,5 +597,6 @@ public void testRewriteWithQueryVectorBuilder() throws Exception { assertThat(rewritten.getVectorSimilarity(), equalTo(1f)); assertThat(rewritten.filterQueries(), hasSize(numFilters)); assertThat(rewritten.filterQueries(), equalTo(filters)); + assertThat(rewritten.isAutoPrefiltering(), equalTo(knnVectorQueryBuilder.isAutoPrefiltering())); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index d02195c5f52c9..be62dc8f43829 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -1047,7 +1047,8 @@ yield new SparseVectorQueryBuilder( k = Math.max(k, DEFAULT_SIZE); } - yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null, null); + yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, k, null, null, null, null) + .setAutoPrefiltering(true); } default -> throw new IllegalStateException( "Field ["