datafusion_functions_aggregate/
min_max.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
19//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
20
21mod min_max_bytes;
22
23use arrow::array::{
24    ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
25    Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray,
26    DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array,
27    Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
28    IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
29    LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
30    Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
31    TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
32    TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
33};
34use arrow::compute;
35use arrow::datatypes::{
36    DataType, Decimal128Type, Decimal256Type, DurationMicrosecondType,
37    DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type,
38    Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit,
39    UInt16Type, UInt32Type, UInt64Type, UInt8Type,
40};
41use datafusion_common::stats::Precision;
42use datafusion_common::{
43    downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result,
44};
45use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
46use datafusion_physical_expr::expressions;
47use std::cmp::Ordering;
48use std::fmt::Debug;
49
50use arrow::datatypes::i256;
51use arrow::datatypes::{
52    Date32Type, Date64Type, Time32MillisecondType, Time32SecondType,
53    Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
54    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
55};
56
57use crate::min_max::min_max_bytes::MinMaxBytesAccumulator;
58use datafusion_common::ScalarValue;
59use datafusion_expr::{
60    function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation,
61    SetMonotonicity, Signature, Volatility,
62};
63use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
64use datafusion_macros::user_doc;
65use half::f16;
66use std::mem::size_of_val;
67use std::ops::Deref;
68
69fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
70    // make sure that the input types only has one element.
71    if input_types.len() != 1 {
72        return exec_err!(
73            "min/max was called with {} arguments. It requires only 1.",
74            input_types.len()
75        );
76    }
77    // min and max support the dictionary data type
78    // unpack the dictionary to get the value
79    match &input_types[0] {
80        DataType::Dictionary(_, dict_value_type) => {
81            // TODO add checker, if the value type is complex data type
82            Ok(vec![dict_value_type.deref().clone()])
83        }
84        // TODO add checker for datatype which min and max supported
85        // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
86        _ => Ok(input_types.to_vec()),
87    }
88}
89
90#[user_doc(
91    doc_section(label = "General Functions"),
92    description = "Returns the maximum value in the specified column.",
93    syntax_example = "max(expression)",
94    sql_example = r#"```sql
95> SELECT max(column_name) FROM table_name;
96+----------------------+
97| max(column_name)      |
98+----------------------+
99| 150                  |
100+----------------------+
101```"#,
102    standard_argument(name = "expression",)
103)]
104// MAX aggregate UDF
105#[derive(Debug)]
106pub struct Max {
107    signature: Signature,
108}
109
110impl Max {
111    pub fn new() -> Self {
112        Self {
113            signature: Signature::user_defined(Volatility::Immutable),
114        }
115    }
116}
117
118impl Default for Max {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX`
124/// the specified [`ArrowPrimitiveType`].
125///
126/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
127macro_rules! primitive_max_accumulator {
128    ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
129        Ok(Box::new(
130            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| {
131                match (new).partial_cmp(cur) {
132                    Some(Ordering::Greater) | None => {
133                        // new is Greater or None
134                        *cur = new
135                    }
136                    _ => {}
137                }
138            })
139            // Initialize each accumulator to $NATIVE::MIN
140            .with_starting_value($NATIVE::MIN),
141        ))
142    }};
143}
144
145/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN`
146/// the specified [`ArrowPrimitiveType`].
147///
148///
149/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
150macro_rules! primitive_min_accumulator {
151    ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
152        Ok(Box::new(
153            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| {
154                match (new).partial_cmp(cur) {
155                    Some(Ordering::Less) | None => {
156                        // new is Less or NaN
157                        *cur = new
158                    }
159                    _ => {}
160                }
161            })
162            // Initialize each accumulator to $NATIVE::MAX
163            .with_starting_value($NATIVE::MAX),
164        ))
165    }};
166}
167
168trait FromColumnStatistics {
169    fn value_from_column_statistics(
170        &self,
171        stats: &ColumnStatistics,
172    ) -> Option<ScalarValue>;
173
174    fn value_from_statistics(
175        &self,
176        statistics_args: &StatisticsArgs,
177    ) -> Option<ScalarValue> {
178        if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows {
179            match *num_rows {
180                0 => return ScalarValue::try_from(statistics_args.return_type).ok(),
181                value if value > 0 => {
182                    let col_stats = &statistics_args.statistics.column_statistics;
183                    if statistics_args.exprs.len() == 1 {
184                        // TODO optimize with exprs other than Column
185                        if let Some(col_expr) = statistics_args.exprs[0]
186                            .as_any()
187                            .downcast_ref::<expressions::Column>()
188                        {
189                            return self.value_from_column_statistics(
190                                &col_stats[col_expr.index()],
191                            );
192                        }
193                    }
194                }
195                _ => {}
196            }
197        }
198        None
199    }
200}
201
202impl FromColumnStatistics for Max {
203    fn value_from_column_statistics(
204        &self,
205        col_stats: &ColumnStatistics,
206    ) -> Option<ScalarValue> {
207        if let Precision::Exact(ref val) = col_stats.max_value {
208            if !val.is_null() {
209                return Some(val.clone());
210            }
211        }
212        None
213    }
214}
215
216impl AggregateUDFImpl for Max {
217    fn as_any(&self) -> &dyn std::any::Any {
218        self
219    }
220
221    fn name(&self) -> &str {
222        "max"
223    }
224
225    fn signature(&self) -> &Signature {
226        &self.signature
227    }
228
229    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
230        Ok(arg_types[0].to_owned())
231    }
232
233    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
234        Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?))
235    }
236
237    fn aliases(&self) -> &[String] {
238        &[]
239    }
240
241    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
242        use DataType::*;
243        matches!(
244            args.return_type,
245            Int8 | Int16
246                | Int32
247                | Int64
248                | UInt8
249                | UInt16
250                | UInt32
251                | UInt64
252                | Float16
253                | Float32
254                | Float64
255                | Decimal128(_, _)
256                | Decimal256(_, _)
257                | Date32
258                | Date64
259                | Time32(_)
260                | Time64(_)
261                | Timestamp(_, _)
262                | Utf8
263                | LargeUtf8
264                | Utf8View
265                | Binary
266                | LargeBinary
267                | BinaryView
268                | Duration(_)
269        )
270    }
271
272    fn create_groups_accumulator(
273        &self,
274        args: AccumulatorArgs,
275    ) -> Result<Box<dyn GroupsAccumulator>> {
276        use DataType::*;
277        use TimeUnit::*;
278        let data_type = args.return_type;
279        match data_type {
280            Int8 => primitive_max_accumulator!(data_type, i8, Int8Type),
281            Int16 => primitive_max_accumulator!(data_type, i16, Int16Type),
282            Int32 => primitive_max_accumulator!(data_type, i32, Int32Type),
283            Int64 => primitive_max_accumulator!(data_type, i64, Int64Type),
284            UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type),
285            UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type),
286            UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type),
287            UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type),
288            Float16 => {
289                primitive_max_accumulator!(data_type, f16, Float16Type)
290            }
291            Float32 => {
292                primitive_max_accumulator!(data_type, f32, Float32Type)
293            }
294            Float64 => {
295                primitive_max_accumulator!(data_type, f64, Float64Type)
296            }
297            Date32 => primitive_max_accumulator!(data_type, i32, Date32Type),
298            Date64 => primitive_max_accumulator!(data_type, i64, Date64Type),
299            Time32(Second) => {
300                primitive_max_accumulator!(data_type, i32, Time32SecondType)
301            }
302            Time32(Millisecond) => {
303                primitive_max_accumulator!(data_type, i32, Time32MillisecondType)
304            }
305            Time64(Microsecond) => {
306                primitive_max_accumulator!(data_type, i64, Time64MicrosecondType)
307            }
308            Time64(Nanosecond) => {
309                primitive_max_accumulator!(data_type, i64, Time64NanosecondType)
310            }
311            Timestamp(Second, _) => {
312                primitive_max_accumulator!(data_type, i64, TimestampSecondType)
313            }
314            Timestamp(Millisecond, _) => {
315                primitive_max_accumulator!(data_type, i64, TimestampMillisecondType)
316            }
317            Timestamp(Microsecond, _) => {
318                primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType)
319            }
320            Timestamp(Nanosecond, _) => {
321                primitive_max_accumulator!(data_type, i64, TimestampNanosecondType)
322            }
323            Duration(Second) => {
324                primitive_max_accumulator!(data_type, i64, DurationSecondType)
325            }
326            Duration(Millisecond) => {
327                primitive_max_accumulator!(data_type, i64, DurationMillisecondType)
328            }
329            Duration(Microsecond) => {
330                primitive_max_accumulator!(data_type, i64, DurationMicrosecondType)
331            }
332            Duration(Nanosecond) => {
333                primitive_max_accumulator!(data_type, i64, DurationNanosecondType)
334            }
335            Decimal128(_, _) => {
336                primitive_max_accumulator!(data_type, i128, Decimal128Type)
337            }
338            Decimal256(_, _) => {
339                primitive_max_accumulator!(data_type, i256, Decimal256Type)
340            }
341            Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
342                Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone())))
343            }
344
345            // This is only reached if groups_accumulator_supported is out of sync
346            _ => internal_err!("GroupsAccumulator not supported for max({})", data_type),
347        }
348    }
349
350    fn create_sliding_accumulator(
351        &self,
352        args: AccumulatorArgs,
353    ) -> Result<Box<dyn Accumulator>> {
354        Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?))
355    }
356
357    fn is_descending(&self) -> Option<bool> {
358        Some(true)
359    }
360
361    fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
362        datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
363    }
364
365    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
366        get_min_max_result_type(arg_types)
367    }
368    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
369        datafusion_expr::ReversedUDAF::Identical
370    }
371    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
372        self.value_from_statistics(statistics_args)
373    }
374
375    fn documentation(&self) -> Option<&Documentation> {
376        self.doc()
377    }
378
379    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
380        // `MAX` is monotonically increasing as it always increases or stays
381        // the same as new values are seen.
382        SetMonotonicity::Increasing
383    }
384}
385
386// Statically-typed version of min/max(array) -> ScalarValue for string types
387macro_rules! typed_min_max_batch_string {
388    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
389        let array = downcast_value!($VALUES, $ARRAYTYPE);
390        let value = compute::$OP(array);
391        let value = value.and_then(|e| Some(e.to_string()));
392        ScalarValue::$SCALAR(value)
393    }};
394}
395// Statically-typed version of min/max(array) -> ScalarValue for binary types.
396macro_rules! typed_min_max_batch_binary {
397    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
398        let array = downcast_value!($VALUES, $ARRAYTYPE);
399        let value = compute::$OP(array);
400        let value = value.and_then(|e| Some(e.to_vec()));
401        ScalarValue::$SCALAR(value)
402    }};
403}
404
405// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
406macro_rules! typed_min_max_batch {
407    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
408        let array = downcast_value!($VALUES, $ARRAYTYPE);
409        let value = compute::$OP(array);
410        ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*)
411    }};
412}
413
414// Statically-typed version of min/max(array) -> ScalarValue  for non-string types.
415// this is a macro to support both operations (min and max).
416macro_rules! min_max_batch {
417    ($VALUES:expr, $OP:ident) => {{
418        match $VALUES.data_type() {
419            DataType::Null => ScalarValue::Null,
420            DataType::Decimal128(precision, scale) => {
421                typed_min_max_batch!(
422                    $VALUES,
423                    Decimal128Array,
424                    Decimal128,
425                    $OP,
426                    precision,
427                    scale
428                )
429            }
430            DataType::Decimal256(precision, scale) => {
431                typed_min_max_batch!(
432                    $VALUES,
433                    Decimal256Array,
434                    Decimal256,
435                    $OP,
436                    precision,
437                    scale
438                )
439            }
440            // all types that have a natural order
441            DataType::Float64 => {
442                typed_min_max_batch!($VALUES, Float64Array, Float64, $OP)
443            }
444            DataType::Float32 => {
445                typed_min_max_batch!($VALUES, Float32Array, Float32, $OP)
446            }
447            DataType::Float16 => {
448                typed_min_max_batch!($VALUES, Float16Array, Float16, $OP)
449            }
450            DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP),
451            DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP),
452            DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP),
453            DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP),
454            DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP),
455            DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP),
456            DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP),
457            DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP),
458            DataType::Timestamp(TimeUnit::Second, tz_opt) => {
459                typed_min_max_batch!(
460                    $VALUES,
461                    TimestampSecondArray,
462                    TimestampSecond,
463                    $OP,
464                    tz_opt
465                )
466            }
467            DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!(
468                $VALUES,
469                TimestampMillisecondArray,
470                TimestampMillisecond,
471                $OP,
472                tz_opt
473            ),
474            DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!(
475                $VALUES,
476                TimestampMicrosecondArray,
477                TimestampMicrosecond,
478                $OP,
479                tz_opt
480            ),
481            DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!(
482                $VALUES,
483                TimestampNanosecondArray,
484                TimestampNanosecond,
485                $OP,
486                tz_opt
487            ),
488            DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP),
489            DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP),
490            DataType::Time32(TimeUnit::Second) => {
491                typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP)
492            }
493            DataType::Time32(TimeUnit::Millisecond) => {
494                typed_min_max_batch!(
495                    $VALUES,
496                    Time32MillisecondArray,
497                    Time32Millisecond,
498                    $OP
499                )
500            }
501            DataType::Time64(TimeUnit::Microsecond) => {
502                typed_min_max_batch!(
503                    $VALUES,
504                    Time64MicrosecondArray,
505                    Time64Microsecond,
506                    $OP
507                )
508            }
509            DataType::Time64(TimeUnit::Nanosecond) => {
510                typed_min_max_batch!(
511                    $VALUES,
512                    Time64NanosecondArray,
513                    Time64Nanosecond,
514                    $OP
515                )
516            }
517            DataType::Interval(IntervalUnit::YearMonth) => {
518                typed_min_max_batch!(
519                    $VALUES,
520                    IntervalYearMonthArray,
521                    IntervalYearMonth,
522                    $OP
523                )
524            }
525            DataType::Interval(IntervalUnit::DayTime) => {
526                typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP)
527            }
528            DataType::Interval(IntervalUnit::MonthDayNano) => {
529                typed_min_max_batch!(
530                    $VALUES,
531                    IntervalMonthDayNanoArray,
532                    IntervalMonthDayNano,
533                    $OP
534                )
535            }
536            DataType::Duration(TimeUnit::Second) => {
537                typed_min_max_batch!($VALUES, DurationSecondArray, DurationSecond, $OP)
538            }
539            DataType::Duration(TimeUnit::Millisecond) => {
540                typed_min_max_batch!(
541                    $VALUES,
542                    DurationMillisecondArray,
543                    DurationMillisecond,
544                    $OP
545                )
546            }
547            DataType::Duration(TimeUnit::Microsecond) => {
548                typed_min_max_batch!(
549                    $VALUES,
550                    DurationMicrosecondArray,
551                    DurationMicrosecond,
552                    $OP
553                )
554            }
555            DataType::Duration(TimeUnit::Nanosecond) => {
556                typed_min_max_batch!(
557                    $VALUES,
558                    DurationNanosecondArray,
559                    DurationNanosecond,
560                    $OP
561                )
562            }
563            other => {
564                // This should have been handled before
565                return internal_err!(
566                    "Min/Max accumulator not implemented for type {:?}",
567                    other
568                );
569            }
570        }
571    }};
572}
573
574/// dynamically-typed min(array) -> ScalarValue
575fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
576    Ok(match values.data_type() {
577        DataType::Utf8 => {
578            typed_min_max_batch_string!(values, StringArray, Utf8, min_string)
579        }
580        DataType::LargeUtf8 => {
581            typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string)
582        }
583        DataType::Utf8View => {
584            typed_min_max_batch_string!(
585                values,
586                StringViewArray,
587                Utf8View,
588                min_string_view
589            )
590        }
591        DataType::Boolean => {
592            typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean)
593        }
594        DataType::Binary => {
595            typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary)
596        }
597        DataType::LargeBinary => {
598            typed_min_max_batch_binary!(
599                &values,
600                LargeBinaryArray,
601                LargeBinary,
602                min_binary
603            )
604        }
605        DataType::BinaryView => {
606            typed_min_max_batch_binary!(
607                &values,
608                BinaryViewArray,
609                BinaryView,
610                min_binary_view
611            )
612        }
613        _ => min_max_batch!(values, min),
614    })
615}
616
617/// dynamically-typed max(array) -> ScalarValue
618pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
619    Ok(match values.data_type() {
620        DataType::Utf8 => {
621            typed_min_max_batch_string!(values, StringArray, Utf8, max_string)
622        }
623        DataType::LargeUtf8 => {
624            typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string)
625        }
626        DataType::Utf8View => {
627            typed_min_max_batch_string!(
628                values,
629                StringViewArray,
630                Utf8View,
631                max_string_view
632            )
633        }
634        DataType::Boolean => {
635            typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean)
636        }
637        DataType::Binary => {
638            typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary)
639        }
640        DataType::BinaryView => {
641            typed_min_max_batch_binary!(
642                &values,
643                BinaryViewArray,
644                BinaryView,
645                max_binary_view
646            )
647        }
648        DataType::LargeBinary => {
649            typed_min_max_batch_binary!(
650                &values,
651                LargeBinaryArray,
652                LargeBinary,
653                max_binary
654            )
655        }
656        _ => min_max_batch!(values, max),
657    })
658}
659
660// min/max of two non-string scalar values.
661macro_rules! typed_min_max {
662    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
663        ScalarValue::$SCALAR(
664            match ($VALUE, $DELTA) {
665                (None, None) => None,
666                (Some(a), None) => Some(*a),
667                (None, Some(b)) => Some(*b),
668                (Some(a), Some(b)) => Some((*a).$OP(*b)),
669            },
670            $($EXTRA_ARGS.clone()),*
671        )
672    }};
673}
674macro_rules! typed_min_max_float {
675    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
676        ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
677            (None, None) => None,
678            (Some(a), None) => Some(*a),
679            (None, Some(b)) => Some(*b),
680            (Some(a), Some(b)) => match a.total_cmp(b) {
681                choose_min_max!($OP) => Some(*b),
682                _ => Some(*a),
683            },
684        })
685    }};
686}
687
688// min/max of two scalar string values.
689macro_rules! typed_min_max_string {
690    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
691        ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
692            (None, None) => None,
693            (Some(a), None) => Some(a.clone()),
694            (None, Some(b)) => Some(b.clone()),
695            (Some(a), Some(b)) => Some((a).$OP(b).clone()),
696        })
697    }};
698}
699
700macro_rules! choose_min_max {
701    (min) => {
702        std::cmp::Ordering::Greater
703    };
704    (max) => {
705        std::cmp::Ordering::Less
706    };
707}
708
709macro_rules! interval_min_max {
710    ($OP:tt, $LHS:expr, $RHS:expr) => {{
711        match $LHS.partial_cmp(&$RHS) {
712            Some(choose_min_max!($OP)) => $RHS.clone(),
713            Some(_) => $LHS.clone(),
714            None => {
715                return internal_err!("Comparison error while computing interval min/max")
716            }
717        }
718    }};
719}
720
721// min/max of two scalar values of the same type
722macro_rules! min_max {
723    ($VALUE:expr, $DELTA:expr, $OP:ident) => {{
724        Ok(match ($VALUE, $DELTA) {
725            (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null,
726            (
727                lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss),
728                rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss)
729            ) => {
730                if lhsp.eq(rhsp) && lhss.eq(rhss) {
731                    typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss)
732                } else {
733                    return internal_err!(
734                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
735                    (lhs, rhs)
736                );
737                }
738            }
739            (
740                lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss),
741                rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss)
742            ) => {
743                if lhsp.eq(rhsp) && lhss.eq(rhss) {
744                    typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss)
745                } else {
746                    return internal_err!(
747                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
748                    (lhs, rhs)
749                );
750                }
751            }
752            (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => {
753                typed_min_max!(lhs, rhs, Boolean, $OP)
754            }
755            (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
756                typed_min_max_float!(lhs, rhs, Float64, $OP)
757            }
758            (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
759                typed_min_max_float!(lhs, rhs, Float32, $OP)
760            }
761            (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => {
762                typed_min_max_float!(lhs, rhs, Float16, $OP)
763            }
764            (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
765                typed_min_max!(lhs, rhs, UInt64, $OP)
766            }
767            (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
768                typed_min_max!(lhs, rhs, UInt32, $OP)
769            }
770            (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
771                typed_min_max!(lhs, rhs, UInt16, $OP)
772            }
773            (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
774                typed_min_max!(lhs, rhs, UInt8, $OP)
775            }
776            (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
777                typed_min_max!(lhs, rhs, Int64, $OP)
778            }
779            (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
780                typed_min_max!(lhs, rhs, Int32, $OP)
781            }
782            (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
783                typed_min_max!(lhs, rhs, Int16, $OP)
784            }
785            (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
786                typed_min_max!(lhs, rhs, Int8, $OP)
787            }
788            (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => {
789                typed_min_max_string!(lhs, rhs, Utf8, $OP)
790            }
791            (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => {
792                typed_min_max_string!(lhs, rhs, LargeUtf8, $OP)
793            }
794            (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => {
795                typed_min_max_string!(lhs, rhs, Utf8View, $OP)
796            }
797            (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => {
798                typed_min_max_string!(lhs, rhs, Binary, $OP)
799            }
800            (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => {
801                typed_min_max_string!(lhs, rhs, LargeBinary, $OP)
802            }
803            (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => {
804                typed_min_max_string!(lhs, rhs, BinaryView, $OP)
805            }
806            (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => {
807                typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz)
808            }
809            (
810                ScalarValue::TimestampMillisecond(lhs, l_tz),
811                ScalarValue::TimestampMillisecond(rhs, _),
812            ) => {
813                typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz)
814            }
815            (
816                ScalarValue::TimestampMicrosecond(lhs, l_tz),
817                ScalarValue::TimestampMicrosecond(rhs, _),
818            ) => {
819                typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz)
820            }
821            (
822                ScalarValue::TimestampNanosecond(lhs, l_tz),
823                ScalarValue::TimestampNanosecond(rhs, _),
824            ) => {
825                typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz)
826            }
827            (
828                ScalarValue::Date32(lhs),
829                ScalarValue::Date32(rhs),
830            ) => {
831                typed_min_max!(lhs, rhs, Date32, $OP)
832            }
833            (
834                ScalarValue::Date64(lhs),
835                ScalarValue::Date64(rhs),
836            ) => {
837                typed_min_max!(lhs, rhs, Date64, $OP)
838            }
839            (
840                ScalarValue::Time32Second(lhs),
841                ScalarValue::Time32Second(rhs),
842            ) => {
843                typed_min_max!(lhs, rhs, Time32Second, $OP)
844            }
845            (
846                ScalarValue::Time32Millisecond(lhs),
847                ScalarValue::Time32Millisecond(rhs),
848            ) => {
849                typed_min_max!(lhs, rhs, Time32Millisecond, $OP)
850            }
851            (
852                ScalarValue::Time64Microsecond(lhs),
853                ScalarValue::Time64Microsecond(rhs),
854            ) => {
855                typed_min_max!(lhs, rhs, Time64Microsecond, $OP)
856            }
857            (
858                ScalarValue::Time64Nanosecond(lhs),
859                ScalarValue::Time64Nanosecond(rhs),
860            ) => {
861                typed_min_max!(lhs, rhs, Time64Nanosecond, $OP)
862            }
863            (
864                ScalarValue::IntervalYearMonth(lhs),
865                ScalarValue::IntervalYearMonth(rhs),
866            ) => {
867                typed_min_max!(lhs, rhs, IntervalYearMonth, $OP)
868            }
869            (
870                ScalarValue::IntervalMonthDayNano(lhs),
871                ScalarValue::IntervalMonthDayNano(rhs),
872            ) => {
873                typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP)
874            }
875            (
876                ScalarValue::IntervalDayTime(lhs),
877                ScalarValue::IntervalDayTime(rhs),
878            ) => {
879                typed_min_max!(lhs, rhs, IntervalDayTime, $OP)
880            }
881            (
882                ScalarValue::IntervalYearMonth(_),
883                ScalarValue::IntervalMonthDayNano(_),
884            ) | (
885                ScalarValue::IntervalYearMonth(_),
886                ScalarValue::IntervalDayTime(_),
887            ) | (
888                ScalarValue::IntervalMonthDayNano(_),
889                ScalarValue::IntervalDayTime(_),
890            ) | (
891                ScalarValue::IntervalMonthDayNano(_),
892                ScalarValue::IntervalYearMonth(_),
893            ) | (
894                ScalarValue::IntervalDayTime(_),
895                ScalarValue::IntervalYearMonth(_),
896            ) | (
897                ScalarValue::IntervalDayTime(_),
898                ScalarValue::IntervalMonthDayNano(_),
899            ) => {
900                interval_min_max!($OP, $VALUE, $DELTA)
901            }
902                    (
903                ScalarValue::DurationSecond(lhs),
904                ScalarValue::DurationSecond(rhs),
905            ) => {
906                typed_min_max!(lhs, rhs, DurationSecond, $OP)
907            }
908                                (
909                ScalarValue::DurationMillisecond(lhs),
910                ScalarValue::DurationMillisecond(rhs),
911            ) => {
912                typed_min_max!(lhs, rhs, DurationMillisecond, $OP)
913            }
914                                (
915                ScalarValue::DurationMicrosecond(lhs),
916                ScalarValue::DurationMicrosecond(rhs),
917            ) => {
918                typed_min_max!(lhs, rhs, DurationMicrosecond, $OP)
919            }
920                                        (
921                ScalarValue::DurationNanosecond(lhs),
922                ScalarValue::DurationNanosecond(rhs),
923            ) => {
924                typed_min_max!(lhs, rhs, DurationNanosecond, $OP)
925            }
926            e => {
927                return internal_err!(
928                    "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
929                    e
930                )
931            }
932        })
933    }};
934}
935
936/// An accumulator to compute the maximum value
937#[derive(Debug)]
938pub struct MaxAccumulator {
939    max: ScalarValue,
940}
941
942impl MaxAccumulator {
943    /// new max accumulator
944    pub fn try_new(datatype: &DataType) -> Result<Self> {
945        Ok(Self {
946            max: ScalarValue::try_from(datatype)?,
947        })
948    }
949}
950
951impl Accumulator for MaxAccumulator {
952    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
953        let values = &values[0];
954        let delta = &max_batch(values)?;
955        let new_max: Result<ScalarValue, DataFusionError> =
956            min_max!(&self.max, delta, max);
957        self.max = new_max?;
958        Ok(())
959    }
960
961    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
962        self.update_batch(states)
963    }
964
965    fn state(&mut self) -> Result<Vec<ScalarValue>> {
966        Ok(vec![self.evaluate()?])
967    }
968    fn evaluate(&mut self) -> Result<ScalarValue> {
969        Ok(self.max.clone())
970    }
971
972    fn size(&self) -> usize {
973        size_of_val(self) - size_of_val(&self.max) + self.max.size()
974    }
975}
976
977#[derive(Debug)]
978pub struct SlidingMaxAccumulator {
979    max: ScalarValue,
980    moving_max: MovingMax<ScalarValue>,
981}
982
983impl SlidingMaxAccumulator {
984    /// new max accumulator
985    pub fn try_new(datatype: &DataType) -> Result<Self> {
986        Ok(Self {
987            max: ScalarValue::try_from(datatype)?,
988            moving_max: MovingMax::<ScalarValue>::new(),
989        })
990    }
991}
992
993impl Accumulator for SlidingMaxAccumulator {
994    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
995        for idx in 0..values[0].len() {
996            let val = ScalarValue::try_from_array(&values[0], idx)?;
997            self.moving_max.push(val);
998        }
999        if let Some(res) = self.moving_max.max() {
1000            self.max = res.clone();
1001        }
1002        Ok(())
1003    }
1004
1005    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1006        for _idx in 0..values[0].len() {
1007            (self.moving_max).pop();
1008        }
1009        if let Some(res) = self.moving_max.max() {
1010            self.max = res.clone();
1011        }
1012        Ok(())
1013    }
1014
1015    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1016        self.update_batch(states)
1017    }
1018
1019    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1020        Ok(vec![self.max.clone()])
1021    }
1022
1023    fn evaluate(&mut self) -> Result<ScalarValue> {
1024        Ok(self.max.clone())
1025    }
1026
1027    fn supports_retract_batch(&self) -> bool {
1028        true
1029    }
1030
1031    fn size(&self) -> usize {
1032        size_of_val(self) - size_of_val(&self.max) + self.max.size()
1033    }
1034}
1035
1036#[user_doc(
1037    doc_section(label = "General Functions"),
1038    description = "Returns the minimum value in the specified column.",
1039    syntax_example = "min(expression)",
1040    sql_example = r#"```sql
1041> SELECT min(column_name) FROM table_name;
1042+----------------------+
1043| min(column_name)      |
1044+----------------------+
1045| 12                   |
1046+----------------------+
1047```"#,
1048    standard_argument(name = "expression",)
1049)]
1050#[derive(Debug)]
1051pub struct Min {
1052    signature: Signature,
1053}
1054
1055impl Min {
1056    pub fn new() -> Self {
1057        Self {
1058            signature: Signature::user_defined(Volatility::Immutable),
1059        }
1060    }
1061}
1062
1063impl Default for Min {
1064    fn default() -> Self {
1065        Self::new()
1066    }
1067}
1068
1069impl FromColumnStatistics for Min {
1070    fn value_from_column_statistics(
1071        &self,
1072        col_stats: &ColumnStatistics,
1073    ) -> Option<ScalarValue> {
1074        if let Precision::Exact(ref val) = col_stats.min_value {
1075            if !val.is_null() {
1076                return Some(val.clone());
1077            }
1078        }
1079        None
1080    }
1081}
1082
1083impl AggregateUDFImpl for Min {
1084    fn as_any(&self) -> &dyn std::any::Any {
1085        self
1086    }
1087
1088    fn name(&self) -> &str {
1089        "min"
1090    }
1091
1092    fn signature(&self) -> &Signature {
1093        &self.signature
1094    }
1095
1096    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1097        Ok(arg_types[0].to_owned())
1098    }
1099
1100    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1101        Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?))
1102    }
1103
1104    fn aliases(&self) -> &[String] {
1105        &[]
1106    }
1107
1108    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1109        use DataType::*;
1110        matches!(
1111            args.return_type,
1112            Int8 | Int16
1113                | Int32
1114                | Int64
1115                | UInt8
1116                | UInt16
1117                | UInt32
1118                | UInt64
1119                | Float16
1120                | Float32
1121                | Float64
1122                | Decimal128(_, _)
1123                | Decimal256(_, _)
1124                | Date32
1125                | Date64
1126                | Time32(_)
1127                | Time64(_)
1128                | Timestamp(_, _)
1129                | Utf8
1130                | LargeUtf8
1131                | Utf8View
1132                | Binary
1133                | LargeBinary
1134                | BinaryView
1135                | Duration(_)
1136        )
1137    }
1138
1139    fn create_groups_accumulator(
1140        &self,
1141        args: AccumulatorArgs,
1142    ) -> Result<Box<dyn GroupsAccumulator>> {
1143        use DataType::*;
1144        use TimeUnit::*;
1145        let data_type = args.return_type;
1146        match data_type {
1147            Int8 => primitive_min_accumulator!(data_type, i8, Int8Type),
1148            Int16 => primitive_min_accumulator!(data_type, i16, Int16Type),
1149            Int32 => primitive_min_accumulator!(data_type, i32, Int32Type),
1150            Int64 => primitive_min_accumulator!(data_type, i64, Int64Type),
1151            UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type),
1152            UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type),
1153            UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type),
1154            UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type),
1155            Float16 => {
1156                primitive_min_accumulator!(data_type, f16, Float16Type)
1157            }
1158            Float32 => {
1159                primitive_min_accumulator!(data_type, f32, Float32Type)
1160            }
1161            Float64 => {
1162                primitive_min_accumulator!(data_type, f64, Float64Type)
1163            }
1164            Date32 => primitive_min_accumulator!(data_type, i32, Date32Type),
1165            Date64 => primitive_min_accumulator!(data_type, i64, Date64Type),
1166            Time32(Second) => {
1167                primitive_min_accumulator!(data_type, i32, Time32SecondType)
1168            }
1169            Time32(Millisecond) => {
1170                primitive_min_accumulator!(data_type, i32, Time32MillisecondType)
1171            }
1172            Time64(Microsecond) => {
1173                primitive_min_accumulator!(data_type, i64, Time64MicrosecondType)
1174            }
1175            Time64(Nanosecond) => {
1176                primitive_min_accumulator!(data_type, i64, Time64NanosecondType)
1177            }
1178            Timestamp(Second, _) => {
1179                primitive_min_accumulator!(data_type, i64, TimestampSecondType)
1180            }
1181            Timestamp(Millisecond, _) => {
1182                primitive_min_accumulator!(data_type, i64, TimestampMillisecondType)
1183            }
1184            Timestamp(Microsecond, _) => {
1185                primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType)
1186            }
1187            Timestamp(Nanosecond, _) => {
1188                primitive_min_accumulator!(data_type, i64, TimestampNanosecondType)
1189            }
1190            Duration(Second) => {
1191                primitive_min_accumulator!(data_type, i64, DurationSecondType)
1192            }
1193            Duration(Millisecond) => {
1194                primitive_min_accumulator!(data_type, i64, DurationMillisecondType)
1195            }
1196            Duration(Microsecond) => {
1197                primitive_min_accumulator!(data_type, i64, DurationMicrosecondType)
1198            }
1199            Duration(Nanosecond) => {
1200                primitive_min_accumulator!(data_type, i64, DurationNanosecondType)
1201            }
1202            Decimal128(_, _) => {
1203                primitive_min_accumulator!(data_type, i128, Decimal128Type)
1204            }
1205            Decimal256(_, _) => {
1206                primitive_min_accumulator!(data_type, i256, Decimal256Type)
1207            }
1208            Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
1209                Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone())))
1210            }
1211
1212            // This is only reached if groups_accumulator_supported is out of sync
1213            _ => internal_err!("GroupsAccumulator not supported for min({})", data_type),
1214        }
1215    }
1216
1217    fn create_sliding_accumulator(
1218        &self,
1219        args: AccumulatorArgs,
1220    ) -> Result<Box<dyn Accumulator>> {
1221        Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?))
1222    }
1223
1224    fn is_descending(&self) -> Option<bool> {
1225        Some(false)
1226    }
1227
1228    fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
1229        self.value_from_statistics(statistics_args)
1230    }
1231    fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
1232        datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
1233    }
1234
1235    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1236        get_min_max_result_type(arg_types)
1237    }
1238
1239    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
1240        datafusion_expr::ReversedUDAF::Identical
1241    }
1242
1243    fn documentation(&self) -> Option<&Documentation> {
1244        self.doc()
1245    }
1246
1247    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
1248        // `MIN` is monotonically decreasing as it always decreases or stays
1249        // the same as new values are seen.
1250        SetMonotonicity::Decreasing
1251    }
1252}
1253
1254/// An accumulator to compute the minimum value
1255#[derive(Debug)]
1256pub struct MinAccumulator {
1257    min: ScalarValue,
1258}
1259
1260impl MinAccumulator {
1261    /// new min accumulator
1262    pub fn try_new(datatype: &DataType) -> Result<Self> {
1263        Ok(Self {
1264            min: ScalarValue::try_from(datatype)?,
1265        })
1266    }
1267}
1268
1269impl Accumulator for MinAccumulator {
1270    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1271        Ok(vec![self.evaluate()?])
1272    }
1273
1274    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1275        let values = &values[0];
1276        let delta = &min_batch(values)?;
1277        let new_min: Result<ScalarValue, DataFusionError> =
1278            min_max!(&self.min, delta, min);
1279        self.min = new_min?;
1280        Ok(())
1281    }
1282
1283    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1284        self.update_batch(states)
1285    }
1286
1287    fn evaluate(&mut self) -> Result<ScalarValue> {
1288        Ok(self.min.clone())
1289    }
1290
1291    fn size(&self) -> usize {
1292        size_of_val(self) - size_of_val(&self.min) + self.min.size()
1293    }
1294}
1295
1296#[derive(Debug)]
1297pub struct SlidingMinAccumulator {
1298    min: ScalarValue,
1299    moving_min: MovingMin<ScalarValue>,
1300}
1301
1302impl SlidingMinAccumulator {
1303    pub fn try_new(datatype: &DataType) -> Result<Self> {
1304        Ok(Self {
1305            min: ScalarValue::try_from(datatype)?,
1306            moving_min: MovingMin::<ScalarValue>::new(),
1307        })
1308    }
1309}
1310
1311impl Accumulator for SlidingMinAccumulator {
1312    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1313        Ok(vec![self.min.clone()])
1314    }
1315
1316    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1317        for idx in 0..values[0].len() {
1318            let val = ScalarValue::try_from_array(&values[0], idx)?;
1319            if !val.is_null() {
1320                self.moving_min.push(val);
1321            }
1322        }
1323        if let Some(res) = self.moving_min.min() {
1324            self.min = res.clone();
1325        }
1326        Ok(())
1327    }
1328
1329    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1330        for idx in 0..values[0].len() {
1331            let val = ScalarValue::try_from_array(&values[0], idx)?;
1332            if !val.is_null() {
1333                (self.moving_min).pop();
1334            }
1335        }
1336        if let Some(res) = self.moving_min.min() {
1337            self.min = res.clone();
1338        }
1339        Ok(())
1340    }
1341
1342    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1343        self.update_batch(states)
1344    }
1345
1346    fn evaluate(&mut self) -> Result<ScalarValue> {
1347        Ok(self.min.clone())
1348    }
1349
1350    fn supports_retract_batch(&self) -> bool {
1351        true
1352    }
1353
1354    fn size(&self) -> usize {
1355        size_of_val(self) - size_of_val(&self.min) + self.min.size()
1356    }
1357}
1358
1359/// Keep track of the minimum value in a sliding window.
1360///
1361/// The implementation is taken from <https://github.com/spebern/moving_min_max/blob/master/src/lib.rs>
1362///
1363/// `moving min max` provides one data structure for keeping track of the
1364/// minimum value and one for keeping track of the maximum value in a sliding
1365/// window.
1366///
1367/// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty,
1368/// push to this stack all elements popped from first stack while updating their current min/max. Now pop from
1369/// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue,
1370/// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values.
1371///
1372/// The complexity of the operations are
1373/// - O(1) for getting the minimum/maximum
1374/// - O(1) for push
1375/// - amortized O(1) for pop
1376///
1377/// ```
1378/// # use datafusion_functions_aggregate::min_max::MovingMin;
1379/// let mut moving_min = MovingMin::<i32>::new();
1380/// moving_min.push(2);
1381/// moving_min.push(1);
1382/// moving_min.push(3);
1383///
1384/// assert_eq!(moving_min.min(), Some(&1));
1385/// assert_eq!(moving_min.pop(), Some(2));
1386///
1387/// assert_eq!(moving_min.min(), Some(&1));
1388/// assert_eq!(moving_min.pop(), Some(1));
1389///
1390/// assert_eq!(moving_min.min(), Some(&3));
1391/// assert_eq!(moving_min.pop(), Some(3));
1392///
1393/// assert_eq!(moving_min.min(), None);
1394/// assert_eq!(moving_min.pop(), None);
1395/// ```
1396#[derive(Debug)]
1397pub struct MovingMin<T> {
1398    push_stack: Vec<(T, T)>,
1399    pop_stack: Vec<(T, T)>,
1400}
1401
1402impl<T: Clone + PartialOrd> Default for MovingMin<T> {
1403    fn default() -> Self {
1404        Self {
1405            push_stack: Vec::new(),
1406            pop_stack: Vec::new(),
1407        }
1408    }
1409}
1410
1411impl<T: Clone + PartialOrd> MovingMin<T> {
1412    /// Creates a new `MovingMin` to keep track of the minimum in a sliding
1413    /// window.
1414    #[inline]
1415    pub fn new() -> Self {
1416        Self::default()
1417    }
1418
1419    /// Creates a new `MovingMin` to keep track of the minimum in a sliding
1420    /// window with `capacity` allocated slots.
1421    #[inline]
1422    pub fn with_capacity(capacity: usize) -> Self {
1423        Self {
1424            push_stack: Vec::with_capacity(capacity),
1425            pop_stack: Vec::with_capacity(capacity),
1426        }
1427    }
1428
1429    /// Returns the minimum of the sliding window or `None` if the window is
1430    /// empty.
1431    #[inline]
1432    pub fn min(&self) -> Option<&T> {
1433        match (self.push_stack.last(), self.pop_stack.last()) {
1434            (None, None) => None,
1435            (Some((_, min)), None) => Some(min),
1436            (None, Some((_, min))) => Some(min),
1437            (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }),
1438        }
1439    }
1440
1441    /// Pushes a new element into the sliding window.
1442    #[inline]
1443    pub fn push(&mut self, val: T) {
1444        self.push_stack.push(match self.push_stack.last() {
1445            Some((_, min)) => {
1446                if val > *min {
1447                    (val, min.clone())
1448                } else {
1449                    (val.clone(), val)
1450                }
1451            }
1452            None => (val.clone(), val),
1453        });
1454    }
1455
1456    /// Removes and returns the last value of the sliding window.
1457    #[inline]
1458    pub fn pop(&mut self) -> Option<T> {
1459        if self.pop_stack.is_empty() {
1460            match self.push_stack.pop() {
1461                Some((val, _)) => {
1462                    let mut last = (val.clone(), val);
1463                    self.pop_stack.push(last.clone());
1464                    while let Some((val, _)) = self.push_stack.pop() {
1465                        let min = if last.1 < val {
1466                            last.1.clone()
1467                        } else {
1468                            val.clone()
1469                        };
1470                        last = (val.clone(), min);
1471                        self.pop_stack.push(last.clone());
1472                    }
1473                }
1474                None => return None,
1475            }
1476        }
1477        self.pop_stack.pop().map(|(val, _)| val)
1478    }
1479
1480    /// Returns the number of elements stored in the sliding window.
1481    #[inline]
1482    pub fn len(&self) -> usize {
1483        self.push_stack.len() + self.pop_stack.len()
1484    }
1485
1486    /// Returns `true` if the moving window contains no elements.
1487    #[inline]
1488    pub fn is_empty(&self) -> bool {
1489        self.len() == 0
1490    }
1491}
1492
1493/// Keep track of the maximum value in a sliding window.
1494///
1495/// See [`MovingMin`] for more details.
1496///
1497/// ```
1498/// # use datafusion_functions_aggregate::min_max::MovingMax;
1499/// let mut moving_max = MovingMax::<i32>::new();
1500/// moving_max.push(2);
1501/// moving_max.push(3);
1502/// moving_max.push(1);
1503///
1504/// assert_eq!(moving_max.max(), Some(&3));
1505/// assert_eq!(moving_max.pop(), Some(2));
1506///
1507/// assert_eq!(moving_max.max(), Some(&3));
1508/// assert_eq!(moving_max.pop(), Some(3));
1509///
1510/// assert_eq!(moving_max.max(), Some(&1));
1511/// assert_eq!(moving_max.pop(), Some(1));
1512///
1513/// assert_eq!(moving_max.max(), None);
1514/// assert_eq!(moving_max.pop(), None);
1515/// ```
1516#[derive(Debug)]
1517pub struct MovingMax<T> {
1518    push_stack: Vec<(T, T)>,
1519    pop_stack: Vec<(T, T)>,
1520}
1521
1522impl<T: Clone + PartialOrd> Default for MovingMax<T> {
1523    fn default() -> Self {
1524        Self {
1525            push_stack: Vec::new(),
1526            pop_stack: Vec::new(),
1527        }
1528    }
1529}
1530
1531impl<T: Clone + PartialOrd> MovingMax<T> {
1532    /// Creates a new `MovingMax` to keep track of the maximum in a sliding window.
1533    #[inline]
1534    pub fn new() -> Self {
1535        Self::default()
1536    }
1537
1538    /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with
1539    /// `capacity` allocated slots.
1540    #[inline]
1541    pub fn with_capacity(capacity: usize) -> Self {
1542        Self {
1543            push_stack: Vec::with_capacity(capacity),
1544            pop_stack: Vec::with_capacity(capacity),
1545        }
1546    }
1547
1548    /// Returns the maximum of the sliding window or `None` if the window is empty.
1549    #[inline]
1550    pub fn max(&self) -> Option<&T> {
1551        match (self.push_stack.last(), self.pop_stack.last()) {
1552            (None, None) => None,
1553            (Some((_, max)), None) => Some(max),
1554            (None, Some((_, max))) => Some(max),
1555            (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }),
1556        }
1557    }
1558
1559    /// Pushes a new element into the sliding window.
1560    #[inline]
1561    pub fn push(&mut self, val: T) {
1562        self.push_stack.push(match self.push_stack.last() {
1563            Some((_, max)) => {
1564                if val < *max {
1565                    (val, max.clone())
1566                } else {
1567                    (val.clone(), val)
1568                }
1569            }
1570            None => (val.clone(), val),
1571        });
1572    }
1573
1574    /// Removes and returns the last value of the sliding window.
1575    #[inline]
1576    pub fn pop(&mut self) -> Option<T> {
1577        if self.pop_stack.is_empty() {
1578            match self.push_stack.pop() {
1579                Some((val, _)) => {
1580                    let mut last = (val.clone(), val);
1581                    self.pop_stack.push(last.clone());
1582                    while let Some((val, _)) = self.push_stack.pop() {
1583                        let max = if last.1 > val {
1584                            last.1.clone()
1585                        } else {
1586                            val.clone()
1587                        };
1588                        last = (val.clone(), max);
1589                        self.pop_stack.push(last.clone());
1590                    }
1591                }
1592                None => return None,
1593            }
1594        }
1595        self.pop_stack.pop().map(|(val, _)| val)
1596    }
1597
1598    /// Returns the number of elements stored in the sliding window.
1599    #[inline]
1600    pub fn len(&self) -> usize {
1601        self.push_stack.len() + self.pop_stack.len()
1602    }
1603
1604    /// Returns `true` if the moving window contains no elements.
1605    #[inline]
1606    pub fn is_empty(&self) -> bool {
1607        self.len() == 0
1608    }
1609}
1610
1611make_udaf_expr_and_func!(
1612    Max,
1613    max,
1614    expression,
1615    "Returns the maximum of a group of values.",
1616    max_udaf
1617);
1618
1619make_udaf_expr_and_func!(
1620    Min,
1621    min,
1622    expression,
1623    "Returns the minimum of a group of values.",
1624    min_udaf
1625);
1626
1627#[cfg(test)]
1628mod tests {
1629    use super::*;
1630    use arrow::datatypes::{
1631        IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
1632    };
1633    use std::sync::Arc;
1634
1635    #[test]
1636    fn interval_min_max() {
1637        // IntervalYearMonth
1638        let b = IntervalYearMonthArray::from(vec![
1639            IntervalYearMonthType::make_value(0, 1),
1640            IntervalYearMonthType::make_value(5, 34),
1641            IntervalYearMonthType::make_value(-2, 4),
1642            IntervalYearMonthType::make_value(7, -4),
1643            IntervalYearMonthType::make_value(0, 1),
1644        ]);
1645        let b: ArrayRef = Arc::new(b);
1646
1647        let mut min =
1648            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1649                .unwrap();
1650        min.update_batch(&[Arc::clone(&b)]).unwrap();
1651        let min_res = min.evaluate().unwrap();
1652        assert_eq!(
1653            min_res,
1654            ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1655                -2, 4,
1656            )))
1657        );
1658
1659        let mut max =
1660            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth))
1661                .unwrap();
1662        max.update_batch(&[Arc::clone(&b)]).unwrap();
1663        let max_res = max.evaluate().unwrap();
1664        assert_eq!(
1665            max_res,
1666            ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value(
1667                5, 34,
1668            )))
1669        );
1670
1671        // IntervalDayTime
1672        let b = IntervalDayTimeArray::from(vec![
1673            IntervalDayTimeType::make_value(0, 0),
1674            IntervalDayTimeType::make_value(5, 454000),
1675            IntervalDayTimeType::make_value(-34, 0),
1676            IntervalDayTimeType::make_value(7, -4000),
1677            IntervalDayTimeType::make_value(1, 0),
1678        ]);
1679        let b: ArrayRef = Arc::new(b);
1680
1681        let mut min =
1682            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1683        min.update_batch(&[Arc::clone(&b)]).unwrap();
1684        let min_res = min.evaluate().unwrap();
1685        assert_eq!(
1686            min_res,
1687            ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(-34, 0)))
1688        );
1689
1690        let mut max =
1691            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap();
1692        max.update_batch(&[Arc::clone(&b)]).unwrap();
1693        let max_res = max.evaluate().unwrap();
1694        assert_eq!(
1695            max_res,
1696            ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(7, -4000)))
1697        );
1698
1699        // IntervalMonthDayNano
1700        let b = IntervalMonthDayNanoArray::from(vec![
1701            IntervalMonthDayNanoType::make_value(1, 0, 0),
1702            IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000),
1703            IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000),
1704            IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000),
1705            IntervalMonthDayNanoType::make_value(1, 0, 0),
1706        ]);
1707        let b: ArrayRef = Arc::new(b);
1708
1709        let mut min =
1710            MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1711                .unwrap();
1712        min.update_batch(&[Arc::clone(&b)]).unwrap();
1713        let min_res = min.evaluate().unwrap();
1714        assert_eq!(
1715            min_res,
1716            ScalarValue::IntervalMonthDayNano(Some(
1717                IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000)
1718            ))
1719        );
1720
1721        let mut max =
1722            MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano))
1723                .unwrap();
1724        max.update_batch(&[Arc::clone(&b)]).unwrap();
1725        let max_res = max.evaluate().unwrap();
1726        assert_eq!(
1727            max_res,
1728            ScalarValue::IntervalMonthDayNano(Some(
1729                IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000)
1730            ))
1731        );
1732    }
1733
1734    #[test]
1735    fn float_min_max_with_nans() {
1736        let pos_nan = f32::NAN;
1737        let zero = 0_f32;
1738        let neg_inf = f32::NEG_INFINITY;
1739
1740        let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| {
1741            for batch in values.iter() {
1742                let batch =
1743                    Arc::new(Float32Array::from_iter_values(batch.iter().copied()));
1744                acc.update_batch(&[batch]).unwrap();
1745            }
1746            let result = acc.evaluate().unwrap();
1747            assert_eq!(result, ScalarValue::Float32(Some(expected)));
1748        };
1749
1750        // This test checks both comparison between batches (which uses the min_max macro
1751        // defined above) and within a batch (which uses the arrow min/max compute function
1752        // and verifies both respect the total order comparison for floats)
1753
1754        let min = || MinAccumulator::try_new(&DataType::Float32).unwrap();
1755        let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap();
1756
1757        check(&mut min(), &[&[zero], &[pos_nan]], zero);
1758        check(&mut min(), &[&[zero, pos_nan]], zero);
1759        check(&mut min(), &[&[zero], &[neg_inf]], neg_inf);
1760        check(&mut min(), &[&[zero, neg_inf]], neg_inf);
1761        check(&mut max(), &[&[zero], &[pos_nan]], pos_nan);
1762        check(&mut max(), &[&[zero, pos_nan]], pos_nan);
1763        check(&mut max(), &[&[zero], &[neg_inf]], zero);
1764        check(&mut max(), &[&[zero, neg_inf]], zero);
1765    }
1766
1767    use datafusion_common::Result;
1768    use rand::Rng;
1769
1770    fn get_random_vec_i32(len: usize) -> Vec<i32> {
1771        let mut rng = rand::thread_rng();
1772        let mut input = Vec::with_capacity(len);
1773        for _i in 0..len {
1774            input.push(rng.gen_range(0..100));
1775        }
1776        input
1777    }
1778
1779    fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1780        let data = get_random_vec_i32(len);
1781        let mut expected = Vec::with_capacity(len);
1782        let mut moving_min = MovingMin::<i32>::new();
1783        let mut res = Vec::with_capacity(len);
1784        for i in 0..len {
1785            let start = i.saturating_sub(n_sliding_window);
1786            expected.push(*data[start..i + 1].iter().min().unwrap());
1787
1788            moving_min.push(data[i]);
1789            if i > n_sliding_window {
1790                moving_min.pop();
1791            }
1792            res.push(*moving_min.min().unwrap());
1793        }
1794        assert_eq!(res, expected);
1795        Ok(())
1796    }
1797
1798    fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> {
1799        let data = get_random_vec_i32(len);
1800        let mut expected = Vec::with_capacity(len);
1801        let mut moving_max = MovingMax::<i32>::new();
1802        let mut res = Vec::with_capacity(len);
1803        for i in 0..len {
1804            let start = i.saturating_sub(n_sliding_window);
1805            expected.push(*data[start..i + 1].iter().max().unwrap());
1806
1807            moving_max.push(data[i]);
1808            if i > n_sliding_window {
1809                moving_max.pop();
1810            }
1811            res.push(*moving_max.max().unwrap());
1812        }
1813        assert_eq!(res, expected);
1814        Ok(())
1815    }
1816
1817    #[test]
1818    fn moving_min_tests() -> Result<()> {
1819        moving_min_i32(100, 10)?;
1820        moving_min_i32(100, 20)?;
1821        moving_min_i32(100, 50)?;
1822        moving_min_i32(100, 100)?;
1823        Ok(())
1824    }
1825
1826    #[test]
1827    fn moving_max_tests() -> Result<()> {
1828        moving_max_i32(100, 10)?;
1829        moving_max_i32(100, 20)?;
1830        moving_max_i32(100, 50)?;
1831        moving_max_i32(100, 100)?;
1832        Ok(())
1833    }
1834
1835    #[test]
1836    fn test_min_max_coerce_types() {
1837        // the coerced types is same with input types
1838        let funs: Vec<Box<dyn AggregateUDFImpl>> =
1839            vec![Box::new(Min::new()), Box::new(Max::new())];
1840        let input_types = vec![
1841            vec![DataType::Int32],
1842            vec![DataType::Decimal128(10, 2)],
1843            vec![DataType::Decimal256(1, 1)],
1844            vec![DataType::Utf8],
1845        ];
1846        for fun in funs {
1847            for input_type in &input_types {
1848                let result = fun.coerce_types(input_type);
1849                assert_eq!(*input_type, result.unwrap());
1850            }
1851        }
1852    }
1853
1854    #[test]
1855    fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> {
1856        let data_type =
1857            DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32));
1858        let result = get_min_max_result_type(&[data_type])?;
1859        assert_eq!(result, vec![DataType::Int32]);
1860        Ok(())
1861    }
1862}