datafusion_optimizer/
scalar_subquery_to_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s
19
20use std::collections::{BTreeSet, HashMap};
21use std::sync::Arc;
22
23use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
24use crate::optimizer::ApplyOrder;
25use crate::utils::{evaluates_to_null, replace_qualified_name};
26use crate::{OptimizerConfig, OptimizerRule};
27
28use crate::analyzer::type_coercion::TypeCoercionRewriter;
29use datafusion_common::alias::AliasGenerator;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
32};
33use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue};
34use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
35use datafusion_expr::logical_plan::{JoinType, Subquery};
36use datafusion_expr::utils::conjunction;
37use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder};
38
39/// Optimizer rule for rewriting subquery filters to joins
40#[derive(Default, Debug)]
41pub struct ScalarSubqueryToJoin {}
42
43impl ScalarSubqueryToJoin {
44    #[allow(missing_docs)]
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Finds expressions that have a scalar subquery in them (and recurses when found)
50    ///
51    /// # Arguments
52    /// * `predicate` - A conjunction to split and search
53    ///
54    /// Returns a tuple (subqueries, alias)
55    fn extract_subquery_exprs(
56        &self,
57        predicate: &Expr,
58        alias_gen: &Arc<AliasGenerator>,
59    ) -> Result<(Vec<(Subquery, String)>, Expr)> {
60        let mut extract = ExtractScalarSubQuery {
61            sub_query_info: vec![],
62            alias_gen,
63        };
64        predicate
65            .clone()
66            .rewrite(&mut extract)
67            .data()
68            .map(|new_expr| (extract.sub_query_info, new_expr))
69    }
70}
71
72impl OptimizerRule for ScalarSubqueryToJoin {
73    fn supports_rewrite(&self) -> bool {
74        true
75    }
76
77    fn rewrite(
78        &self,
79        plan: LogicalPlan,
80        config: &dyn OptimizerConfig,
81    ) -> Result<Transformed<LogicalPlan>> {
82        match plan {
83            LogicalPlan::Filter(filter) => {
84                // Optimization: skip the rest of the rule and its copies if
85                // there are no scalar subqueries
86                if !contains_scalar_subquery(&filter.predicate) {
87                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
88                }
89
90                let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs(
91                    &filter.predicate,
92                    config.alias_generator(),
93                )?;
94
95                if subqueries.is_empty() {
96                    return internal_err!("Expected subqueries not found in filter");
97                }
98
99                // iterate through all subqueries in predicate, turning each into a left join
100                let mut cur_input = filter.input.as_ref().clone();
101                for (subquery, alias) in subqueries {
102                    if let Some((optimized_subquery, expr_check_map)) =
103                        build_join(&subquery, &cur_input, &alias)?
104                    {
105                        if !expr_check_map.is_empty() {
106                            rewrite_expr = rewrite_expr
107                                .transform_up(|expr| {
108                                    // replace column references with entry in map, if it exists
109                                    if let Some(map_expr) = expr
110                                        .try_as_col()
111                                        .and_then(|col| expr_check_map.get(&col.name))
112                                    {
113                                        Ok(Transformed::yes(map_expr.clone()))
114                                    } else {
115                                        Ok(Transformed::no(expr))
116                                    }
117                                })
118                                .data()?;
119                        }
120                        cur_input = optimized_subquery;
121                    } else {
122                        // if we can't handle all of the subqueries then bail for now
123                        return Ok(Transformed::no(LogicalPlan::Filter(filter)));
124                    }
125                }
126                let new_plan = LogicalPlanBuilder::from(cur_input)
127                    .filter(rewrite_expr)?
128                    .build()?;
129                Ok(Transformed::yes(new_plan))
130            }
131            LogicalPlan::Projection(projection) => {
132                // Optimization: skip the rest of the rule and its copies if
133                // there are no scalar subqueries
134                if !projection.expr.iter().any(contains_scalar_subquery) {
135                    return Ok(Transformed::no(LogicalPlan::Projection(projection)));
136                }
137
138                let mut all_subqueries = vec![];
139                let mut expr_to_rewrite_expr_map = HashMap::new();
140                let mut subquery_to_expr_map = HashMap::new();
141                for expr in projection.expr.iter() {
142                    let (subqueries, rewrite_exprs) =
143                        self.extract_subquery_exprs(expr, config.alias_generator())?;
144                    for (subquery, _) in &subqueries {
145                        subquery_to_expr_map.insert(subquery.clone(), expr.clone());
146                    }
147                    all_subqueries.extend(subqueries);
148                    expr_to_rewrite_expr_map.insert(expr, rewrite_exprs);
149                }
150                if all_subqueries.is_empty() {
151                    return internal_err!("Expected subqueries not found in projection");
152                }
153                // iterate through all subqueries in predicate, turning each into a left join
154                let mut cur_input = projection.input.as_ref().clone();
155                for (subquery, alias) in all_subqueries {
156                    if let Some((optimized_subquery, expr_check_map)) =
157                        build_join(&subquery, &cur_input, &alias)?
158                    {
159                        cur_input = optimized_subquery;
160                        if !expr_check_map.is_empty() {
161                            if let Some(expr) = subquery_to_expr_map.get(&subquery) {
162                                if let Some(rewrite_expr) =
163                                    expr_to_rewrite_expr_map.get(expr)
164                                {
165                                    let new_expr = rewrite_expr
166                                        .clone()
167                                        .transform_up(|expr| {
168                                            // replace column references with entry in map, if it exists
169                                            if let Some(map_expr) =
170                                                expr.try_as_col().and_then(|col| {
171                                                    expr_check_map.get(&col.name)
172                                                })
173                                            {
174                                                Ok(Transformed::yes(map_expr.clone()))
175                                            } else {
176                                                Ok(Transformed::no(expr))
177                                            }
178                                        })
179                                        .data()?;
180                                    expr_to_rewrite_expr_map.insert(expr, new_expr);
181                                }
182                            }
183                        }
184                    } else {
185                        // if we can't handle all of the subqueries then bail for now
186                        return Ok(Transformed::no(LogicalPlan::Projection(projection)));
187                    }
188                }
189
190                let mut proj_exprs = vec![];
191                for expr in projection.expr.iter() {
192                    let old_expr_name = expr.schema_name().to_string();
193                    let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap();
194                    let new_expr_name = new_expr.schema_name().to_string();
195                    if new_expr_name != old_expr_name {
196                        proj_exprs.push(new_expr.clone().alias(old_expr_name))
197                    } else {
198                        proj_exprs.push(new_expr.clone());
199                    }
200                }
201                let new_plan = LogicalPlanBuilder::from(cur_input)
202                    .project(proj_exprs)?
203                    .build()?;
204                Ok(Transformed::yes(new_plan))
205            }
206
207            plan => Ok(Transformed::no(plan)),
208        }
209    }
210
211    fn name(&self) -> &str {
212        "scalar_subquery_to_join"
213    }
214
215    fn apply_order(&self) -> Option<ApplyOrder> {
216        Some(ApplyOrder::TopDown)
217    }
218}
219
220/// Returns true if the expression has a scalar subquery somewhere in it
221/// false otherwise
222fn contains_scalar_subquery(expr: &Expr) -> bool {
223    expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_))))
224        .expect("Inner is always Ok")
225}
226
227struct ExtractScalarSubQuery<'a> {
228    sub_query_info: Vec<(Subquery, String)>,
229    alias_gen: &'a Arc<AliasGenerator>,
230}
231
232impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
233    type Node = Expr;
234
235    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
236        match expr {
237            Expr::ScalarSubquery(subquery) => {
238                let subqry_alias = self.alias_gen.next("__scalar_sq");
239                self.sub_query_info
240                    .push((subquery.clone(), subqry_alias.clone()));
241                let scalar_expr = subquery
242                    .subquery
243                    .head_output_expr()?
244                    .map_or(plan_err!("single expression required."), Ok)?;
245                Ok(Transformed::new(
246                    Expr::Column(create_col_from_scalar_expr(
247                        &scalar_expr,
248                        subqry_alias,
249                    )?),
250                    true,
251                    TreeNodeRecursion::Jump,
252                ))
253            }
254            _ => Ok(Transformed::no(expr)),
255        }
256    }
257}
258
259/// Takes a query like:
260///
261/// ```text
262/// select id from customers where balance >
263///     (select avg(total) from orders where orders.c_id = customers.id)
264/// ```
265///
266/// and optimizes it into:
267///
268/// ```text
269/// select c.id from customers c
270/// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
271/// where c.balance > o.val
272/// ```
273///
274/// Or a query like:
275///
276/// ```text
277/// select id from customers where balance >
278///     (select avg(total) from orders)
279/// ```
280///
281/// and optimizes it into:
282///
283/// ```text
284/// select c.id from customers c
285/// left join (select avg(total) as val from orders) a
286/// where c.balance > a.val
287/// ```
288///
289/// # Arguments
290///
291/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders)
292/// * `filter_input` - The non-subquery portion (from customers)
293/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y)
294/// * `subquery_alias` - Subquery aliases
295fn build_join(
296    subquery: &Subquery,
297    filter_input: &LogicalPlan,
298    subquery_alias: &str,
299) -> Result<Option<(LogicalPlan, HashMap<String, Expr>)>> {
300    let subquery_plan = subquery.subquery.as_ref();
301    let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
302    let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?;
303    if !pull_up.can_pull_up {
304        return Ok(None);
305    }
306
307    let collected_count_expr_map =
308        pull_up.collected_count_expr_map.get(&new_plan).cloned();
309    let sub_query_alias = LogicalPlanBuilder::from(new_plan)
310        .alias(subquery_alias.to_string())?
311        .build()?;
312
313    let mut all_correlated_cols = BTreeSet::new();
314    pull_up
315        .correlated_subquery_cols_map
316        .values()
317        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
318
319    // alias the join filter
320    let join_filter_opt =
321        conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
322            replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
323        })?;
324
325    // join our sub query into the main plan
326    let new_plan = if join_filter_opt.is_none() {
327        match filter_input {
328            LogicalPlan::EmptyRelation(EmptyRelation {
329                produce_one_row: true,
330                schema: _,
331            }) => sub_query_alias,
332            _ => {
333                // if not correlated, group down to 1 row and left join on that (preserving row count)
334                LogicalPlanBuilder::from(filter_input.clone())
335                    .join_on(
336                        sub_query_alias,
337                        JoinType::Left,
338                        vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)],
339                    )?
340                    .build()?
341            }
342        }
343    } else {
344        // left join if correlated, grouping by the join keys so we don't change row count
345        LogicalPlanBuilder::from(filter_input.clone())
346            .join_on(sub_query_alias, JoinType::Left, join_filter_opt)?
347            .build()?
348    };
349    let mut computation_project_expr = HashMap::new();
350    if let Some(expr_map) = collected_count_expr_map {
351        for (name, result) in expr_map {
352            if evaluates_to_null(result.clone(), result.column_refs())? {
353                // If expr always returns null when column is null, skip processing
354                continue;
355            }
356            let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr {
357                Expr::Case(expr::Case {
358                    expr: None,
359                    when_then_expr: vec![
360                        (
361                            Box::new(Expr::IsNull(Box::new(Expr::Column(
362                                Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
363                            )))),
364                            Box::new(result),
365                        ),
366                        (
367                            Box::new(Expr::Not(Box::new(filter.clone()))),
368                            Box::new(Expr::Literal(ScalarValue::Null, None)),
369                        ),
370                    ],
371                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
372                        name.clone(),
373                    )))),
374                })
375            } else {
376                Expr::Case(expr::Case {
377                    expr: None,
378                    when_then_expr: vec![(
379                        Box::new(Expr::IsNull(Box::new(Expr::Column(
380                            Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
381                        )))),
382                        Box::new(result),
383                    )],
384                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
385                        name.clone(),
386                    )))),
387                })
388            };
389            let mut expr_rewrite = TypeCoercionRewriter {
390                schema: new_plan.schema(),
391            };
392            computation_project_expr
393                .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?);
394        }
395    }
396
397    Ok(Some((new_plan, computation_project_expr)))
398}
399
400#[cfg(test)]
401mod tests {
402    use std::ops::Add;
403
404    use super::*;
405    use crate::test::*;
406
407    use arrow::datatypes::DataType;
408    use datafusion_expr::test::function_stub::sum;
409
410    use crate::assert_optimized_plan_eq_display_indent_snapshot;
411    use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between};
412    use datafusion_functions_aggregate::min_max::{max, min};
413
414    macro_rules! assert_optimized_plan_equal {
415        (
416            $plan:expr,
417            @ $expected:literal $(,)?
418        ) => {{
419            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ScalarSubqueryToJoin::new());
420            assert_optimized_plan_eq_display_indent_snapshot!(
421                rule,
422                $plan,
423                @ $expected,
424            )
425        }};
426    }
427
428    /// Test multiple correlated subqueries
429    #[test]
430    fn multiple_subqueries() -> Result<()> {
431        let orders = Arc::new(
432            LogicalPlanBuilder::from(scan_tpch_table("orders"))
433                .filter(
434                    col("orders.o_custkey")
435                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
436                )?
437                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
438                .project(vec![max(col("orders.o_custkey"))])?
439                .build()?,
440        );
441
442        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
443            .filter(
444                lit(1)
445                    .lt(scalar_subquery(Arc::clone(&orders)))
446                    .and(lit(1).lt(scalar_subquery(orders))),
447            )?
448            .project(vec![col("customer.c_custkey")])?
449            .build()?;
450
451        assert_optimized_plan_equal!(
452            plan,
453            @r"
454        Projection: customer.c_custkey [c_custkey:Int64]
455          Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
456            Left Join:  Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
457              Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
458                TableScan: customer [c_custkey:Int64, c_name:Utf8]
459                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
460                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
461                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
462                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
463              SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
464                Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
465                  Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
466                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
467        "
468        )
469    }
470
471    /// Test recursive correlated subqueries
472    #[test]
473    fn recursive_subqueries() -> Result<()> {
474        let lineitem = Arc::new(
475            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
476                .filter(
477                    col("lineitem.l_orderkey")
478                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
479                )?
480                .aggregate(
481                    Vec::<Expr>::new(),
482                    vec![sum(col("lineitem.l_extendedprice"))],
483                )?
484                .project(vec![sum(col("lineitem.l_extendedprice"))])?
485                .build()?,
486        );
487
488        let orders = Arc::new(
489            LogicalPlanBuilder::from(scan_tpch_table("orders"))
490                .filter(
491                    col("orders.o_custkey")
492                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey"))
493                        .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
494                )?
495                .aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
496                .project(vec![sum(col("orders.o_totalprice"))])?
497                .build()?,
498        );
499
500        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
501            .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
502            .project(vec![col("customer.c_custkey")])?
503            .build()?;
504
505        assert_optimized_plan_equal!(
506            plan,
507            @r"
508        Projection: customer.c_custkey [c_custkey:Int64]
509          Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
510            Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
511              TableScan: customer [c_custkey:Int64, c_name:Utf8]
512              SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
513                Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
514                  Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N]
515                    Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
516                      Left Join:  Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
517                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
518                        SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
519                          Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
520                            Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N]
521                              TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
522        "
523        )
524    }
525
526    /// Test for correlated scalar subquery filter with additional subquery filters
527    #[test]
528    fn scalar_subquery_with_subquery_filters() -> Result<()> {
529        let sq = Arc::new(
530            LogicalPlanBuilder::from(scan_tpch_table("orders"))
531                .filter(
532                    out_ref_col(DataType::Int64, "customer.c_custkey")
533                        .eq(col("orders.o_custkey"))
534                        .and(col("o_orderkey").eq(lit(1))),
535                )?
536                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
537                .project(vec![max(col("orders.o_custkey"))])?
538                .build()?,
539        );
540
541        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
542            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
543            .project(vec![col("customer.c_custkey")])?
544            .build()?;
545
546        assert_optimized_plan_equal!(
547            plan,
548            @r"
549        Projection: customer.c_custkey [c_custkey:Int64]
550          Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
551            Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
552              TableScan: customer [c_custkey:Int64, c_name:Utf8]
553              SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
554                Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
555                  Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
556                    Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
557                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
558        "
559        )
560    }
561
562    /// Test for correlated scalar subquery with no columns in schema
563    #[test]
564    fn scalar_subquery_no_cols() -> Result<()> {
565        let sq = Arc::new(
566            LogicalPlanBuilder::from(scan_tpch_table("orders"))
567                .filter(
568                    out_ref_col(DataType::Int64, "customer.c_custkey")
569                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
570                )?
571                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
572                .project(vec![max(col("orders.o_custkey"))])?
573                .build()?,
574        );
575
576        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
577            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
578            .project(vec![col("customer.c_custkey")])?
579            .build()?;
580
581        // it will optimize, but fail for the same reason the unoptimized query would
582        assert_optimized_plan_equal!(
583            plan,
584            @r"
585        Projection: customer.c_custkey [c_custkey:Int64]
586          Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
587            Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
588              TableScan: customer [c_custkey:Int64, c_name:Utf8]
589              SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
590                Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
591                  Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
592                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
593        "
594        )
595    }
596
597    /// Test for scalar subquery with both columns in schema
598    #[test]
599    fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
600        let sq = Arc::new(
601            LogicalPlanBuilder::from(scan_tpch_table("orders"))
602                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
603                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
604                .project(vec![max(col("orders.o_custkey"))])?
605                .build()?,
606        );
607
608        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
609            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
610            .project(vec![col("customer.c_custkey")])?
611            .build()?;
612
613        assert_optimized_plan_equal!(
614            plan,
615            @r"
616        Projection: customer.c_custkey [c_custkey:Int64]
617          Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
618            Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
619              TableScan: customer [c_custkey:Int64, c_name:Utf8]
620              SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
621                Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
622                  Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
623                    Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
624                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
625        "
626        )
627    }
628
629    /// Test for correlated scalar subquery not equal
630    #[test]
631    fn scalar_subquery_where_not_eq() -> Result<()> {
632        let sq = Arc::new(
633            LogicalPlanBuilder::from(scan_tpch_table("orders"))
634                .filter(
635                    out_ref_col(DataType::Int64, "customer.c_custkey")
636                        .not_eq(col("orders.o_custkey")),
637                )?
638                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
639                .project(vec![max(col("orders.o_custkey"))])?
640                .build()?,
641        );
642
643        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
644            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
645            .project(vec![col("customer.c_custkey")])?
646            .build()?;
647
648        // Unsupported predicate, subquery should not be decorrelated
649        assert_optimized_plan_equal!(
650            plan,
651            @r"
652        Projection: customer.c_custkey [c_custkey:Int64]
653          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
654            Subquery: [max(orders.o_custkey):Int64;N]
655              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
656                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
657                  Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
658                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
659            TableScan: customer [c_custkey:Int64, c_name:Utf8]
660        "
661        )
662    }
663
664    /// Test for correlated scalar subquery less than
665    #[test]
666    fn scalar_subquery_where_less_than() -> Result<()> {
667        let sq = Arc::new(
668            LogicalPlanBuilder::from(scan_tpch_table("orders"))
669                .filter(
670                    out_ref_col(DataType::Int64, "customer.c_custkey")
671                        .lt(col("orders.o_custkey")),
672                )?
673                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
674                .project(vec![max(col("orders.o_custkey"))])?
675                .build()?,
676        );
677
678        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
679            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
680            .project(vec![col("customer.c_custkey")])?
681            .build()?;
682
683        // Unsupported predicate, subquery should not be decorrelated
684        assert_optimized_plan_equal!(
685            plan,
686            @r"
687        Projection: customer.c_custkey [c_custkey:Int64]
688          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
689            Subquery: [max(orders.o_custkey):Int64;N]
690              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
691                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
692                  Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
693                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
694            TableScan: customer [c_custkey:Int64, c_name:Utf8]
695        "
696        )
697    }
698
699    /// Test for correlated scalar subquery filter with subquery disjunction
700    #[test]
701    fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
702        let sq = Arc::new(
703            LogicalPlanBuilder::from(scan_tpch_table("orders"))
704                .filter(
705                    out_ref_col(DataType::Int64, "customer.c_custkey")
706                        .eq(col("orders.o_custkey"))
707                        .or(col("o_orderkey").eq(lit(1))),
708                )?
709                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
710                .project(vec![max(col("orders.o_custkey"))])?
711                .build()?,
712        );
713
714        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
715            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
716            .project(vec![col("customer.c_custkey")])?
717            .build()?;
718
719        // Unsupported predicate, subquery should not be decorrelated
720        assert_optimized_plan_equal!(
721            plan,
722            @r"
723        Projection: customer.c_custkey [c_custkey:Int64]
724          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
725            Subquery: [max(orders.o_custkey):Int64;N]
726              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
727                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
728                  Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
729                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
730            TableScan: customer [c_custkey:Int64, c_name:Utf8]
731        "
732        )
733    }
734
735    /// Test for correlated scalar without projection
736    #[test]
737    fn scalar_subquery_no_projection() -> Result<()> {
738        let sq = Arc::new(
739            LogicalPlanBuilder::from(scan_tpch_table("orders"))
740                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
741                .build()?,
742        );
743
744        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
745            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
746            .project(vec![col("customer.c_custkey")])?
747            .build()?;
748
749        let expected = "Error during planning: Scalar subquery should only return one column, but found 4: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice";
750        assert_analyzer_check_err(vec![], plan, expected);
751        Ok(())
752    }
753
754    /// Test for correlated scalar expressions
755    #[test]
756    fn scalar_subquery_project_expr() -> Result<()> {
757        let sq = Arc::new(
758            LogicalPlanBuilder::from(scan_tpch_table("orders"))
759                .filter(
760                    out_ref_col(DataType::Int64, "customer.c_custkey")
761                        .eq(col("orders.o_custkey")),
762                )?
763                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
764                .project(vec![col("max(orders.o_custkey)").add(lit(1))])?
765                .build()?,
766        );
767
768        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
769            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
770            .project(vec![col("customer.c_custkey")])?
771            .build()?;
772
773        assert_optimized_plan_equal!(
774            plan,
775            @r"
776        Projection: customer.c_custkey [c_custkey:Int64]
777          Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
778            Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
779              TableScan: customer [c_custkey:Int64, c_name:Utf8]
780              SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
781                Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
782                  Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
783                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
784        "
785        )
786    }
787
788    /// Test for correlated scalar subquery with non-strong project
789    #[test]
790    fn scalar_subquery_with_non_strong_project() -> Result<()> {
791        let case = Expr::Case(expr::Case {
792            expr: None,
793            when_then_expr: vec![(
794                Box::new(col("max(orders.o_totalprice)")),
795                Box::new(lit("a")),
796            )],
797            else_expr: Some(Box::new(lit("b"))),
798        });
799
800        let sq = Arc::new(
801            LogicalPlanBuilder::from(scan_tpch_table("orders"))
802                .filter(
803                    out_ref_col(DataType::Int64, "customer.c_custkey")
804                        .eq(col("orders.o_custkey")),
805                )?
806                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_totalprice"))])?
807                .project(vec![case])?
808                .build()?,
809        );
810
811        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
812            .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])?
813            .build()?;
814
815        assert_optimized_plan_equal!(
816            plan,
817            @r#"
818        Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N]
819          Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]
820            TableScan: customer [c_custkey:Int64, c_name:Utf8]
821            SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
822              Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
823                Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N]
824                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
825        "#
826        )
827    }
828
829    /// Test for correlated scalar subquery multiple projected columns
830    #[test]
831    fn scalar_subquery_multi_col() -> Result<()> {
832        let sq = Arc::new(
833            LogicalPlanBuilder::from(scan_tpch_table("orders"))
834                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
835                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
836                .build()?,
837        );
838
839        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
840            .filter(
841                col("customer.c_custkey")
842                    .eq(scalar_subquery(sq))
843                    .and(col("c_custkey").eq(lit(1))),
844            )?
845            .project(vec![col("customer.c_custkey")])?
846            .build()?;
847
848        let expected = "Error during planning: Scalar subquery should only return one column, but found 2: orders.o_custkey, orders.o_orderkey";
849        assert_analyzer_check_err(vec![], plan, expected);
850        Ok(())
851    }
852
853    /// Test for correlated scalar subquery filter with additional filters
854    #[test]
855    fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
856        let sq = Arc::new(
857            LogicalPlanBuilder::from(scan_tpch_table("orders"))
858                .filter(
859                    out_ref_col(DataType::Int64, "customer.c_custkey")
860                        .eq(col("orders.o_custkey")),
861                )?
862                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
863                .project(vec![max(col("orders.o_custkey"))])?
864                .build()?,
865        );
866
867        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
868            .filter(
869                col("customer.c_custkey")
870                    .gt_eq(scalar_subquery(sq))
871                    .and(col("c_custkey").eq(lit(1))),
872            )?
873            .project(vec![col("customer.c_custkey")])?
874            .build()?;
875
876        assert_optimized_plan_equal!(
877            plan,
878            @r"
879        Projection: customer.c_custkey [c_custkey:Int64]
880          Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
881            Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
882              TableScan: customer [c_custkey:Int64, c_name:Utf8]
883              SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
884                Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
885                  Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
886                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
887        "
888        )
889    }
890
891    #[test]
892    fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
893        let sq = Arc::new(
894            LogicalPlanBuilder::from(scan_tpch_table("orders"))
895                .filter(
896                    out_ref_col(DataType::Int64, "customer.c_custkey")
897                        .eq(col("orders.o_custkey")),
898                )?
899                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
900                .project(vec![max(col("orders.o_custkey"))])?
901                .build()?,
902        );
903
904        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
905            .filter(
906                col("customer.c_custkey")
907                    .eq(scalar_subquery(sq))
908                    .and(col("c_custkey").eq(lit(1))),
909            )?
910            .project(vec![col("customer.c_custkey")])?
911            .build()?;
912
913        assert_optimized_plan_equal!(
914            plan,
915            @r"
916        Projection: customer.c_custkey [c_custkey:Int64]
917          Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
918            Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
919              TableScan: customer [c_custkey:Int64, c_name:Utf8]
920              SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
921                Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
922                  Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
923                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
924        "
925        )
926    }
927
928    /// Test for correlated scalar subquery filter with disjunctions
929    #[test]
930    fn scalar_subquery_disjunction() -> Result<()> {
931        let sq = Arc::new(
932            LogicalPlanBuilder::from(scan_tpch_table("orders"))
933                .filter(
934                    out_ref_col(DataType::Int64, "customer.c_custkey")
935                        .eq(col("orders.o_custkey")),
936                )?
937                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
938                .project(vec![max(col("orders.o_custkey"))])?
939                .build()?,
940        );
941
942        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
943            .filter(
944                col("customer.c_custkey")
945                    .eq(scalar_subquery(sq))
946                    .or(col("customer.c_custkey").eq(lit(1))),
947            )?
948            .project(vec![col("customer.c_custkey")])?
949            .build()?;
950
951        assert_optimized_plan_equal!(
952            plan,
953            @r"
954        Projection: customer.c_custkey [c_custkey:Int64]
955          Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
956            Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
957              TableScan: customer [c_custkey:Int64, c_name:Utf8]
958              SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
959                Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
960                  Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
961                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
962        "
963        )
964    }
965
966    /// Test for correlated scalar subquery filter
967    #[test]
968    fn exists_subquery_correlated() -> Result<()> {
969        let sq = Arc::new(
970            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
971                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
972                .aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
973                .project(vec![min(col("c"))])?
974                .build()?,
975        );
976
977        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
978            .filter(col("test.c").lt(scalar_subquery(sq)))?
979            .project(vec![col("test.c")])?
980            .build()?;
981
982        assert_optimized_plan_equal!(
983            plan,
984            @r"
985        Projection: test.c [c:UInt32]
986          Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
987            Left Join:  Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
988              TableScan: test [a:UInt32, b:UInt32, c:UInt32]
989              SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
990                Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
991                  Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N]
992                    TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
993        "
994        )
995    }
996
997    /// Test for non-correlated scalar subquery with no filters
998    #[test]
999    fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
1000        let sq = Arc::new(
1001            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1002                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1003                .project(vec![max(col("orders.o_custkey"))])?
1004                .build()?,
1005        );
1006
1007        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1008            .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
1009            .project(vec![col("customer.c_custkey")])?
1010            .build()?;
1011
1012        assert_optimized_plan_equal!(
1013            plan,
1014            @r"
1015        Projection: customer.c_custkey [c_custkey:Int64]
1016          Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1017            Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1018              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1019              SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1020                Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1021                  Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1022                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1023        "
1024        )
1025    }
1026
1027    #[test]
1028    fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
1029        let sq = Arc::new(
1030            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1031                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1032                .project(vec![max(col("orders.o_custkey"))])?
1033                .build()?,
1034        );
1035
1036        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1037            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
1038            .project(vec![col("customer.c_custkey")])?
1039            .build()?;
1040
1041        assert_optimized_plan_equal!(
1042            plan,
1043            @r"
1044        Projection: customer.c_custkey [c_custkey:Int64]
1045          Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1046            Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1047              TableScan: customer [c_custkey:Int64, c_name:Utf8]
1048              SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1049                Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1050                  Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1051                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1052        "
1053        )
1054    }
1055
1056    #[test]
1057    fn correlated_scalar_subquery_in_between_clause() -> Result<()> {
1058        let sq1 = Arc::new(
1059            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1060                .filter(
1061                    out_ref_col(DataType::Int64, "customer.c_custkey")
1062                        .eq(col("orders.o_custkey")),
1063                )?
1064                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1065                .project(vec![min(col("orders.o_custkey"))])?
1066                .build()?,
1067        );
1068        let sq2 = Arc::new(
1069            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1070                .filter(
1071                    out_ref_col(DataType::Int64, "customer.c_custkey")
1072                        .eq(col("orders.o_custkey")),
1073                )?
1074                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1075                .project(vec![max(col("orders.o_custkey"))])?
1076                .build()?,
1077        );
1078
1079        let between_expr = Expr::Between(Between {
1080            expr: Box::new(col("customer.c_custkey")),
1081            negated: false,
1082            low: Box::new(scalar_subquery(sq1)),
1083            high: Box::new(scalar_subquery(sq2)),
1084        });
1085
1086        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1087            .filter(between_expr)?
1088            .project(vec![col("customer.c_custkey")])?
1089            .build()?;
1090
1091        assert_optimized_plan_equal!(
1092            plan,
1093            @r"
1094        Projection: customer.c_custkey [c_custkey:Int64]
1095          Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1096            Left Join:  Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1097              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1098                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1099                SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1100                  Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1101                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N]
1102                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1103              SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1104                Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1105                  Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
1106                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1107        "
1108        )
1109    }
1110
1111    #[test]
1112    fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> {
1113        let sq1 = Arc::new(
1114            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1115                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1116                .project(vec![min(col("orders.o_custkey"))])?
1117                .build()?,
1118        );
1119        let sq2 = Arc::new(
1120            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1121                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1122                .project(vec![max(col("orders.o_custkey"))])?
1123                .build()?,
1124        );
1125
1126        let between_expr = Expr::Between(Between {
1127            expr: Box::new(col("customer.c_custkey")),
1128            negated: false,
1129            low: Box::new(scalar_subquery(sq1)),
1130            high: Box::new(scalar_subquery(sq2)),
1131        });
1132
1133        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1134            .filter(between_expr)?
1135            .project(vec![col("customer.c_custkey")])?
1136            .build()?;
1137
1138        assert_optimized_plan_equal!(
1139            plan,
1140            @r"
1141        Projection: customer.c_custkey [c_custkey:Int64]
1142          Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]
1143            Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]
1144              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]
1145                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1146                SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]
1147                  Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]
1148                    Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]
1149                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1150              SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]
1151                Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1152                  Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1153                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1154        "
1155        )
1156    }
1157}