1use crate::utils::{get_scalar_value_from_args, get_signed_integer};
21
22use arrow::datatypes::FieldRef;
23use datafusion_common::arrow::array::ArrayRef;
24use datafusion_common::arrow::datatypes::{DataType, Field};
25use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
26use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
27use datafusion_expr::window_state::WindowAggState;
28use datafusion_expr::{
29 Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature,
30 Volatility, WindowUDFImpl,
31};
32use datafusion_functions_window_common::field;
33use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
34use field::WindowUDFFieldArgs;
35use std::any::Any;
36use std::cmp::Ordering;
37use std::fmt::Debug;
38use std::ops::Range;
39use std::sync::LazyLock;
40
41get_or_init_udwf!(
42 First,
43 first_value,
44 "returns the first value in the window frame",
45 NthValue::first
46);
47get_or_init_udwf!(
48 Last,
49 last_value,
50 "returns the last value in the window frame",
51 NthValue::last
52);
53get_or_init_udwf!(
54 NthValue,
55 nth_value,
56 "returns the nth value in the window frame",
57 NthValue::nth
58);
59
60pub fn first_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
63 first_value_udwf().call(vec![arg])
64}
65
66pub fn last_value(arg: datafusion_expr::Expr) -> datafusion_expr::Expr {
69 last_value_udwf().call(vec![arg])
70}
71
72pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr {
75 nth_value_udwf().call(vec![arg, n.lit()])
76}
77
78#[derive(Debug, Copy, Clone)]
80pub enum NthValueKind {
81 First,
82 Last,
83 Nth,
84}
85
86impl NthValueKind {
87 fn name(&self) -> &'static str {
88 match self {
89 NthValueKind::First => "first_value",
90 NthValueKind::Last => "last_value",
91 NthValueKind::Nth => "nth_value",
92 }
93 }
94}
95
96#[derive(Debug)]
97pub struct NthValue {
98 signature: Signature,
99 kind: NthValueKind,
100}
101
102impl NthValue {
103 pub fn new(kind: NthValueKind) -> Self {
105 Self {
106 signature: Signature::one_of(
107 vec![
108 TypeSignature::Any(0),
109 TypeSignature::Any(1),
110 TypeSignature::Any(2),
111 ],
112 Volatility::Immutable,
113 ),
114 kind,
115 }
116 }
117
118 pub fn first() -> Self {
119 Self::new(NthValueKind::First)
120 }
121
122 pub fn last() -> Self {
123 Self::new(NthValueKind::Last)
124 }
125 pub fn nth() -> Self {
126 Self::new(NthValueKind::Nth)
127 }
128}
129
130static FIRST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
131 Documentation::builder(
132 DOC_SECTION_ANALYTICAL,
133 "Returns value evaluated at the row that is the first row of the window \
134 frame.",
135 "first_value(expression)",
136 )
137 .with_argument("expression", "Expression to operate on")
138 .with_sql_example(r#"```sql
139 --Example usage of the first_value window function:
140 SELECT department,
141 employee_id,
142 salary,
143 first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary
144 FROM employees;
145```
146
147```sql
148+-------------+-------------+--------+------------+
149| department | employee_id | salary | top_salary |
150+-------------+-------------+--------+------------+
151| Sales | 1 | 70000 | 70000 |
152| Sales | 2 | 50000 | 70000 |
153| Sales | 3 | 30000 | 70000 |
154| Engineering | 4 | 90000 | 90000 |
155| Engineering | 5 | 80000 | 90000 |
156+-------------+-------------+--------+------------+
157```"#)
158 .build()
159});
160
161fn get_first_value_doc() -> &'static Documentation {
162 &FIRST_VALUE_DOCUMENTATION
163}
164
165static LAST_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
166 Documentation::builder(
167 DOC_SECTION_ANALYTICAL,
168 "Returns value evaluated at the row that is the last row of the window \
169 frame.",
170 "last_value(expression)",
171 )
172 .with_argument("expression", "Expression to operate on")
173 .with_sql_example(r#"```sql
174-- SQL example of last_value:
175SELECT department,
176 employee_id,
177 salary,
178 last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary
179FROM employees;
180```
181
182```sql
183+-------------+-------------+--------+---------------------+
184| department | employee_id | salary | running_last_salary |
185+-------------+-------------+--------+---------------------+
186| Sales | 1 | 30000 | 30000 |
187| Sales | 2 | 50000 | 50000 |
188| Sales | 3 | 70000 | 70000 |
189| Engineering | 4 | 40000 | 40000 |
190| Engineering | 5 | 60000 | 60000 |
191+-------------+-------------+--------+---------------------+
192```"#)
193 .build()
194});
195
196fn get_last_value_doc() -> &'static Documentation {
197 &LAST_VALUE_DOCUMENTATION
198}
199
200static NTH_VALUE_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
201 Documentation::builder(
202 DOC_SECTION_ANALYTICAL,
203 "Returns the value evaluated at the nth row of the window frame \
204 (counting from 1). Returns NULL if no such row exists.",
205 "nth_value(expression, n)",
206 )
207 .with_argument(
208 "expression",
209 "The column from which to retrieve the nth value.",
210 )
211 .with_argument(
212 "n",
213 "Integer. Specifies the row number (starting from 1) in the window frame.",
214 )
215 .with_sql_example(
216 r#"```sql
217-- Sample employees table:
218CREATE TABLE employees (id INT, salary INT);
219INSERT INTO employees (id, salary) VALUES
220(1, 30000),
221(2, 40000),
222(3, 50000),
223(4, 60000),
224(5, 70000);
225
226-- Example usage of nth_value:
227SELECT nth_value(salary, 2) OVER (
228 ORDER BY salary
229 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
230) AS nth_value
231FROM employees;
232```
233
234```text
235+-----------+
236| nth_value |
237+-----------+
238| 40000 |
239| 40000 |
240| 40000 |
241| 40000 |
242| 40000 |
243+-----------+
244```"#,
245 )
246 .build()
247});
248
249fn get_nth_value_doc() -> &'static Documentation {
250 &NTH_VALUE_DOCUMENTATION
251}
252
253impl WindowUDFImpl for NthValue {
254 fn as_any(&self) -> &dyn Any {
255 self
256 }
257
258 fn name(&self) -> &str {
259 self.kind.name()
260 }
261
262 fn signature(&self) -> &Signature {
263 &self.signature
264 }
265
266 fn partition_evaluator(
267 &self,
268 partition_evaluator_args: PartitionEvaluatorArgs,
269 ) -> Result<Box<dyn PartitionEvaluator>> {
270 let state = NthValueState {
271 finalized_result: None,
272 kind: self.kind,
273 };
274
275 if !matches!(self.kind, NthValueKind::Nth) {
276 return Ok(Box::new(NthValueEvaluator {
277 state,
278 ignore_nulls: partition_evaluator_args.ignore_nulls(),
279 n: 0,
280 }));
281 }
282
283 let n =
284 match get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)
285 .map_err(|_e| {
286 exec_datafusion_err!(
287 "Expected a signed integer literal for the second argument of nth_value")
288 })?
289 .map(get_signed_integer)
290 {
291 Some(Ok(n)) => {
292 if partition_evaluator_args.is_reversed() {
293 -n
294 } else {
295 n
296 }
297 }
298 _ => {
299 return exec_err!(
300 "Expected a signed integer literal for the second argument of nth_value"
301 )
302 }
303 };
304
305 Ok(Box::new(NthValueEvaluator {
306 state,
307 ignore_nulls: partition_evaluator_args.ignore_nulls(),
308 n,
309 }))
310 }
311
312 fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
313 let return_type = field_args
314 .input_fields()
315 .first()
316 .map(|f| f.data_type())
317 .cloned()
318 .unwrap_or(DataType::Null);
319
320 Ok(Field::new(field_args.name(), return_type, true).into())
321 }
322
323 fn reverse_expr(&self) -> ReversedUDWF {
324 match self.kind {
325 NthValueKind::First => ReversedUDWF::Reversed(last_value_udwf()),
326 NthValueKind::Last => ReversedUDWF::Reversed(first_value_udwf()),
327 NthValueKind::Nth => ReversedUDWF::Reversed(nth_value_udwf()),
328 }
329 }
330
331 fn documentation(&self) -> Option<&Documentation> {
332 match self.kind {
333 NthValueKind::First => Some(get_first_value_doc()),
334 NthValueKind::Last => Some(get_last_value_doc()),
335 NthValueKind::Nth => Some(get_nth_value_doc()),
336 }
337 }
338}
339
340#[derive(Debug, Clone)]
341pub struct NthValueState {
342 pub finalized_result: Option<ScalarValue>,
351 pub kind: NthValueKind,
352}
353
354#[derive(Debug)]
355pub(crate) struct NthValueEvaluator {
356 state: NthValueState,
357 ignore_nulls: bool,
358 n: i64,
359}
360
361impl PartitionEvaluator for NthValueEvaluator {
362 fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
368 let out = &state.out_col;
369 let size = out.len();
370 let mut buffer_size = 1;
371 let (is_prunable, is_reverse_direction) = match self.state.kind {
373 NthValueKind::First => {
374 let n_range =
375 state.window_frame_range.end - state.window_frame_range.start;
376 (n_range > 0 && size > 0, false)
377 }
378 NthValueKind::Last => (true, true),
379 NthValueKind::Nth => {
380 let n_range =
381 state.window_frame_range.end - state.window_frame_range.start;
382 match self.n.cmp(&0) {
383 Ordering::Greater => (
384 n_range >= (self.n as usize) && size > (self.n as usize),
385 false,
386 ),
387 Ordering::Less => {
388 let reverse_index = (-self.n) as usize;
389 buffer_size = reverse_index;
390 (n_range >= reverse_index, true)
392 }
393 Ordering::Equal => (false, false),
394 }
395 }
396 };
397 if is_prunable && !self.ignore_nulls {
399 if self.state.finalized_result.is_none() && !is_reverse_direction {
400 let result = ScalarValue::try_from_array(out, size - 1)?;
401 self.state.finalized_result = Some(result);
402 }
403 state.window_frame_range.start =
404 state.window_frame_range.end.saturating_sub(buffer_size);
405 }
406 Ok(())
407 }
408
409 fn evaluate(
410 &mut self,
411 values: &[ArrayRef],
412 range: &Range<usize>,
413 ) -> Result<ScalarValue> {
414 if let Some(ref result) = self.state.finalized_result {
415 Ok(result.clone())
416 } else {
417 let arr = &values[0];
419 let n_range = range.end - range.start;
420 if n_range == 0 {
421 return ScalarValue::try_from(arr.data_type());
423 }
424
425 let valid_indices = if self.ignore_nulls {
427 let slice = arr.slice(range.start, n_range);
429 match slice.nulls() {
430 Some(nulls) => {
431 let valid_indices = nulls
432 .valid_indices()
433 .map(|idx| {
434 idx + range.start
436 })
437 .collect::<Vec<_>>();
438 if valid_indices.is_empty() {
439 return ScalarValue::try_from(arr.data_type());
441 }
442 Some(valid_indices)
443 }
444 None => None,
445 }
446 } else {
447 None
448 };
449 match self.state.kind {
450 NthValueKind::First => {
451 if let Some(valid_indices) = &valid_indices {
452 ScalarValue::try_from_array(arr, valid_indices[0])
453 } else {
454 ScalarValue::try_from_array(arr, range.start)
455 }
456 }
457 NthValueKind::Last => {
458 if let Some(valid_indices) = &valid_indices {
459 ScalarValue::try_from_array(
460 arr,
461 valid_indices[valid_indices.len() - 1],
462 )
463 } else {
464 ScalarValue::try_from_array(arr, range.end - 1)
465 }
466 }
467 NthValueKind::Nth => {
468 match self.n.cmp(&0) {
469 Ordering::Greater => {
470 let index = (self.n as usize) - 1;
472 if index >= n_range {
473 ScalarValue::try_from(arr.data_type())
475 } else if let Some(valid_indices) = valid_indices {
476 if index >= valid_indices.len() {
477 return ScalarValue::try_from(arr.data_type());
478 }
479 ScalarValue::try_from_array(&arr, valid_indices[index])
480 } else {
481 ScalarValue::try_from_array(arr, range.start + index)
482 }
483 }
484 Ordering::Less => {
485 let reverse_index = (-self.n) as usize;
486 if n_range < reverse_index {
487 ScalarValue::try_from(arr.data_type())
489 } else if let Some(valid_indices) = valid_indices {
490 if reverse_index > valid_indices.len() {
491 return ScalarValue::try_from(arr.data_type());
492 }
493 let new_index =
494 valid_indices[valid_indices.len() - reverse_index];
495 ScalarValue::try_from_array(&arr, new_index)
496 } else {
497 ScalarValue::try_from_array(
498 arr,
499 range.start + n_range - reverse_index,
500 )
501 }
502 }
503 Ordering::Equal => ScalarValue::try_from(arr.data_type()),
504 }
505 }
506 }
507 }
508 }
509
510 fn supports_bounded_execution(&self) -> bool {
511 true
512 }
513
514 fn uses_window_frame(&self) -> bool {
515 true
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522 use arrow::array::*;
523 use datafusion_common::cast::as_int32_array;
524 use datafusion_physical_expr::expressions::{Column, Literal};
525 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
526 use std::sync::Arc;
527
528 fn test_i32_result(
529 expr: NthValue,
530 partition_evaluator_args: PartitionEvaluatorArgs,
531 expected: Int32Array,
532 ) -> Result<()> {
533 let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
534 let values = vec![arr];
535 let mut ranges: Vec<Range<usize>> = vec![];
536 for i in 0..8 {
537 ranges.push(Range {
538 start: 0,
539 end: i + 1,
540 })
541 }
542 let mut evaluator = expr.partition_evaluator(partition_evaluator_args)?;
543 let result = ranges
544 .iter()
545 .map(|range| evaluator.evaluate(&values, range))
546 .collect::<Result<Vec<ScalarValue>>>()?;
547 let result = ScalarValue::iter_to_array(result.into_iter())?;
548 let result = as_int32_array(&result)?;
549 assert_eq!(expected, *result);
550 Ok(())
551 }
552
553 #[test]
554 fn first_value() -> Result<()> {
555 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
556 test_i32_result(
557 NthValue::first(),
558 PartitionEvaluatorArgs::new(
559 &[expr],
560 &[Field::new("f", DataType::Int32, true).into()],
561 false,
562 false,
563 ),
564 Int32Array::from(vec![1; 8]).iter().collect::<Int32Array>(),
565 )
566 }
567
568 #[test]
569 fn last_value() -> Result<()> {
570 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
571 test_i32_result(
572 NthValue::last(),
573 PartitionEvaluatorArgs::new(
574 &[expr],
575 &[Field::new("f", DataType::Int32, true).into()],
576 false,
577 false,
578 ),
579 Int32Array::from(vec![
580 Some(1),
581 Some(-2),
582 Some(3),
583 Some(-4),
584 Some(5),
585 Some(-6),
586 Some(7),
587 Some(8),
588 ]),
589 )
590 }
591
592 #[test]
593 fn nth_value_1() -> Result<()> {
594 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
595 let n_value =
596 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
597
598 test_i32_result(
599 NthValue::nth(),
600 PartitionEvaluatorArgs::new(
601 &[expr, n_value],
602 &[Field::new("f", DataType::Int32, true).into()],
603 false,
604 false,
605 ),
606 Int32Array::from(vec![1; 8]),
607 )?;
608 Ok(())
609 }
610
611 #[test]
612 fn nth_value_2() -> Result<()> {
613 let expr = Arc::new(Column::new("c3", 0)) as Arc<dyn PhysicalExpr>;
614 let n_value =
615 Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
616
617 test_i32_result(
618 NthValue::nth(),
619 PartitionEvaluatorArgs::new(
620 &[expr, n_value],
621 &[Field::new("f", DataType::Int32, true).into()],
622 false,
623 false,
624 ),
625 Int32Array::from(vec![
626 None,
627 Some(-2),
628 Some(-2),
629 Some(-2),
630 Some(-2),
631 Some(-2),
632 Some(-2),
633 Some(-2),
634 ]),
635 )?;
636 Ok(())
637 }
638}