datafusion_expr/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Tree node implementation for Logical Expressions
19
20use crate::expr::{
21    AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast,
22    GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
23    WindowFunction, WindowFunctionParams,
24};
25use crate::{Expr, ExprFunctionExt};
26
27use datafusion_common::tree_node::{
28    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
29};
30use datafusion_common::Result;
31
32/// Implementation of the [`TreeNode`] trait
33///
34/// This allows logical expressions (`Expr`) to be traversed and transformed
35/// Facilitates tasks such as optimization and rewriting during query
36/// planning.
37impl TreeNode for Expr {
38    /// Applies a function `f` to each child expression of `self`.
39    ///
40    /// The function `f` determines whether to continue traversing the tree or to stop.
41    /// This method collects all child expressions and applies `f` to each.
42    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
43        &'n self,
44        f: F,
45    ) -> Result<TreeNodeRecursion> {
46        match self {
47            Expr::Alias(Alias { expr, .. })
48            | Expr::Unnest(Unnest { expr })
49            | Expr::Not(expr)
50            | Expr::IsNotNull(expr)
51            | Expr::IsTrue(expr)
52            | Expr::IsFalse(expr)
53            | Expr::IsUnknown(expr)
54            | Expr::IsNotTrue(expr)
55            | Expr::IsNotFalse(expr)
56            | Expr::IsNotUnknown(expr)
57            | Expr::IsNull(expr)
58            | Expr::Negative(expr)
59            | Expr::Cast(Cast { expr, .. })
60            | Expr::TryCast(TryCast { expr, .. })
61            | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f),
62            Expr::GroupingSet(GroupingSet::Rollup(exprs))
63            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
64            Expr::ScalarFunction(ScalarFunction { args, .. }) => {
65                args.apply_elements(f)
66            }
67            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
68                lists_of_exprs.apply_elements(f)
69            }
70            // TODO: remove the next line after `Expr::Wildcard` is removed
71            #[expect(deprecated)]
72            Expr::Column(_)
73            // Treat OuterReferenceColumn as a leaf expression
74            | Expr::OuterReferenceColumn(_, _)
75            | Expr::ScalarVariable(_, _)
76            | Expr::Literal(_)
77            | Expr::Exists { .. }
78            | Expr::ScalarSubquery(_)
79            | Expr::Wildcard { .. }
80            | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
81            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
82                (left, right).apply_ref_elements(f)
83            }
84            Expr::Like(Like { expr, pattern, .. })
85            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
86                (expr, pattern).apply_ref_elements(f)
87            }
88            Expr::Between(Between {
89                              expr, low, high, ..
90                          }) => (expr, low, high).apply_ref_elements(f),
91            Expr::Case(Case { expr, when_then_expr, else_expr }) =>
92                (expr, when_then_expr, else_expr).apply_ref_elements(f),
93            Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
94                (args, filter, order_by).apply_ref_elements(f),
95            Expr::WindowFunction(WindowFunction {
96                params : WindowFunctionParams {
97                    args,
98                    partition_by,
99                    order_by,
100                    ..}, ..}) => {
101                (args, partition_by, order_by).apply_ref_elements(f)
102            }
103            Expr::InList(InList { expr, list, .. }) => {
104                (expr, list).apply_ref_elements(f)
105            }
106        }
107    }
108
109    /// Maps each child of `self` using the provided closure `f`.
110    ///
111    /// The closure `f` takes ownership of an expression and returns a `Transformed` result,
112    /// indicating whether the expression was transformed or left unchanged.
113    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
114        self,
115        mut f: F,
116    ) -> Result<Transformed<Self>> {
117        Ok(match self {
118            // TODO: remove the next line after `Expr::Wildcard` is removed
119            #[expect(deprecated)]
120            Expr::Column(_)
121            | Expr::Wildcard { .. }
122            | Expr::Placeholder(Placeholder { .. })
123            | Expr::OuterReferenceColumn(_, _)
124            | Expr::Exists { .. }
125            | Expr::ScalarSubquery(_)
126            | Expr::ScalarVariable(_, _)
127            | Expr::Literal(_) => Transformed::no(self),
128            Expr::Unnest(Unnest { expr, .. }) => expr
129                .map_elements(f)?
130                .update_data(|expr| Expr::Unnest(Unnest { expr })),
131            Expr::Alias(Alias {
132                expr,
133                relation,
134                name,
135                metadata,
136            }) => f(*expr)?.update_data(|e| {
137                e.alias_qualified_with_metadata(relation, name, metadata)
138            }),
139            Expr::InSubquery(InSubquery {
140                expr,
141                subquery,
142                negated,
143            }) => expr.map_elements(f)?.update_data(|be| {
144                Expr::InSubquery(InSubquery::new(be, subquery, negated))
145            }),
146            Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
147                .map_elements(f)?
148                .update_data(|(new_left, new_right)| {
149                    Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
150                }),
151            Expr::Like(Like {
152                negated,
153                expr,
154                pattern,
155                escape_char,
156                case_insensitive,
157            }) => {
158                (expr, pattern)
159                    .map_elements(f)?
160                    .update_data(|(new_expr, new_pattern)| {
161                        Expr::Like(Like::new(
162                            negated,
163                            new_expr,
164                            new_pattern,
165                            escape_char,
166                            case_insensitive,
167                        ))
168                    })
169            }
170            Expr::SimilarTo(Like {
171                negated,
172                expr,
173                pattern,
174                escape_char,
175                case_insensitive,
176            }) => {
177                (expr, pattern)
178                    .map_elements(f)?
179                    .update_data(|(new_expr, new_pattern)| {
180                        Expr::SimilarTo(Like::new(
181                            negated,
182                            new_expr,
183                            new_pattern,
184                            escape_char,
185                            case_insensitive,
186                        ))
187                    })
188            }
189            Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
190            Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
191            Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
192            Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
193            Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
194            Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
195            Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
196            Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
197            Expr::IsNotUnknown(expr) => {
198                expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
199            }
200            Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
201            Expr::Between(Between {
202                expr,
203                negated,
204                low,
205                high,
206            }) => (expr, low, high).map_elements(f)?.update_data(
207                |(new_expr, new_low, new_high)| {
208                    Expr::Between(Between::new(new_expr, negated, new_low, new_high))
209                },
210            ),
211            Expr::Case(Case {
212                expr,
213                when_then_expr,
214                else_expr,
215            }) => (expr, when_then_expr, else_expr)
216                .map_elements(f)?
217                .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
218                    Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
219                }),
220            Expr::Cast(Cast { expr, data_type }) => expr
221                .map_elements(f)?
222                .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
223            Expr::TryCast(TryCast { expr, data_type }) => expr
224                .map_elements(f)?
225                .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
226            Expr::ScalarFunction(ScalarFunction { func, args }) => {
227                args.map_elements(f)?.map_data(|new_args| {
228                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
229                        func, new_args,
230                    )))
231                })?
232            }
233            Expr::WindowFunction(WindowFunction {
234                fun,
235                params:
236                    WindowFunctionParams {
237                        args,
238                        partition_by,
239                        order_by,
240                        window_frame,
241                        null_treatment,
242                    },
243            }) => (args, partition_by, order_by).map_elements(f)?.update_data(
244                |(new_args, new_partition_by, new_order_by)| {
245                    Expr::WindowFunction(WindowFunction::new(fun, new_args))
246                        .partition_by(new_partition_by)
247                        .order_by(new_order_by)
248                        .window_frame(window_frame)
249                        .null_treatment(null_treatment)
250                        .build()
251                        .unwrap()
252                },
253            ),
254            Expr::AggregateFunction(AggregateFunction {
255                func,
256                params:
257                    AggregateFunctionParams {
258                        args,
259                        distinct,
260                        filter,
261                        order_by,
262                        null_treatment,
263                    },
264            }) => (args, filter, order_by).map_elements(f)?.map_data(
265                |(new_args, new_filter, new_order_by)| {
266                    Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
267                        func,
268                        new_args,
269                        distinct,
270                        new_filter,
271                        new_order_by,
272                        null_treatment,
273                    )))
274                },
275            )?,
276            Expr::GroupingSet(grouping_set) => match grouping_set {
277                GroupingSet::Rollup(exprs) => exprs
278                    .map_elements(f)?
279                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
280                GroupingSet::Cube(exprs) => exprs
281                    .map_elements(f)?
282                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
283                GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
284                    .map_elements(f)?
285                    .update_data(|new_lists_of_exprs| {
286                        Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
287                    }),
288            },
289            Expr::InList(InList {
290                expr,
291                list,
292                negated,
293            }) => (expr, list)
294                .map_elements(f)?
295                .update_data(|(new_expr, new_list)| {
296                    Expr::InList(InList::new(new_expr, new_list, negated))
297                }),
298        })
299    }
300}