datafusion_optimizer/
eliminate_cross_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
19use crate::{OptimizerConfig, OptimizerRule};
20use std::sync::Arc;
21
22use crate::join_key_set::JoinKeySet;
23use datafusion_common::tree_node::{Transformed, TreeNode};
24use datafusion_common::Result;
25use datafusion_expr::expr::{BinaryExpr, Expr};
26use datafusion_expr::logical_plan::{
27    Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
28};
29use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
30use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator};
31
32#[derive(Default, Debug)]
33pub struct EliminateCrossJoin;
34
35impl EliminateCrossJoin {
36    #[allow(missing_docs)]
37    pub fn new() -> Self {
38        Self {}
39    }
40}
41
42/// Eliminate cross joins by rewriting them to inner joins when possible.
43///
44/// # Example
45/// The initial plan for this query:
46/// ```sql
47/// select ... from a, b where a.x = b.y and b.xx = 100;
48/// ```
49///
50/// Looks like this:
51/// ```text
52/// Filter(a.x = b.y AND b.xx = 100)
53///  Cross Join
54///   TableScan a
55///   TableScan b
56/// ```
57///
58/// After the rule is applied, the plan will look like this:
59/// ```text
60/// Filter(b.xx = 100)
61///   InnerJoin(a.x = b.y)
62///     TableScan a
63///     TableScan b
64/// ```
65///
66/// # Other Examples
67/// * 'select ... from a, b where a.x = b.y and b.xx = 100;'
68/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
69/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
70/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);'
71/// * 'select ... from a, b where a.x > b.y'
72///
73/// For above queries, the join predicate is available in filters and they are moved to
74/// join nodes appropriately
75///
76/// This fix helps to improve the performance of TPCH Q19. issue#78
77impl OptimizerRule for EliminateCrossJoin {
78    fn supports_rewrite(&self) -> bool {
79        true
80    }
81
82    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
83    fn rewrite(
84        &self,
85        plan: LogicalPlan,
86        config: &dyn OptimizerConfig,
87    ) -> Result<Transformed<LogicalPlan>> {
88        let plan_schema = Arc::clone(plan.schema());
89        let mut possible_join_keys = JoinKeySet::new();
90        let mut all_inputs: Vec<LogicalPlan> = vec![];
91        let mut all_filters: Vec<Expr> = vec![];
92
93        let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
94            // if input isn't a join that can potentially be rewritten
95            // avoid unwrapping the input
96            let rewritable = matches!(
97                filter.input.as_ref(),
98                LogicalPlan::Join(Join {
99                    join_type: JoinType::Inner,
100                    ..
101                })
102            );
103
104            if !rewritable {
105                // recursively try to rewrite children
106                return rewrite_children(self, LogicalPlan::Filter(filter), config);
107            }
108
109            if !can_flatten_join_inputs(&filter.input) {
110                return Ok(Transformed::no(LogicalPlan::Filter(filter)));
111            }
112
113            let Filter {
114                input, predicate, ..
115            } = filter;
116            flatten_join_inputs(
117                Arc::unwrap_or_clone(input),
118                &mut possible_join_keys,
119                &mut all_inputs,
120                &mut all_filters,
121            )?;
122
123            extract_possible_join_keys(&predicate, &mut possible_join_keys);
124            Some(predicate)
125        } else if matches!(
126            plan,
127            LogicalPlan::Join(Join {
128                join_type: JoinType::Inner,
129                ..
130            })
131        ) {
132            if !can_flatten_join_inputs(&plan) {
133                return Ok(Transformed::no(plan));
134            }
135            flatten_join_inputs(
136                plan,
137                &mut possible_join_keys,
138                &mut all_inputs,
139                &mut all_filters,
140            )?;
141            None
142        } else {
143            // recursively try to rewrite children
144            return rewrite_children(self, plan, config);
145        };
146
147        // Join keys are handled locally:
148        let mut all_join_keys = JoinKeySet::new();
149        let mut left = all_inputs.remove(0);
150        while !all_inputs.is_empty() {
151            left = find_inner_join(
152                left,
153                &mut all_inputs,
154                &possible_join_keys,
155                &mut all_join_keys,
156            )?;
157        }
158
159        left = rewrite_children(self, left, config)?.data;
160
161        if &plan_schema != left.schema() {
162            left = LogicalPlan::Projection(Projection::new_from_schema(
163                Arc::new(left),
164                Arc::clone(&plan_schema),
165            ));
166        }
167
168        if !all_filters.is_empty() {
169            // Add any filters on top - PushDownFilter can push filters down to applicable join
170            let first = all_filters.swap_remove(0);
171            let predicate = all_filters.into_iter().fold(first, and);
172            left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
173        }
174
175        let Some(predicate) = parent_predicate else {
176            return Ok(Transformed::yes(left));
177        };
178
179        // If there are no join keys then do nothing:
180        if all_join_keys.is_empty() {
181            Filter::try_new(predicate, Arc::new(left))
182                .map(|filter| Transformed::yes(LogicalPlan::Filter(filter)))
183        } else {
184            // Remove join expressions from filter:
185            match remove_join_expressions(predicate, &all_join_keys) {
186                Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left))
187                    .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))),
188                _ => Ok(Transformed::yes(left)),
189            }
190        }
191    }
192
193    fn name(&self) -> &str {
194        "eliminate_cross_join"
195    }
196}
197
198fn rewrite_children(
199    optimizer: &impl OptimizerRule,
200    plan: LogicalPlan,
201    config: &dyn OptimizerConfig,
202) -> Result<Transformed<LogicalPlan>> {
203    let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?;
204
205    // recompute schema if the plan was transformed
206    if transformed_plan.transformed {
207        transformed_plan.map_data(|plan| plan.recompute_schema())
208    } else {
209        Ok(transformed_plan)
210    }
211}
212
213/// Recursively accumulate possible_join_keys and inputs from inner joins
214/// (including cross joins).
215///
216/// Assumes can_flatten_join_inputs has returned true and thus the plan can be
217/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to
218/// possible_join_keys
219fn flatten_join_inputs(
220    plan: LogicalPlan,
221    possible_join_keys: &mut JoinKeySet,
222    all_inputs: &mut Vec<LogicalPlan>,
223    all_filters: &mut Vec<Expr>,
224) -> Result<()> {
225    match plan {
226        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
227            if let Some(filter) = join.filter {
228                all_filters.push(filter);
229            }
230            possible_join_keys.insert_all_owned(join.on);
231            flatten_join_inputs(
232                Arc::unwrap_or_clone(join.left),
233                possible_join_keys,
234                all_inputs,
235                all_filters,
236            )?;
237            flatten_join_inputs(
238                Arc::unwrap_or_clone(join.right),
239                possible_join_keys,
240                all_inputs,
241                all_filters,
242            )?;
243        }
244        _ => {
245            all_inputs.push(plan);
246        }
247    };
248    Ok(())
249}
250
251/// Returns true if the plan is a Join or Cross join could be flattened with
252/// `flatten_join_inputs`
253///
254/// Must stay in sync with `flatten_join_inputs`
255fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
256    // can only flatten inner / cross joins
257    match plan {
258        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
259        _ => return false,
260    };
261
262    for child in plan.inputs() {
263        if let LogicalPlan::Join(Join {
264            join_type: JoinType::Inner,
265            ..
266        }) = child
267        {
268            if !can_flatten_join_inputs(child) {
269                return false;
270            }
271        }
272    }
273    true
274}
275
276/// Finds the next to join with the left input plan,
277///
278/// Finds the next `right` from `rights` that can be joined with `left_input`
279/// plan based on the join keys in `possible_join_keys`.
280///
281/// If such a matching `right` is found:
282/// 1. Adds the matching join keys to `all_join_keys`.
283/// 2. Returns `left_input JOIN right ON (all join keys)`.
284///
285/// If no matching `right` is found:
286/// 1. Removes the first plan from `rights`
287/// 2. Returns `left_input CROSS JOIN right`.
288fn find_inner_join(
289    left_input: LogicalPlan,
290    rights: &mut Vec<LogicalPlan>,
291    possible_join_keys: &JoinKeySet,
292    all_join_keys: &mut JoinKeySet,
293) -> Result<LogicalPlan> {
294    for (i, right_input) in rights.iter().enumerate() {
295        let mut join_keys = vec![];
296
297        for (l, r) in possible_join_keys.iter() {
298            let key_pair = find_valid_equijoin_key_pair(
299                l,
300                r,
301                left_input.schema(),
302                right_input.schema(),
303            )?;
304
305            // Save join keys
306            if let Some((valid_l, valid_r)) = key_pair {
307                if can_hash(&valid_l.get_type(left_input.schema())?) {
308                    join_keys.push((valid_l, valid_r));
309                }
310            }
311        }
312
313        // Found one or more matching join keys
314        if !join_keys.is_empty() {
315            all_join_keys.insert_all(join_keys.iter());
316            let right_input = rights.remove(i);
317            let join_schema = Arc::new(build_join_schema(
318                left_input.schema(),
319                right_input.schema(),
320                &JoinType::Inner,
321            )?);
322
323            return Ok(LogicalPlan::Join(Join {
324                left: Arc::new(left_input),
325                right: Arc::new(right_input),
326                join_type: JoinType::Inner,
327                join_constraint: JoinConstraint::On,
328                on: join_keys,
329                filter: None,
330                schema: join_schema,
331                null_equals_null: false,
332            }));
333        }
334    }
335
336    // no matching right plan had any join keys, cross join with the first right
337    // plan
338    let right = rights.remove(0);
339    let join_schema = Arc::new(build_join_schema(
340        left_input.schema(),
341        right.schema(),
342        &JoinType::Inner,
343    )?);
344
345    Ok(LogicalPlan::Join(Join {
346        left: Arc::new(left_input),
347        right: Arc::new(right),
348        schema: join_schema,
349        on: vec![],
350        filter: None,
351        join_type: JoinType::Inner,
352        join_constraint: JoinConstraint::On,
353        null_equals_null: false,
354    }))
355}
356
357/// Extract join keys from a WHERE clause
358fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
359    if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
360        match op {
361            Operator::Eq => {
362                // insert handles ensuring  we don't add the same Join keys multiple times
363                join_keys.insert(left, right);
364            }
365            Operator::And => {
366                extract_possible_join_keys(left, join_keys);
367                extract_possible_join_keys(right, join_keys)
368            }
369            // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
370            Operator::Or => {
371                let mut left_join_keys = JoinKeySet::new();
372                let mut right_join_keys = JoinKeySet::new();
373
374                extract_possible_join_keys(left, &mut left_join_keys);
375                extract_possible_join_keys(right, &mut right_join_keys);
376
377                join_keys.insert_intersection(&left_join_keys, &right_join_keys)
378            }
379            _ => (),
380        };
381    }
382}
383
384/// Remove join expressions from a filter expression
385///
386/// # Returns
387/// * `Some()` when there are few remaining predicates in filter_expr
388/// * `None` otherwise
389fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
390    match expr {
391        Expr::BinaryExpr(BinaryExpr {
392            left,
393            op: Operator::Eq,
394            right,
395        }) if join_keys.contains(&left, &right) => {
396            // was a join key, so remove it
397            None
398        }
399        // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
400        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
401            let l = remove_join_expressions(*left, join_keys);
402            let r = remove_join_expressions(*right, join_keys);
403            match (l, r) {
404                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
405                    Box::new(ll),
406                    op,
407                    Box::new(rr),
408                ))),
409                (Some(ll), _) => Some(ll),
410                (_, Some(rr)) => Some(rr),
411                _ => None,
412            }
413        }
414        Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
415            let l = remove_join_expressions(*left, join_keys);
416            let r = remove_join_expressions(*right, join_keys);
417            match (l, r) {
418                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
419                    Box::new(ll),
420                    op,
421                    Box::new(rr),
422                ))),
423                // When either `left` or `right` is empty, it means they are `true`
424                // so OR'ing anything with them will also be true
425                _ => None,
426            }
427        }
428        _ => Some(expr),
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::optimizer::OptimizerContext;
436    use crate::test::*;
437
438    use datafusion_expr::{
439        binary_expr, col, lit,
440        logical_plan::builder::LogicalPlanBuilder,
441        Operator::{And, Or},
442    };
443    use insta::assert_snapshot;
444
445    macro_rules! assert_optimized_plan_equal {
446        (
447            $plan:expr,
448            @ $expected:literal $(,)?
449        ) => {{
450            let starting_schema = Arc::clone($plan.schema());
451            let rule = EliminateCrossJoin::new();
452            let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap();
453            let formatted_plan = optimized_plan.display_indent_schema();
454            // Ensure the rule was actually applied
455            assert!(is_plan_transformed, "failed to optimize plan");
456            // Verify the schema remains unchanged
457            assert_eq!(&starting_schema, optimized_plan.schema());
458            assert_snapshot!(
459                formatted_plan,
460                @ $expected,
461            );
462
463            Ok(())
464        }};
465    }
466
467    #[test]
468    fn eliminate_cross_with_simple_and() -> Result<()> {
469        let t1 = test_table_scan_with_name("t1")?;
470        let t2 = test_table_scan_with_name("t2")?;
471
472        // could eliminate to inner join since filter has Join predicates
473        let plan = LogicalPlanBuilder::from(t1)
474            .cross_join(t2)?
475            .filter(binary_expr(
476                col("t1.a").eq(col("t2.a")),
477                And,
478                col("t2.c").lt(lit(20u32)),
479            ))?
480            .build()?;
481
482        assert_optimized_plan_equal!(
483            plan,
484            @ r"
485        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
486          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
487            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
488            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
489        "
490        )
491    }
492
493    #[test]
494    fn eliminate_cross_with_simple_or() -> Result<()> {
495        let t1 = test_table_scan_with_name("t1")?;
496        let t2 = test_table_scan_with_name("t2")?;
497
498        // could not eliminate to inner join since filter OR expression and there is no common
499        // Join predicates in left and right of OR expr.
500        let plan = LogicalPlanBuilder::from(t1)
501            .cross_join(t2)?
502            .filter(binary_expr(
503                col("t1.a").eq(col("t2.a")),
504                Or,
505                col("t2.b").eq(col("t1.a")),
506            ))?
507            .build()?;
508
509        assert_optimized_plan_equal!(
510            plan,
511            @ r"
512        Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
513          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
514            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
515            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
516        "
517        )
518    }
519
520    #[test]
521    fn eliminate_cross_with_and() -> Result<()> {
522        let t1 = test_table_scan_with_name("t1")?;
523        let t2 = test_table_scan_with_name("t2")?;
524
525        // could eliminate to inner join
526        let plan = LogicalPlanBuilder::from(t1)
527            .cross_join(t2)?
528            .filter(binary_expr(
529                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(20u32))),
530                And,
531                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").eq(lit(10u32))),
532            ))?
533            .build()?;
534
535        assert_optimized_plan_equal!(
536            plan,
537            @ r"
538        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
539          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
540            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
541            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
542        "
543        )
544    }
545
546    #[test]
547    fn eliminate_cross_with_or() -> Result<()> {
548        let t1 = test_table_scan_with_name("t1")?;
549        let t2 = test_table_scan_with_name("t2")?;
550
551        // could eliminate to inner join since Or predicates have common Join predicates
552        let plan = LogicalPlanBuilder::from(t1)
553            .cross_join(t2)?
554            .filter(binary_expr(
555                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
556                Or,
557                binary_expr(
558                    col("t1.a").eq(col("t2.a")),
559                    And,
560                    col("t2.c").eq(lit(688u32)),
561                ),
562            ))?
563            .build()?;
564
565        assert_optimized_plan_equal!(
566            plan,
567            @ r"
568        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
569          Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
570            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
571            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
572        "
573        )
574    }
575
576    #[test]
577    fn eliminate_cross_not_possible_simple() -> Result<()> {
578        let t1 = test_table_scan_with_name("t1")?;
579        let t2 = test_table_scan_with_name("t2")?;
580
581        // could not eliminate to inner join
582        let plan = LogicalPlanBuilder::from(t1)
583            .cross_join(t2)?
584            .filter(binary_expr(
585                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
586                Or,
587                binary_expr(
588                    col("t1.b").eq(col("t2.b")),
589                    And,
590                    col("t2.c").eq(lit(688u32)),
591                ),
592            ))?
593            .build()?;
594
595        assert_optimized_plan_equal!(
596            plan,
597            @ r"
598        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
599          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
600            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
601            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
602        "
603        )
604    }
605
606    #[test]
607    fn eliminate_cross_not_possible() -> Result<()> {
608        let t1 = test_table_scan_with_name("t1")?;
609        let t2 = test_table_scan_with_name("t2")?;
610
611        // could not eliminate to inner join
612        let plan = LogicalPlanBuilder::from(t1)
613            .cross_join(t2)?
614            .filter(binary_expr(
615                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
616                Or,
617                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").eq(lit(688u32))),
618            ))?
619            .build()?;
620
621        assert_optimized_plan_equal!(
622            plan,
623            @ r"
624        Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
625          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
626            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
627            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
628        "
629        )
630    }
631
632    #[test]
633    fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
634        let t1 = test_table_scan_with_name("t1")?;
635        let t2 = test_table_scan_with_name("t2")?;
636        let t3 = test_table_scan_with_name("t3")?;
637
638        // could not eliminate to inner join with filter
639        let plan = LogicalPlanBuilder::from(t1)
640            .join(
641                t3,
642                JoinType::Inner,
643                (vec!["t1.a"], vec!["t3.a"]),
644                Some(col("t1.a").gt(lit(20u32))),
645            )?
646            .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
647            .filter(col("t1.a").gt(lit(15u32)))?
648            .build()?;
649
650        assert_optimized_plan_equal!(
651            plan,
652            @ r"
653        Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
654          Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
655            Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
656              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
657                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
658                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
659              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
660        "
661        )
662    }
663
664    #[test]
665    /// ```txt
666    /// filter: a.id = b.id and a.id = c.id
667    ///   cross_join a (bc)
668    ///     cross_join b c
669    /// ```
670    /// Without reorder, it will be
671    /// ```txt
672    ///   inner_join a (bc) on a.id = b.id and a.id = c.id
673    ///     cross_join b c
674    /// ```
675    /// Reorder it to be
676    /// ```txt
677    ///   inner_join (ab)c and a.id = c.id
678    ///     inner_join a b on a.id = b.id
679    /// ```
680    fn reorder_join_to_eliminate_cross_join_multi_tables() -> Result<()> {
681        let t1 = test_table_scan_with_name("t1")?;
682        let t2 = test_table_scan_with_name("t2")?;
683        let t3 = test_table_scan_with_name("t3")?;
684
685        // could eliminate to inner join
686        let plan = LogicalPlanBuilder::from(t1)
687            .cross_join(t2)?
688            .cross_join(t3)?
689            .filter(binary_expr(
690                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
691                And,
692                binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
693            ))?
694            .build()?;
695
696        assert_optimized_plan_equal!(
697            plan,
698            @ r"
699        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
700          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
701            Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
702              Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
703                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
704                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
705              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
706        "
707        )
708    }
709
710    #[test]
711    fn eliminate_cross_join_multi_tables() -> Result<()> {
712        let t1 = test_table_scan_with_name("t1")?;
713        let t2 = test_table_scan_with_name("t2")?;
714        let t3 = test_table_scan_with_name("t3")?;
715        let t4 = test_table_scan_with_name("t4")?;
716
717        // could eliminate to inner join
718        let plan1 = LogicalPlanBuilder::from(t1)
719            .cross_join(t2)?
720            .filter(binary_expr(
721                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
722                Or,
723                binary_expr(
724                    col("t1.a").eq(col("t2.a")),
725                    And,
726                    col("t2.c").eq(lit(688u32)),
727                ),
728            ))?
729            .build()?;
730
731        let plan2 = LogicalPlanBuilder::from(t3)
732            .cross_join(t4)?
733            .filter(binary_expr(
734                binary_expr(
735                    binary_expr(
736                        col("t3.a").eq(col("t4.a")),
737                        And,
738                        col("t4.c").lt(lit(15u32)),
739                    ),
740                    Or,
741                    binary_expr(
742                        col("t3.a").eq(col("t4.a")),
743                        And,
744                        col("t3.c").eq(lit(688u32)),
745                    ),
746                ),
747                Or,
748                binary_expr(
749                    col("t3.a").eq(col("t4.a")),
750                    And,
751                    col("t3.b").eq(col("t4.b")),
752                ),
753            ))?
754            .build()?;
755
756        let plan = LogicalPlanBuilder::from(plan1)
757            .cross_join(plan2)?
758            .filter(binary_expr(
759                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
760                Or,
761                binary_expr(
762                    col("t3.a").eq(col("t1.a")),
763                    And,
764                    col("t4.c").eq(lit(688u32)),
765                ),
766            ))?
767            .build()?;
768
769        assert_optimized_plan_equal!(
770            plan,
771            @ r"
772        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
773          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
774            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
775              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
776                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
777                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
778            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
779              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
780                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
781                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
782        "
783        )
784    }
785
786    #[test]
787    fn eliminate_cross_join_multi_tables_1() -> Result<()> {
788        let t1 = test_table_scan_with_name("t1")?;
789        let t2 = test_table_scan_with_name("t2")?;
790        let t3 = test_table_scan_with_name("t3")?;
791        let t4 = test_table_scan_with_name("t4")?;
792
793        // could eliminate to inner join
794        let plan1 = LogicalPlanBuilder::from(t1)
795            .cross_join(t2)?
796            .filter(binary_expr(
797                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
798                Or,
799                binary_expr(
800                    col("t1.a").eq(col("t2.a")),
801                    And,
802                    col("t2.c").eq(lit(688u32)),
803                ),
804            ))?
805            .build()?;
806
807        // could eliminate to inner join
808        let plan2 = LogicalPlanBuilder::from(t3)
809            .cross_join(t4)?
810            .filter(binary_expr(
811                binary_expr(
812                    binary_expr(
813                        col("t3.a").eq(col("t4.a")),
814                        And,
815                        col("t4.c").lt(lit(15u32)),
816                    ),
817                    Or,
818                    binary_expr(
819                        col("t3.a").eq(col("t4.a")),
820                        And,
821                        col("t3.c").eq(lit(688u32)),
822                    ),
823                ),
824                Or,
825                binary_expr(
826                    col("t3.a").eq(col("t4.a")),
827                    And,
828                    col("t3.b").eq(col("t4.b")),
829                ),
830            ))?
831            .build()?;
832
833        // could not eliminate to inner join
834        let plan = LogicalPlanBuilder::from(plan1)
835            .cross_join(plan2)?
836            .filter(binary_expr(
837                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
838                Or,
839                binary_expr(col("t3.a").eq(col("t1.a")), Or, col("t4.c").eq(lit(688u32))),
840            ))?
841            .build()?;
842
843        assert_optimized_plan_equal!(
844            plan,
845            @ r"
846        Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
847          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
848            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
849              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
850                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
851                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
852            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
853              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
854                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
855                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
856        "
857        )
858    }
859
860    #[test]
861    fn eliminate_cross_join_multi_tables_2() -> Result<()> {
862        let t1 = test_table_scan_with_name("t1")?;
863        let t2 = test_table_scan_with_name("t2")?;
864        let t3 = test_table_scan_with_name("t3")?;
865        let t4 = test_table_scan_with_name("t4")?;
866
867        // could eliminate to inner join
868        let plan1 = LogicalPlanBuilder::from(t1)
869            .cross_join(t2)?
870            .filter(binary_expr(
871                binary_expr(col("t1.a").eq(col("t2.a")), And, col("t2.c").lt(lit(15u32))),
872                Or,
873                binary_expr(
874                    col("t1.a").eq(col("t2.a")),
875                    And,
876                    col("t2.c").eq(lit(688u32)),
877                ),
878            ))?
879            .build()?;
880
881        // could not eliminate to inner join
882        let plan2 = LogicalPlanBuilder::from(t3)
883            .cross_join(t4)?
884            .filter(binary_expr(
885                binary_expr(
886                    binary_expr(
887                        col("t3.a").eq(col("t4.a")),
888                        And,
889                        col("t4.c").lt(lit(15u32)),
890                    ),
891                    Or,
892                    binary_expr(
893                        col("t3.a").eq(col("t4.a")),
894                        And,
895                        col("t3.c").eq(lit(688u32)),
896                    ),
897                ),
898                Or,
899                binary_expr(col("t3.a").eq(col("t4.a")), Or, col("t3.b").eq(col("t4.b"))),
900            ))?
901            .build()?;
902
903        // could eliminate to inner join
904        let plan = LogicalPlanBuilder::from(plan1)
905            .cross_join(plan2)?
906            .filter(binary_expr(
907                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
908                Or,
909                binary_expr(
910                    col("t3.a").eq(col("t1.a")),
911                    And,
912                    col("t4.c").eq(lit(688u32)),
913                ),
914            ))?
915            .build()?;
916
917        assert_optimized_plan_equal!(
918            plan,
919            @ r"
920        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
921          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
922            Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
923              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
924                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
925                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
926            Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
927              Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
928                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
929                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
930        "
931        )
932    }
933
934    #[test]
935    fn eliminate_cross_join_multi_tables_3() -> Result<()> {
936        let t1 = test_table_scan_with_name("t1")?;
937        let t2 = test_table_scan_with_name("t2")?;
938        let t3 = test_table_scan_with_name("t3")?;
939        let t4 = test_table_scan_with_name("t4")?;
940
941        // could not eliminate to inner join
942        let plan1 = LogicalPlanBuilder::from(t1)
943            .cross_join(t2)?
944            .filter(binary_expr(
945                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
946                Or,
947                binary_expr(
948                    col("t1.a").eq(col("t2.a")),
949                    And,
950                    col("t2.c").eq(lit(688u32)),
951                ),
952            ))?
953            .build()?;
954
955        // could eliminate to inner join
956        let plan2 = LogicalPlanBuilder::from(t3)
957            .cross_join(t4)?
958            .filter(binary_expr(
959                binary_expr(
960                    binary_expr(
961                        col("t3.a").eq(col("t4.a")),
962                        And,
963                        col("t4.c").lt(lit(15u32)),
964                    ),
965                    Or,
966                    binary_expr(
967                        col("t3.a").eq(col("t4.a")),
968                        And,
969                        col("t3.c").eq(lit(688u32)),
970                    ),
971                ),
972                Or,
973                binary_expr(
974                    col("t3.a").eq(col("t4.a")),
975                    And,
976                    col("t3.b").eq(col("t4.b")),
977                ),
978            ))?
979            .build()?;
980
981        // could eliminate to inner join
982        let plan = LogicalPlanBuilder::from(plan1)
983            .cross_join(plan2)?
984            .filter(binary_expr(
985                binary_expr(col("t3.a").eq(col("t1.a")), And, col("t4.c").lt(lit(15u32))),
986                Or,
987                binary_expr(
988                    col("t3.a").eq(col("t1.a")),
989                    And,
990                    col("t4.c").eq(lit(688u32)),
991                ),
992            ))?
993            .build()?;
994
995        assert_optimized_plan_equal!(
996            plan,
997            @ r"
998        Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
999          Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1000            Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1001              Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1002                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1003                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1004            Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1005              Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1006                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1007                TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1008        "
1009        )
1010    }
1011
1012    #[test]
1013    fn eliminate_cross_join_multi_tables_4() -> Result<()> {
1014        let t1 = test_table_scan_with_name("t1")?;
1015        let t2 = test_table_scan_with_name("t2")?;
1016        let t3 = test_table_scan_with_name("t3")?;
1017        let t4 = test_table_scan_with_name("t4")?;
1018
1019        // could eliminate to inner join
1020        // filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688)
1021        let plan1 = LogicalPlanBuilder::from(t1)
1022            .cross_join(t2)?
1023            .filter(binary_expr(
1024                binary_expr(col("t1.a").eq(col("t2.a")), Or, col("t2.c").lt(lit(15u32))),
1025                And,
1026                binary_expr(
1027                    col("t1.a").eq(col("t2.a")),
1028                    And,
1029                    col("t2.c").eq(lit(688u32)),
1030                ),
1031            ))?
1032            .build()?;
1033
1034        // could eliminate to inner join
1035        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1036
1037        // could eliminate to inner join
1038        // filter:
1039        //   ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1040        //     AND
1041        //   ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1042        let plan = LogicalPlanBuilder::from(plan1)
1043            .cross_join(plan2)?
1044            .filter(binary_expr(
1045                binary_expr(
1046                    binary_expr(
1047                        col("t3.a").eq(col("t1.a")),
1048                        And,
1049                        col("t4.c").lt(lit(15u32)),
1050                    ),
1051                    Or,
1052                    binary_expr(
1053                        col("t3.a").eq(col("t1.a")),
1054                        And,
1055                        col("t4.c").eq(lit(688u32)),
1056                    ),
1057                ),
1058                And,
1059                binary_expr(
1060                    binary_expr(
1061                        binary_expr(
1062                            col("t3.a").eq(col("t4.a")),
1063                            And,
1064                            col("t4.c").lt(lit(15u32)),
1065                        ),
1066                        Or,
1067                        binary_expr(
1068                            col("t3.a").eq(col("t4.a")),
1069                            And,
1070                            col("t3.c").eq(lit(688u32)),
1071                        ),
1072                    ),
1073                    Or,
1074                    binary_expr(
1075                        col("t3.a").eq(col("t4.a")),
1076                        And,
1077                        col("t3.b").eq(col("t4.b")),
1078                    ),
1079                ),
1080            ))?
1081            .build()?;
1082
1083        assert_optimized_plan_equal!(
1084            plan,
1085            @ r"
1086        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1087          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1088            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1089              Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1090                Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1091                  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1092                  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1093              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1094            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1095        "
1096        )
1097    }
1098
1099    #[test]
1100    fn eliminate_cross_join_multi_tables_5() -> Result<()> {
1101        let t1 = test_table_scan_with_name("t1")?;
1102        let t2 = test_table_scan_with_name("t2")?;
1103        let t3 = test_table_scan_with_name("t3")?;
1104        let t4 = test_table_scan_with_name("t4")?;
1105
1106        // could eliminate to inner join
1107        let plan1 = LogicalPlanBuilder::from(t1).cross_join(t2)?.build()?;
1108
1109        // could eliminate to inner join
1110        let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;
1111
1112        // could eliminate to inner join
1113        // Filter:
1114        //  ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
1115        //      AND
1116        //  ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
1117        //      AND
1118        //  ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688))
1119        let plan = LogicalPlanBuilder::from(plan1)
1120            .cross_join(plan2)?
1121            .filter(binary_expr(
1122                binary_expr(
1123                    binary_expr(
1124                        binary_expr(
1125                            col("t3.a").eq(col("t1.a")),
1126                            And,
1127                            col("t4.c").lt(lit(15u32)),
1128                        ),
1129                        Or,
1130                        binary_expr(
1131                            col("t3.a").eq(col("t1.a")),
1132                            And,
1133                            col("t4.c").eq(lit(688u32)),
1134                        ),
1135                    ),
1136                    And,
1137                    binary_expr(
1138                        binary_expr(
1139                            binary_expr(
1140                                col("t3.a").eq(col("t4.a")),
1141                                And,
1142                                col("t4.c").lt(lit(15u32)),
1143                            ),
1144                            Or,
1145                            binary_expr(
1146                                col("t3.a").eq(col("t4.a")),
1147                                And,
1148                                col("t3.c").eq(lit(688u32)),
1149                            ),
1150                        ),
1151                        Or,
1152                        binary_expr(
1153                            col("t3.a").eq(col("t4.a")),
1154                            And,
1155                            col("t3.b").eq(col("t4.b")),
1156                        ),
1157                    ),
1158                ),
1159                And,
1160                binary_expr(
1161                    binary_expr(
1162                        col("t1.a").eq(col("t2.a")),
1163                        Or,
1164                        col("t2.c").lt(lit(15u32)),
1165                    ),
1166                    And,
1167                    binary_expr(
1168                        col("t1.a").eq(col("t2.a")),
1169                        And,
1170                        col("t2.c").eq(lit(688u32)),
1171                    ),
1172                ),
1173            ))?
1174            .build()?;
1175
1176        assert_optimized_plan_equal!(
1177            plan,
1178            @ r"
1179        Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1180          Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1181            Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1182              Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1183                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1184                TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1185              TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1186            TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]
1187        "
1188        )
1189    }
1190
1191    #[test]
1192    fn eliminate_cross_join_with_expr_and() -> Result<()> {
1193        let t1 = test_table_scan_with_name("t1")?;
1194        let t2 = test_table_scan_with_name("t2")?;
1195
1196        // could eliminate to inner join since filter has Join predicates
1197        let plan = LogicalPlanBuilder::from(t1)
1198            .cross_join(t2)?
1199            .filter(binary_expr(
1200                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1201                And,
1202                col("t2.c").lt(lit(20u32)),
1203            ))?
1204            .build()?;
1205
1206        assert_optimized_plan_equal!(
1207            plan,
1208            @ r"
1209        Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1210          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1211            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1212            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1213        "
1214        )
1215    }
1216
1217    #[test]
1218    fn eliminate_cross_with_expr_or() -> Result<()> {
1219        let t1 = test_table_scan_with_name("t1")?;
1220        let t2 = test_table_scan_with_name("t2")?;
1221
1222        // could not eliminate to inner join since filter OR expression and there is no common
1223        // Join predicates in left and right of OR expr.
1224        let plan = LogicalPlanBuilder::from(t1)
1225            .cross_join(t2)?
1226            .filter(binary_expr(
1227                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1228                Or,
1229                col("t2.b").eq(col("t1.a")),
1230            ))?
1231            .build()?;
1232
1233        assert_optimized_plan_equal!(
1234            plan,
1235            @ r"
1236        Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1237          Cross Join:  [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1238            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1239            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1240        "
1241        )
1242    }
1243
1244    #[test]
1245    fn eliminate_cross_with_common_expr_and() -> Result<()> {
1246        let t1 = test_table_scan_with_name("t1")?;
1247        let t2 = test_table_scan_with_name("t2")?;
1248
1249        // could eliminate to inner join
1250        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1251        let plan = LogicalPlanBuilder::from(t1)
1252            .cross_join(t2)?
1253            .filter(binary_expr(
1254                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
1255                And,
1256                binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
1257            ))?
1258            .build()?;
1259
1260        assert_optimized_plan_equal!(
1261            plan,
1262            @ r"
1263        Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1264          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1265            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1266            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1267        "
1268        )
1269    }
1270
1271    #[test]
1272    fn eliminate_cross_with_common_expr_or() -> Result<()> {
1273        let t1 = test_table_scan_with_name("t1")?;
1274        let t2 = test_table_scan_with_name("t2")?;
1275
1276        // could eliminate to inner join since Or predicates have common Join predicates
1277        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
1278        let plan = LogicalPlanBuilder::from(t1)
1279            .cross_join(t2)?
1280            .filter(binary_expr(
1281                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
1282                Or,
1283                binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
1284            ))?
1285            .build()?;
1286
1287        assert_optimized_plan_equal!(
1288            plan,
1289            @ r"
1290        Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1291          Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1292            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1293            TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1294        "
1295        )
1296    }
1297
1298    #[test]
1299    fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
1300        let t1 = test_table_scan_with_name("t1")?;
1301        let t2 = test_table_scan_with_name("t2")?;
1302        let t3 = test_table_scan_with_name("t3")?;
1303
1304        // could eliminate to inner join
1305        let plan = LogicalPlanBuilder::from(t1)
1306            .cross_join(t2)?
1307            .cross_join(t3)?
1308            .filter(binary_expr(
1309                binary_expr(
1310                    (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
1311                    And,
1312                    col("t3.c").lt(lit(15u32)),
1313                ),
1314                And,
1315                binary_expr(
1316                    (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
1317                    And,
1318                    col("t3.b").lt(lit(15u32)),
1319                ),
1320            ))?
1321            .build()?;
1322
1323        assert_optimized_plan_equal!(
1324            plan,
1325            @ r"
1326        Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1327          Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1328            Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1329              Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]
1330                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
1331                TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
1332              TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
1333        "
1334        )
1335    }
1336}