datafusion_expr/
expr_fn.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Functions for creating logical expressions
19
20use crate::expr::{
21    AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22    Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, WindowFunctionParams,
23};
24use crate::function::{
25    AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26    StateFieldsArgs,
27};
28use crate::select_expr::SelectExpr;
29use crate::{
30    conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
31    AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs,
32    ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
33};
34use crate::{
35    AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
36};
37use arrow::compute::kernels::cast_utils::{
38    parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
39};
40use arrow::datatypes::{DataType, Field};
41use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference};
42use datafusion_functions_window_common::field::WindowUDFFieldArgs;
43use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
44use sqlparser::ast::NullTreatment;
45use std::any::Any;
46use std::fmt::Debug;
47use std::ops::Not;
48use std::sync::Arc;
49
50/// Create a column expression based on a qualified or unqualified column name. Will
51/// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase).
52///
53/// For example:
54///
55/// ```rust
56/// # use datafusion_expr::col;
57/// let c1 = col("a");
58/// let c2 = col("A");
59/// assert_eq!(c1, c2);
60///
61/// // note how quoting with double quotes preserves the case
62/// let c3 = col(r#""A""#);
63/// assert_ne!(c1, c3);
64/// ```
65pub fn col(ident: impl Into<Column>) -> Expr {
66    Expr::Column(ident.into())
67}
68
69/// Create an out reference column which hold a reference that has been resolved to a field
70/// outside of the current plan.
71pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
72    Expr::OuterReferenceColumn(dt, ident.into())
73}
74
75/// Create an unqualified column expression from the provided name, without normalizing
76/// the column.
77///
78/// For example:
79///
80/// ```rust
81/// # use datafusion_expr::{col, ident};
82/// let c1 = ident("A"); // not normalized staying as column 'A'
83/// let c2 = col("A"); // normalized via SQL rules becoming column 'a'
84/// assert_ne!(c1, c2);
85///
86/// let c3 = col(r#""A""#);
87/// assert_eq!(c1, c3);
88///
89/// let c4 = col("t1.a"); // parses as relation 't1' column 'a'
90/// let c5 = ident("t1.a"); // parses as column 't1.a'
91/// assert_ne!(c4, c5);
92/// ```
93pub fn ident(name: impl Into<String>) -> Expr {
94    Expr::Column(Column::from_name(name))
95}
96
97/// Create placeholder value that will be filled in (such as `$1`)
98///
99/// Note the parameter type can be inferred using [`Expr::infer_placeholder_types`]
100///
101/// # Example
102///
103/// ```rust
104/// # use datafusion_expr::{placeholder};
105/// let p = placeholder("$0"); // $0, refers to parameter 1
106/// assert_eq!(p.to_string(), "$0")
107/// ```
108pub fn placeholder(id: impl Into<String>) -> Expr {
109    Expr::Placeholder(Placeholder {
110        id: id.into(),
111        data_type: None,
112    })
113}
114
115/// Create an '*' [`Expr::Wildcard`] expression that matches all columns
116///
117/// # Example
118///
119/// ```rust
120/// # use datafusion_expr::{wildcard};
121/// let p = wildcard();
122/// assert_eq!(p.to_string(), "*")
123/// ```
124pub fn wildcard() -> SelectExpr {
125    SelectExpr::Wildcard(WildcardOptions::default())
126}
127
128/// Create an '*' [`Expr::Wildcard`] expression with the wildcard options
129pub fn wildcard_with_options(options: WildcardOptions) -> SelectExpr {
130    SelectExpr::Wildcard(options)
131}
132
133/// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table
134///
135/// # Example
136///
137/// ```rust
138/// # use datafusion_common::TableReference;
139/// # use datafusion_expr::{qualified_wildcard};
140/// let p = qualified_wildcard(TableReference::bare("t"));
141/// assert_eq!(p.to_string(), "t.*")
142/// ```
143pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> SelectExpr {
144    SelectExpr::QualifiedWildcard(qualifier.into(), WildcardOptions::default())
145}
146
147/// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options
148pub fn qualified_wildcard_with_options(
149    qualifier: impl Into<TableReference>,
150    options: WildcardOptions,
151) -> SelectExpr {
152    SelectExpr::QualifiedWildcard(qualifier.into(), options)
153}
154
155/// Return a new expression `left <op> right`
156pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
157    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
158}
159
160/// Return a new expression with a logical AND
161pub fn and(left: Expr, right: Expr) -> Expr {
162    Expr::BinaryExpr(BinaryExpr::new(
163        Box::new(left),
164        Operator::And,
165        Box::new(right),
166    ))
167}
168
169/// Return a new expression with a logical OR
170pub fn or(left: Expr, right: Expr) -> Expr {
171    Expr::BinaryExpr(BinaryExpr::new(
172        Box::new(left),
173        Operator::Or,
174        Box::new(right),
175    ))
176}
177
178/// Return a new expression with a logical NOT
179pub fn not(expr: Expr) -> Expr {
180    expr.not()
181}
182
183/// Return a new expression with bitwise AND
184pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
185    Expr::BinaryExpr(BinaryExpr::new(
186        Box::new(left),
187        Operator::BitwiseAnd,
188        Box::new(right),
189    ))
190}
191
192/// Return a new expression with bitwise OR
193pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
194    Expr::BinaryExpr(BinaryExpr::new(
195        Box::new(left),
196        Operator::BitwiseOr,
197        Box::new(right),
198    ))
199}
200
201/// Return a new expression with bitwise XOR
202pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
203    Expr::BinaryExpr(BinaryExpr::new(
204        Box::new(left),
205        Operator::BitwiseXor,
206        Box::new(right),
207    ))
208}
209
210/// Return a new expression with bitwise SHIFT RIGHT
211pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
212    Expr::BinaryExpr(BinaryExpr::new(
213        Box::new(left),
214        Operator::BitwiseShiftRight,
215        Box::new(right),
216    ))
217}
218
219/// Return a new expression with bitwise SHIFT LEFT
220pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
221    Expr::BinaryExpr(BinaryExpr::new(
222        Box::new(left),
223        Operator::BitwiseShiftLeft,
224        Box::new(right),
225    ))
226}
227
228/// Create an in_list expression
229pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
230    Expr::InList(InList::new(Box::new(expr), list, negated))
231}
232
233/// Create an EXISTS subquery expression
234pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
235    let outer_ref_columns = subquery.all_out_ref_exprs();
236    Expr::Exists(Exists {
237        subquery: Subquery {
238            subquery,
239            outer_ref_columns,
240            spans: Spans::new(),
241        },
242        negated: false,
243    })
244}
245
246/// Create a NOT EXISTS subquery expression
247pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
248    let outer_ref_columns = subquery.all_out_ref_exprs();
249    Expr::Exists(Exists {
250        subquery: Subquery {
251            subquery,
252            outer_ref_columns,
253            spans: Spans::new(),
254        },
255        negated: true,
256    })
257}
258
259/// Create an IN subquery expression
260pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
261    let outer_ref_columns = subquery.all_out_ref_exprs();
262    Expr::InSubquery(InSubquery::new(
263        Box::new(expr),
264        Subquery {
265            subquery,
266            outer_ref_columns,
267            spans: Spans::new(),
268        },
269        false,
270    ))
271}
272
273/// Create a NOT IN subquery expression
274pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
275    let outer_ref_columns = subquery.all_out_ref_exprs();
276    Expr::InSubquery(InSubquery::new(
277        Box::new(expr),
278        Subquery {
279            subquery,
280            outer_ref_columns,
281            spans: Spans::new(),
282        },
283        true,
284    ))
285}
286
287/// Create a scalar subquery expression
288pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
289    let outer_ref_columns = subquery.all_out_ref_exprs();
290    Expr::ScalarSubquery(Subquery {
291        subquery,
292        outer_ref_columns,
293        spans: Spans::new(),
294    })
295}
296
297/// Create a grouping set
298pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
299    Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
300}
301
302/// Create a grouping set for all combination of `exprs`
303pub fn cube(exprs: Vec<Expr>) -> Expr {
304    Expr::GroupingSet(GroupingSet::Cube(exprs))
305}
306
307/// Create a grouping set for rollup
308pub fn rollup(exprs: Vec<Expr>) -> Expr {
309    Expr::GroupingSet(GroupingSet::Rollup(exprs))
310}
311
312/// Create a cast expression
313pub fn cast(expr: Expr, data_type: DataType) -> Expr {
314    Expr::Cast(Cast::new(Box::new(expr), data_type))
315}
316
317/// Create a try cast expression
318pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
319    Expr::TryCast(TryCast::new(Box::new(expr), data_type))
320}
321
322/// Create is null expression
323pub fn is_null(expr: Expr) -> Expr {
324    Expr::IsNull(Box::new(expr))
325}
326
327/// Create is true expression
328pub fn is_true(expr: Expr) -> Expr {
329    Expr::IsTrue(Box::new(expr))
330}
331
332/// Create is not true expression
333pub fn is_not_true(expr: Expr) -> Expr {
334    Expr::IsNotTrue(Box::new(expr))
335}
336
337/// Create is false expression
338pub fn is_false(expr: Expr) -> Expr {
339    Expr::IsFalse(Box::new(expr))
340}
341
342/// Create is not false expression
343pub fn is_not_false(expr: Expr) -> Expr {
344    Expr::IsNotFalse(Box::new(expr))
345}
346
347/// Create is unknown expression
348pub fn is_unknown(expr: Expr) -> Expr {
349    Expr::IsUnknown(Box::new(expr))
350}
351
352/// Create is not unknown expression
353pub fn is_not_unknown(expr: Expr) -> Expr {
354    Expr::IsNotUnknown(Box::new(expr))
355}
356
357/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
358pub fn case(expr: Expr) -> CaseBuilder {
359    CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
360}
361
362/// Create a CASE WHEN statement with boolean WHEN expressions and no base expression.
363pub fn when(when: Expr, then: Expr) -> CaseBuilder {
364    CaseBuilder::new(None, vec![when], vec![then], None)
365}
366
367/// Create a Unnest expression
368pub fn unnest(expr: Expr) -> Expr {
369    Expr::Unnest(Unnest {
370        expr: Box::new(expr),
371    })
372}
373
374/// Convenience method to create a new user defined scalar function (UDF) with a
375/// specific signature and specific return type.
376///
377/// Note this function does not expose all available features of [`ScalarUDF`],
378/// such as
379///
380/// * computing return types based on input types
381/// * multiple [`Signature`]s
382/// * aliases
383///
384/// See [`ScalarUDF`] for details and examples on how to use the full
385/// functionality.
386pub fn create_udf(
387    name: &str,
388    input_types: Vec<DataType>,
389    return_type: DataType,
390    volatility: Volatility,
391    fun: ScalarFunctionImplementation,
392) -> ScalarUDF {
393    ScalarUDF::from(SimpleScalarUDF::new(
394        name,
395        input_types,
396        return_type,
397        volatility,
398        fun,
399    ))
400}
401
402/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
403/// return type.
404pub struct SimpleScalarUDF {
405    name: String,
406    signature: Signature,
407    return_type: DataType,
408    fun: ScalarFunctionImplementation,
409}
410
411impl Debug for SimpleScalarUDF {
412    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
413        f.debug_struct("SimpleScalarUDF")
414            .field("name", &self.name)
415            .field("signature", &self.signature)
416            .field("return_type", &self.return_type)
417            .field("fun", &"<FUNC>")
418            .finish()
419    }
420}
421
422impl SimpleScalarUDF {
423    /// Create a new `SimpleScalarUDF` from a name, input types, return type and
424    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
425    pub fn new(
426        name: impl Into<String>,
427        input_types: Vec<DataType>,
428        return_type: DataType,
429        volatility: Volatility,
430        fun: ScalarFunctionImplementation,
431    ) -> Self {
432        Self::new_with_signature(
433            name,
434            Signature::exact(input_types, volatility),
435            return_type,
436            fun,
437        )
438    }
439
440    /// Create a new `SimpleScalarUDF` from a name, signature, return type and
441    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
442    pub fn new_with_signature(
443        name: impl Into<String>,
444        signature: Signature,
445        return_type: DataType,
446        fun: ScalarFunctionImplementation,
447    ) -> Self {
448        Self {
449            name: name.into(),
450            signature,
451            return_type,
452            fun,
453        }
454    }
455}
456
457impl ScalarUDFImpl for SimpleScalarUDF {
458    fn as_any(&self) -> &dyn Any {
459        self
460    }
461
462    fn name(&self) -> &str {
463        &self.name
464    }
465
466    fn signature(&self) -> &Signature {
467        &self.signature
468    }
469
470    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
471        Ok(self.return_type.clone())
472    }
473
474    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
475        (self.fun)(&args.args)
476    }
477}
478
479/// Creates a new UDAF with a specific signature, state type and return type.
480/// The signature and state type must match the `Accumulator's implementation`.
481pub fn create_udaf(
482    name: &str,
483    input_type: Vec<DataType>,
484    return_type: Arc<DataType>,
485    volatility: Volatility,
486    accumulator: AccumulatorFactoryFunction,
487    state_type: Arc<Vec<DataType>>,
488) -> AggregateUDF {
489    let return_type = Arc::unwrap_or_clone(return_type);
490    let state_type = Arc::unwrap_or_clone(state_type);
491    let state_fields = state_type
492        .into_iter()
493        .enumerate()
494        .map(|(i, t)| Field::new(format!("{i}"), t, true))
495        .collect::<Vec<_>>();
496    AggregateUDF::from(SimpleAggregateUDF::new(
497        name,
498        input_type,
499        return_type,
500        volatility,
501        accumulator,
502        state_fields,
503    ))
504}
505
506/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
507/// return type.
508pub struct SimpleAggregateUDF {
509    name: String,
510    signature: Signature,
511    return_type: DataType,
512    accumulator: AccumulatorFactoryFunction,
513    state_fields: Vec<Field>,
514}
515
516impl Debug for SimpleAggregateUDF {
517    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
518        f.debug_struct("SimpleAggregateUDF")
519            .field("name", &self.name)
520            .field("signature", &self.signature)
521            .field("return_type", &self.return_type)
522            .field("fun", &"<FUNC>")
523            .finish()
524    }
525}
526
527impl SimpleAggregateUDF {
528    /// Create a new `SimpleAggregateUDF` from a name, input types, return type, state type and
529    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
530    pub fn new(
531        name: impl Into<String>,
532        input_type: Vec<DataType>,
533        return_type: DataType,
534        volatility: Volatility,
535        accumulator: AccumulatorFactoryFunction,
536        state_fields: Vec<Field>,
537    ) -> Self {
538        let name = name.into();
539        let signature = Signature::exact(input_type, volatility);
540        Self {
541            name,
542            signature,
543            return_type,
544            accumulator,
545            state_fields,
546        }
547    }
548
549    /// Create a new `SimpleAggregateUDF` from a name, signature, return type, state type and
550    /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
551    pub fn new_with_signature(
552        name: impl Into<String>,
553        signature: Signature,
554        return_type: DataType,
555        accumulator: AccumulatorFactoryFunction,
556        state_fields: Vec<Field>,
557    ) -> Self {
558        let name = name.into();
559        Self {
560            name,
561            signature,
562            return_type,
563            accumulator,
564            state_fields,
565        }
566    }
567}
568
569impl AggregateUDFImpl for SimpleAggregateUDF {
570    fn as_any(&self) -> &dyn Any {
571        self
572    }
573
574    fn name(&self) -> &str {
575        &self.name
576    }
577
578    fn signature(&self) -> &Signature {
579        &self.signature
580    }
581
582    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
583        Ok(self.return_type.clone())
584    }
585
586    fn accumulator(
587        &self,
588        acc_args: AccumulatorArgs,
589    ) -> Result<Box<dyn crate::Accumulator>> {
590        (self.accumulator)(acc_args)
591    }
592
593    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
594        Ok(self.state_fields.clone())
595    }
596}
597
598/// Creates a new UDWF with a specific signature, state type and return type.
599///
600/// The signature and state type must match the [`PartitionEvaluator`]'s implementation`.
601///
602/// [`PartitionEvaluator`]: crate::PartitionEvaluator
603pub fn create_udwf(
604    name: &str,
605    input_type: DataType,
606    return_type: Arc<DataType>,
607    volatility: Volatility,
608    partition_evaluator_factory: PartitionEvaluatorFactory,
609) -> WindowUDF {
610    let return_type = Arc::unwrap_or_clone(return_type);
611    WindowUDF::from(SimpleWindowUDF::new(
612        name,
613        input_type,
614        return_type,
615        volatility,
616        partition_evaluator_factory,
617    ))
618}
619
620/// Implements [`WindowUDFImpl`] for functions that have a single signature and
621/// return type.
622pub struct SimpleWindowUDF {
623    name: String,
624    signature: Signature,
625    return_type: DataType,
626    partition_evaluator_factory: PartitionEvaluatorFactory,
627}
628
629impl Debug for SimpleWindowUDF {
630    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
631        f.debug_struct("WindowUDF")
632            .field("name", &self.name)
633            .field("signature", &self.signature)
634            .field("return_type", &"<func>")
635            .field("partition_evaluator_factory", &"<FUNC>")
636            .finish()
637    }
638}
639
640impl SimpleWindowUDF {
641    /// Create a new `SimpleWindowUDF` from a name, input types, return type and
642    /// implementation. Implementing [`WindowUDFImpl`] allows more flexibility
643    pub fn new(
644        name: impl Into<String>,
645        input_type: DataType,
646        return_type: DataType,
647        volatility: Volatility,
648        partition_evaluator_factory: PartitionEvaluatorFactory,
649    ) -> Self {
650        let name = name.into();
651        let signature = Signature::exact([input_type].to_vec(), volatility);
652        Self {
653            name,
654            signature,
655            return_type,
656            partition_evaluator_factory,
657        }
658    }
659}
660
661impl WindowUDFImpl for SimpleWindowUDF {
662    fn as_any(&self) -> &dyn Any {
663        self
664    }
665
666    fn name(&self) -> &str {
667        &self.name
668    }
669
670    fn signature(&self) -> &Signature {
671        &self.signature
672    }
673
674    fn partition_evaluator(
675        &self,
676        _partition_evaluator_args: PartitionEvaluatorArgs,
677    ) -> Result<Box<dyn PartitionEvaluator>> {
678        (self.partition_evaluator_factory)()
679    }
680
681    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
682        Ok(Field::new(
683            field_args.name(),
684            self.return_type.clone(),
685            true,
686        ))
687    }
688}
689
690pub fn interval_year_month_lit(value: &str) -> Expr {
691    let interval = parse_interval_year_month(value).ok();
692    Expr::Literal(ScalarValue::IntervalYearMonth(interval))
693}
694
695pub fn interval_datetime_lit(value: &str) -> Expr {
696    let interval = parse_interval_day_time(value).ok();
697    Expr::Literal(ScalarValue::IntervalDayTime(interval))
698}
699
700pub fn interval_month_day_nano_lit(value: &str) -> Expr {
701    let interval = parse_interval_month_day_nano(value).ok();
702    Expr::Literal(ScalarValue::IntervalMonthDayNano(interval))
703}
704
705/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
706///
707/// Adds methods to [`Expr`] that make it easy to set optional options
708/// such as `ORDER BY`, `FILTER` and `DISTINCT`
709///
710/// # Example
711/// ```no_run
712/// # use datafusion_common::Result;
713/// # use datafusion_expr::test::function_stub::count;
714/// # use sqlparser::ast::NullTreatment;
715/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col};
716/// # // first_value is an aggregate function in another crate
717/// # fn first_value(_arg: Expr) -> Expr {
718/// unimplemented!() }
719/// # fn main() -> Result<()> {
720/// // Create an aggregate count, filtering on column y > 5
721/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?;
722///
723/// // Find the first value in an aggregate sorted by column y
724/// // equivalent to:
725/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)`
726/// let sort_expr = col("y").sort(true, true);
727/// let agg = first_value(col("x"))
728///     .order_by(vec![sort_expr])
729///     .null_treatment(NullTreatment::IgnoreNulls)
730///     .build()?;
731///
732/// // Create a window expression for percent rank partitioned on column a
733/// // equivalent to:
734/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)`
735/// // percent_rank is an udwf function in another crate
736/// # fn percent_rank() -> Expr {
737/// unimplemented!() }
738/// let window = percent_rank()
739///     .partition_by(vec![col("a")])
740///     .order_by(vec![col("b").sort(true, true)])
741///     .null_treatment(NullTreatment::IgnoreNulls)
742///     .build()?;
743/// #     Ok(())
744/// # }
745/// ```
746pub trait ExprFunctionExt {
747    /// Add `ORDER BY <order_by>`
748    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
749    /// Add `FILTER <filter>`
750    fn filter(self, filter: Expr) -> ExprFuncBuilder;
751    /// Add `DISTINCT`
752    fn distinct(self) -> ExprFuncBuilder;
753    /// Add `RESPECT NULLS` or `IGNORE NULLS`
754    fn null_treatment(
755        self,
756        null_treatment: impl Into<Option<NullTreatment>>,
757    ) -> ExprFuncBuilder;
758    /// Add `PARTITION BY`
759    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
760    /// Add appropriate window frame conditions
761    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
762}
763
764#[derive(Debug, Clone)]
765pub enum ExprFuncKind {
766    Aggregate(AggregateFunction),
767    Window(WindowFunction),
768}
769
770/// Implementation of [`ExprFunctionExt`].
771///
772/// See [`ExprFunctionExt`] for usage and examples
773#[derive(Debug, Clone)]
774pub struct ExprFuncBuilder {
775    fun: Option<ExprFuncKind>,
776    order_by: Option<Vec<Sort>>,
777    filter: Option<Expr>,
778    distinct: bool,
779    null_treatment: Option<NullTreatment>,
780    partition_by: Option<Vec<Expr>>,
781    window_frame: Option<WindowFrame>,
782}
783
784impl ExprFuncBuilder {
785    /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`]
786    fn new(fun: Option<ExprFuncKind>) -> Self {
787        Self {
788            fun,
789            order_by: None,
790            filter: None,
791            distinct: false,
792            null_treatment: None,
793            partition_by: None,
794            window_frame: None,
795        }
796    }
797
798    /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
799    ///
800    /// # Errors:
801    ///
802    /// Returns an error if this builder  [`ExprFunctionExt`] was used with an
803    /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`]
804    pub fn build(self) -> Result<Expr> {
805        let Self {
806            fun,
807            order_by,
808            filter,
809            distinct,
810            null_treatment,
811            partition_by,
812            window_frame,
813        } = self;
814
815        let Some(fun) = fun else {
816            return plan_err!(
817                "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
818            );
819        };
820
821        let fun_expr = match fun {
822            ExprFuncKind::Aggregate(mut udaf) => {
823                udaf.params.order_by = order_by;
824                udaf.params.filter = filter.map(Box::new);
825                udaf.params.distinct = distinct;
826                udaf.params.null_treatment = null_treatment;
827                Expr::AggregateFunction(udaf)
828            }
829            ExprFuncKind::Window(WindowFunction {
830                fun,
831                params: WindowFunctionParams { args, .. },
832            }) => {
833                let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
834                Expr::WindowFunction(WindowFunction {
835                    fun,
836                    params: WindowFunctionParams {
837                        args,
838                        partition_by: partition_by.unwrap_or_default(),
839                        order_by: order_by.unwrap_or_default(),
840                        window_frame: window_frame
841                            .unwrap_or(WindowFrame::new(has_order_by)),
842                        null_treatment,
843                    },
844                })
845            }
846        };
847
848        Ok(fun_expr)
849    }
850}
851
852impl ExprFunctionExt for ExprFuncBuilder {
853    /// Add `ORDER BY <order_by>`
854    fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
855        self.order_by = Some(order_by);
856        self
857    }
858
859    /// Add `FILTER <filter>`
860    fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
861        self.filter = Some(filter);
862        self
863    }
864
865    /// Add `DISTINCT`
866    fn distinct(mut self) -> ExprFuncBuilder {
867        self.distinct = true;
868        self
869    }
870
871    /// Add `RESPECT NULLS` or `IGNORE NULLS`
872    fn null_treatment(
873        mut self,
874        null_treatment: impl Into<Option<NullTreatment>>,
875    ) -> ExprFuncBuilder {
876        self.null_treatment = null_treatment.into();
877        self
878    }
879
880    fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
881        self.partition_by = Some(partition_by);
882        self
883    }
884
885    fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
886        self.window_frame = Some(window_frame);
887        self
888    }
889}
890
891impl ExprFunctionExt for Expr {
892    fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
893        let mut builder = match self {
894            Expr::AggregateFunction(udaf) => {
895                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
896            }
897            Expr::WindowFunction(udwf) => {
898                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
899            }
900            _ => ExprFuncBuilder::new(None),
901        };
902        if builder.fun.is_some() {
903            builder.order_by = Some(order_by);
904        }
905        builder
906    }
907    fn filter(self, filter: Expr) -> ExprFuncBuilder {
908        match self {
909            Expr::AggregateFunction(udaf) => {
910                let mut builder =
911                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
912                builder.filter = Some(filter);
913                builder
914            }
915            _ => ExprFuncBuilder::new(None),
916        }
917    }
918    fn distinct(self) -> ExprFuncBuilder {
919        match self {
920            Expr::AggregateFunction(udaf) => {
921                let mut builder =
922                    ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
923                builder.distinct = true;
924                builder
925            }
926            _ => ExprFuncBuilder::new(None),
927        }
928    }
929    fn null_treatment(
930        self,
931        null_treatment: impl Into<Option<NullTreatment>>,
932    ) -> ExprFuncBuilder {
933        let mut builder = match self {
934            Expr::AggregateFunction(udaf) => {
935                ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
936            }
937            Expr::WindowFunction(udwf) => {
938                ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)))
939            }
940            _ => ExprFuncBuilder::new(None),
941        };
942        if builder.fun.is_some() {
943            builder.null_treatment = null_treatment.into();
944        }
945        builder
946    }
947
948    fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
949        match self {
950            Expr::WindowFunction(udwf) => {
951                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
952                builder.partition_by = Some(partition_by);
953                builder
954            }
955            _ => ExprFuncBuilder::new(None),
956        }
957    }
958
959    fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
960        match self {
961            Expr::WindowFunction(udwf) => {
962                let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf)));
963                builder.window_frame = Some(window_frame);
964                builder
965            }
966            _ => ExprFuncBuilder::new(None),
967        }
968    }
969}
970
971#[cfg(test)]
972mod test {
973    use super::*;
974
975    #[test]
976    fn filter_is_null_and_is_not_null() {
977        let col_null = col("col1");
978        let col_not_null = ident("col2");
979        assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
980        assert_eq!(
981            format!("{}", col_not_null.is_not_null()),
982            "col2 IS NOT NULL"
983        );
984    }
985}