datafusion_python/
dataframe.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::ffi::CString;
20use std::sync::Arc;
21
22use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
23use arrow::compute::can_cast_types;
24use arrow::error::ArrowError;
25use arrow::ffi::FFI_ArrowSchema;
26use arrow::ffi_stream::FFI_ArrowArrayStream;
27use arrow::pyarrow::FromPyArrow;
28use datafusion::arrow::datatypes::Schema;
29use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
30use datafusion::arrow::util::pretty;
31use datafusion::common::UnnestOptions;
32use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions};
33use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
34use datafusion::datasource::TableProvider;
35use datafusion::error::DataFusionError;
36use datafusion::execution::SendableRecordBatchStream;
37use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
38use datafusion::prelude::*;
39use datafusion_ffi::table_provider::FFI_TableProvider;
40use futures::{StreamExt, TryStreamExt};
41use pyo3::exceptions::PyValueError;
42use pyo3::prelude::*;
43use pyo3::pybacked::PyBackedStr;
44use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45use tokio::task::JoinHandle;
46
47use crate::catalog::PyTable;
48use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
49use crate::expr::sort_expr::to_sort_expressions;
50use crate::physical_plan::PyExecutionPlan;
51use crate::record_batch::PyRecordBatchStream;
52use crate::sql::logical::PyLogicalPlan;
53use crate::utils::{
54    get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
55};
56use crate::{
57    errors::PyDataFusionResult,
58    expr::{sort_expr::PySortExpr, PyExpr},
59};
60
61// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
62// - we have not decided on the table_provider approach yet
63// this is an interim implementation
64#[pyclass(name = "TableProvider", module = "datafusion")]
65pub struct PyTableProvider {
66    provider: Arc<dyn TableProvider + Send>,
67}
68
69impl PyTableProvider {
70    pub fn new(provider: Arc<dyn TableProvider>) -> Self {
71        Self { provider }
72    }
73
74    pub fn as_table(&self) -> PyTable {
75        let table_provider: Arc<dyn TableProvider> = self.provider.clone();
76        PyTable::new(table_provider)
77    }
78}
79
80#[pymethods]
81impl PyTableProvider {
82    fn __datafusion_table_provider__<'py>(
83        &self,
84        py: Python<'py>,
85    ) -> PyResult<Bound<'py, PyCapsule>> {
86        let name = CString::new("datafusion_table_provider").unwrap();
87
88        let runtime = get_tokio_runtime().0.handle().clone();
89        let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime));
90
91        PyCapsule::new(py, provider, Some(name.clone()))
92    }
93}
94
95/// Configuration for DataFrame display formatting
96#[derive(Debug, Clone)]
97pub struct FormatterConfig {
98    /// Maximum memory in bytes to use for display (default: 2MB)
99    pub max_bytes: usize,
100    /// Minimum number of rows to display (default: 20)
101    pub min_rows: usize,
102    /// Number of rows to include in __repr__ output (default: 10)
103    pub repr_rows: usize,
104}
105
106impl Default for FormatterConfig {
107    fn default() -> Self {
108        Self {
109            max_bytes: 2 * 1024 * 1024, // 2MB
110            min_rows: 20,
111            repr_rows: 10,
112        }
113    }
114}
115
116impl FormatterConfig {
117    /// Validates that all configuration values are positive integers.
118    ///
119    /// # Returns
120    ///
121    /// `Ok(())` if all values are valid, or an `Err` with a descriptive error message.
122    pub fn validate(&self) -> Result<(), String> {
123        if self.max_bytes == 0 {
124            return Err("max_bytes must be a positive integer".to_string());
125        }
126
127        if self.min_rows == 0 {
128            return Err("min_rows must be a positive integer".to_string());
129        }
130
131        if self.repr_rows == 0 {
132            return Err("repr_rows must be a positive integer".to_string());
133        }
134
135        Ok(())
136    }
137}
138
139/// Holds the Python formatter and its configuration
140struct PythonFormatter<'py> {
141    /// The Python formatter object
142    formatter: Bound<'py, PyAny>,
143    /// The formatter configuration
144    config: FormatterConfig,
145}
146
147/// Get the Python formatter and its configuration
148fn get_python_formatter_with_config(py: Python) -> PyResult<PythonFormatter> {
149    let formatter = import_python_formatter(py)?;
150    let config = build_formatter_config_from_python(&formatter)?;
151    Ok(PythonFormatter { formatter, config })
152}
153
154/// Get the Python formatter from the datafusion.dataframe_formatter module
155fn import_python_formatter(py: Python) -> PyResult<Bound<'_, PyAny>> {
156    let formatter_module = py.import("datafusion.dataframe_formatter")?;
157    let get_formatter = formatter_module.getattr("get_formatter")?;
158    get_formatter.call0()
159}
160
161// Helper function to extract attributes with fallback to default
162fn get_attr<'a, T>(py_object: &'a Bound<'a, PyAny>, attr_name: &str, default_value: T) -> T
163where
164    T: for<'py> pyo3::FromPyObject<'py> + Clone,
165{
166    py_object
167        .getattr(attr_name)
168        .and_then(|v| v.extract::<T>())
169        .unwrap_or_else(|_| default_value.clone())
170}
171
172/// Helper function to create a FormatterConfig from a Python formatter object
173fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<FormatterConfig> {
174    let default_config = FormatterConfig::default();
175    let max_bytes = get_attr(formatter, "max_memory_bytes", default_config.max_bytes);
176    let min_rows = get_attr(formatter, "min_rows_display", default_config.min_rows);
177    let repr_rows = get_attr(formatter, "repr_rows", default_config.repr_rows);
178
179    let config = FormatterConfig {
180        max_bytes,
181        min_rows,
182        repr_rows,
183    };
184
185    // Return the validated config, converting String error to PyErr
186    config.validate().map_err(PyValueError::new_err)?;
187    Ok(config)
188}
189
190/// Python mapping of `ParquetOptions` (includes just the writer-related options).
191#[pyclass(name = "ParquetWriterOptions", module = "datafusion", subclass)]
192#[derive(Clone, Default)]
193pub struct PyParquetWriterOptions {
194    options: ParquetOptions,
195}
196
197#[pymethods]
198impl PyParquetWriterOptions {
199    #[new]
200    #[allow(clippy::too_many_arguments)]
201    pub fn new(
202        data_pagesize_limit: usize,
203        write_batch_size: usize,
204        writer_version: String,
205        skip_arrow_metadata: bool,
206        compression: Option<String>,
207        dictionary_enabled: Option<bool>,
208        dictionary_page_size_limit: usize,
209        statistics_enabled: Option<String>,
210        max_row_group_size: usize,
211        created_by: String,
212        column_index_truncate_length: Option<usize>,
213        statistics_truncate_length: Option<usize>,
214        data_page_row_count_limit: usize,
215        encoding: Option<String>,
216        bloom_filter_on_write: bool,
217        bloom_filter_fpp: Option<f64>,
218        bloom_filter_ndv: Option<u64>,
219        allow_single_file_parallelism: bool,
220        maximum_parallel_row_group_writers: usize,
221        maximum_buffered_record_batches_per_stream: usize,
222    ) -> Self {
223        Self {
224            options: ParquetOptions {
225                data_pagesize_limit,
226                write_batch_size,
227                writer_version,
228                skip_arrow_metadata,
229                compression,
230                dictionary_enabled,
231                dictionary_page_size_limit,
232                statistics_enabled,
233                max_row_group_size,
234                created_by,
235                column_index_truncate_length,
236                statistics_truncate_length,
237                data_page_row_count_limit,
238                encoding,
239                bloom_filter_on_write,
240                bloom_filter_fpp,
241                bloom_filter_ndv,
242                allow_single_file_parallelism,
243                maximum_parallel_row_group_writers,
244                maximum_buffered_record_batches_per_stream,
245                ..Default::default()
246            },
247        }
248    }
249}
250
251/// Python mapping of `ParquetColumnOptions`.
252#[pyclass(name = "ParquetColumnOptions", module = "datafusion", subclass)]
253#[derive(Clone, Default)]
254pub struct PyParquetColumnOptions {
255    options: ParquetColumnOptions,
256}
257
258#[pymethods]
259impl PyParquetColumnOptions {
260    #[new]
261    pub fn new(
262        bloom_filter_enabled: Option<bool>,
263        encoding: Option<String>,
264        dictionary_enabled: Option<bool>,
265        compression: Option<String>,
266        statistics_enabled: Option<String>,
267        bloom_filter_fpp: Option<f64>,
268        bloom_filter_ndv: Option<u64>,
269    ) -> Self {
270        Self {
271            options: ParquetColumnOptions {
272                bloom_filter_enabled,
273                encoding,
274                dictionary_enabled,
275                compression,
276                statistics_enabled,
277                bloom_filter_fpp,
278                bloom_filter_ndv,
279                ..Default::default()
280            },
281        }
282    }
283}
284
285/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
286/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
287/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
288#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
289#[derive(Clone)]
290pub struct PyDataFrame {
291    df: Arc<DataFrame>,
292
293    // In IPython environment cache batches between __repr__ and _repr_html_ calls.
294    batches: Option<(Vec<RecordBatch>, bool)>,
295}
296
297impl PyDataFrame {
298    /// creates a new PyDataFrame
299    pub fn new(df: DataFrame) -> Self {
300        Self {
301            df: Arc::new(df),
302            batches: None,
303        }
304    }
305
306    fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult<String> {
307        // Get the Python formatter and config
308        let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
309
310        let should_cache = *is_ipython_env(py) && self.batches.is_none();
311        let (batches, has_more) = match self.batches.take() {
312            Some(b) => b,
313            None => wait_for_future(
314                py,
315                collect_record_batches_to_display(self.df.as_ref().clone(), config),
316            )??,
317        };
318
319        if batches.is_empty() {
320            // This should not be reached, but do it for safety since we index into the vector below
321            return Ok("No data to display".to_string());
322        }
323
324        let table_uuid = uuid::Uuid::new_v4().to_string();
325
326        // Convert record batches to PyObject list
327        let py_batches = batches
328            .iter()
329            .map(|rb| rb.to_pyarrow(py))
330            .collect::<PyResult<Vec<PyObject>>>()?;
331
332        let py_schema = self.schema().into_pyobject(py)?;
333
334        let kwargs = pyo3::types::PyDict::new(py);
335        let py_batches_list = PyList::new(py, py_batches.as_slice())?;
336        kwargs.set_item("batches", py_batches_list)?;
337        kwargs.set_item("schema", py_schema)?;
338        kwargs.set_item("has_more", has_more)?;
339        kwargs.set_item("table_uuid", table_uuid)?;
340
341        let method_name = match as_html {
342            true => "format_html",
343            false => "format_str",
344        };
345
346        let html_result = formatter.call_method(method_name, (), Some(&kwargs))?;
347        let html_str: String = html_result.extract()?;
348
349        if should_cache {
350            self.batches = Some((batches, has_more));
351        }
352
353        Ok(html_str)
354    }
355}
356
357#[pymethods]
358impl PyDataFrame {
359    /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
360    fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
361        if let Ok(key) = key.extract::<PyBackedStr>() {
362            // df[col]
363            self.select_columns(vec![key])
364        } else if let Ok(tuple) = key.downcast::<PyTuple>() {
365            // df[col1, col2, col3]
366            let keys = tuple
367                .iter()
368                .map(|item| item.extract::<PyBackedStr>())
369                .collect::<PyResult<Vec<PyBackedStr>>>()?;
370            self.select_columns(keys)
371        } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
372            // df[[col1, col2, col3]]
373            self.select_columns(keys)
374        } else {
375            let message = "DataFrame can only be indexed by string index or indices".to_string();
376            Err(PyDataFusionError::Common(message))
377        }
378    }
379
380    fn __repr__(&mut self, py: Python) -> PyDataFusionResult<String> {
381        self.prepare_repr_string(py, false)
382    }
383
384    #[staticmethod]
385    #[expect(unused_variables)]
386    fn default_str_repr<'py>(
387        batches: Vec<Bound<'py, PyAny>>,
388        schema: &Bound<'py, PyAny>,
389        has_more: bool,
390        table_uuid: &str,
391    ) -> PyResult<String> {
392        let batches = batches
393            .into_iter()
394            .map(|batch| RecordBatch::from_pyarrow_bound(&batch))
395            .collect::<PyResult<Vec<RecordBatch>>>()?
396            .into_iter()
397            .filter(|batch| batch.num_rows() > 0)
398            .collect::<Vec<_>>();
399
400        if batches.is_empty() {
401            return Ok("No data to display".to_owned());
402        }
403
404        let batches_as_displ =
405            pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
406
407        let additional_str = match has_more {
408            true => "\nData truncated.",
409            false => "",
410        };
411
412        Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
413    }
414
415    fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult<String> {
416        self.prepare_repr_string(py, true)
417    }
418
419    /// Calculate summary statistics for a DataFrame
420    fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
421        let df = self.df.as_ref().clone();
422        let stat_df = wait_for_future(py, df.describe())??;
423        Ok(Self::new(stat_df))
424    }
425
426    /// Returns the schema from the logical plan
427    fn schema(&self) -> PyArrowType<Schema> {
428        PyArrowType(self.df.schema().into())
429    }
430
431    /// Convert this DataFrame into a Table that can be used in register_table
432    /// By convention, into_... methods consume self and return the new object.
433    /// Disabling the clippy lint, so we can use &self
434    /// because we're working with Python bindings
435    /// where objects are shared
436    /// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
437    /// - we have not decided on the table_provider approach yet
438    #[allow(clippy::wrong_self_convention)]
439    fn into_view(&self) -> PyDataFusionResult<PyTable> {
440        // Call the underlying Rust DataFrame::into_view method.
441        // Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
442        // so that we don’t invalidate this PyDataFrame.
443        let table_provider = self.df.as_ref().clone().into_view();
444        let table_provider = PyTableProvider::new(table_provider);
445
446        Ok(table_provider.as_table())
447    }
448
449    #[pyo3(signature = (*args))]
450    fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
451        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
452        let df = self.df.as_ref().clone().select_columns(&args)?;
453        Ok(Self::new(df))
454    }
455
456    #[pyo3(signature = (*args))]
457    fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
458        let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
459        let df = self.df.as_ref().clone().select(expr)?;
460        Ok(Self::new(df))
461    }
462
463    #[pyo3(signature = (*args))]
464    fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
465        let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
466        let df = self.df.as_ref().clone().drop_columns(&cols)?;
467        Ok(Self::new(df))
468    }
469
470    fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
471        let df = self.df.as_ref().clone().filter(predicate.into())?;
472        Ok(Self::new(df))
473    }
474
475    fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
476        let df = self.df.as_ref().clone().with_column(name, expr.into())?;
477        Ok(Self::new(df))
478    }
479
480    fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
481        let mut df = self.df.as_ref().clone();
482        for expr in exprs {
483            let expr: Expr = expr.into();
484            let name = format!("{}", expr.schema_name());
485            df = df.with_column(name.as_str(), expr)?
486        }
487        Ok(Self::new(df))
488    }
489
490    /// Rename one column by applying a new projection. This is a no-op if the column to be
491    /// renamed does not exist.
492    fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
493        let df = self
494            .df
495            .as_ref()
496            .clone()
497            .with_column_renamed(old_name, new_name)?;
498        Ok(Self::new(df))
499    }
500
501    fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
502        let group_by = group_by.into_iter().map(|e| e.into()).collect();
503        let aggs = aggs.into_iter().map(|e| e.into()).collect();
504        let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
505        Ok(Self::new(df))
506    }
507
508    #[pyo3(signature = (*exprs))]
509    fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
510        let exprs = to_sort_expressions(exprs);
511        let df = self.df.as_ref().clone().sort(exprs)?;
512        Ok(Self::new(df))
513    }
514
515    #[pyo3(signature = (count, offset=0))]
516    fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
517        let df = self.df.as_ref().clone().limit(offset, Some(count))?;
518        Ok(Self::new(df))
519    }
520
521    /// Executes the plan, returning a list of `RecordBatch`es.
522    /// Unless some order is specified in the plan, there is no
523    /// guarantee of the order of the result.
524    fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
525        let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
526            .map_err(PyDataFusionError::from)?;
527        // cannot use PyResult<Vec<RecordBatch>> return type due to
528        // https://github.com/PyO3/pyo3/issues/1813
529        batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
530    }
531
532    /// Cache DataFrame.
533    fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
534        let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
535        Ok(Self::new(df))
536    }
537
538    /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
539    /// maintaining the input partitioning.
540    fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
541        let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
542            .map_err(PyDataFusionError::from)?;
543
544        batches
545            .into_iter()
546            .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
547            .collect()
548    }
549
550    /// Print the result, 20 lines by default
551    #[pyo3(signature = (num=20))]
552    fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
553        let df = self.df.as_ref().clone().limit(0, Some(num))?;
554        print_dataframe(py, df)
555    }
556
557    /// Filter out duplicate rows
558    fn distinct(&self) -> PyDataFusionResult<Self> {
559        let df = self.df.as_ref().clone().distinct()?;
560        Ok(Self::new(df))
561    }
562
563    fn join(
564        &self,
565        right: PyDataFrame,
566        how: &str,
567        left_on: Vec<PyBackedStr>,
568        right_on: Vec<PyBackedStr>,
569    ) -> PyDataFusionResult<Self> {
570        let join_type = match how {
571            "inner" => JoinType::Inner,
572            "left" => JoinType::Left,
573            "right" => JoinType::Right,
574            "full" => JoinType::Full,
575            "semi" => JoinType::LeftSemi,
576            "anti" => JoinType::LeftAnti,
577            how => {
578                return Err(PyDataFusionError::Common(format!(
579                    "The join type {how} does not exist or is not implemented"
580                )));
581            }
582        };
583
584        let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
585        let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
586
587        let df = self.df.as_ref().clone().join(
588            right.df.as_ref().clone(),
589            join_type,
590            &left_keys,
591            &right_keys,
592            None,
593        )?;
594        Ok(Self::new(df))
595    }
596
597    fn join_on(
598        &self,
599        right: PyDataFrame,
600        on_exprs: Vec<PyExpr>,
601        how: &str,
602    ) -> PyDataFusionResult<Self> {
603        let join_type = match how {
604            "inner" => JoinType::Inner,
605            "left" => JoinType::Left,
606            "right" => JoinType::Right,
607            "full" => JoinType::Full,
608            "semi" => JoinType::LeftSemi,
609            "anti" => JoinType::LeftAnti,
610            how => {
611                return Err(PyDataFusionError::Common(format!(
612                    "The join type {how} does not exist or is not implemented"
613                )));
614            }
615        };
616        let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
617
618        let df = self
619            .df
620            .as_ref()
621            .clone()
622            .join_on(right.df.as_ref().clone(), join_type, exprs)?;
623        Ok(Self::new(df))
624    }
625
626    /// Print the query plan
627    #[pyo3(signature = (verbose=false, analyze=false))]
628    fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
629        let df = self.df.as_ref().clone().explain(verbose, analyze)?;
630        print_dataframe(py, df)
631    }
632
633    /// Get the logical plan for this `DataFrame`
634    fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
635        Ok(self.df.as_ref().clone().logical_plan().clone().into())
636    }
637
638    /// Get the optimized logical plan for this `DataFrame`
639    fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
640        Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
641    }
642
643    /// Get the execution plan for this `DataFrame`
644    fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
645        let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
646        Ok(plan.into())
647    }
648
649    /// Repartition a `DataFrame` based on a logical partitioning scheme.
650    fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
651        let new_df = self
652            .df
653            .as_ref()
654            .clone()
655            .repartition(Partitioning::RoundRobinBatch(num))?;
656        Ok(Self::new(new_df))
657    }
658
659    /// Repartition a `DataFrame` based on a logical partitioning scheme.
660    #[pyo3(signature = (*args, num))]
661    fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
662        let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
663        let new_df = self
664            .df
665            .as_ref()
666            .clone()
667            .repartition(Partitioning::Hash(expr, num))?;
668        Ok(Self::new(new_df))
669    }
670
671    /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
672    /// two `DataFrame`s must have exactly the same schema
673    #[pyo3(signature = (py_df, distinct=false))]
674    fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
675        let new_df = if distinct {
676            self.df
677                .as_ref()
678                .clone()
679                .union_distinct(py_df.df.as_ref().clone())?
680        } else {
681            self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
682        };
683
684        Ok(Self::new(new_df))
685    }
686
687    /// Calculate the distinct union of two `DataFrame`s.  The
688    /// two `DataFrame`s must have exactly the same schema
689    fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
690        let new_df = self
691            .df
692            .as_ref()
693            .clone()
694            .union_distinct(py_df.df.as_ref().clone())?;
695        Ok(Self::new(new_df))
696    }
697
698    #[pyo3(signature = (column, preserve_nulls=true))]
699    fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
700        // TODO: expose RecursionUnnestOptions
701        // REF: https://github.com/apache/datafusion/pull/11577
702        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
703        let df = self
704            .df
705            .as_ref()
706            .clone()
707            .unnest_columns_with_options(&[column], unnest_options)?;
708        Ok(Self::new(df))
709    }
710
711    #[pyo3(signature = (columns, preserve_nulls=true))]
712    fn unnest_columns(
713        &self,
714        columns: Vec<String>,
715        preserve_nulls: bool,
716    ) -> PyDataFusionResult<Self> {
717        // TODO: expose RecursionUnnestOptions
718        // REF: https://github.com/apache/datafusion/pull/11577
719        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
720        let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
721        let df = self
722            .df
723            .as_ref()
724            .clone()
725            .unnest_columns_with_options(&cols, unnest_options)?;
726        Ok(Self::new(df))
727    }
728
729    /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
730    fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
731        let new_df = self
732            .df
733            .as_ref()
734            .clone()
735            .intersect(py_df.df.as_ref().clone())?;
736        Ok(Self::new(new_df))
737    }
738
739    /// Calculate the exception of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
740    fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
741        let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
742        Ok(Self::new(new_df))
743    }
744
745    /// Write a `DataFrame` to a CSV file.
746    fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
747        let csv_options = CsvOptions {
748            has_header: Some(with_header),
749            ..Default::default()
750        };
751        wait_for_future(
752            py,
753            self.df.as_ref().clone().write_csv(
754                path,
755                DataFrameWriteOptions::new(),
756                Some(csv_options),
757            ),
758        )??;
759        Ok(())
760    }
761
762    /// Write a `DataFrame` to a Parquet file.
763    #[pyo3(signature = (
764        path,
765        compression="zstd",
766        compression_level=None
767        ))]
768    fn write_parquet(
769        &self,
770        path: &str,
771        compression: &str,
772        compression_level: Option<u32>,
773        py: Python,
774    ) -> PyDataFusionResult<()> {
775        fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
776            cl.ok_or(PyValueError::new_err("compression_level is not defined"))
777        }
778
779        let _validated = match compression.to_lowercase().as_str() {
780            "snappy" => Compression::SNAPPY,
781            "gzip" => Compression::GZIP(
782                GzipLevel::try_new(compression_level.unwrap_or(6))
783                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
784            ),
785            "brotli" => Compression::BROTLI(
786                BrotliLevel::try_new(verify_compression_level(compression_level)?)
787                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
788            ),
789            "zstd" => Compression::ZSTD(
790                ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
791                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
792            ),
793            "lzo" => Compression::LZO,
794            "lz4" => Compression::LZ4,
795            "lz4_raw" => Compression::LZ4_RAW,
796            "uncompressed" => Compression::UNCOMPRESSED,
797            _ => {
798                return Err(PyDataFusionError::Common(format!(
799                    "Unrecognized compression type {compression}"
800                )));
801            }
802        };
803
804        let mut compression_string = compression.to_string();
805        if let Some(level) = compression_level {
806            compression_string.push_str(&format!("({level})"));
807        }
808
809        let mut options = TableParquetOptions::default();
810        options.global.compression = Some(compression_string);
811
812        wait_for_future(
813            py,
814            self.df.as_ref().clone().write_parquet(
815                path,
816                DataFrameWriteOptions::new(),
817                Option::from(options),
818            ),
819        )??;
820        Ok(())
821    }
822
823    /// Write a `DataFrame` to a Parquet file, using advanced options.
824    fn write_parquet_with_options(
825        &self,
826        path: &str,
827        options: PyParquetWriterOptions,
828        column_specific_options: HashMap<String, PyParquetColumnOptions>,
829        py: Python,
830    ) -> PyDataFusionResult<()> {
831        let table_options = TableParquetOptions {
832            global: options.options,
833            column_specific_options: column_specific_options
834                .into_iter()
835                .map(|(k, v)| (k, v.options))
836                .collect(),
837            ..Default::default()
838        };
839
840        wait_for_future(
841            py,
842            self.df.as_ref().clone().write_parquet(
843                path,
844                DataFrameWriteOptions::new(),
845                Option::from(table_options),
846            ),
847        )??;
848        Ok(())
849    }
850
851    /// Executes a query and writes the results to a partitioned JSON file.
852    fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
853        wait_for_future(
854            py,
855            self.df
856                .as_ref()
857                .clone()
858                .write_json(path, DataFrameWriteOptions::new(), None),
859        )??;
860        Ok(())
861    }
862
863    /// Convert to Arrow Table
864    /// Collect the batches and pass to Arrow Table
865    fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
866        let batches = self.collect(py)?.into_pyobject(py)?;
867        let schema = self.schema().into_pyobject(py)?;
868
869        // Instantiate pyarrow Table object and use its from_batches method
870        let table_class = py.import("pyarrow")?.getattr("Table")?;
871        let args = PyTuple::new(py, &[batches, schema])?;
872        let table: PyObject = table_class.call_method1("from_batches", args)?.into();
873        Ok(table)
874    }
875
876    #[pyo3(signature = (requested_schema=None))]
877    fn __arrow_c_stream__<'py>(
878        &'py mut self,
879        py: Python<'py>,
880        requested_schema: Option<Bound<'py, PyCapsule>>,
881    ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
882        let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??;
883        let mut schema: Schema = self.df.schema().to_owned().into();
884
885        if let Some(schema_capsule) = requested_schema {
886            validate_pycapsule(&schema_capsule, "arrow_schema")?;
887
888            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
889            let desired_schema = Schema::try_from(schema_ptr)?;
890
891            schema = project_schema(schema, desired_schema)?;
892
893            batches = batches
894                .into_iter()
895                .map(|record_batch| record_batch_into_schema(record_batch, &schema))
896                .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
897        }
898
899        let batches_wrapped = batches.into_iter().map(Ok);
900
901        let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
902        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
903
904        let ffi_stream = FFI_ArrowArrayStream::new(reader);
905        let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
906        PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
907    }
908
909    fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
910        // create a Tokio runtime to run the async code
911        let rt = &get_tokio_runtime().0;
912        let df = self.df.as_ref().clone();
913        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
914            rt.spawn(async move { df.execute_stream().await });
915        let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
916        Ok(PyRecordBatchStream::new(stream))
917    }
918
919    fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
920        // create a Tokio runtime to run the async code
921        let rt = &get_tokio_runtime().0;
922        let df = self.df.as_ref().clone();
923        let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
924            rt.spawn(async move { df.execute_stream_partitioned().await });
925        let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })?
926            .map_err(py_datafusion_err)?
927            .map_err(py_datafusion_err)?;
928
929        Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
930    }
931
932    /// Convert to pandas dataframe with pyarrow
933    /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
934    fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
935        let table = self.to_arrow_table(py)?;
936
937        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
938        let result = table.call_method0(py, "to_pandas")?;
939        Ok(result)
940    }
941
942    /// Convert to Python list using pyarrow
943    /// Each list item represents one row encoded as dictionary
944    fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
945        let table = self.to_arrow_table(py)?;
946
947        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
948        let result = table.call_method0(py, "to_pylist")?;
949        Ok(result)
950    }
951
952    /// Convert to Python dictionary using pyarrow
953    /// Each dictionary key is a column and the dictionary value represents the column values
954    fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
955        let table = self.to_arrow_table(py)?;
956
957        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
958        let result = table.call_method0(py, "to_pydict")?;
959        Ok(result)
960    }
961
962    /// Convert to polars dataframe with pyarrow
963    /// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
964    fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
965        let table = self.to_arrow_table(py)?;
966        let dataframe = py.import("polars")?.getattr("DataFrame")?;
967        let args = PyTuple::new(py, &[table])?;
968        let result: PyObject = dataframe.call1(args)?.into();
969        Ok(result)
970    }
971
972    // Executes this DataFrame to get the total number of rows.
973    fn count(&self, py: Python) -> PyDataFusionResult<usize> {
974        Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
975    }
976
977    /// Fill null values with a specified value for specific columns
978    #[pyo3(signature = (value, columns=None))]
979    fn fill_null(
980        &self,
981        value: PyObject,
982        columns: Option<Vec<PyBackedStr>>,
983        py: Python,
984    ) -> PyDataFusionResult<Self> {
985        let scalar_value = py_obj_to_scalar_value(py, value)?;
986
987        let cols = match columns {
988            Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
989            None => Vec::new(), // Empty vector means fill null for all columns
990        };
991
992        let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
993        Ok(Self::new(df))
994    }
995}
996
997/// Print DataFrame
998fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
999    // Get string representation of record batches
1000    let batches = wait_for_future(py, df.collect())??;
1001    let batches_as_string = pretty::pretty_format_batches(&batches);
1002    let result = match batches_as_string {
1003        Ok(batch) => format!("DataFrame()\n{batch}"),
1004        Err(err) => format!("Error: {:?}", err.to_string()),
1005    };
1006
1007    // Import the Python 'builtins' module to access the print function
1008    // Note that println! does not print to the Python debug console and is not visible in notebooks for instance
1009    let print = py.import("builtins")?.getattr("print")?;
1010    print.call1((result,))?;
1011    Ok(())
1012}
1013
1014fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
1015    let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
1016
1017    let project_indices: Vec<usize> = to_schema
1018        .fields
1019        .iter()
1020        .map(|field| field.name())
1021        .filter_map(|field_name| merged_schema.index_of(field_name).ok())
1022        .collect();
1023
1024    merged_schema.project(&project_indices)
1025}
1026
1027fn record_batch_into_schema(
1028    record_batch: RecordBatch,
1029    schema: &Schema,
1030) -> Result<RecordBatch, ArrowError> {
1031    let schema = Arc::new(schema.clone());
1032    let base_schema = record_batch.schema();
1033    if base_schema.fields().is_empty() {
1034        // Nothing to project
1035        return Ok(RecordBatch::new_empty(schema));
1036    }
1037
1038    let array_size = record_batch.column(0).len();
1039    let mut data_arrays = Vec::with_capacity(schema.fields().len());
1040
1041    for field in schema.fields() {
1042        let desired_data_type = field.data_type();
1043        if let Some(original_data) = record_batch.column_by_name(field.name()) {
1044            let original_data_type = original_data.data_type();
1045
1046            if can_cast_types(original_data_type, desired_data_type) {
1047                data_arrays.push(arrow::compute::kernels::cast(
1048                    original_data,
1049                    desired_data_type,
1050                )?);
1051            } else if field.is_nullable() {
1052                data_arrays.push(new_null_array(desired_data_type, array_size));
1053            } else {
1054                return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
1055            }
1056        } else {
1057            if !field.is_nullable() {
1058                return Err(ArrowError::CastError(format!(
1059                    "Attempting to set null to non-nullable field {} during schema projection.",
1060                    field.name()
1061                )));
1062            }
1063            data_arrays.push(new_null_array(desired_data_type, array_size));
1064        }
1065    }
1066
1067    RecordBatch::try_new(schema, data_arrays)
1068}
1069
1070/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
1071/// It additionally returns a bool, which indicates if there are more record batches available.
1072/// We do this so we can determine if we should indicate to the user that the data has been
1073/// truncated. This collects until we have achived both of these two conditions
1074///
1075/// - We have collected our minimum number of rows
1076/// - We have reached our limit, either data size or maximum number of rows
1077///
1078/// Otherwise it will return when the stream has exhausted. If you want a specific number of
1079/// rows, set min_rows == max_rows.
1080async fn collect_record_batches_to_display(
1081    df: DataFrame,
1082    config: FormatterConfig,
1083) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
1084    let FormatterConfig {
1085        max_bytes,
1086        min_rows,
1087        repr_rows,
1088    } = config;
1089
1090    let partitioned_stream = df.execute_stream_partitioned().await?;
1091    let mut stream = futures::stream::iter(partitioned_stream).flatten();
1092    let mut size_estimate_so_far = 0;
1093    let mut rows_so_far = 0;
1094    let mut record_batches = Vec::default();
1095    let mut has_more = false;
1096
1097    // ensure minimum rows even if memory/row limits are hit
1098    while (size_estimate_so_far < max_bytes && rows_so_far < repr_rows) || rows_so_far < min_rows {
1099        let mut rb = match stream.next().await {
1100            None => {
1101                break;
1102            }
1103            Some(Ok(r)) => r,
1104            Some(Err(e)) => return Err(e),
1105        };
1106
1107        let mut rows_in_rb = rb.num_rows();
1108        if rows_in_rb > 0 {
1109            size_estimate_so_far += rb.get_array_memory_size();
1110
1111            if size_estimate_so_far > max_bytes {
1112                let ratio = max_bytes as f32 / size_estimate_so_far as f32;
1113                let total_rows = rows_in_rb + rows_so_far;
1114
1115                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
1116                if reduced_row_num < min_rows {
1117                    reduced_row_num = min_rows.min(total_rows);
1118                }
1119
1120                let limited_rows_this_rb = reduced_row_num - rows_so_far;
1121                if limited_rows_this_rb < rows_in_rb {
1122                    rows_in_rb = limited_rows_this_rb;
1123                    rb = rb.slice(0, limited_rows_this_rb);
1124                    has_more = true;
1125                }
1126            }
1127
1128            if rows_in_rb + rows_so_far > repr_rows {
1129                rb = rb.slice(0, repr_rows - rows_so_far);
1130                has_more = true;
1131            }
1132
1133            rows_so_far += rb.num_rows();
1134            record_batches.push(rb);
1135        }
1136    }
1137
1138    if record_batches.is_empty() {
1139        return Ok((Vec::default(), false));
1140    }
1141
1142    if !has_more {
1143        // Data was not already truncated, so check to see if more record batches remain
1144        has_more = match stream.try_next().await {
1145            Ok(None) => false, // reached end
1146            Ok(Some(_)) => true,
1147            Err(_) => false, // Stream disconnected
1148        };
1149    }
1150
1151    Ok((record_batches, has_more))
1152}