1use crate::expr::{
21 AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast,
22 GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
23 WindowFunction, WindowFunctionParams,
24};
25use crate::{Expr, ExprFunctionExt};
26
27use datafusion_common::tree_node::{
28 Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
29};
30use datafusion_common::Result;
31
32impl TreeNode for Expr {
38 fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
43 &'n self,
44 f: F,
45 ) -> Result<TreeNodeRecursion> {
46 match self {
47 Expr::Alias(Alias { expr, .. })
48 | Expr::Unnest(Unnest { expr })
49 | Expr::Not(expr)
50 | Expr::IsNotNull(expr)
51 | Expr::IsTrue(expr)
52 | Expr::IsFalse(expr)
53 | Expr::IsUnknown(expr)
54 | Expr::IsNotTrue(expr)
55 | Expr::IsNotFalse(expr)
56 | Expr::IsNotUnknown(expr)
57 | Expr::IsNull(expr)
58 | Expr::Negative(expr)
59 | Expr::Cast(Cast { expr, .. })
60 | Expr::TryCast(TryCast { expr, .. })
61 | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f),
62 Expr::GroupingSet(GroupingSet::Rollup(exprs))
63 | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
64 Expr::ScalarFunction(ScalarFunction { args, .. }) => {
65 args.apply_elements(f)
66 }
67 Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
68 lists_of_exprs.apply_elements(f)
69 }
70 #[expect(deprecated)]
72 Expr::Column(_)
73 | Expr::OuterReferenceColumn(_, _)
75 | Expr::ScalarVariable(_, _)
76 | Expr::Literal(_)
77 | Expr::Exists { .. }
78 | Expr::ScalarSubquery(_)
79 | Expr::Wildcard { .. }
80 | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
81 Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
82 (left, right).apply_ref_elements(f)
83 }
84 Expr::Like(Like { expr, pattern, .. })
85 | Expr::SimilarTo(Like { expr, pattern, .. }) => {
86 (expr, pattern).apply_ref_elements(f)
87 }
88 Expr::Between(Between {
89 expr, low, high, ..
90 }) => (expr, low, high).apply_ref_elements(f),
91 Expr::Case(Case { expr, when_then_expr, else_expr }) =>
92 (expr, when_then_expr, else_expr).apply_ref_elements(f),
93 Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
94 (args, filter, order_by).apply_ref_elements(f),
95 Expr::WindowFunction(WindowFunction {
96 params : WindowFunctionParams {
97 args,
98 partition_by,
99 order_by,
100 ..}, ..}) => {
101 (args, partition_by, order_by).apply_ref_elements(f)
102 }
103 Expr::InList(InList { expr, list, .. }) => {
104 (expr, list).apply_ref_elements(f)
105 }
106 }
107 }
108
109 fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
114 self,
115 mut f: F,
116 ) -> Result<Transformed<Self>> {
117 Ok(match self {
118 #[expect(deprecated)]
120 Expr::Column(_)
121 | Expr::Wildcard { .. }
122 | Expr::Placeholder(Placeholder { .. })
123 | Expr::OuterReferenceColumn(_, _)
124 | Expr::Exists { .. }
125 | Expr::ScalarSubquery(_)
126 | Expr::ScalarVariable(_, _)
127 | Expr::Literal(_) => Transformed::no(self),
128 Expr::Unnest(Unnest { expr, .. }) => expr
129 .map_elements(f)?
130 .update_data(|expr| Expr::Unnest(Unnest { expr })),
131 Expr::Alias(Alias {
132 expr,
133 relation,
134 name,
135 metadata,
136 }) => f(*expr)?.update_data(|e| {
137 e.alias_qualified_with_metadata(relation, name, metadata)
138 }),
139 Expr::InSubquery(InSubquery {
140 expr,
141 subquery,
142 negated,
143 }) => expr.map_elements(f)?.update_data(|be| {
144 Expr::InSubquery(InSubquery::new(be, subquery, negated))
145 }),
146 Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
147 .map_elements(f)?
148 .update_data(|(new_left, new_right)| {
149 Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
150 }),
151 Expr::Like(Like {
152 negated,
153 expr,
154 pattern,
155 escape_char,
156 case_insensitive,
157 }) => {
158 (expr, pattern)
159 .map_elements(f)?
160 .update_data(|(new_expr, new_pattern)| {
161 Expr::Like(Like::new(
162 negated,
163 new_expr,
164 new_pattern,
165 escape_char,
166 case_insensitive,
167 ))
168 })
169 }
170 Expr::SimilarTo(Like {
171 negated,
172 expr,
173 pattern,
174 escape_char,
175 case_insensitive,
176 }) => {
177 (expr, pattern)
178 .map_elements(f)?
179 .update_data(|(new_expr, new_pattern)| {
180 Expr::SimilarTo(Like::new(
181 negated,
182 new_expr,
183 new_pattern,
184 escape_char,
185 case_insensitive,
186 ))
187 })
188 }
189 Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
190 Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
191 Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
192 Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
193 Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
194 Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
195 Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
196 Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
197 Expr::IsNotUnknown(expr) => {
198 expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
199 }
200 Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
201 Expr::Between(Between {
202 expr,
203 negated,
204 low,
205 high,
206 }) => (expr, low, high).map_elements(f)?.update_data(
207 |(new_expr, new_low, new_high)| {
208 Expr::Between(Between::new(new_expr, negated, new_low, new_high))
209 },
210 ),
211 Expr::Case(Case {
212 expr,
213 when_then_expr,
214 else_expr,
215 }) => (expr, when_then_expr, else_expr)
216 .map_elements(f)?
217 .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
218 Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
219 }),
220 Expr::Cast(Cast { expr, data_type }) => expr
221 .map_elements(f)?
222 .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
223 Expr::TryCast(TryCast { expr, data_type }) => expr
224 .map_elements(f)?
225 .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
226 Expr::ScalarFunction(ScalarFunction { func, args }) => {
227 args.map_elements(f)?.map_data(|new_args| {
228 Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
229 func, new_args,
230 )))
231 })?
232 }
233 Expr::WindowFunction(WindowFunction {
234 fun,
235 params:
236 WindowFunctionParams {
237 args,
238 partition_by,
239 order_by,
240 window_frame,
241 null_treatment,
242 },
243 }) => (args, partition_by, order_by).map_elements(f)?.update_data(
244 |(new_args, new_partition_by, new_order_by)| {
245 Expr::WindowFunction(WindowFunction::new(fun, new_args))
246 .partition_by(new_partition_by)
247 .order_by(new_order_by)
248 .window_frame(window_frame)
249 .null_treatment(null_treatment)
250 .build()
251 .unwrap()
252 },
253 ),
254 Expr::AggregateFunction(AggregateFunction {
255 func,
256 params:
257 AggregateFunctionParams {
258 args,
259 distinct,
260 filter,
261 order_by,
262 null_treatment,
263 },
264 }) => (args, filter, order_by).map_elements(f)?.map_data(
265 |(new_args, new_filter, new_order_by)| {
266 Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
267 func,
268 new_args,
269 distinct,
270 new_filter,
271 new_order_by,
272 null_treatment,
273 )))
274 },
275 )?,
276 Expr::GroupingSet(grouping_set) => match grouping_set {
277 GroupingSet::Rollup(exprs) => exprs
278 .map_elements(f)?
279 .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
280 GroupingSet::Cube(exprs) => exprs
281 .map_elements(f)?
282 .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
283 GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
284 .map_elements(f)?
285 .update_data(|new_lists_of_exprs| {
286 Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
287 }),
288 },
289 Expr::InList(InList {
290 expr,
291 list,
292 negated,
293 }) => (expr, list)
294 .map_elements(f)?
295 .update_data(|(new_expr, new_list)| {
296 Expr::InList(InList::new(new_expr, new_list, negated))
297 }),
298 })
299 }
300}