datafusion_expr/expr_rewriter/
mod.rs1use std::collections::HashMap;
21use std::collections::HashSet;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::expr::{Alias, Sort, Unnest};
26use crate::logical_plan::Projection;
27use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
28
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31use datafusion_common::TableReference;
32use datafusion_common::{Column, DFSchema, Result};
33
34mod order_by;
35pub use order_by::rewrite_sort_cols_by_aggs;
36
37pub trait FunctionRewrite: Debug {
47 fn name(&self) -> &str;
49
50 fn rewrite(
55 &self,
56 expr: Expr,
57 schema: &DFSchema,
58 config: &ConfigOptions,
59 ) -> Result<Transformed<Expr>>;
60}
61
62pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
65 expr.transform(|expr| {
66 Ok({
67 if let Expr::Column(c) = expr {
68 let col = LogicalPlanBuilder::normalize(plan, c)?;
69 Transformed::yes(Expr::Column(col))
70 } else {
71 Transformed::no(expr)
72 }
73 })
74 })
75 .data()
76}
77
78pub fn normalize_col_with_schemas_and_ambiguity_check(
80 expr: Expr,
81 schemas: &[&[&DFSchema]],
82 using_columns: &[HashSet<Column>],
83) -> Result<Expr> {
84 if let Expr::Unnest(Unnest { expr }) = expr {
86 let e = normalize_col_with_schemas_and_ambiguity_check(
87 expr.as_ref().clone(),
88 schemas,
89 using_columns,
90 )?;
91 return Ok(Expr::Unnest(Unnest { expr: Box::new(e) }));
92 }
93
94 expr.transform(|expr| {
95 Ok({
96 if let Expr::Column(c) = expr {
97 let col =
98 c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?;
99 Transformed::yes(Expr::Column(col))
100 } else {
101 Transformed::no(expr)
102 }
103 })
104 })
105 .data()
106}
107
108pub fn normalize_cols(
110 exprs: impl IntoIterator<Item = impl Into<Expr>>,
111 plan: &LogicalPlan,
112) -> Result<Vec<Expr>> {
113 exprs
114 .into_iter()
115 .map(|e| normalize_col(e.into(), plan))
116 .collect()
117}
118
119pub fn normalize_sorts(
120 sorts: impl IntoIterator<Item = impl Into<Sort>>,
121 plan: &LogicalPlan,
122) -> Result<Vec<Sort>> {
123 sorts
124 .into_iter()
125 .map(|e| {
126 let sort = e.into();
127 normalize_col(sort.expr, plan)
128 .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
129 })
130 .collect()
131}
132
133pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
136 expr.transform(|expr| {
137 Ok({
138 if let Expr::Column(c) = &expr {
139 match replace_map.get(c) {
140 Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())),
141 None => Transformed::no(expr),
142 }
143 } else {
144 Transformed::no(expr)
145 }
146 })
147 })
148 .data()
149}
150
151pub fn unnormalize_col(expr: Expr) -> Expr {
157 expr.transform(|expr| {
158 Ok({
159 if let Expr::Column(c) = expr {
160 let col = Column::new_unqualified(c.name);
161 Transformed::yes(Expr::Column(col))
162 } else {
163 Transformed::no(expr)
164 }
165 })
166 })
167 .data()
168 .expect("Unnormalize is infallible")
169}
170
171pub fn create_col_from_scalar_expr(
173 scalar_expr: &Expr,
174 subqry_alias: String,
175) -> Result<Column> {
176 match scalar_expr {
177 Expr::Alias(Alias { name, .. }) => Ok(Column::new(
178 Some::<TableReference>(subqry_alias.into()),
179 name,
180 )),
181 Expr::Column(col) => Ok(col.with_relation(subqry_alias.into())),
182 _ => {
183 let scalar_column = scalar_expr.schema_name().to_string();
184 Ok(Column::new(
185 Some::<TableReference>(subqry_alias.into()),
186 scalar_column,
187 ))
188 }
189 }
190}
191
192#[inline]
194pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
195 exprs.into_iter().map(unnormalize_col).collect()
196}
197
198pub fn strip_outer_reference(expr: Expr) -> Expr {
201 expr.transform(|expr| {
202 Ok({
203 if let Expr::OuterReferenceColumn(_, col) = expr {
204 Transformed::yes(Expr::Column(col))
205 } else {
206 Transformed::no(expr)
207 }
208 })
209 })
210 .data()
211 .expect("strip_outer_reference is infallible")
212}
213
214pub fn coerce_plan_expr_for_schema(
217 plan: LogicalPlan,
218 schema: &DFSchema,
219) -> Result<LogicalPlan> {
220 match plan {
221 LogicalPlan::Projection(Projection { expr, input, .. }) => {
223 let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?;
224 let projection = Projection::try_new(new_exprs, input)?;
225 Ok(LogicalPlan::Projection(projection))
226 }
227 _ => {
228 let exprs: Vec<Expr> = plan.schema().iter().map(Expr::from).collect();
229 let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
230 let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none());
231 if add_project {
232 let projection = Projection::try_new(new_exprs, Arc::new(plan))?;
233 Ok(LogicalPlan::Projection(projection))
234 } else {
235 Ok(plan)
236 }
237 }
238 }
239}
240
241fn coerce_exprs_for_schema(
242 exprs: Vec<Expr>,
243 src_schema: &DFSchema,
244 dst_schema: &DFSchema,
245) -> Result<Vec<Expr>> {
246 exprs
247 .into_iter()
248 .enumerate()
249 .map(|(idx, expr)| {
250 let new_type = dst_schema.field(idx).data_type();
251 if new_type != &expr.get_type(src_schema)? {
252 match expr {
253 Expr::Alias(Alias { expr, name, .. }) => {
254 Ok(expr.cast_to(new_type, src_schema)?.alias(name))
255 }
256 #[expect(deprecated)]
257 Expr::Wildcard { .. } => Ok(expr),
258 _ => expr.cast_to(new_type, src_schema),
259 }
260 } else {
261 Ok(expr)
262 }
263 })
264 .collect::<Result<_>>()
265}
266
267#[inline]
269pub fn unalias(expr: Expr) -> Expr {
270 match expr {
271 Expr::Alias(Alias { expr, .. }) => unalias(*expr),
272 _ => expr,
273 }
274}
275
276pub struct NamePreserver {
285 use_alias: bool,
286}
287
288#[derive(Debug)]
291pub enum SavedName {
292 Saved {
294 relation: Option<TableReference>,
295 name: String,
296 },
297 None,
299}
300
301impl NamePreserver {
302 pub fn new(plan: &LogicalPlan) -> Self {
304 Self {
305 use_alias: !matches!(
308 plan,
309 LogicalPlan::Filter(_)
310 | LogicalPlan::Join(_)
311 | LogicalPlan::TableScan(_)
312 | LogicalPlan::Limit(_)
313 | LogicalPlan::Statement(_)
314 ),
315 }
316 }
317
318 pub fn new_for_projection() -> Self {
322 Self { use_alias: true }
323 }
324
325 pub fn save(&self, expr: &Expr) -> SavedName {
326 if self.use_alias {
327 let (relation, name) = expr.qualified_name();
328 SavedName::Saved { relation, name }
329 } else {
330 SavedName::None
331 }
332 }
333}
334
335impl SavedName {
336 pub fn restore(self, expr: Expr) -> Expr {
338 match self {
339 SavedName::Saved { relation, name } => {
340 let (new_relation, new_name) = expr.qualified_name();
341 if new_relation != relation || new_name != name {
342 expr.alias_qualified(relation, name)
343 } else {
344 expr
345 }
346 }
347 SavedName::None => expr,
348 }
349 }
350}
351
352#[cfg(test)]
353mod test {
354 use std::ops::Add;
355
356 use super::*;
357 use crate::literal::lit_with_metadata;
358 use crate::{col, lit, Cast};
359 use arrow::datatypes::{DataType, Field, Schema};
360 use datafusion_common::tree_node::TreeNodeRewriter;
361 use datafusion_common::ScalarValue;
362
363 #[derive(Default)]
364 struct RecordingRewriter {
365 v: Vec<String>,
366 }
367
368 impl TreeNodeRewriter for RecordingRewriter {
369 type Node = Expr;
370
371 fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
372 self.v.push(format!("Previsited {expr}"));
373 Ok(Transformed::no(expr))
374 }
375
376 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
377 self.v.push(format!("Mutated {expr}"));
378 Ok(Transformed::no(expr))
379 }
380 }
381
382 #[test]
383 fn rewriter_rewrite() {
384 let transformer = |expr: Expr| -> Result<Transformed<Expr>> {
386 match expr {
387 Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => {
388 let utf8_val = if utf8_val == "foo" {
389 "bar".to_string()
390 } else {
391 utf8_val
392 };
393 Ok(Transformed::yes(lit_with_metadata(
394 utf8_val,
395 metadata
396 .map(|m| m.into_iter().collect::<HashMap<String, String>>()),
397 )))
398 }
399 _ => Ok(Transformed::no(expr)),
401 }
402 };
403
404 let rewritten = col("state")
406 .eq(lit("foo"))
407 .transform(transformer)
408 .data()
409 .unwrap();
410 assert_eq!(rewritten, col("state").eq(lit("bar")));
411
412 let rewritten = col("state")
414 .eq(lit("baz"))
415 .transform(transformer)
416 .data()
417 .unwrap();
418 assert_eq!(rewritten, col("state").eq(lit("baz")));
419 }
420
421 #[test]
422 fn normalize_cols() {
423 let expr = col("a") + col("b") + col("c");
424
425 let schema_a = make_schema_with_empty_metadata(
427 vec![Some("tableA".into()), Some("tableA".into())],
428 vec!["a", "aa"],
429 );
430 let schema_c = make_schema_with_empty_metadata(
431 vec![Some("tableC".into()), Some("tableC".into())],
432 vec!["cc", "c"],
433 );
434 let schema_b =
435 make_schema_with_empty_metadata(vec![Some("tableB".into())], vec!["b"]);
436 let schema_f = make_schema_with_empty_metadata(
438 vec![Some("tableC".into()), Some("tableC".into())],
439 vec!["f", "ff"],
440 );
441 let schemas = vec![schema_c, schema_f, schema_b, schema_a];
442 let schemas = schemas.iter().collect::<Vec<_>>();
443
444 let normalized_expr =
445 normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
446 .unwrap();
447 assert_eq!(
448 normalized_expr,
449 col("tableA.a") + col("tableB.b") + col("tableC.c")
450 );
451 }
452
453 #[test]
454 fn normalize_cols_non_exist() {
455 let expr = col("a") + col("b");
457 let schema_a =
458 make_schema_with_empty_metadata(vec![Some("\"tableA\"".into())], vec!["a"]);
459 let schemas = [schema_a];
460 let schemas = schemas.iter().collect::<Vec<_>>();
461
462 let error =
463 normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
464 .unwrap_err()
465 .strip_backtrace();
466 let expected = "Schema error: No field named b. \
467 Valid fields are \"tableA\".a.";
468 assert_eq!(error, expected);
469 }
470
471 #[test]
472 fn unnormalize_cols() {
473 let expr = col("tableA.a") + col("tableB.b");
474 let unnormalized_expr = unnormalize_col(expr);
475 assert_eq!(unnormalized_expr, col("a") + col("b"));
476 }
477
478 fn make_schema_with_empty_metadata(
479 qualifiers: Vec<Option<TableReference>>,
480 fields: Vec<&str>,
481 ) -> DFSchema {
482 let fields = fields
483 .iter()
484 .map(|f| Arc::new(Field::new(f.to_string(), DataType::Int8, false)))
485 .collect::<Vec<_>>();
486 let schema = Arc::new(Schema::new(fields));
487 DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap()
488 }
489
490 #[test]
491 fn rewriter_visit() {
492 let mut rewriter = RecordingRewriter::default();
493 col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
494
495 assert_eq!(
496 rewriter.v,
497 vec![
498 "Previsited state = Utf8(\"CO\")",
499 "Previsited state",
500 "Mutated state",
501 "Previsited Utf8(\"CO\")",
502 "Mutated Utf8(\"CO\")",
503 "Mutated state = Utf8(\"CO\")"
504 ]
505 )
506 }
507
508 #[test]
509 fn test_rewrite_preserving_name() {
510 test_rewrite(col("a"), col("a"));
511
512 test_rewrite(col("a"), col("b"));
513
514 test_rewrite(
516 col("a"),
517 Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
518 );
519
520 test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
522
523 test_rewrite(
525 Expr::Column(Column::new(Some("test"), "a")),
526 Expr::Column(Column::new_unqualified("test.a")),
527 );
528 test_rewrite(
529 Expr::Column(Column::new_unqualified("test.a")),
530 Expr::Column(Column::new(Some("test"), "a")),
531 );
532 }
533
534 fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
537 struct TestRewriter {
538 rewrite_to: Expr,
539 }
540
541 impl TreeNodeRewriter for TestRewriter {
542 type Node = Expr;
543
544 fn f_up(&mut self, _: Expr) -> Result<Transformed<Expr>> {
545 Ok(Transformed::yes(self.rewrite_to.clone()))
546 }
547 }
548
549 let mut rewriter = TestRewriter {
550 rewrite_to: rewrite_to.clone(),
551 };
552 let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
553 let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
554 let new_expr = saved_name.restore(new_expr);
555
556 let original_name = expr_from.qualified_name();
557 let new_name = new_expr.qualified_name();
558 assert_eq!(
559 original_name, new_name,
560 "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
561 )
562 }
563}