datafusion_expr/
expr_schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{Between, Expr, Like};
19use crate::expr::{
20    AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList,
21    InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
22    WindowFunctionParams,
23};
24use crate::type_coercion::functions::{
25    data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
26};
27use crate::udf::ReturnTypeArgs;
28use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
29use arrow::compute::can_cast_types;
30use arrow::datatypes::{DataType, Field};
31use datafusion_common::{
32    not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
33    Result, Spans, TableReference,
34};
35use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
36use datafusion_functions_window_common::field::WindowUDFFieldArgs;
37use std::collections::HashMap;
38use std::sync::Arc;
39
40/// Trait to allow expr to typable with respect to a schema
41pub trait ExprSchemable {
42    /// Given a schema, return the type of the expr
43    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
44
45    /// Given a schema, return the nullability of the expr
46    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
47
48    /// Given a schema, return the expr's optional metadata
49    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>;
50
51    /// Convert to a field with respect to a schema
52    fn to_field(
53        &self,
54        input_schema: &dyn ExprSchema,
55    ) -> Result<(Option<TableReference>, Arc<Field>)>;
56
57    /// Cast to a type with respect to a schema
58    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
59
60    /// Given a schema, return the type and nullability of the expr
61    fn data_type_and_nullable(&self, schema: &dyn ExprSchema)
62        -> Result<(DataType, bool)>;
63}
64
65impl ExprSchemable for Expr {
66    /// Returns the [arrow::datatypes::DataType] of the expression
67    /// based on [ExprSchema]
68    ///
69    /// Note: [`DFSchema`] implements [ExprSchema].
70    ///
71    /// [`DFSchema`]: datafusion_common::DFSchema
72    ///
73    /// # Examples
74    ///
75    /// Get the type of an expression that adds 2 columns. Adding an Int32
76    /// and Float32 results in Float32 type
77    ///
78    /// ```
79    /// # use arrow::datatypes::{DataType, Field};
80    /// # use datafusion_common::DFSchema;
81    /// # use datafusion_expr::{col, ExprSchemable};
82    /// # use std::collections::HashMap;
83    ///
84    /// fn main() {
85    ///   let expr = col("c1") + col("c2");
86    ///   let schema = DFSchema::from_unqualified_fields(
87    ///     vec![
88    ///       Field::new("c1", DataType::Int32, true),
89    ///       Field::new("c2", DataType::Float32, true),
90    ///       ].into(),
91    ///       HashMap::new(),
92    ///   ).unwrap();
93    ///   assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap()));
94    /// }
95    /// ```
96    ///
97    /// # Errors
98    ///
99    /// This function errors when it is not possible to compute its
100    /// [arrow::datatypes::DataType].  This happens when e.g. the
101    /// expression refers to a column that does not exist in the
102    /// schema, or when the expression is incorrectly typed
103    /// (e.g. `[utf8] + [bool]`).
104    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
105    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
106        match self {
107            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
108                Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
109                    None => schema.data_type(&Column::from_name(name)).cloned(),
110                    Some(dt) => Ok(dt.clone()),
111                },
112                _ => expr.get_type(schema),
113            },
114            Expr::Negative(expr) => expr.get_type(schema),
115            Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
116            Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()),
117            Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
118            Expr::Literal(l) => Ok(l.data_type()),
119            Expr::Case(case) => {
120                for (_, then_expr) in &case.when_then_expr {
121                    let then_type = then_expr.get_type(schema)?;
122                    if !then_type.is_null() {
123                        return Ok(then_type);
124                    }
125                }
126                case.else_expr
127                    .as_ref()
128                    .map_or(Ok(DataType::Null), |e| e.get_type(schema))
129            }
130            Expr::Cast(Cast { data_type, .. })
131            | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
132            Expr::Unnest(Unnest { expr }) => {
133                let arg_data_type = expr.get_type(schema)?;
134                // Unnest's output type is the inner type of the list
135                match arg_data_type {
136                    DataType::List(field)
137                    | DataType::LargeList(field)
138                    | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()),
139                    DataType::Struct(_) => Ok(arg_data_type),
140                    DataType::Null => {
141                        not_impl_err!("unnest() does not support null yet")
142                    }
143                    _ => {
144                        plan_err!(
145                            "unnest() can only be applied to array, struct and null"
146                        )
147                    }
148                }
149            }
150            Expr::ScalarFunction(_func) => {
151                let (return_type, _) = self.data_type_and_nullable(schema)?;
152                Ok(return_type)
153            }
154            Expr::WindowFunction(window_function) => self
155                .data_type_and_nullable_with_window_function(schema, window_function)
156                .map(|(return_type, _)| return_type),
157            Expr::AggregateFunction(AggregateFunction {
158                func,
159                params: AggregateFunctionParams { args, .. },
160            }) => {
161                let data_types = args
162                    .iter()
163                    .map(|e| e.get_type(schema))
164                    .collect::<Result<Vec<_>>>()?;
165                let new_types = data_types_with_aggregate_udf(&data_types, func)
166                    .map_err(|err| {
167                        plan_datafusion_err!(
168                            "{} {}",
169                            match err {
170                                DataFusionError::Plan(msg) => msg,
171                                err => err.to_string(),
172                            },
173                            utils::generate_signature_error_msg(
174                                func.name(),
175                                func.signature().clone(),
176                                &data_types
177                            )
178                        )
179                    })?;
180                Ok(func.return_type(&new_types)?)
181            }
182            Expr::Not(_)
183            | Expr::IsNull(_)
184            | Expr::Exists { .. }
185            | Expr::InSubquery(_)
186            | Expr::Between { .. }
187            | Expr::InList { .. }
188            | Expr::IsNotNull(_)
189            | Expr::IsTrue(_)
190            | Expr::IsFalse(_)
191            | Expr::IsUnknown(_)
192            | Expr::IsNotTrue(_)
193            | Expr::IsNotFalse(_)
194            | Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
195            Expr::ScalarSubquery(subquery) => {
196                Ok(subquery.subquery.schema().field(0).data_type().clone())
197            }
198            Expr::BinaryExpr(BinaryExpr {
199                ref left,
200                ref right,
201                ref op,
202            }) => BinaryTypeCoercer::new(
203                &left.get_type(schema)?,
204                op,
205                &right.get_type(schema)?,
206            )
207            .get_result_type(),
208            Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
209            Expr::Placeholder(Placeholder { data_type, .. }) => {
210                if let Some(dtype) = data_type {
211                    Ok(dtype.clone())
212                } else {
213                    // If the placeholder's type hasn't been specified, treat it as
214                    // null (unspecified placeholders generate an error during planning)
215                    Ok(DataType::Null)
216                }
217            }
218            #[expect(deprecated)]
219            Expr::Wildcard { .. } => Ok(DataType::Null),
220            Expr::GroupingSet(_) => {
221                // Grouping sets do not really have a type and do not appear in projections
222                Ok(DataType::Null)
223            }
224        }
225    }
226
227    /// Returns the nullability of the expression based on [ExprSchema].
228    ///
229    /// Note: [`DFSchema`] implements [ExprSchema].
230    ///
231    /// [`DFSchema`]: datafusion_common::DFSchema
232    ///
233    /// # Errors
234    ///
235    /// This function errors when it is not possible to compute its
236    /// nullability.  This happens when the expression refers to a
237    /// column that does not exist in the schema.
238    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
239        match self {
240            Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => {
241                expr.nullable(input_schema)
242            }
243
244            Expr::InList(InList { expr, list, .. }) => {
245                // Avoid inspecting too many expressions.
246                const MAX_INSPECT_LIMIT: usize = 6;
247                // Stop if a nullable expression is found or an error occurs.
248                let has_nullable = std::iter::once(expr.as_ref())
249                    .chain(list)
250                    .take(MAX_INSPECT_LIMIT)
251                    .find_map(|e| {
252                        e.nullable(input_schema)
253                            .map(|nullable| if nullable { Some(()) } else { None })
254                            .transpose()
255                    })
256                    .transpose()?;
257                Ok(match has_nullable {
258                    // If a nullable subexpression is found, the result may also be nullable.
259                    Some(_) => true,
260                    // If the list is too long, we assume it is nullable.
261                    None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
262                    // All the subexpressions are non-nullable, so the result must be non-nullable.
263                    _ => false,
264                })
265            }
266
267            Expr::Between(Between {
268                expr, low, high, ..
269            }) => Ok(expr.nullable(input_schema)?
270                || low.nullable(input_schema)?
271                || high.nullable(input_schema)?),
272
273            Expr::Column(c) => input_schema.nullable(c),
274            Expr::OuterReferenceColumn(_, _) => Ok(true),
275            Expr::Literal(value) => Ok(value.is_null()),
276            Expr::Case(case) => {
277                // This expression is nullable if any of the input expressions are nullable
278                let then_nullable = case
279                    .when_then_expr
280                    .iter()
281                    .map(|(_, t)| t.nullable(input_schema))
282                    .collect::<Result<Vec<_>>>()?;
283                if then_nullable.contains(&true) {
284                    Ok(true)
285                } else if let Some(e) = &case.else_expr {
286                    e.nullable(input_schema)
287                } else {
288                    // CASE produces NULL if there is no `else` expr
289                    // (aka when none of the `when_then_exprs` match)
290                    Ok(true)
291                }
292            }
293            Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
294            Expr::ScalarFunction(_func) => {
295                let (_, nullable) = self.data_type_and_nullable(input_schema)?;
296                Ok(nullable)
297            }
298            Expr::AggregateFunction(AggregateFunction { func, .. }) => {
299                Ok(func.is_nullable())
300            }
301            Expr::WindowFunction(window_function) => self
302                .data_type_and_nullable_with_window_function(
303                    input_schema,
304                    window_function,
305                )
306                .map(|(_, nullable)| nullable),
307            Expr::ScalarVariable(_, _)
308            | Expr::TryCast { .. }
309            | Expr::Unnest(_)
310            | Expr::Placeholder(_) => Ok(true),
311            Expr::IsNull(_)
312            | Expr::IsNotNull(_)
313            | Expr::IsTrue(_)
314            | Expr::IsFalse(_)
315            | Expr::IsUnknown(_)
316            | Expr::IsNotTrue(_)
317            | Expr::IsNotFalse(_)
318            | Expr::IsNotUnknown(_)
319            | Expr::Exists { .. } => Ok(false),
320            Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
321            Expr::ScalarSubquery(subquery) => {
322                Ok(subquery.subquery.schema().field(0).is_nullable())
323            }
324            Expr::BinaryExpr(BinaryExpr {
325                ref left,
326                ref right,
327                ..
328            }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
329            Expr::Like(Like { expr, pattern, .. })
330            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
331                Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
332            }
333            #[expect(deprecated)]
334            Expr::Wildcard { .. } => Ok(false),
335            Expr::GroupingSet(_) => {
336                // Grouping sets do not really have the concept of nullable and do not appear
337                // in projections
338                Ok(true)
339            }
340        }
341    }
342
343    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> {
344        match self {
345            Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
346            Expr::Alias(Alias { expr, metadata, .. }) => {
347                let mut ret = expr.metadata(schema)?;
348                if let Some(metadata) = metadata {
349                    if !metadata.is_empty() {
350                        ret.extend(metadata.clone());
351                        return Ok(ret);
352                    }
353                }
354                Ok(ret)
355            }
356            Expr::Cast(Cast { expr, .. }) => expr.metadata(schema),
357            _ => Ok(HashMap::new()),
358        }
359    }
360
361    /// Returns the datatype and nullability of the expression based on [ExprSchema].
362    ///
363    /// Note: [`DFSchema`] implements [ExprSchema].
364    ///
365    /// [`DFSchema`]: datafusion_common::DFSchema
366    ///
367    /// # Errors
368    ///
369    /// This function errors when it is not possible to compute its
370    /// datatype or nullability.
371    fn data_type_and_nullable(
372        &self,
373        schema: &dyn ExprSchema,
374    ) -> Result<(DataType, bool)> {
375        match self {
376            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
377                Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
378                    None => schema
379                        .data_type_and_nullable(&Column::from_name(name))
380                        .map(|(d, n)| (d.clone(), n)),
381                    Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)),
382                },
383                _ => expr.data_type_and_nullable(schema),
384            },
385            Expr::Negative(expr) => expr.data_type_and_nullable(schema),
386            Expr::Column(c) => schema
387                .data_type_and_nullable(c)
388                .map(|(d, n)| (d.clone(), n)),
389            Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)),
390            Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)),
391            Expr::Literal(l) => Ok((l.data_type(), l.is_null())),
392            Expr::IsNull(_)
393            | Expr::IsNotNull(_)
394            | Expr::IsTrue(_)
395            | Expr::IsFalse(_)
396            | Expr::IsUnknown(_)
397            | Expr::IsNotTrue(_)
398            | Expr::IsNotFalse(_)
399            | Expr::IsNotUnknown(_)
400            | Expr::Exists { .. } => Ok((DataType::Boolean, false)),
401            Expr::ScalarSubquery(subquery) => Ok((
402                subquery.subquery.schema().field(0).data_type().clone(),
403                subquery.subquery.schema().field(0).is_nullable(),
404            )),
405            Expr::BinaryExpr(BinaryExpr {
406                ref left,
407                ref right,
408                ref op,
409            }) => {
410                let (lhs_type, lhs_nullable) = left.data_type_and_nullable(schema)?;
411                let (rhs_type, rhs_nullable) = right.data_type_and_nullable(schema)?;
412                let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type);
413                coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default());
414                coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default());
415                Ok((coercer.get_result_type()?, lhs_nullable || rhs_nullable))
416            }
417            Expr::WindowFunction(window_function) => {
418                self.data_type_and_nullable_with_window_function(schema, window_function)
419            }
420            Expr::ScalarFunction(ScalarFunction { func, args }) => {
421                let (arg_types, nullables): (Vec<DataType>, Vec<bool>) = args
422                    .iter()
423                    .map(|e| e.data_type_and_nullable(schema))
424                    .collect::<Result<Vec<_>>>()?
425                    .into_iter()
426                    .unzip();
427                // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
428                let new_data_types = data_types_with_scalar_udf(&arg_types, func)
429                    .map_err(|err| {
430                        plan_datafusion_err!(
431                            "{} {}",
432                            match err {
433                                DataFusionError::Plan(msg) => msg,
434                                err => err.to_string(),
435                            },
436                            utils::generate_signature_error_msg(
437                                func.name(),
438                                func.signature().clone(),
439                                &arg_types,
440                            )
441                        )
442                    })?;
443
444                let arguments = args
445                    .iter()
446                    .map(|e| match e {
447                        Expr::Literal(sv) => Some(sv),
448                        _ => None,
449                    })
450                    .collect::<Vec<_>>();
451                let args = ReturnTypeArgs {
452                    arg_types: &new_data_types,
453                    scalar_arguments: &arguments,
454                    nullables: &nullables,
455                };
456
457                let (return_type, nullable) =
458                    func.return_type_from_args(args)?.into_parts();
459                Ok((return_type, nullable))
460            }
461            _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
462        }
463    }
464
465    /// Returns a [arrow::datatypes::Field] compatible with this expression.
466    ///
467    /// So for example, a projected expression `col(c1) + col(c2)` is
468    /// placed in an output field **named** col("c1 + c2")
469    fn to_field(
470        &self,
471        input_schema: &dyn ExprSchema,
472    ) -> Result<(Option<TableReference>, Arc<Field>)> {
473        let (relation, schema_name) = self.qualified_name();
474        let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
475        let field = Field::new(schema_name, data_type, nullable)
476            .with_metadata(self.metadata(input_schema)?)
477            .into();
478        Ok((relation, field))
479    }
480
481    /// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
482    ///
483    /// # Errors
484    ///
485    /// This function errors when it is impossible to cast the
486    /// expression to the target [arrow::datatypes::DataType].
487    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
488        let this_type = self.get_type(schema)?;
489        if this_type == *cast_to_type {
490            return Ok(self);
491        }
492
493        // TODO(kszucs): Most of the operations do not validate the type correctness
494        // like all of the binary expressions below. Perhaps Expr should track the
495        // type of the expression?
496
497        if can_cast_types(&this_type, cast_to_type) {
498            match self {
499                Expr::ScalarSubquery(subquery) => {
500                    Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
501                }
502                _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
503            }
504        } else {
505            plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}")
506        }
507    }
508}
509
510impl Expr {
511    /// Common method for window functions that applies type coercion
512    /// to all arguments of the window function to check if it matches
513    /// its signature.
514    ///
515    /// If successful, this method returns the data type and
516    /// nullability of the window function's result.
517    ///
518    /// Otherwise, returns an error if there's a type mismatch between
519    /// the window function's signature and the provided arguments.
520    fn data_type_and_nullable_with_window_function(
521        &self,
522        schema: &dyn ExprSchema,
523        window_function: &WindowFunction,
524    ) -> Result<(DataType, bool)> {
525        let WindowFunction {
526            fun,
527            params: WindowFunctionParams { args, .. },
528            ..
529        } = window_function;
530
531        let data_types = args
532            .iter()
533            .map(|e| e.get_type(schema))
534            .collect::<Result<Vec<_>>>()?;
535        match fun {
536            WindowFunctionDefinition::AggregateUDF(udaf) => {
537                let new_types = data_types_with_aggregate_udf(&data_types, udaf)
538                    .map_err(|err| {
539                        plan_datafusion_err!(
540                            "{} {}",
541                            match err {
542                                DataFusionError::Plan(msg) => msg,
543                                err => err.to_string(),
544                            },
545                            utils::generate_signature_error_msg(
546                                fun.name(),
547                                fun.signature(),
548                                &data_types
549                            )
550                        )
551                    })?;
552
553                let return_type = udaf.return_type(&new_types)?;
554                let nullable = udaf.is_nullable();
555
556                Ok((return_type, nullable))
557            }
558            WindowFunctionDefinition::WindowUDF(udwf) => {
559                let new_types =
560                    data_types_with_window_udf(&data_types, udwf).map_err(|err| {
561                        plan_datafusion_err!(
562                            "{} {}",
563                            match err {
564                                DataFusionError::Plan(msg) => msg,
565                                err => err.to_string(),
566                            },
567                            utils::generate_signature_error_msg(
568                                fun.name(),
569                                fun.signature(),
570                                &data_types
571                            )
572                        )
573                    })?;
574                let (_, function_name) = self.qualified_name();
575                let field_args = WindowUDFFieldArgs::new(&new_types, &function_name);
576
577                udwf.field(field_args)
578                    .map(|field| (field.data_type().clone(), field.is_nullable()))
579            }
580        }
581    }
582}
583
584/// Cast subquery in InSubquery/ScalarSubquery to a given type.
585///
586/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific
587///    columns), it casts the first expression in the projection to the target type and creates a
588///    new projection with the casted expression.
589/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan
590///    with the casted first column.
591///
592pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
593    if subquery.subquery.schema().field(0).data_type() == cast_to_type {
594        return Ok(subquery);
595    }
596
597    let plan = subquery.subquery.as_ref();
598    let new_plan = match plan {
599        LogicalPlan::Projection(projection) => {
600            let cast_expr = projection.expr[0]
601                .clone()
602                .cast_to(cast_to_type, projection.input.schema())?;
603            LogicalPlan::Projection(Projection::try_new(
604                vec![cast_expr],
605                Arc::clone(&projection.input),
606            )?)
607        }
608        _ => {
609            let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0)))
610                .cast_to(cast_to_type, subquery.subquery.schema())?;
611            LogicalPlan::Projection(Projection::try_new(
612                vec![cast_expr],
613                subquery.subquery,
614            )?)
615        }
616    };
617    Ok(Subquery {
618        subquery: Arc::new(new_plan),
619        outer_ref_columns: subquery.outer_ref_columns,
620        spans: Spans::new(),
621    })
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627    use crate::{col, lit};
628
629    use datafusion_common::{internal_err, DFSchema, ScalarValue};
630
631    macro_rules! test_is_expr_nullable {
632        ($EXPR_TYPE:ident) => {{
633            let expr = lit(ScalarValue::Null).$EXPR_TYPE();
634            assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
635        }};
636    }
637
638    #[test]
639    fn expr_schema_nullability() {
640        let expr = col("foo").eq(lit(1));
641        assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
642        assert!(expr
643            .nullable(&MockExprSchema::new().with_nullable(true))
644            .unwrap());
645
646        test_is_expr_nullable!(is_null);
647        test_is_expr_nullable!(is_not_null);
648        test_is_expr_nullable!(is_true);
649        test_is_expr_nullable!(is_not_true);
650        test_is_expr_nullable!(is_false);
651        test_is_expr_nullable!(is_not_false);
652        test_is_expr_nullable!(is_unknown);
653        test_is_expr_nullable!(is_not_unknown);
654    }
655
656    #[test]
657    fn test_between_nullability() {
658        let get_schema = |nullable| {
659            MockExprSchema::new()
660                .with_data_type(DataType::Int32)
661                .with_nullable(nullable)
662        };
663
664        let expr = col("foo").between(lit(1), lit(2));
665        assert!(!expr.nullable(&get_schema(false)).unwrap());
666        assert!(expr.nullable(&get_schema(true)).unwrap());
667
668        let null = lit(ScalarValue::Int32(None));
669
670        let expr = col("foo").between(null.clone(), lit(2));
671        assert!(expr.nullable(&get_schema(false)).unwrap());
672
673        let expr = col("foo").between(lit(1), null.clone());
674        assert!(expr.nullable(&get_schema(false)).unwrap());
675
676        let expr = col("foo").between(null.clone(), null);
677        assert!(expr.nullable(&get_schema(false)).unwrap());
678    }
679
680    #[test]
681    fn test_inlist_nullability() {
682        let get_schema = |nullable| {
683            MockExprSchema::new()
684                .with_data_type(DataType::Int32)
685                .with_nullable(nullable)
686        };
687
688        let expr = col("foo").in_list(vec![lit(1); 5], false);
689        assert!(!expr.nullable(&get_schema(false)).unwrap());
690        assert!(expr.nullable(&get_schema(true)).unwrap());
691        // Testing nullable() returns an error.
692        assert!(expr
693            .nullable(&get_schema(false).with_error_on_nullable(true))
694            .is_err());
695
696        let null = lit(ScalarValue::Int32(None));
697        let expr = col("foo").in_list(vec![null, lit(1)], false);
698        assert!(expr.nullable(&get_schema(false)).unwrap());
699
700        // Testing on long list
701        let expr = col("foo").in_list(vec![lit(1); 6], false);
702        assert!(expr.nullable(&get_schema(false)).unwrap());
703    }
704
705    #[test]
706    fn test_like_nullability() {
707        let get_schema = |nullable| {
708            MockExprSchema::new()
709                .with_data_type(DataType::Utf8)
710                .with_nullable(nullable)
711        };
712
713        let expr = col("foo").like(lit("bar"));
714        assert!(!expr.nullable(&get_schema(false)).unwrap());
715        assert!(expr.nullable(&get_schema(true)).unwrap());
716
717        let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
718        assert!(expr.nullable(&get_schema(false)).unwrap());
719    }
720
721    #[test]
722    fn expr_schema_data_type() {
723        let expr = col("foo");
724        assert_eq!(
725            DataType::Utf8,
726            expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
727                .unwrap()
728        );
729    }
730
731    #[test]
732    fn test_expr_metadata() {
733        let mut meta = HashMap::new();
734        meta.insert("bar".to_string(), "buzz".to_string());
735        let expr = col("foo");
736        let schema = MockExprSchema::new()
737            .with_data_type(DataType::Int32)
738            .with_metadata(meta.clone());
739
740        // col, alias, and cast should be metadata-preserving
741        assert_eq!(meta, expr.metadata(&schema).unwrap());
742        assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
743        assert_eq!(
744            meta,
745            expr.clone()
746                .cast_to(&DataType::Int64, &schema)
747                .unwrap()
748                .metadata(&schema)
749                .unwrap()
750        );
751
752        let schema = DFSchema::from_unqualified_fields(
753            vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())]
754                .into(),
755            HashMap::new(),
756        )
757        .unwrap();
758
759        // verify to_field method populates metadata
760        assert_eq!(&meta, expr.to_field(&schema).unwrap().1.metadata());
761    }
762
763    #[derive(Debug)]
764    struct MockExprSchema {
765        nullable: bool,
766        data_type: DataType,
767        error_on_nullable: bool,
768        metadata: HashMap<String, String>,
769    }
770
771    impl MockExprSchema {
772        fn new() -> Self {
773            Self {
774                nullable: false,
775                data_type: DataType::Null,
776                error_on_nullable: false,
777                metadata: HashMap::new(),
778            }
779        }
780
781        fn with_nullable(mut self, nullable: bool) -> Self {
782            self.nullable = nullable;
783            self
784        }
785
786        fn with_data_type(mut self, data_type: DataType) -> Self {
787            self.data_type = data_type;
788            self
789        }
790
791        fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
792            self.error_on_nullable = error_on_nullable;
793            self
794        }
795
796        fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
797            self.metadata = metadata;
798            self
799        }
800    }
801
802    impl ExprSchema for MockExprSchema {
803        fn nullable(&self, _col: &Column) -> Result<bool> {
804            if self.error_on_nullable {
805                internal_err!("nullable error")
806            } else {
807                Ok(self.nullable)
808            }
809        }
810
811        fn data_type(&self, _col: &Column) -> Result<&DataType> {
812            Ok(&self.data_type)
813        }
814
815        fn metadata(&self, _col: &Column) -> Result<&HashMap<String, String>> {
816            Ok(&self.metadata)
817        }
818
819        fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> {
820            Ok((self.data_type(col)?, self.nullable(col)?))
821        }
822    }
823}