datafusion_expr/
expr_fn.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Functions for creating logical expressions
19
20use crate::expr::{
21    AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22    Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, WindowFunctionParams,
23};
24use crate::function::{
25    AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26    StateFieldsArgs,
27};
28use crate::select_expr::SelectExpr;
29use crate::{
30    conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
31    AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs,
32    ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
33};
34use crate::{
35    AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
36};
37use arrow::compute::kernels::cast_utils::{
38    parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
39};
40use arrow::datatypes::{DataType, Field, FieldRef};
41use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference};
42use datafusion_functions_window_common::field::WindowUDFFieldArgs;
43use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
44use sqlparser::ast::NullTreatment;
45use std::any::Any;
46use std::fmt::Debug;
47use std::ops::Not;
48use std::sync::Arc;
49
50/// Create a column expression based on a qualified or unqualified column name. Will
51/// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase).
52///
53/// For example:
54///
55/// ```rust
56/// # use datafusion_expr::col;
57/// let c1 = col("a");
58/// let c2 = col("A");
59/// assert_eq!(c1, c2);
60///
61/// // note how quoting with double quotes preserves the case
62/// let c3 = col(r#""A""#);
63/// assert_ne!(c1, c3);
64/// ```
65pub fn col(ident: impl Into<Column>) -> Expr {
66    Expr::Column(ident.into())
67}
68
69/// Create an out reference column which hold a reference that has been resolved to a field
70/// outside of the current plan.
71pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
72    Expr::OuterReferenceColumn(dt, ident.into())
73}
74
75/// Create an unqualified column expression from the provided name, without normalizing
76/// the column.
77///
78/// For example:
79///
80/// ```rust
81/// # use datafusion_expr::{col, ident};
82/// let c1 = ident("A"); // not normalized staying as column 'A'
83/// let c2 = col("A"); // normalized via SQL rules becoming column 'a'
84/// assert_ne!(c1, c2);
85///
86/// let c3 = col(r#""A""#);
87/// assert_eq!(c1, c3);
88///
89/// let c4 = col("t1.a"); // parses as relation 't1' column 'a'
90/// let c5 = ident("t1.a"); // parses as column 't1.a'
91/// assert_ne!(c4, c5);
92/// ```
93pub fn ident(name: impl Into<String>) -> Expr {
94    Expr::Column(Column::from_name(name))
95}
96
97/// Create placeholder value that will be filled in (such as `$1`)
98///
99/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
100///
101/// # Example
102///
103/// ```rust
104/// # use datafusion_expr::{placeholder};
105/// let p = placeholder("$0"); // $0, refers to parameter 1
106/// assert_eq!(p.to_string(), "$0")
107/// ```
108pub fn placeholder(id: impl Into<String>) -> Expr {
109    Expr::Placeholder(Placeholder {
110        id: id.into(),
111        data_type: None,
112    })
113}
114
115/// Create an '*' [`Expr::Wildcard`] expression that matches all columns
116///
117/// # Example
118///
119/// ```rust
120/// # use datafusion_expr::{wildcard};
121/// let p = wildcard();
122/// assert_eq!(p.to_string(), "*")
123/// ```
124pub fn wildcard() -> SelectExpr {
125    SelectExpr::Wildcard(WildcardOptions::default())
126}
127
128/// Create an '*' [`Expr::Wildcard`] expression with the wildcard options
129pub fn wildcard_with_options(options: WildcardOptions) -> SelectExpr {
130    SelectExpr::Wildcard(options)
131}
132
133/// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table
134///
135/// # Example
136///
137/// ```rust
138/// # use datafusion_common::TableReference;
139/// # use datafusion_expr::{qualified_wildcard};
140/// let p = qualified_wildcard(TableReference::bare("t"));
141/// assert_eq!(p.to_string(), "t.*")
142/// ```
143pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> SelectExpr {
144    SelectExpr::QualifiedWildcard(qualifier.into(), WildcardOptions::default())
145}
146
147/// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options
148pub fn qualified_wildcard_with_options(
149    qualifier: impl Into<TableReference>,
150    options: WildcardOptions,
151) -> SelectExpr {
152    SelectExpr::QualifiedWildcard(qualifier.into(), options)
153}
154
155/// Return a new expression `left <op> right`
156pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
157    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
158}
159
160/// Return a new expression with a logical AND
161pub fn and(left: Expr, right: Expr) -> Expr {
162    Expr::BinaryExpr(BinaryExpr::new(
163        Box::new(left),
164        Operator::And,
165        Box::new(right),
166    ))
167}
168
169/// Return a new expression with a logical OR
170pub fn or(left: Expr, right: Expr) -> Expr {
171    Expr::BinaryExpr(BinaryExpr::new(
172        Box::new(left),
173        Operator::Or,
174        Box::new(right),
175    ))
176}
177
178/// Return a new expression with a logical NOT
179pub fn not(expr: Expr) -> Expr {
180    expr.not()
181}
182
183/// Return a new expression with bitwise AND
184pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
185    Expr::BinaryExpr(BinaryExpr::new(
186        Box::new(left),
187        Operator::BitwiseAnd,
188        Box::new(right),
189    ))
190}
191
192/// Return a new expression with bitwise OR
193pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
194    Expr::BinaryExpr(BinaryExpr::new(
195        Box::new(left),
196        Operator::BitwiseOr,
197        Box::new(right),
198    ))
199}
200
201/// Return a new expression with bitwise XOR
202pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
203    Expr::BinaryExpr(BinaryExpr::new(
204        Box::new(left),
205        Operator::BitwiseXor,
206        Box::new(right),
207    ))
208}
209
210/// Return a new expression with bitwise SHIFT RIGHT
211pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
212    Expr::BinaryExpr(BinaryExpr::new(
213        Box::new(left),
214        Operator::BitwiseShiftRight,
215        Box::new(right),
216    ))
217}
218
219/// Return a new expression with bitwise SHIFT LEFT
220pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
221    Expr::BinaryExpr(BinaryExpr::new(
222        Box::new(left),
223        Operator::BitwiseShiftLeft,
224        Box::new(right),
225    ))
226}
227
228/// Create an in_list expression
229pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
230    Expr::InList(InList::new(Box::new(expr), list, negated))
231}
232
233/// Create an EXISTS subquery expression
234pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
235    let outer_ref_columns = subquery.all_out_ref_exprs();
236    Expr::Exists(Exists {
237        subquery: Subquery {
238            subquery,
239            outer_ref_columns,
240            spans: Spans::new(),
241        },
242        negated: false,
243    })
244}
245
246/// Create a NOT EXISTS subquery expression
247pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
248    let outer_ref_columns = subquery.all_out_ref_exprs();
249    Expr::Exists(Exists {
250        subquery: Subquery {
251            subquery,
252            outer_ref_columns,
253            spans: Spans::new(),
254        },
255        negated: true,
256    })
257}
258
259/// Create an IN subquery expression
260pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
261    let outer_ref_columns = subquery.all_out_ref_exprs();
262    Expr::InSubquery(InSubquery::new(
263        Box::new(expr),
264        Subquery {
265            subquery,
266            outer_ref_columns,
267            spans: Spans::new(),
268        },
269        false,
270    ))
271}
272
273/// Create a NOT IN subquery expression
274pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
275    let outer_ref_columns = subquery.all_out_ref_exprs();
276    Expr::InSubquery(InSubquery::new(
277        Box::new(expr),
278        Subquery {
279            subquery,
280            outer_ref_columns,
281            spans: Spans::new(),
282        },
283        true,
284    ))
285}
286
287/// Create a scalar subquery expression
288pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
289    let outer_ref_columns = subquery.all_out_ref_exprs();
290    Expr::ScalarSubquery(Subquery {
291        subquery,
292        outer_ref_columns,
293        spans: Spans::new(),
294    })
295}
296
297/// Create a grouping set
298pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
299    Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
300}
301
302/// Create a grouping set for all combination of `exprs`
303pub fn cube(exprs: Vec<Expr>) -> Expr {
304    Expr::GroupingSet(GroupingSet::Cube(exprs))
305}
306
307/// Create a grouping set for rollup
308pub fn rollup(exprs: Vec<Expr>) -> Expr {
309    Expr::GroupingSet(GroupingSet::Rollup(exprs))
310}
311
312/// Create a cast expression
313pub fn cast(expr: Expr, data_type: DataType) -> Expr {
314    Expr::Cast(Cast::new(Box::new(expr), data_type))
315}
316
317/// Create a try cast expression
318pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
319    Expr::TryCast(TryCast::new(Box::new(expr), data_type))
320}
321
322/// Create is null expression
323pub fn is_null(expr: Expr) -> Expr {
324    Expr::IsNull(Box::new(expr))
325}
326
327/// Create is true expression
328pub fn is_true(expr: Expr) -> Expr {
329    Expr::IsTrue(Box::new(expr))
330}
331
332/// Create is not true expression
333pub fn is_not_true(expr: Expr) -> Expr {
334    Expr::IsNotTrue(Box::new(expr))
335}
336
337/// Create is false expression
338pub fn is_false(expr: Expr) -> Expr {
339    Expr::IsFalse(Box::new(expr))
340}
341
342/// Create is not false expression
343pub fn is_not_false(expr: Expr) -> Expr {
344    Expr::IsNotFalse(Box::new(expr))
345}
346
347/// Create is unknown expression
348pub fn is_unknown(expr: Expr) -> Expr {
349    Expr::IsUnknown(Box::new(expr))
350}
351
352/// Create is not unknown expression
353pub fn is_not_unknown(expr: Expr) -> Expr {
354    Expr::IsNotUnknown(Box::new(expr))
355}
356
357/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
358pub fn case(expr: Expr) -> CaseBuilder {
359    CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
360}
361
362/// Create a CASE WHEN statement with boolean WHEN expressions and no base expression.
363pub fn when(when: Expr, then: Expr) -> CaseBuilder {
364    CaseBuilder::new(None, vec![when], vec![then], None)
365}
366
367/// Create a Unnest expression
368pub fn unnest(expr: Expr) -> Expr {
369    Expr::Unnest(Unnest {
370        expr: Box::new(expr),
371    })
372}
373
374/// Convenience method to create a new user defined scalar function (UDF) with a
375/// specific signature and specific return type.
376///
377/// Note this function does not expose all available features of [`ScalarUDF`],
378/// such as
379///
380/// * computing return types based on input types
381/// * multiple [`Signature`]s
382/// * aliases
383///
384/// See [`ScalarUDF`] for details and examples on how to use the full
385/// functionality.
386pub fn create_udf(
387    name: &str,
388    input_types: Vec<DataType>,
389    return_type: DataType,
390    volatility: Volatility,
391    fun: ScalarFunctionImplementation,
392) -> ScalarUDF {
393    ScalarUDF::from(SimpleScalarUDF::new(
394        name,
395        input_types,
396        return_type,
397        volatility,
398        fun,
399    ))
400}
401
402/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
403/// return type.
404pub struct SimpleScalarUDF {
405    name: String,
406    signature: Signature,
407    return_type: DataType,
408    fun: ScalarFunctionImplementation,
409}
410
411impl Debug for SimpleScalarUDF {
412    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
413        f.debug_struct("SimpleScalarUDF")
414            .field("name", &self.name)
415            .field("signature", &self.signature)
416            .field("return_type", &self.return_type)
417            .field("fun", &"<FUNC>")
418            .finish()
419    }
420}
421
422impl SimpleScalarUDF {
423    /// Create a new `SimpleScalarUDF` from a name, input types, return type and
424    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
425    pub fn new(
426        name: impl Into<String>,
427        input_types: Vec<DataType>,
428        return_type: DataType,
429        volatility: Volatility,
430        fun: ScalarFunctionImplementation,
431    ) -> Self {
432        Self::new_with_signature(
433            name,
434            Signature::exact(input_types, volatility),
435            return_type,
436            fun,
437        )
438    }
439
440    /// Create a new `SimpleScalarUDF` from a name, signature, return type and
441    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
442    pub fn new_with_signature(
443        name: impl Into<String>,
444        signature: Signature,
445        return_type: DataType,
446        fun: ScalarFunctionImplementation,
447    ) -> Self {
448        Self {
449            name: name.into(),
450            signature,
451            return_type,
452            fun,
453        }
454    }
455}
456
457impl ScalarUDFImpl for SimpleScalarUDF {
458    fn as_any(&self) -> &dyn Any {
459        self
460    }
461
462    fn name(&self) -> &str {
463        &self.name
464    }
465
466    fn signature(&self) -> &Signature {
467        &self.signature
468    }
469
470    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
471        Ok(self.return_type.clone())
472    }
473
474    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
475        (self.fun)(&args.args)
476    }
477}
478
479/// Creates a new UDAF with a specific signature, state type and return type.
480/// The signature and state type must match the `Accumulator's implementation`.
481pub fn create_udaf(
482    name: &str,
483    input_type: Vec<DataType>,
484    return_type: Arc<DataType>,
485    volatility: Volatility,
486    accumulator: AccumulatorFactoryFunction,
487    state_type: Arc<Vec<DataType>>,
488) -> AggregateUDF {
489    let return_type = Arc::unwrap_or_clone(return_type);
490    let state_type = Arc::unwrap_or_clone(state_type);
491    let state_fields = state_type
492        .into_iter()
493        .enumerate()
494        .map(|(i, t)| Field::new(format!("{i}"), t, true))
495        .map(Arc::new)
496        .collect::<Vec<_>>();
497    AggregateUDF::from(SimpleAggregateUDF::new(
498        name,
499        input_type,
500        return_type,
501        volatility,
502        accumulator,
503        state_fields,
504    ))
505}
506
507/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
508/// return type.
509pub struct SimpleAggregateUDF {
510    name: String,
511    signature: Signature,
512    return_type: DataType,
513    accumulator: AccumulatorFactoryFunction,
514    state_fields: Vec<FieldRef>,
515}
516
517impl Debug for SimpleAggregateUDF {
518    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
519        f.debug_struct("SimpleAggregateUDF")
520            .field("name", &self.name)
521            .field("signature", &self.signature)
522            .field("return_type", &self.return_type)
523            .field("fun", &"<FUNC>")
524            .finish()
525    }
526}
527
528impl SimpleAggregateUDF {
529    /// Create a new `SimpleAggregateUDF` from a name, input types, return type, state type and
530    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
531    pub fn new(
532        name: impl Into<String>,
533        input_type: Vec<DataType>,
534        return_type: DataType,
535        volatility: Volatility,
536        accumulator: AccumulatorFactoryFunction,
537        state_fields: Vec<FieldRef>,
538    ) -> Self {
539        let name = name.into();
540        let signature = Signature::exact(input_type, volatility);
541        Self {
542            name,
543            signature,
544            return_type,
545            accumulator,
546            state_fields,
547        }
548    }
549
550    /// Create a new `SimpleAggregateUDF` from a name, signature, return type, state type and
551    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
552    pub fn new_with_signature(
553        name: impl Into<String>,
554        signature: Signature,
555        return_type: DataType,
556        accumulator: AccumulatorFactoryFunction,
557        state_fields: Vec<FieldRef>,
558    ) -> Self {
559        let name = name.into();
560        Self {
561            name,
562            signature,
563            return_type,
564            accumulator,
565            state_fields,
566        }
567    }
568}
569
570impl AggregateUDFImpl for SimpleAggregateUDF {
571    fn as_any(&self) -> &dyn Any {
572        self
573    }
574
575    fn name(&self) -> &str {
576        &self.name
577    }
578
579    fn signature(&self) -> &Signature {
580        &self.signature
581    }
582
583    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
584        Ok(self.return_type.clone())
585    }
586
587    fn accumulator(
588        &self,
589        acc_args: AccumulatorArgs,
590    ) -> Result<Box<dyn crate::Accumulator>> {
591        (self.accumulator)(acc_args)
592    }
593
594    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
595        Ok(self.state_fields.clone())
596    }
597}
598
599/// Creates a new UDWF with a specific signature, state type and return type.
600///
601/// The signature and state type must match the [`PartitionEvaluator`]'s implementation`.
602///
603/// [`PartitionEvaluator`]: crate::PartitionEvaluator
604pub fn create_udwf(
605    name: &str,
606    input_type: DataType,
607    return_type: Arc<DataType>,
608    volatility: Volatility,
609    partition_evaluator_factory: PartitionEvaluatorFactory,
610) -> WindowUDF {
611    let return_type = Arc::unwrap_or_clone(return_type);
612    WindowUDF::from(SimpleWindowUDF::new(
613        name,
614        input_type,
615        return_type,
616        volatility,
617        partition_evaluator_factory,
618    ))
619}
620
621/// Implements [`WindowUDFImpl`] for functions that have a single signature and
622/// return type.
623pub struct SimpleWindowUDF {
624    name: String,
625    signature: Signature,
626    return_type: DataType,
627    partition_evaluator_factory: PartitionEvaluatorFactory,
628}
629
630impl Debug for SimpleWindowUDF {
631    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
632        f.debug_struct("WindowUDF")
633            .field("name", &self.name)
634            .field("signature", &self.signature)
635            .field("return_type", &"<func>")
636            .field("partition_evaluator_factory", &"<FUNC>")
637            .finish()
638    }
639}
640
641impl SimpleWindowUDF {
642    /// Create a new `SimpleWindowUDF` from a name, input types, return type and
643    /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility
644    pub fn new(
645        name: impl Into<String>,
646        input_type: DataType,
647        return_type: DataType,
648        volatility: Volatility,
649        partition_evaluator_factory: PartitionEvaluatorFactory,
650    ) -> Self {
651        let name = name.into();
652        let signature = Signature::exact([input_type].to_vec(), volatility);
653        Self {
654            name,
655            signature,
656            return_type,
657            partition_evaluator_factory,
658        }
659    }
660}
661
662impl WindowUDFImpl for SimpleWindowUDF {
663    fn as_any(&self) -> &dyn Any {
664        self
665    }
666
667    fn name(&self) -> &str {
668        &self.name
669    }
670
671    fn signature(&self) -> &Signature {
672        &self.signature
673    }
674
675    fn partition_evaluator(
676        &self,
677        _partition_evaluator_args: PartitionEvaluatorArgs,
678    ) -> Result<Box<dyn PartitionEvaluator>> {
679        (self.partition_evaluator_factory)()
680    }
681
682    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
683        Ok(Arc::new(Field::new(
684            field_args.name(),
685            self.return_type.clone(),
686            true,
687        )))
688    }
689}
690
691pub fn interval_year_month_lit(value: &str) -> Expr {
692    let interval = parse_interval_year_month(value).ok();
693    Expr::Literal(ScalarValue::IntervalYearMonth(interval), None)
694}
695
696pub fn interval_datetime_lit(value: &str) -> Expr {
697    let interval = parse_interval_day_time(value).ok();
698    Expr::Literal(ScalarValue::IntervalDayTime(interval), None)
699}
700
701pub fn interval_month_day_nano_lit(value: &str) -> Expr {
702    let interval = parse_interval_month_day_nano(value).ok();
703    Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None)
704}
705
706/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
707///
708/// Adds methods to [`Expr`] that make it easy to set optional options
709/// such as `ORDER BY`, `FILTER` and `DISTINCT`
710///
711/// # Example
712/// ```no_run
713/// # use datafusion_common::Result;
714/// # use datafusion_expr::test::function_stub::count;
715/// # use sqlparser::ast::NullTreatment;
716/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col};
717/// # // first_value is an aggregate function in another crate
718/// # fn first_value(_arg: Expr) -> Expr {
719/// unimplemented!() }
720/// # fn main() -> Result<()> {
721/// // Create an aggregate count, filtering on column y > 5
722/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?;
723///
724/// // Find the first value in an aggregate sorted by column y
725/// // equivalent to:
726/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)`
727/// let sort_expr = col("y").sort(true, true);
728/// let agg = first_value(col("x"))
729///     .order_by(vec![sort_expr])
730///     .null_treatment(NullTreatment::IgnoreNulls)
731///     .build()?;
732///
733/// // Create a window expression for percent rank partitioned on column a
734/// // equivalent to:
735/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)`
736/// // percent_rank is an udwf function in another crate
737/// # fn percent_rank() -> Expr {
738/// unimplemented!() }
739/// let window = percent_rank()
740///     .partition_by(vec![col("a")])
741///     .order_by(vec![col("b").sort(true, true)])
742///     .null_treatment(NullTreatment::IgnoreNulls)
743///     .build()?;
744/// #     Ok(())
745/// # }
746/// ```
747pub trait ExprFunctionExt {
748    /// Add `ORDER BY <order_by>`
749    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
750    /// Add `FILTER <filter>`
751    fn filter(self, filter: Expr) -> ExprFuncBuilder;
752    /// Add `DISTINCT`
753    fn distinct(self) -> ExprFuncBuilder;
754    /// Add `RESPECT NULLS` or `IGNORE NULLS`
755    fn null_treatment(
756        self,
757        null_treatment: impl Into<Option<NullTreatment>>,
758    ) -> ExprFuncBuilder;
759    /// Add `PARTITION BY`
760    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
761    /// Add appropriate window frame conditions
762    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
763}
764
765#[derive(Debug, Clone)]
766pub enum ExprFuncKind {
767    Aggregate(AggregateFunction),
768    Window(WindowFunction),
769}
770
771/// Implementation of [`ExprFunctionExt`].
772///
773/// See [`ExprFunctionExt`] for usage and examples
774#[derive(Debug, Clone)]
775pub struct ExprFuncBuilder {
776    fun: Option<ExprFuncKind>,
777    order_by: Option<Vec<Sort>>,
778    filter: Option<Expr>,
779    distinct: bool,
780    null_treatment: Option<NullTreatment>,
781    partition_by: Option<Vec<Expr>>,
782    window_frame: Option<WindowFrame>,
783}
784
785impl ExprFuncBuilder {
786    /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`]
787    fn new(fun: Option<ExprFuncKind>) -> Self {
788        Self {
789            fun,
790            order_by: None,
791            filter: None,
792            distinct: false,
793            null_treatment: None,
794            partition_by: None,
795            window_frame: None,
796        }
797    }
798
799    /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
800    ///
801    /// # Errors:
802    ///
803    /// Returns an error if this builder  [`ExprFunctionExt`] was used with an
804    /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
805    pub fn build(self) -> Result<Expr> {
806        let Self {
807            fun,
808            order_by,
809            filter,
810            distinct,
811            null_treatment,
812            partition_by,
813            window_frame,
814        } = self;
815
816        let Some(fun) = fun else {
817            return plan_err!(
818                "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
819            );
820        };
821
822        let fun_expr = match fun {
823            ExprFuncKind::Aggregate(mut udaf) => {
824                udaf.params.order_by = order_by;
825                udaf.params.filter = filter.map(Box::new);
826                udaf.params.distinct = distinct;
827                udaf.params.null_treatment = null_treatment;
828                Expr::AggregateFunction(udaf)
829            }
830            ExprFuncKind::Window(WindowFunction {
831                fun,
832                params: WindowFunctionParams { args, .. },
833            }) => {
834                let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
835                Expr::from(WindowFunction {
836                    fun,
837                    params: WindowFunctionParams {
838                        args,
839                        partition_by: partition_by.unwrap_or_default(),
840                        order_by: order_by.unwrap_or_default(),
841                        window_frame: window_frame
842                            .unwrap_or_else(|| WindowFrame::new(has_order_by)),
843                        null_treatment,
844                    },
845                })
846            }
847        };
848
849        Ok(fun_expr)
850    }
851}
852
853impl ExprFunctionExt for ExprFuncBuilder {
854    /// Add `ORDER BY <order_by>`
855    fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
856        self.order_by = Some(order_by);
857        self
858    }
859
860    /// Add `FILTER <filter>`
861    fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
862        self.filter = Some(filter);
863        self
864    }
865
866    /// Add `DISTINCT`
867    fn distinct(mut self) -> ExprFuncBuilder {
868        self.distinct = true;
869        self
870    }
871
872    /// Add `RESPECT NULLS` or `IGNORE NULLS`
873    fn null_treatment(
874        mut self,
875        null_treatment: impl Into<Option<NullTreatment>>,
876    ) -> ExprFuncBuilder {
877        self.null_treatment = null_treatment.into();
878        self
879    }
880
881    fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
882        self.partition_by = Some(partition_by);
883        self
884    }
885
886    fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
887        self.window_frame = Some(window_frame);
888        self
889    }
890}
891
892impl ExprFunctionExt for Expr {
893    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
894        let mut builder = match self {
895            Expr::AggregateFunction(udaf) => {
896                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
897            }
898            Expr::WindowFunction(udwf) => {
899                ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf)))
900            }
901            _ => ExprFuncBuilder::new(None),
902        };
903        if builder.fun.is_some() {
904            builder.order_by = Some(order_by);
905        }
906        builder
907    }
908    fn filter(self, filter: Expr) -> ExprFuncBuilder {
909        match self {
910            Expr::AggregateFunction(udaf) => {
911                let mut builder =
912                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
913                builder.filter = Some(filter);
914                builder
915            }
916            _ => ExprFuncBuilder::new(None),
917        }
918    }
919    fn distinct(self) -> ExprFuncBuilder {
920        match self {
921            Expr::AggregateFunction(udaf) => {
922                let mut builder =
923                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
924                builder.distinct = true;
925                builder
926            }
927            _ => ExprFuncBuilder::new(None),
928        }
929    }
930    fn null_treatment(
931        self,
932        null_treatment: impl Into<Option<NullTreatment>>,
933    ) -> ExprFuncBuilder {
934        let mut builder = match self {
935            Expr::AggregateFunction(udaf) => {
936                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
937            }
938            Expr::WindowFunction(udwf) => {
939                ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf)))
940            }
941            _ => ExprFuncBuilder::new(None),
942        };
943        if builder.fun.is_some() {
944            builder.null_treatment = null_treatment.into();
945        }
946        builder
947    }
948
949    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
950        match self {
951            Expr::WindowFunction(udwf) => {
952                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf)));
953                builder.partition_by = Some(partition_by);
954                builder
955            }
956            _ => ExprFuncBuilder::new(None),
957        }
958    }
959
960    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
961        match self {
962            Expr::WindowFunction(udwf) => {
963                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf)));
964                builder.window_frame = Some(window_frame);
965                builder
966            }
967            _ => ExprFuncBuilder::new(None),
968        }
969    }
970}
971
972#[cfg(test)]
973mod test {
974    use super::*;
975
976    #[test]
977    fn filter_is_null_and_is_not_null() {
978        let col_null = col("col1");
979        let col_not_null = ident("col2");
980        assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
981        assert_eq!(
982            format!("{}", col_not_null.is_not_null()),
983            "col2 IS NOT NULL"
984        );
985    }
986}