datafusion_functions_aggregate/
first_last.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the FIRST_VALUE/LAST_VALUE aggregations.
19
20use std::any::Any;
21use std::fmt::Debug;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{
26    Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder,
27    PrimitiveArray,
28};
29use arrow::buffer::{BooleanBuffer, NullBuffer};
30use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
31use arrow::datatypes::{
32    DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, FieldRef,
33    Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
34    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
35    TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
36    TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
37    UInt8Type,
38};
39use datafusion_common::cast::as_boolean_array;
40use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
41use datafusion_common::{
42    arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
43};
44use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
45use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
46use datafusion_expr::{
47    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
48    GroupsAccumulator, Signature, SortExpr, Volatility,
49};
50use datafusion_functions_aggregate_common::utils::get_sort_options;
51use datafusion_macros::user_doc;
52use datafusion_physical_expr_common::sort_expr::LexOrdering;
53
54create_func!(FirstValue, first_value_udaf);
55create_func!(LastValue, last_value_udaf);
56
57/// Returns the first value in a group of values.
58pub fn first_value(expression: Expr, order_by: Option<Vec<SortExpr>>) -> Expr {
59    if let Some(order_by) = order_by {
60        first_value_udaf()
61            .call(vec![expression])
62            .order_by(order_by)
63            .build()
64            // guaranteed to be `Expr::AggregateFunction`
65            .unwrap()
66    } else {
67        first_value_udaf().call(vec![expression])
68    }
69}
70
71/// Returns the last value in a group of values.
72pub fn last_value(expression: Expr, order_by: Option<Vec<SortExpr>>) -> Expr {
73    if let Some(order_by) = order_by {
74        last_value_udaf()
75            .call(vec![expression])
76            .order_by(order_by)
77            .build()
78            // guaranteed to be `Expr::AggregateFunction`
79            .unwrap()
80    } else {
81        last_value_udaf().call(vec![expression])
82    }
83}
84
85#[user_doc(
86    doc_section(label = "General Functions"),
87    description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
88    syntax_example = "first_value(expression [ORDER BY expression])",
89    sql_example = r#"```sql
90> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
91+-----------------------------------------------+
92| first_value(column_name ORDER BY other_column)|
93+-----------------------------------------------+
94| first_element                                 |
95+-----------------------------------------------+
96```"#,
97    standard_argument(name = "expression",)
98)]
99pub struct FirstValue {
100    signature: Signature,
101    requirement_satisfied: bool,
102}
103
104impl Debug for FirstValue {
105    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
106        f.debug_struct("FirstValue")
107            .field("name", &self.name())
108            .field("signature", &self.signature)
109            .field("accumulator", &"<FUNC>")
110            .finish()
111    }
112}
113
114impl Default for FirstValue {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120impl FirstValue {
121    pub fn new() -> Self {
122        Self {
123            signature: Signature::any(1, Volatility::Immutable),
124            requirement_satisfied: false,
125        }
126    }
127
128    fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
129        self.requirement_satisfied = requirement_satisfied;
130        self
131    }
132}
133
134impl AggregateUDFImpl for FirstValue {
135    fn as_any(&self) -> &dyn Any {
136        self
137    }
138
139    fn name(&self) -> &str {
140        "first_value"
141    }
142
143    fn signature(&self) -> &Signature {
144        &self.signature
145    }
146
147    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
148        Ok(arg_types[0].clone())
149    }
150
151    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
152        let ordering_dtypes = acc_args
153            .ordering_req
154            .iter()
155            .map(|e| e.expr.data_type(acc_args.schema))
156            .collect::<Result<Vec<_>>>()?;
157
158        // When requirement is empty, or it is signalled by outside caller that
159        // the ordering requirement is/will be satisfied.
160        let requirement_satisfied =
161            acc_args.ordering_req.is_empty() || self.requirement_satisfied;
162
163        FirstValueAccumulator::try_new(
164            acc_args.return_field.data_type(),
165            &ordering_dtypes,
166            acc_args.ordering_req.clone(),
167            acc_args.ignore_nulls,
168        )
169        .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
170    }
171
172    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
173        let mut fields = vec![Field::new(
174            format_state_name(args.name, "first_value"),
175            args.return_type().clone(),
176            true,
177        )
178        .into()];
179        fields.extend(args.ordering_fields.to_vec());
180        fields.push(Field::new("is_set", DataType::Boolean, true).into());
181        Ok(fields)
182    }
183
184    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
185        // TODO: extract to function
186        use DataType::*;
187        matches!(
188            args.return_field.data_type(),
189            Int8 | Int16
190                | Int32
191                | Int64
192                | UInt8
193                | UInt16
194                | UInt32
195                | UInt64
196                | Float16
197                | Float32
198                | Float64
199                | Decimal128(_, _)
200                | Decimal256(_, _)
201                | Date32
202                | Date64
203                | Time32(_)
204                | Time64(_)
205                | Timestamp(_, _)
206        )
207    }
208
209    fn create_groups_accumulator(
210        &self,
211        args: AccumulatorArgs,
212    ) -> Result<Box<dyn GroupsAccumulator>> {
213        // TODO: extract to function
214        fn create_accumulator<T>(
215            args: AccumulatorArgs,
216        ) -> Result<Box<dyn GroupsAccumulator>>
217        where
218            T: ArrowPrimitiveType + Send,
219        {
220            let ordering_dtypes = args
221                .ordering_req
222                .iter()
223                .map(|e| e.expr.data_type(args.schema))
224                .collect::<Result<Vec<_>>>()?;
225
226            Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
227                args.ordering_req.clone(),
228                args.ignore_nulls,
229                args.return_field.data_type(),
230                &ordering_dtypes,
231                true,
232            )?))
233        }
234
235        match args.return_field.data_type() {
236            DataType::Int8 => create_accumulator::<Int8Type>(args),
237            DataType::Int16 => create_accumulator::<Int16Type>(args),
238            DataType::Int32 => create_accumulator::<Int32Type>(args),
239            DataType::Int64 => create_accumulator::<Int64Type>(args),
240            DataType::UInt8 => create_accumulator::<UInt8Type>(args),
241            DataType::UInt16 => create_accumulator::<UInt16Type>(args),
242            DataType::UInt32 => create_accumulator::<UInt32Type>(args),
243            DataType::UInt64 => create_accumulator::<UInt64Type>(args),
244            DataType::Float16 => create_accumulator::<Float16Type>(args),
245            DataType::Float32 => create_accumulator::<Float32Type>(args),
246            DataType::Float64 => create_accumulator::<Float64Type>(args),
247
248            DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
249            DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
250
251            DataType::Timestamp(TimeUnit::Second, _) => {
252                create_accumulator::<TimestampSecondType>(args)
253            }
254            DataType::Timestamp(TimeUnit::Millisecond, _) => {
255                create_accumulator::<TimestampMillisecondType>(args)
256            }
257            DataType::Timestamp(TimeUnit::Microsecond, _) => {
258                create_accumulator::<TimestampMicrosecondType>(args)
259            }
260            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
261                create_accumulator::<TimestampNanosecondType>(args)
262            }
263
264            DataType::Date32 => create_accumulator::<Date32Type>(args),
265            DataType::Date64 => create_accumulator::<Date64Type>(args),
266            DataType::Time32(TimeUnit::Second) => {
267                create_accumulator::<Time32SecondType>(args)
268            }
269            DataType::Time32(TimeUnit::Millisecond) => {
270                create_accumulator::<Time32MillisecondType>(args)
271            }
272
273            DataType::Time64(TimeUnit::Microsecond) => {
274                create_accumulator::<Time64MicrosecondType>(args)
275            }
276            DataType::Time64(TimeUnit::Nanosecond) => {
277                create_accumulator::<Time64NanosecondType>(args)
278            }
279
280            _ => {
281                internal_err!(
282                    "GroupsAccumulator not supported for first_value({})",
283                    args.return_field.data_type()
284                )
285            }
286        }
287    }
288
289    fn aliases(&self) -> &[String] {
290        &[]
291    }
292
293    fn with_beneficial_ordering(
294        self: Arc<Self>,
295        beneficial_ordering: bool,
296    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
297        Ok(Some(Arc::new(
298            FirstValue::new().with_requirement_satisfied(beneficial_ordering),
299        )))
300    }
301
302    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
303        AggregateOrderSensitivity::Beneficial
304    }
305
306    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
307        datafusion_expr::ReversedUDAF::Reversed(last_value_udaf())
308    }
309
310    fn documentation(&self) -> Option<&Documentation> {
311        self.doc()
312    }
313}
314
315// TODO: rename to PrimitiveGroupsAccumulator
316struct FirstPrimitiveGroupsAccumulator<T>
317where
318    T: ArrowPrimitiveType + Send,
319{
320    // ================ state ===========
321    vals: Vec<T::Native>,
322    // Stores ordering values, of the aggregator requirement corresponding to first value
323    // of the aggregator.
324    // The `orderings` are stored row-wise, meaning that `orderings[group_idx]`
325    // represents the ordering values corresponding to the `group_idx`-th group.
326    orderings: Vec<Vec<ScalarValue>>,
327    // At the beginning, `is_sets[group_idx]` is false, which means `first` is not seen yet.
328    // Once we see the first value, we set the `is_sets[group_idx]` flag
329    is_sets: BooleanBufferBuilder,
330    // null_builder[group_idx] == false => vals[group_idx] is null
331    null_builder: BooleanBufferBuilder,
332    // size of `self.orderings`
333    // Calculating the memory usage of `self.orderings` using `ScalarValue::size_of_vec` is quite costly.
334    // Therefore, we cache it and compute `size_of` only after each update
335    // to avoid calling `ScalarValue::size_of_vec` by Self.size.
336    size_of_orderings: usize,
337
338    // buffer for `get_filtered_min_of_each_group`
339    // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
340    // only valid if filter_min_of_each_group_buf.1[group_idx] == true
341    // TODO: rename to extreme_of_each_group_buf
342    min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
343
344    // =========== option ============
345
346    // Stores the applicable ordering requirement.
347    ordering_req: LexOrdering,
348    // true: take first element in an aggregation group according to the requested ordering.
349    // false: take last element in an aggregation group according to the requested ordering.
350    pick_first_in_group: bool,
351    // derived from `ordering_req`.
352    sort_options: Vec<SortOptions>,
353    // Stores whether incoming data already satisfies the ordering requirement.
354    input_requirement_satisfied: bool,
355    // Ignore null values.
356    ignore_nulls: bool,
357    /// The output type
358    data_type: DataType,
359    default_orderings: Vec<ScalarValue>,
360}
361
362impl<T> FirstPrimitiveGroupsAccumulator<T>
363where
364    T: ArrowPrimitiveType + Send,
365{
366    fn try_new(
367        ordering_req: LexOrdering,
368        ignore_nulls: bool,
369        data_type: &DataType,
370        ordering_dtypes: &[DataType],
371        pick_first_in_group: bool,
372    ) -> Result<Self> {
373        let requirement_satisfied = ordering_req.is_empty();
374
375        let default_orderings = ordering_dtypes
376            .iter()
377            .map(ScalarValue::try_from)
378            .collect::<Result<Vec<_>>>()?;
379
380        let sort_options = get_sort_options(ordering_req.as_ref());
381
382        Ok(Self {
383            null_builder: BooleanBufferBuilder::new(0),
384            ordering_req,
385            sort_options,
386            input_requirement_satisfied: requirement_satisfied,
387            ignore_nulls,
388            default_orderings,
389            data_type: data_type.clone(),
390            vals: Vec::new(),
391            orderings: Vec::new(),
392            is_sets: BooleanBufferBuilder::new(0),
393            size_of_orderings: 0,
394            min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
395            pick_first_in_group,
396        })
397    }
398
399    fn need_update(&self, group_idx: usize) -> bool {
400        if !self.is_sets.get_bit(group_idx) {
401            return true;
402        }
403
404        if self.ignore_nulls && !self.null_builder.get_bit(group_idx) {
405            return true;
406        }
407
408        !self.input_requirement_satisfied
409    }
410
411    fn should_update_state(
412        &self,
413        group_idx: usize,
414        new_ordering_values: &[ScalarValue],
415    ) -> Result<bool> {
416        if !self.is_sets.get_bit(group_idx) {
417            return Ok(true);
418        }
419
420        assert!(new_ordering_values.len() == self.ordering_req.len());
421        let current_ordering = &self.orderings[group_idx];
422        compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
423            if self.pick_first_in_group {
424                x.is_gt()
425            } else {
426                x.is_lt()
427            }
428        })
429    }
430
431    fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
432        let result = emit_to.take_needed(&mut self.orderings);
433
434        match emit_to {
435            EmitTo::All => self.size_of_orderings = 0,
436            EmitTo::First(_) => {
437                self.size_of_orderings -=
438                    result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
439            }
440        }
441
442        result
443    }
444
445    fn take_need(
446        bool_buf_builder: &mut BooleanBufferBuilder,
447        emit_to: EmitTo,
448    ) -> BooleanBuffer {
449        let bool_buf = bool_buf_builder.finish();
450        match emit_to {
451            EmitTo::All => bool_buf,
452            EmitTo::First(n) => {
453                // split off the first N values in seen_values
454                //
455                // TODO make this more efficient rather than two
456                // copies and bitwise manipulation
457                let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
458                // reset the existing buffer
459                for b in bool_buf.iter().skip(n) {
460                    bool_buf_builder.append(b);
461                }
462                first_n
463            }
464        }
465    }
466
467    fn resize_states(&mut self, new_size: usize) {
468        self.vals.resize(new_size, T::default_value());
469
470        self.null_builder.resize(new_size);
471
472        if self.orderings.len() < new_size {
473            let current_len = self.orderings.len();
474
475            self.orderings
476                .resize(new_size, self.default_orderings.clone());
477
478            self.size_of_orderings += (new_size - current_len)
479                * ScalarValue::size_of_vec(
480                    // Note: In some cases (such as in the unit test below)
481                    // ScalarValue::size_of_vec(&self.default_orderings) != ScalarValue::size_of_vec(&self.default_orderings.clone())
482                    // This may be caused by the different vec.capacity() values?
483                    self.orderings.last().unwrap(),
484                );
485        }
486
487        self.is_sets.resize(new_size);
488
489        self.min_of_each_group_buf.0.resize(new_size, 0);
490        self.min_of_each_group_buf.1.resize(new_size);
491    }
492
493    fn update_state(
494        &mut self,
495        group_idx: usize,
496        orderings: &[ScalarValue],
497        new_val: T::Native,
498        is_null: bool,
499    ) {
500        self.vals[group_idx] = new_val;
501        self.is_sets.set_bit(group_idx, true);
502
503        self.null_builder.set_bit(group_idx, !is_null);
504
505        assert!(orderings.len() == self.ordering_req.len());
506        let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
507        self.orderings[group_idx].clear();
508        self.orderings[group_idx].extend_from_slice(orderings);
509        let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
510        self.size_of_orderings = self.size_of_orderings - old_size + new_size;
511    }
512
513    fn take_state(
514        &mut self,
515        emit_to: EmitTo,
516    ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
517        emit_to.take_needed(&mut self.min_of_each_group_buf.0);
518        self.min_of_each_group_buf
519            .1
520            .truncate(self.min_of_each_group_buf.0.len());
521
522        (
523            self.take_vals_and_null_buf(emit_to),
524            self.take_orderings(emit_to),
525            Self::take_need(&mut self.is_sets, emit_to),
526        )
527    }
528
529    // should be used in test only
530    #[cfg(test)]
531    fn compute_size_of_orderings(&self) -> usize {
532        self.orderings
533            .iter()
534            .map(ScalarValue::size_of_vec)
535            .sum::<usize>()
536    }
537    /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the
538    /// minimum value in `orderings` for each group, using lexicographical comparison.
539    /// Values are filtered using `opt_filter` and `is_set_arr` if provided.
540    /// TODO: rename to get_filtered_extreme_of_each_group
541    fn get_filtered_min_of_each_group(
542        &mut self,
543        orderings: &[ArrayRef],
544        group_indices: &[usize],
545        opt_filter: Option<&BooleanArray>,
546        vals: &PrimitiveArray<T>,
547        is_set_arr: Option<&BooleanArray>,
548    ) -> Result<Vec<(usize, usize)>> {
549        // Set all values in min_of_each_group_buf.1 to false.
550        self.min_of_each_group_buf.1.truncate(0);
551        self.min_of_each_group_buf
552            .1
553            .append_n(self.vals.len(), false);
554
555        // No need to call `clear` since `self.min_of_each_group_buf.0[group_idx]`
556        // is only valid when `self.min_of_each_group_buf.1[group_idx] == true`.
557
558        let comparator = {
559            assert_eq!(orderings.len(), self.ordering_req.len());
560            let sort_columns = orderings
561                .iter()
562                .zip(self.ordering_req.iter())
563                .map(|(array, req)| SortColumn {
564                    values: Arc::clone(array),
565                    options: Some(req.options),
566                })
567                .collect::<Vec<_>>();
568
569            LexicographicalComparator::try_new(&sort_columns)?
570        };
571
572        for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
573            let group_idx = *group_idx;
574
575            let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
576
577            let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
578
579            if !passed_filter || !is_set {
580                continue;
581            }
582
583            if !self.need_update(group_idx) {
584                continue;
585            }
586
587            if self.ignore_nulls && vals.is_null(idx_in_val) {
588                continue;
589            }
590
591            let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
592
593            if !is_valid {
594                self.min_of_each_group_buf.1.set_bit(group_idx, true);
595                self.min_of_each_group_buf.0[group_idx] = idx_in_val;
596            } else {
597                let ordering = comparator
598                    .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);
599
600                if (ordering.is_gt() && self.pick_first_in_group)
601                    || (ordering.is_lt() && !self.pick_first_in_group)
602                {
603                    self.min_of_each_group_buf.0[group_idx] = idx_in_val;
604                }
605            }
606        }
607
608        Ok(self
609            .min_of_each_group_buf
610            .0
611            .iter()
612            .enumerate()
613            .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx))
614            .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
615            .collect::<Vec<_>>())
616    }
617
618    fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
619        let r = emit_to.take_needed(&mut self.vals);
620
621        let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to));
622
623        let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) // no copy
624            .with_data_type(self.data_type.clone());
625        Arc::new(values)
626    }
627}
628
629impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
630where
631    T: ArrowPrimitiveType + Send,
632{
633    fn update_batch(
634        &mut self,
635        // e.g. first_value(a order by b): values_and_order_cols will be [a, b]
636        values_and_order_cols: &[ArrayRef],
637        group_indices: &[usize],
638        opt_filter: Option<&BooleanArray>,
639        total_num_groups: usize,
640    ) -> Result<()> {
641        self.resize_states(total_num_groups);
642
643        let vals = values_and_order_cols[0].as_primitive::<T>();
644
645        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
646
647        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible.
648        for (group_idx, idx) in self
649            .get_filtered_min_of_each_group(
650                &values_and_order_cols[1..],
651                group_indices,
652                opt_filter,
653                vals,
654                None,
655            )?
656            .into_iter()
657        {
658            extract_row_at_idx_to_buf(
659                &values_and_order_cols[1..],
660                idx,
661                &mut ordering_buf,
662            )?;
663
664            if self.should_update_state(group_idx, &ordering_buf)? {
665                self.update_state(
666                    group_idx,
667                    &ordering_buf,
668                    vals.value(idx),
669                    vals.is_null(idx),
670                );
671            }
672        }
673
674        Ok(())
675    }
676
677    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
678        Ok(self.take_state(emit_to).0)
679    }
680
681    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
682        let (val_arr, orderings, is_sets) = self.take_state(emit_to);
683        let mut result = Vec::with_capacity(self.orderings.len() + 2);
684
685        result.push(val_arr);
686
687        let ordering_cols = {
688            let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
689            for _ in 0..self.ordering_req.len() {
690                ordering_cols.push(Vec::with_capacity(self.orderings.len()));
691            }
692            for row in orderings.into_iter() {
693                assert_eq!(row.len(), self.ordering_req.len());
694                for (col_idx, ordering) in row.into_iter().enumerate() {
695                    ordering_cols[col_idx].push(ordering);
696                }
697            }
698
699            ordering_cols
700        };
701        for ordering_col in ordering_cols {
702            result.push(ScalarValue::iter_to_array(ordering_col)?);
703        }
704
705        result.push(Arc::new(BooleanArray::new(is_sets, None)));
706
707        Ok(result)
708    }
709
710    fn merge_batch(
711        &mut self,
712        values: &[ArrayRef],
713        group_indices: &[usize],
714        opt_filter: Option<&BooleanArray>,
715        total_num_groups: usize,
716    ) -> Result<()> {
717        self.resize_states(total_num_groups);
718
719        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
720
721        let (is_set_arr, val_and_order_cols) = match values.split_last() {
722            Some(result) => result,
723            None => return internal_err!("Empty row in FISRT_VALUE"),
724        };
725
726        let is_set_arr = as_boolean_array(is_set_arr)?;
727
728        let vals = values[0].as_primitive::<T>();
729        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible.
730        let groups = self.get_filtered_min_of_each_group(
731            &val_and_order_cols[1..],
732            group_indices,
733            opt_filter,
734            vals,
735            Some(is_set_arr),
736        )?;
737
738        for (group_idx, idx) in groups.into_iter() {
739            extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
740
741            if self.should_update_state(group_idx, &ordering_buf)? {
742                self.update_state(
743                    group_idx,
744                    &ordering_buf,
745                    vals.value(idx),
746                    vals.is_null(idx),
747                );
748            }
749        }
750
751        Ok(())
752    }
753
754    fn size(&self) -> usize {
755        self.vals.capacity() * size_of::<T::Native>()
756            + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes
757            + self.is_sets.capacity() / 8
758            + self.size_of_orderings
759            + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
760            + self.min_of_each_group_buf.1.capacity() / 8
761    }
762
763    fn supports_convert_to_state(&self) -> bool {
764        true
765    }
766
767    fn convert_to_state(
768        &self,
769        values: &[ArrayRef],
770        opt_filter: Option<&BooleanArray>,
771    ) -> Result<Vec<ArrayRef>> {
772        let mut result = values.to_vec();
773        match opt_filter {
774            Some(f) => {
775                result.push(Arc::new(f.clone()));
776                Ok(result)
777            }
778            None => {
779                result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
780                Ok(result)
781            }
782        }
783    }
784}
785#[derive(Debug)]
786pub struct FirstValueAccumulator {
787    first: ScalarValue,
788    // At the beginning, `is_set` is false, which means `first` is not seen yet.
789    // Once we see the first value, we set the `is_set` flag and do not update `first` anymore.
790    is_set: bool,
791    // Stores ordering values, of the aggregator requirement corresponding to first value
792    // of the aggregator. These values are used during merging of multiple partitions.
793    orderings: Vec<ScalarValue>,
794    // Stores the applicable ordering requirement.
795    ordering_req: LexOrdering,
796    // Stores whether incoming data already satisfies the ordering requirement.
797    requirement_satisfied: bool,
798    // Ignore null values.
799    ignore_nulls: bool,
800}
801
802impl FirstValueAccumulator {
803    /// Creates a new `FirstValueAccumulator` for the given `data_type`.
804    pub fn try_new(
805        data_type: &DataType,
806        ordering_dtypes: &[DataType],
807        ordering_req: LexOrdering,
808        ignore_nulls: bool,
809    ) -> Result<Self> {
810        let orderings = ordering_dtypes
811            .iter()
812            .map(ScalarValue::try_from)
813            .collect::<Result<Vec<_>>>()?;
814        let requirement_satisfied = ordering_req.is_empty();
815        ScalarValue::try_from(data_type).map(|first| Self {
816            first,
817            is_set: false,
818            orderings,
819            ordering_req,
820            requirement_satisfied,
821            ignore_nulls,
822        })
823    }
824
825    pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
826        self.requirement_satisfied = requirement_satisfied;
827        self
828    }
829
830    // Updates state with the values in the given row.
831    fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
832        // Ensure any Array based scalars hold have a single value to reduce memory pressure
833        row.iter_mut().for_each(|s| {
834            s.compact();
835        });
836
837        self.first = row.remove(0);
838        self.orderings = row;
839        self.is_set = true;
840    }
841
842    fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
843        let [value, ordering_values @ ..] = values else {
844            return internal_err!("Empty row in FIRST_VALUE");
845        };
846        if self.requirement_satisfied {
847            // Get first entry according to the pre-existing ordering (0th index):
848            if self.ignore_nulls {
849                // If ignoring nulls, find the first non-null value.
850                for i in 0..value.len() {
851                    if !value.is_null(i) {
852                        return Ok(Some(i));
853                    }
854                }
855                return Ok(None);
856            } else {
857                // If not ignoring nulls, return the first value if it exists.
858                return Ok((!value.is_empty()).then_some(0));
859            }
860        }
861
862        let sort_columns = ordering_values
863            .iter()
864            .zip(self.ordering_req.iter())
865            .map(|(values, req)| SortColumn {
866                values: Arc::clone(values),
867                options: Some(req.options),
868            })
869            .collect::<Vec<_>>();
870
871        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
872
873        let min_index = if self.ignore_nulls {
874            (0..value.len())
875                .filter(|&index| !value.is_null(index))
876                .min_by(|&a, &b| comparator.compare(a, b))
877        } else {
878            (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
879        };
880
881        Ok(min_index)
882    }
883}
884
885impl Accumulator for FirstValueAccumulator {
886    fn state(&mut self) -> Result<Vec<ScalarValue>> {
887        let mut result = vec![self.first.clone()];
888        result.extend(self.orderings.iter().cloned());
889        result.push(ScalarValue::Boolean(Some(self.is_set)));
890        Ok(result)
891    }
892
893    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
894        if !self.is_set {
895            if let Some(first_idx) = self.get_first_idx(values)? {
896                let row = get_row_at_idx(values, first_idx)?;
897                self.update_with_new_row(row);
898            }
899        } else if !self.requirement_satisfied {
900            if let Some(first_idx) = self.get_first_idx(values)? {
901                let row = get_row_at_idx(values, first_idx)?;
902                let orderings = &row[1..];
903                if compare_rows(
904                    &self.orderings,
905                    orderings,
906                    &get_sort_options(self.ordering_req.as_ref()),
907                )?
908                .is_gt()
909                {
910                    self.update_with_new_row(row);
911                }
912            }
913        }
914        Ok(())
915    }
916
917    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
918        // FIRST_VALUE(first1, first2, first3, ...)
919        // last index contains is_set flag.
920        let is_set_idx = states.len() - 1;
921        let flags = states[is_set_idx].as_boolean();
922        let filtered_states =
923            filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
924        // 1..is_set_idx range corresponds to ordering section
925        let sort_columns = convert_to_sort_cols(
926            &filtered_states[1..is_set_idx],
927            self.ordering_req.as_ref(),
928        );
929
930        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
931        let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
932
933        if let Some(first_idx) = min {
934            let mut first_row = get_row_at_idx(&filtered_states, first_idx)?;
935            // When collecting orderings, we exclude the is_set flag from the state.
936            let first_ordering = &first_row[1..is_set_idx];
937            let sort_options = get_sort_options(self.ordering_req.as_ref());
938            // Either there is no existing value, or there is an earlier version in new data.
939            if !self.is_set
940                || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt()
941            {
942                // Update with first value in the state. Note that we should exclude the
943                // is_set flag from the state. Otherwise, we will end up with a state
944                // containing two is_set flags.
945                assert!(is_set_idx <= first_row.len());
946                first_row.resize(is_set_idx, ScalarValue::Null);
947                self.update_with_new_row(first_row);
948            }
949        }
950        Ok(())
951    }
952
953    fn evaluate(&mut self) -> Result<ScalarValue> {
954        Ok(self.first.clone())
955    }
956
957    fn size(&self) -> usize {
958        size_of_val(self) - size_of_val(&self.first)
959            + self.first.size()
960            + ScalarValue::size_of_vec(&self.orderings)
961            - size_of_val(&self.orderings)
962    }
963}
964
965#[user_doc(
966    doc_section(label = "General Functions"),
967    description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
968    syntax_example = "last_value(expression [ORDER BY expression])",
969    sql_example = r#"```sql
970> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
971+-----------------------------------------------+
972| last_value(column_name ORDER BY other_column) |
973+-----------------------------------------------+
974| last_element                                  |
975+-----------------------------------------------+
976```"#,
977    standard_argument(name = "expression",)
978)]
979pub struct LastValue {
980    signature: Signature,
981    requirement_satisfied: bool,
982}
983
984impl Debug for LastValue {
985    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
986        f.debug_struct("LastValue")
987            .field("name", &self.name())
988            .field("signature", &self.signature)
989            .field("accumulator", &"<FUNC>")
990            .finish()
991    }
992}
993
994impl Default for LastValue {
995    fn default() -> Self {
996        Self::new()
997    }
998}
999
1000impl LastValue {
1001    pub fn new() -> Self {
1002        Self {
1003            signature: Signature::any(1, Volatility::Immutable),
1004            requirement_satisfied: false,
1005        }
1006    }
1007
1008    fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
1009        self.requirement_satisfied = requirement_satisfied;
1010        self
1011    }
1012}
1013
1014impl AggregateUDFImpl for LastValue {
1015    fn as_any(&self) -> &dyn Any {
1016        self
1017    }
1018
1019    fn name(&self) -> &str {
1020        "last_value"
1021    }
1022
1023    fn signature(&self) -> &Signature {
1024        &self.signature
1025    }
1026
1027    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1028        Ok(arg_types[0].clone())
1029    }
1030
1031    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1032        let ordering_dtypes = acc_args
1033            .ordering_req
1034            .iter()
1035            .map(|e| e.expr.data_type(acc_args.schema))
1036            .collect::<Result<Vec<_>>>()?;
1037
1038        let requirement_satisfied =
1039            acc_args.ordering_req.is_empty() || self.requirement_satisfied;
1040
1041        LastValueAccumulator::try_new(
1042            acc_args.return_field.data_type(),
1043            &ordering_dtypes,
1044            acc_args.ordering_req.clone(),
1045            acc_args.ignore_nulls,
1046        )
1047        .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
1048    }
1049
1050    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
1051        let StateFieldsArgs {
1052            name,
1053            input_fields,
1054            return_field: _,
1055            ordering_fields,
1056            is_distinct: _,
1057        } = args;
1058        let mut fields = vec![Field::new(
1059            format_state_name(name, "last_value"),
1060            input_fields[0].data_type().clone(),
1061            true,
1062        )
1063        .into()];
1064        fields.extend(ordering_fields.to_vec());
1065        fields.push(Field::new("is_set", DataType::Boolean, true).into());
1066        Ok(fields)
1067    }
1068
1069    fn aliases(&self) -> &[String] {
1070        &[]
1071    }
1072
1073    fn with_beneficial_ordering(
1074        self: Arc<Self>,
1075        beneficial_ordering: bool,
1076    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1077        Ok(Some(Arc::new(
1078            LastValue::new().with_requirement_satisfied(beneficial_ordering),
1079        )))
1080    }
1081
1082    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1083        AggregateOrderSensitivity::Beneficial
1084    }
1085
1086    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
1087        datafusion_expr::ReversedUDAF::Reversed(first_value_udaf())
1088    }
1089
1090    fn documentation(&self) -> Option<&Documentation> {
1091        self.doc()
1092    }
1093
1094    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1095        use DataType::*;
1096        matches!(
1097            args.return_field.data_type(),
1098            Int8 | Int16
1099                | Int32
1100                | Int64
1101                | UInt8
1102                | UInt16
1103                | UInt32
1104                | UInt64
1105                | Float16
1106                | Float32
1107                | Float64
1108                | Decimal128(_, _)
1109                | Decimal256(_, _)
1110                | Date32
1111                | Date64
1112                | Time32(_)
1113                | Time64(_)
1114                | Timestamp(_, _)
1115        )
1116    }
1117
1118    fn create_groups_accumulator(
1119        &self,
1120        args: AccumulatorArgs,
1121    ) -> Result<Box<dyn GroupsAccumulator>> {
1122        fn create_accumulator<T>(
1123            args: AccumulatorArgs,
1124        ) -> Result<Box<dyn GroupsAccumulator>>
1125        where
1126            T: ArrowPrimitiveType + Send,
1127        {
1128            let ordering_dtypes = args
1129                .ordering_req
1130                .iter()
1131                .map(|e| e.expr.data_type(args.schema))
1132                .collect::<Result<Vec<_>>>()?;
1133
1134            Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
1135                args.ordering_req.clone(),
1136                args.ignore_nulls,
1137                args.return_field.data_type(),
1138                &ordering_dtypes,
1139                false,
1140            )?))
1141        }
1142
1143        match args.return_field.data_type() {
1144            DataType::Int8 => create_accumulator::<Int8Type>(args),
1145            DataType::Int16 => create_accumulator::<Int16Type>(args),
1146            DataType::Int32 => create_accumulator::<Int32Type>(args),
1147            DataType::Int64 => create_accumulator::<Int64Type>(args),
1148            DataType::UInt8 => create_accumulator::<UInt8Type>(args),
1149            DataType::UInt16 => create_accumulator::<UInt16Type>(args),
1150            DataType::UInt32 => create_accumulator::<UInt32Type>(args),
1151            DataType::UInt64 => create_accumulator::<UInt64Type>(args),
1152            DataType::Float16 => create_accumulator::<Float16Type>(args),
1153            DataType::Float32 => create_accumulator::<Float32Type>(args),
1154            DataType::Float64 => create_accumulator::<Float64Type>(args),
1155
1156            DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
1157            DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
1158
1159            DataType::Timestamp(TimeUnit::Second, _) => {
1160                create_accumulator::<TimestampSecondType>(args)
1161            }
1162            DataType::Timestamp(TimeUnit::Millisecond, _) => {
1163                create_accumulator::<TimestampMillisecondType>(args)
1164            }
1165            DataType::Timestamp(TimeUnit::Microsecond, _) => {
1166                create_accumulator::<TimestampMicrosecondType>(args)
1167            }
1168            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1169                create_accumulator::<TimestampNanosecondType>(args)
1170            }
1171
1172            DataType::Date32 => create_accumulator::<Date32Type>(args),
1173            DataType::Date64 => create_accumulator::<Date64Type>(args),
1174            DataType::Time32(TimeUnit::Second) => {
1175                create_accumulator::<Time32SecondType>(args)
1176            }
1177            DataType::Time32(TimeUnit::Millisecond) => {
1178                create_accumulator::<Time32MillisecondType>(args)
1179            }
1180
1181            DataType::Time64(TimeUnit::Microsecond) => {
1182                create_accumulator::<Time64MicrosecondType>(args)
1183            }
1184            DataType::Time64(TimeUnit::Nanosecond) => {
1185                create_accumulator::<Time64NanosecondType>(args)
1186            }
1187
1188            _ => {
1189                internal_err!(
1190                    "GroupsAccumulator not supported for last_value({})",
1191                    args.return_field.data_type()
1192                )
1193            }
1194        }
1195    }
1196}
1197
1198#[derive(Debug)]
1199struct LastValueAccumulator {
1200    last: ScalarValue,
1201    // The `is_set` flag keeps track of whether the last value is finalized.
1202    // This information is used to discriminate genuine NULLs and NULLS that
1203    // occur due to empty partitions.
1204    is_set: bool,
1205    orderings: Vec<ScalarValue>,
1206    // Stores the applicable ordering requirement.
1207    ordering_req: LexOrdering,
1208    // Stores whether incoming data already satisfies the ordering requirement.
1209    requirement_satisfied: bool,
1210    // Ignore null values.
1211    ignore_nulls: bool,
1212}
1213
1214impl LastValueAccumulator {
1215    /// Creates a new `LastValueAccumulator` for the given `data_type`.
1216    pub fn try_new(
1217        data_type: &DataType,
1218        ordering_dtypes: &[DataType],
1219        ordering_req: LexOrdering,
1220        ignore_nulls: bool,
1221    ) -> Result<Self> {
1222        let orderings = ordering_dtypes
1223            .iter()
1224            .map(ScalarValue::try_from)
1225            .collect::<Result<Vec<_>>>()?;
1226        let requirement_satisfied = ordering_req.is_empty();
1227        ScalarValue::try_from(data_type).map(|last| Self {
1228            last,
1229            is_set: false,
1230            orderings,
1231            ordering_req,
1232            requirement_satisfied,
1233            ignore_nulls,
1234        })
1235    }
1236
1237    // Updates state with the values in the given row.
1238    fn update_with_new_row(&mut self, mut row: Vec<ScalarValue>) {
1239        // Ensure any Array based scalars hold have a single value to reduce memory pressure
1240        row.iter_mut().for_each(|s| {
1241            s.compact();
1242        });
1243
1244        self.last = row.remove(0);
1245        self.orderings = row;
1246        self.is_set = true;
1247    }
1248
1249    fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1250        let [value, ordering_values @ ..] = values else {
1251            return internal_err!("Empty row in LAST_VALUE");
1252        };
1253        if self.requirement_satisfied {
1254            // Get last entry according to the order of data:
1255            if self.ignore_nulls {
1256                // If ignoring nulls, find the last non-null value.
1257                for i in (0..value.len()).rev() {
1258                    if !value.is_null(i) {
1259                        return Ok(Some(i));
1260                    }
1261                }
1262                return Ok(None);
1263            } else {
1264                return Ok((!value.is_empty()).then_some(value.len() - 1));
1265            }
1266        }
1267        let sort_columns = ordering_values
1268            .iter()
1269            .zip(self.ordering_req.iter())
1270            .map(|(values, req)| SortColumn {
1271                values: Arc::clone(values),
1272                options: Some(req.options),
1273            })
1274            .collect::<Vec<_>>();
1275
1276        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1277        let max_ind = if self.ignore_nulls {
1278            (0..value.len())
1279                .filter(|&index| !(value.is_null(index)))
1280                .max_by(|&a, &b| comparator.compare(a, b))
1281        } else {
1282            (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1283        };
1284
1285        Ok(max_ind)
1286    }
1287
1288    fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
1289        self.requirement_satisfied = requirement_satisfied;
1290        self
1291    }
1292}
1293
1294impl Accumulator for LastValueAccumulator {
1295    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1296        let mut result = vec![self.last.clone()];
1297        result.extend(self.orderings.clone());
1298        result.push(ScalarValue::Boolean(Some(self.is_set)));
1299        Ok(result)
1300    }
1301
1302    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1303        if !self.is_set || self.requirement_satisfied {
1304            if let Some(last_idx) = self.get_last_idx(values)? {
1305                let row = get_row_at_idx(values, last_idx)?;
1306                self.update_with_new_row(row);
1307            }
1308        } else if let Some(last_idx) = self.get_last_idx(values)? {
1309            let row = get_row_at_idx(values, last_idx)?;
1310            let orderings = &row[1..];
1311            // Update when there is a more recent entry
1312            if compare_rows(
1313                &self.orderings,
1314                orderings,
1315                &get_sort_options(self.ordering_req.as_ref()),
1316            )?
1317            .is_lt()
1318            {
1319                self.update_with_new_row(row);
1320            }
1321        }
1322
1323        Ok(())
1324    }
1325
1326    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1327        // LAST_VALUE(last1, last2, last3, ...)
1328        // last index contains is_set flag.
1329        let is_set_idx = states.len() - 1;
1330        let flags = states[is_set_idx].as_boolean();
1331        let filtered_states =
1332            filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1333        // 1..is_set_idx range corresponds to ordering section
1334        let sort_columns = convert_to_sort_cols(
1335            &filtered_states[1..is_set_idx],
1336            self.ordering_req.as_ref(),
1337        );
1338
1339        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1340        let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1341
1342        if let Some(last_idx) = max {
1343            let mut last_row = get_row_at_idx(&filtered_states, last_idx)?;
1344            // When collecting orderings, we exclude the is_set flag from the state.
1345            let last_ordering = &last_row[1..is_set_idx];
1346            let sort_options = get_sort_options(self.ordering_req.as_ref());
1347            // Either there is no existing value, or there is a newer (latest)
1348            // version in the new data:
1349            if !self.is_set
1350                || self.requirement_satisfied
1351                || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt()
1352            {
1353                // Update with last value in the state. Note that we should exclude the
1354                // is_set flag from the state. Otherwise, we will end up with a state
1355                // containing two is_set flags.
1356                assert!(is_set_idx <= last_row.len());
1357                last_row.resize(is_set_idx, ScalarValue::Null);
1358                self.update_with_new_row(last_row);
1359            }
1360        }
1361        Ok(())
1362    }
1363
1364    fn evaluate(&mut self) -> Result<ScalarValue> {
1365        Ok(self.last.clone())
1366    }
1367
1368    fn size(&self) -> usize {
1369        size_of_val(self) - size_of_val(&self.last)
1370            + self.last.size()
1371            + ScalarValue::size_of_vec(&self.orderings)
1372            - size_of_val(&self.orderings)
1373    }
1374}
1375
1376/// Filters states according to the `is_set` flag at the last column and returns
1377/// the resulting states.
1378fn filter_states_according_to_is_set(
1379    states: &[ArrayRef],
1380    flags: &BooleanArray,
1381) -> Result<Vec<ArrayRef>> {
1382    states
1383        .iter()
1384        .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1385        .collect::<Result<Vec<_>>>()
1386}
1387
1388/// Combines array refs and their corresponding orderings to construct `SortColumn`s.
1389fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1390    arrs.iter()
1391        .zip(sort_exprs.iter())
1392        .map(|(item, sort_expr)| SortColumn {
1393            values: Arc::clone(item),
1394            options: Some(sort_expr.options),
1395        })
1396        .collect::<Vec<_>>()
1397}
1398
1399#[cfg(test)]
1400mod tests {
1401    use std::iter::repeat_with;
1402
1403    use arrow::{
1404        array::{Int64Array, ListArray},
1405        compute::SortOptions,
1406        datatypes::Schema,
1407    };
1408    use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
1409
1410    use super::*;
1411
1412    #[test]
1413    fn test_first_last_value_value() -> Result<()> {
1414        let mut first_accumulator = FirstValueAccumulator::try_new(
1415            &DataType::Int64,
1416            &[],
1417            LexOrdering::default(),
1418            false,
1419        )?;
1420        let mut last_accumulator = LastValueAccumulator::try_new(
1421            &DataType::Int64,
1422            &[],
1423            LexOrdering::default(),
1424            false,
1425        )?;
1426        // first value in the tuple is start of the range (inclusive),
1427        // second value in the tuple is end of the range (exclusive)
1428        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1429        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
1430        let arrs = ranges
1431            .into_iter()
1432            .map(|(start, end)| {
1433                Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1434            })
1435            .collect::<Vec<_>>();
1436        for arr in arrs {
1437            // Once first_value is set, accumulator should remember it.
1438            // It shouldn't update first_value for each new batch
1439            first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1440            // last_value should be updated for each new batch.
1441            last_accumulator.update_batch(&[arr])?;
1442        }
1443        // First Value comes from the first value of the first batch which is 0
1444        assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1445        // Last value comes from the last value of the last batch which is 12
1446        assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1447        Ok(())
1448    }
1449
1450    #[test]
1451    fn test_first_last_state_after_merge() -> Result<()> {
1452        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1453        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
1454        let arrs = ranges
1455            .into_iter()
1456            .map(|(start, end)| {
1457                Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1458            })
1459            .collect::<Vec<_>>();
1460
1461        // FirstValueAccumulator
1462        let mut first_accumulator = FirstValueAccumulator::try_new(
1463            &DataType::Int64,
1464            &[],
1465            LexOrdering::default(),
1466            false,
1467        )?;
1468
1469        first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1470        let state1 = first_accumulator.state()?;
1471
1472        let mut first_accumulator = FirstValueAccumulator::try_new(
1473            &DataType::Int64,
1474            &[],
1475            LexOrdering::default(),
1476            false,
1477        )?;
1478        first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1479        let state2 = first_accumulator.state()?;
1480
1481        assert_eq!(state1.len(), state2.len());
1482
1483        let mut states = vec![];
1484
1485        for idx in 0..state1.len() {
1486            states.push(compute::concat(&[
1487                &state1[idx].to_array()?,
1488                &state2[idx].to_array()?,
1489            ])?);
1490        }
1491
1492        let mut first_accumulator = FirstValueAccumulator::try_new(
1493            &DataType::Int64,
1494            &[],
1495            LexOrdering::default(),
1496            false,
1497        )?;
1498        first_accumulator.merge_batch(&states)?;
1499
1500        let merged_state = first_accumulator.state()?;
1501        assert_eq!(merged_state.len(), state1.len());
1502
1503        // LastValueAccumulator
1504        let mut last_accumulator = LastValueAccumulator::try_new(
1505            &DataType::Int64,
1506            &[],
1507            LexOrdering::default(),
1508            false,
1509        )?;
1510
1511        last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1512        let state1 = last_accumulator.state()?;
1513
1514        let mut last_accumulator = LastValueAccumulator::try_new(
1515            &DataType::Int64,
1516            &[],
1517            LexOrdering::default(),
1518            false,
1519        )?;
1520        last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1521        let state2 = last_accumulator.state()?;
1522
1523        assert_eq!(state1.len(), state2.len());
1524
1525        let mut states = vec![];
1526
1527        for idx in 0..state1.len() {
1528            states.push(compute::concat(&[
1529                &state1[idx].to_array()?,
1530                &state2[idx].to_array()?,
1531            ])?);
1532        }
1533
1534        let mut last_accumulator = LastValueAccumulator::try_new(
1535            &DataType::Int64,
1536            &[],
1537            LexOrdering::default(),
1538            false,
1539        )?;
1540        last_accumulator.merge_batch(&states)?;
1541
1542        let merged_state = last_accumulator.state()?;
1543        assert_eq!(merged_state.len(), state1.len());
1544
1545        Ok(())
1546    }
1547
1548    #[test]
1549    fn test_frist_group_acc() -> Result<()> {
1550        let schema = Arc::new(Schema::new(vec![
1551            Field::new("a", DataType::Int64, true),
1552            Field::new("b", DataType::Int64, true),
1553            Field::new("c", DataType::Int64, true),
1554            Field::new("d", DataType::Int32, true),
1555            Field::new("e", DataType::Boolean, true),
1556        ]));
1557
1558        let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1559            expr: col("c", &schema).unwrap(),
1560            options: SortOptions::default(),
1561        }]);
1562
1563        let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1564            sort_key,
1565            true,
1566            &DataType::Int64,
1567            &[DataType::Int64],
1568            true,
1569        )?;
1570
1571        let mut val_with_orderings = {
1572            let mut val_with_orderings = Vec::<ArrayRef>::new();
1573
1574            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1575            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1576
1577            val_with_orderings.push(vals);
1578            val_with_orderings.push(orderings);
1579
1580            val_with_orderings
1581        };
1582
1583        group_acc.update_batch(
1584            &val_with_orderings,
1585            &[0, 1, 2, 1],
1586            Some(&BooleanArray::from(vec![true, true, false, true])),
1587            3,
1588        )?;
1589        assert_eq!(
1590            group_acc.size_of_orderings,
1591            group_acc.compute_size_of_orderings()
1592        );
1593
1594        let state = group_acc.state(EmitTo::All)?;
1595
1596        let expected_state: Vec<Arc<dyn Array>> = vec![
1597            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1598            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1599            Arc::new(BooleanArray::from(vec![true, true, false])),
1600        ];
1601        assert_eq!(state, expected_state);
1602
1603        assert_eq!(
1604            group_acc.size_of_orderings,
1605            group_acc.compute_size_of_orderings()
1606        );
1607
1608        group_acc.merge_batch(
1609            &state,
1610            &[0, 1, 2],
1611            Some(&BooleanArray::from(vec![true, false, false])),
1612            3,
1613        )?;
1614
1615        assert_eq!(
1616            group_acc.size_of_orderings,
1617            group_acc.compute_size_of_orderings()
1618        );
1619
1620        val_with_orderings.clear();
1621        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1622        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1623
1624        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1625
1626        let binding = group_acc.evaluate(EmitTo::All)?;
1627        let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1628
1629        let expect: PrimitiveArray<Int64Type> =
1630            Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1631
1632        assert_eq!(eval_result, &expect);
1633
1634        assert_eq!(
1635            group_acc.size_of_orderings,
1636            group_acc.compute_size_of_orderings()
1637        );
1638
1639        Ok(())
1640    }
1641
1642    #[test]
1643    fn test_group_acc_size_of_ordering() -> Result<()> {
1644        let schema = Arc::new(Schema::new(vec![
1645            Field::new("a", DataType::Int64, true),
1646            Field::new("b", DataType::Int64, true),
1647            Field::new("c", DataType::Int64, true),
1648            Field::new("d", DataType::Int32, true),
1649            Field::new("e", DataType::Boolean, true),
1650        ]));
1651
1652        let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1653            expr: col("c", &schema).unwrap(),
1654            options: SortOptions::default(),
1655        }]);
1656
1657        let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1658            sort_key,
1659            true,
1660            &DataType::Int64,
1661            &[DataType::Int64],
1662            true,
1663        )?;
1664
1665        let val_with_orderings = {
1666            let mut val_with_orderings = Vec::<ArrayRef>::new();
1667
1668            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1669            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1670
1671            val_with_orderings.push(vals);
1672            val_with_orderings.push(orderings);
1673
1674            val_with_orderings
1675        };
1676
1677        for _ in 0..10 {
1678            group_acc.update_batch(
1679                &val_with_orderings,
1680                &[0, 1, 2, 1],
1681                Some(&BooleanArray::from(vec![true, true, false, true])),
1682                100,
1683            )?;
1684            assert_eq!(
1685                group_acc.size_of_orderings,
1686                group_acc.compute_size_of_orderings()
1687            );
1688
1689            group_acc.state(EmitTo::First(2))?;
1690            assert_eq!(
1691                group_acc.size_of_orderings,
1692                group_acc.compute_size_of_orderings()
1693            );
1694
1695            let s = group_acc.state(EmitTo::All)?;
1696            assert_eq!(
1697                group_acc.size_of_orderings,
1698                group_acc.compute_size_of_orderings()
1699            );
1700
1701            group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1702            assert_eq!(
1703                group_acc.size_of_orderings,
1704                group_acc.compute_size_of_orderings()
1705            );
1706
1707            group_acc.evaluate(EmitTo::First(2))?;
1708            assert_eq!(
1709                group_acc.size_of_orderings,
1710                group_acc.compute_size_of_orderings()
1711            );
1712
1713            group_acc.evaluate(EmitTo::All)?;
1714            assert_eq!(
1715                group_acc.size_of_orderings,
1716                group_acc.compute_size_of_orderings()
1717            );
1718        }
1719
1720        Ok(())
1721    }
1722
1723    #[test]
1724    fn test_last_group_acc() -> Result<()> {
1725        let schema = Arc::new(Schema::new(vec![
1726            Field::new("a", DataType::Int64, true),
1727            Field::new("b", DataType::Int64, true),
1728            Field::new("c", DataType::Int64, true),
1729            Field::new("d", DataType::Int32, true),
1730            Field::new("e", DataType::Boolean, true),
1731        ]));
1732
1733        let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1734            expr: col("c", &schema).unwrap(),
1735            options: SortOptions::default(),
1736        }]);
1737
1738        let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1739            sort_key,
1740            true,
1741            &DataType::Int64,
1742            &[DataType::Int64],
1743            false,
1744        )?;
1745
1746        let mut val_with_orderings = {
1747            let mut val_with_orderings = Vec::<ArrayRef>::new();
1748
1749            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1750            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1751
1752            val_with_orderings.push(vals);
1753            val_with_orderings.push(orderings);
1754
1755            val_with_orderings
1756        };
1757
1758        group_acc.update_batch(
1759            &val_with_orderings,
1760            &[0, 1, 2, 1],
1761            Some(&BooleanArray::from(vec![true, true, false, true])),
1762            3,
1763        )?;
1764
1765        let state = group_acc.state(EmitTo::All)?;
1766
1767        let expected_state: Vec<Arc<dyn Array>> = vec![
1768            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1769            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1770            Arc::new(BooleanArray::from(vec![true, true, false])),
1771        ];
1772        assert_eq!(state, expected_state);
1773
1774        group_acc.merge_batch(
1775            &state,
1776            &[0, 1, 2],
1777            Some(&BooleanArray::from(vec![true, false, false])),
1778            3,
1779        )?;
1780
1781        val_with_orderings.clear();
1782        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1783        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1784
1785        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1786
1787        let binding = group_acc.evaluate(EmitTo::All)?;
1788        let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1789
1790        let expect: PrimitiveArray<Int64Type> =
1791            Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1792
1793        assert_eq!(eval_result, &expect);
1794
1795        Ok(())
1796    }
1797
1798    #[test]
1799    fn test_first_list_acc_size() -> Result<()> {
1800        fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1801            let mut first_accumulator = FirstValueAccumulator::try_new(
1802                &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1803                &[],
1804                LexOrdering::default(),
1805                false,
1806            )?;
1807
1808            first_accumulator.update_batch(values)?;
1809
1810            Ok(first_accumulator.size())
1811        }
1812
1813        let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1814            repeat_with(|| Some(vec![Some(1)])).take(10000),
1815        );
1816        let batch2 =
1817            ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1818
1819        let size1 = size_after_batch(&[Arc::new(batch1)])?;
1820        let size2 = size_after_batch(&[Arc::new(batch2)])?;
1821        assert_eq!(size1, size2);
1822
1823        Ok(())
1824    }
1825
1826    #[test]
1827    fn test_last_list_acc_size() -> Result<()> {
1828        fn size_after_batch(values: &[ArrayRef]) -> Result<usize> {
1829            let mut last_accumulator = LastValueAccumulator::try_new(
1830                &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
1831                &[],
1832                LexOrdering::default(),
1833                false,
1834            )?;
1835
1836            last_accumulator.update_batch(values)?;
1837
1838            Ok(last_accumulator.size())
1839        }
1840
1841        let batch1 = ListArray::from_iter_primitive::<Int32Type, _, _>(
1842            repeat_with(|| Some(vec![Some(1)])).take(10000),
1843        );
1844        let batch2 =
1845            ListArray::from_iter_primitive::<Int32Type, _, _>([Some(vec![Some(1)])]);
1846
1847        let size1 = size_after_batch(&[Arc::new(batch1)])?;
1848        let size2 = size_after_batch(&[Arc::new(batch2)])?;
1849        assert_eq!(size1, size2);
1850
1851        Ok(())
1852    }
1853}