1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use arrow::datatypes::FieldRef;
22use datafusion_common::arrow::array::ArrayRef;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::Field;
25use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::{
28 Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
29 Volatility, WindowUDFImpl,
30};
31use datafusion_functions_window_common::expr::ExpressionArgs;
32use datafusion_functions_window_common::field::WindowUDFFieldArgs;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35use std::any::Any;
36use std::cmp::min;
37use std::collections::VecDeque;
38use std::ops::{Neg, Range};
39use std::sync::{Arc, LazyLock};
40
41get_or_init_udwf!(
42 Lag,
43 lag,
44 "Returns the row value that precedes the current row by a specified \
45 offset within partition. If no such row exists, then returns the \
46 default value.",
47 WindowShift::lag
48);
49get_or_init_udwf!(
50 Lead,
51 lead,
52 "Returns the value from a row that follows the current row by a \
53 specified offset within the partition. If no such row exists, then \
54 returns the default value.",
55 WindowShift::lead
56);
57
58pub fn lag(
65 arg: datafusion_expr::Expr,
66 shift_offset: Option<i64>,
67 default_value: Option<ScalarValue>,
68) -> datafusion_expr::Expr {
69 let shift_offset_lit = shift_offset
70 .map(|v| v.lit())
71 .unwrap_or(ScalarValue::Null.lit());
72 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
73
74 lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
75}
76
77pub fn lead(
84 arg: datafusion_expr::Expr,
85 shift_offset: Option<i64>,
86 default_value: Option<ScalarValue>,
87) -> datafusion_expr::Expr {
88 let shift_offset_lit = shift_offset
89 .map(|v| v.lit())
90 .unwrap_or(ScalarValue::Null.lit());
91 let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
92
93 lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
94}
95
96#[derive(Debug)]
97enum WindowShiftKind {
98 Lag,
99 Lead,
100}
101
102impl WindowShiftKind {
103 fn name(&self) -> &'static str {
104 match self {
105 WindowShiftKind::Lag => "lag",
106 WindowShiftKind::Lead => "lead",
107 }
108 }
109
110 fn shift_offset(&self, value: Option<i64>) -> i64 {
114 match self {
115 WindowShiftKind::Lag => value.unwrap_or(1),
116 WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
117 }
118 }
119}
120
121#[derive(Debug)]
123pub struct WindowShift {
124 signature: Signature,
125 kind: WindowShiftKind,
126}
127
128impl WindowShift {
129 fn new(kind: WindowShiftKind) -> Self {
130 Self {
131 signature: Signature::one_of(
132 vec![
133 TypeSignature::Any(1),
134 TypeSignature::Any(2),
135 TypeSignature::Any(3),
136 ],
137 Volatility::Immutable,
138 ),
139 kind,
140 }
141 }
142
143 pub fn lag() -> Self {
144 Self::new(WindowShiftKind::Lag)
145 }
146
147 pub fn lead() -> Self {
148 Self::new(WindowShiftKind::Lead)
149 }
150}
151
152static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
153 Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
154 current row within the partition; if there is no such row, instead return default \
155 (which must be of the same type as value).", "lag(expression, offset, default)")
156 .with_argument("expression", "Expression to operate on")
157 .with_argument("offset", "Integer. Specifies how many rows back \
158 the value of expression should be retrieved. Defaults to 1.")
159 .with_argument("default", "The default value if the offset is \
160 not within the partition. Must be of the same type as expression.")
161 .with_sql_example(r#"```sql
162 --Example usage of the lag window function:
163 SELECT employee_id,
164 salary,
165 lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
166 FROM employees;
167```
168
169```sql
170+-------------+--------+-------------+
171| employee_id | salary | prev_salary |
172+-------------+--------+-------------+
173| 1 | 30000 | 0 |
174| 2 | 50000 | 30000 |
175| 3 | 70000 | 50000 |
176| 4 | 60000 | 70000 |
177+-------------+--------+-------------+
178```"#)
179 .build()
180});
181
182fn get_lag_doc() -> &'static Documentation {
183 &LAG_DOCUMENTATION
184}
185
186static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
187 Documentation::builder(DOC_SECTION_ANALYTICAL,
188 "Returns value evaluated at the row that is offset rows after the \
189 current row within the partition; if there is no such row, instead return default \
190 (which must be of the same type as value).",
191 "lead(expression, offset, default)")
192 .with_argument("expression", "Expression to operate on")
193 .with_argument("offset", "Integer. Specifies how many rows \
194 forward the value of expression should be retrieved. Defaults to 1.")
195 .with_argument("default", "The default value if the offset is \
196 not within the partition. Must be of the same type as expression.")
197 .with_sql_example(r#"```sql
198-- Example usage of lead() :
199SELECT
200 employee_id,
201 department,
202 salary,
203 lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
204FROM employees;
205```
206
207```sql
208+-------------+-------------+--------+--------------+
209| employee_id | department | salary | next_salary |
210+-------------+-------------+--------+--------------+
211| 1 | Sales | 30000 | 50000 |
212| 2 | Sales | 50000 | 70000 |
213| 3 | Sales | 70000 | 0 |
214| 4 | Engineering | 40000 | 60000 |
215| 5 | Engineering | 60000 | 0 |
216+-------------+-------------+--------+--------------+
217```"#)
218 .build()
219});
220
221fn get_lead_doc() -> &'static Documentation {
222 &LEAD_DOCUMENTATION
223}
224
225impl WindowUDFImpl for WindowShift {
226 fn as_any(&self) -> &dyn Any {
227 self
228 }
229
230 fn name(&self) -> &str {
231 self.kind.name()
232 }
233
234 fn signature(&self) -> &Signature {
235 &self.signature
236 }
237
238 fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
244 parse_expr(expr_args.input_exprs(), expr_args.input_fields())
245 .into_iter()
246 .collect::<Vec<_>>()
247 }
248
249 fn partition_evaluator(
250 &self,
251 partition_evaluator_args: PartitionEvaluatorArgs,
252 ) -> Result<Box<dyn PartitionEvaluator>> {
253 let shift_offset =
254 get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
255 .map(get_signed_integer)
256 .map_or(Ok(None), |v| v.map(Some))
257 .map(|n| self.kind.shift_offset(n))
258 .map(|offset| {
259 if partition_evaluator_args.is_reversed() {
260 -offset
261 } else {
262 offset
263 }
264 })?;
265 let default_value = parse_default_value(
266 partition_evaluator_args.input_exprs(),
267 partition_evaluator_args.input_fields(),
268 )?;
269
270 Ok(Box::new(WindowShiftEvaluator {
271 shift_offset,
272 default_value,
273 ignore_nulls: partition_evaluator_args.ignore_nulls(),
274 non_null_offsets: VecDeque::new(),
275 }))
276 }
277
278 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
279 let return_field = parse_expr_field(field_args.input_fields())?;
280
281 Ok(return_field
282 .as_ref()
283 .clone()
284 .with_name(field_args.name())
285 .into())
286 }
287
288 fn reverse_expr(&self) -> ReversedUDWF {
289 match self.kind {
290 WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
291 WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
292 }
293 }
294
295 fn documentation(&self) -> Option<&Documentation> {
296 match self.kind {
297 WindowShiftKind::Lag => Some(get_lag_doc()),
298 WindowShiftKind::Lead => Some(get_lead_doc()),
299 }
300 }
301}
302
303fn parse_expr(
316 input_exprs: &[Arc<dyn PhysicalExpr>],
317 input_fields: &[FieldRef],
318) -> Result<Arc<dyn PhysicalExpr>> {
319 assert!(!input_exprs.is_empty());
320 assert!(!input_fields.is_empty());
321
322 let expr = Arc::clone(input_exprs.first().unwrap());
323 let expr_field = input_fields.first().unwrap();
324
325 if !expr_field.data_type().is_null() {
327 return Ok(expr);
328 }
329
330 let default_value = get_scalar_value_from_args(input_exprs, 2)?;
331 default_value.map_or(Ok(expr), |value| {
332 ScalarValue::try_from(&value.data_type()).map(|v| {
333 Arc::new(datafusion_physical_expr::expressions::Literal::new(v))
334 as Arc<dyn PhysicalExpr>
335 })
336 })
337}
338
339static NULL_FIELD: LazyLock<FieldRef> =
340 LazyLock::new(|| Field::new("value", DataType::Null, true).into());
341
342fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
347 assert!(!input_fields.is_empty());
348 let expr_field = input_fields.first().unwrap_or(&NULL_FIELD);
349
350 if !expr_field.data_type().is_null() {
352 return Ok(expr_field.as_ref().clone().with_nullable(true).into());
353 }
354
355 let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD);
356 Ok(default_value_field
357 .as_ref()
358 .clone()
359 .with_nullable(true)
360 .into())
361}
362
363fn parse_default_value(
366 input_exprs: &[Arc<dyn PhysicalExpr>],
367 input_types: &[FieldRef],
368) -> Result<ScalarValue> {
369 let expr_field = parse_expr_field(input_types)?;
370 let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
371
372 unparsed
373 .filter(|v| !v.data_type().is_null())
374 .map(|v| v.cast_to(expr_field.data_type()))
375 .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
376}
377
378#[derive(Debug)]
379struct WindowShiftEvaluator {
380 shift_offset: i64,
381 default_value: ScalarValue,
382 ignore_nulls: bool,
383 non_null_offsets: VecDeque<usize>,
385}
386
387impl WindowShiftEvaluator {
388 fn is_lag(&self) -> bool {
389 self.shift_offset > 0
391 }
392}
393
394fn evaluate_all_with_ignore_null(
396 array: &ArrayRef,
397 offset: i64,
398 default_value: &ScalarValue,
399 is_lag: bool,
400) -> Result<ArrayRef, DataFusionError> {
401 let valid_indices: Vec<usize> =
402 array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
403 let direction = !is_lag;
404 let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
405 .map(|id| {
406 let result_index = match valid_indices.binary_search(&id) {
407 Ok(pos) => if direction {
408 pos.checked_add(offset as usize)
409 } else {
410 pos.checked_sub(offset.unsigned_abs() as usize)
411 }
412 .and_then(|new_pos| {
413 if new_pos < valid_indices.len() {
414 Some(valid_indices[new_pos])
415 } else {
416 None
417 }
418 }),
419 Err(pos) => if direction {
420 pos.checked_add(offset as usize)
421 } else if pos > 0 {
422 pos.checked_sub(offset.unsigned_abs() as usize)
423 } else {
424 None
425 }
426 .and_then(|new_pos| {
427 if new_pos < valid_indices.len() {
428 Some(valid_indices[new_pos])
429 } else {
430 None
431 }
432 }),
433 };
434
435 match result_index {
436 Some(index) => ScalarValue::try_from_array(array, index),
437 None => Ok(default_value.clone()),
438 }
439 })
440 .collect();
441
442 let new_array = new_array_results?;
443 ScalarValue::iter_to_array(new_array)
444}
445fn shift_with_default_value(
447 array: &ArrayRef,
448 offset: i64,
449 default_value: &ScalarValue,
450) -> Result<ArrayRef> {
451 use datafusion_common::arrow::compute::concat;
452
453 let value_len = array.len() as i64;
454 if offset == 0 {
455 Ok(Arc::clone(array))
456 } else if offset == i64::MIN || offset.abs() >= value_len {
457 default_value.to_array_of_size(value_len as usize)
458 } else {
459 let slice_offset = (-offset).clamp(0, value_len) as usize;
460 let length = array.len() - offset.unsigned_abs() as usize;
461 let slice = array.slice(slice_offset, length);
462
463 let nulls = offset.unsigned_abs() as usize;
465 let default_values = default_value.to_array_of_size(nulls)?;
466
467 if offset > 0 {
469 concat(&[default_values.as_ref(), slice.as_ref()])
470 .map_err(|e| arrow_datafusion_err!(e))
471 } else {
472 concat(&[slice.as_ref(), default_values.as_ref()])
473 .map_err(|e| arrow_datafusion_err!(e))
474 }
475 }
476}
477
478impl PartitionEvaluator for WindowShiftEvaluator {
479 fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
480 if self.is_lag() {
481 let start = if self.non_null_offsets.len() == self.shift_offset as usize {
482 let offset: usize = self.non_null_offsets.iter().sum();
484 idx.saturating_sub(offset)
485 } else if !self.ignore_nulls {
486 let offset = self.shift_offset as usize;
487 idx.saturating_sub(offset)
488 } else {
489 0
490 };
491 let end = idx + 1;
492 Ok(Range { start, end })
493 } else {
494 let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
495 let offset: usize = self.non_null_offsets.iter().sum();
497 min(idx + offset + 1, n_rows)
498 } else if !self.ignore_nulls {
499 let offset = (-self.shift_offset) as usize;
500 min(idx + offset, n_rows)
501 } else {
502 n_rows
503 };
504 Ok(Range { start: idx, end })
505 }
506 }
507
508 fn is_causal(&self) -> bool {
509 self.is_lag()
511 }
512
513 fn evaluate(
514 &mut self,
515 values: &[ArrayRef],
516 range: &Range<usize>,
517 ) -> Result<ScalarValue> {
518 let array = &values[0];
519 let len = array.len();
520
521 let i = if self.is_lag() {
523 (range.end as i64 - self.shift_offset - 1) as usize
524 } else {
525 (range.start as i64 - self.shift_offset) as usize
527 };
528
529 let mut idx: Option<usize> = if i < len { Some(i) } else { None };
530
531 if self.ignore_nulls && self.is_lag() {
534 idx = if self.non_null_offsets.len() == self.shift_offset as usize {
537 let total_offset: usize = self.non_null_offsets.iter().sum();
538 Some(range.end - 1 - total_offset)
539 } else {
540 None
541 };
542
543 if array.is_valid(range.end - 1) {
545 self.non_null_offsets.push_back(1);
547 if self.non_null_offsets.len() > self.shift_offset as usize {
548 self.non_null_offsets.pop_front();
550 }
551 } else if !self.non_null_offsets.is_empty() {
552 let end_idx = self.non_null_offsets.len() - 1;
554 self.non_null_offsets[end_idx] += 1;
555 }
556 } else if self.ignore_nulls && !self.is_lag() {
557 let non_null_row_count = (-self.shift_offset) as usize;
560
561 if self.non_null_offsets.is_empty() {
562 let mut offset_val = 1;
564 for idx in range.start + 1..range.end {
565 if array.is_valid(idx) {
566 self.non_null_offsets.push_back(offset_val);
567 offset_val = 1;
568 } else {
569 offset_val += 1;
570 }
571 if self.non_null_offsets.len() == non_null_row_count + 1 {
574 break;
575 }
576 }
577 } else if range.end < len && array.is_valid(range.end) {
578 if array.is_valid(range.end) {
580 self.non_null_offsets.push_back(1);
582 } else {
583 let last_idx = self.non_null_offsets.len() - 1;
585 self.non_null_offsets[last_idx] += 1;
586 }
587 }
588
589 idx = if self.non_null_offsets.len() >= non_null_row_count {
591 let total_offset: usize =
592 self.non_null_offsets.iter().take(non_null_row_count).sum();
593 Some(range.start + total_offset)
594 } else {
595 None
596 };
597 if !self.non_null_offsets.is_empty() {
600 self.non_null_offsets[0] -= 1;
601 if self.non_null_offsets[0] == 0 {
602 self.non_null_offsets.pop_front();
604 }
605 }
606 }
607
608 #[allow(clippy::unnecessary_unwrap)]
614 if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
615 ScalarValue::try_from_array(array, idx.unwrap())
616 } else {
617 Ok(self.default_value.clone())
618 }
619 }
620
621 fn evaluate_all(
622 &mut self,
623 values: &[ArrayRef],
624 _num_rows: usize,
625 ) -> Result<ArrayRef> {
626 let value = &values[0];
628 if !self.ignore_nulls {
629 shift_with_default_value(value, self.shift_offset, &self.default_value)
630 } else {
631 evaluate_all_with_ignore_null(
632 value,
633 self.shift_offset,
634 &self.default_value,
635 self.is_lag(),
636 )
637 }
638 }
639
640 fn supports_bounded_execution(&self) -> bool {
641 true
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use arrow::array::*;
649 use datafusion_common::cast::as_int32_array;
650 use datafusion_physical_expr::expressions::{Column, Literal};
651 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
652
653 fn test_i32_result(
654 expr: WindowShift,
655 partition_evaluator_args: PartitionEvaluatorArgs,
656 expected: Int32Array,
657 ) -> Result<()> {
658 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
659 let values = vec![arr];
660 let num_rows = values.len();
661 let result = expr
662 .partition_evaluator(partition_evaluator_args)?
663 .evaluate_all(&values, num_rows)?;
664 let result = as_int32_array(&result)?;
665 assert_eq!(expected, *result);
666 Ok(())
667 }
668
669 #[test]
670 fn lead_lag_get_range() -> Result<()> {
671 let lag_fn = WindowShiftEvaluator {
673 shift_offset: 2,
674 default_value: ScalarValue::Null,
675 ignore_nulls: false,
676 non_null_offsets: Default::default(),
677 };
678 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
679 assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
680
681 let lag_fn = WindowShiftEvaluator {
683 shift_offset: 2,
684 default_value: ScalarValue::Null,
685 ignore_nulls: true,
686 non_null_offsets: vec![2, 2].into(), };
689 assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
690
691 let lead_fn = WindowShiftEvaluator {
693 shift_offset: -2,
694 default_value: ScalarValue::Null,
695 ignore_nulls: false,
696 non_null_offsets: Default::default(),
697 };
698 assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
699 assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
700
701 let lead_fn = WindowShiftEvaluator {
703 shift_offset: -2,
704 default_value: ScalarValue::Null,
705 ignore_nulls: true,
706 non_null_offsets: vec![2, 2].into(),
708 };
709 assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
710
711 Ok(())
712 }
713
714 #[test]
715 fn test_lead_window_shift() -> Result<()> {
716 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
717
718 test_i32_result(
719 WindowShift::lead(),
720 PartitionEvaluatorArgs::new(
721 &[expr],
722 &[Field::new("f", DataType::Int32, true).into()],
723 false,
724 false,
725 ),
726 [
727 Some(-2),
728 Some(3),
729 Some(-4),
730 Some(5),
731 Some(-6),
732 Some(7),
733 Some(8),
734 None,
735 ]
736 .iter()
737 .collect::<Int32Array>(),
738 )
739 }
740
741 #[test]
742 fn test_lag_window_shift() -> Result<()> {
743 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
744
745 test_i32_result(
746 WindowShift::lag(),
747 PartitionEvaluatorArgs::new(
748 &[expr],
749 &[Field::new("f", DataType::Int32, true).into()],
750 false,
751 false,
752 ),
753 [
754 None,
755 Some(1),
756 Some(-2),
757 Some(3),
758 Some(-4),
759 Some(5),
760 Some(-6),
761 Some(7),
762 ]
763 .iter()
764 .collect::<Int32Array>(),
765 )
766 }
767
768 #[test]
769 fn test_lag_with_default() -> Result<()> {
770 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
771 let shift_offset =
772 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
773 let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
774 as Arc<dyn PhysicalExpr>;
775
776 let input_exprs = &[expr, shift_offset, default_value];
777 let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32]
778 .into_iter()
779 .map(|d| Field::new("f", d, true))
780 .map(Arc::new)
781 .collect::<Vec<_>>();
782
783 test_i32_result(
784 WindowShift::lag(),
785 PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false),
786 [
787 Some(100),
788 Some(1),
789 Some(-2),
790 Some(3),
791 Some(-4),
792 Some(5),
793 Some(-6),
794 Some(7),
795 ]
796 .iter()
797 .collect::<Int32Array>(),
798 )
799 }
800}