1use crate::expr::{
21 AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22 Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, WindowFunctionParams,
23};
24use crate::function::{
25 AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
26 StateFieldsArgs,
27};
28use crate::select_expr::SelectExpr;
29use crate::{
30 conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
31 AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs,
32 ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
33};
34use crate::{
35 AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
36};
37use arrow::compute::kernels::cast_utils::{
38 parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
39};
40use arrow::datatypes::{DataType, Field, FieldRef};
41use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference};
42use datafusion_functions_window_common::field::WindowUDFFieldArgs;
43use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
44use sqlparser::ast::NullTreatment;
45use std::any::Any;
46use std::fmt::Debug;
47use std::ops::Not;
48use std::sync::Arc;
49
50pub fn col(ident: impl Into<Column>) -> Expr {
66 Expr::Column(ident.into())
67}
68
69pub fn out_ref_col(dt: DataType, ident: impl Into<Column>) -> Expr {
72 Expr::OuterReferenceColumn(dt, ident.into())
73}
74
75pub fn ident(name: impl Into<String>) -> Expr {
94 Expr::Column(Column::from_name(name))
95}
96
97pub fn placeholder(id: impl Into<String>) -> Expr {
109 Expr::Placeholder(Placeholder {
110 id: id.into(),
111 data_type: None,
112 })
113}
114
115pub fn wildcard() -> SelectExpr {
125 SelectExpr::Wildcard(WildcardOptions::default())
126}
127
128pub fn wildcard_with_options(options: WildcardOptions) -> SelectExpr {
130 SelectExpr::Wildcard(options)
131}
132
133pub fn qualified_wildcard(qualifier: impl Into<TableReference>) -> SelectExpr {
144 SelectExpr::QualifiedWildcard(qualifier.into(), WildcardOptions::default())
145}
146
147pub fn qualified_wildcard_with_options(
149 qualifier: impl Into<TableReference>,
150 options: WildcardOptions,
151) -> SelectExpr {
152 SelectExpr::QualifiedWildcard(qualifier.into(), options)
153}
154
155pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr {
157 Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
158}
159
160pub fn and(left: Expr, right: Expr) -> Expr {
162 Expr::BinaryExpr(BinaryExpr::new(
163 Box::new(left),
164 Operator::And,
165 Box::new(right),
166 ))
167}
168
169pub fn or(left: Expr, right: Expr) -> Expr {
171 Expr::BinaryExpr(BinaryExpr::new(
172 Box::new(left),
173 Operator::Or,
174 Box::new(right),
175 ))
176}
177
178pub fn not(expr: Expr) -> Expr {
180 expr.not()
181}
182
183pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
185 Expr::BinaryExpr(BinaryExpr::new(
186 Box::new(left),
187 Operator::BitwiseAnd,
188 Box::new(right),
189 ))
190}
191
192pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
194 Expr::BinaryExpr(BinaryExpr::new(
195 Box::new(left),
196 Operator::BitwiseOr,
197 Box::new(right),
198 ))
199}
200
201pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
203 Expr::BinaryExpr(BinaryExpr::new(
204 Box::new(left),
205 Operator::BitwiseXor,
206 Box::new(right),
207 ))
208}
209
210pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
212 Expr::BinaryExpr(BinaryExpr::new(
213 Box::new(left),
214 Operator::BitwiseShiftRight,
215 Box::new(right),
216 ))
217}
218
219pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
221 Expr::BinaryExpr(BinaryExpr::new(
222 Box::new(left),
223 Operator::BitwiseShiftLeft,
224 Box::new(right),
225 ))
226}
227
228pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
230 Expr::InList(InList::new(Box::new(expr), list, negated))
231}
232
233pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
235 let outer_ref_columns = subquery.all_out_ref_exprs();
236 Expr::Exists(Exists {
237 subquery: Subquery {
238 subquery,
239 outer_ref_columns,
240 spans: Spans::new(),
241 },
242 negated: false,
243 })
244}
245
246pub fn not_exists(subquery: Arc<LogicalPlan>) -> Expr {
248 let outer_ref_columns = subquery.all_out_ref_exprs();
249 Expr::Exists(Exists {
250 subquery: Subquery {
251 subquery,
252 outer_ref_columns,
253 spans: Spans::new(),
254 },
255 negated: true,
256 })
257}
258
259pub fn in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
261 let outer_ref_columns = subquery.all_out_ref_exprs();
262 Expr::InSubquery(InSubquery::new(
263 Box::new(expr),
264 Subquery {
265 subquery,
266 outer_ref_columns,
267 spans: Spans::new(),
268 },
269 false,
270 ))
271}
272
273pub fn not_in_subquery(expr: Expr, subquery: Arc<LogicalPlan>) -> Expr {
275 let outer_ref_columns = subquery.all_out_ref_exprs();
276 Expr::InSubquery(InSubquery::new(
277 Box::new(expr),
278 Subquery {
279 subquery,
280 outer_ref_columns,
281 spans: Spans::new(),
282 },
283 true,
284 ))
285}
286
287pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
289 let outer_ref_columns = subquery.all_out_ref_exprs();
290 Expr::ScalarSubquery(Subquery {
291 subquery,
292 outer_ref_columns,
293 spans: Spans::new(),
294 })
295}
296
297pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
299 Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
300}
301
302pub fn cube(exprs: Vec<Expr>) -> Expr {
304 Expr::GroupingSet(GroupingSet::Cube(exprs))
305}
306
307pub fn rollup(exprs: Vec<Expr>) -> Expr {
309 Expr::GroupingSet(GroupingSet::Rollup(exprs))
310}
311
312pub fn cast(expr: Expr, data_type: DataType) -> Expr {
314 Expr::Cast(Cast::new(Box::new(expr), data_type))
315}
316
317pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
319 Expr::TryCast(TryCast::new(Box::new(expr), data_type))
320}
321
322pub fn is_null(expr: Expr) -> Expr {
324 Expr::IsNull(Box::new(expr))
325}
326
327pub fn is_true(expr: Expr) -> Expr {
329 Expr::IsTrue(Box::new(expr))
330}
331
332pub fn is_not_true(expr: Expr) -> Expr {
334 Expr::IsNotTrue(Box::new(expr))
335}
336
337pub fn is_false(expr: Expr) -> Expr {
339 Expr::IsFalse(Box::new(expr))
340}
341
342pub fn is_not_false(expr: Expr) -> Expr {
344 Expr::IsNotFalse(Box::new(expr))
345}
346
347pub fn is_unknown(expr: Expr) -> Expr {
349 Expr::IsUnknown(Box::new(expr))
350}
351
352pub fn is_not_unknown(expr: Expr) -> Expr {
354 Expr::IsNotUnknown(Box::new(expr))
355}
356
357pub fn case(expr: Expr) -> CaseBuilder {
359 CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None)
360}
361
362pub fn when(when: Expr, then: Expr) -> CaseBuilder {
364 CaseBuilder::new(None, vec![when], vec![then], None)
365}
366
367pub fn unnest(expr: Expr) -> Expr {
369 Expr::Unnest(Unnest {
370 expr: Box::new(expr),
371 })
372}
373
374pub fn create_udf(
387 name: &str,
388 input_types: Vec<DataType>,
389 return_type: DataType,
390 volatility: Volatility,
391 fun: ScalarFunctionImplementation,
392) -> ScalarUDF {
393 ScalarUDF::from(SimpleScalarUDF::new(
394 name,
395 input_types,
396 return_type,
397 volatility,
398 fun,
399 ))
400}
401
402pub struct SimpleScalarUDF {
405 name: String,
406 signature: Signature,
407 return_type: DataType,
408 fun: ScalarFunctionImplementation,
409}
410
411impl Debug for SimpleScalarUDF {
412 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
413 f.debug_struct("SimpleScalarUDF")
414 .field("name", &self.name)
415 .field("signature", &self.signature)
416 .field("return_type", &self.return_type)
417 .field("fun", &"<FUNC>")
418 .finish()
419 }
420}
421
422impl SimpleScalarUDF {
423 pub fn new(
426 name: impl Into<String>,
427 input_types: Vec<DataType>,
428 return_type: DataType,
429 volatility: Volatility,
430 fun: ScalarFunctionImplementation,
431 ) -> Self {
432 Self::new_with_signature(
433 name,
434 Signature::exact(input_types, volatility),
435 return_type,
436 fun,
437 )
438 }
439
440 pub fn new_with_signature(
443 name: impl Into<String>,
444 signature: Signature,
445 return_type: DataType,
446 fun: ScalarFunctionImplementation,
447 ) -> Self {
448 Self {
449 name: name.into(),
450 signature,
451 return_type,
452 fun,
453 }
454 }
455}
456
457impl ScalarUDFImpl for SimpleScalarUDF {
458 fn as_any(&self) -> &dyn Any {
459 self
460 }
461
462 fn name(&self) -> &str {
463 &self.name
464 }
465
466 fn signature(&self) -> &Signature {
467 &self.signature
468 }
469
470 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
471 Ok(self.return_type.clone())
472 }
473
474 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
475 (self.fun)(&args.args)
476 }
477}
478
479pub fn create_udaf(
482 name: &str,
483 input_type: Vec<DataType>,
484 return_type: Arc<DataType>,
485 volatility: Volatility,
486 accumulator: AccumulatorFactoryFunction,
487 state_type: Arc<Vec<DataType>>,
488) -> AggregateUDF {
489 let return_type = Arc::unwrap_or_clone(return_type);
490 let state_type = Arc::unwrap_or_clone(state_type);
491 let state_fields = state_type
492 .into_iter()
493 .enumerate()
494 .map(|(i, t)| Field::new(format!("{i}"), t, true))
495 .map(Arc::new)
496 .collect::<Vec<_>>();
497 AggregateUDF::from(SimpleAggregateUDF::new(
498 name,
499 input_type,
500 return_type,
501 volatility,
502 accumulator,
503 state_fields,
504 ))
505}
506
507pub struct SimpleAggregateUDF {
510 name: String,
511 signature: Signature,
512 return_type: DataType,
513 accumulator: AccumulatorFactoryFunction,
514 state_fields: Vec<FieldRef>,
515}
516
517impl Debug for SimpleAggregateUDF {
518 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
519 f.debug_struct("SimpleAggregateUDF")
520 .field("name", &self.name)
521 .field("signature", &self.signature)
522 .field("return_type", &self.return_type)
523 .field("fun", &"<FUNC>")
524 .finish()
525 }
526}
527
528impl SimpleAggregateUDF {
529 pub fn new(
532 name: impl Into<String>,
533 input_type: Vec<DataType>,
534 return_type: DataType,
535 volatility: Volatility,
536 accumulator: AccumulatorFactoryFunction,
537 state_fields: Vec<FieldRef>,
538 ) -> Self {
539 let name = name.into();
540 let signature = Signature::exact(input_type, volatility);
541 Self {
542 name,
543 signature,
544 return_type,
545 accumulator,
546 state_fields,
547 }
548 }
549
550 pub fn new_with_signature(
553 name: impl Into<String>,
554 signature: Signature,
555 return_type: DataType,
556 accumulator: AccumulatorFactoryFunction,
557 state_fields: Vec<FieldRef>,
558 ) -> Self {
559 let name = name.into();
560 Self {
561 name,
562 signature,
563 return_type,
564 accumulator,
565 state_fields,
566 }
567 }
568}
569
570impl AggregateUDFImpl for SimpleAggregateUDF {
571 fn as_any(&self) -> &dyn Any {
572 self
573 }
574
575 fn name(&self) -> &str {
576 &self.name
577 }
578
579 fn signature(&self) -> &Signature {
580 &self.signature
581 }
582
583 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
584 Ok(self.return_type.clone())
585 }
586
587 fn accumulator(
588 &self,
589 acc_args: AccumulatorArgs,
590 ) -> Result<Box<dyn crate::Accumulator>> {
591 (self.accumulator)(acc_args)
592 }
593
594 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
595 Ok(self.state_fields.clone())
596 }
597}
598
599pub fn create_udwf(
605 name: &str,
606 input_type: DataType,
607 return_type: Arc<DataType>,
608 volatility: Volatility,
609 partition_evaluator_factory: PartitionEvaluatorFactory,
610) -> WindowUDF {
611 let return_type = Arc::unwrap_or_clone(return_type);
612 WindowUDF::from(SimpleWindowUDF::new(
613 name,
614 input_type,
615 return_type,
616 volatility,
617 partition_evaluator_factory,
618 ))
619}
620
621pub struct SimpleWindowUDF {
624 name: String,
625 signature: Signature,
626 return_type: DataType,
627 partition_evaluator_factory: PartitionEvaluatorFactory,
628}
629
630impl Debug for SimpleWindowUDF {
631 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
632 f.debug_struct("WindowUDF")
633 .field("name", &self.name)
634 .field("signature", &self.signature)
635 .field("return_type", &"<func>")
636 .field("partition_evaluator_factory", &"<FUNC>")
637 .finish()
638 }
639}
640
641impl SimpleWindowUDF {
642 pub fn new(
645 name: impl Into<String>,
646 input_type: DataType,
647 return_type: DataType,
648 volatility: Volatility,
649 partition_evaluator_factory: PartitionEvaluatorFactory,
650 ) -> Self {
651 let name = name.into();
652 let signature = Signature::exact([input_type].to_vec(), volatility);
653 Self {
654 name,
655 signature,
656 return_type,
657 partition_evaluator_factory,
658 }
659 }
660}
661
662impl WindowUDFImpl for SimpleWindowUDF {
663 fn as_any(&self) -> &dyn Any {
664 self
665 }
666
667 fn name(&self) -> &str {
668 &self.name
669 }
670
671 fn signature(&self) -> &Signature {
672 &self.signature
673 }
674
675 fn partition_evaluator(
676 &self,
677 _partition_evaluator_args: PartitionEvaluatorArgs,
678 ) -> Result<Box<dyn PartitionEvaluator>> {
679 (self.partition_evaluator_factory)()
680 }
681
682 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
683 Ok(Arc::new(Field::new(
684 field_args.name(),
685 self.return_type.clone(),
686 true,
687 )))
688 }
689}
690
691pub fn interval_year_month_lit(value: &str) -> Expr {
692 let interval = parse_interval_year_month(value).ok();
693 Expr::Literal(ScalarValue::IntervalYearMonth(interval), None)
694}
695
696pub fn interval_datetime_lit(value: &str) -> Expr {
697 let interval = parse_interval_day_time(value).ok();
698 Expr::Literal(ScalarValue::IntervalDayTime(interval), None)
699}
700
701pub fn interval_month_day_nano_lit(value: &str) -> Expr {
702 let interval = parse_interval_month_day_nano(value).ok();
703 Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None)
704}
705
706pub trait ExprFunctionExt {
748 fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder;
750 fn filter(self, filter: Expr) -> ExprFuncBuilder;
752 fn distinct(self) -> ExprFuncBuilder;
754 fn null_treatment(
756 self,
757 null_treatment: impl Into<Option<NullTreatment>>,
758 ) -> ExprFuncBuilder;
759 fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder;
761 fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder;
763}
764
765#[derive(Debug, Clone)]
766pub enum ExprFuncKind {
767 Aggregate(AggregateFunction),
768 Window(WindowFunction),
769}
770
771#[derive(Debug, Clone)]
775pub struct ExprFuncBuilder {
776 fun: Option<ExprFuncKind>,
777 order_by: Option<Vec<Sort>>,
778 filter: Option<Expr>,
779 distinct: bool,
780 null_treatment: Option<NullTreatment>,
781 partition_by: Option<Vec<Expr>>,
782 window_frame: Option<WindowFrame>,
783}
784
785impl ExprFuncBuilder {
786 fn new(fun: Option<ExprFuncKind>) -> Self {
788 Self {
789 fun,
790 order_by: None,
791 filter: None,
792 distinct: false,
793 null_treatment: None,
794 partition_by: None,
795 window_frame: None,
796 }
797 }
798
799 pub fn build(self) -> Result<Expr> {
806 let Self {
807 fun,
808 order_by,
809 filter,
810 distinct,
811 null_treatment,
812 partition_by,
813 window_frame,
814 } = self;
815
816 let Some(fun) = fun else {
817 return plan_err!(
818 "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction"
819 );
820 };
821
822 let fun_expr = match fun {
823 ExprFuncKind::Aggregate(mut udaf) => {
824 udaf.params.order_by = order_by;
825 udaf.params.filter = filter.map(Box::new);
826 udaf.params.distinct = distinct;
827 udaf.params.null_treatment = null_treatment;
828 Expr::AggregateFunction(udaf)
829 }
830 ExprFuncKind::Window(WindowFunction {
831 fun,
832 params: WindowFunctionParams { args, .. },
833 }) => {
834 let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
835 Expr::from(WindowFunction {
836 fun,
837 params: WindowFunctionParams {
838 args,
839 partition_by: partition_by.unwrap_or_default(),
840 order_by: order_by.unwrap_or_default(),
841 window_frame: window_frame
842 .unwrap_or_else(|| WindowFrame::new(has_order_by)),
843 null_treatment,
844 },
845 })
846 }
847 };
848
849 Ok(fun_expr)
850 }
851}
852
853impl ExprFunctionExt for ExprFuncBuilder {
854 fn order_by(mut self, order_by: Vec<Sort>) -> ExprFuncBuilder {
856 self.order_by = Some(order_by);
857 self
858 }
859
860 fn filter(mut self, filter: Expr) -> ExprFuncBuilder {
862 self.filter = Some(filter);
863 self
864 }
865
866 fn distinct(mut self) -> ExprFuncBuilder {
868 self.distinct = true;
869 self
870 }
871
872 fn null_treatment(
874 mut self,
875 null_treatment: impl Into<Option<NullTreatment>>,
876 ) -> ExprFuncBuilder {
877 self.null_treatment = null_treatment.into();
878 self
879 }
880
881 fn partition_by(mut self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
882 self.partition_by = Some(partition_by);
883 self
884 }
885
886 fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder {
887 self.window_frame = Some(window_frame);
888 self
889 }
890}
891
892impl ExprFunctionExt for Expr {
893 fn order_by(self, order_by: Vec<Sort>) -> ExprFuncBuilder {
894 let mut builder = match self {
895 Expr::AggregateFunction(udaf) => {
896 ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
897 }
898 Expr::WindowFunction(udwf) => {
899 ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf)))
900 }
901 _ => ExprFuncBuilder::new(None),
902 };
903 if builder.fun.is_some() {
904 builder.order_by = Some(order_by);
905 }
906 builder
907 }
908 fn filter(self, filter: Expr) -> ExprFuncBuilder {
909 match self {
910 Expr::AggregateFunction(udaf) => {
911 let mut builder =
912 ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
913 builder.filter = Some(filter);
914 builder
915 }
916 _ => ExprFuncBuilder::new(None),
917 }
918 }
919 fn distinct(self) -> ExprFuncBuilder {
920 match self {
921 Expr::AggregateFunction(udaf) => {
922 let mut builder =
923 ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)));
924 builder.distinct = true;
925 builder
926 }
927 _ => ExprFuncBuilder::new(None),
928 }
929 }
930 fn null_treatment(
931 self,
932 null_treatment: impl Into<Option<NullTreatment>>,
933 ) -> ExprFuncBuilder {
934 let mut builder = match self {
935 Expr::AggregateFunction(udaf) => {
936 ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf)))
937 }
938 Expr::WindowFunction(udwf) => {
939 ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf)))
940 }
941 _ => ExprFuncBuilder::new(None),
942 };
943 if builder.fun.is_some() {
944 builder.null_treatment = null_treatment.into();
945 }
946 builder
947 }
948
949 fn partition_by(self, partition_by: Vec<Expr>) -> ExprFuncBuilder {
950 match self {
951 Expr::WindowFunction(udwf) => {
952 let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf)));
953 builder.partition_by = Some(partition_by);
954 builder
955 }
956 _ => ExprFuncBuilder::new(None),
957 }
958 }
959
960 fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder {
961 match self {
962 Expr::WindowFunction(udwf) => {
963 let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf)));
964 builder.window_frame = Some(window_frame);
965 builder
966 }
967 _ => ExprFuncBuilder::new(None),
968 }
969 }
970}
971
972#[cfg(test)]
973mod test {
974 use super::*;
975
976 #[test]
977 fn filter_is_null_and_is_not_null() {
978 let col_null = col("col1");
979 let col_not_null = ident("col2");
980 assert_eq!(format!("{}", col_null.is_null()), "col1 IS NULL");
981 assert_eq!(
982 format!("{}", col_not_null.is_not_null()),
983 "col2 IS NOT NULL"
984 );
985 }
986}