datafusion_functions_window/
nth_value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `nth_value` window function implementation
19
20use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29    Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
30    Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use field::WindowUDFFieldArgs;
35use std::any::Any;
36use std::cmp::Ordering;
37use std::fmt::Debug;
38use std::ops::Range;
39use std::sync::LazyLock;
40
41get_or_init_udwf!(
42    First,
43    first_value,
44    "returns the first value in the window frame",
45    NthValue::first
46);
47get_or_init_udwf!(
48    Last,
49    last_value,
50    "returns the last value in the window frame",
51    NthValue::last
52);
53get_or_init_udwf!(
54    NthValue,
55    nth_value,
56    "returns the nth value in the window frame",
57    NthValue::nth
58);
59
60/// Create an expression to represent the `first_value` window function
61///
62pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
63    first_value_udwf().call(vec![arg])
64}
65
66/// Create an expression to represent the `last_value` window function
67///
68pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
69    last_value_udwf().call(vec![arg])
70}
71
72/// Create an expression to represent the `nth_value` window function
73///
74pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
75    nth_value_udwf().call(vec![arg, n.lit()])
76}
77
78/// Tag to differentiate special use cases of the NTH_VALUE built-in window function.
79#[derive(Debug, Copy, Clone)]
80pub enum NthValueKind {
81    First,
82    Last,
83    Nth,
84}
85
86impl NthValueKind {
87    fn name(&self) -> &'static str {
88        match self {
89            NthValueKind::First => "first_value",
90            NthValueKind::Last => "last_value",
91            NthValueKind::Nth => "nth_value",
92        }
93    }
94}
95
96#[derive(Debug)]
97pub struct NthValue {
98    signature: Signature,
99    kind: NthValueKind,
100}
101
102impl NthValue {
103    /// Create a new `nth_value` function
104    pub fn new(kind: NthValueKind) -> Self {
105        Self {
106            signature: Signature::one_of(
107                vec![
108                    TypeSignature::Any(0),
109                    TypeSignature::Any(1),
110                    TypeSignature::Any(2),
111                ],
112                Volatility::Immutable,
113            ),
114            kind,
115        }
116    }
117
118    pub fn first() -> Self {
119        Self::new(NthValueKind::First)
120    }
121
122    pub fn last() -> Self {
123        Self::new(NthValueKind::Last)
124    }
125    pub fn nth() -> Self {
126        Self::new(NthValueKind::Nth)
127    }
128}
129
130static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
131    Documentation::builder(
132        DOC_SECTION_ANALYTICAL,
133        "Returns value evaluated at the row that is the first row of the window \
134            frame.",
135        "first_value(expression)",
136    )
137    .with_argument("expression", "Expression to operate on")
138        .with_sql_example(r#"```sql
139    --Example usage of the first_value window function:
140    SELECT department,
141           employee_id,
142           salary,
143           first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
144    FROM employees;
145```
146
147```sql
148+-------------+-------------+--------+------------+
149| department  | employee_id | salary | top_salary |
150+-------------+-------------+--------+------------+
151| Sales       | 1           | 70000  | 70000      |
152| Sales       | 2           | 50000  | 70000      |
153| Sales       | 3           | 30000  | 70000      |
154| Engineering | 4           | 90000  | 90000      |
155| Engineering | 5           | 80000  | 90000      |
156+-------------+-------------+--------+------------+
157```"#)
158    .build()
159});
160
161fn get_first_value_doc() -> &'static Documentation {
162    &FIRST_VALUE_DOCUMENTATION
163}
164
165static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
166    Documentation::builder(
167        DOC_SECTION_ANALYTICAL,
168        "Returns value evaluated at the row that is the last row of the window \
169            frame.",
170        "last_value(expression)",
171    )
172    .with_argument("expression", "Expression to operate on")
173        .with_sql_example(r#"```sql
174-- SQL example of last_value:
175SELECT department,
176       employee_id,
177       salary,
178       last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
179FROM employees;
180```
181
182```sql
183+-------------+-------------+--------+---------------------+
184| department  | employee_id | salary | running_last_salary |
185+-------------+-------------+--------+---------------------+
186| Sales       | 1           | 30000  | 30000               |
187| Sales       | 2           | 50000  | 50000               |
188| Sales       | 3           | 70000  | 70000               |
189| Engineering | 4           | 40000  | 40000               |
190| Engineering | 5           | 60000  | 60000               |
191+-------------+-------------+--------+---------------------+
192```"#)
193    .build()
194});
195
196fn get_last_value_doc() -> &'static Documentation {
197    &LAST_VALUE_DOCUMENTATION
198}
199
200static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
201    Documentation::builder(
202        DOC_SECTION_ANALYTICAL,
203        "Returns the value evaluated at the nth row of the window frame \
204         (counting from 1). Returns NULL if no such row exists.",
205        "nth_value(expression, n)",
206    )
207    .with_argument(
208        "expression",
209        "The column from which to retrieve the nth value.",
210    )
211    .with_argument(
212        "n",
213        "Integer. Specifies the row number (starting from 1) in the window frame.",
214    )
215    .with_sql_example(
216        r#"```sql
217-- Sample employees table:
218CREATE TABLE employees (id INT, salary INT);
219INSERT INTO employees (id, salary) VALUES
220(1, 30000),
221(2, 40000),
222(3, 50000),
223(4, 60000),
224(5, 70000);
225
226-- Example usage of nth_value:
227SELECT nth_value(salary, 2) OVER (
228  ORDER BY salary
229  ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
230) AS nth_value
231FROM employees;
232```
233
234```text
235+-----------+
236| nth_value |
237+-----------+
238| 40000     |
239| 40000     |
240| 40000     |
241| 40000     |
242| 40000     |
243+-----------+
244```"#,
245    )
246    .build()
247});
248
249fn get_nth_value_doc() -> &'static Documentation {
250    &NTH_VALUE_DOCUMENTATION
251}
252
253impl WindowUDFImpl for NthValue {
254    fn as_any(&self) -> &dyn Any {
255        self
256    }
257
258    fn name(&self) -> &str {
259        self.kind.name()
260    }
261
262    fn signature(&self) -> &Signature {
263        &self.signature
264    }
265
266    fn partition_evaluator(
267        &self,
268        partition_evaluator_args: PartitionEvaluatorArgs,
269    ) -> Result<Box<dyn PartitionEvaluator>> {
270        let state = NthValueState {
271            finalized_result: None,
272            kind: self.kind,
273        };
274
275        if !matches!(self.kind, NthValueKind::Nth) {
276            return Ok(Box::new(NthValueEvaluator {
277                state,
278                ignore_nulls: partition_evaluator_args.ignore_nulls(),
279                n: 0,
280            }));
281        }
282
283        let n =
284            match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
285                .map_err(|_e| {
286                    exec_datafusion_err!(
287                "Expected a signed integer literal for the second argument of nth_value")
288                })?
289                .map(get_signed_integer)
290            {
291                Some(Ok(n)) => {
292                    if partition_evaluator_args.is_reversed() {
293                        -n
294                    } else {
295                        n
296                    }
297                }
298                _ => {
299                    return exec_err!(
300                "Expected a signed integer literal for the second argument of nth_value"
301            )
302                }
303            };
304
305        Ok(Box::new(NthValueEvaluator {
306            state,
307            ignore_nulls: partition_evaluator_args.ignore_nulls(),
308            n,
309        }))
310    }
311
312    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
313        let return_type = field_args
314            .input_fields()
315            .first()
316            .map(|f| f.data_type())
317            .cloned()
318            .unwrap_or(DataType::Null);
319
320        Ok(Field::new(field_args.name(), return_type, true).into())
321    }
322
323    fn reverse_expr(&self) -> ReversedUDWF {
324        match self.kind {
325            NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
326            NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
327            NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
328        }
329    }
330
331    fn documentation(&self) -> Option<&Documentation> {
332        match self.kind {
333            NthValueKind::First => Some(get_first_value_doc()),
334            NthValueKind::Last => Some(get_last_value_doc()),
335            NthValueKind::Nth => Some(get_nth_value_doc()),
336        }
337    }
338}
339
340#[derive(Debug, Clone)]
341pub struct NthValueState {
342    // In certain cases, we can finalize the result early. Consider this usage:
343    // ```
344    //  FIRST_VALUE(increasing_col) OVER window AS my_first_value
345    //  WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window
346    // ```
347    // The result will always be the first entry in the table. We can store such
348    // early-finalizing results and then just reuse them as necessary. This opens
349    // opportunities to prune our datasets.
350    pub finalized_result: Option<ScalarValue>,
351    pub kind: NthValueKind,
352}
353
354#[derive(Debug)]
355pub(crate) struct NthValueEvaluator {
356    state: NthValueState,
357    ignore_nulls: bool,
358    n: i64,
359}
360
361impl PartitionEvaluator for NthValueEvaluator {
362    /// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING),
363    /// for some functions such as FIRST_VALUE, LAST_VALUE and NTH_VALUE, we
364    /// can memoize the result.  Once result is calculated, it will always stay
365    /// same. Hence, we do not need to keep past data as we process the entire
366    /// dataset.
367    fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
368        let out = &state.out_col;
369        let size = out.len();
370        let mut buffer_size = 1;
371        // Decide if we arrived at a final result yet:
372        let (is_prunable, is_reverse_direction) = match self.state.kind {
373            NthValueKind::First => {
374                let n_range =
375                    state.window_frame_range.end - state.window_frame_range.start;
376                (n_range > 0 && size > 0, false)
377            }
378            NthValueKind::Last => (true, true),
379            NthValueKind::Nth => {
380                let n_range =
381                    state.window_frame_range.end - state.window_frame_range.start;
382                match self.n.cmp(&0) {
383                    Ordering::Greater => (
384                        n_range >= (self.n as usize) && size > (self.n as usize),
385                        false,
386                    ),
387                    Ordering::Less => {
388                        let reverse_index = (-self.n) as usize;
389                        buffer_size = reverse_index;
390                        // Negative index represents reverse direction.
391                        (n_range >= reverse_index, true)
392                    }
393                    Ordering::Equal => (false, false),
394                }
395            }
396        };
397        // Do not memoize results when nulls are ignored.
398        if is_prunable && !self.ignore_nulls {
399            if self.state.finalized_result.is_none() && !is_reverse_direction {
400                let result = ScalarValue::try_from_array(out, size - 1)?;
401                self.state.finalized_result = Some(result);
402            }
403            state.window_frame_range.start =
404                state.window_frame_range.end.saturating_sub(buffer_size);
405        }
406        Ok(())
407    }
408
409    fn evaluate(
410        &mut self,
411        values: &[ArrayRef],
412        range: &Range<usize>,
413    ) -> Result<ScalarValue> {
414        if let Some(ref result) = self.state.finalized_result {
415            Ok(result.clone())
416        } else {
417            // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1.
418            let arr = &values[0];
419            let n_range = range.end - range.start;
420            if n_range == 0 {
421                // We produce None if the window is empty.
422                return ScalarValue::try_from(arr.data_type());
423            }
424
425            // If null values exist and need to be ignored, extract the valid indices.
426            let valid_indices = if self.ignore_nulls {
427                // Calculate valid indices, inside the window frame boundaries.
428                let slice = arr.slice(range.start, n_range);
429                match slice.nulls() {
430                    Some(nulls) => {
431                        let valid_indices = nulls
432                            .valid_indices()
433                            .map(|idx| {
434                                // Add offset `range.start` to valid indices, to point correct index in the original arr.
435                                idx + range.start
436                            })
437                            .collect::<Vec<_>>();
438                        if valid_indices.is_empty() {
439                            // If all values are null, return directly.
440                            return ScalarValue::try_from(arr.data_type());
441                        }
442                        Some(valid_indices)
443                    }
444                    None => None,
445                }
446            } else {
447                None
448            };
449            match self.state.kind {
450                NthValueKind::First => {
451                    if let Some(valid_indices) = &valid_indices {
452                        ScalarValue::try_from_array(arr, valid_indices[0])
453                    } else {
454                        ScalarValue::try_from_array(arr, range.start)
455                    }
456                }
457                NthValueKind::Last => {
458                    if let Some(valid_indices) = &valid_indices {
459                        ScalarValue::try_from_array(
460                            arr,
461                            valid_indices[valid_indices.len() - 1],
462                        )
463                    } else {
464                        ScalarValue::try_from_array(arr, range.end - 1)
465                    }
466                }
467                NthValueKind::Nth => {
468                    match self.n.cmp(&0) {
469                        Ordering::Greater => {
470                            // SQL indices are not 0-based.
471                            let index = (self.n as usize) - 1;
472                            if index >= n_range {
473                                // Outside the range, return NULL:
474                                ScalarValue::try_from(arr.data_type())
475                            } else if let Some(valid_indices) = valid_indices {
476                                if index >= valid_indices.len() {
477                                    return ScalarValue::try_from(arr.data_type());
478                                }
479                                ScalarValue::try_from_array(&arr, valid_indices[index])
480                            } else {
481                                ScalarValue::try_from_array(arr, range.start + index)
482                            }
483                        }
484                        Ordering::Less => {
485                            let reverse_index = (-self.n) as usize;
486                            if n_range < reverse_index {
487                                // Outside the range, return NULL:
488                                ScalarValue::try_from(arr.data_type())
489                            } else if let Some(valid_indices) = valid_indices {
490                                if reverse_index > valid_indices.len() {
491                                    return ScalarValue::try_from(arr.data_type());
492                                }
493                                let new_index =
494                                    valid_indices[valid_indices.len() - reverse_index];
495                                ScalarValue::try_from_array(&arr, new_index)
496                            } else {
497                                ScalarValue::try_from_array(
498                                    arr,
499                                    range.start + n_range - reverse_index,
500                                )
501                            }
502                        }
503                        Ordering::Equal => ScalarValue::try_from(arr.data_type()),
504                    }
505                }
506            }
507        }
508    }
509
510    fn supports_bounded_execution(&self) -> bool {
511        true
512    }
513
514    fn uses_window_frame(&self) -> bool {
515        true
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use arrow::array::*;
523    use datafusion_common::cast::as_int32_array;
524    use datafusion_physical_expr::expressions::{Column, Literal};
525    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
526    use std::sync::Arc;
527
528    fn test_i32_result(
529        expr: NthValue,
530        partition_evaluator_args: PartitionEvaluatorArgs,
531        expected: Int32Array,
532    ) -> Result<()> {
533        let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
534        let values = vec![arr];
535        let mut ranges: Vec<Range<usize>> = vec![];
536        for i in 0..8 {
537            ranges.push(Range {
538                start: 0,
539                end: i + 1,
540            })
541        }
542        let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
543        let result = ranges
544            .iter()
545            .map(|range| evaluator.evaluate(&values, range))
546            .collect::<Result<Vec<ScalarValue>>>()?;
547        let result = ScalarValue::iter_to_array(result.into_iter())?;
548        let result = as_int32_array(&result)?;
549        assert_eq!(expected, *result);
550        Ok(())
551    }
552
553    #[test]
554    fn first_value() -> Result<()> {
555        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
556        test_i32_result(
557            NthValue::first(),
558            PartitionEvaluatorArgs::new(
559                &[expr],
560                &[Field::new("f", DataType::Int32, true).into()],
561                false,
562                false,
563            ),
564            Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
565        )
566    }
567
568    #[test]
569    fn last_value() -> Result<()> {
570        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
571        test_i32_result(
572            NthValue::last(),
573            PartitionEvaluatorArgs::new(
574                &[expr],
575                &[Field::new("f", DataType::Int32, true).into()],
576                false,
577                false,
578            ),
579            Int32Array::from(vec![
580                Some(1),
581                Some(-2),
582                Some(3),
583                Some(-4),
584                Some(5),
585                Some(-6),
586                Some(7),
587                Some(8),
588            ]),
589        )
590    }
591
592    #[test]
593    fn nth_value_1() -> Result<()> {
594        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
595        let n_value =
596            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
597
598        test_i32_result(
599            NthValue::nth(),
600            PartitionEvaluatorArgs::new(
601                &[expr, n_value],
602                &[Field::new("f", DataType::Int32, true).into()],
603                false,
604                false,
605            ),
606            Int32Array::from(vec![1; 8]),
607        )?;
608        Ok(())
609    }
610
611    #[test]
612    fn nth_value_2() -> Result<()> {
613        let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
614        let n_value =
615            Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
616
617        test_i32_result(
618            NthValue::nth(),
619            PartitionEvaluatorArgs::new(
620                &[expr, n_value],
621                &[Field::new("f", DataType::Int32, true).into()],
622                false,
623                false,
624            ),
625            Int32Array::from(vec![
626                None,
627                Some(-2),
628                Some(-2),
629                Some(-2),
630                Some(-2),
631                Some(-2),
632                Some(-2),
633                Some(-2),
634            ]),
635        )?;
636        Ok(())
637    }
638}