Skip to content

Commit 15bec3c

Browse files
authored
Add support for sparse_vector queries against semantic_text fields (#118617)
1 parent 7c65a8e commit 15bec3c

File tree

13 files changed

+887
-76
lines changed

13 files changed

+887
-76
lines changed

docs/changelog/118617.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 118617
2+
summary: Add support for `sparse_vector` queries against `semantic_text` fields
3+
area: "Search"
4+
type: enhancement
5+
issues: []

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,26 +90,33 @@ public SparseVectorQueryBuilder(
9090
: (this.shouldPruneTokens ? new TokenPruningConfig() : null));
9191
this.weightedTokensSupplier = null;
9292

93-
if (queryVectors == null ^ inferenceId == null == false) {
93+
// Preserve BWC error messaging
94+
if (queryVectors != null && inferenceId != null) {
9495
throw new IllegalArgumentException(
9596
"["
9697
+ NAME
9798
+ "] requires one of ["
9899
+ QUERY_VECTOR_FIELD.getPreferredName()
99100
+ "] or ["
100101
+ INFERENCE_ID_FIELD.getPreferredName()
101-
+ "]"
102+
+ "] for "
103+
+ ALLOWED_FIELD_TYPE
104+
+ " fields"
102105
);
103106
}
104-
if (inferenceId != null && query == null) {
107+
108+
// Preserve BWC error messaging
109+
if ((queryVectors == null) == (query == null)) {
105110
throw new IllegalArgumentException(
106111
"["
107112
+ NAME
108-
+ "] requires ["
109-
+ QUERY_FIELD.getPreferredName()
110-
+ "] when ["
113+
+ "] requires one of ["
114+
+ QUERY_VECTOR_FIELD.getPreferredName()
115+
+ "] or ["
111116
+ INFERENCE_ID_FIELD.getPreferredName()
112-
+ "] is specified"
117+
+ "] for "
118+
+ ALLOWED_FIELD_TYPE
119+
+ " fields"
113120
);
114121
}
115122
}
@@ -143,6 +150,14 @@ public List<WeightedToken> getQueryVectors() {
143150
return queryVectors;
144151
}
145152

153+
public String getInferenceId() {
154+
return inferenceId;
155+
}
156+
157+
public String getQuery() {
158+
return query;
159+
}
160+
146161
public boolean shouldPruneTokens() {
147162
return shouldPruneTokens;
148163
}
@@ -176,7 +191,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
176191
}
177192
builder.endObject();
178193
} else {
179-
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
194+
if (inferenceId != null) {
195+
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
196+
}
180197
builder.field(QUERY_FIELD.getPreferredName(), query);
181198
}
182199
builder.field(PRUNE_FIELD.getPreferredName(), shouldPruneTokens);
@@ -228,6 +245,11 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
228245
shouldPruneTokens,
229246
tokenPruningConfig
230247
);
248+
} else if (inferenceId == null) {
249+
// Edge case, where inference_id was not specified in the request,
250+
// but we did not intercept this and rewrite to a query o field with
251+
// pre-configured inference. So we trap here and output a nicer error message.
252+
throw new IllegalArgumentException("inference_id required to perform vector search on query string");
231253
}
232254

233255
// TODO move this to xpack core and use inference APIs

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,16 +260,16 @@ public void testIllegalValues() {
260260
{
261261
IllegalArgumentException e = expectThrows(
262262
IllegalArgumentException.class,
263-
() -> new SparseVectorQueryBuilder("field name", null, "model id")
263+
() -> new SparseVectorQueryBuilder("field name", null, null)
264264
);
265-
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id]", e.getMessage());
265+
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
266266
}
267267
{
268268
IllegalArgumentException e = expectThrows(
269269
IllegalArgumentException.class,
270270
() -> new SparseVectorQueryBuilder("field name", "model text", null)
271271
);
272-
assertEquals("[sparse_vector] requires [query] when [inference_id] is specified", e.getMessage());
272+
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
273273
}
274274
}
275275

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
import org.elasticsearch.features.FeatureSpecification;
1111
import org.elasticsearch.features.NodeFeature;
1212
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
13-
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
1413
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
1514
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
1615
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
1716

1817
import java.util.Set;
1918

19+
import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
20+
import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
21+
2022
/**
2123
* Provides inference features.
2224
*/
@@ -45,7 +47,8 @@ public Set<NodeFeature> getTestFeatures() {
4547
SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX,
4648
SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
4749
SEMANTIC_TEXT_HIGHLIGHTER,
48-
SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED
50+
SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
51+
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED
4952
);
5053
}
5154
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
8181
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
8282
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
83+
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
8384
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
8485
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
8586
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
@@ -440,7 +441,7 @@ public List<QuerySpec<?>> getQueries() {
440441

441442
@Override
442443
public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
443-
return List.of(new SemanticMatchQueryRewriteInterceptor());
444+
return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor());
444445
}
445446

446447
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java

Lines changed: 29 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,12 @@
77

88
package org.elasticsearch.xpack.inference.queries;
99

10-
import org.elasticsearch.action.ResolvedIndices;
11-
import org.elasticsearch.cluster.metadata.IndexMetadata;
12-
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
1310
import org.elasticsearch.features.NodeFeature;
14-
import org.elasticsearch.index.mapper.IndexFieldMapper;
1511
import org.elasticsearch.index.query.BoolQueryBuilder;
1612
import org.elasticsearch.index.query.MatchQueryBuilder;
1713
import org.elasticsearch.index.query.QueryBuilder;
18-
import org.elasticsearch.index.query.QueryRewriteContext;
19-
import org.elasticsearch.index.query.TermQueryBuilder;
20-
import org.elasticsearch.index.query.TermsQueryBuilder;
21-
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
2214

23-
import java.util.ArrayList;
24-
import java.util.Collection;
25-
import java.util.List;
26-
27-
public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
15+
public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
2816

2917
public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
3018
"search.semantic_match_query_rewrite_interception_supported"
@@ -33,63 +21,45 @@ public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterce
3321
public SemanticMatchQueryRewriteInterceptor() {}
3422

3523
@Override
36-
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
24+
protected String getFieldName(QueryBuilder queryBuilder) {
3725
assert (queryBuilder instanceof MatchQueryBuilder);
3826
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
39-
QueryBuilder rewritten = queryBuilder;
40-
ResolvedIndices resolvedIndices = context.getResolvedIndices();
41-
if (resolvedIndices != null) {
42-
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
43-
List<String> inferenceIndices = new ArrayList<>();
44-
List<String> nonInferenceIndices = new ArrayList<>();
45-
for (IndexMetadata indexMetadata : indexMetadataCollection) {
46-
String indexName = indexMetadata.getIndex().getName();
47-
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName());
48-
if (inferenceFieldMetadata != null) {
49-
inferenceIndices.add(indexName);
50-
} else {
51-
nonInferenceIndices.add(indexName);
52-
}
53-
}
54-
55-
if (inferenceIndices.isEmpty()) {
56-
return rewritten;
57-
} else if (nonInferenceIndices.isEmpty() == false) {
58-
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
59-
for (String inferenceIndexName : inferenceIndices) {
60-
// Add a separate clause for each semantic query, because they may be using different inference endpoints
61-
// TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints
62-
boolQueryBuilder.should(
63-
createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value())
64-
);
65-
}
66-
boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder));
67-
rewritten = boolQueryBuilder;
68-
} else {
69-
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
70-
}
71-
}
72-
73-
return rewritten;
27+
return matchQueryBuilder.fieldName();
28+
}
7429

30+
@Override
31+
protected String getQuery(QueryBuilder queryBuilder) {
32+
assert (queryBuilder instanceof MatchQueryBuilder);
33+
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
34+
return (String) matchQueryBuilder.value();
7535
}
7636

7737
@Override
78-
public String getQueryName() {
79-
return MatchQueryBuilder.NAME;
38+
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
39+
return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
8040
}
8141

82-
private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) {
42+
@Override
43+
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
44+
QueryBuilder queryBuilder,
45+
InferenceIndexInformationForField indexInformation
46+
) {
47+
assert (queryBuilder instanceof MatchQueryBuilder);
48+
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
8349
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
84-
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
85-
boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
50+
boolQueryBuilder.should(
51+
createSemanticSubQuery(
52+
indexInformation.getInferenceIndices(),
53+
matchQueryBuilder.fieldName(),
54+
(String) matchQueryBuilder.value()
55+
)
56+
);
57+
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
8658
return boolQueryBuilder;
8759
}
8860

89-
private QueryBuilder createMatchSubQuery(List<String> indices, MatchQueryBuilder matchQueryBuilder) {
90-
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
91-
boolQueryBuilder.must(matchQueryBuilder);
92-
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
93-
return boolQueryBuilder;
61+
@Override
62+
public String getQueryName() {
63+
return MatchQueryBuilder.NAME;
9464
}
9565
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ public String getWriteableName() {
148148
return NAME;
149149
}
150150

151+
public String getFieldName() {
152+
return fieldName;
153+
}
154+
155+
public String getQuery() {
156+
return query;
157+
}
158+
151159
@Override
152160
public TransportVersion getMinimalSupportedVersion() {
153161
return TransportVersions.V_8_15_0;

0 commit comments

Comments
 (0)