Skip to content

ES|QL - Add scoring for full text functions disjunctions #121793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
978770d
LuceneQueryScoreEvaluator first implementation
carlosdelest Feb 5, 2025
cb2c3c4
Add ScoreOperator and ScoreMapper
carlosdelest Feb 5, 2025
9ca756a
Add a ExpressionScoreMapper and a ScoreMapper interface to retrieve s…
carlosdelest Feb 5, 2025
e4eb86d
Implement ExpressionScoreMapper for FullTextFunction and BinaryLogic
carlosdelest Feb 5, 2025
a437da3
Create a ScoreOperator that can be planned via the LocalExecutionPlan…
carlosdelest Feb 5, 2025
aa4ffbf
Fix EvalMapper
carlosdelest Feb 5, 2025
1044bfd
Add tests
carlosdelest Feb 5, 2025
5abed67
Spotless
carlosdelest Feb 5, 2025
8b7fd0a
Update docs/changelog/121793.yaml
carlosdelest Feb 5, 2025
72bbd5f
Fix tests
carlosdelest Feb 5, 2025
1bfa58f
Merge remote-tracking branch 'origin/main' into enhancement/esql-scor…
carlosdelest Mar 3, 2025
5cb0bfc
Add testing and capabilities
carlosdelest Mar 4, 2025
3cab2dc
Remove disjunction limitations from docs
carlosdelest Mar 4, 2025
c639ec6
Calculate the _score attr position instead of hardcoding it
carlosdelest Mar 4, 2025
3f3b5b7
Refactor LuceneQueryExpressionEvaluator into a superclass and subclas…
carlosdelest Mar 5, 2025
145955c
Fix tests
carlosdelest Mar 5, 2025
63ca98b
Refactor query evaluators to use subclasses instead of interfaces
carlosdelest Mar 5, 2025
0008559
Merge remote-tracking branch 'carlosdelest/enhancement/esql-score-dis…
carlosdelest Mar 5, 2025
ef3decc
[CI] Auto commit changes from spotless
Mar 5, 2025
6755ab2
Refactor tests
carlosdelest Mar 6, 2025
b2de161
Refactor tests
carlosdelest Mar 6, 2025
bd88335
Refactor tests
carlosdelest Mar 6, 2025
ab8bcf1
Spotless
carlosdelest Mar 6, 2025
0fb9dc7
Merge remote-tracking branch 'origin/main' into enhancement/esql-scor…
carlosdelest Mar 6, 2025
3b994f6
Merge remote-tracking branch 'carlosdelest/enhancement/esql-score-dis…
carlosdelest Mar 6, 2025
9f39ad3
Add javadoc
carlosdelest Mar 6, 2025
0de1df5
Added missing tests
carlosdelest Mar 6, 2025
33016a9
Merge remote-tracking branch 'origin/main' into enhancement/esql-scor…
carlosdelest Mar 6, 2025
c355bcc
Fix changelog
carlosdelest Mar 6, 2025
7457544
Fix test
carlosdelest Mar 10, 2025
e4758d3
Remove unnecessary method
carlosdelest Mar 10, 2025
e3e068c
Add missing capabilities to tests
carlosdelest Mar 10, 2025
5859c81
Merge remote-tracking branch 'origin/main' into enhancement/esql-scor…
carlosdelest Mar 10, 2025
0447824
Merge remote-tracking branch 'origin/main' into enhancement/esql-scor…
carlosdelest Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor query evaluators to use subclasses instead of interfaces
  • Loading branch information
carlosdelest committed Mar 5, 2025
commit 63ca98bf0fc224ee69bf813e8262f8844f83b90e
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is extracted from the previous LuceneQueryExpressionEvaluator, and it contains the base mechanism for executing Lucene queries over Pages.

Subclasses can implement methods to decide what the Block will look like, and how to add results to it based on matching / non matching results.

Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BooleanVector;
Expand All @@ -32,7 +32,10 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.function.BiFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Consumer;

/**
* {@link EvalOperator.ExpressionEvaluator} to run a Lucene {@link Query} during
Expand All @@ -41,26 +44,22 @@
* {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
* this evaluator is here to save the day.
*/
public abstract class LuceneQueryEvaluator implements Releasable {

public static final double NO_MATCH_SCORE = 0.0;
public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements Releasable {

public record ShardConfig(Query query, IndexSearcher searcher) {}

private final BlockFactory blockFactory;
private final ShardConfig[] shards;
private final BiFunction<BlockFactory, Integer, ScoreVectorBuilder> scoreVectorBuilderSupplier;

private ShardState[] perShardState = EMPTY_SHARD_STATES;
private final List<ShardState> perShardState;

protected LuceneQueryEvaluator(
BlockFactory blockFactory,
ShardConfig[] shards,
BiFunction<BlockFactory, Integer, ScoreVectorBuilder> scoreVectorBuilderSupplier
ShardConfig[] shards
) {
this.blockFactory = blockFactory;
this.shards = shards;
this.scoreVectorBuilderSupplier = scoreVectorBuilderSupplier;
this.perShardState = new ArrayList<>(Collections.nCopies(shards.length, null));
}

public Block executeQuery(Page page) {
Expand Down Expand Up @@ -115,7 +114,7 @@ private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException
int min = docs.docs().getInt(0);
int max = docs.docs().getInt(docs.getPositionCount() - 1);
int length = max - min + 1;
try (ScoreVectorBuilder scoreBuilder = scoreVectorBuilderSupplier.apply(blockFactory, length)) {
try (T scoreBuilder = createBuilder(blockFactory, length)) {
if (length == docs.getPositionCount() && length > 1) {
return segmentState.scoreDense(scoreBuilder, min, max);
}
Expand Down Expand Up @@ -143,8 +142,7 @@ private Vector evalSlow(DocVector docs) throws IOException {
int prevShard = -1;
int prevSegment = -1;
SegmentState segmentState = null;
try (ScoreVectorBuilder scoreBuilder = scoreVectorBuilderSupplier.apply(blockFactory, docs.getPositionCount())) {
scoreBuilder.initVector();
try (T scoreBuilder = createBuilder(blockFactory, docs.getPositionCount())) {
for (int i = 0; i < docs.getPositionCount(); i++) {
int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
int segment = docs.segments().getInt(map[i]);
Expand All @@ -155,7 +153,7 @@ private Vector evalSlow(DocVector docs) throws IOException {
prevSegment = segment;
}
if (segmentState.noMatch) {
scoreBuilder.appendNoMatch();
appendNoMatch(scoreBuilder);
} else {
segmentState.scoreSingleDocWithScorer(scoreBuilder, docs.docs().getInt(map[i]));
}
Expand All @@ -170,40 +168,39 @@ private Vector evalSlow(DocVector docs) throws IOException {
public void close() {
}

protected abstract ScoreMode scoreMode();

private ShardState shardState(int shard) throws IOException {
if (shard >= perShardState.length) {
perShardState = ArrayUtil.grow(perShardState, shard + 1);
} else if (perShardState[shard] != null) {
return perShardState[shard];
ShardState shardState = perShardState.get(shard);
if (shardState != null) {
return shardState;
}
perShardState[shard] = new ShardState(shards[shard]);
return perShardState[shard];
shardState = new ShardState(shards[shard]);
perShardState.set(shard, shardState);
return shardState;
}

private class ShardState {
private final Weight weight;
private final IndexSearcher searcher;
private SegmentState[] perSegmentState = EMPTY_SEGMENT_STATES;
private final List<SegmentState> perSegmentState;

ShardState(ShardConfig config) throws IOException {
weight = config.searcher.createWeight(config.query, scoreMode(), 1.0f);
searcher = config.searcher;
perSegmentState = new ArrayList<>(Collections.nCopies(searcher.getLeafContexts().size(), null));
}

SegmentState segmentState(int segment) throws IOException {
if (segment >= perSegmentState.length) {
perSegmentState = ArrayUtil.grow(perSegmentState, segment + 1);
} else if (perSegmentState[segment] != null) {
return perSegmentState[segment];
SegmentState segmentState = perSegmentState.get(segment);
if (segmentState != null) {
return segmentState;
}
perSegmentState[segment] = new SegmentState(weight, searcher.getLeafContexts().get(segment));
return perSegmentState[segment];
segmentState = new SegmentState(weight, searcher.getLeafContexts().get(segment));
perSegmentState.set(segment, segmentState);
return segmentState;
}
}

private static class SegmentState {
private class SegmentState {
private final Weight weight;
private final LeafReaderContext ctx;

Expand Down Expand Up @@ -244,9 +241,9 @@ private SegmentState(Weight weight, LeafReaderContext ctx) {
* Score a range using the {@link BulkScorer}. This should be faster
* than using {@link #scoreSparse} for dense doc ids.
*/
Vector scoreDense(ScoreVectorBuilder scoreBuilder, int min, int max) throws IOException {
Vector scoreDense(T scoreBuilder, int min, int max) throws IOException {
if (noMatch) {
return scoreBuilder.createNoMatchVector();
return createNoMatchVector(blockFactory, max - min + 1);
}
if (bulkScorer == null || // The bulkScorer wasn't initialized
Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
Expand All @@ -255,10 +252,12 @@ Vector scoreDense(ScoreVectorBuilder scoreBuilder, int min, int max) throws IOEx
bulkScorer = weight.bulkScorer(ctx);
if (bulkScorer == null) {
noMatch = true;
return scoreBuilder.createNoMatchVector();
return createNoMatchVector(blockFactory, max - min + 1);
}
}
try (DenseCollector collector = new DenseCollector(min, max, scoreBuilder)) {
try (DenseCollector<T> collector = new DenseCollector<>(min, max, scoreBuilder,
LuceneQueryEvaluator.this::appendNoMatch,
LuceneQueryEvaluator.this::appendMatch)) {
bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
return collector.build();
}
Expand All @@ -268,12 +267,11 @@ Vector scoreDense(ScoreVectorBuilder scoreBuilder, int min, int max) throws IOEx
* Score a vector of doc ids using {@link Scorer}. If you have a dense range of
* doc ids it'd be faster to use {@link #scoreDense}.
*/
Vector scoreSparse(ScoreVectorBuilder scoreBuilder, IntVector docs) throws IOException {
Vector scoreSparse(T scoreBuilder, IntVector docs) throws IOException {
initScorer(docs.getInt(0));
if (noMatch) {
return scoreBuilder.createNoMatchVector();
return createNoMatchVector(blockFactory, docs.getPositionCount());
}
scoreBuilder.initVector();
for (int i = 0; i < docs.getPositionCount(); i++) {
scoreSingleDocWithScorer(scoreBuilder, docs.getInt(i));
}
Expand All @@ -296,41 +294,47 @@ private void initScorer(int minDocId) throws IOException {
}
}

private void scoreSingleDocWithScorer(ScoreVectorBuilder builder, int doc) throws IOException {
private void scoreSingleDocWithScorer(T builder, int doc) throws IOException {
if (scorer.iterator().docID() == doc) {
builder.appendMatch(scorer);
appendMatch(builder, scorer);
} else if (scorer.iterator().docID() > doc) {
builder.appendNoMatch();
appendNoMatch(builder);
} else {
if (scorer.iterator().advance(doc) == doc) {
builder.appendMatch(scorer);
appendMatch(builder, scorer);
} else {
builder.appendNoMatch();
appendNoMatch(builder);
}
}
}
}

private static final ShardState[] EMPTY_SHARD_STATES = new ShardState[0];
private static final SegmentState[] EMPTY_SEGMENT_STATES = new SegmentState[0];

/**
* Collects matching information for dense range of doc ids. This assumes that
* doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
* which isn't documented, but @jpountz swears is true.
*/
static class DenseCollector implements LeafCollector, Releasable {
private final ScoreVectorBuilder scoreBuilder;
static class DenseCollector<U extends Vector.Builder> implements LeafCollector, Releasable {
private final U scoreBuilder;
private final int max;
private Scorable scorer;
private final Consumer<U> appendNoMatch;
private final CheckedBiConsumer<U, Scorable, IOException> appendMatch;

private Scorable scorer;
int next;

DenseCollector(int min, int max, ScoreVectorBuilder scoreBuilder) {
DenseCollector(
int min,
int max,
U scoreBuilder,
Consumer<U> appendNoMatch,
CheckedBiConsumer<U, Scorable, IOException> appendMatch
) {
this.scoreBuilder = scoreBuilder;
scoreBuilder.initVector();
this.max = max;
next = min;
this.appendNoMatch = appendNoMatch;
this.appendMatch = appendMatch;
}

@Override
Expand All @@ -341,9 +345,9 @@ public void setScorer(Scorable scorable) {
@Override
public void collect(int doc) throws IOException {
while (next++ < doc) {
scoreBuilder.appendNoMatch();
appendNoMatch.accept(scoreBuilder);
}
scoreBuilder.appendMatch(scorer);
appendMatch.accept(scoreBuilder, scorer);
}

public Vector build() {
Expand All @@ -353,7 +357,7 @@ public Vector build() {
@Override
public void finish() {
while (next++ <= max) {
scoreBuilder.appendNoMatch();
appendNoMatch.accept(scoreBuilder);
}
}

Expand All @@ -363,15 +367,13 @@ public void close() {
}
}

public interface ScoreVectorBuilder extends Releasable {
Vector createNoMatchVector();
protected abstract ScoreMode scoreMode();

void initVector();
protected abstract Vector createNoMatchVector(BlockFactory blockFactory, int size);

void appendNoMatch();
protected abstract T createBuilder(BlockFactory blockFactory, int size);

void appendMatch(Scorable scorer) throws IOException;
protected abstract void appendNoMatch(T builder);

Vector build();
}
protected abstract void appendMatch(T builder, Scorable scorer) throws IOException;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the ExpressionEvaluator just overrides the methods needed from the superclass

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import org.elasticsearch.compute.data.Vector;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.core.Releasables;

import java.io.IOException;

/**
* {@link EvalOperator.ExpressionEvaluator} to run a Lucene {@link Query} during
Expand All @@ -26,15 +27,15 @@
* {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
* this evaluator is here to save the day.
*/
public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator implements EvalOperator.ExpressionEvaluator {

public static final double NO_MATCH_SCORE = 0.0;
public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator<BooleanVector.Builder>
implements
EvalOperator.ExpressionEvaluator {

LuceneQueryExpressionEvaluator(
BlockFactory blockFactory,
ShardConfig[] shards
) {
super(blockFactory, shards, BooleanScoreVectorBuilder::new);
super(blockFactory, shards);
}

@Override
Expand All @@ -47,63 +48,30 @@ protected ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}

public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
private final ShardConfig[] shardConfigs;

public Factory(ShardConfig[] shardConfigs) {
this.shardConfigs = shardConfigs;
}

@Override
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs);
}
@Override
protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
return blockFactory.newConstantBooleanVector(false, size);
}

static class BooleanScoreVectorBuilder implements ScoreVectorBuilder {

private final BlockFactory blockFactory;
private final int size;

private BooleanVector.Builder builder;

BooleanScoreVectorBuilder(BlockFactory blockFactory, int size) {
this.blockFactory = blockFactory;
this.size = size;
}

@Override
public Vector createNoMatchVector() {
return blockFactory.newConstantBooleanVector(false, size);
}

@Override
public void initVector() {
assert builder == null : "initVector called twice";
builder = blockFactory.newBooleanVectorFixedBuilder(size);
}

@Override
public void appendNoMatch() {
assert builder != null : "appendNoMatch called before initVector";
builder.appendBoolean(false);
}
@Override
protected BooleanVector.Builder createBuilder(BlockFactory blockFactory, int size) {
return blockFactory.newBooleanVectorFixedBuilder(size);
}

@Override
public void appendMatch(Scorable scorer) {
assert builder != null : "appendMatch called before initVector";
builder.appendBoolean(true);
}
@Override
protected void appendNoMatch(BooleanVector.Builder builder) {
builder.appendBoolean(false);
}

@Override
public Vector build() {
assert builder != null : "build called before initVector";
return builder.build();
}
@Override
protected void appendMatch(BooleanVector.Builder builder, Scorable scorer) throws IOException {
builder.appendBoolean(true);
}

public record Factory(ShardConfig[] shardConfigs) implements EvalOperator.ExpressionEvaluator.Factory {
@Override
public void close() {
Releasables.closeExpectNoException(builder);
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
return new LuceneQueryExpressionEvaluator(context.blockFactory(), shardConfigs);
}
}
}
Loading