datafusion_functions_aggregate/
first_last.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the FIRST_VALUE/LAST_VALUE aggregations.
19
20use std::any::Any;
21use std::fmt::Debug;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{
26    Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder,
27    PrimitiveArray,
28};
29use arrow::buffer::{BooleanBuffer, NullBuffer};
30use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
31use arrow::datatypes::{
32    DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float16Type,
33    Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
34    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
35    TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
36    TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
37    UInt8Type,
38};
39use datafusion_common::cast::as_boolean_array;
40use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
41use datafusion_common::{
42    arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
43};
44use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
45use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
46use datafusion_expr::{
47    Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
48    GroupsAccumulator, Signature, SortExpr, Volatility,
49};
50use datafusion_functions_aggregate_common::utils::get_sort_options;
51use datafusion_macros::user_doc;
52use datafusion_physical_expr_common::sort_expr::LexOrdering;
53
54create_func!(FirstValue, first_value_udaf);
55create_func!(LastValue, last_value_udaf);
56
57/// Returns the first value in a group of values.
58pub fn first_value(expression: Expr, order_by: Option<Vec<SortExpr>>) -> Expr {
59    if let Some(order_by) = order_by {
60        first_value_udaf()
61            .call(vec![expression])
62            .order_by(order_by)
63            .build()
64            // guaranteed to be `Expr::AggregateFunction`
65            .unwrap()
66    } else {
67        first_value_udaf().call(vec![expression])
68    }
69}
70
71/// Returns the last value in a group of values.
72pub fn last_value(expression: Expr, order_by: Option<Vec<SortExpr>>) -> Expr {
73    if let Some(order_by) = order_by {
74        last_value_udaf()
75            .call(vec![expression])
76            .order_by(order_by)
77            .build()
78            // guaranteed to be `Expr::AggregateFunction`
79            .unwrap()
80    } else {
81        last_value_udaf().call(vec![expression])
82    }
83}
84
85#[user_doc(
86    doc_section(label = "General Functions"),
87    description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
88    syntax_example = "first_value(expression [ORDER BY expression])",
89    sql_example = r#"```sql
90> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
91+-----------------------------------------------+
92| first_value(column_name ORDER BY other_column)|
93+-----------------------------------------------+
94| first_element                                 |
95+-----------------------------------------------+
96```"#,
97    standard_argument(name = "expression",)
98)]
99pub struct FirstValue {
100    signature: Signature,
101    requirement_satisfied: bool,
102}
103
104impl Debug for FirstValue {
105    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
106        f.debug_struct("FirstValue")
107            .field("name", &self.name())
108            .field("signature", &self.signature)
109            .field("accumulator", &"<FUNC>")
110            .finish()
111    }
112}
113
114impl Default for FirstValue {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120impl FirstValue {
121    pub fn new() -> Self {
122        Self {
123            signature: Signature::any(1, Volatility::Immutable),
124            requirement_satisfied: false,
125        }
126    }
127
128    fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
129        self.requirement_satisfied = requirement_satisfied;
130        self
131    }
132}
133
134impl AggregateUDFImpl for FirstValue {
135    fn as_any(&self) -> &dyn Any {
136        self
137    }
138
139    fn name(&self) -> &str {
140        "first_value"
141    }
142
143    fn signature(&self) -> &Signature {
144        &self.signature
145    }
146
147    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
148        Ok(arg_types[0].clone())
149    }
150
151    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
152        let ordering_dtypes = acc_args
153            .ordering_req
154            .iter()
155            .map(|e| e.expr.data_type(acc_args.schema))
156            .collect::<Result<Vec<_>>>()?;
157
158        // When requirement is empty, or it is signalled by outside caller that
159        // the ordering requirement is/will be satisfied.
160        let requirement_satisfied =
161            acc_args.ordering_req.is_empty() || self.requirement_satisfied;
162
163        FirstValueAccumulator::try_new(
164            acc_args.return_type,
165            &ordering_dtypes,
166            acc_args.ordering_req.clone(),
167            acc_args.ignore_nulls,
168        )
169        .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
170    }
171
172    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
173        let mut fields = vec![Field::new(
174            format_state_name(args.name, "first_value"),
175            args.return_type.clone(),
176            true,
177        )];
178        fields.extend(args.ordering_fields.to_vec());
179        fields.push(Field::new("is_set", DataType::Boolean, true));
180        Ok(fields)
181    }
182
183    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
184        // TODO: extract to function
185        use DataType::*;
186        matches!(
187            args.return_type,
188            Int8 | Int16
189                | Int32
190                | Int64
191                | UInt8
192                | UInt16
193                | UInt32
194                | UInt64
195                | Float16
196                | Float32
197                | Float64
198                | Decimal128(_, _)
199                | Decimal256(_, _)
200                | Date32
201                | Date64
202                | Time32(_)
203                | Time64(_)
204                | Timestamp(_, _)
205        )
206    }
207
208    fn create_groups_accumulator(
209        &self,
210        args: AccumulatorArgs,
211    ) -> Result<Box<dyn GroupsAccumulator>> {
212        // TODO: extract to function
213        fn create_accumulator<T>(
214            args: AccumulatorArgs,
215        ) -> Result<Box<dyn GroupsAccumulator>>
216        where
217            T: ArrowPrimitiveType + Send,
218        {
219            let ordering_dtypes = args
220                .ordering_req
221                .iter()
222                .map(|e| e.expr.data_type(args.schema))
223                .collect::<Result<Vec<_>>>()?;
224
225            Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
226                args.ordering_req.clone(),
227                args.ignore_nulls,
228                args.return_type,
229                &ordering_dtypes,
230                true,
231            )?))
232        }
233
234        match args.return_type {
235            DataType::Int8 => create_accumulator::<Int8Type>(args),
236            DataType::Int16 => create_accumulator::<Int16Type>(args),
237            DataType::Int32 => create_accumulator::<Int32Type>(args),
238            DataType::Int64 => create_accumulator::<Int64Type>(args),
239            DataType::UInt8 => create_accumulator::<UInt8Type>(args),
240            DataType::UInt16 => create_accumulator::<UInt16Type>(args),
241            DataType::UInt32 => create_accumulator::<UInt32Type>(args),
242            DataType::UInt64 => create_accumulator::<UInt64Type>(args),
243            DataType::Float16 => create_accumulator::<Float16Type>(args),
244            DataType::Float32 => create_accumulator::<Float32Type>(args),
245            DataType::Float64 => create_accumulator::<Float64Type>(args),
246
247            DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
248            DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
249
250            DataType::Timestamp(TimeUnit::Second, _) => {
251                create_accumulator::<TimestampSecondType>(args)
252            }
253            DataType::Timestamp(TimeUnit::Millisecond, _) => {
254                create_accumulator::<TimestampMillisecondType>(args)
255            }
256            DataType::Timestamp(TimeUnit::Microsecond, _) => {
257                create_accumulator::<TimestampMicrosecondType>(args)
258            }
259            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
260                create_accumulator::<TimestampNanosecondType>(args)
261            }
262
263            DataType::Date32 => create_accumulator::<Date32Type>(args),
264            DataType::Date64 => create_accumulator::<Date64Type>(args),
265            DataType::Time32(TimeUnit::Second) => {
266                create_accumulator::<Time32SecondType>(args)
267            }
268            DataType::Time32(TimeUnit::Millisecond) => {
269                create_accumulator::<Time32MillisecondType>(args)
270            }
271
272            DataType::Time64(TimeUnit::Microsecond) => {
273                create_accumulator::<Time64MicrosecondType>(args)
274            }
275            DataType::Time64(TimeUnit::Nanosecond) => {
276                create_accumulator::<Time64NanosecondType>(args)
277            }
278
279            _ => {
280                internal_err!(
281                    "GroupsAccumulator not supported for first_value({})",
282                    args.return_type
283                )
284            }
285        }
286    }
287
288    fn aliases(&self) -> &[String] {
289        &[]
290    }
291
292    fn with_beneficial_ordering(
293        self: Arc<Self>,
294        beneficial_ordering: bool,
295    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
296        Ok(Some(Arc::new(
297            FirstValue::new().with_requirement_satisfied(beneficial_ordering),
298        )))
299    }
300
301    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
302        AggregateOrderSensitivity::Beneficial
303    }
304
305    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
306        datafusion_expr::ReversedUDAF::Reversed(last_value_udaf())
307    }
308
309    fn documentation(&self) -> Option<&Documentation> {
310        self.doc()
311    }
312}
313
314// TODO: rename to PrimitiveGroupsAccumulator
315struct FirstPrimitiveGroupsAccumulator<T>
316where
317    T: ArrowPrimitiveType + Send,
318{
319    // ================ state ===========
320    vals: Vec<T::Native>,
321    // Stores ordering values, of the aggregator requirement corresponding to first value
322    // of the aggregator.
323    // The `orderings` are stored row-wise, meaning that `orderings[group_idx]`
324    // represents the ordering values corresponding to the `group_idx`-th group.
325    orderings: Vec<Vec<ScalarValue>>,
326    // At the beginning, `is_sets[group_idx]` is false, which means `first` is not seen yet.
327    // Once we see the first value, we set the `is_sets[group_idx]` flag
328    is_sets: BooleanBufferBuilder,
329    // null_builder[group_idx] == false => vals[group_idx] is null
330    null_builder: BooleanBufferBuilder,
331    // size of `self.orderings`
332    // Calculating the memory usage of `self.orderings` using `ScalarValue::size_of_vec` is quite costly.
333    // Therefore, we cache it and compute `size_of` only after each update
334    // to avoid calling `ScalarValue::size_of_vec` by Self.size.
335    size_of_orderings: usize,
336
337    // buffer for `get_filtered_min_of_each_group`
338    // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
339    // only valid if filter_min_of_each_group_buf.1[group_idx] == true
340    // TODO: rename to extreme_of_each_group_buf
341    min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
342
343    // =========== option ============
344
345    // Stores the applicable ordering requirement.
346    ordering_req: LexOrdering,
347    // true: take first element in an aggregation group according to the requested ordering.
348    // false: take last element in an aggregation group according to the requested ordering.
349    pick_first_in_group: bool,
350    // derived from `ordering_req`.
351    sort_options: Vec<SortOptions>,
352    // Stores whether incoming data already satisfies the ordering requirement.
353    input_requirement_satisfied: bool,
354    // Ignore null values.
355    ignore_nulls: bool,
356    /// The output type
357    data_type: DataType,
358    default_orderings: Vec<ScalarValue>,
359}
360
361impl<T> FirstPrimitiveGroupsAccumulator<T>
362where
363    T: ArrowPrimitiveType + Send,
364{
365    fn try_new(
366        ordering_req: LexOrdering,
367        ignore_nulls: bool,
368        data_type: &DataType,
369        ordering_dtypes: &[DataType],
370        pick_first_in_group: bool,
371    ) -> Result<Self> {
372        let requirement_satisfied = ordering_req.is_empty();
373
374        let default_orderings = ordering_dtypes
375            .iter()
376            .map(ScalarValue::try_from)
377            .collect::<Result<Vec<_>>>()?;
378
379        let sort_options = get_sort_options(ordering_req.as_ref());
380
381        Ok(Self {
382            null_builder: BooleanBufferBuilder::new(0),
383            ordering_req,
384            sort_options,
385            input_requirement_satisfied: requirement_satisfied,
386            ignore_nulls,
387            default_orderings,
388            data_type: data_type.clone(),
389            vals: Vec::new(),
390            orderings: Vec::new(),
391            is_sets: BooleanBufferBuilder::new(0),
392            size_of_orderings: 0,
393            min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
394            pick_first_in_group,
395        })
396    }
397
398    fn need_update(&self, group_idx: usize) -> bool {
399        if !self.is_sets.get_bit(group_idx) {
400            return true;
401        }
402
403        if self.ignore_nulls && !self.null_builder.get_bit(group_idx) {
404            return true;
405        }
406
407        !self.input_requirement_satisfied
408    }
409
410    fn should_update_state(
411        &self,
412        group_idx: usize,
413        new_ordering_values: &[ScalarValue],
414    ) -> Result<bool> {
415        if !self.is_sets.get_bit(group_idx) {
416            return Ok(true);
417        }
418
419        assert!(new_ordering_values.len() == self.ordering_req.len());
420        let current_ordering = &self.orderings[group_idx];
421        compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
422            if self.pick_first_in_group {
423                x.is_gt()
424            } else {
425                x.is_lt()
426            }
427        })
428    }
429
430    fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
431        let result = emit_to.take_needed(&mut self.orderings);
432
433        match emit_to {
434            EmitTo::All => self.size_of_orderings = 0,
435            EmitTo::First(_) => {
436                self.size_of_orderings -=
437                    result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
438            }
439        }
440
441        result
442    }
443
444    fn take_need(
445        bool_buf_builder: &mut BooleanBufferBuilder,
446        emit_to: EmitTo,
447    ) -> BooleanBuffer {
448        let bool_buf = bool_buf_builder.finish();
449        match emit_to {
450            EmitTo::All => bool_buf,
451            EmitTo::First(n) => {
452                // split off the first N values in seen_values
453                //
454                // TODO make this more efficient rather than two
455                // copies and bitwise manipulation
456                let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
457                // reset the existing buffer
458                for b in bool_buf.iter().skip(n) {
459                    bool_buf_builder.append(b);
460                }
461                first_n
462            }
463        }
464    }
465
466    fn resize_states(&mut self, new_size: usize) {
467        self.vals.resize(new_size, T::default_value());
468
469        self.null_builder.resize(new_size);
470
471        if self.orderings.len() < new_size {
472            let current_len = self.orderings.len();
473
474            self.orderings
475                .resize(new_size, self.default_orderings.clone());
476
477            self.size_of_orderings += (new_size - current_len)
478                * ScalarValue::size_of_vec(
479                    // Note: In some cases (such as in the unit test below)
480                    // ScalarValue::size_of_vec(&self.default_orderings) != ScalarValue::size_of_vec(&self.default_orderings.clone())
481                    // This may be caused by the different vec.capacity() values?
482                    self.orderings.last().unwrap(),
483                );
484        }
485
486        self.is_sets.resize(new_size);
487
488        self.min_of_each_group_buf.0.resize(new_size, 0);
489        self.min_of_each_group_buf.1.resize(new_size);
490    }
491
492    fn update_state(
493        &mut self,
494        group_idx: usize,
495        orderings: &[ScalarValue],
496        new_val: T::Native,
497        is_null: bool,
498    ) {
499        self.vals[group_idx] = new_val;
500        self.is_sets.set_bit(group_idx, true);
501
502        self.null_builder.set_bit(group_idx, !is_null);
503
504        assert!(orderings.len() == self.ordering_req.len());
505        let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
506        self.orderings[group_idx].clear();
507        self.orderings[group_idx].extend_from_slice(orderings);
508        let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
509        self.size_of_orderings = self.size_of_orderings - old_size + new_size;
510    }
511
512    fn take_state(
513        &mut self,
514        emit_to: EmitTo,
515    ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
516        emit_to.take_needed(&mut self.min_of_each_group_buf.0);
517        self.min_of_each_group_buf
518            .1
519            .truncate(self.min_of_each_group_buf.0.len());
520
521        (
522            self.take_vals_and_null_buf(emit_to),
523            self.take_orderings(emit_to),
524            Self::take_need(&mut self.is_sets, emit_to),
525        )
526    }
527
528    // should be used in test only
529    #[cfg(test)]
530    fn compute_size_of_orderings(&self) -> usize {
531        self.orderings
532            .iter()
533            .map(ScalarValue::size_of_vec)
534            .sum::<usize>()
535    }
536    /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the
537    /// minimum value in `orderings` for each group, using lexicographical comparison.
538    /// Values are filtered using `opt_filter` and `is_set_arr` if provided.
539    /// TODO: rename to get_filtered_extreme_of_each_group
540    fn get_filtered_min_of_each_group(
541        &mut self,
542        orderings: &[ArrayRef],
543        group_indices: &[usize],
544        opt_filter: Option<&BooleanArray>,
545        vals: &PrimitiveArray<T>,
546        is_set_arr: Option<&BooleanArray>,
547    ) -> Result<Vec<(usize, usize)>> {
548        // Set all values in min_of_each_group_buf.1 to false.
549        self.min_of_each_group_buf.1.truncate(0);
550        self.min_of_each_group_buf
551            .1
552            .append_n(self.vals.len(), false);
553
554        // No need to call `clear` since `self.min_of_each_group_buf.0[group_idx]`
555        // is only valid when `self.min_of_each_group_buf.1[group_idx] == true`.
556
557        let comparator = {
558            assert_eq!(orderings.len(), self.ordering_req.len());
559            let sort_columns = orderings
560                .iter()
561                .zip(self.ordering_req.iter())
562                .map(|(array, req)| SortColumn {
563                    values: Arc::clone(array),
564                    options: Some(req.options),
565                })
566                .collect::<Vec<_>>();
567
568            LexicographicalComparator::try_new(&sort_columns)?
569        };
570
571        for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
572            let group_idx = *group_idx;
573
574            let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
575
576            let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
577
578            if !passed_filter || !is_set {
579                continue;
580            }
581
582            if !self.need_update(group_idx) {
583                continue;
584            }
585
586            if self.ignore_nulls && vals.is_null(idx_in_val) {
587                continue;
588            }
589
590            let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
591
592            if !is_valid {
593                self.min_of_each_group_buf.1.set_bit(group_idx, true);
594                self.min_of_each_group_buf.0[group_idx] = idx_in_val;
595            } else {
596                let ordering = comparator
597                    .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);
598
599                if (ordering.is_gt() && self.pick_first_in_group)
600                    || (ordering.is_lt() && !self.pick_first_in_group)
601                {
602                    self.min_of_each_group_buf.0[group_idx] = idx_in_val;
603                }
604            }
605        }
606
607        Ok(self
608            .min_of_each_group_buf
609            .0
610            .iter()
611            .enumerate()
612            .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx))
613            .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
614            .collect::<Vec<_>>())
615    }
616
617    fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
618        let r = emit_to.take_needed(&mut self.vals);
619
620        let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to));
621
622        let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) // no copy
623            .with_data_type(self.data_type.clone());
624        Arc::new(values)
625    }
626}
627
628impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
629where
630    T: ArrowPrimitiveType + Send,
631{
632    fn update_batch(
633        &mut self,
634        // e.g. first_value(a order by b): values_and_order_cols will be [a, b]
635        values_and_order_cols: &[ArrayRef],
636        group_indices: &[usize],
637        opt_filter: Option<&BooleanArray>,
638        total_num_groups: usize,
639    ) -> Result<()> {
640        self.resize_states(total_num_groups);
641
642        let vals = values_and_order_cols[0].as_primitive::<T>();
643
644        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
645
646        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible.
647        for (group_idx, idx) in self
648            .get_filtered_min_of_each_group(
649                &values_and_order_cols[1..],
650                group_indices,
651                opt_filter,
652                vals,
653                None,
654            )?
655            .into_iter()
656        {
657            extract_row_at_idx_to_buf(
658                &values_and_order_cols[1..],
659                idx,
660                &mut ordering_buf,
661            )?;
662
663            if self.should_update_state(group_idx, &ordering_buf)? {
664                self.update_state(
665                    group_idx,
666                    &ordering_buf,
667                    vals.value(idx),
668                    vals.is_null(idx),
669                );
670            }
671        }
672
673        Ok(())
674    }
675
676    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
677        Ok(self.take_state(emit_to).0)
678    }
679
680    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
681        let (val_arr, orderings, is_sets) = self.take_state(emit_to);
682        let mut result = Vec::with_capacity(self.orderings.len() + 2);
683
684        result.push(val_arr);
685
686        let ordering_cols = {
687            let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
688            for _ in 0..self.ordering_req.len() {
689                ordering_cols.push(Vec::with_capacity(self.orderings.len()));
690            }
691            for row in orderings.into_iter() {
692                assert_eq!(row.len(), self.ordering_req.len());
693                for (col_idx, ordering) in row.into_iter().enumerate() {
694                    ordering_cols[col_idx].push(ordering);
695                }
696            }
697
698            ordering_cols
699        };
700        for ordering_col in ordering_cols {
701            result.push(ScalarValue::iter_to_array(ordering_col)?);
702        }
703
704        result.push(Arc::new(BooleanArray::new(is_sets, None)));
705
706        Ok(result)
707    }
708
709    fn merge_batch(
710        &mut self,
711        values: &[ArrayRef],
712        group_indices: &[usize],
713        opt_filter: Option<&BooleanArray>,
714        total_num_groups: usize,
715    ) -> Result<()> {
716        self.resize_states(total_num_groups);
717
718        let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
719
720        let (is_set_arr, val_and_order_cols) = match values.split_last() {
721            Some(result) => result,
722            None => return internal_err!("Empty row in FISRT_VALUE"),
723        };
724
725        let is_set_arr = as_boolean_array(is_set_arr)?;
726
727        let vals = values[0].as_primitive::<T>();
728        // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible.
729        let groups = self.get_filtered_min_of_each_group(
730            &val_and_order_cols[1..],
731            group_indices,
732            opt_filter,
733            vals,
734            Some(is_set_arr),
735        )?;
736
737        for (group_idx, idx) in groups.into_iter() {
738            extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
739
740            if self.should_update_state(group_idx, &ordering_buf)? {
741                self.update_state(
742                    group_idx,
743                    &ordering_buf,
744                    vals.value(idx),
745                    vals.is_null(idx),
746                );
747            }
748        }
749
750        Ok(())
751    }
752
753    fn size(&self) -> usize {
754        self.vals.capacity() * size_of::<T::Native>()
755            + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes 
756            + self.is_sets.capacity() / 8
757            + self.size_of_orderings
758            + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
759            + self.min_of_each_group_buf.1.capacity() / 8
760    }
761
762    fn supports_convert_to_state(&self) -> bool {
763        true
764    }
765
766    fn convert_to_state(
767        &self,
768        values: &[ArrayRef],
769        opt_filter: Option<&BooleanArray>,
770    ) -> Result<Vec<ArrayRef>> {
771        let mut result = values.to_vec();
772        match opt_filter {
773            Some(f) => {
774                result.push(Arc::new(f.clone()));
775                Ok(result)
776            }
777            None => {
778                result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
779                Ok(result)
780            }
781        }
782    }
783}
784#[derive(Debug)]
785pub struct FirstValueAccumulator {
786    first: ScalarValue,
787    // At the beginning, `is_set` is false, which means `first` is not seen yet.
788    // Once we see the first value, we set the `is_set` flag and do not update `first` anymore.
789    is_set: bool,
790    // Stores ordering values, of the aggregator requirement corresponding to first value
791    // of the aggregator. These values are used during merging of multiple partitions.
792    orderings: Vec<ScalarValue>,
793    // Stores the applicable ordering requirement.
794    ordering_req: LexOrdering,
795    // Stores whether incoming data already satisfies the ordering requirement.
796    requirement_satisfied: bool,
797    // Ignore null values.
798    ignore_nulls: bool,
799}
800
801impl FirstValueAccumulator {
802    /// Creates a new `FirstValueAccumulator` for the given `data_type`.
803    pub fn try_new(
804        data_type: &DataType,
805        ordering_dtypes: &[DataType],
806        ordering_req: LexOrdering,
807        ignore_nulls: bool,
808    ) -> Result<Self> {
809        let orderings = ordering_dtypes
810            .iter()
811            .map(ScalarValue::try_from)
812            .collect::<Result<Vec<_>>>()?;
813        let requirement_satisfied = ordering_req.is_empty();
814        ScalarValue::try_from(data_type).map(|first| Self {
815            first,
816            is_set: false,
817            orderings,
818            ordering_req,
819            requirement_satisfied,
820            ignore_nulls,
821        })
822    }
823
824    pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
825        self.requirement_satisfied = requirement_satisfied;
826        self
827    }
828
829    // Updates state with the values in the given row.
830    fn update_with_new_row(&mut self, row: &[ScalarValue]) {
831        self.first = row[0].clone();
832        self.orderings = row[1..].to_vec();
833        self.is_set = true;
834    }
835
836    fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
837        let [value, ordering_values @ ..] = values else {
838            return internal_err!("Empty row in FIRST_VALUE");
839        };
840        if self.requirement_satisfied {
841            // Get first entry according to the pre-existing ordering (0th index):
842            if self.ignore_nulls {
843                // If ignoring nulls, find the first non-null value.
844                for i in 0..value.len() {
845                    if !value.is_null(i) {
846                        return Ok(Some(i));
847                    }
848                }
849                return Ok(None);
850            } else {
851                // If not ignoring nulls, return the first value if it exists.
852                return Ok((!value.is_empty()).then_some(0));
853            }
854        }
855
856        let sort_columns = ordering_values
857            .iter()
858            .zip(self.ordering_req.iter())
859            .map(|(values, req)| SortColumn {
860                values: Arc::clone(values),
861                options: Some(req.options),
862            })
863            .collect::<Vec<_>>();
864
865        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
866
867        let min_index = if self.ignore_nulls {
868            (0..value.len())
869                .filter(|&index| !value.is_null(index))
870                .min_by(|&a, &b| comparator.compare(a, b))
871        } else {
872            (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
873        };
874
875        Ok(min_index)
876    }
877}
878
879impl Accumulator for FirstValueAccumulator {
880    fn state(&mut self) -> Result<Vec<ScalarValue>> {
881        let mut result = vec![self.first.clone()];
882        result.extend(self.orderings.iter().cloned());
883        result.push(ScalarValue::Boolean(Some(self.is_set)));
884        Ok(result)
885    }
886
887    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
888        if !self.is_set {
889            if let Some(first_idx) = self.get_first_idx(values)? {
890                let row = get_row_at_idx(values, first_idx)?;
891                self.update_with_new_row(&row);
892            }
893        } else if !self.requirement_satisfied {
894            if let Some(first_idx) = self.get_first_idx(values)? {
895                let row = get_row_at_idx(values, first_idx)?;
896                let orderings = &row[1..];
897                if compare_rows(
898                    &self.orderings,
899                    orderings,
900                    &get_sort_options(self.ordering_req.as_ref()),
901                )?
902                .is_gt()
903                {
904                    self.update_with_new_row(&row);
905                }
906            }
907        }
908        Ok(())
909    }
910
911    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
912        // FIRST_VALUE(first1, first2, first3, ...)
913        // last index contains is_set flag.
914        let is_set_idx = states.len() - 1;
915        let flags = states[is_set_idx].as_boolean();
916        let filtered_states =
917            filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
918        // 1..is_set_idx range corresponds to ordering section
919        let sort_columns = convert_to_sort_cols(
920            &filtered_states[1..is_set_idx],
921            self.ordering_req.as_ref(),
922        );
923
924        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
925        let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
926
927        if let Some(first_idx) = min {
928            let first_row = get_row_at_idx(&filtered_states, first_idx)?;
929            // When collecting orderings, we exclude the is_set flag from the state.
930            let first_ordering = &first_row[1..is_set_idx];
931            let sort_options = get_sort_options(self.ordering_req.as_ref());
932            // Either there is no existing value, or there is an earlier version in new data.
933            if !self.is_set
934                || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt()
935            {
936                // Update with first value in the state. Note that we should exclude the
937                // is_set flag from the state. Otherwise, we will end up with a state
938                // containing two is_set flags.
939                self.update_with_new_row(&first_row[0..is_set_idx]);
940            }
941        }
942        Ok(())
943    }
944
945    fn evaluate(&mut self) -> Result<ScalarValue> {
946        Ok(self.first.clone())
947    }
948
949    fn size(&self) -> usize {
950        size_of_val(self) - size_of_val(&self.first)
951            + self.first.size()
952            + ScalarValue::size_of_vec(&self.orderings)
953            - size_of_val(&self.orderings)
954    }
955}
956
957#[user_doc(
958    doc_section(label = "General Functions"),
959    description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
960    syntax_example = "last_value(expression [ORDER BY expression])",
961    sql_example = r#"```sql
962> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
963+-----------------------------------------------+
964| last_value(column_name ORDER BY other_column) |
965+-----------------------------------------------+
966| last_element                                  |
967+-----------------------------------------------+
968```"#,
969    standard_argument(name = "expression",)
970)]
971pub struct LastValue {
972    signature: Signature,
973    requirement_satisfied: bool,
974}
975
976impl Debug for LastValue {
977    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
978        f.debug_struct("LastValue")
979            .field("name", &self.name())
980            .field("signature", &self.signature)
981            .field("accumulator", &"<FUNC>")
982            .finish()
983    }
984}
985
986impl Default for LastValue {
987    fn default() -> Self {
988        Self::new()
989    }
990}
991
992impl LastValue {
993    pub fn new() -> Self {
994        Self {
995            signature: Signature::any(1, Volatility::Immutable),
996            requirement_satisfied: false,
997        }
998    }
999
1000    fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
1001        self.requirement_satisfied = requirement_satisfied;
1002        self
1003    }
1004}
1005
1006impl AggregateUDFImpl for LastValue {
1007    fn as_any(&self) -> &dyn Any {
1008        self
1009    }
1010
1011    fn name(&self) -> &str {
1012        "last_value"
1013    }
1014
1015    fn signature(&self) -> &Signature {
1016        &self.signature
1017    }
1018
1019    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1020        Ok(arg_types[0].clone())
1021    }
1022
1023    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1024        let ordering_dtypes = acc_args
1025            .ordering_req
1026            .iter()
1027            .map(|e| e.expr.data_type(acc_args.schema))
1028            .collect::<Result<Vec<_>>>()?;
1029
1030        let requirement_satisfied =
1031            acc_args.ordering_req.is_empty() || self.requirement_satisfied;
1032
1033        LastValueAccumulator::try_new(
1034            acc_args.return_type,
1035            &ordering_dtypes,
1036            acc_args.ordering_req.clone(),
1037            acc_args.ignore_nulls,
1038        )
1039        .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
1040    }
1041
1042    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
1043        let StateFieldsArgs {
1044            name,
1045            input_types,
1046            return_type: _,
1047            ordering_fields,
1048            is_distinct: _,
1049        } = args;
1050        let mut fields = vec![Field::new(
1051            format_state_name(name, "last_value"),
1052            input_types[0].clone(),
1053            true,
1054        )];
1055        fields.extend(ordering_fields.to_vec());
1056        fields.push(Field::new("is_set", DataType::Boolean, true));
1057        Ok(fields)
1058    }
1059
1060    fn aliases(&self) -> &[String] {
1061        &[]
1062    }
1063
1064    fn with_beneficial_ordering(
1065        self: Arc<Self>,
1066        beneficial_ordering: bool,
1067    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1068        Ok(Some(Arc::new(
1069            LastValue::new().with_requirement_satisfied(beneficial_ordering),
1070        )))
1071    }
1072
1073    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1074        AggregateOrderSensitivity::Beneficial
1075    }
1076
1077    fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
1078        datafusion_expr::ReversedUDAF::Reversed(first_value_udaf())
1079    }
1080
1081    fn documentation(&self) -> Option<&Documentation> {
1082        self.doc()
1083    }
1084
1085    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1086        use DataType::*;
1087        matches!(
1088            args.return_type,
1089            Int8 | Int16
1090                | Int32
1091                | Int64
1092                | UInt8
1093                | UInt16
1094                | UInt32
1095                | UInt64
1096                | Float16
1097                | Float32
1098                | Float64
1099                | Decimal128(_, _)
1100                | Decimal256(_, _)
1101                | Date32
1102                | Date64
1103                | Time32(_)
1104                | Time64(_)
1105                | Timestamp(_, _)
1106        )
1107    }
1108
1109    fn create_groups_accumulator(
1110        &self,
1111        args: AccumulatorArgs,
1112    ) -> Result<Box<dyn GroupsAccumulator>> {
1113        fn create_accumulator<T>(
1114            args: AccumulatorArgs,
1115        ) -> Result<Box<dyn GroupsAccumulator>>
1116        where
1117            T: ArrowPrimitiveType + Send,
1118        {
1119            let ordering_dtypes = args
1120                .ordering_req
1121                .iter()
1122                .map(|e| e.expr.data_type(args.schema))
1123                .collect::<Result<Vec<_>>>()?;
1124
1125            Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
1126                args.ordering_req.clone(),
1127                args.ignore_nulls,
1128                args.return_type,
1129                &ordering_dtypes,
1130                false,
1131            )?))
1132        }
1133
1134        match args.return_type {
1135            DataType::Int8 => create_accumulator::<Int8Type>(args),
1136            DataType::Int16 => create_accumulator::<Int16Type>(args),
1137            DataType::Int32 => create_accumulator::<Int32Type>(args),
1138            DataType::Int64 => create_accumulator::<Int64Type>(args),
1139            DataType::UInt8 => create_accumulator::<UInt8Type>(args),
1140            DataType::UInt16 => create_accumulator::<UInt16Type>(args),
1141            DataType::UInt32 => create_accumulator::<UInt32Type>(args),
1142            DataType::UInt64 => create_accumulator::<UInt64Type>(args),
1143            DataType::Float16 => create_accumulator::<Float16Type>(args),
1144            DataType::Float32 => create_accumulator::<Float32Type>(args),
1145            DataType::Float64 => create_accumulator::<Float64Type>(args),
1146
1147            DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
1148            DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
1149
1150            DataType::Timestamp(TimeUnit::Second, _) => {
1151                create_accumulator::<TimestampSecondType>(args)
1152            }
1153            DataType::Timestamp(TimeUnit::Millisecond, _) => {
1154                create_accumulator::<TimestampMillisecondType>(args)
1155            }
1156            DataType::Timestamp(TimeUnit::Microsecond, _) => {
1157                create_accumulator::<TimestampMicrosecondType>(args)
1158            }
1159            DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1160                create_accumulator::<TimestampNanosecondType>(args)
1161            }
1162
1163            DataType::Date32 => create_accumulator::<Date32Type>(args),
1164            DataType::Date64 => create_accumulator::<Date64Type>(args),
1165            DataType::Time32(TimeUnit::Second) => {
1166                create_accumulator::<Time32SecondType>(args)
1167            }
1168            DataType::Time32(TimeUnit::Millisecond) => {
1169                create_accumulator::<Time32MillisecondType>(args)
1170            }
1171
1172            DataType::Time64(TimeUnit::Microsecond) => {
1173                create_accumulator::<Time64MicrosecondType>(args)
1174            }
1175            DataType::Time64(TimeUnit::Nanosecond) => {
1176                create_accumulator::<Time64NanosecondType>(args)
1177            }
1178
1179            _ => {
1180                internal_err!(
1181                    "GroupsAccumulator not supported for last_value({})",
1182                    args.return_type
1183                )
1184            }
1185        }
1186    }
1187}
1188
1189#[derive(Debug)]
1190struct LastValueAccumulator {
1191    last: ScalarValue,
1192    // The `is_set` flag keeps track of whether the last value is finalized.
1193    // This information is used to discriminate genuine NULLs and NULLS that
1194    // occur due to empty partitions.
1195    is_set: bool,
1196    orderings: Vec<ScalarValue>,
1197    // Stores the applicable ordering requirement.
1198    ordering_req: LexOrdering,
1199    // Stores whether incoming data already satisfies the ordering requirement.
1200    requirement_satisfied: bool,
1201    // Ignore null values.
1202    ignore_nulls: bool,
1203}
1204
1205impl LastValueAccumulator {
1206    /// Creates a new `LastValueAccumulator` for the given `data_type`.
1207    pub fn try_new(
1208        data_type: &DataType,
1209        ordering_dtypes: &[DataType],
1210        ordering_req: LexOrdering,
1211        ignore_nulls: bool,
1212    ) -> Result<Self> {
1213        let orderings = ordering_dtypes
1214            .iter()
1215            .map(ScalarValue::try_from)
1216            .collect::<Result<Vec<_>>>()?;
1217        let requirement_satisfied = ordering_req.is_empty();
1218        ScalarValue::try_from(data_type).map(|last| Self {
1219            last,
1220            is_set: false,
1221            orderings,
1222            ordering_req,
1223            requirement_satisfied,
1224            ignore_nulls,
1225        })
1226    }
1227
1228    // Updates state with the values in the given row.
1229    fn update_with_new_row(&mut self, row: &[ScalarValue]) {
1230        self.last = row[0].clone();
1231        self.orderings = row[1..].to_vec();
1232        self.is_set = true;
1233    }
1234
1235    fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1236        let [value, ordering_values @ ..] = values else {
1237            return internal_err!("Empty row in LAST_VALUE");
1238        };
1239        if self.requirement_satisfied {
1240            // Get last entry according to the order of data:
1241            if self.ignore_nulls {
1242                // If ignoring nulls, find the last non-null value.
1243                for i in (0..value.len()).rev() {
1244                    if !value.is_null(i) {
1245                        return Ok(Some(i));
1246                    }
1247                }
1248                return Ok(None);
1249            } else {
1250                return Ok((!value.is_empty()).then_some(value.len() - 1));
1251            }
1252        }
1253        let sort_columns = ordering_values
1254            .iter()
1255            .zip(self.ordering_req.iter())
1256            .map(|(values, req)| SortColumn {
1257                values: Arc::clone(values),
1258                options: Some(req.options),
1259            })
1260            .collect::<Vec<_>>();
1261
1262        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1263        let max_ind = if self.ignore_nulls {
1264            (0..value.len())
1265                .filter(|&index| !(value.is_null(index)))
1266                .max_by(|&a, &b| comparator.compare(a, b))
1267        } else {
1268            (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1269        };
1270
1271        Ok(max_ind)
1272    }
1273
1274    fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
1275        self.requirement_satisfied = requirement_satisfied;
1276        self
1277    }
1278}
1279
1280impl Accumulator for LastValueAccumulator {
1281    fn state(&mut self) -> Result<Vec<ScalarValue>> {
1282        let mut result = vec![self.last.clone()];
1283        result.extend(self.orderings.clone());
1284        result.push(ScalarValue::Boolean(Some(self.is_set)));
1285        Ok(result)
1286    }
1287
1288    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1289        if !self.is_set || self.requirement_satisfied {
1290            if let Some(last_idx) = self.get_last_idx(values)? {
1291                let row = get_row_at_idx(values, last_idx)?;
1292                self.update_with_new_row(&row);
1293            }
1294        } else if let Some(last_idx) = self.get_last_idx(values)? {
1295            let row = get_row_at_idx(values, last_idx)?;
1296            let orderings = &row[1..];
1297            // Update when there is a more recent entry
1298            if compare_rows(
1299                &self.orderings,
1300                orderings,
1301                &get_sort_options(self.ordering_req.as_ref()),
1302            )?
1303            .is_lt()
1304            {
1305                self.update_with_new_row(&row);
1306            }
1307        }
1308
1309        Ok(())
1310    }
1311
1312    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1313        // LAST_VALUE(last1, last2, last3, ...)
1314        // last index contains is_set flag.
1315        let is_set_idx = states.len() - 1;
1316        let flags = states[is_set_idx].as_boolean();
1317        let filtered_states =
1318            filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1319        // 1..is_set_idx range corresponds to ordering section
1320        let sort_columns = convert_to_sort_cols(
1321            &filtered_states[1..is_set_idx],
1322            self.ordering_req.as_ref(),
1323        );
1324
1325        let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1326        let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1327
1328        if let Some(last_idx) = max {
1329            let last_row = get_row_at_idx(&filtered_states, last_idx)?;
1330            // When collecting orderings, we exclude the is_set flag from the state.
1331            let last_ordering = &last_row[1..is_set_idx];
1332            let sort_options = get_sort_options(self.ordering_req.as_ref());
1333            // Either there is no existing value, or there is a newer (latest)
1334            // version in the new data:
1335            if !self.is_set
1336                || self.requirement_satisfied
1337                || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt()
1338            {
1339                // Update with last value in the state. Note that we should exclude the
1340                // is_set flag from the state. Otherwise, we will end up with a state
1341                // containing two is_set flags.
1342                self.update_with_new_row(&last_row[0..is_set_idx]);
1343            }
1344        }
1345        Ok(())
1346    }
1347
1348    fn evaluate(&mut self) -> Result<ScalarValue> {
1349        Ok(self.last.clone())
1350    }
1351
1352    fn size(&self) -> usize {
1353        size_of_val(self) - size_of_val(&self.last)
1354            + self.last.size()
1355            + ScalarValue::size_of_vec(&self.orderings)
1356            - size_of_val(&self.orderings)
1357    }
1358}
1359
1360/// Filters states according to the `is_set` flag at the last column and returns
1361/// the resulting states.
1362fn filter_states_according_to_is_set(
1363    states: &[ArrayRef],
1364    flags: &BooleanArray,
1365) -> Result<Vec<ArrayRef>> {
1366    states
1367        .iter()
1368        .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1369        .collect::<Result<Vec<_>>>()
1370}
1371
1372/// Combines array refs and their corresponding orderings to construct `SortColumn`s.
1373fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1374    arrs.iter()
1375        .zip(sort_exprs.iter())
1376        .map(|(item, sort_expr)| SortColumn {
1377            values: Arc::clone(item),
1378            options: Some(sort_expr.options),
1379        })
1380        .collect::<Vec<_>>()
1381}
1382
1383#[cfg(test)]
1384mod tests {
1385    use arrow::{array::Int64Array, compute::SortOptions, datatypes::Schema};
1386    use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
1387
1388    use super::*;
1389
1390    #[test]
1391    fn test_first_last_value_value() -> Result<()> {
1392        let mut first_accumulator = FirstValueAccumulator::try_new(
1393            &DataType::Int64,
1394            &[],
1395            LexOrdering::default(),
1396            false,
1397        )?;
1398        let mut last_accumulator = LastValueAccumulator::try_new(
1399            &DataType::Int64,
1400            &[],
1401            LexOrdering::default(),
1402            false,
1403        )?;
1404        // first value in the tuple is start of the range (inclusive),
1405        // second value in the tuple is end of the range (exclusive)
1406        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1407        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
1408        let arrs = ranges
1409            .into_iter()
1410            .map(|(start, end)| {
1411                Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1412            })
1413            .collect::<Vec<_>>();
1414        for arr in arrs {
1415            // Once first_value is set, accumulator should remember it.
1416            // It shouldn't update first_value for each new batch
1417            first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1418            // last_value should be updated for each new batch.
1419            last_accumulator.update_batch(&[arr])?;
1420        }
1421        // First Value comes from the first value of the first batch which is 0
1422        assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1423        // Last value comes from the last value of the last batch which is 12
1424        assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1425        Ok(())
1426    }
1427
1428    #[test]
1429    fn test_first_last_state_after_merge() -> Result<()> {
1430        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1431        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
1432        let arrs = ranges
1433            .into_iter()
1434            .map(|(start, end)| {
1435                Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1436            })
1437            .collect::<Vec<_>>();
1438
1439        // FirstValueAccumulator
1440        let mut first_accumulator = FirstValueAccumulator::try_new(
1441            &DataType::Int64,
1442            &[],
1443            LexOrdering::default(),
1444            false,
1445        )?;
1446
1447        first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1448        let state1 = first_accumulator.state()?;
1449
1450        let mut first_accumulator = FirstValueAccumulator::try_new(
1451            &DataType::Int64,
1452            &[],
1453            LexOrdering::default(),
1454            false,
1455        )?;
1456        first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1457        let state2 = first_accumulator.state()?;
1458
1459        assert_eq!(state1.len(), state2.len());
1460
1461        let mut states = vec![];
1462
1463        for idx in 0..state1.len() {
1464            states.push(compute::concat(&[
1465                &state1[idx].to_array()?,
1466                &state2[idx].to_array()?,
1467            ])?);
1468        }
1469
1470        let mut first_accumulator = FirstValueAccumulator::try_new(
1471            &DataType::Int64,
1472            &[],
1473            LexOrdering::default(),
1474            false,
1475        )?;
1476        first_accumulator.merge_batch(&states)?;
1477
1478        let merged_state = first_accumulator.state()?;
1479        assert_eq!(merged_state.len(), state1.len());
1480
1481        // LastValueAccumulator
1482        let mut last_accumulator = LastValueAccumulator::try_new(
1483            &DataType::Int64,
1484            &[],
1485            LexOrdering::default(),
1486            false,
1487        )?;
1488
1489        last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1490        let state1 = last_accumulator.state()?;
1491
1492        let mut last_accumulator = LastValueAccumulator::try_new(
1493            &DataType::Int64,
1494            &[],
1495            LexOrdering::default(),
1496            false,
1497        )?;
1498        last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1499        let state2 = last_accumulator.state()?;
1500
1501        assert_eq!(state1.len(), state2.len());
1502
1503        let mut states = vec![];
1504
1505        for idx in 0..state1.len() {
1506            states.push(compute::concat(&[
1507                &state1[idx].to_array()?,
1508                &state2[idx].to_array()?,
1509            ])?);
1510        }
1511
1512        let mut last_accumulator = LastValueAccumulator::try_new(
1513            &DataType::Int64,
1514            &[],
1515            LexOrdering::default(),
1516            false,
1517        )?;
1518        last_accumulator.merge_batch(&states)?;
1519
1520        let merged_state = last_accumulator.state()?;
1521        assert_eq!(merged_state.len(), state1.len());
1522
1523        Ok(())
1524    }
1525
1526    #[test]
1527    fn test_frist_group_acc() -> Result<()> {
1528        let schema = Arc::new(Schema::new(vec![
1529            Field::new("a", DataType::Int64, true),
1530            Field::new("b", DataType::Int64, true),
1531            Field::new("c", DataType::Int64, true),
1532            Field::new("d", DataType::Int32, true),
1533            Field::new("e", DataType::Boolean, true),
1534        ]));
1535
1536        let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1537            expr: col("c", &schema).unwrap(),
1538            options: SortOptions::default(),
1539        }]);
1540
1541        let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1542            sort_key,
1543            true,
1544            &DataType::Int64,
1545            &[DataType::Int64],
1546            true,
1547        )?;
1548
1549        let mut val_with_orderings = {
1550            let mut val_with_orderings = Vec::<ArrayRef>::new();
1551
1552            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1553            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1554
1555            val_with_orderings.push(vals);
1556            val_with_orderings.push(orderings);
1557
1558            val_with_orderings
1559        };
1560
1561        group_acc.update_batch(
1562            &val_with_orderings,
1563            &[0, 1, 2, 1],
1564            Some(&BooleanArray::from(vec![true, true, false, true])),
1565            3,
1566        )?;
1567        assert_eq!(
1568            group_acc.size_of_orderings,
1569            group_acc.compute_size_of_orderings()
1570        );
1571
1572        let state = group_acc.state(EmitTo::All)?;
1573
1574        let expected_state: Vec<Arc<dyn Array>> = vec![
1575            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1576            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1577            Arc::new(BooleanArray::from(vec![true, true, false])),
1578        ];
1579        assert_eq!(state, expected_state);
1580
1581        assert_eq!(
1582            group_acc.size_of_orderings,
1583            group_acc.compute_size_of_orderings()
1584        );
1585
1586        group_acc.merge_batch(
1587            &state,
1588            &[0, 1, 2],
1589            Some(&BooleanArray::from(vec![true, false, false])),
1590            3,
1591        )?;
1592
1593        assert_eq!(
1594            group_acc.size_of_orderings,
1595            group_acc.compute_size_of_orderings()
1596        );
1597
1598        val_with_orderings.clear();
1599        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1600        val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1601
1602        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1603
1604        let binding = group_acc.evaluate(EmitTo::All)?;
1605        let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1606
1607        let expect: PrimitiveArray<Int64Type> =
1608            Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1609
1610        assert_eq!(eval_result, &expect);
1611
1612        assert_eq!(
1613            group_acc.size_of_orderings,
1614            group_acc.compute_size_of_orderings()
1615        );
1616
1617        Ok(())
1618    }
1619
1620    #[test]
1621    fn test_group_acc_size_of_ordering() -> Result<()> {
1622        let schema = Arc::new(Schema::new(vec![
1623            Field::new("a", DataType::Int64, true),
1624            Field::new("b", DataType::Int64, true),
1625            Field::new("c", DataType::Int64, true),
1626            Field::new("d", DataType::Int32, true),
1627            Field::new("e", DataType::Boolean, true),
1628        ]));
1629
1630        let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1631            expr: col("c", &schema).unwrap(),
1632            options: SortOptions::default(),
1633        }]);
1634
1635        let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1636            sort_key,
1637            true,
1638            &DataType::Int64,
1639            &[DataType::Int64],
1640            true,
1641        )?;
1642
1643        let val_with_orderings = {
1644            let mut val_with_orderings = Vec::<ArrayRef>::new();
1645
1646            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1647            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1648
1649            val_with_orderings.push(vals);
1650            val_with_orderings.push(orderings);
1651
1652            val_with_orderings
1653        };
1654
1655        for _ in 0..10 {
1656            group_acc.update_batch(
1657                &val_with_orderings,
1658                &[0, 1, 2, 1],
1659                Some(&BooleanArray::from(vec![true, true, false, true])),
1660                100,
1661            )?;
1662            assert_eq!(
1663                group_acc.size_of_orderings,
1664                group_acc.compute_size_of_orderings()
1665            );
1666
1667            group_acc.state(EmitTo::First(2))?;
1668            assert_eq!(
1669                group_acc.size_of_orderings,
1670                group_acc.compute_size_of_orderings()
1671            );
1672
1673            let s = group_acc.state(EmitTo::All)?;
1674            assert_eq!(
1675                group_acc.size_of_orderings,
1676                group_acc.compute_size_of_orderings()
1677            );
1678
1679            group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1680            assert_eq!(
1681                group_acc.size_of_orderings,
1682                group_acc.compute_size_of_orderings()
1683            );
1684
1685            group_acc.evaluate(EmitTo::First(2))?;
1686            assert_eq!(
1687                group_acc.size_of_orderings,
1688                group_acc.compute_size_of_orderings()
1689            );
1690
1691            group_acc.evaluate(EmitTo::All)?;
1692            assert_eq!(
1693                group_acc.size_of_orderings,
1694                group_acc.compute_size_of_orderings()
1695            );
1696        }
1697
1698        Ok(())
1699    }
1700
1701    #[test]
1702    fn test_last_group_acc() -> Result<()> {
1703        let schema = Arc::new(Schema::new(vec![
1704            Field::new("a", DataType::Int64, true),
1705            Field::new("b", DataType::Int64, true),
1706            Field::new("c", DataType::Int64, true),
1707            Field::new("d", DataType::Int32, true),
1708            Field::new("e", DataType::Boolean, true),
1709        ]));
1710
1711        let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1712            expr: col("c", &schema).unwrap(),
1713            options: SortOptions::default(),
1714        }]);
1715
1716        let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1717            sort_key,
1718            true,
1719            &DataType::Int64,
1720            &[DataType::Int64],
1721            false,
1722        )?;
1723
1724        let mut val_with_orderings = {
1725            let mut val_with_orderings = Vec::<ArrayRef>::new();
1726
1727            let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1728            let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1729
1730            val_with_orderings.push(vals);
1731            val_with_orderings.push(orderings);
1732
1733            val_with_orderings
1734        };
1735
1736        group_acc.update_batch(
1737            &val_with_orderings,
1738            &[0, 1, 2, 1],
1739            Some(&BooleanArray::from(vec![true, true, false, true])),
1740            3,
1741        )?;
1742
1743        let state = group_acc.state(EmitTo::All)?;
1744
1745        let expected_state: Vec<Arc<dyn Array>> = vec![
1746            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1747            Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1748            Arc::new(BooleanArray::from(vec![true, true, false])),
1749        ];
1750        assert_eq!(state, expected_state);
1751
1752        group_acc.merge_batch(
1753            &state,
1754            &[0, 1, 2],
1755            Some(&BooleanArray::from(vec![true, false, false])),
1756            3,
1757        )?;
1758
1759        val_with_orderings.clear();
1760        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1761        val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1762
1763        group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1764
1765        let binding = group_acc.evaluate(EmitTo::All)?;
1766        let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1767
1768        let expect: PrimitiveArray<Int64Type> =
1769            Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1770
1771        assert_eq!(eval_result, &expect);
1772
1773        Ok(())
1774    }
1775}