Skip to content

Add support for sparse_vector queries against semantic_text fields #118617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
367b507
Add SemanticSparseVectorQueryRewriteInterceptor
kderusso Dec 12, 2024
1d3151a
Add yaml
kderusso Dec 12, 2024
c427143
Refactor match rewriting & cleanup
kderusso Dec 12, 2024
fe9deab
Update docs/changelog/118617.yaml
kderusso Dec 12, 2024
c3bf3a8
Update changelog
kderusso Dec 12, 2024
015cefc
Merge branch 'main' into kderusso/sparse-vector-semantic-text-field
kderusso Dec 12, 2024
2857250
Silly error introduced in refactoring
kderusso Dec 12, 2024
2181d86
Adding some yaml test cases for not specifying inference ID - these a…
kderusso Dec 13, 2024
8c82d00
Refactor from static utils into abstract class and add support for us…
kderusso Dec 13, 2024
78cafbb
Merge branch 'main' into kderusso/sparse-vector-semantic-text-field
kderusso Dec 13, 2024
0bab375
Fix some test errors, and do some cleanup
kderusso Dec 13, 2024
f9cb789
Merge branch 'main' into kderusso/sparse-vector-semantic-text-field
kderusso Dec 16, 2024
06be83d
Add some additional error validation to ensure BWC-compliant error me…
kderusso Dec 16, 2024
7bc962b
PR feedback
kderusso Dec 16, 2024
d28cd8b
Don't throw on multiple inference IDs
kderusso Dec 16, 2024
05c45e2
Cleanup
kderusso Dec 16, 2024
3e5e09c
Merge branch 'main' into kderusso/sparse-vector-semantic-text-field
kderusso Dec 16, 2024
fe00f43
Remove doc order from yaml test (to future proof against shard count …
kderusso Dec 16, 2024
46e03ff
Fix test
kderusso Dec 16, 2024
8afdd04
PR feedback - cleanup
kderusso Dec 17, 2024
9c25161
Add tests
kderusso Dec 17, 2024
db3c768
Revert error messages to be BWC compliant
kderusso Dec 17, 2024
6768b3e
Merge branch 'main' into kderusso/sparse-vector-semantic-text-field
kderusso Dec 17, 2024
e97f8b0
Grumble grumble missed a test...
kderusso Dec 17, 2024
187cf13
update yaml test
kderusso Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/118617.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118617
summary: Add support for `sparse_vector` queries against `semantic_text` fields
area: "Search"
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -90,26 +90,33 @@ public SparseVectorQueryBuilder(
: (this.shouldPruneTokens ? new TokenPruningConfig() : null));
this.weightedTokensSupplier = null;

if (queryVectors == null ^ inferenceId == null == false) {
// Preserve BWC error messaging
if (queryVectors != null && inferenceId != null) {
throw new IllegalArgumentException(
"["
+ NAME
+ "] requires one of ["
+ QUERY_VECTOR_FIELD.getPreferredName()
+ "] or ["
+ INFERENCE_ID_FIELD.getPreferredName()
+ "]"
+ "] for "
+ ALLOWED_FIELD_TYPE
+ " fields"
);
}
if (inferenceId != null && query == null) {

// Preserve BWC error messaging
if ((queryVectors == null) == (query == null)) {
throw new IllegalArgumentException(
"["
+ NAME
+ "] requires ["
+ QUERY_FIELD.getPreferredName()
+ "] when ["
+ "] requires one of ["
+ QUERY_VECTOR_FIELD.getPreferredName()
+ "] or ["
+ INFERENCE_ID_FIELD.getPreferredName()
+ "] is specified"
+ "] for "
+ ALLOWED_FIELD_TYPE
+ " fields"
);
}
}
Expand Down Expand Up @@ -143,6 +150,14 @@ public List<WeightedToken> getQueryVectors() {
return queryVectors;
}

public String getInferenceId() {
return inferenceId;
}

public String getQuery() {
return query;
}

public boolean shouldPruneTokens() {
return shouldPruneTokens;
}
Expand Down Expand Up @@ -176,7 +191,9 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
}
builder.endObject();
} else {
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
if (inferenceId != null) {
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
}
builder.field(QUERY_FIELD.getPreferredName(), query);
}
builder.field(PRUNE_FIELD.getPreferredName(), shouldPruneTokens);
Expand Down Expand Up @@ -228,6 +245,11 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
shouldPruneTokens,
tokenPruningConfig
);
} else if (inferenceId == null) {
// Edge case, where inference_id was not specified in the request,
// but we did not intercept this and rewrite to a query o field with
// pre-configured inference. So we trap here and output a nicer error message.
throw new IllegalArgumentException("inference_id required to perform vector search on query string");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including the field name and the (wrong) field type might help users diagnose the problem

Suggested change
throw new IllegalArgumentException("inference_id required to perform vector search on query string");
throw new IllegalArgumentException("inference_id required to perform a sparse_vector query on sparse_vector field [" + fieldName + "]");

}

// TODO move this to xpack core and use inference APIs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,16 @@ public void testIllegalValues() {
{
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", null, "model id")
() -> new SparseVectorQueryBuilder("field name", null, null)
);
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id]", e.getMessage());
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
}
{
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", "model text", null)
);
assertEquals("[sparse_vector] requires [query] when [inference_id] is specified", e.getMessage());
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;

import java.util.Set;

import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;

/**
* Provides inference features.
*/
Expand Down Expand Up @@ -45,7 +47,8 @@ public Set<NodeFeature> getTestFeatures() {
SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX,
SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
SEMANTIC_TEXT_HIGHLIGHTER,
SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED
SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
Expand Down Expand Up @@ -440,7 +441,7 @@ public List<QuerySpec<?>> getQueries() {

@Override
public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
return List.of(new SemanticMatchQueryRewriteInterceptor());
return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,12 @@

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_match_query_rewrite_interception_supported"
Expand All @@ -33,63 +21,45 @@ public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterce
public SemanticMatchQueryRewriteInterceptor() {}

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
protected String getFieldName(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
ResolvedIndices resolvedIndices = context.getResolvedIndices();
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> inferenceIndices = new ArrayList<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName());
if (inferenceFieldMetadata != null) {
inferenceIndices.add(indexName);
} else {
nonInferenceIndices.add(indexName);
}
}

if (inferenceIndices.isEmpty()) {
return rewritten;
} else if (nonInferenceIndices.isEmpty() == false) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String inferenceIndexName : inferenceIndices) {
// Add a separate clause for each semantic query, because they may be using different inference endpoints
// TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints
boolQueryBuilder.should(
createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value())
);
}
boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder));
rewritten = boolQueryBuilder;
} else {
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
}
}

return rewritten;
return matchQueryBuilder.fieldName();
}

@Override
protected String getQuery(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
return (String) matchQueryBuilder.value();
}

@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
}

private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) {
@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
boolQueryBuilder.should(
createSemanticSubQuery(
indexInformation.getInferenceIndices(),
matchQueryBuilder.fieldName(),
(String) matchQueryBuilder.value()
)
);
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
return boolQueryBuilder;
}

private QueryBuilder createMatchSubQuery(List<String> indices, MatchQueryBuilder matchQueryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(matchQueryBuilder);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ public String getWriteableName() {
return NAME;
}

public String getFieldName() {
return fieldName;
}

public String getQuery() {
return query;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_15_0;
Expand Down
Loading