datafusion_functions_window/
row_number.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! `row_number` window function implementation
19
20use arrow::datatypes::FieldRef;
21use datafusion_common::arrow::array::ArrayRef;
22use datafusion_common::arrow::array::UInt64Array;
23use datafusion_common::arrow::compute::SortOptions;
24use datafusion_common::arrow::datatypes::DataType;
25use datafusion_common::arrow::datatypes::Field;
26use datafusion_common::{Result, ScalarValue};
27use datafusion_expr::{
28    Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl,
29};
30use datafusion_functions_window_common::field;
31use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
32use datafusion_macros::user_doc;
33use field::WindowUDFFieldArgs;
34use std::any::Any;
35use std::fmt::Debug;
36use std::ops::Range;
37
38define_udwf_and_expr!(
39    RowNumber,
40    row_number,
41    "Returns a unique row number for each row in window partition beginning at 1."
42);
43
44/// row_number expression
45#[user_doc(
46    doc_section(label = "Ranking Functions"),
47    description = "Number of the current row within its partition, counting from 1.",
48    syntax_example = "row_number()",
49    sql_example = r"```sql
50    --Example usage of the row_number window function:
51    SELECT department,
52           salary,
53           row_number() OVER (PARTITION BY department ORDER BY salary DESC) AS row_num
54    FROM employees;
55```
56
57```sql
58+-------------+--------+---------+
59| department  | salary | row_num |
60+-------------+--------+---------+
61| Sales       | 70000  | 1       |
62| Sales       | 50000  | 2       |
63| Sales       | 50000  | 3       |
64| Sales       | 30000  | 4       |
65| Engineering | 90000  | 1       |
66| Engineering | 80000  | 2       |
67+-------------+--------+---------+
68```#"
69)]
70#[derive(Debug)]
71pub struct RowNumber {
72    signature: Signature,
73}
74
75impl RowNumber {
76    /// Create a new `row_number` function
77    pub fn new() -> Self {
78        Self {
79            signature: Signature::nullary(Volatility::Immutable),
80        }
81    }
82}
83
84impl Default for RowNumber {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl WindowUDFImpl for RowNumber {
91    fn as_any(&self) -> &dyn Any {
92        self
93    }
94
95    fn name(&self) -> &str {
96        "row_number"
97    }
98
99    fn signature(&self) -> &Signature {
100        &self.signature
101    }
102
103    fn partition_evaluator(
104        &self,
105        _partition_evaluator_args: PartitionEvaluatorArgs,
106    ) -> Result<Box<dyn PartitionEvaluator>> {
107        Ok(Box::<NumRowsEvaluator>::default())
108    }
109
110    fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
111        Ok(Field::new(field_args.name(), DataType::UInt64, false).into())
112    }
113
114    fn sort_options(&self) -> Option<SortOptions> {
115        Some(SortOptions {
116            descending: false,
117            nulls_first: false,
118        })
119    }
120
121    fn documentation(&self) -> Option<&Documentation> {
122        self.doc()
123    }
124}
125
126/// State for the `row_number` built-in window function.
127#[derive(Debug, Default)]
128struct NumRowsEvaluator {
129    n_rows: usize,
130}
131
132impl PartitionEvaluator for NumRowsEvaluator {
133    fn is_causal(&self) -> bool {
134        // The row_number function doesn't need "future" values to emit results:
135        true
136    }
137
138    fn evaluate_all(
139        &mut self,
140        _values: &[ArrayRef],
141        num_rows: usize,
142    ) -> Result<ArrayRef> {
143        Ok(std::sync::Arc::new(UInt64Array::from_iter_values(
144            1..(num_rows as u64) + 1,
145        )))
146    }
147
148    fn evaluate(
149        &mut self,
150        _values: &[ArrayRef],
151        _range: &Range<usize>,
152    ) -> Result<ScalarValue> {
153        self.n_rows += 1;
154        Ok(ScalarValue::UInt64(Some(self.n_rows as u64)))
155    }
156
157    fn supports_bounded_execution(&self) -> bool {
158        true
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use std::sync::Arc;
165
166    use datafusion_common::arrow::array::{Array, BooleanArray};
167    use datafusion_common::cast::as_uint64_array;
168
169    use super::*;
170
171    #[test]
172    fn row_number_all_null() -> Result<()> {
173        let values: ArrayRef = Arc::new(BooleanArray::from(vec![
174            None, None, None, None, None, None, None, None,
175        ]));
176        let num_rows = values.len();
177
178        let actual = RowNumber::default()
179            .partition_evaluator(PartitionEvaluatorArgs::default())?
180            .evaluate_all(&[values], num_rows)?;
181        let actual = as_uint64_array(&actual)?;
182
183        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *actual.values());
184        Ok(())
185    }
186
187    #[test]
188    fn row_number_all_values() -> Result<()> {
189        let values: ArrayRef = Arc::new(BooleanArray::from(vec![
190            true, false, true, false, false, true, false, true,
191        ]));
192        let num_rows = values.len();
193
194        let actual = RowNumber::default()
195            .partition_evaluator(PartitionEvaluatorArgs::default())?
196            .evaluate_all(&[values], num_rows)?;
197        let actual = as_uint64_array(&actual)?;
198
199        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *actual.values());
200        Ok(())
201    }
202}