1use std::any::Any;
21use std::fmt::Debug;
22use std::mem::size_of_val;
23use std::sync::Arc;
24
25use arrow::array::{
26 Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder,
27 PrimitiveArray,
28};
29use arrow::buffer::{BooleanBuffer, NullBuffer};
30use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
31use arrow::datatypes::{
32 DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float16Type,
33 Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
34 Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
35 TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
36 TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
37 UInt8Type,
38};
39use datafusion_common::cast::as_boolean_array;
40use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
41use datafusion_common::{
42 arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
43};
44use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
45use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
46use datafusion_expr::{
47 Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
48 GroupsAccumulator, Signature, SortExpr, Volatility,
49};
50use datafusion_functions_aggregate_common::utils::get_sort_options;
51use datafusion_macros::user_doc;
52use datafusion_physical_expr_common::sort_expr::LexOrdering;
53
54create_func!(FirstValue, first_value_udaf);
55create_func!(LastValue, last_value_udaf);
56
57pub fn first_value(expression: Expr, order_by: Option<Vec<SortExpr>>) -> Expr {
59 if let Some(order_by) = order_by {
60 first_value_udaf()
61 .call(vec![expression])
62 .order_by(order_by)
63 .build()
64 .unwrap()
66 } else {
67 first_value_udaf().call(vec![expression])
68 }
69}
70
71pub fn last_value(expression: Expr, order_by: Option<Vec<SortExpr>>) -> Expr {
73 if let Some(order_by) = order_by {
74 last_value_udaf()
75 .call(vec![expression])
76 .order_by(order_by)
77 .build()
78 .unwrap()
80 } else {
81 last_value_udaf().call(vec![expression])
82 }
83}
84
85#[user_doc(
86 doc_section(label = "General Functions"),
87 description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
88 syntax_example = "first_value(expression [ORDER BY expression])",
89 sql_example = r#"```sql
90> SELECT first_value(column_name ORDER BY other_column) FROM table_name;
91+-----------------------------------------------+
92| first_value(column_name ORDER BY other_column)|
93+-----------------------------------------------+
94| first_element |
95+-----------------------------------------------+
96```"#,
97 standard_argument(name = "expression",)
98)]
99pub struct FirstValue {
100 signature: Signature,
101 requirement_satisfied: bool,
102}
103
104impl Debug for FirstValue {
105 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
106 f.debug_struct("FirstValue")
107 .field("name", &self.name())
108 .field("signature", &self.signature)
109 .field("accumulator", &"<FUNC>")
110 .finish()
111 }
112}
113
114impl Default for FirstValue {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120impl FirstValue {
121 pub fn new() -> Self {
122 Self {
123 signature: Signature::any(1, Volatility::Immutable),
124 requirement_satisfied: false,
125 }
126 }
127
128 fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
129 self.requirement_satisfied = requirement_satisfied;
130 self
131 }
132}
133
134impl AggregateUDFImpl for FirstValue {
135 fn as_any(&self) -> &dyn Any {
136 self
137 }
138
139 fn name(&self) -> &str {
140 "first_value"
141 }
142
143 fn signature(&self) -> &Signature {
144 &self.signature
145 }
146
147 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
148 Ok(arg_types[0].clone())
149 }
150
151 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
152 let ordering_dtypes = acc_args
153 .ordering_req
154 .iter()
155 .map(|e| e.expr.data_type(acc_args.schema))
156 .collect::<Result<Vec<_>>>()?;
157
158 let requirement_satisfied =
161 acc_args.ordering_req.is_empty() || self.requirement_satisfied;
162
163 FirstValueAccumulator::try_new(
164 acc_args.return_type,
165 &ordering_dtypes,
166 acc_args.ordering_req.clone(),
167 acc_args.ignore_nulls,
168 )
169 .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
170 }
171
172 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
173 let mut fields = vec![Field::new(
174 format_state_name(args.name, "first_value"),
175 args.return_type.clone(),
176 true,
177 )];
178 fields.extend(args.ordering_fields.to_vec());
179 fields.push(Field::new("is_set", DataType::Boolean, true));
180 Ok(fields)
181 }
182
183 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
184 use DataType::*;
186 matches!(
187 args.return_type,
188 Int8 | Int16
189 | Int32
190 | Int64
191 | UInt8
192 | UInt16
193 | UInt32
194 | UInt64
195 | Float16
196 | Float32
197 | Float64
198 | Decimal128(_, _)
199 | Decimal256(_, _)
200 | Date32
201 | Date64
202 | Time32(_)
203 | Time64(_)
204 | Timestamp(_, _)
205 )
206 }
207
208 fn create_groups_accumulator(
209 &self,
210 args: AccumulatorArgs,
211 ) -> Result<Box<dyn GroupsAccumulator>> {
212 fn create_accumulator<T>(
214 args: AccumulatorArgs,
215 ) -> Result<Box<dyn GroupsAccumulator>>
216 where
217 T: ArrowPrimitiveType + Send,
218 {
219 let ordering_dtypes = args
220 .ordering_req
221 .iter()
222 .map(|e| e.expr.data_type(args.schema))
223 .collect::<Result<Vec<_>>>()?;
224
225 Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
226 args.ordering_req.clone(),
227 args.ignore_nulls,
228 args.return_type,
229 &ordering_dtypes,
230 true,
231 )?))
232 }
233
234 match args.return_type {
235 DataType::Int8 => create_accumulator::<Int8Type>(args),
236 DataType::Int16 => create_accumulator::<Int16Type>(args),
237 DataType::Int32 => create_accumulator::<Int32Type>(args),
238 DataType::Int64 => create_accumulator::<Int64Type>(args),
239 DataType::UInt8 => create_accumulator::<UInt8Type>(args),
240 DataType::UInt16 => create_accumulator::<UInt16Type>(args),
241 DataType::UInt32 => create_accumulator::<UInt32Type>(args),
242 DataType::UInt64 => create_accumulator::<UInt64Type>(args),
243 DataType::Float16 => create_accumulator::<Float16Type>(args),
244 DataType::Float32 => create_accumulator::<Float32Type>(args),
245 DataType::Float64 => create_accumulator::<Float64Type>(args),
246
247 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
248 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
249
250 DataType::Timestamp(TimeUnit::Second, _) => {
251 create_accumulator::<TimestampSecondType>(args)
252 }
253 DataType::Timestamp(TimeUnit::Millisecond, _) => {
254 create_accumulator::<TimestampMillisecondType>(args)
255 }
256 DataType::Timestamp(TimeUnit::Microsecond, _) => {
257 create_accumulator::<TimestampMicrosecondType>(args)
258 }
259 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
260 create_accumulator::<TimestampNanosecondType>(args)
261 }
262
263 DataType::Date32 => create_accumulator::<Date32Type>(args),
264 DataType::Date64 => create_accumulator::<Date64Type>(args),
265 DataType::Time32(TimeUnit::Second) => {
266 create_accumulator::<Time32SecondType>(args)
267 }
268 DataType::Time32(TimeUnit::Millisecond) => {
269 create_accumulator::<Time32MillisecondType>(args)
270 }
271
272 DataType::Time64(TimeUnit::Microsecond) => {
273 create_accumulator::<Time64MicrosecondType>(args)
274 }
275 DataType::Time64(TimeUnit::Nanosecond) => {
276 create_accumulator::<Time64NanosecondType>(args)
277 }
278
279 _ => {
280 internal_err!(
281 "GroupsAccumulator not supported for first_value({})",
282 args.return_type
283 )
284 }
285 }
286 }
287
288 fn aliases(&self) -> &[String] {
289 &[]
290 }
291
292 fn with_beneficial_ordering(
293 self: Arc<Self>,
294 beneficial_ordering: bool,
295 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
296 Ok(Some(Arc::new(
297 FirstValue::new().with_requirement_satisfied(beneficial_ordering),
298 )))
299 }
300
301 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
302 AggregateOrderSensitivity::Beneficial
303 }
304
305 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
306 datafusion_expr::ReversedUDAF::Reversed(last_value_udaf())
307 }
308
309 fn documentation(&self) -> Option<&Documentation> {
310 self.doc()
311 }
312}
313
314struct FirstPrimitiveGroupsAccumulator<T>
316where
317 T: ArrowPrimitiveType + Send,
318{
319 vals: Vec<T::Native>,
321 orderings: Vec<Vec<ScalarValue>>,
326 is_sets: BooleanBufferBuilder,
329 null_builder: BooleanBufferBuilder,
331 size_of_orderings: usize,
336
337 min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
342
343 ordering_req: LexOrdering,
347 pick_first_in_group: bool,
350 sort_options: Vec<SortOptions>,
352 input_requirement_satisfied: bool,
354 ignore_nulls: bool,
356 data_type: DataType,
358 default_orderings: Vec<ScalarValue>,
359}
360
361impl<T> FirstPrimitiveGroupsAccumulator<T>
362where
363 T: ArrowPrimitiveType + Send,
364{
365 fn try_new(
366 ordering_req: LexOrdering,
367 ignore_nulls: bool,
368 data_type: &DataType,
369 ordering_dtypes: &[DataType],
370 pick_first_in_group: bool,
371 ) -> Result<Self> {
372 let requirement_satisfied = ordering_req.is_empty();
373
374 let default_orderings = ordering_dtypes
375 .iter()
376 .map(ScalarValue::try_from)
377 .collect::<Result<Vec<_>>>()?;
378
379 let sort_options = get_sort_options(ordering_req.as_ref());
380
381 Ok(Self {
382 null_builder: BooleanBufferBuilder::new(0),
383 ordering_req,
384 sort_options,
385 input_requirement_satisfied: requirement_satisfied,
386 ignore_nulls,
387 default_orderings,
388 data_type: data_type.clone(),
389 vals: Vec::new(),
390 orderings: Vec::new(),
391 is_sets: BooleanBufferBuilder::new(0),
392 size_of_orderings: 0,
393 min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
394 pick_first_in_group,
395 })
396 }
397
398 fn need_update(&self, group_idx: usize) -> bool {
399 if !self.is_sets.get_bit(group_idx) {
400 return true;
401 }
402
403 if self.ignore_nulls && !self.null_builder.get_bit(group_idx) {
404 return true;
405 }
406
407 !self.input_requirement_satisfied
408 }
409
410 fn should_update_state(
411 &self,
412 group_idx: usize,
413 new_ordering_values: &[ScalarValue],
414 ) -> Result<bool> {
415 if !self.is_sets.get_bit(group_idx) {
416 return Ok(true);
417 }
418
419 assert!(new_ordering_values.len() == self.ordering_req.len());
420 let current_ordering = &self.orderings[group_idx];
421 compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
422 if self.pick_first_in_group {
423 x.is_gt()
424 } else {
425 x.is_lt()
426 }
427 })
428 }
429
430 fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
431 let result = emit_to.take_needed(&mut self.orderings);
432
433 match emit_to {
434 EmitTo::All => self.size_of_orderings = 0,
435 EmitTo::First(_) => {
436 self.size_of_orderings -=
437 result.iter().map(ScalarValue::size_of_vec).sum::<usize>()
438 }
439 }
440
441 result
442 }
443
444 fn take_need(
445 bool_buf_builder: &mut BooleanBufferBuilder,
446 emit_to: EmitTo,
447 ) -> BooleanBuffer {
448 let bool_buf = bool_buf_builder.finish();
449 match emit_to {
450 EmitTo::All => bool_buf,
451 EmitTo::First(n) => {
452 let first_n: BooleanBuffer = bool_buf.iter().take(n).collect();
457 for b in bool_buf.iter().skip(n) {
459 bool_buf_builder.append(b);
460 }
461 first_n
462 }
463 }
464 }
465
466 fn resize_states(&mut self, new_size: usize) {
467 self.vals.resize(new_size, T::default_value());
468
469 self.null_builder.resize(new_size);
470
471 if self.orderings.len() < new_size {
472 let current_len = self.orderings.len();
473
474 self.orderings
475 .resize(new_size, self.default_orderings.clone());
476
477 self.size_of_orderings += (new_size - current_len)
478 * ScalarValue::size_of_vec(
479 self.orderings.last().unwrap(),
483 );
484 }
485
486 self.is_sets.resize(new_size);
487
488 self.min_of_each_group_buf.0.resize(new_size, 0);
489 self.min_of_each_group_buf.1.resize(new_size);
490 }
491
492 fn update_state(
493 &mut self,
494 group_idx: usize,
495 orderings: &[ScalarValue],
496 new_val: T::Native,
497 is_null: bool,
498 ) {
499 self.vals[group_idx] = new_val;
500 self.is_sets.set_bit(group_idx, true);
501
502 self.null_builder.set_bit(group_idx, !is_null);
503
504 assert!(orderings.len() == self.ordering_req.len());
505 let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
506 self.orderings[group_idx].clear();
507 self.orderings[group_idx].extend_from_slice(orderings);
508 let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]);
509 self.size_of_orderings = self.size_of_orderings - old_size + new_size;
510 }
511
512 fn take_state(
513 &mut self,
514 emit_to: EmitTo,
515 ) -> (ArrayRef, Vec<Vec<ScalarValue>>, BooleanBuffer) {
516 emit_to.take_needed(&mut self.min_of_each_group_buf.0);
517 self.min_of_each_group_buf
518 .1
519 .truncate(self.min_of_each_group_buf.0.len());
520
521 (
522 self.take_vals_and_null_buf(emit_to),
523 self.take_orderings(emit_to),
524 Self::take_need(&mut self.is_sets, emit_to),
525 )
526 }
527
528 #[cfg(test)]
530 fn compute_size_of_orderings(&self) -> usize {
531 self.orderings
532 .iter()
533 .map(ScalarValue::size_of_vec)
534 .sum::<usize>()
535 }
536 fn get_filtered_min_of_each_group(
541 &mut self,
542 orderings: &[ArrayRef],
543 group_indices: &[usize],
544 opt_filter: Option<&BooleanArray>,
545 vals: &PrimitiveArray<T>,
546 is_set_arr: Option<&BooleanArray>,
547 ) -> Result<Vec<(usize, usize)>> {
548 self.min_of_each_group_buf.1.truncate(0);
550 self.min_of_each_group_buf
551 .1
552 .append_n(self.vals.len(), false);
553
554 let comparator = {
558 assert_eq!(orderings.len(), self.ordering_req.len());
559 let sort_columns = orderings
560 .iter()
561 .zip(self.ordering_req.iter())
562 .map(|(array, req)| SortColumn {
563 values: Arc::clone(array),
564 options: Some(req.options),
565 })
566 .collect::<Vec<_>>();
567
568 LexicographicalComparator::try_new(&sort_columns)?
569 };
570
571 for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
572 let group_idx = *group_idx;
573
574 let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
575
576 let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));
577
578 if !passed_filter || !is_set {
579 continue;
580 }
581
582 if !self.need_update(group_idx) {
583 continue;
584 }
585
586 if self.ignore_nulls && vals.is_null(idx_in_val) {
587 continue;
588 }
589
590 let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
591
592 if !is_valid {
593 self.min_of_each_group_buf.1.set_bit(group_idx, true);
594 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
595 } else {
596 let ordering = comparator
597 .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);
598
599 if (ordering.is_gt() && self.pick_first_in_group)
600 || (ordering.is_lt() && !self.pick_first_in_group)
601 {
602 self.min_of_each_group_buf.0[group_idx] = idx_in_val;
603 }
604 }
605 }
606
607 Ok(self
608 .min_of_each_group_buf
609 .0
610 .iter()
611 .enumerate()
612 .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx))
613 .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val))
614 .collect::<Vec<_>>())
615 }
616
617 fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef {
618 let r = emit_to.take_needed(&mut self.vals);
619
620 let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to));
621
622 let values = PrimitiveArray::<T>::new(r.into(), Some(null_buf)) .with_data_type(self.data_type.clone());
624 Arc::new(values)
625 }
626}
627
628impl<T> GroupsAccumulator for FirstPrimitiveGroupsAccumulator<T>
629where
630 T: ArrowPrimitiveType + Send,
631{
632 fn update_batch(
633 &mut self,
634 values_and_order_cols: &[ArrayRef],
636 group_indices: &[usize],
637 opt_filter: Option<&BooleanArray>,
638 total_num_groups: usize,
639 ) -> Result<()> {
640 self.resize_states(total_num_groups);
641
642 let vals = values_and_order_cols[0].as_primitive::<T>();
643
644 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
645
646 for (group_idx, idx) in self
648 .get_filtered_min_of_each_group(
649 &values_and_order_cols[1..],
650 group_indices,
651 opt_filter,
652 vals,
653 None,
654 )?
655 .into_iter()
656 {
657 extract_row_at_idx_to_buf(
658 &values_and_order_cols[1..],
659 idx,
660 &mut ordering_buf,
661 )?;
662
663 if self.should_update_state(group_idx, &ordering_buf)? {
664 self.update_state(
665 group_idx,
666 &ordering_buf,
667 vals.value(idx),
668 vals.is_null(idx),
669 );
670 }
671 }
672
673 Ok(())
674 }
675
676 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
677 Ok(self.take_state(emit_to).0)
678 }
679
680 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
681 let (val_arr, orderings, is_sets) = self.take_state(emit_to);
682 let mut result = Vec::with_capacity(self.orderings.len() + 2);
683
684 result.push(val_arr);
685
686 let ordering_cols = {
687 let mut ordering_cols = Vec::with_capacity(self.ordering_req.len());
688 for _ in 0..self.ordering_req.len() {
689 ordering_cols.push(Vec::with_capacity(self.orderings.len()));
690 }
691 for row in orderings.into_iter() {
692 assert_eq!(row.len(), self.ordering_req.len());
693 for (col_idx, ordering) in row.into_iter().enumerate() {
694 ordering_cols[col_idx].push(ordering);
695 }
696 }
697
698 ordering_cols
699 };
700 for ordering_col in ordering_cols {
701 result.push(ScalarValue::iter_to_array(ordering_col)?);
702 }
703
704 result.push(Arc::new(BooleanArray::new(is_sets, None)));
705
706 Ok(result)
707 }
708
709 fn merge_batch(
710 &mut self,
711 values: &[ArrayRef],
712 group_indices: &[usize],
713 opt_filter: Option<&BooleanArray>,
714 total_num_groups: usize,
715 ) -> Result<()> {
716 self.resize_states(total_num_groups);
717
718 let mut ordering_buf = Vec::with_capacity(self.ordering_req.len());
719
720 let (is_set_arr, val_and_order_cols) = match values.split_last() {
721 Some(result) => result,
722 None => return internal_err!("Empty row in FISRT_VALUE"),
723 };
724
725 let is_set_arr = as_boolean_array(is_set_arr)?;
726
727 let vals = values[0].as_primitive::<T>();
728 let groups = self.get_filtered_min_of_each_group(
730 &val_and_order_cols[1..],
731 group_indices,
732 opt_filter,
733 vals,
734 Some(is_set_arr),
735 )?;
736
737 for (group_idx, idx) in groups.into_iter() {
738 extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?;
739
740 if self.should_update_state(group_idx, &ordering_buf)? {
741 self.update_state(
742 group_idx,
743 &ordering_buf,
744 vals.value(idx),
745 vals.is_null(idx),
746 );
747 }
748 }
749
750 Ok(())
751 }
752
753 fn size(&self) -> usize {
754 self.vals.capacity() * size_of::<T::Native>()
755 + self.null_builder.capacity() / 8 + self.is_sets.capacity() / 8
757 + self.size_of_orderings
758 + self.min_of_each_group_buf.0.capacity() * size_of::<usize>()
759 + self.min_of_each_group_buf.1.capacity() / 8
760 }
761
762 fn supports_convert_to_state(&self) -> bool {
763 true
764 }
765
766 fn convert_to_state(
767 &self,
768 values: &[ArrayRef],
769 opt_filter: Option<&BooleanArray>,
770 ) -> Result<Vec<ArrayRef>> {
771 let mut result = values.to_vec();
772 match opt_filter {
773 Some(f) => {
774 result.push(Arc::new(f.clone()));
775 Ok(result)
776 }
777 None => {
778 result.push(Arc::new(BooleanArray::from(vec![true; values[0].len()])));
779 Ok(result)
780 }
781 }
782 }
783}
784#[derive(Debug)]
785pub struct FirstValueAccumulator {
786 first: ScalarValue,
787 is_set: bool,
790 orderings: Vec<ScalarValue>,
793 ordering_req: LexOrdering,
795 requirement_satisfied: bool,
797 ignore_nulls: bool,
799}
800
801impl FirstValueAccumulator {
802 pub fn try_new(
804 data_type: &DataType,
805 ordering_dtypes: &[DataType],
806 ordering_req: LexOrdering,
807 ignore_nulls: bool,
808 ) -> Result<Self> {
809 let orderings = ordering_dtypes
810 .iter()
811 .map(ScalarValue::try_from)
812 .collect::<Result<Vec<_>>>()?;
813 let requirement_satisfied = ordering_req.is_empty();
814 ScalarValue::try_from(data_type).map(|first| Self {
815 first,
816 is_set: false,
817 orderings,
818 ordering_req,
819 requirement_satisfied,
820 ignore_nulls,
821 })
822 }
823
824 pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
825 self.requirement_satisfied = requirement_satisfied;
826 self
827 }
828
829 fn update_with_new_row(&mut self, row: &[ScalarValue]) {
831 self.first = row[0].clone();
832 self.orderings = row[1..].to_vec();
833 self.is_set = true;
834 }
835
836 fn get_first_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
837 let [value, ordering_values @ ..] = values else {
838 return internal_err!("Empty row in FIRST_VALUE");
839 };
840 if self.requirement_satisfied {
841 if self.ignore_nulls {
843 for i in 0..value.len() {
845 if !value.is_null(i) {
846 return Ok(Some(i));
847 }
848 }
849 return Ok(None);
850 } else {
851 return Ok((!value.is_empty()).then_some(0));
853 }
854 }
855
856 let sort_columns = ordering_values
857 .iter()
858 .zip(self.ordering_req.iter())
859 .map(|(values, req)| SortColumn {
860 values: Arc::clone(values),
861 options: Some(req.options),
862 })
863 .collect::<Vec<_>>();
864
865 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
866
867 let min_index = if self.ignore_nulls {
868 (0..value.len())
869 .filter(|&index| !value.is_null(index))
870 .min_by(|&a, &b| comparator.compare(a, b))
871 } else {
872 (0..value.len()).min_by(|&a, &b| comparator.compare(a, b))
873 };
874
875 Ok(min_index)
876 }
877}
878
879impl Accumulator for FirstValueAccumulator {
880 fn state(&mut self) -> Result<Vec<ScalarValue>> {
881 let mut result = vec![self.first.clone()];
882 result.extend(self.orderings.iter().cloned());
883 result.push(ScalarValue::Boolean(Some(self.is_set)));
884 Ok(result)
885 }
886
887 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
888 if !self.is_set {
889 if let Some(first_idx) = self.get_first_idx(values)? {
890 let row = get_row_at_idx(values, first_idx)?;
891 self.update_with_new_row(&row);
892 }
893 } else if !self.requirement_satisfied {
894 if let Some(first_idx) = self.get_first_idx(values)? {
895 let row = get_row_at_idx(values, first_idx)?;
896 let orderings = &row[1..];
897 if compare_rows(
898 &self.orderings,
899 orderings,
900 &get_sort_options(self.ordering_req.as_ref()),
901 )?
902 .is_gt()
903 {
904 self.update_with_new_row(&row);
905 }
906 }
907 }
908 Ok(())
909 }
910
911 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
912 let is_set_idx = states.len() - 1;
915 let flags = states[is_set_idx].as_boolean();
916 let filtered_states =
917 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
918 let sort_columns = convert_to_sort_cols(
920 &filtered_states[1..is_set_idx],
921 self.ordering_req.as_ref(),
922 );
923
924 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
925 let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b));
926
927 if let Some(first_idx) = min {
928 let first_row = get_row_at_idx(&filtered_states, first_idx)?;
929 let first_ordering = &first_row[1..is_set_idx];
931 let sort_options = get_sort_options(self.ordering_req.as_ref());
932 if !self.is_set
934 || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt()
935 {
936 self.update_with_new_row(&first_row[0..is_set_idx]);
940 }
941 }
942 Ok(())
943 }
944
945 fn evaluate(&mut self) -> Result<ScalarValue> {
946 Ok(self.first.clone())
947 }
948
949 fn size(&self) -> usize {
950 size_of_val(self) - size_of_val(&self.first)
951 + self.first.size()
952 + ScalarValue::size_of_vec(&self.orderings)
953 - size_of_val(&self.orderings)
954 }
955}
956
957#[user_doc(
958 doc_section(label = "General Functions"),
959 description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.",
960 syntax_example = "last_value(expression [ORDER BY expression])",
961 sql_example = r#"```sql
962> SELECT last_value(column_name ORDER BY other_column) FROM table_name;
963+-----------------------------------------------+
964| last_value(column_name ORDER BY other_column) |
965+-----------------------------------------------+
966| last_element |
967+-----------------------------------------------+
968```"#,
969 standard_argument(name = "expression",)
970)]
971pub struct LastValue {
972 signature: Signature,
973 requirement_satisfied: bool,
974}
975
976impl Debug for LastValue {
977 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
978 f.debug_struct("LastValue")
979 .field("name", &self.name())
980 .field("signature", &self.signature)
981 .field("accumulator", &"<FUNC>")
982 .finish()
983 }
984}
985
986impl Default for LastValue {
987 fn default() -> Self {
988 Self::new()
989 }
990}
991
992impl LastValue {
993 pub fn new() -> Self {
994 Self {
995 signature: Signature::any(1, Volatility::Immutable),
996 requirement_satisfied: false,
997 }
998 }
999
1000 fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
1001 self.requirement_satisfied = requirement_satisfied;
1002 self
1003 }
1004}
1005
1006impl AggregateUDFImpl for LastValue {
1007 fn as_any(&self) -> &dyn Any {
1008 self
1009 }
1010
1011 fn name(&self) -> &str {
1012 "last_value"
1013 }
1014
1015 fn signature(&self) -> &Signature {
1016 &self.signature
1017 }
1018
1019 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1020 Ok(arg_types[0].clone())
1021 }
1022
1023 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
1024 let ordering_dtypes = acc_args
1025 .ordering_req
1026 .iter()
1027 .map(|e| e.expr.data_type(acc_args.schema))
1028 .collect::<Result<Vec<_>>>()?;
1029
1030 let requirement_satisfied =
1031 acc_args.ordering_req.is_empty() || self.requirement_satisfied;
1032
1033 LastValueAccumulator::try_new(
1034 acc_args.return_type,
1035 &ordering_dtypes,
1036 acc_args.ordering_req.clone(),
1037 acc_args.ignore_nulls,
1038 )
1039 .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
1040 }
1041
1042 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
1043 let StateFieldsArgs {
1044 name,
1045 input_types,
1046 return_type: _,
1047 ordering_fields,
1048 is_distinct: _,
1049 } = args;
1050 let mut fields = vec![Field::new(
1051 format_state_name(name, "last_value"),
1052 input_types[0].clone(),
1053 true,
1054 )];
1055 fields.extend(ordering_fields.to_vec());
1056 fields.push(Field::new("is_set", DataType::Boolean, true));
1057 Ok(fields)
1058 }
1059
1060 fn aliases(&self) -> &[String] {
1061 &[]
1062 }
1063
1064 fn with_beneficial_ordering(
1065 self: Arc<Self>,
1066 beneficial_ordering: bool,
1067 ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
1068 Ok(Some(Arc::new(
1069 LastValue::new().with_requirement_satisfied(beneficial_ordering),
1070 )))
1071 }
1072
1073 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
1074 AggregateOrderSensitivity::Beneficial
1075 }
1076
1077 fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
1078 datafusion_expr::ReversedUDAF::Reversed(first_value_udaf())
1079 }
1080
1081 fn documentation(&self) -> Option<&Documentation> {
1082 self.doc()
1083 }
1084
1085 fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1086 use DataType::*;
1087 matches!(
1088 args.return_type,
1089 Int8 | Int16
1090 | Int32
1091 | Int64
1092 | UInt8
1093 | UInt16
1094 | UInt32
1095 | UInt64
1096 | Float16
1097 | Float32
1098 | Float64
1099 | Decimal128(_, _)
1100 | Decimal256(_, _)
1101 | Date32
1102 | Date64
1103 | Time32(_)
1104 | Time64(_)
1105 | Timestamp(_, _)
1106 )
1107 }
1108
1109 fn create_groups_accumulator(
1110 &self,
1111 args: AccumulatorArgs,
1112 ) -> Result<Box<dyn GroupsAccumulator>> {
1113 fn create_accumulator<T>(
1114 args: AccumulatorArgs,
1115 ) -> Result<Box<dyn GroupsAccumulator>>
1116 where
1117 T: ArrowPrimitiveType + Send,
1118 {
1119 let ordering_dtypes = args
1120 .ordering_req
1121 .iter()
1122 .map(|e| e.expr.data_type(args.schema))
1123 .collect::<Result<Vec<_>>>()?;
1124
1125 Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
1126 args.ordering_req.clone(),
1127 args.ignore_nulls,
1128 args.return_type,
1129 &ordering_dtypes,
1130 false,
1131 )?))
1132 }
1133
1134 match args.return_type {
1135 DataType::Int8 => create_accumulator::<Int8Type>(args),
1136 DataType::Int16 => create_accumulator::<Int16Type>(args),
1137 DataType::Int32 => create_accumulator::<Int32Type>(args),
1138 DataType::Int64 => create_accumulator::<Int64Type>(args),
1139 DataType::UInt8 => create_accumulator::<UInt8Type>(args),
1140 DataType::UInt16 => create_accumulator::<UInt16Type>(args),
1141 DataType::UInt32 => create_accumulator::<UInt32Type>(args),
1142 DataType::UInt64 => create_accumulator::<UInt64Type>(args),
1143 DataType::Float16 => create_accumulator::<Float16Type>(args),
1144 DataType::Float32 => create_accumulator::<Float32Type>(args),
1145 DataType::Float64 => create_accumulator::<Float64Type>(args),
1146
1147 DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
1148 DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
1149
1150 DataType::Timestamp(TimeUnit::Second, _) => {
1151 create_accumulator::<TimestampSecondType>(args)
1152 }
1153 DataType::Timestamp(TimeUnit::Millisecond, _) => {
1154 create_accumulator::<TimestampMillisecondType>(args)
1155 }
1156 DataType::Timestamp(TimeUnit::Microsecond, _) => {
1157 create_accumulator::<TimestampMicrosecondType>(args)
1158 }
1159 DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1160 create_accumulator::<TimestampNanosecondType>(args)
1161 }
1162
1163 DataType::Date32 => create_accumulator::<Date32Type>(args),
1164 DataType::Date64 => create_accumulator::<Date64Type>(args),
1165 DataType::Time32(TimeUnit::Second) => {
1166 create_accumulator::<Time32SecondType>(args)
1167 }
1168 DataType::Time32(TimeUnit::Millisecond) => {
1169 create_accumulator::<Time32MillisecondType>(args)
1170 }
1171
1172 DataType::Time64(TimeUnit::Microsecond) => {
1173 create_accumulator::<Time64MicrosecondType>(args)
1174 }
1175 DataType::Time64(TimeUnit::Nanosecond) => {
1176 create_accumulator::<Time64NanosecondType>(args)
1177 }
1178
1179 _ => {
1180 internal_err!(
1181 "GroupsAccumulator not supported for last_value({})",
1182 args.return_type
1183 )
1184 }
1185 }
1186 }
1187}
1188
1189#[derive(Debug)]
1190struct LastValueAccumulator {
1191 last: ScalarValue,
1192 is_set: bool,
1196 orderings: Vec<ScalarValue>,
1197 ordering_req: LexOrdering,
1199 requirement_satisfied: bool,
1201 ignore_nulls: bool,
1203}
1204
1205impl LastValueAccumulator {
1206 pub fn try_new(
1208 data_type: &DataType,
1209 ordering_dtypes: &[DataType],
1210 ordering_req: LexOrdering,
1211 ignore_nulls: bool,
1212 ) -> Result<Self> {
1213 let orderings = ordering_dtypes
1214 .iter()
1215 .map(ScalarValue::try_from)
1216 .collect::<Result<Vec<_>>>()?;
1217 let requirement_satisfied = ordering_req.is_empty();
1218 ScalarValue::try_from(data_type).map(|last| Self {
1219 last,
1220 is_set: false,
1221 orderings,
1222 ordering_req,
1223 requirement_satisfied,
1224 ignore_nulls,
1225 })
1226 }
1227
1228 fn update_with_new_row(&mut self, row: &[ScalarValue]) {
1230 self.last = row[0].clone();
1231 self.orderings = row[1..].to_vec();
1232 self.is_set = true;
1233 }
1234
1235 fn get_last_idx(&self, values: &[ArrayRef]) -> Result<Option<usize>> {
1236 let [value, ordering_values @ ..] = values else {
1237 return internal_err!("Empty row in LAST_VALUE");
1238 };
1239 if self.requirement_satisfied {
1240 if self.ignore_nulls {
1242 for i in (0..value.len()).rev() {
1244 if !value.is_null(i) {
1245 return Ok(Some(i));
1246 }
1247 }
1248 return Ok(None);
1249 } else {
1250 return Ok((!value.is_empty()).then_some(value.len() - 1));
1251 }
1252 }
1253 let sort_columns = ordering_values
1254 .iter()
1255 .zip(self.ordering_req.iter())
1256 .map(|(values, req)| SortColumn {
1257 values: Arc::clone(values),
1258 options: Some(req.options),
1259 })
1260 .collect::<Vec<_>>();
1261
1262 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1263 let max_ind = if self.ignore_nulls {
1264 (0..value.len())
1265 .filter(|&index| !(value.is_null(index)))
1266 .max_by(|&a, &b| comparator.compare(a, b))
1267 } else {
1268 (0..value.len()).max_by(|&a, &b| comparator.compare(a, b))
1269 };
1270
1271 Ok(max_ind)
1272 }
1273
1274 fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self {
1275 self.requirement_satisfied = requirement_satisfied;
1276 self
1277 }
1278}
1279
1280impl Accumulator for LastValueAccumulator {
1281 fn state(&mut self) -> Result<Vec<ScalarValue>> {
1282 let mut result = vec![self.last.clone()];
1283 result.extend(self.orderings.clone());
1284 result.push(ScalarValue::Boolean(Some(self.is_set)));
1285 Ok(result)
1286 }
1287
1288 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
1289 if !self.is_set || self.requirement_satisfied {
1290 if let Some(last_idx) = self.get_last_idx(values)? {
1291 let row = get_row_at_idx(values, last_idx)?;
1292 self.update_with_new_row(&row);
1293 }
1294 } else if let Some(last_idx) = self.get_last_idx(values)? {
1295 let row = get_row_at_idx(values, last_idx)?;
1296 let orderings = &row[1..];
1297 if compare_rows(
1299 &self.orderings,
1300 orderings,
1301 &get_sort_options(self.ordering_req.as_ref()),
1302 )?
1303 .is_lt()
1304 {
1305 self.update_with_new_row(&row);
1306 }
1307 }
1308
1309 Ok(())
1310 }
1311
1312 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
1313 let is_set_idx = states.len() - 1;
1316 let flags = states[is_set_idx].as_boolean();
1317 let filtered_states =
1318 filter_states_according_to_is_set(&states[0..is_set_idx], flags)?;
1319 let sort_columns = convert_to_sort_cols(
1321 &filtered_states[1..is_set_idx],
1322 self.ordering_req.as_ref(),
1323 );
1324
1325 let comparator = LexicographicalComparator::try_new(&sort_columns)?;
1326 let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b));
1327
1328 if let Some(last_idx) = max {
1329 let last_row = get_row_at_idx(&filtered_states, last_idx)?;
1330 let last_ordering = &last_row[1..is_set_idx];
1332 let sort_options = get_sort_options(self.ordering_req.as_ref());
1333 if !self.is_set
1336 || self.requirement_satisfied
1337 || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt()
1338 {
1339 self.update_with_new_row(&last_row[0..is_set_idx]);
1343 }
1344 }
1345 Ok(())
1346 }
1347
1348 fn evaluate(&mut self) -> Result<ScalarValue> {
1349 Ok(self.last.clone())
1350 }
1351
1352 fn size(&self) -> usize {
1353 size_of_val(self) - size_of_val(&self.last)
1354 + self.last.size()
1355 + ScalarValue::size_of_vec(&self.orderings)
1356 - size_of_val(&self.orderings)
1357 }
1358}
1359
1360fn filter_states_according_to_is_set(
1363 states: &[ArrayRef],
1364 flags: &BooleanArray,
1365) -> Result<Vec<ArrayRef>> {
1366 states
1367 .iter()
1368 .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e)))
1369 .collect::<Result<Vec<_>>>()
1370}
1371
1372fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec<SortColumn> {
1374 arrs.iter()
1375 .zip(sort_exprs.iter())
1376 .map(|(item, sort_expr)| SortColumn {
1377 values: Arc::clone(item),
1378 options: Some(sort_expr.options),
1379 })
1380 .collect::<Vec<_>>()
1381}
1382
1383#[cfg(test)]
1384mod tests {
1385 use arrow::{array::Int64Array, compute::SortOptions, datatypes::Schema};
1386 use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
1387
1388 use super::*;
1389
1390 #[test]
1391 fn test_first_last_value_value() -> Result<()> {
1392 let mut first_accumulator = FirstValueAccumulator::try_new(
1393 &DataType::Int64,
1394 &[],
1395 LexOrdering::default(),
1396 false,
1397 )?;
1398 let mut last_accumulator = LastValueAccumulator::try_new(
1399 &DataType::Int64,
1400 &[],
1401 LexOrdering::default(),
1402 false,
1403 )?;
1404 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1407 let arrs = ranges
1409 .into_iter()
1410 .map(|(start, end)| {
1411 Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
1412 })
1413 .collect::<Vec<_>>();
1414 for arr in arrs {
1415 first_accumulator.update_batch(&[Arc::clone(&arr)])?;
1418 last_accumulator.update_batch(&[arr])?;
1420 }
1421 assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
1423 assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
1425 Ok(())
1426 }
1427
1428 #[test]
1429 fn test_first_last_state_after_merge() -> Result<()> {
1430 let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
1431 let arrs = ranges
1433 .into_iter()
1434 .map(|(start, end)| {
1435 Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
1436 })
1437 .collect::<Vec<_>>();
1438
1439 let mut first_accumulator = FirstValueAccumulator::try_new(
1441 &DataType::Int64,
1442 &[],
1443 LexOrdering::default(),
1444 false,
1445 )?;
1446
1447 first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1448 let state1 = first_accumulator.state()?;
1449
1450 let mut first_accumulator = FirstValueAccumulator::try_new(
1451 &DataType::Int64,
1452 &[],
1453 LexOrdering::default(),
1454 false,
1455 )?;
1456 first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1457 let state2 = first_accumulator.state()?;
1458
1459 assert_eq!(state1.len(), state2.len());
1460
1461 let mut states = vec![];
1462
1463 for idx in 0..state1.len() {
1464 states.push(compute::concat(&[
1465 &state1[idx].to_array()?,
1466 &state2[idx].to_array()?,
1467 ])?);
1468 }
1469
1470 let mut first_accumulator = FirstValueAccumulator::try_new(
1471 &DataType::Int64,
1472 &[],
1473 LexOrdering::default(),
1474 false,
1475 )?;
1476 first_accumulator.merge_batch(&states)?;
1477
1478 let merged_state = first_accumulator.state()?;
1479 assert_eq!(merged_state.len(), state1.len());
1480
1481 let mut last_accumulator = LastValueAccumulator::try_new(
1483 &DataType::Int64,
1484 &[],
1485 LexOrdering::default(),
1486 false,
1487 )?;
1488
1489 last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?;
1490 let state1 = last_accumulator.state()?;
1491
1492 let mut last_accumulator = LastValueAccumulator::try_new(
1493 &DataType::Int64,
1494 &[],
1495 LexOrdering::default(),
1496 false,
1497 )?;
1498 last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?;
1499 let state2 = last_accumulator.state()?;
1500
1501 assert_eq!(state1.len(), state2.len());
1502
1503 let mut states = vec![];
1504
1505 for idx in 0..state1.len() {
1506 states.push(compute::concat(&[
1507 &state1[idx].to_array()?,
1508 &state2[idx].to_array()?,
1509 ])?);
1510 }
1511
1512 let mut last_accumulator = LastValueAccumulator::try_new(
1513 &DataType::Int64,
1514 &[],
1515 LexOrdering::default(),
1516 false,
1517 )?;
1518 last_accumulator.merge_batch(&states)?;
1519
1520 let merged_state = last_accumulator.state()?;
1521 assert_eq!(merged_state.len(), state1.len());
1522
1523 Ok(())
1524 }
1525
1526 #[test]
1527 fn test_frist_group_acc() -> Result<()> {
1528 let schema = Arc::new(Schema::new(vec![
1529 Field::new("a", DataType::Int64, true),
1530 Field::new("b", DataType::Int64, true),
1531 Field::new("c", DataType::Int64, true),
1532 Field::new("d", DataType::Int32, true),
1533 Field::new("e", DataType::Boolean, true),
1534 ]));
1535
1536 let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1537 expr: col("c", &schema).unwrap(),
1538 options: SortOptions::default(),
1539 }]);
1540
1541 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1542 sort_key,
1543 true,
1544 &DataType::Int64,
1545 &[DataType::Int64],
1546 true,
1547 )?;
1548
1549 let mut val_with_orderings = {
1550 let mut val_with_orderings = Vec::<ArrayRef>::new();
1551
1552 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1553 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1554
1555 val_with_orderings.push(vals);
1556 val_with_orderings.push(orderings);
1557
1558 val_with_orderings
1559 };
1560
1561 group_acc.update_batch(
1562 &val_with_orderings,
1563 &[0, 1, 2, 1],
1564 Some(&BooleanArray::from(vec![true, true, false, true])),
1565 3,
1566 )?;
1567 assert_eq!(
1568 group_acc.size_of_orderings,
1569 group_acc.compute_size_of_orderings()
1570 );
1571
1572 let state = group_acc.state(EmitTo::All)?;
1573
1574 let expected_state: Vec<Arc<dyn Array>> = vec![
1575 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1576 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1577 Arc::new(BooleanArray::from(vec![true, true, false])),
1578 ];
1579 assert_eq!(state, expected_state);
1580
1581 assert_eq!(
1582 group_acc.size_of_orderings,
1583 group_acc.compute_size_of_orderings()
1584 );
1585
1586 group_acc.merge_batch(
1587 &state,
1588 &[0, 1, 2],
1589 Some(&BooleanArray::from(vec![true, false, false])),
1590 3,
1591 )?;
1592
1593 assert_eq!(
1594 group_acc.size_of_orderings,
1595 group_acc.compute_size_of_orderings()
1596 );
1597
1598 val_with_orderings.clear();
1599 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1600 val_with_orderings.push(Arc::new(Int64Array::from(vec![6, 6])));
1601
1602 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1603
1604 let binding = group_acc.evaluate(EmitTo::All)?;
1605 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1606
1607 let expect: PrimitiveArray<Int64Type> =
1608 Int64Array::from(vec![Some(1), Some(6), Some(6), None]);
1609
1610 assert_eq!(eval_result, &expect);
1611
1612 assert_eq!(
1613 group_acc.size_of_orderings,
1614 group_acc.compute_size_of_orderings()
1615 );
1616
1617 Ok(())
1618 }
1619
1620 #[test]
1621 fn test_group_acc_size_of_ordering() -> Result<()> {
1622 let schema = Arc::new(Schema::new(vec![
1623 Field::new("a", DataType::Int64, true),
1624 Field::new("b", DataType::Int64, true),
1625 Field::new("c", DataType::Int64, true),
1626 Field::new("d", DataType::Int32, true),
1627 Field::new("e", DataType::Boolean, true),
1628 ]));
1629
1630 let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1631 expr: col("c", &schema).unwrap(),
1632 options: SortOptions::default(),
1633 }]);
1634
1635 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1636 sort_key,
1637 true,
1638 &DataType::Int64,
1639 &[DataType::Int64],
1640 true,
1641 )?;
1642
1643 let val_with_orderings = {
1644 let mut val_with_orderings = Vec::<ArrayRef>::new();
1645
1646 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1647 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1648
1649 val_with_orderings.push(vals);
1650 val_with_orderings.push(orderings);
1651
1652 val_with_orderings
1653 };
1654
1655 for _ in 0..10 {
1656 group_acc.update_batch(
1657 &val_with_orderings,
1658 &[0, 1, 2, 1],
1659 Some(&BooleanArray::from(vec![true, true, false, true])),
1660 100,
1661 )?;
1662 assert_eq!(
1663 group_acc.size_of_orderings,
1664 group_acc.compute_size_of_orderings()
1665 );
1666
1667 group_acc.state(EmitTo::First(2))?;
1668 assert_eq!(
1669 group_acc.size_of_orderings,
1670 group_acc.compute_size_of_orderings()
1671 );
1672
1673 let s = group_acc.state(EmitTo::All)?;
1674 assert_eq!(
1675 group_acc.size_of_orderings,
1676 group_acc.compute_size_of_orderings()
1677 );
1678
1679 group_acc.merge_batch(&s, &Vec::from_iter(0..s[0].len()), None, 100)?;
1680 assert_eq!(
1681 group_acc.size_of_orderings,
1682 group_acc.compute_size_of_orderings()
1683 );
1684
1685 group_acc.evaluate(EmitTo::First(2))?;
1686 assert_eq!(
1687 group_acc.size_of_orderings,
1688 group_acc.compute_size_of_orderings()
1689 );
1690
1691 group_acc.evaluate(EmitTo::All)?;
1692 assert_eq!(
1693 group_acc.size_of_orderings,
1694 group_acc.compute_size_of_orderings()
1695 );
1696 }
1697
1698 Ok(())
1699 }
1700
1701 #[test]
1702 fn test_last_group_acc() -> Result<()> {
1703 let schema = Arc::new(Schema::new(vec![
1704 Field::new("a", DataType::Int64, true),
1705 Field::new("b", DataType::Int64, true),
1706 Field::new("c", DataType::Int64, true),
1707 Field::new("d", DataType::Int32, true),
1708 Field::new("e", DataType::Boolean, true),
1709 ]));
1710
1711 let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1712 expr: col("c", &schema).unwrap(),
1713 options: SortOptions::default(),
1714 }]);
1715
1716 let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1717 sort_key,
1718 true,
1719 &DataType::Int64,
1720 &[DataType::Int64],
1721 false,
1722 )?;
1723
1724 let mut val_with_orderings = {
1725 let mut val_with_orderings = Vec::<ArrayRef>::new();
1726
1727 let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1728 let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1729
1730 val_with_orderings.push(vals);
1731 val_with_orderings.push(orderings);
1732
1733 val_with_orderings
1734 };
1735
1736 group_acc.update_batch(
1737 &val_with_orderings,
1738 &[0, 1, 2, 1],
1739 Some(&BooleanArray::from(vec![true, true, false, true])),
1740 3,
1741 )?;
1742
1743 let state = group_acc.state(EmitTo::All)?;
1744
1745 let expected_state: Vec<Arc<dyn Array>> = vec![
1746 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1747 Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1748 Arc::new(BooleanArray::from(vec![true, true, false])),
1749 ];
1750 assert_eq!(state, expected_state);
1751
1752 group_acc.merge_batch(
1753 &state,
1754 &[0, 1, 2],
1755 Some(&BooleanArray::from(vec![true, false, false])),
1756 3,
1757 )?;
1758
1759 val_with_orderings.clear();
1760 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1761 val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1762
1763 group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1764
1765 let binding = group_acc.evaluate(EmitTo::All)?;
1766 let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1767
1768 let expect: PrimitiveArray<Int64Type> =
1769 Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1770
1771 assert_eq!(eval_result, &expect);
1772
1773 Ok(())
1774 }
1775}