1use std::{collections::VecDeque, ops::Range, sync::Arc};
21
22use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
23
24use arrow::{
25 array::ArrayRef,
26 compute::{concat, concat_batches, SortOptions},
27 datatypes::{DataType, SchemaRef},
28 record_batch::RecordBatch,
29};
30use datafusion_common::{
31 internal_err,
32 utils::{compare_rows, get_row_at_idx, search_in_slice},
33 DataFusionError, Result, ScalarValue,
34};
35
36#[derive(Debug)]
38pub struct WindowAggState {
39 pub window_frame_range: Range<usize>,
41 pub window_frame_ctx: Option<WindowFrameContext>,
42 pub last_calculated_index: usize,
44 pub offset_pruned_rows: usize,
46 pub out_col: ArrayRef,
48 pub n_row_result_missing: usize,
51 pub is_end: bool,
53}
54
55impl WindowAggState {
56 pub fn prune_state(&mut self, n_prune: usize) {
57 self.window_frame_range = Range {
58 start: self.window_frame_range.start - n_prune,
59 end: self.window_frame_range.end - n_prune,
60 };
61 self.last_calculated_index -= n_prune;
62 self.offset_pruned_rows += n_prune;
63
64 match self.window_frame_ctx.as_mut() {
65 Some(WindowFrameContext::Rows(_)) => {}
67 Some(WindowFrameContext::Range { .. }) => {}
68 Some(WindowFrameContext::Groups { state, .. }) => {
69 let mut n_group_to_del = 0;
70 for (_, end_idx) in &state.group_end_indices {
71 if n_prune < *end_idx {
72 break;
73 }
74 n_group_to_del += 1;
75 }
76 state.group_end_indices.drain(0..n_group_to_del);
77 state
78 .group_end_indices
79 .iter_mut()
80 .for_each(|(_, start_idx)| *start_idx -= n_prune);
81 state.current_group_idx -= n_group_to_del;
82 }
83 None => {}
84 };
85 }
86
87 pub fn update(
88 &mut self,
89 out_col: &ArrayRef,
90 partition_batch_state: &PartitionBatchState,
91 ) -> Result<()> {
92 self.last_calculated_index += out_col.len();
93 self.out_col = concat(&[&self.out_col, &out_col])?;
94 self.n_row_result_missing =
95 partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
96 self.is_end = partition_batch_state.is_end;
97 Ok(())
98 }
99
100 pub fn new(out_type: &DataType) -> Result<Self> {
101 let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0)?;
102 Ok(Self {
103 window_frame_range: Range { start: 0, end: 0 },
104 window_frame_ctx: None,
105 last_calculated_index: 0,
106 offset_pruned_rows: 0,
107 out_col: empty_out_col,
108 n_row_result_missing: 0,
109 is_end: false,
110 })
111 }
112}
113
114#[derive(Debug)]
116pub enum WindowFrameContext {
117 Rows(Arc<WindowFrame>),
119 Range {
123 window_frame: Arc<WindowFrame>,
124 state: WindowFrameStateRange,
125 },
126 Groups {
130 window_frame: Arc<WindowFrame>,
131 state: WindowFrameStateGroups,
132 },
133}
134
135impl WindowFrameContext {
136 pub fn new(window_frame: Arc<WindowFrame>, sort_options: Vec<SortOptions>) -> Self {
138 match window_frame.units {
139 WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame),
140 WindowFrameUnits::Range => WindowFrameContext::Range {
141 window_frame,
142 state: WindowFrameStateRange::new(sort_options),
143 },
144 WindowFrameUnits::Groups => WindowFrameContext::Groups {
145 window_frame,
146 state: WindowFrameStateGroups::default(),
147 },
148 }
149 }
150
151 pub fn calculate_range(
153 &mut self,
154 range_columns: &[ArrayRef],
155 last_range: &Range<usize>,
156 length: usize,
157 idx: usize,
158 ) -> Result<Range<usize>> {
159 match self {
160 WindowFrameContext::Rows(window_frame) => {
161 Self::calculate_range_rows(window_frame, length, idx)
162 }
163 WindowFrameContext::Range {
167 window_frame,
168 ref mut state,
169 } => state.calculate_range(
170 window_frame,
171 last_range,
172 range_columns,
173 length,
174 idx,
175 ),
176 WindowFrameContext::Groups {
180 window_frame,
181 ref mut state,
182 } => state.calculate_range(window_frame, range_columns, length, idx),
183 }
184 }
185
186 fn calculate_range_rows(
188 window_frame: &Arc<WindowFrame>,
189 length: usize,
190 idx: usize,
191 ) -> Result<Range<usize>> {
192 let start = match window_frame.start_bound {
193 WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0,
195 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
196 idx.saturating_sub(n as usize)
197 }
198 WindowFrameBound::CurrentRow => idx,
199 WindowFrameBound::Following(ScalarValue::UInt64(None)) => {
201 return internal_err!(
202 "Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'"
203 )
204 }
205 WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
206 std::cmp::min(idx + n as usize, length)
207 }
208 WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
210 return internal_err!("Rows should be Uint")
211 }
212 };
213 let end = match window_frame.end_bound {
214 WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => {
216 return internal_err!(
217 "Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'"
218 )
219 }
220 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
221 if idx >= n as usize {
222 idx - n as usize + 1
223 } else {
224 0
225 }
226 }
227 WindowFrameBound::CurrentRow => idx + 1,
228 WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
230 WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
231 std::cmp::min(idx + n as usize + 1, length)
232 }
233 WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
235 return internal_err!("Rows should be Uint")
236 }
237 };
238 Ok(Range { start, end })
239 }
240}
241
242#[derive(Debug)]
244pub struct PartitionBatchState {
245 pub record_batch: RecordBatch,
247 pub most_recent_row: Option<RecordBatch>,
252 pub is_end: bool,
254 pub n_out_row: usize,
256}
257
258impl PartitionBatchState {
259 pub fn new(schema: SchemaRef) -> Self {
260 Self {
261 record_batch: RecordBatch::new_empty(schema),
262 most_recent_row: None,
263 is_end: false,
264 n_out_row: 0,
265 }
266 }
267
268 pub fn extend(&mut self, batch: &RecordBatch) -> Result<()> {
269 self.record_batch =
270 concat_batches(&self.record_batch.schema(), [&self.record_batch, batch])?;
271 Ok(())
272 }
273
274 pub fn set_most_recent_row(&mut self, batch: RecordBatch) {
275 self.most_recent_row = Some(batch);
278 }
279}
280
281#[derive(Debug, Default)]
286pub struct WindowFrameStateRange {
287 sort_options: Vec<SortOptions>,
288}
289
290impl WindowFrameStateRange {
291 fn new(sort_options: Vec<SortOptions>) -> Self {
293 Self { sort_options }
294 }
295
296 fn calculate_range(
302 &mut self,
303 window_frame: &Arc<WindowFrame>,
304 last_range: &Range<usize>,
305 range_columns: &[ArrayRef],
306 length: usize,
307 idx: usize,
308 ) -> Result<Range<usize>> {
309 let start = match window_frame.start_bound {
310 WindowFrameBound::Preceding(ref n) => {
311 if n.is_null() {
312 0
314 } else {
315 self.calculate_index_of_row::<true, true>(
316 range_columns,
317 last_range,
318 idx,
319 Some(n),
320 length,
321 )?
322 }
323 }
324 WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
325 range_columns,
326 last_range,
327 idx,
328 None,
329 length,
330 )?,
331 WindowFrameBound::Following(ref n) => self
332 .calculate_index_of_row::<true, false>(
333 range_columns,
334 last_range,
335 idx,
336 Some(n),
337 length,
338 )?,
339 };
340 let end = match window_frame.end_bound {
341 WindowFrameBound::Preceding(ref n) => self
342 .calculate_index_of_row::<false, true>(
343 range_columns,
344 last_range,
345 idx,
346 Some(n),
347 length,
348 )?,
349 WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
350 range_columns,
351 last_range,
352 idx,
353 None,
354 length,
355 )?,
356 WindowFrameBound::Following(ref n) => {
357 if n.is_null() {
358 length
360 } else {
361 self.calculate_index_of_row::<false, false>(
362 range_columns,
363 last_range,
364 idx,
365 Some(n),
366 length,
367 )?
368 }
369 }
370 };
371 Ok(Range { start, end })
372 }
373
374 fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
378 &mut self,
379 range_columns: &[ArrayRef],
380 last_range: &Range<usize>,
381 idx: usize,
382 delta: Option<&ScalarValue>,
383 length: usize,
384 ) -> Result<usize> {
385 let current_row_values = get_row_at_idx(range_columns, idx)?;
386 let end_range = if let Some(delta) = delta {
387 let is_descending: bool = self
388 .sort_options
389 .first()
390 .ok_or_else(|| {
391 DataFusionError::Internal(
392 "Sort options unexpectedly absent in a window frame".to_string(),
393 )
394 })?
395 .descending;
396
397 current_row_values
398 .iter()
399 .map(|value| {
400 if value.is_null() {
401 return Ok(value.clone());
402 }
403 if SEARCH_SIDE == is_descending {
404 value.add(delta)
406 } else if value.is_unsigned() && value < delta {
407 value.sub(value)
411 } else {
412 value.sub(delta)
414 }
415 })
416 .collect::<Result<Vec<ScalarValue>>>()?
417 } else {
418 current_row_values
419 };
420 let search_start = if SIDE {
421 last_range.start
422 } else {
423 last_range.end
424 };
425 let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
426 let cmp = compare_rows(current, target, &self.sort_options)?;
427 Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
428 };
429 search_in_slice(range_columns, &end_range, compare_fn, search_start, length)
430 }
431}
432
433#[derive(Debug, Default)]
458pub struct WindowFrameStateGroups {
459 pub group_end_indices: VecDeque<(Vec<ScalarValue>, usize)>,
463 pub current_group_idx: usize,
465}
466
467impl WindowFrameStateGroups {
468 fn calculate_range(
469 &mut self,
470 window_frame: &Arc<WindowFrame>,
471 range_columns: &[ArrayRef],
472 length: usize,
473 idx: usize,
474 ) -> Result<Range<usize>> {
475 let start = match window_frame.start_bound {
476 WindowFrameBound::Preceding(ref n) => {
477 if n.is_null() {
478 0
480 } else {
481 self.calculate_index_of_row::<true, true>(
482 range_columns,
483 idx,
484 Some(n),
485 length,
486 )?
487 }
488 }
489 WindowFrameBound::CurrentRow => self.calculate_index_of_row::<true, true>(
490 range_columns,
491 idx,
492 None,
493 length,
494 )?,
495 WindowFrameBound::Following(ref n) => self
496 .calculate_index_of_row::<true, false>(
497 range_columns,
498 idx,
499 Some(n),
500 length,
501 )?,
502 };
503 let end = match window_frame.end_bound {
504 WindowFrameBound::Preceding(ref n) => self
505 .calculate_index_of_row::<false, true>(
506 range_columns,
507 idx,
508 Some(n),
509 length,
510 )?,
511 WindowFrameBound::CurrentRow => self.calculate_index_of_row::<false, false>(
512 range_columns,
513 idx,
514 None,
515 length,
516 )?,
517 WindowFrameBound::Following(ref n) => {
518 if n.is_null() {
519 length
521 } else {
522 self.calculate_index_of_row::<false, false>(
523 range_columns,
524 idx,
525 Some(n),
526 length,
527 )?
528 }
529 }
530 };
531 Ok(Range { start, end })
532 }
533
534 fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
539 &mut self,
540 range_columns: &[ArrayRef],
541 idx: usize,
542 delta: Option<&ScalarValue>,
543 length: usize,
544 ) -> Result<usize> {
545 let delta = if let Some(delta) = delta {
546 if let ScalarValue::UInt64(Some(value)) = delta {
547 *value as usize
548 } else {
549 return internal_err!(
550 "Unexpectedly got a non-UInt64 value in a GROUPS mode window frame"
551 );
552 }
553 } else {
554 0
555 };
556 let mut group_start = 0;
557 let last_group = self.group_end_indices.back_mut();
558 if let Some((group_row, group_end)) = last_group {
559 if *group_end < length {
560 let new_group_row = get_row_at_idx(range_columns, *group_end)?;
561 if new_group_row.eq(group_row) {
563 *group_end = search_in_slice(
565 range_columns,
566 group_row,
567 check_equality,
568 *group_end,
569 length,
570 )?;
571 }
572 }
573 group_start = *group_end;
575 }
576
577 while idx >= group_start {
579 let group_row = get_row_at_idx(range_columns, group_start)?;
580 let group_end = search_in_slice(
582 range_columns,
583 &group_row,
584 check_equality,
585 group_start,
586 length,
587 )?;
588 self.group_end_indices.push_back((group_row, group_end));
589 group_start = group_end;
590 }
591
592 while self.current_group_idx < self.group_end_indices.len()
594 && idx >= self.group_end_indices[self.current_group_idx].1
595 {
596 self.current_group_idx += 1;
597 }
598
599 let group_idx = if SEARCH_SIDE {
601 self.current_group_idx.saturating_sub(delta)
602 } else {
603 self.current_group_idx + delta
604 };
605
606 while self.group_end_indices.len() <= group_idx && group_start < length {
608 let group_row = get_row_at_idx(range_columns, group_start)?;
609 let group_end = search_in_slice(
611 range_columns,
612 &group_row,
613 check_equality,
614 group_start,
615 length,
616 )?;
617 self.group_end_indices.push_back((group_row, group_end));
618 group_start = group_end;
619 }
620
621 Ok(match (SIDE, SEARCH_SIDE) {
623 (true, _) => {
625 let group_idx = std::cmp::min(group_idx, self.group_end_indices.len());
626 if group_idx > 0 {
627 self.group_end_indices[group_idx - 1].1
629 } else {
630 0
632 }
633 }
634 (false, true) => {
636 if self.current_group_idx >= delta {
637 let group_idx = self.current_group_idx - delta;
638 self.group_end_indices[group_idx].1
639 } else {
640 0
642 }
643 }
644 (false, false) => {
646 let group_idx = std::cmp::min(
647 self.current_group_idx + delta,
648 self.group_end_indices.len() - 1,
649 );
650 self.group_end_indices[group_idx].1
651 }
652 })
653 }
654}
655
656fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<bool> {
657 Ok(current == target)
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 use arrow::array::Float64Array;
665
666 fn get_test_data() -> (Vec<ArrayRef>, Vec<SortOptions>) {
667 let range_columns: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
668 5.0, 7.0, 8.0, 8.0, 9., 10., 10., 10., 11.,
669 ]))];
670 let sort_options = vec![SortOptions {
671 descending: false,
672 nulls_first: false,
673 }];
674
675 (range_columns, sort_options)
676 }
677
678 fn assert_expected(
679 expected_results: Vec<(Range<usize>, usize)>,
680 window_frame: &Arc<WindowFrame>,
681 ) -> Result<()> {
682 let mut window_frame_groups = WindowFrameStateGroups::default();
683 let (range_columns, _) = get_test_data();
684 let n_row = range_columns[0].len();
685 for (idx, (expected_range, expected_group_idx)) in
686 expected_results.into_iter().enumerate()
687 {
688 let range = window_frame_groups.calculate_range(
689 window_frame,
690 &range_columns,
691 n_row,
692 idx,
693 )?;
694 assert_eq!(range, expected_range);
695 assert_eq!(window_frame_groups.current_group_idx, expected_group_idx);
696 }
697 Ok(())
698 }
699
700 #[test]
701 fn test_window_frame_group_boundaries() -> Result<()> {
702 let window_frame = Arc::new(WindowFrame::new_bounds(
703 WindowFrameUnits::Groups,
704 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
705 WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
706 ));
707 let expected_results = vec![
708 (Range { start: 0, end: 2 }, 0),
709 (Range { start: 0, end: 4 }, 1),
710 (Range { start: 1, end: 5 }, 2),
711 (Range { start: 1, end: 5 }, 2),
712 (Range { start: 2, end: 8 }, 3),
713 (Range { start: 4, end: 9 }, 4),
714 (Range { start: 4, end: 9 }, 4),
715 (Range { start: 4, end: 9 }, 4),
716 (Range { start: 5, end: 9 }, 5),
717 ];
718 assert_expected(expected_results, &window_frame)
719 }
720
721 #[test]
722 fn test_window_frame_group_boundaries_both_following() -> Result<()> {
723 let window_frame = Arc::new(WindowFrame::new_bounds(
724 WindowFrameUnits::Groups,
725 WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
726 WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
727 ));
728 let expected_results = vec![
729 (Range::<usize> { start: 1, end: 4 }, 0),
730 (Range::<usize> { start: 2, end: 5 }, 1),
731 (Range::<usize> { start: 4, end: 8 }, 2),
732 (Range::<usize> { start: 4, end: 8 }, 2),
733 (Range::<usize> { start: 5, end: 9 }, 3),
734 (Range::<usize> { start: 8, end: 9 }, 4),
735 (Range::<usize> { start: 8, end: 9 }, 4),
736 (Range::<usize> { start: 8, end: 9 }, 4),
737 (Range::<usize> { start: 9, end: 9 }, 5),
738 ];
739 assert_expected(expected_results, &window_frame)
740 }
741
742 #[test]
743 fn test_window_frame_group_boundaries_both_preceding() -> Result<()> {
744 let window_frame = Arc::new(WindowFrame::new_bounds(
745 WindowFrameUnits::Groups,
746 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
747 WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
748 ));
749 let expected_results = vec![
750 (Range::<usize> { start: 0, end: 0 }, 0),
751 (Range::<usize> { start: 0, end: 1 }, 1),
752 (Range::<usize> { start: 0, end: 2 }, 2),
753 (Range::<usize> { start: 0, end: 2 }, 2),
754 (Range::<usize> { start: 1, end: 4 }, 3),
755 (Range::<usize> { start: 2, end: 5 }, 4),
756 (Range::<usize> { start: 2, end: 5 }, 4),
757 (Range::<usize> { start: 2, end: 5 }, 4),
758 (Range::<usize> { start: 4, end: 8 }, 5),
759 ];
760 assert_expected(expected_results, &window_frame)
761 }
762}