datafusion_python/expr/
window.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::common::{DataFusionError, ScalarValue};
19use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
20use pyo3::{prelude::*, IntoPyObjectExt};
21use std::fmt::{self, Display, Formatter};
22
23use crate::common::data_type::PyScalarValue;
24use crate::common::df_schema::PyDFSchema;
25use crate::errors::{py_type_err, PyDataFusionResult};
26use crate::expr::logical_node::LogicalNode;
27use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
28use crate::expr::PyExpr;
29use crate::sql::logical::PyLogicalPlan;
30
31use super::py_expr_list;
32
33use crate::errors::py_datafusion_err;
34
35#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
36#[derive(Clone)]
37pub struct PyWindowExpr {
38    window: Window,
39}
40
41#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)]
42#[derive(Clone)]
43pub struct PyWindowFrame {
44    window_frame: WindowFrame,
45}
46
47impl From<PyWindowFrame> for WindowFrame {
48    fn from(window_frame: PyWindowFrame) -> Self {
49        window_frame.window_frame
50    }
51}
52
53impl From<WindowFrame> for PyWindowFrame {
54    fn from(window_frame: WindowFrame) -> PyWindowFrame {
55        PyWindowFrame { window_frame }
56    }
57}
58
59#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)]
60#[derive(Clone)]
61pub struct PyWindowFrameBound {
62    frame_bound: WindowFrameBound,
63}
64
65impl From<PyWindowExpr> for Window {
66    fn from(window: PyWindowExpr) -> Window {
67        window.window
68    }
69}
70
71impl From<Window> for PyWindowExpr {
72    fn from(window: Window) -> PyWindowExpr {
73        PyWindowExpr { window }
74    }
75}
76
77impl From<WindowFrameBound> for PyWindowFrameBound {
78    fn from(frame_bound: WindowFrameBound) -> Self {
79        PyWindowFrameBound { frame_bound }
80    }
81}
82
83impl Display for PyWindowExpr {
84    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
85        write!(
86            f,
87            "Over\n
88            Window Expr: {:?}
89            Schema: {:?}",
90            &self.window.window_expr, &self.window.schema
91        )
92    }
93}
94
95impl Display for PyWindowFrame {
96    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
97        write!(
98            f,
99            "OVER ({} BETWEEN {} AND {})",
100            self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
101        )
102    }
103}
104
105#[pymethods]
106impl PyWindowExpr {
107    /// Returns the schema of the Window
108    pub fn schema(&self) -> PyResult<PyDFSchema> {
109        Ok(self.window.schema.as_ref().clone().into())
110    }
111
112    /// Returns window expressions
113    pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
114        py_expr_list(&self.window.window_expr)
115    }
116
117    /// Returns order by columns in a window function expression
118    pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
119        match expr.expr.unalias() {
120            Expr::WindowFunction(boxed_window_fn) => {
121                py_sort_expr_list(&boxed_window_fn.params.order_by)
122            }
123            other => Err(not_window_function_err(other)),
124        }
125    }
126
127    /// Return partition by columns in a window function expression
128    pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
129        match expr.expr.unalias() {
130            Expr::WindowFunction(boxed_window_fn) => {
131                py_expr_list(&boxed_window_fn.params.partition_by)
132            }
133            other => Err(not_window_function_err(other)),
134        }
135    }
136
137    /// Return input args for window function
138    pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
139        match expr.expr.unalias() {
140            Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args),
141            other => Err(not_window_function_err(other)),
142        }
143    }
144
145    /// Return window function name
146    pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
147        match expr.expr.unalias() {
148            Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()),
149            other => Err(not_window_function_err(other)),
150        }
151    }
152
153    /// Returns a Pywindow frame for a given window function expression
154    pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
155        match expr.expr.unalias() {
156            Expr::WindowFunction(boxed_window_fn) => {
157                Some(boxed_window_fn.params.window_frame.into())
158            }
159            _ => None,
160        }
161    }
162}
163
164fn not_window_function_err(expr: Expr) -> PyErr {
165    py_type_err(format!(
166        "Provided {} Expr {:?} is not a WindowFunction type",
167        expr.variant_name(),
168        expr
169    ))
170}
171
172#[pymethods]
173impl PyWindowFrame {
174    #[new]
175    #[pyo3(signature=(unit, start_bound, end_bound))]
176    pub fn new(
177        unit: &str,
178        start_bound: Option<PyScalarValue>,
179        end_bound: Option<PyScalarValue>,
180    ) -> PyResult<Self> {
181        let units = unit.to_ascii_lowercase();
182        let units = match units.as_str() {
183            "rows" => WindowFrameUnits::Rows,
184            "range" => WindowFrameUnits::Range,
185            "groups" => WindowFrameUnits::Groups,
186            _ => {
187                return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
188                    "{units:?}",
189                ))));
190            }
191        };
192        let start_bound = match start_bound {
193            Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
194            None => match units {
195                WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
196                WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
197                WindowFrameUnits::Groups => {
198                    return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
199                        "{units:?}",
200                    ))));
201                }
202            },
203        };
204        let end_bound = match end_bound {
205            Some(end_bound) => WindowFrameBound::Following(end_bound.0),
206            None => match units {
207                WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
208                WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
209                WindowFrameUnits::Groups => {
210                    return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
211                        "{units:?}",
212                    ))));
213                }
214            },
215        };
216        Ok(PyWindowFrame {
217            window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
218        })
219    }
220
221    /// Returns the window frame units for the bounds
222    pub fn get_frame_units(&self) -> PyResult<String> {
223        Ok(self.window_frame.units.to_string())
224    }
225    /// Returns starting bound
226    pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
227        Ok(self.window_frame.start_bound.clone().into())
228    }
229    /// Returns end bound
230    pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
231        Ok(self.window_frame.end_bound.clone().into())
232    }
233
234    /// Get a String representation of this window frame
235    fn __repr__(&self) -> String {
236        format!("{self}")
237    }
238}
239
240#[pymethods]
241impl PyWindowFrameBound {
242    /// Returns if the frame bound is current row
243    pub fn is_current_row(&self) -> bool {
244        matches!(self.frame_bound, WindowFrameBound::CurrentRow)
245    }
246
247    /// Returns if the frame bound is preceding
248    pub fn is_preceding(&self) -> bool {
249        matches!(self.frame_bound, WindowFrameBound::Preceding(_))
250    }
251
252    /// Returns if the frame bound is following
253    pub fn is_following(&self) -> bool {
254        matches!(self.frame_bound, WindowFrameBound::Following(_))
255    }
256    /// Returns the offset of the window frame
257    pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
258        match &self.frame_bound {
259            WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
260                x if x.is_null() => Ok(None),
261                ScalarValue::UInt64(v) => Ok(*v),
262                // The cast below is only safe because window bounds cannot be negative
263                ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
264                ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
265                    Ok(s) => Ok(Some(s)),
266                    Err(_e) => Err(DataFusionError::Plan(format!(
267                        "Unable to parse u64 from Utf8 value '{s}'"
268                    ))
269                    .into()),
270                },
271                ref x => {
272                    Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
273                }
274            },
275            WindowFrameBound::CurrentRow => Ok(None),
276        }
277    }
278    /// Returns if the frame bound is unbounded
279    pub fn is_unbounded(&self) -> PyResult<bool> {
280        match &self.frame_bound {
281            WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
282            WindowFrameBound::CurrentRow => Ok(false),
283        }
284    }
285}
286
287impl LogicalNode for PyWindowExpr {
288    fn inputs(&self) -> Vec<PyLogicalPlan> {
289        vec![self.window.input.as_ref().clone().into()]
290    }
291
292    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
293        self.clone().into_bound_py_any(py)
294    }
295}