datafusion_functions_window/
lead_lag.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `lead` and `lag` window function implementations
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21use arrow::datatypes::FieldRef;
22use datafusion_common::arrow::array::ArrayRef;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_common::arrow::datatypes::Field;
25use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::{
28    Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
29    Volatility, WindowUDFImpl,
30};
31use datafusion_functions_window_common::expr::ExpressionArgs;
32use datafusion_functions_window_common::field::WindowUDFFieldArgs;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
35use std::any::Any;
36use std::cmp::min;
37use std::collections::VecDeque;
38use std::ops::{Neg, Range};
39use std::sync::{Arc, LazyLock};
40
41get_or_init_udwf!(
42    Lag,
43    lag,
44    "Returns the row value that precedes the current row by a specified \
45    offset within partition. If no such row exists, then returns the \
46    default value.",
47    WindowShift::lag
48);
49get_or_init_udwf!(
50    Lead,
51    lead,
52    "Returns the value from a row that follows the current row by a \
53    specified offset within the partition. If no such row exists, then \
54    returns the default value.",
55    WindowShift::lead
56);
57
58/// Create an expression to represent the `lag` window function
59///
60/// returns value evaluated at the row that is offset rows before the current row within the partition;
61/// if there is no such row, instead return default (which must be of the same type as value).
62/// Both offset and default are evaluated with respect to the current row.
63/// If omitted, offset defaults to 1 and default to null
64pub fn lag(
65    arg: datafusion_expr::Expr,
66    shift_offset: Option<i64>,
67    default_value: Option<ScalarValue>,
68) -> datafusion_expr::Expr {
69    let shift_offset_lit = shift_offset
70        .map(|v| v.lit())
71        .unwrap_or(ScalarValue::Null.lit());
72    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
73
74    lag_udwf().call(vec![arg, shift_offset_lit, default_lit])
75}
76
77/// Create an expression to represent the `lead` window function
78///
79/// returns value evaluated at the row that is offset rows after the current row within the partition;
80/// if there is no such row, instead return default (which must be of the same type as value).
81/// Both offset and default are evaluated with respect to the current row.
82/// If omitted, offset defaults to 1 and default to null
83pub fn lead(
84    arg: datafusion_expr::Expr,
85    shift_offset: Option<i64>,
86    default_value: Option<ScalarValue>,
87) -> datafusion_expr::Expr {
88    let shift_offset_lit = shift_offset
89        .map(|v| v.lit())
90        .unwrap_or(ScalarValue::Null.lit());
91    let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
92
93    lead_udwf().call(vec![arg, shift_offset_lit, default_lit])
94}
95
96#[derive(Debug)]
97enum WindowShiftKind {
98    Lag,
99    Lead,
100}
101
102impl WindowShiftKind {
103    fn name(&self) -> &'static str {
104        match self {
105            WindowShiftKind::Lag => "lag",
106            WindowShiftKind::Lead => "lead",
107        }
108    }
109
110    /// In [`WindowShiftEvaluator`] a positive offset is used to signal
111    /// computation of `lag()`. So here we negate the input offset
112    /// value when computing `lead()`.
113    fn shift_offset(&self, value: Option<i64>) -> i64 {
114        match self {
115            WindowShiftKind::Lag => value.unwrap_or(1),
116            WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1),
117        }
118    }
119}
120
121/// window shift expression
122#[derive(Debug)]
123pub struct WindowShift {
124    signature: Signature,
125    kind: WindowShiftKind,
126}
127
128impl WindowShift {
129    fn new(kind: WindowShiftKind) -> Self {
130        Self {
131            signature: Signature::one_of(
132                vec![
133                    TypeSignature::Any(1),
134                    TypeSignature::Any(2),
135                    TypeSignature::Any(3),
136                ],
137                Volatility::Immutable,
138            ),
139            kind,
140        }
141    }
142
143    pub fn lag() -> Self {
144        Self::new(WindowShiftKind::Lag)
145    }
146
147    pub fn lead() -> Self {
148        Self::new(WindowShiftKind::Lead)
149    }
150}
151
152static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
153    Documentation::builder(DOC_SECTION_ANALYTICAL, "Returns value evaluated at the row that is offset rows before the \
154            current row within the partition; if there is no such row, instead return default \
155            (which must be of the same type as value).", "lag(expression, offset, default)")
156        .with_argument("expression", "Expression to operate on")
157        .with_argument("offset", "Integer. Specifies how many rows back \
158        the value of expression should be retrieved. Defaults to 1.")
159        .with_argument("default", "The default value if the offset is \
160        not within the partition. Must be of the same type as expression.")
161        .with_sql_example(r#"```sql
162    --Example usage of the lag window function:
163    SELECT employee_id,
164           salary,
165           lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
166    FROM employees;
167```
168
169```sql
170+-------------+--------+-------------+
171| employee_id | salary | prev_salary |
172+-------------+--------+-------------+
173| 1           | 30000  | 0           |
174| 2           | 50000  | 30000       |
175| 3           | 70000  | 50000       |
176| 4           | 60000  | 70000       |
177+-------------+--------+-------------+
178```"#)
179        .build()
180});
181
182fn get_lag_doc() -> &'static Documentation {
183    &LAG_DOCUMENTATION
184}
185
186static LEAD_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
187    Documentation::builder(DOC_SECTION_ANALYTICAL,
188            "Returns value evaluated at the row that is offset rows after the \
189            current row within the partition; if there is no such row, instead return default \
190            (which must be of the same type as value).",
191        "lead(expression, offset, default)")
192        .with_argument("expression", "Expression to operate on")
193        .with_argument("offset", "Integer. Specifies how many rows \
194        forward the value of expression should be retrieved. Defaults to 1.")
195        .with_argument("default", "The default value if the offset is \
196        not within the partition. Must be of the same type as expression.")
197        .with_sql_example(r#"```sql
198-- Example usage of lead() :
199SELECT
200    employee_id,
201    department,
202    salary,
203    lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
204FROM employees;
205```
206
207```sql
208+-------------+-------------+--------+--------------+
209| employee_id | department  | salary | next_salary  |
210+-------------+-------------+--------+--------------+
211| 1           | Sales       | 30000  | 50000        |
212| 2           | Sales       | 50000  | 70000        |
213| 3           | Sales       | 70000  | 0            |
214| 4           | Engineering | 40000  | 60000        |
215| 5           | Engineering | 60000  | 0            |
216+-------------+-------------+--------+--------------+
217```"#)
218        .build()
219});
220
221fn get_lead_doc() -> &'static Documentation {
222    &LEAD_DOCUMENTATION
223}
224
225impl WindowUDFImpl for WindowShift {
226    fn as_any(&self) -> &dyn Any {
227        self
228    }
229
230    fn name(&self) -> &str {
231        self.kind.name()
232    }
233
234    fn signature(&self) -> &Signature {
235        &self.signature
236    }
237
238    /// Handles the case where `NULL` expression is passed as an
239    /// argument to `lead`/`lag`. The type is refined depending
240    /// on the default value argument.
241    ///
242    /// For more details see: <https://github.com/apache/datafusion/issues/12717>
243    fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
244        parse_expr(expr_args.input_exprs(), expr_args.input_fields())
245            .into_iter()
246            .collect::<Vec<_>>()
247    }
248
249    fn partition_evaluator(
250        &self,
251        partition_evaluator_args: PartitionEvaluatorArgs,
252    ) -> Result<Box<dyn PartitionEvaluator>> {
253        let shift_offset =
254            get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)?
255                .map(get_signed_integer)
256                .map_or(Ok(None), |v| v.map(Some))
257                .map(|n| self.kind.shift_offset(n))
258                .map(|offset| {
259                    if partition_evaluator_args.is_reversed() {
260                        -offset
261                    } else {
262                        offset
263                    }
264                })?;
265        let default_value = parse_default_value(
266            partition_evaluator_args.input_exprs(),
267            partition_evaluator_args.input_fields(),
268        )?;
269
270        Ok(Box::new(WindowShiftEvaluator {
271            shift_offset,
272            default_value,
273            ignore_nulls: partition_evaluator_args.ignore_nulls(),
274            non_null_offsets: VecDeque::new(),
275        }))
276    }
277
278    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
279        let return_field = parse_expr_field(field_args.input_fields())?;
280
281        Ok(return_field
282            .as_ref()
283            .clone()
284            .with_name(field_args.name())
285            .into())
286    }
287
288    fn reverse_expr(&self) -> ReversedUDWF {
289        match self.kind {
290            WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()),
291            WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()),
292        }
293    }
294
295    fn documentation(&self) -> Option<&Documentation> {
296        match self.kind {
297            WindowShiftKind::Lag => Some(get_lag_doc()),
298            WindowShiftKind::Lead => Some(get_lead_doc()),
299        }
300    }
301}
302
303/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to
304/// refine it by matching it with the type of the default value.
305///
306/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null`
307/// is refined into `ScalarValue::Boolean(None)`. Only the type is
308/// refined, the expression value remains `NULL`.
309///
310/// When the window function is evaluated with `NULL` expression
311/// this guarantees that the type matches with that of the default
312/// value.
313///
314/// For more details see: <https://github.com/apache/datafusion/issues/12717>
315fn parse_expr(
316    input_exprs: &[Arc<dyn PhysicalExpr>],
317    input_fields: &[FieldRef],
318) -> Result<Arc<dyn PhysicalExpr>> {
319    assert!(!input_exprs.is_empty());
320    assert!(!input_fields.is_empty());
321
322    let expr = Arc::clone(input_exprs.first().unwrap());
323    let expr_field = input_fields.first().unwrap();
324
325    // Handles the most common case where NULL is unexpected
326    if !expr_field.data_type().is_null() {
327        return Ok(expr);
328    }
329
330    let default_value = get_scalar_value_from_args(input_exprs, 2)?;
331    default_value.map_or(Ok(expr), |value| {
332        ScalarValue::try_from(&value.data_type()).map(|v| {
333            Arc::new(datafusion_physical_expr::expressions::Literal::new(v))
334                as Arc<dyn PhysicalExpr>
335        })
336    })
337}
338
339static NULL_FIELD: LazyLock<FieldRef> =
340    LazyLock::new(|| Field::new("value", DataType::Null, true).into());
341
342/// Returns the field of the default value(if provided) when the
343/// expression is `NULL`.
344///
345/// Otherwise, returns the expression field unchanged.
346fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
347    assert!(!input_fields.is_empty());
348    let expr_field = input_fields.first().unwrap_or(&NULL_FIELD);
349
350    // Handles the most common case where NULL is unexpected
351    if !expr_field.data_type().is_null() {
352        return Ok(expr_field.as_ref().clone().with_nullable(true).into());
353    }
354
355    let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD);
356    Ok(default_value_field
357        .as_ref()
358        .clone()
359        .with_nullable(true)
360        .into())
361}
362
363/// Handles type coercion and null value refinement for default value
364/// argument depending on the data type of the input expression.
365fn parse_default_value(
366    input_exprs: &[Arc<dyn PhysicalExpr>],
367    input_types: &[FieldRef],
368) -> Result<ScalarValue> {
369    let expr_field = parse_expr_field(input_types)?;
370    let unparsed = get_scalar_value_from_args(input_exprs, 2)?;
371
372    unparsed
373        .filter(|v| !v.data_type().is_null())
374        .map(|v| v.cast_to(expr_field.data_type()))
375        .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
376}
377
378#[derive(Debug)]
379struct WindowShiftEvaluator {
380    shift_offset: i64,
381    default_value: ScalarValue,
382    ignore_nulls: bool,
383    // VecDeque contains offset values that between non-null entries
384    non_null_offsets: VecDeque<usize>,
385}
386
387impl WindowShiftEvaluator {
388    fn is_lag(&self) -> bool {
389        // Mode is LAG, when shift_offset is positive
390        self.shift_offset > 0
391    }
392}
393
394// implement ignore null for evaluate_all
395fn evaluate_all_with_ignore_null(
396    array: &ArrayRef,
397    offset: i64,
398    default_value: &ScalarValue,
399    is_lag: bool,
400) -> Result<ArrayRef, DataFusionError> {
401    let valid_indices: Vec<usize> =
402        array.nulls().unwrap().valid_indices().collect::<Vec<_>>();
403    let direction = !is_lag;
404    let new_array_results: Result<Vec<_>, DataFusionError> = (0..array.len())
405        .map(|id| {
406            let result_index = match valid_indices.binary_search(&id) {
407                Ok(pos) => if direction {
408                    pos.checked_add(offset as usize)
409                } else {
410                    pos.checked_sub(offset.unsigned_abs() as usize)
411                }
412                .and_then(|new_pos| {
413                    if new_pos < valid_indices.len() {
414                        Some(valid_indices[new_pos])
415                    } else {
416                        None
417                    }
418                }),
419                Err(pos) => if direction {
420                    pos.checked_add(offset as usize)
421                } else if pos > 0 {
422                    pos.checked_sub(offset.unsigned_abs() as usize)
423                } else {
424                    None
425                }
426                .and_then(|new_pos| {
427                    if new_pos < valid_indices.len() {
428                        Some(valid_indices[new_pos])
429                    } else {
430                        None
431                    }
432                }),
433            };
434
435            match result_index {
436                Some(index) => ScalarValue::try_from_array(array, index),
437                None => Ok(default_value.clone()),
438            }
439        })
440        .collect();
441
442    let new_array = new_array_results?;
443    ScalarValue::iter_to_array(new_array)
444}
445// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
446fn shift_with_default_value(
447    array: &ArrayRef,
448    offset: i64,
449    default_value: &ScalarValue,
450) -> Result<ArrayRef> {
451    use datafusion_common::arrow::compute::concat;
452
453    let value_len = array.len() as i64;
454    if offset == 0 {
455        Ok(Arc::clone(array))
456    } else if offset == i64::MIN || offset.abs() >= value_len {
457        default_value.to_array_of_size(value_len as usize)
458    } else {
459        let slice_offset = (-offset).clamp(0, value_len) as usize;
460        let length = array.len() - offset.unsigned_abs() as usize;
461        let slice = array.slice(slice_offset, length);
462
463        // Generate array with remaining `null` items
464        let nulls = offset.unsigned_abs() as usize;
465        let default_values = default_value.to_array_of_size(nulls)?;
466
467        // Concatenate both arrays, add nulls after if shift > 0 else before
468        if offset > 0 {
469            concat(&[default_values.as_ref(), slice.as_ref()])
470                .map_err(|e| arrow_datafusion_err!(e))
471        } else {
472            concat(&[slice.as_ref(), default_values.as_ref()])
473                .map_err(|e| arrow_datafusion_err!(e))
474        }
475    }
476}
477
478impl PartitionEvaluator for WindowShiftEvaluator {
479    fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
480        if self.is_lag() {
481            let start = if self.non_null_offsets.len() == self.shift_offset as usize {
482                // How many rows needed previous than the current row to get necessary lag result
483                let offset: usize = self.non_null_offsets.iter().sum();
484                idx.saturating_sub(offset)
485            } else if !self.ignore_nulls {
486                let offset = self.shift_offset as usize;
487                idx.saturating_sub(offset)
488            } else {
489                0
490            };
491            let end = idx + 1;
492            Ok(Range { start, end })
493        } else {
494            let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize {
495                // How many rows needed further than the current row to get necessary lead result
496                let offset: usize = self.non_null_offsets.iter().sum();
497                min(idx + offset + 1, n_rows)
498            } else if !self.ignore_nulls {
499                let offset = (-self.shift_offset) as usize;
500                min(idx + offset, n_rows)
501            } else {
502                n_rows
503            };
504            Ok(Range { start: idx, end })
505        }
506    }
507
508    fn is_causal(&self) -> bool {
509        // Lagging windows are causal by definition:
510        self.is_lag()
511    }
512
513    fn evaluate(
514        &mut self,
515        values: &[ArrayRef],
516        range: &Range<usize>,
517    ) -> Result<ScalarValue> {
518        let array = &values[0];
519        let len = array.len();
520
521        // LAG mode
522        let i = if self.is_lag() {
523            (range.end as i64 - self.shift_offset - 1) as usize
524        } else {
525            // LEAD mode
526            (range.start as i64 - self.shift_offset) as usize
527        };
528
529        let mut idx: Option<usize> = if i < len { Some(i) } else { None };
530
531        // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows
532        // If current row index points to NULL value the row is NOT counted
533        if self.ignore_nulls && self.is_lag() {
534            // LAG when NULLS are ignored.
535            // Find the nonNULL row index that shifted by offset comparing to current row index
536            idx = if self.non_null_offsets.len() == self.shift_offset as usize {
537                let total_offset: usize = self.non_null_offsets.iter().sum();
538                Some(range.end - 1 - total_offset)
539            } else {
540                None
541            };
542
543            // Keep track of offset values between non-null entries
544            if array.is_valid(range.end - 1) {
545                // Non-null add new offset
546                self.non_null_offsets.push_back(1);
547                if self.non_null_offsets.len() > self.shift_offset as usize {
548                    // WE do not need to keep track of more than `lag number of offset` values.
549                    self.non_null_offsets.pop_front();
550                }
551            } else if !self.non_null_offsets.is_empty() {
552                // Entry is null, increment offset value of the last entry.
553                let end_idx = self.non_null_offsets.len() - 1;
554                self.non_null_offsets[end_idx] += 1;
555            }
556        } else if self.ignore_nulls && !self.is_lag() {
557            // LEAD when NULLS are ignored.
558            // Stores the necessary non-null entry number further than the current row.
559            let non_null_row_count = (-self.shift_offset) as usize;
560
561            if self.non_null_offsets.is_empty() {
562                // When empty, fill non_null offsets with the data further than the current row.
563                let mut offset_val = 1;
564                for idx in range.start + 1..range.end {
565                    if array.is_valid(idx) {
566                        self.non_null_offsets.push_back(offset_val);
567                        offset_val = 1;
568                    } else {
569                        offset_val += 1;
570                    }
571                    // It is enough to keep track of `non_null_row_count + 1` non-null offset.
572                    // further data is unnecessary for the result.
573                    if self.non_null_offsets.len() == non_null_row_count + 1 {
574                        break;
575                    }
576                }
577            } else if range.end < len && array.is_valid(range.end) {
578                // Update `non_null_offsets` with the new end data.
579                if array.is_valid(range.end) {
580                    // When non-null, append a new offset.
581                    self.non_null_offsets.push_back(1);
582                } else {
583                    // When null, increment offset count of the last entry
584                    let last_idx = self.non_null_offsets.len() - 1;
585                    self.non_null_offsets[last_idx] += 1;
586                }
587            }
588
589            // Find the nonNULL row index that shifted by offset comparing to current row index
590            idx = if self.non_null_offsets.len() >= non_null_row_count {
591                let total_offset: usize =
592                    self.non_null_offsets.iter().take(non_null_row_count).sum();
593                Some(range.start + total_offset)
594            } else {
595                None
596            };
597            // Prune `self.non_null_offsets` from the start. so that at next iteration
598            // start of the `self.non_null_offsets` matches with current row.
599            if !self.non_null_offsets.is_empty() {
600                self.non_null_offsets[0] -= 1;
601                if self.non_null_offsets[0] == 0 {
602                    // When offset is 0. Remove it.
603                    self.non_null_offsets.pop_front();
604                }
605            }
606        }
607
608        // Set the default value if
609        // - index is out of window bounds
610        // OR
611        // - ignore nulls mode and current value is null and is within window bounds
612        // .unwrap() is safe here as there is a none check in front
613        #[allow(clippy::unnecessary_unwrap)]
614        if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
615            ScalarValue::try_from_array(array, idx.unwrap())
616        } else {
617            Ok(self.default_value.clone())
618        }
619    }
620
621    fn evaluate_all(
622        &mut self,
623        values: &[ArrayRef],
624        _num_rows: usize,
625    ) -> Result<ArrayRef> {
626        // LEAD, LAG window functions take single column, values will have size 1
627        let value = &values[0];
628        if !self.ignore_nulls {
629            shift_with_default_value(value, self.shift_offset, &self.default_value)
630        } else {
631            evaluate_all_with_ignore_null(
632                value,
633                self.shift_offset,
634                &self.default_value,
635                self.is_lag(),
636            )
637        }
638    }
639
640    fn supports_bounded_execution(&self) -> bool {
641        true
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use arrow::array::*;
649    use datafusion_common::cast::as_int32_array;
650    use datafusion_physical_expr::expressions::{Column, Literal};
651    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
652
653    fn test_i32_result(
654        expr: WindowShift,
655        partition_evaluator_args: PartitionEvaluatorArgs,
656        expected: Int32Array,
657    ) -> Result<()> {
658        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
659        let values = vec![arr];
660        let num_rows = values.len();
661        let result = expr
662            .partition_evaluator(partition_evaluator_args)?
663            .evaluate_all(&values, num_rows)?;
664        let result = as_int32_array(&result)?;
665        assert_eq!(expected, *result);
666        Ok(())
667    }
668
669    #[test]
670    fn lead_lag_get_range() -> Result<()> {
671        // LAG(2)
672        let lag_fn = WindowShiftEvaluator {
673            shift_offset: 2,
674            default_value: ScalarValue::Null,
675            ignore_nulls: false,
676            non_null_offsets: Default::default(),
677        };
678        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 });
679        assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 });
680
681        // LAG(2 ignore nulls)
682        let lag_fn = WindowShiftEvaluator {
683            shift_offset: 2,
684            default_value: ScalarValue::Null,
685            ignore_nulls: true,
686            // models data received [<Some>, <Some>, <Some>, NULL, <Some>, NULL, <current row>, ...]
687            non_null_offsets: vec![2, 2].into(), // [1, 1, 2, 2] actually, just last 2 is used
688        };
689        assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 });
690
691        // LEAD(2)
692        let lead_fn = WindowShiftEvaluator {
693            shift_offset: -2,
694            default_value: ScalarValue::Null,
695            ignore_nulls: false,
696            non_null_offsets: Default::default(),
697        };
698        assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 });
699        assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 });
700
701        // LEAD(2 ignore nulls)
702        let lead_fn = WindowShiftEvaluator {
703            shift_offset: -2,
704            default_value: ScalarValue::Null,
705            ignore_nulls: true,
706            // models data received [..., <current row>, NULL, <Some>, NULL, <Some>, ..]
707            non_null_offsets: vec![2, 2].into(),
708        };
709        assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 });
710
711        Ok(())
712    }
713
714    #[test]
715    fn test_lead_window_shift() -> Result<()> {
716        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
717
718        test_i32_result(
719            WindowShift::lead(),
720            PartitionEvaluatorArgs::new(
721                &[expr],
722                &[Field::new("f", DataType::Int32, true).into()],
723                false,
724                false,
725            ),
726            [
727                Some(-2),
728                Some(3),
729                Some(-4),
730                Some(5),
731                Some(-6),
732                Some(7),
733                Some(8),
734                None,
735            ]
736            .iter()
737            .collect::<Int32Array>(),
738        )
739    }
740
741    #[test]
742    fn test_lag_window_shift() -> Result<()> {
743        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
744
745        test_i32_result(
746            WindowShift::lag(),
747            PartitionEvaluatorArgs::new(
748                &[expr],
749                &[Field::new("f", DataType::Int32, true).into()],
750                false,
751                false,
752            ),
753            [
754                None,
755                Some(1),
756                Some(-2),
757                Some(3),
758                Some(-4),
759                Some(5),
760                Some(-6),
761                Some(7),
762            ]
763            .iter()
764            .collect::<Int32Array>(),
765        )
766    }
767
768    #[test]
769    fn test_lag_with_default() -> Result<()> {
770        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
771        let shift_offset =
772            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
773        let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100))))
774            as Arc<dyn PhysicalExpr>;
775
776        let input_exprs = &[expr, shift_offset, default_value];
777        let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32]
778            .into_iter()
779            .map(|d| Field::new("f", d, true))
780            .map(Arc::new)
781            .collect::<Vec<_>>();
782
783        test_i32_result(
784            WindowShift::lag(),
785            PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false),
786            [
787                Some(100),
788                Some(1),
789                Some(-2),
790                Some(3),
791                Some(-4),
792                Some(5),
793                Some(-6),
794                Some(7),
795            ]
796            .iter()
797            .collect::<Int32Array>(),
798        )
799    }
800}