Skip to content

ESQL: Enable physical plan verification #118114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/118114.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118114
summary: Enable physical plan verification
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ public Attribute(Source source, String name, Nullability nullability, @Nullable
this.nullability = nullability;
}

public static String rawTemporaryName(String inner, String outer, String suffix) {
return SYNTHETIC_ATTRIBUTE_NAME_PREFIX + inner + "$" + outer + "$" + suffix;
public static String rawTemporaryName(String... parts) {
var name = String.join("$", parts);
return name.isEmpty() || name.startsWith(SYNTHETIC_ATTRIBUTE_NAME_PREFIX) ? name : SYNTHETIC_ATTRIBUTE_NAME_PREFIX + name;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ protected List<Batch<PhysicalPlan>> batches() {
}

protected List<Batch<PhysicalPlan>> rules(boolean optimizeForEsSource) {
List<Rule<?, PhysicalPlan>> esSourceRules = new ArrayList<>(4);
List<Rule<?, PhysicalPlan>> esSourceRules = new ArrayList<>(6);
esSourceRules.add(new ReplaceSourceAttributes());

if (optimizeForEsSource) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
package org.elasticsearch.xpack.esql.optimizer;

import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.optimizer.rules.PlanConsistencyChecker;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;

Expand All @@ -31,10 +33,14 @@ private PhysicalVerifier() {}
/** Verifies the physical plan. */
public Collection<Failure> verify(PhysicalPlan plan) {
Set<Failure> failures = new LinkedHashSet<>();
Failures depFailures = new Failures();

plan.forEachDown(p -> {
// FIXME: re-enable
// DEPENDENCY_CHECK.checkPlan(p, failures);
if (p instanceof AggregateExec agg) {
var exclude = Expressions.references(agg.ordinalAttributes());
DEPENDENCY_CHECK.checkPlan(p, exclude, depFailures);
return;
}
if (p instanceof FieldExtractExec fieldExtractExec) {
Attribute sourceAttribute = fieldExtractExec.sourceAttribute();
if (sourceAttribute == null) {
Expand All @@ -48,8 +54,13 @@ public Collection<Failure> verify(PhysicalPlan plan) {
);
}
}
DEPENDENCY_CHECK.checkPlan(p, depFailures);
});

if (depFailures.hasFailures()) {
throw new IllegalStateException(depFailures.toString());
}

return failures;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ public class PlanConsistencyChecker<P extends QueryPlan<P>> {
* {@link org.elasticsearch.xpack.esql.common.Failure Failure}s to the {@link Failures} object.
*/
public void checkPlan(P p, Failures failures) {
checkPlan(p, AttributeSet.EMPTY, failures);
}

public void checkPlan(P p, AttributeSet exclude, Failures failures) {
AttributeSet refs = p.references();
AttributeSet input = p.inputSet();
AttributeSet missing = refs.subtract(input);
AttributeSet missing = refs.subtract(input).subtract(exclude);
// TODO: for Joins, we should probably check if the required fields from the left child are actually in the left child, not
// just any child (and analogously for the right child).
if (missing.isEmpty() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
Expand All @@ -22,7 +21,6 @@

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -54,18 +52,9 @@ public PhysicalPlan apply(PhysicalPlan plan) {
* it loads the field lazily. If we have more than one field we need to
* make sure the fields are loaded for the standard hash aggregator.
*/
if (p instanceof AggregateExec agg && agg.groupings().size() == 1) {
// CATEGORIZE requires the standard hash aggregator as well.
if (agg.groupings().get(0).anyMatch(e -> e instanceof Categorize) == false) {
var leaves = new LinkedList<>();
// TODO: this seems out of place
agg.aggregates()
.stream()
.filter(a -> agg.groupings().contains(a) == false)
.forEach(a -> leaves.addAll(a.collectLeaves()));
var remove = agg.groupings().stream().filter(g -> leaves.contains(g) == false).toList();
missing.removeAll(Expressions.references(remove));
}
if (p instanceof AggregateExec agg) {
var ordinalAttributes = agg.ordinalAttributes();
missing.removeAll(Expressions.references(ordinalAttributes));
Comment on lines +55 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better.

}

// add extractor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -184,6 +187,24 @@ protected AttributeSet computeReferences() {
return mode.isInputPartial() ? new AttributeSet(intermediateAttributes) : Aggregate.computeReferences(aggregates, groupings);
}

/** Returns the attributes that can be loaded from ordinals -- no explicit extraction is needed */
public List<Attribute> ordinalAttributes() {
List<Attribute> orginalAttributs = new ArrayList<>(groupings.size());
// Ordinals can be leveraged just for a single grouping. If there are multiple groupings, fields need to be laoded for the
// hash aggregator.
// CATEGORIZE requires the standard hash aggregator as well.
if (groupings().size() == 1 && groupings.get(0).anyMatch(e -> e instanceof Categorize) == false) {
var leaves = new HashSet<>();
aggregates.stream().filter(a -> groupings.contains(a) == false).forEach(a -> leaves.addAll(a.collectLeaves()));
groupings.forEach(g -> {
if (leaves.contains(g) == false) {
orginalAttributs.add((Attribute) g);
}
});
}
return orginalAttributs;
}

@Override
public int hashCode() {
return Objects.hash(groupings, aggregates, mode, intermediateAttributes, estimatedRowSize, child());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
Expand Down Expand Up @@ -72,6 +73,12 @@ public boolean inBetweenAggs() {
return inBetweenAggs;
}

@Override
protected AttributeSet computeReferences() {
// ExchangeExec does no input referencing, it only outputs all synthetic attributes, "sourced" from remote exchanges.
return AttributeSet.EMPTY;
}

@Override
public UnaryExec replaceChild(PhysicalPlan newChild) {
return new ExchangeExec(source(), output, inBetweenAggs, newChild);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,7 @@ public static Attribute extractSourceAttributesFrom(PhysicalPlan plan) {

@Override
protected AttributeSet computeReferences() {
AttributeSet required = new AttributeSet(docValuesAttributes);

required.add(sourceAttribute);
required.addAll(attributesToExtract);

return required;
return sourceAttribute != null ? new AttributeSet(sourceAttribute) : AttributeSet.EMPTY;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ public List<Attribute> addedFields() {
public List<Attribute> output() {
if (lazyOutput == null) {
lazyOutput = new ArrayList<>(left().output());
for (Attribute attr : addedFields) {
lazyOutput.add(attr);
}
var addedFieldsNames = addedFields.stream().map(Attribute::name).toList();
lazyOutput.removeIf(a -> addedFieldsNames.contains(a.name()));
lazyOutput.addAll(addedFields);
Comment on lines +96 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is very interesting. Thanks for catching this.

I'm wondering if enforcing unique names for PhysicalPlan outputs is actually the right way to go. It is what we assumed so far, so I don't wanna depart from here, but there's also a reason to be more lenient:

In PhysicalPlan land, this output method was technically correct, because it represented exactly what the physical operators will do: they're gonna append the blocks for the added fields to the pages that are processed. The blocks corresponding to shadowed attributes/channels will not be stripped from incoming pages.

This means that the output layout technically has duplicate attribute names. This change just hides them and makes them inaccessible during the physical planning stage.

FWIW, EvalExec has the exact same situation. Eval performs shadowing, but EvalOperator can only ever append blocks, so the output pages actually still contain the blocks corresponding to shadowed attributes.

}
return lazyOutput;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ private void aggregatesToFactory(
// coordinator/exchange phase
else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
if (grouping) {
sourceAttr = aggregateMapper.mapGrouping(aggregateFunction);
sourceAttr = aggregateMapper.mapGrouping(ne);
} else {
sourceAttr = aggregateMapper.mapNonGrouping(aggregateFunction);
sourceAttr = aggregateMapper.mapNonGrouping(ne);
}
} else {
throw new EsqlIllegalArgumentException("illegal aggregation mode");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
Expand Down Expand Up @@ -91,7 +92,7 @@ final class AggregateMapper {
private record AggDef(Class<?> aggClazz, String type, String extra, boolean grouping) {}

/** Map of AggDef types to intermediate named expressions. */
private static final Map<AggDef, List<IntermediateStateDesc>> mapper = AGG_FUNCTIONS.stream()
private static final Map<AggDef, List<IntermediateStateDesc>> MAPPER = AGG_FUNCTIONS.stream()
.flatMap(AggregateMapper::typeAndNames)
.flatMap(AggregateMapper::groupingAndNonGrouping)
.collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState));
Expand All @@ -103,50 +104,57 @@ private record AggDef(Class<?> aggClazz, String type, String extra, boolean grou
cache = new HashMap<>();
}

public List<NamedExpression> mapNonGrouping(List<? extends Expression> aggregates) {
public List<NamedExpression> mapNonGrouping(List<? extends NamedExpression> aggregates) {
return doMapping(aggregates, false);
}

public List<NamedExpression> mapNonGrouping(Expression aggregate) {
public List<NamedExpression> mapNonGrouping(NamedExpression aggregate) {
return map(aggregate, false).toList();
}

public List<NamedExpression> mapGrouping(List<? extends Expression> aggregates) {
public List<NamedExpression> mapGrouping(List<? extends NamedExpression> aggregates) {
return doMapping(aggregates, true);
}

private List<NamedExpression> doMapping(List<? extends Expression> aggregates, boolean grouping) {
private List<NamedExpression> doMapping(List<? extends NamedExpression> aggregates, boolean grouping) {
AttributeMap<NamedExpression> attrToExpressions = new AttributeMap<>();
aggregates.stream().flatMap(agg -> map(agg, grouping)).forEach(ne -> attrToExpressions.put(ne.toAttribute(), ne));
aggregates.stream().flatMap(ne -> map(ne, grouping)).forEach(ne -> attrToExpressions.put(ne.toAttribute(), ne));
return attrToExpressions.values().stream().toList();
}

public List<NamedExpression> mapGrouping(Expression aggregate) {
public List<NamedExpression> mapGrouping(NamedExpression aggregate) {
return map(aggregate, true).toList();
}

private Stream<NamedExpression> map(Expression aggregate, boolean grouping) {
return cache.computeIfAbsent(Alias.unwrap(aggregate), aggKey -> computeEntryForAgg(aggKey, grouping)).stream();
private Stream<NamedExpression> map(NamedExpression ne, boolean grouping) {
return cache.computeIfAbsent(Alias.unwrap(ne), aggKey -> computeEntryForAgg(ne.name(), aggKey, grouping)).stream();
}

private static List<NamedExpression> computeEntryForAgg(Expression aggregate, boolean grouping) {
var aggDef = aggDefOrNull(aggregate, grouping);
if (aggDef != null) {
var is = getNonNull(aggDef);
var exp = isToNE(is).toList();
return exp;
private static List<NamedExpression> computeEntryForAgg(String aggAlias, Expression aggregate, boolean grouping) {
if (aggregate instanceof AggregateFunction aggregateFunction) {
return entryForAgg(aggAlias, aggregateFunction, grouping);
}
if (aggregate instanceof FieldAttribute || aggregate instanceof MetadataAttribute || aggregate instanceof ReferenceAttribute) {
// This condition is a little pedantic, but do we expected other expressions here? if so, then add them
// This condition is a little pedantic, but do we expect other expressions here? if so, then add them
return List.of();
} else {
throw new EsqlIllegalArgumentException("unknown agg: " + aggregate.getClass() + ": " + aggregate);
}
throw new EsqlIllegalArgumentException("unknown agg: " + aggregate.getClass() + ": " + aggregate);
}

private static List<NamedExpression> entryForAgg(String aggAlias, AggregateFunction aggregateFunction, boolean grouping) {
var aggDef = new AggDef(
aggregateFunction.getClass(),
dataTypeToString(aggregateFunction.field().dataType(), aggregateFunction.getClass()),
aggregateFunction instanceof SpatialCentroid ? "SourceValues" : "",
grouping
);
var is = getNonNull(aggDef);
return isToNE(is, aggAlias).toList();
}

/** Gets the agg from the mapper - wrapper around map::get for more informative failure.*/
private static List<IntermediateStateDesc> getNonNull(AggDef aggDef) {
var l = mapper.get(aggDef);
var l = MAPPER.get(aggDef);
if (l == null) {
throw new EsqlIllegalArgumentException("Cannot find intermediate state for: " + aggDef);
}
Expand Down Expand Up @@ -199,18 +207,6 @@ private static Stream<AggDef> groupingAndNonGrouping(Tuple<Class<?>, Tuple<Strin
}
}

private static AggDef aggDefOrNull(Expression aggregate, boolean grouping) {
if (aggregate instanceof AggregateFunction aggregateFunction) {
return new AggDef(
aggregateFunction.getClass(),
dataTypeToString(aggregateFunction.field().dataType(), aggregateFunction.getClass()),
aggregate instanceof SpatialCentroid ? "SourceValues" : "",
grouping
);
}
return null;
}

/** Retrieves the intermediate state description for a given class, type, and grouping. */
private static List<IntermediateStateDesc> lookupIntermediateState(AggDef aggDef) {
try {
Expand Down Expand Up @@ -257,15 +253,15 @@ private static String determinePackageName(Class<?> clazz) {
}

/** Maps intermediate state description to named expressions. */
private static Stream<NamedExpression> isToNE(List<IntermediateStateDesc> intermediateStateDescs) {
private static Stream<NamedExpression> isToNE(List<IntermediateStateDesc> intermediateStateDescs, String aggAlias) {
return intermediateStateDescs.stream().map(is -> {
final DataType dataType;
if (Strings.isEmpty(is.dataType())) {
dataType = toDataType(is.type());
} else {
dataType = DataType.fromEs(is.dataType());
}
return new ReferenceAttribute(Source.EMPTY, is.name(), dataType);
return new ReferenceAttribute(Source.EMPTY, Attribute.rawTemporaryName(aggAlias, is.name()), dataType);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ public void testCountFieldWithEval() {
var esStatsQuery = as(exg.child(), EsStatsQueryExec.class);

assertThat(esStatsQuery.limit(), is(nullValue()));
assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen"));
assertThat(Expressions.names(esStatsQuery.output()), contains("$$c$count", "$$c$seen"));
var stat = as(esStatsQuery.stats().get(0), Stat.class);
assertThat(stat.query(), is(QueryBuilders.existsQuery("salary")));
}
Expand All @@ -276,7 +276,7 @@ public void testCountOneFieldWithFilter() {
var exchange = as(agg.child(), ExchangeExec.class);
var esStatsQuery = as(exchange.child(), EsStatsQueryExec.class);
assertThat(esStatsQuery.limit(), is(nullValue()));
assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen"));
assertThat(Expressions.names(esStatsQuery.output()), contains("$$c$count", "$$c$seen"));
var stat = as(esStatsQuery.stats().get(0), Stat.class);
Source source = new Source(2, 8, "salary > 1000");
var exists = QueryBuilders.existsQuery("salary");
Expand Down Expand Up @@ -386,7 +386,7 @@ public void testAnotherCountAllWithFilter() {
var exchange = as(agg.child(), ExchangeExec.class);
var esStatsQuery = as(exchange.child(), EsStatsQueryExec.class);
assertThat(esStatsQuery.limit(), is(nullValue()));
assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen"));
assertThat(Expressions.names(esStatsQuery.output()), contains("$$c$count", "$$c$seen"));
var source = ((SingleValueQuery.Builder) esStatsQuery.query()).source();
var expected = wrapWithSingleQuery(query, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", source);
assertThat(expected.toString(), is(esStatsQuery.query().toString()));
Expand Down Expand Up @@ -997,7 +997,7 @@ public boolean exists(String field) {
var exchange = as(agg.child(), ExchangeExec.class);
assertThat(exchange.inBetweenAggs(), is(true));
var localSource = as(exchange.child(), LocalSourceExec.class);
assertThat(Expressions.names(localSource.output()), contains("count", "seen"));
assertThat(Expressions.names(localSource.output()), contains("$$c$count", "$$c$seen"));
}

/**
Expand Down Expand Up @@ -1152,7 +1152,7 @@ public void testIsNotNull_TextField_Pushdown_WithCount() {
var exg = as(agg.child(), ExchangeExec.class);
var esStatsQuery = as(exg.child(), EsStatsQueryExec.class);
assertThat(esStatsQuery.limit(), is(nullValue()));
assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen"));
assertThat(Expressions.names(esStatsQuery.output()), contains("$$c$count", "$$c$seen"));
var stat = as(esStatsQuery.stats().get(0), Stat.class);
assertThat(stat.query(), is(QueryBuilders.existsQuery("job")));
}
Expand Down
Loading