diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index 280e6274d84de..ab8d4a61d753a 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -21,10 +21,13 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.OrdinalBytesRefVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.AggregationOperator; import org.elasticsearch.compute.operator.DriverContext; @@ -275,11 +278,18 @@ private static Block dataBlock(int groups, String dataType) { int blockLength = blockLength(groups); return switch (dataType) { case BYTES_REF -> { - try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) { + try ( + BytesRefVector.Builder dict = blockFactory.newBytesRefVectorBuilder(blockLength); + IntVector.Builder ords = blockFactory.newIntVectorBuilder(blockLength) + ) { + final int dictLength = Math.min(blockLength, KEYWORDS.length); + for (int i = 0; i < dictLength; i++) { + dict.appendBytesRef(KEYWORDS[i]); + } for (int i = 0; i < blockLength; i++) { - builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]); + ords.appendInt(i % dictLength); } - yield builder.build(); + yield new OrdinalBytesRefVector(ords.build(), dict.build()).asBlock(); } } case INT -> { diff --git a/docs/changelog/127849.yaml b/docs/changelog/127849.yaml new file mode 100644 index 0000000000000..4d5b747b35011 --- /dev/null +++ b/docs/changelog/127849.yaml @@ -0,0 +1,5 @@ +pr: 127849 +summary: Optimize ordinal inputs in Values aggregation +area: "ES|QL" +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 3d8e5b1dcb756..0182b29b71271 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -35,6 +35,7 @@ import static java.util.stream.Collectors.joining; import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize; +import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod; import static org.elasticsearch.compute.gen.Methods.requireAnyArgs; import static org.elasticsearch.compute.gen.Methods.requireAnyType; import static org.elasticsearch.compute.gen.Methods.requireArgs; @@ -332,10 +333,32 @@ private MethodSpec prepareProcessPage() { builder.beginControlFlow("if (valuesBlock.mayHaveNulls())"); builder.addStatement("state.enableGroupIdTracking(seenGroupIds)"); builder.endControlFlow(); - builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra))); + if (shouldWrapAddInput(blockType(aggParam.type()))) { + builder.addStatement( + "var addInput = $L", + addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra)) + ); + builder.addStatement("return $T.wrapAddInput(addInput, state, valuesBlock)", declarationType); + } else { + builder.addStatement( + "return $L", + addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra)) + ); + } } builder.endControlFlow(); - builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra))); + if (shouldWrapAddInput(vectorType(aggParam.type()))) { + builder.addStatement( + "var addInput = $L", + addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra)) + ); + builder.addStatement("return $T.wrapAddInput(addInput, state, valuesVector)", declarationType); + } else { + builder.addStatement( + "return $L", + addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra)) + ); + } return builder.build(); } @@ -525,6 +548,15 @@ private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVar warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable)); } + private boolean shouldWrapAddInput(TypeName valuesType) { + return optionalStaticMethod( + declarationType, + requireType(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT), + requireName("wrapAddInput"), + requireArgs(requireType(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT), requireType(aggState.declaredType()), requireType(valuesType)) + ) != null; + } + private void warningsBlock(MethodSpec.Builder builder, Runnable block) { if (warnExceptions.isEmpty() == false) { builder.beginControlFlow("try"); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java index f2fa7b8084448..b94eb13433a15 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java @@ -59,6 +59,23 @@ static ExecutableElement requireStaticMethod( TypeMatcher returnTypeMatcher, NameMatcher nameMatcher, ArgumentMatcher argumentMatcher + ) { + ExecutableElement method = optionalStaticMethod(declarationType, returnTypeMatcher, nameMatcher, argumentMatcher); + if (method == null) { + var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: "; + var signatures = nameMatcher.names.stream() + .map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")") + .collect(joining(" or ")); + throw new IllegalArgumentException(message + signatures); + } + return method; + } + + static ExecutableElement optionalStaticMethod( + TypeElement declarationType, + TypeMatcher returnTypeMatcher, + NameMatcher nameMatcher, + ArgumentMatcher argumentMatcher ) { return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream()) .filter(method -> method.getModifiers().contains(Modifier.STATIC)) @@ -66,13 +83,7 @@ static ExecutableElement requireStaticMethod( .filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType()))) .filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList())) .findFirst() - .orElseThrow(() -> { - var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: "; - var signatures = nameMatcher.names.stream() - .map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")") - .collect(joining(" or ")); - return new IllegalArgumentException(message + signatures); - }); + .orElse(null); } static NameMatcher requireName(String... names) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index f326492664fb8..ca7a68e333960 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -19,6 +19,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; @@ -56,6 +57,22 @@ public static GroupingState initGrouping(BigArrays bigArrays) { return new GroupingState(bigArrays); } + public static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + GroupingState state, + BytesRefBlock values + ) { + return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values); + } + + public static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + GroupingState state, + BytesRefVector values + ) { + return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values); + } + public static void combine(GroupingState state, int groupId, BytesRef v) { state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); } @@ -127,8 +144,8 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { - private final LongLongHash values; - private final BytesRefHash bytes; + final LongLongHash values; + BytesRefHash bytes; private GroupingState(BigArrays bigArrays) { LongLongHash _values = null; diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java index 6db44ffce8faf..142ab77e725c0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java @@ -63,7 +63,7 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG if (valuesBlock.mayHaveNulls()) { state.enableGroupIdTracking(seenGroupIds); } - return new GroupingAggregatorFunction.AddInput() { + var addInput = new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock); @@ -78,8 +78,9 @@ public void add(int positionOffset, IntVector groupIds) { public void close() { } }; + return ValuesBytesRefAggregator.wrapAddInput(addInput, state, valuesBlock); } - return new GroupingAggregatorFunction.AddInput() { + var addInput = new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector); @@ -94,6 +95,7 @@ public void add(int positionOffset, IntVector groupIds) { public void close() { } }; + return ValuesBytesRefAggregator.wrapAddInput(addInput, state, valuesVector); } private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java new file mode 100644 index 0000000000000..84047875f06ad --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; +import org.elasticsearch.core.Releasables; + +final class ValuesBytesRefAggregators { + static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + ValuesBytesRefAggregator.GroupingState state, + BytesRefBlock values + ) { + OrdinalBytesRefBlock valuesOrdinal = values.asOrdinals(); + if (valuesOrdinal == null) { + return delegate; + } + BytesRefVector dict = valuesOrdinal.getDictionaryVector(); + final IntVector hashIds; + BytesRef spare = new BytesRef(); + try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { + for (int p = 0; p < dict.getPositionCount(); p++) { + hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare))))); + } + hashIds = hashIdsBuilder.build(); + } + IntBlock ordinalIds = valuesOrdinal.getOrdinalsBlock(); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + if (ordinalIds.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + } + } + } + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int groupId = groupIds.getInt(groupPosition); + if (ordinalIds.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + } + } + } + + @Override + public void close() { + Releasables.close(hashIds, delegate); + } + }; + } + + static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + ValuesBytesRefAggregator.GroupingState state, + BytesRefVector values + ) { + var valuesOrdinal = values.asOrdinals(); + if (valuesOrdinal == null) { + return delegate; + } + BytesRefVector dict = valuesOrdinal.getDictionaryVector(); + final IntVector hashIds; + BytesRef spare = new BytesRef(); + try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) { + for (int p = 0; p < dict.getPositionCount(); p++) { + hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare))))); + } + hashIds = hashIdsBuilder.build(); + } + var ordinalIds = valuesOrdinal.getOrdinalsVector(); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + if (groupIds.isNull(groupPosition)) { + continue; + } + int groupStart = groupIds.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groupIds.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groupIds.getInt(g); + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + } + } + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { + int groupId = groupIds.getInt(groupPosition); + state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + } + } + + @Override + public void close() { + Releasables.close(hashIds, delegate); + } + }; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 68c6a8640cbd0..f94dd86df7b31 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -28,7 +28,10 @@ import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -$if(int||double||float||BytesRef)$ +$if(BytesRef)$ +import org.elasticsearch.compute.data.$Type$Block; +import org.elasticsearch.compute.data.$Type$Vector; +$elseif(int||double||float)$ import org.elasticsearch.compute.data.$Type$Block; $endif$ import org.elasticsearch.compute.data.IntVector; @@ -87,6 +90,24 @@ $endif$ return new GroupingState(bigArrays); } +$if(BytesRef)$ + public static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + GroupingState state, + BytesRefBlock values + ) { + return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values); + } + + public static GroupingAggregatorFunction.AddInput wrapAddInput( + GroupingAggregatorFunction.AddInput delegate, + GroupingState state, + BytesRefVector values + ) { + return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values); + } +$endif$ + public static void combine(GroupingState state, int groupId, $type$ v) { $if(long)$ state.values.add(groupId, v); @@ -234,8 +255,8 @@ $if(long||double)$ private final LongLongHash values; $elseif(BytesRef)$ - private final LongLongHash values; - private final BytesRefHash bytes; + final LongLongHash values; + BytesRefHash bytes; $elseif(int||float)$ private final LongHash values;