datafusion_optimizer/
single_distinct_to_groupby.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`SingleDistinctToGroupBy`] replaces `AGG(DISTINCT ..)` with `AGG(..) GROUP BY ..`
19
20use std::sync::Arc;
21
22use crate::optimizer::ApplyOrder;
23use crate::{OptimizerConfig, OptimizerRule};
24
25use datafusion_common::{
26    internal_err, tree_node::Transformed, DataFusionError, HashSet, Result,
27};
28use datafusion_expr::builder::project;
29use datafusion_expr::expr::AggregateFunctionParams;
30use datafusion_expr::{
31    col,
32    expr::AggregateFunction,
33    logical_plan::{Aggregate, LogicalPlan},
34    Expr,
35};
36
37/// single distinct to group by optimizer rule
38///  ```text
39///    Before:
40///    SELECT a, count(DISTINCT b), sum(c)
41///    FROM t
42///    GROUP BY a
43///
44///    After:
45///    SELECT a, count(alias1), sum(alias2)
46///    FROM (
47///      SELECT a, b as alias1, sum(c) as alias2
48///      FROM t
49///      GROUP BY a, b
50///    )
51///    GROUP BY a
52///  ```
53#[derive(Default, Debug)]
54pub struct SingleDistinctToGroupBy {}
55
56const SINGLE_DISTINCT_ALIAS: &str = "alias1";
57
58impl SingleDistinctToGroupBy {
59    #[allow(missing_docs)]
60    pub fn new() -> Self {
61        Self {}
62    }
63}
64
65/// Check whether all aggregate exprs are distinct on a single field.
66fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result<bool> {
67    let mut fields_set = HashSet::new();
68    let mut aggregate_count = 0;
69    for expr in aggr_expr {
70        if let Expr::AggregateFunction(AggregateFunction {
71            func,
72            params:
73                AggregateFunctionParams {
74                    distinct,
75                    args,
76                    filter,
77                    order_by,
78                    null_treatment: _,
79                },
80        }) = expr
81        {
82            if filter.is_some() || order_by.is_some() {
83                return Ok(false);
84            }
85            aggregate_count += 1;
86            if *distinct {
87                for e in args {
88                    fields_set.insert(e);
89                }
90            } else if func.name() != "sum"
91                && func.name().to_lowercase() != "min"
92                && func.name().to_lowercase() != "max"
93            {
94                return Ok(false);
95            }
96        } else {
97            return Ok(false);
98        }
99    }
100    Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
101}
102
103/// Check if the first expr is [Expr::GroupingSet].
104fn contains_grouping_set(expr: &[Expr]) -> bool {
105    matches!(expr.first(), Some(Expr::GroupingSet(_)))
106}
107
108impl OptimizerRule for SingleDistinctToGroupBy {
109    fn name(&self) -> &str {
110        "single_distinct_aggregation_to_group_by"
111    }
112
113    fn apply_order(&self) -> Option<ApplyOrder> {
114        Some(ApplyOrder::TopDown)
115    }
116
117    fn supports_rewrite(&self) -> bool {
118        true
119    }
120
121    fn rewrite(
122        &self,
123        plan: LogicalPlan,
124        _config: &dyn OptimizerConfig,
125    ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
126        match plan {
127            LogicalPlan::Aggregate(Aggregate {
128                input,
129                aggr_expr,
130                schema,
131                group_expr,
132                ..
133            }) if is_single_distinct_agg(&aggr_expr)?
134                && !contains_grouping_set(&group_expr) =>
135            {
136                let group_size = group_expr.len();
137                // alias all original group_by exprs
138                let (mut inner_group_exprs, out_group_expr_with_alias): (
139                    Vec<Expr>,
140                    Vec<(Expr, _)>,
141                ) = group_expr
142                    .into_iter()
143                    .enumerate()
144                    .map(|(i, group_expr)| {
145                        if let Expr::Column(_) = group_expr {
146                            // For Column expressions we can use existing expression as is.
147                            (group_expr.clone(), (group_expr, None))
148                        } else {
149                            // For complex expression write is as alias, to be able to refer
150                            // if from parent operators successfully.
151                            // Consider plan below.
152                            //
153                            // Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
154                            // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
155                            // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
156                            //
157                            // First aggregate(from bottom) refers to `test.a` column.
158                            // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate.
159
160                            // If we were to write plan above as below without alias
161                            //
162                            // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\
163                            // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
164                            // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
165                            //
166                            // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it.
167                            let alias_str = format!("group_alias_{i}");
168                            let (qualifier, field) = schema.qualified_field(i);
169                            (
170                                group_expr.alias(alias_str.clone()),
171                                (col(alias_str), Some((qualifier, field.name()))),
172                            )
173                        }
174                    })
175                    .unzip();
176
177                // replace the distinct arg with alias
178                let mut index = 1;
179                let mut group_fields_set = HashSet::new();
180                let mut inner_aggr_exprs = vec![];
181                let outer_aggr_exprs = aggr_expr
182                    .into_iter()
183                    .map(|aggr_expr| match aggr_expr {
184                        Expr::AggregateFunction(AggregateFunction {
185                            func,
186                            params: AggregateFunctionParams { mut args, distinct, .. }
187                        }) => {
188                            if distinct {
189                                if args.len() != 1 {
190                                    return internal_err!("DISTINCT aggregate should have exactly one argument");
191                                }
192                                let arg = args.swap_remove(0);
193
194                                if group_fields_set.insert(arg.schema_name().to_string()) {
195                                    inner_group_exprs
196                                        .push(arg.alias(SINGLE_DISTINCT_ALIAS));
197                                }
198                                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
199                                    func,
200                                    vec![col(SINGLE_DISTINCT_ALIAS)],
201                                    false, // intentional to remove distinct here
202                                    None,
203                                    None,
204                                    None,
205                                )))
206                                // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
207                            } else {
208                                index += 1;
209                                let alias_str = format!("alias{index}");
210                                inner_aggr_exprs.push(
211                                    Expr::AggregateFunction(AggregateFunction::new_udf(
212                                        Arc::clone(&func),
213                                        args,
214                                        false,
215                                        None,
216                                        None,
217                                        None,
218                                    ))
219                                    .alias(&alias_str),
220                                );
221                                Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
222                                    func,
223                                    vec![col(&alias_str)],
224                                    false,
225                                    None,
226                                    None,
227                                    None,
228                                )))
229                            }
230                        }
231                        _ => Ok(aggr_expr),
232                    })
233                    .collect::<Result<Vec<_>>>()?;
234
235                // construct the inner AggrPlan
236                let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
237                    input,
238                    inner_group_exprs,
239                    inner_aggr_exprs,
240                )?);
241
242                let outer_group_exprs = out_group_expr_with_alias
243                    .iter()
244                    .map(|(expr, _)| expr.clone())
245                    .collect();
246
247                // so the aggregates are displayed in the same way even after the rewrite
248                // this optimizer has two kinds of alias:
249                // - group_by aggr
250                // - aggr expr
251                let alias_expr: Vec<_> = out_group_expr_with_alias
252                    .into_iter()
253                    .map(|(group_expr, original_name)| match original_name {
254                        Some((qualifier, name)) => {
255                            group_expr.alias_qualified(qualifier.cloned(), name)
256                        }
257                        None => group_expr,
258                    })
259                    .chain(outer_aggr_exprs.iter().cloned().enumerate().map(
260                        |(idx, expr)| {
261                            let idx = idx + group_size;
262                            let (qualifier, field) = schema.qualified_field(idx);
263                            expr.alias_qualified(qualifier.cloned(), field.name())
264                        },
265                    ))
266                    .collect();
267
268                let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
269                    Arc::new(inner_agg),
270                    outer_group_exprs,
271                    outer_aggr_exprs,
272                )?);
273                Ok(Transformed::yes(project(outer_aggr, alias_expr)?))
274            }
275            _ => Ok(Transformed::no(plan)),
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use crate::assert_optimized_plan_eq_display_indent_snapshot;
284    use crate::test::*;
285    use datafusion_expr::expr::GroupingSet;
286    use datafusion_expr::ExprFunctionExt;
287    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
288    use datafusion_functions_aggregate::count::count_udaf;
289    use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum};
290    use datafusion_functions_aggregate::min_max::max_udaf;
291    use datafusion_functions_aggregate::sum::sum_udaf;
292
293    fn max_distinct(expr: Expr) -> Expr {
294        Expr::AggregateFunction(AggregateFunction::new_udf(
295            max_udaf(),
296            vec![expr],
297            true,
298            None,
299            None,
300            None,
301        ))
302    }
303
304    macro_rules! assert_optimized_plan_equal {
305        (
306            $plan:expr,
307            @ $expected:literal $(,)?
308        ) => {{
309            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(SingleDistinctToGroupBy::new());
310            assert_optimized_plan_eq_display_indent_snapshot!(
311                rule,
312                $plan,
313                @ $expected,
314            )
315        }};
316    }
317
318    #[test]
319    fn not_exist_distinct() -> Result<()> {
320        let table_scan = test_table_scan()?;
321
322        let plan = LogicalPlanBuilder::from(table_scan)
323            .aggregate(Vec::<Expr>::new(), vec![max(col("b"))])?
324            .build()?;
325
326        // Do nothing
327        assert_optimized_plan_equal!(
328            plan,
329            @r"
330        Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N]
331          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
332        "
333        )
334    }
335
336    #[test]
337    fn single_distinct() -> Result<()> {
338        let table_scan = test_table_scan()?;
339
340        let plan = LogicalPlanBuilder::from(table_scan)
341            .aggregate(Vec::<Expr>::new(), vec![count_distinct(col("b"))])?
342            .build()?;
343
344        // Should work
345        assert_optimized_plan_equal!(
346            plan,
347            @r"
348        Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]
349          Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]
350            Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]
351              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
352        "
353        )
354    }
355
356    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
357    #[test]
358    fn single_distinct_and_grouping_set() -> Result<()> {
359        let table_scan = test_table_scan()?;
360
361        let grouping_set = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
362            vec![col("a")],
363            vec![col("b")],
364        ]));
365
366        let plan = LogicalPlanBuilder::from(table_scan)
367            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
368            .build()?;
369
370        // Should not be optimized
371        assert_optimized_plan_equal!(
372            plan,
373            @r"
374        Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
375          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
376        "
377        )
378    }
379
380    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
381    #[test]
382    fn single_distinct_and_cube() -> Result<()> {
383        let table_scan = test_table_scan()?;
384
385        let grouping_set = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));
386
387        let plan = LogicalPlanBuilder::from(table_scan)
388            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
389            .build()?;
390
391        // Should not be optimized
392        assert_optimized_plan_equal!(
393            plan,
394            @r"
395        Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
396          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
397        "
398        )
399    }
400
401    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
402    #[test]
403    fn single_distinct_and_rollup() -> Result<()> {
404        let table_scan = test_table_scan()?;
405
406        let grouping_set =
407            Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));
408
409        let plan = LogicalPlanBuilder::from(table_scan)
410            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
411            .build()?;
412
413        // Should not be optimized
414        assert_optimized_plan_equal!(
415            plan,
416            @r"
417        Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
418          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
419        "
420        )
421    }
422
423    #[test]
424    fn single_distinct_expr() -> Result<()> {
425        let table_scan = test_table_scan()?;
426
427        let plan = LogicalPlanBuilder::from(table_scan)
428            .aggregate(Vec::<Expr>::new(), vec![count_distinct(lit(2) * col("b"))])?
429            .build()?;
430
431        assert_optimized_plan_equal!(
432            plan,
433            @r"
434        Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]
435          Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]
436            Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]
437              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
438        "
439        )
440    }
441
442    #[test]
443    fn single_distinct_and_groupby() -> Result<()> {
444        let table_scan = test_table_scan()?;
445
446        let plan = LogicalPlanBuilder::from(table_scan)
447            .aggregate(vec![col("a")], vec![count_distinct(col("b"))])?
448            .build()?;
449
450        // Should work
451        assert_optimized_plan_equal!(
452            plan,
453            @r"
454        Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]
455          Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]
456            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]
457              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
458        "
459        )
460    }
461
462    #[test]
463    fn two_distinct_and_groupby() -> Result<()> {
464        let table_scan = test_table_scan()?;
465
466        let plan = LogicalPlanBuilder::from(table_scan)
467            .aggregate(
468                vec![col("a")],
469                vec![count_distinct(col("b")), count_distinct(col("c"))],
470            )?
471            .build()?;
472
473        // Do nothing
474        assert_optimized_plan_equal!(
475            plan,
476            @r"
477        Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]
478          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
479        "
480        )
481    }
482
483    #[test]
484    fn one_field_two_distinct_and_groupby() -> Result<()> {
485        let table_scan = test_table_scan()?;
486
487        let plan = LogicalPlanBuilder::from(table_scan)
488            .aggregate(
489                vec![col("a")],
490                vec![count_distinct(col("b")), max_distinct(col("b"))],
491            )?
492            .build()?;
493
494        // Should work
495        assert_optimized_plan_equal!(
496            plan,
497            @r"
498        Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]
499          Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N]
500            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]
501              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
502        "
503        )
504    }
505
506    #[test]
507    fn distinct_and_common() -> Result<()> {
508        let table_scan = test_table_scan()?;
509
510        let plan = LogicalPlanBuilder::from(table_scan)
511            .aggregate(
512                vec![col("a")],
513                vec![count_distinct(col("b")), count(col("c"))],
514            )?
515            .build()?;
516
517        // Do nothing
518        assert_optimized_plan_equal!(
519            plan,
520            @r"
521        Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]
522          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
523        "
524        )
525    }
526
527    #[test]
528    fn group_by_with_expr() -> Result<()> {
529        let table_scan = test_table_scan().unwrap();
530
531        let plan = LogicalPlanBuilder::from(table_scan)
532            .aggregate(vec![col("a") + lit(1)], vec![count_distinct(col("c"))])?
533            .build()?;
534
535        // Should work
536        assert_optimized_plan_equal!(
537            plan,
538            @r"
539        Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]
540          Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]
541            Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]
542              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
543        "
544        )
545    }
546
547    #[test]
548    fn two_distinct_and_one_common() -> Result<()> {
549        let table_scan = test_table_scan()?;
550
551        let plan = LogicalPlanBuilder::from(table_scan)
552            .aggregate(
553                vec![col("a")],
554                vec![
555                    sum(col("c")),
556                    count_distinct(col("b")),
557                    max_distinct(col("b")),
558                ],
559            )?
560            .build()?;
561
562        // Should work
563        assert_optimized_plan_equal!(
564            plan,
565            @r"
566        Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]
567          Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N]
568            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]
569              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
570        "
571        )
572    }
573
574    #[test]
575    fn one_distinct_and_two_common() -> Result<()> {
576        let table_scan = test_table_scan()?;
577
578        let plan = LogicalPlanBuilder::from(table_scan)
579            .aggregate(
580                vec![col("a")],
581                vec![sum(col("c")), max(col("c")), count_distinct(col("b"))],
582            )?
583            .build()?;
584
585        // Should work
586        assert_optimized_plan_equal!(
587            plan,
588            @r"
589        Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]
590          Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]
591            Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]
592              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
593        "
594        )
595    }
596
597    #[test]
598    fn one_distinct_and_one_common() -> Result<()> {
599        let table_scan = test_table_scan()?;
600
601        let plan = LogicalPlanBuilder::from(table_scan)
602            .aggregate(
603                vec![col("c")],
604                vec![min(col("a")), count_distinct(col("b"))],
605            )?
606            .build()?;
607
608        // Should work
609        assert_optimized_plan_equal!(
610            plan,
611            @r"
612        Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]
613          Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]
614            Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]
615              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
616        "
617        )
618    }
619
620    #[test]
621    fn common_with_filter() -> Result<()> {
622        let table_scan = test_table_scan()?;
623
624        // sum(a) FILTER (WHERE a > 5)
625        let expr = Expr::AggregateFunction(AggregateFunction::new_udf(
626            sum_udaf(),
627            vec![col("a")],
628            false,
629            Some(Box::new(col("a").gt(lit(5)))),
630            None,
631            None,
632        ));
633        let plan = LogicalPlanBuilder::from(table_scan)
634            .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
635            .build()?;
636
637        // Do nothing
638        assert_optimized_plan_equal!(
639            plan,
640            @r"
641        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]
642          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
643        "
644        )
645    }
646
647    #[test]
648    fn distinct_with_filter() -> Result<()> {
649        let table_scan = test_table_scan()?;
650
651        // count(DISTINCT a) FILTER (WHERE a > 5)
652        let expr = count_udaf()
653            .call(vec![col("a")])
654            .distinct()
655            .filter(col("a").gt(lit(5)))
656            .build()?;
657        let plan = LogicalPlanBuilder::from(table_scan)
658            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
659            .build()?;
660
661        // Do nothing
662        assert_optimized_plan_equal!(
663            plan,
664            @r"
665        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]
666          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
667        "
668        )
669    }
670
671    #[test]
672    fn common_with_order_by() -> Result<()> {
673        let table_scan = test_table_scan()?;
674
675        // SUM(a ORDER BY a)
676        let expr = Expr::AggregateFunction(AggregateFunction::new_udf(
677            sum_udaf(),
678            vec![col("a")],
679            false,
680            None,
681            Some(vec![col("a").sort(true, false)]),
682            None,
683        ));
684        let plan = LogicalPlanBuilder::from(table_scan)
685            .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
686            .build()?;
687
688        // Do nothing
689        assert_optimized_plan_equal!(
690            plan,
691            @r"
692        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]
693          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
694        "
695        )
696    }
697
698    #[test]
699    fn distinct_with_order_by() -> Result<()> {
700        let table_scan = test_table_scan()?;
701
702        // count(DISTINCT a ORDER BY a)
703        let expr = count_udaf()
704            .call(vec![col("a")])
705            .distinct()
706            .order_by(vec![col("a").sort(true, false)])
707            .build()?;
708        let plan = LogicalPlanBuilder::from(table_scan)
709            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
710            .build()?;
711
712        // Do nothing
713        assert_optimized_plan_equal!(
714            plan,
715            @r"
716        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]
717          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
718        "
719        )
720    }
721
722    #[test]
723    fn aggregate_with_filter_and_order_by() -> Result<()> {
724        let table_scan = test_table_scan()?;
725
726        // count(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
727        let expr = count_udaf()
728            .call(vec![col("a")])
729            .distinct()
730            .filter(col("a").gt(lit(5)))
731            .order_by(vec![col("a").sort(true, false)])
732            .build()?;
733        let plan = LogicalPlanBuilder::from(table_scan)
734            .aggregate(vec![col("c")], vec![sum(col("a")), expr])?
735            .build()?;
736
737        // Do nothing
738        assert_optimized_plan_equal!(
739            plan,
740            @r"
741        Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]
742          TableScan: test [a:UInt32, b:UInt32, c:UInt32]
743        "
744        )
745    }
746}