datafusion_expr/expr_rewriter/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression rewriter
19
20use std::collections::HashMap;
21use std::collections::HashSet;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::expr::{Alias, Sort, Unnest};
26use crate::logical_plan::Projection;
27use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
28
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31use datafusion_common::TableReference;
32use datafusion_common::{Column, DFSchema, Result};
33
34mod order_by;
35pub use order_by::rewrite_sort_cols_by_aggs;
36
37/// Trait for rewriting [`Expr`]s into function calls.
38///
39/// This trait is used with `FunctionRegistry::register_function_rewrite` to
40/// to evaluating `Expr`s using functions that may not be built in to DataFusion
41///
42/// For example, concatenating arrays `a || b` is represented as
43/// `Operator::ArrowAt`, but can be implemented by calling a function
44/// `array_concat` from the `functions-nested` crate.
45// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it.
46pub trait FunctionRewrite: Debug {
47    /// Return a human readable name for this rewrite
48    fn name(&self) -> &str;
49
50    /// Potentially rewrite `expr` to some other expression
51    ///
52    /// Note that recursion is handled by the caller -- this method should only
53    /// handle `expr`, not recurse to its children.
54    fn rewrite(
55        &self,
56        expr: Expr,
57        schema: &DFSchema,
58        config: &ConfigOptions,
59    ) -> Result<Transformed<Expr>>;
60}
61
62/// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions
63/// in the `expr` expression tree.
64pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
65    expr.transform(|expr| {
66        Ok({
67            if let Expr::Column(c) = expr {
68                let col = LogicalPlanBuilder::normalize(plan, c)?;
69                Transformed::yes(Expr::Column(col))
70            } else {
71                Transformed::no(expr)
72            }
73        })
74    })
75    .data()
76}
77
78/// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage
79pub fn normalize_col_with_schemas_and_ambiguity_check(
80    expr: Expr,
81    schemas: &[&[&DFSchema]],
82    using_columns: &[HashSet<Column>],
83) -> Result<Expr> {
84    // Normalize column inside Unnest
85    if let Expr::Unnest(Unnest { expr }) = expr {
86        let e = normalize_col_with_schemas_and_ambiguity_check(
87            expr.as_ref().clone(),
88            schemas,
89            using_columns,
90        )?;
91        return Ok(Expr::Unnest(Unnest { expr: Box::new(e) }));
92    }
93
94    expr.transform(|expr| {
95        Ok({
96            if let Expr::Column(c) = expr {
97                let col =
98                    c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?;
99                Transformed::yes(Expr::Column(col))
100            } else {
101                Transformed::no(expr)
102            }
103        })
104    })
105    .data()
106}
107
108/// Recursively normalize all [`Column`] expressions in a list of expression trees
109pub fn normalize_cols(
110    exprs: impl IntoIterator<Item = impl Into<Expr>>,
111    plan: &LogicalPlan,
112) -> Result<Vec<Expr>> {
113    exprs
114        .into_iter()
115        .map(|e| normalize_col(e.into(), plan))
116        .collect()
117}
118
119pub fn normalize_sorts(
120    sorts: impl IntoIterator<Item = impl Into<Sort>>,
121    plan: &LogicalPlan,
122) -> Result<Vec<Sort>> {
123    sorts
124        .into_iter()
125        .map(|e| {
126            let sort = e.into();
127            normalize_col(sort.expr, plan)
128                .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
129        })
130        .collect()
131}
132
133/// Recursively replace all [`Column`] expressions in a given expression tree with
134/// `Column` expressions provided by the hash map argument.
135pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
136    expr.transform(|expr| {
137        Ok({
138            if let Expr::Column(c) = &expr {
139                match replace_map.get(c) {
140                    Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())),
141                    None => Transformed::no(expr),
142                }
143            } else {
144                Transformed::no(expr)
145            }
146        })
147    })
148    .data()
149}
150
151/// Recursively 'unnormalize' (remove all qualifiers) from an
152/// expression tree.
153///
154/// For example, if there were expressions like `foo.bar` this would
155/// rewrite it to just `bar`.
156pub fn unnormalize_col(expr: Expr) -> Expr {
157    expr.transform(|expr| {
158        Ok({
159            if let Expr::Column(c) = expr {
160                let col = Column::new_unqualified(c.name);
161                Transformed::yes(Expr::Column(col))
162            } else {
163                Transformed::no(expr)
164            }
165        })
166    })
167    .data()
168    .expect("Unnormalize is infallible")
169}
170
171/// Create a Column from the Scalar Expr
172pub fn create_col_from_scalar_expr(
173    scalar_expr: &Expr,
174    subqry_alias: String,
175) -> Result<Column> {
176    match scalar_expr {
177        Expr::Alias(Alias { name, .. }) => Ok(Column::new(
178            Some::<TableReference>(subqry_alias.into()),
179            name,
180        )),
181        Expr::Column(col) => Ok(col.with_relation(subqry_alias.into())),
182        _ => {
183            let scalar_column = scalar_expr.schema_name().to_string();
184            Ok(Column::new(
185                Some::<TableReference>(subqry_alias.into()),
186                scalar_column,
187            ))
188        }
189    }
190}
191
192/// Recursively un-normalize all [`Column`] expressions in a list of expression trees
193#[inline]
194pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
195    exprs.into_iter().map(unnormalize_col).collect()
196}
197
198/// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column
199/// in the expression tree.
200pub fn strip_outer_reference(expr: Expr) -> Expr {
201    expr.transform(|expr| {
202        Ok({
203            if let Expr::OuterReferenceColumn(_, col) = expr {
204                Transformed::yes(Expr::Column(col))
205            } else {
206                Transformed::no(expr)
207            }
208        })
209    })
210    .data()
211    .expect("strip_outer_reference is infallible")
212}
213
214/// Returns plan with expressions coerced to types compatible with
215/// schema types
216pub fn coerce_plan_expr_for_schema(
217    plan: LogicalPlan,
218    schema: &DFSchema,
219) -> Result<LogicalPlan> {
220    match plan {
221        // special case Projection to avoid adding multiple projections
222        LogicalPlan::Projection(Projection { expr, input, .. }) => {
223            let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?;
224            let projection = Projection::try_new(new_exprs, input)?;
225            Ok(LogicalPlan::Projection(projection))
226        }
227        _ => {
228            let exprs: Vec<Expr> = plan.schema().iter().map(Expr::from).collect();
229            let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
230            let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none());
231            if add_project {
232                let projection = Projection::try_new(new_exprs, Arc::new(plan))?;
233                Ok(LogicalPlan::Projection(projection))
234            } else {
235                Ok(plan)
236            }
237        }
238    }
239}
240
241fn coerce_exprs_for_schema(
242    exprs: Vec<Expr>,
243    src_schema: &DFSchema,
244    dst_schema: &DFSchema,
245) -> Result<Vec<Expr>> {
246    exprs
247        .into_iter()
248        .enumerate()
249        .map(|(idx, expr)| {
250            let new_type = dst_schema.field(idx).data_type();
251            if new_type != &expr.get_type(src_schema)? {
252                match expr {
253                    Expr::Alias(Alias { expr, name, .. }) => {
254                        Ok(expr.cast_to(new_type, src_schema)?.alias(name))
255                    }
256                    #[expect(deprecated)]
257                    Expr::Wildcard { .. } => Ok(expr),
258                    _ => expr.cast_to(new_type, src_schema),
259                }
260            } else {
261                Ok(expr)
262            }
263        })
264        .collect::<Result<_>>()
265}
266
267/// Recursively un-alias an expressions
268#[inline]
269pub fn unalias(expr: Expr) -> Expr {
270    match expr {
271        Expr::Alias(Alias { expr, .. }) => unalias(*expr),
272        _ => expr,
273    }
274}
275
276/// Handles ensuring the name of rewritten expressions is not changed.
277///
278/// This is important when optimizing plans to ensure the output
279/// schema of plan nodes don't change after optimization.
280/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
281/// expression should be preserved: `3 as "1 + 2"`
282///
283/// See <https://github.com/apache/datafusion/issues/3555> for details
284pub struct NamePreserver {
285    use_alias: bool,
286}
287
288/// If the qualified name of an expression is remembered, it will be preserved
289/// when rewriting the expression
290#[derive(Debug)]
291pub enum SavedName {
292    /// Saved qualified name to be preserved
293    Saved {
294        relation: Option<TableReference>,
295        name: String,
296    },
297    /// Name is not preserved
298    None,
299}
300
301impl NamePreserver {
302    /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
303    pub fn new(plan: &LogicalPlan) -> Self {
304        Self {
305            // The expressions of these plans do not contribute to their output schema,
306            // so there is no need to preserve expression names to prevent a schema change.
307            use_alias: !matches!(
308                plan,
309                LogicalPlan::Filter(_)
310                    | LogicalPlan::Join(_)
311                    | LogicalPlan::TableScan(_)
312                    | LogicalPlan::Limit(_)
313                    | LogicalPlan::Statement(_)
314            ),
315        }
316    }
317
318    /// Create a new NamePreserver for rewriting the `expr`s in `Projection`
319    ///
320    /// This will use aliases
321    pub fn new_for_projection() -> Self {
322        Self { use_alias: true }
323    }
324
325    pub fn save(&self, expr: &Expr) -> SavedName {
326        if self.use_alias {
327            let (relation, name) = expr.qualified_name();
328            SavedName::Saved { relation, name }
329        } else {
330            SavedName::None
331        }
332    }
333}
334
335impl SavedName {
336    /// Ensures the qualified name of the rewritten expression is preserved
337    pub fn restore(self, expr: Expr) -> Expr {
338        match self {
339            SavedName::Saved { relation, name } => {
340                let (new_relation, new_name) = expr.qualified_name();
341                if new_relation != relation || new_name != name {
342                    expr.alias_qualified(relation, name)
343                } else {
344                    expr
345                }
346            }
347            SavedName::None => expr,
348        }
349    }
350}
351
352#[cfg(test)]
353mod test {
354    use std::ops::Add;
355
356    use super::*;
357    use crate::literal::lit_with_metadata;
358    use crate::{col, lit, Cast};
359    use arrow::datatypes::{DataType, Field, Schema};
360    use datafusion_common::tree_node::TreeNodeRewriter;
361    use datafusion_common::ScalarValue;
362
363    #[derive(Default)]
364    struct RecordingRewriter {
365        v: Vec<String>,
366    }
367
368    impl TreeNodeRewriter for RecordingRewriter {
369        type Node = Expr;
370
371        fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
372            self.v.push(format!("Previsited {expr}"));
373            Ok(Transformed::no(expr))
374        }
375
376        fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
377            self.v.push(format!("Mutated {expr}"));
378            Ok(Transformed::no(expr))
379        }
380    }
381
382    #[test]
383    fn rewriter_rewrite() {
384        // rewrites all "foo" string literals to "bar"
385        let transformer = |expr: Expr| -> Result<Transformed<Expr>> {
386            match expr {
387                Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => {
388                    let utf8_val = if utf8_val == "foo" {
389                        "bar".to_string()
390                    } else {
391                        utf8_val
392                    };
393                    Ok(Transformed::yes(lit_with_metadata(
394                        utf8_val,
395                        metadata
396                            .map(|m| m.into_iter().collect::<HashMap<String, String>>()),
397                    )))
398                }
399                // otherwise, return None
400                _ => Ok(Transformed::no(expr)),
401            }
402        };
403
404        // rewrites "foo" --> "bar"
405        let rewritten = col("state")
406            .eq(lit("foo"))
407            .transform(transformer)
408            .data()
409            .unwrap();
410        assert_eq!(rewritten, col("state").eq(lit("bar")));
411
412        // doesn't rewrite
413        let rewritten = col("state")
414            .eq(lit("baz"))
415            .transform(transformer)
416            .data()
417            .unwrap();
418        assert_eq!(rewritten, col("state").eq(lit("baz")));
419    }
420
421    #[test]
422    fn normalize_cols() {
423        let expr = col("a") + col("b") + col("c");
424
425        // Schemas with some matching and some non matching cols
426        let schema_a = make_schema_with_empty_metadata(
427            vec![Some("tableA".into()), Some("tableA".into())],
428            vec!["a", "aa"],
429        );
430        let schema_c = make_schema_with_empty_metadata(
431            vec![Some("tableC".into()), Some("tableC".into())],
432            vec!["cc", "c"],
433        );
434        let schema_b =
435            make_schema_with_empty_metadata(vec![Some("tableB".into())], vec!["b"]);
436        // non matching
437        let schema_f = make_schema_with_empty_metadata(
438            vec![Some("tableC".into()), Some("tableC".into())],
439            vec!["f", "ff"],
440        );
441        let schemas = vec![schema_c, schema_f, schema_b, schema_a];
442        let schemas = schemas.iter().collect::<Vec<_>>();
443
444        let normalized_expr =
445            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
446                .unwrap();
447        assert_eq!(
448            normalized_expr,
449            col("tableA.a") + col("tableB.b") + col("tableC.c")
450        );
451    }
452
453    #[test]
454    fn normalize_cols_non_exist() {
455        // test normalizing columns when the name doesn't exist
456        let expr = col("a") + col("b");
457        let schema_a =
458            make_schema_with_empty_metadata(vec![Some("\"tableA\"".into())], vec!["a"]);
459        let schemas = [schema_a];
460        let schemas = schemas.iter().collect::<Vec<_>>();
461
462        let error =
463            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
464                .unwrap_err()
465                .strip_backtrace();
466        let expected = "Schema error: No field named b. \
467            Valid fields are \"tableA\".a.";
468        assert_eq!(error, expected);
469    }
470
471    #[test]
472    fn unnormalize_cols() {
473        let expr = col("tableA.a") + col("tableB.b");
474        let unnormalized_expr = unnormalize_col(expr);
475        assert_eq!(unnormalized_expr, col("a") + col("b"));
476    }
477
478    fn make_schema_with_empty_metadata(
479        qualifiers: Vec<Option<TableReference>>,
480        fields: Vec<&str>,
481    ) -> DFSchema {
482        let fields = fields
483            .iter()
484            .map(|f| Arc::new(Field::new(f.to_string(), DataType::Int8, false)))
485            .collect::<Vec<_>>();
486        let schema = Arc::new(Schema::new(fields));
487        DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap()
488    }
489
490    #[test]
491    fn rewriter_visit() {
492        let mut rewriter = RecordingRewriter::default();
493        col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
494
495        assert_eq!(
496            rewriter.v,
497            vec![
498                "Previsited state = Utf8(\"CO\")",
499                "Previsited state",
500                "Mutated state",
501                "Previsited Utf8(\"CO\")",
502                "Mutated Utf8(\"CO\")",
503                "Mutated state = Utf8(\"CO\")"
504            ]
505        )
506    }
507
508    #[test]
509    fn test_rewrite_preserving_name() {
510        test_rewrite(col("a"), col("a"));
511
512        test_rewrite(col("a"), col("b"));
513
514        // cast data types
515        test_rewrite(
516            col("a"),
517            Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
518        );
519
520        // change literal type from i32 to i64
521        test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
522
523        // test preserve qualifier
524        test_rewrite(
525            Expr::Column(Column::new(Some("test"), "a")),
526            Expr::Column(Column::new_unqualified("test.a")),
527        );
528        test_rewrite(
529            Expr::Column(Column::new_unqualified("test.a")),
530            Expr::Column(Column::new(Some("test"), "a")),
531        );
532    }
533
534    /// rewrites `expr_from` to `rewrite_to` while preserving the original qualified name
535    /// by using the `NamePreserver`
536    fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
537        struct TestRewriter {
538            rewrite_to: Expr,
539        }
540
541        impl TreeNodeRewriter for TestRewriter {
542            type Node = Expr;
543
544            fn f_up(&mut self, _: Expr) -> Result<Transformed<Expr>> {
545                Ok(Transformed::yes(self.rewrite_to.clone()))
546            }
547        }
548
549        let mut rewriter = TestRewriter {
550            rewrite_to: rewrite_to.clone(),
551        };
552        let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
553        let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
554        let new_expr = saved_name.restore(new_expr);
555
556        let original_name = expr_from.qualified_name();
557        let new_name = new_expr.qualified_name();
558        assert_eq!(
559            original_name, new_name,
560            "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
561        )
562    }
563}