datafusion_python/expr/
window.rs1use datafusion::common::{DataFusionError, ScalarValue};
19use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits};
20use pyo3::{prelude::*, IntoPyObjectExt};
21use std::fmt::{self, Display, Formatter};
22
23use crate::common::data_type::PyScalarValue;
24use crate::common::df_schema::PyDFSchema;
25use crate::errors::{py_type_err, PyDataFusionResult};
26use crate::expr::logical_node::LogicalNode;
27use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
28use crate::expr::PyExpr;
29use crate::sql::logical::PyLogicalPlan;
30
31use super::py_expr_list;
32
33use crate::errors::py_datafusion_err;
34
35#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
36#[derive(Clone)]
37pub struct PyWindowExpr {
38 window: Window,
39}
40
41#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)]
42#[derive(Clone)]
43pub struct PyWindowFrame {
44 window_frame: WindowFrame,
45}
46
47impl From<PyWindowFrame> for WindowFrame {
48 fn from(window_frame: PyWindowFrame) -> Self {
49 window_frame.window_frame
50 }
51}
52
53impl From<WindowFrame> for PyWindowFrame {
54 fn from(window_frame: WindowFrame) -> PyWindowFrame {
55 PyWindowFrame { window_frame }
56 }
57}
58
59#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)]
60#[derive(Clone)]
61pub struct PyWindowFrameBound {
62 frame_bound: WindowFrameBound,
63}
64
65impl From<PyWindowExpr> for Window {
66 fn from(window: PyWindowExpr) -> Window {
67 window.window
68 }
69}
70
71impl From<Window> for PyWindowExpr {
72 fn from(window: Window) -> PyWindowExpr {
73 PyWindowExpr { window }
74 }
75}
76
77impl From<WindowFrameBound> for PyWindowFrameBound {
78 fn from(frame_bound: WindowFrameBound) -> Self {
79 PyWindowFrameBound { frame_bound }
80 }
81}
82
83impl Display for PyWindowExpr {
84 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
85 write!(
86 f,
87 "Over\n
88 Window Expr: {:?}
89 Schema: {:?}",
90 &self.window.window_expr, &self.window.schema
91 )
92 }
93}
94
95impl Display for PyWindowFrame {
96 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
97 write!(
98 f,
99 "OVER ({} BETWEEN {} AND {})",
100 self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound
101 )
102 }
103}
104
105#[pymethods]
106impl PyWindowExpr {
107 pub fn schema(&self) -> PyResult<PyDFSchema> {
109 Ok(self.window.schema.as_ref().clone().into())
110 }
111
112 pub fn get_window_expr(&self) -> PyResult<Vec<PyExpr>> {
114 py_expr_list(&self.window.window_expr)
115 }
116
117 pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
119 match expr.expr.unalias() {
120 Expr::WindowFunction(boxed_window_fn) => {
121 py_sort_expr_list(&boxed_window_fn.params.order_by)
122 }
123 other => Err(not_window_function_err(other)),
124 }
125 }
126
127 pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
129 match expr.expr.unalias() {
130 Expr::WindowFunction(boxed_window_fn) => {
131 py_expr_list(&boxed_window_fn.params.partition_by)
132 }
133 other => Err(not_window_function_err(other)),
134 }
135 }
136
137 pub fn get_args(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
139 match expr.expr.unalias() {
140 Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args),
141 other => Err(not_window_function_err(other)),
142 }
143 }
144
145 pub fn window_func_name(&self, expr: PyExpr) -> PyResult<String> {
147 match expr.expr.unalias() {
148 Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()),
149 other => Err(not_window_function_err(other)),
150 }
151 }
152
153 pub fn get_frame(&self, expr: PyExpr) -> Option<PyWindowFrame> {
155 match expr.expr.unalias() {
156 Expr::WindowFunction(boxed_window_fn) => {
157 Some(boxed_window_fn.params.window_frame.into())
158 }
159 _ => None,
160 }
161 }
162}
163
164fn not_window_function_err(expr: Expr) -> PyErr {
165 py_type_err(format!(
166 "Provided {} Expr {:?} is not a WindowFunction type",
167 expr.variant_name(),
168 expr
169 ))
170}
171
172#[pymethods]
173impl PyWindowFrame {
174 #[new]
175 #[pyo3(signature=(unit, start_bound, end_bound))]
176 pub fn new(
177 unit: &str,
178 start_bound: Option<PyScalarValue>,
179 end_bound: Option<PyScalarValue>,
180 ) -> PyResult<Self> {
181 let units = unit.to_ascii_lowercase();
182 let units = match units.as_str() {
183 "rows" => WindowFrameUnits::Rows,
184 "range" => WindowFrameUnits::Range,
185 "groups" => WindowFrameUnits::Groups,
186 _ => {
187 return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
188 "{units:?}",
189 ))));
190 }
191 };
192 let start_bound = match start_bound {
193 Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
194 None => match units {
195 WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
196 WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
197 WindowFrameUnits::Groups => {
198 return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
199 "{units:?}",
200 ))));
201 }
202 },
203 };
204 let end_bound = match end_bound {
205 Some(end_bound) => WindowFrameBound::Following(end_bound.0),
206 None => match units {
207 WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
208 WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
209 WindowFrameUnits::Groups => {
210 return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
211 "{units:?}",
212 ))));
213 }
214 },
215 };
216 Ok(PyWindowFrame {
217 window_frame: WindowFrame::new_bounds(units, start_bound, end_bound),
218 })
219 }
220
221 pub fn get_frame_units(&self) -> PyResult<String> {
223 Ok(self.window_frame.units.to_string())
224 }
225 pub fn get_lower_bound(&self) -> PyResult<PyWindowFrameBound> {
227 Ok(self.window_frame.start_bound.clone().into())
228 }
229 pub fn get_upper_bound(&self) -> PyResult<PyWindowFrameBound> {
231 Ok(self.window_frame.end_bound.clone().into())
232 }
233
234 fn __repr__(&self) -> String {
236 format!("{self}")
237 }
238}
239
240#[pymethods]
241impl PyWindowFrameBound {
242 pub fn is_current_row(&self) -> bool {
244 matches!(self.frame_bound, WindowFrameBound::CurrentRow)
245 }
246
247 pub fn is_preceding(&self) -> bool {
249 matches!(self.frame_bound, WindowFrameBound::Preceding(_))
250 }
251
252 pub fn is_following(&self) -> bool {
254 matches!(self.frame_bound, WindowFrameBound::Following(_))
255 }
256 pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
258 match &self.frame_bound {
259 WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val {
260 x if x.is_null() => Ok(None),
261 ScalarValue::UInt64(v) => Ok(*v),
262 ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)),
264 ScalarValue::Utf8(Some(s)) => match s.parse::<u64>() {
265 Ok(s) => Ok(Some(s)),
266 Err(_e) => Err(DataFusionError::Plan(format!(
267 "Unable to parse u64 from Utf8 value '{s}'"
268 ))
269 .into()),
270 },
271 ref x => {
272 Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into())
273 }
274 },
275 WindowFrameBound::CurrentRow => Ok(None),
276 }
277 }
278 pub fn is_unbounded(&self) -> PyResult<bool> {
280 match &self.frame_bound {
281 WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()),
282 WindowFrameBound::CurrentRow => Ok(false),
283 }
284 }
285}
286
287impl LogicalNode for PyWindowExpr {
288 fn inputs(&self) -> Vec<PyLogicalPlan> {
289 vec![self.window.input.as_ref().clone().into()]
290 }
291
292 fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
293 self.clone().into_bound_py_any(py)
294 }
295}