1use std::collections::HashMap;
19use std::ffi::CString;
20use std::sync::Arc;
21
22use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
23use arrow::compute::can_cast_types;
24use arrow::error::ArrowError;
25use arrow::ffi::FFI_ArrowSchema;
26use arrow::ffi_stream::FFI_ArrowArrayStream;
27use arrow::pyarrow::FromPyArrow;
28use datafusion::arrow::datatypes::Schema;
29use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
30use datafusion::arrow::util::pretty;
31use datafusion::common::UnnestOptions;
32use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions};
33use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
34use datafusion::datasource::TableProvider;
35use datafusion::error::DataFusionError;
36use datafusion::execution::SendableRecordBatchStream;
37use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
38use datafusion::prelude::*;
39use datafusion_ffi::table_provider::FFI_TableProvider;
40use futures::{StreamExt, TryStreamExt};
41use pyo3::exceptions::PyValueError;
42use pyo3::prelude::*;
43use pyo3::pybacked::PyBackedStr;
44use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45use tokio::task::JoinHandle;
46
47use crate::catalog::PyTable;
48use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
49use crate::expr::sort_expr::to_sort_expressions;
50use crate::physical_plan::PyExecutionPlan;
51use crate::record_batch::PyRecordBatchStream;
52use crate::sql::logical::PyLogicalPlan;
53use crate::utils::{
54 get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
55};
56use crate::{
57 errors::PyDataFusionResult,
58 expr::{sort_expr::PySortExpr, PyExpr},
59};
60
61#[pyclass(name = "TableProvider", module = "datafusion")]
65pub struct PyTableProvider {
66 provider: Arc<dyn TableProvider + Send>,
67}
68
69impl PyTableProvider {
70 pub fn new(provider: Arc<dyn TableProvider>) -> Self {
71 Self { provider }
72 }
73
74 pub fn as_table(&self) -> PyTable {
75 let table_provider: Arc<dyn TableProvider> = self.provider.clone();
76 PyTable::new(table_provider)
77 }
78}
79
80#[pymethods]
81impl PyTableProvider {
82 fn __datafusion_table_provider__<'py>(
83 &self,
84 py: Python<'py>,
85 ) -> PyResult<Bound<'py, PyCapsule>> {
86 let name = CString::new("datafusion_table_provider").unwrap();
87
88 let runtime = get_tokio_runtime().0.handle().clone();
89 let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime));
90
91 PyCapsule::new(py, provider, Some(name.clone()))
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct FormatterConfig {
98 pub max_bytes: usize,
100 pub min_rows: usize,
102 pub repr_rows: usize,
104}
105
106impl Default for FormatterConfig {
107 fn default() -> Self {
108 Self {
109 max_bytes: 2 * 1024 * 1024, min_rows: 20,
111 repr_rows: 10,
112 }
113 }
114}
115
116impl FormatterConfig {
117 pub fn validate(&self) -> Result<(), String> {
123 if self.max_bytes == 0 {
124 return Err("max_bytes must be a positive integer".to_string());
125 }
126
127 if self.min_rows == 0 {
128 return Err("min_rows must be a positive integer".to_string());
129 }
130
131 if self.repr_rows == 0 {
132 return Err("repr_rows must be a positive integer".to_string());
133 }
134
135 Ok(())
136 }
137}
138
139struct PythonFormatter<'py> {
141 formatter: Bound<'py, PyAny>,
143 config: FormatterConfig,
145}
146
147fn get_python_formatter_with_config(py: Python) -> PyResult<PythonFormatter> {
149 let formatter = import_python_formatter(py)?;
150 let config = build_formatter_config_from_python(&formatter)?;
151 Ok(PythonFormatter { formatter, config })
152}
153
154fn import_python_formatter(py: Python) -> PyResult<Bound<'_, PyAny>> {
156 let formatter_module = py.import("datafusion.dataframe_formatter")?;
157 let get_formatter = formatter_module.getattr("get_formatter")?;
158 get_formatter.call0()
159}
160
161fn get_attr<'a, T>(py_object: &'a Bound<'a, PyAny>, attr_name: &str, default_value: T) -> T
163where
164 T: for<'py> pyo3::FromPyObject<'py> + Clone,
165{
166 py_object
167 .getattr(attr_name)
168 .and_then(|v| v.extract::<T>())
169 .unwrap_or_else(|_| default_value.clone())
170}
171
172fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<FormatterConfig> {
174 let default_config = FormatterConfig::default();
175 let max_bytes = get_attr(formatter, "max_memory_bytes", default_config.max_bytes);
176 let min_rows = get_attr(formatter, "min_rows_display", default_config.min_rows);
177 let repr_rows = get_attr(formatter, "repr_rows", default_config.repr_rows);
178
179 let config = FormatterConfig {
180 max_bytes,
181 min_rows,
182 repr_rows,
183 };
184
185 config.validate().map_err(PyValueError::new_err)?;
187 Ok(config)
188}
189
190#[pyclass(name = "ParquetWriterOptions", module = "datafusion", subclass)]
192#[derive(Clone, Default)]
193pub struct PyParquetWriterOptions {
194 options: ParquetOptions,
195}
196
197#[pymethods]
198impl PyParquetWriterOptions {
199 #[new]
200 #[allow(clippy::too_many_arguments)]
201 pub fn new(
202 data_pagesize_limit: usize,
203 write_batch_size: usize,
204 writer_version: String,
205 skip_arrow_metadata: bool,
206 compression: Option<String>,
207 dictionary_enabled: Option<bool>,
208 dictionary_page_size_limit: usize,
209 statistics_enabled: Option<String>,
210 max_row_group_size: usize,
211 created_by: String,
212 column_index_truncate_length: Option<usize>,
213 statistics_truncate_length: Option<usize>,
214 data_page_row_count_limit: usize,
215 encoding: Option<String>,
216 bloom_filter_on_write: bool,
217 bloom_filter_fpp: Option<f64>,
218 bloom_filter_ndv: Option<u64>,
219 allow_single_file_parallelism: bool,
220 maximum_parallel_row_group_writers: usize,
221 maximum_buffered_record_batches_per_stream: usize,
222 ) -> Self {
223 Self {
224 options: ParquetOptions {
225 data_pagesize_limit,
226 write_batch_size,
227 writer_version,
228 skip_arrow_metadata,
229 compression,
230 dictionary_enabled,
231 dictionary_page_size_limit,
232 statistics_enabled,
233 max_row_group_size,
234 created_by,
235 column_index_truncate_length,
236 statistics_truncate_length,
237 data_page_row_count_limit,
238 encoding,
239 bloom_filter_on_write,
240 bloom_filter_fpp,
241 bloom_filter_ndv,
242 allow_single_file_parallelism,
243 maximum_parallel_row_group_writers,
244 maximum_buffered_record_batches_per_stream,
245 ..Default::default()
246 },
247 }
248 }
249}
250
251#[pyclass(name = "ParquetColumnOptions", module = "datafusion", subclass)]
253#[derive(Clone, Default)]
254pub struct PyParquetColumnOptions {
255 options: ParquetColumnOptions,
256}
257
258#[pymethods]
259impl PyParquetColumnOptions {
260 #[new]
261 pub fn new(
262 bloom_filter_enabled: Option<bool>,
263 encoding: Option<String>,
264 dictionary_enabled: Option<bool>,
265 compression: Option<String>,
266 statistics_enabled: Option<String>,
267 bloom_filter_fpp: Option<f64>,
268 bloom_filter_ndv: Option<u64>,
269 ) -> Self {
270 Self {
271 options: ParquetColumnOptions {
272 bloom_filter_enabled,
273 encoding,
274 dictionary_enabled,
275 compression,
276 statistics_enabled,
277 bloom_filter_fpp,
278 bloom_filter_ndv,
279 ..Default::default()
280 },
281 }
282 }
283}
284
285#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
289#[derive(Clone)]
290pub struct PyDataFrame {
291 df: Arc<DataFrame>,
292
293 batches: Option<(Vec<RecordBatch>, bool)>,
295}
296
297impl PyDataFrame {
298 pub fn new(df: DataFrame) -> Self {
300 Self {
301 df: Arc::new(df),
302 batches: None,
303 }
304 }
305
306 fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult<String> {
307 let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
309
310 let should_cache = *is_ipython_env(py) && self.batches.is_none();
311 let (batches, has_more) = match self.batches.take() {
312 Some(b) => b,
313 None => wait_for_future(
314 py,
315 collect_record_batches_to_display(self.df.as_ref().clone(), config),
316 )??,
317 };
318
319 if batches.is_empty() {
320 return Ok("No data to display".to_string());
322 }
323
324 let table_uuid = uuid::Uuid::new_v4().to_string();
325
326 let py_batches = batches
328 .iter()
329 .map(|rb| rb.to_pyarrow(py))
330 .collect::<PyResult<Vec<PyObject>>>()?;
331
332 let py_schema = self.schema().into_pyobject(py)?;
333
334 let kwargs = pyo3::types::PyDict::new(py);
335 let py_batches_list = PyList::new(py, py_batches.as_slice())?;
336 kwargs.set_item("batches", py_batches_list)?;
337 kwargs.set_item("schema", py_schema)?;
338 kwargs.set_item("has_more", has_more)?;
339 kwargs.set_item("table_uuid", table_uuid)?;
340
341 let method_name = match as_html {
342 true => "format_html",
343 false => "format_str",
344 };
345
346 let html_result = formatter.call_method(method_name, (), Some(&kwargs))?;
347 let html_str: String = html_result.extract()?;
348
349 if should_cache {
350 self.batches = Some((batches, has_more));
351 }
352
353 Ok(html_str)
354 }
355}
356
357#[pymethods]
358impl PyDataFrame {
359 fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
361 if let Ok(key) = key.extract::<PyBackedStr>() {
362 self.select_columns(vec![key])
364 } else if let Ok(tuple) = key.downcast::<PyTuple>() {
365 let keys = tuple
367 .iter()
368 .map(|item| item.extract::<PyBackedStr>())
369 .collect::<PyResult<Vec<PyBackedStr>>>()?;
370 self.select_columns(keys)
371 } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
372 self.select_columns(keys)
374 } else {
375 let message = "DataFrame can only be indexed by string index or indices".to_string();
376 Err(PyDataFusionError::Common(message))
377 }
378 }
379
380 fn __repr__(&mut self, py: Python) -> PyDataFusionResult<String> {
381 self.prepare_repr_string(py, false)
382 }
383
384 #[staticmethod]
385 #[expect(unused_variables)]
386 fn default_str_repr<'py>(
387 batches: Vec<Bound<'py, PyAny>>,
388 schema: &Bound<'py, PyAny>,
389 has_more: bool,
390 table_uuid: &str,
391 ) -> PyResult<String> {
392 let batches = batches
393 .into_iter()
394 .map(|batch| RecordBatch::from_pyarrow_bound(&batch))
395 .collect::<PyResult<Vec<RecordBatch>>>()?
396 .into_iter()
397 .filter(|batch| batch.num_rows() > 0)
398 .collect::<Vec<_>>();
399
400 if batches.is_empty() {
401 return Ok("No data to display".to_owned());
402 }
403
404 let batches_as_displ =
405 pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
406
407 let additional_str = match has_more {
408 true => "\nData truncated.",
409 false => "",
410 };
411
412 Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
413 }
414
415 fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult<String> {
416 self.prepare_repr_string(py, true)
417 }
418
419 fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
421 let df = self.df.as_ref().clone();
422 let stat_df = wait_for_future(py, df.describe())??;
423 Ok(Self::new(stat_df))
424 }
425
426 fn schema(&self) -> PyArrowType<Schema> {
428 PyArrowType(self.df.schema().into())
429 }
430
431 #[allow(clippy::wrong_self_convention)]
439 fn into_view(&self) -> PyDataFusionResult<PyTable> {
440 let table_provider = self.df.as_ref().clone().into_view();
444 let table_provider = PyTableProvider::new(table_provider);
445
446 Ok(table_provider.as_table())
447 }
448
449 #[pyo3(signature = (*args))]
450 fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
451 let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
452 let df = self.df.as_ref().clone().select_columns(&args)?;
453 Ok(Self::new(df))
454 }
455
456 #[pyo3(signature = (*args))]
457 fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
458 let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
459 let df = self.df.as_ref().clone().select(expr)?;
460 Ok(Self::new(df))
461 }
462
463 #[pyo3(signature = (*args))]
464 fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
465 let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
466 let df = self.df.as_ref().clone().drop_columns(&cols)?;
467 Ok(Self::new(df))
468 }
469
470 fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
471 let df = self.df.as_ref().clone().filter(predicate.into())?;
472 Ok(Self::new(df))
473 }
474
475 fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
476 let df = self.df.as_ref().clone().with_column(name, expr.into())?;
477 Ok(Self::new(df))
478 }
479
480 fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
481 let mut df = self.df.as_ref().clone();
482 for expr in exprs {
483 let expr: Expr = expr.into();
484 let name = format!("{}", expr.schema_name());
485 df = df.with_column(name.as_str(), expr)?
486 }
487 Ok(Self::new(df))
488 }
489
490 fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
493 let df = self
494 .df
495 .as_ref()
496 .clone()
497 .with_column_renamed(old_name, new_name)?;
498 Ok(Self::new(df))
499 }
500
501 fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
502 let group_by = group_by.into_iter().map(|e| e.into()).collect();
503 let aggs = aggs.into_iter().map(|e| e.into()).collect();
504 let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
505 Ok(Self::new(df))
506 }
507
508 #[pyo3(signature = (*exprs))]
509 fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
510 let exprs = to_sort_expressions(exprs);
511 let df = self.df.as_ref().clone().sort(exprs)?;
512 Ok(Self::new(df))
513 }
514
515 #[pyo3(signature = (count, offset=0))]
516 fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
517 let df = self.df.as_ref().clone().limit(offset, Some(count))?;
518 Ok(Self::new(df))
519 }
520
521 fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
525 let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
526 .map_err(PyDataFusionError::from)?;
527 batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
530 }
531
532 fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
534 let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
535 Ok(Self::new(df))
536 }
537
538 fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
541 let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
542 .map_err(PyDataFusionError::from)?;
543
544 batches
545 .into_iter()
546 .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
547 .collect()
548 }
549
550 #[pyo3(signature = (num=20))]
552 fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
553 let df = self.df.as_ref().clone().limit(0, Some(num))?;
554 print_dataframe(py, df)
555 }
556
557 fn distinct(&self) -> PyDataFusionResult<Self> {
559 let df = self.df.as_ref().clone().distinct()?;
560 Ok(Self::new(df))
561 }
562
563 fn join(
564 &self,
565 right: PyDataFrame,
566 how: &str,
567 left_on: Vec<PyBackedStr>,
568 right_on: Vec<PyBackedStr>,
569 ) -> PyDataFusionResult<Self> {
570 let join_type = match how {
571 "inner" => JoinType::Inner,
572 "left" => JoinType::Left,
573 "right" => JoinType::Right,
574 "full" => JoinType::Full,
575 "semi" => JoinType::LeftSemi,
576 "anti" => JoinType::LeftAnti,
577 how => {
578 return Err(PyDataFusionError::Common(format!(
579 "The join type {how} does not exist or is not implemented"
580 )));
581 }
582 };
583
584 let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
585 let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
586
587 let df = self.df.as_ref().clone().join(
588 right.df.as_ref().clone(),
589 join_type,
590 &left_keys,
591 &right_keys,
592 None,
593 )?;
594 Ok(Self::new(df))
595 }
596
597 fn join_on(
598 &self,
599 right: PyDataFrame,
600 on_exprs: Vec<PyExpr>,
601 how: &str,
602 ) -> PyDataFusionResult<Self> {
603 let join_type = match how {
604 "inner" => JoinType::Inner,
605 "left" => JoinType::Left,
606 "right" => JoinType::Right,
607 "full" => JoinType::Full,
608 "semi" => JoinType::LeftSemi,
609 "anti" => JoinType::LeftAnti,
610 how => {
611 return Err(PyDataFusionError::Common(format!(
612 "The join type {how} does not exist or is not implemented"
613 )));
614 }
615 };
616 let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
617
618 let df = self
619 .df
620 .as_ref()
621 .clone()
622 .join_on(right.df.as_ref().clone(), join_type, exprs)?;
623 Ok(Self::new(df))
624 }
625
626 #[pyo3(signature = (verbose=false, analyze=false))]
628 fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
629 let df = self.df.as_ref().clone().explain(verbose, analyze)?;
630 print_dataframe(py, df)
631 }
632
633 fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
635 Ok(self.df.as_ref().clone().logical_plan().clone().into())
636 }
637
638 fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
640 Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
641 }
642
643 fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
645 let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
646 Ok(plan.into())
647 }
648
649 fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
651 let new_df = self
652 .df
653 .as_ref()
654 .clone()
655 .repartition(Partitioning::RoundRobinBatch(num))?;
656 Ok(Self::new(new_df))
657 }
658
659 #[pyo3(signature = (*args, num))]
661 fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
662 let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
663 let new_df = self
664 .df
665 .as_ref()
666 .clone()
667 .repartition(Partitioning::Hash(expr, num))?;
668 Ok(Self::new(new_df))
669 }
670
671 #[pyo3(signature = (py_df, distinct=false))]
674 fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
675 let new_df = if distinct {
676 self.df
677 .as_ref()
678 .clone()
679 .union_distinct(py_df.df.as_ref().clone())?
680 } else {
681 self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
682 };
683
684 Ok(Self::new(new_df))
685 }
686
687 fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
690 let new_df = self
691 .df
692 .as_ref()
693 .clone()
694 .union_distinct(py_df.df.as_ref().clone())?;
695 Ok(Self::new(new_df))
696 }
697
698 #[pyo3(signature = (column, preserve_nulls=true))]
699 fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
700 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
703 let df = self
704 .df
705 .as_ref()
706 .clone()
707 .unnest_columns_with_options(&[column], unnest_options)?;
708 Ok(Self::new(df))
709 }
710
711 #[pyo3(signature = (columns, preserve_nulls=true))]
712 fn unnest_columns(
713 &self,
714 columns: Vec<String>,
715 preserve_nulls: bool,
716 ) -> PyDataFusionResult<Self> {
717 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
720 let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
721 let df = self
722 .df
723 .as_ref()
724 .clone()
725 .unnest_columns_with_options(&cols, unnest_options)?;
726 Ok(Self::new(df))
727 }
728
729 fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
731 let new_df = self
732 .df
733 .as_ref()
734 .clone()
735 .intersect(py_df.df.as_ref().clone())?;
736 Ok(Self::new(new_df))
737 }
738
739 fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
741 let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
742 Ok(Self::new(new_df))
743 }
744
745 fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
747 let csv_options = CsvOptions {
748 has_header: Some(with_header),
749 ..Default::default()
750 };
751 wait_for_future(
752 py,
753 self.df.as_ref().clone().write_csv(
754 path,
755 DataFrameWriteOptions::new(),
756 Some(csv_options),
757 ),
758 )??;
759 Ok(())
760 }
761
762 #[pyo3(signature = (
764 path,
765 compression="zstd",
766 compression_level=None
767 ))]
768 fn write_parquet(
769 &self,
770 path: &str,
771 compression: &str,
772 compression_level: Option<u32>,
773 py: Python,
774 ) -> PyDataFusionResult<()> {
775 fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
776 cl.ok_or(PyValueError::new_err("compression_level is not defined"))
777 }
778
779 let _validated = match compression.to_lowercase().as_str() {
780 "snappy" => Compression::SNAPPY,
781 "gzip" => Compression::GZIP(
782 GzipLevel::try_new(compression_level.unwrap_or(6))
783 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
784 ),
785 "brotli" => Compression::BROTLI(
786 BrotliLevel::try_new(verify_compression_level(compression_level)?)
787 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
788 ),
789 "zstd" => Compression::ZSTD(
790 ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
791 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
792 ),
793 "lzo" => Compression::LZO,
794 "lz4" => Compression::LZ4,
795 "lz4_raw" => Compression::LZ4_RAW,
796 "uncompressed" => Compression::UNCOMPRESSED,
797 _ => {
798 return Err(PyDataFusionError::Common(format!(
799 "Unrecognized compression type {compression}"
800 )));
801 }
802 };
803
804 let mut compression_string = compression.to_string();
805 if let Some(level) = compression_level {
806 compression_string.push_str(&format!("({level})"));
807 }
808
809 let mut options = TableParquetOptions::default();
810 options.global.compression = Some(compression_string);
811
812 wait_for_future(
813 py,
814 self.df.as_ref().clone().write_parquet(
815 path,
816 DataFrameWriteOptions::new(),
817 Option::from(options),
818 ),
819 )??;
820 Ok(())
821 }
822
823 fn write_parquet_with_options(
825 &self,
826 path: &str,
827 options: PyParquetWriterOptions,
828 column_specific_options: HashMap<String, PyParquetColumnOptions>,
829 py: Python,
830 ) -> PyDataFusionResult<()> {
831 let table_options = TableParquetOptions {
832 global: options.options,
833 column_specific_options: column_specific_options
834 .into_iter()
835 .map(|(k, v)| (k, v.options))
836 .collect(),
837 ..Default::default()
838 };
839
840 wait_for_future(
841 py,
842 self.df.as_ref().clone().write_parquet(
843 path,
844 DataFrameWriteOptions::new(),
845 Option::from(table_options),
846 ),
847 )??;
848 Ok(())
849 }
850
851 fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
853 wait_for_future(
854 py,
855 self.df
856 .as_ref()
857 .clone()
858 .write_json(path, DataFrameWriteOptions::new(), None),
859 )??;
860 Ok(())
861 }
862
863 fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
866 let batches = self.collect(py)?.into_pyobject(py)?;
867 let schema = self.schema().into_pyobject(py)?;
868
869 let table_class = py.import("pyarrow")?.getattr("Table")?;
871 let args = PyTuple::new(py, &[batches, schema])?;
872 let table: PyObject = table_class.call_method1("from_batches", args)?.into();
873 Ok(table)
874 }
875
876 #[pyo3(signature = (requested_schema=None))]
877 fn __arrow_c_stream__<'py>(
878 &'py mut self,
879 py: Python<'py>,
880 requested_schema: Option<Bound<'py, PyCapsule>>,
881 ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
882 let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??;
883 let mut schema: Schema = self.df.schema().to_owned().into();
884
885 if let Some(schema_capsule) = requested_schema {
886 validate_pycapsule(&schema_capsule, "arrow_schema")?;
887
888 let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
889 let desired_schema = Schema::try_from(schema_ptr)?;
890
891 schema = project_schema(schema, desired_schema)?;
892
893 batches = batches
894 .into_iter()
895 .map(|record_batch| record_batch_into_schema(record_batch, &schema))
896 .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
897 }
898
899 let batches_wrapped = batches.into_iter().map(Ok);
900
901 let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
902 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
903
904 let ffi_stream = FFI_ArrowArrayStream::new(reader);
905 let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
906 PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
907 }
908
909 fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
910 let rt = &get_tokio_runtime().0;
912 let df = self.df.as_ref().clone();
913 let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
914 rt.spawn(async move { df.execute_stream().await });
915 let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
916 Ok(PyRecordBatchStream::new(stream))
917 }
918
919 fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
920 let rt = &get_tokio_runtime().0;
922 let df = self.df.as_ref().clone();
923 let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
924 rt.spawn(async move { df.execute_stream_partitioned().await });
925 let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })?
926 .map_err(py_datafusion_err)?
927 .map_err(py_datafusion_err)?;
928
929 Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
930 }
931
932 fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
935 let table = self.to_arrow_table(py)?;
936
937 let result = table.call_method0(py, "to_pandas")?;
939 Ok(result)
940 }
941
942 fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
945 let table = self.to_arrow_table(py)?;
946
947 let result = table.call_method0(py, "to_pylist")?;
949 Ok(result)
950 }
951
952 fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
955 let table = self.to_arrow_table(py)?;
956
957 let result = table.call_method0(py, "to_pydict")?;
959 Ok(result)
960 }
961
962 fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
965 let table = self.to_arrow_table(py)?;
966 let dataframe = py.import("polars")?.getattr("DataFrame")?;
967 let args = PyTuple::new(py, &[table])?;
968 let result: PyObject = dataframe.call1(args)?.into();
969 Ok(result)
970 }
971
972 fn count(&self, py: Python) -> PyDataFusionResult<usize> {
974 Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
975 }
976
977 #[pyo3(signature = (value, columns=None))]
979 fn fill_null(
980 &self,
981 value: PyObject,
982 columns: Option<Vec<PyBackedStr>>,
983 py: Python,
984 ) -> PyDataFusionResult<Self> {
985 let scalar_value = py_obj_to_scalar_value(py, value)?;
986
987 let cols = match columns {
988 Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
989 None => Vec::new(), };
991
992 let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
993 Ok(Self::new(df))
994 }
995}
996
997fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
999 let batches = wait_for_future(py, df.collect())??;
1001 let batches_as_string = pretty::pretty_format_batches(&batches);
1002 let result = match batches_as_string {
1003 Ok(batch) => format!("DataFrame()\n{batch}"),
1004 Err(err) => format!("Error: {:?}", err.to_string()),
1005 };
1006
1007 let print = py.import("builtins")?.getattr("print")?;
1010 print.call1((result,))?;
1011 Ok(())
1012}
1013
1014fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
1015 let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
1016
1017 let project_indices: Vec<usize> = to_schema
1018 .fields
1019 .iter()
1020 .map(|field| field.name())
1021 .filter_map(|field_name| merged_schema.index_of(field_name).ok())
1022 .collect();
1023
1024 merged_schema.project(&project_indices)
1025}
1026
1027fn record_batch_into_schema(
1028 record_batch: RecordBatch,
1029 schema: &Schema,
1030) -> Result<RecordBatch, ArrowError> {
1031 let schema = Arc::new(schema.clone());
1032 let base_schema = record_batch.schema();
1033 if base_schema.fields().is_empty() {
1034 return Ok(RecordBatch::new_empty(schema));
1036 }
1037
1038 let array_size = record_batch.column(0).len();
1039 let mut data_arrays = Vec::with_capacity(schema.fields().len());
1040
1041 for field in schema.fields() {
1042 let desired_data_type = field.data_type();
1043 if let Some(original_data) = record_batch.column_by_name(field.name()) {
1044 let original_data_type = original_data.data_type();
1045
1046 if can_cast_types(original_data_type, desired_data_type) {
1047 data_arrays.push(arrow::compute::kernels::cast(
1048 original_data,
1049 desired_data_type,
1050 )?);
1051 } else if field.is_nullable() {
1052 data_arrays.push(new_null_array(desired_data_type, array_size));
1053 } else {
1054 return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
1055 }
1056 } else {
1057 if !field.is_nullable() {
1058 return Err(ArrowError::CastError(format!(
1059 "Attempting to set null to non-nullable field {} during schema projection.",
1060 field.name()
1061 )));
1062 }
1063 data_arrays.push(new_null_array(desired_data_type, array_size));
1064 }
1065 }
1066
1067 RecordBatch::try_new(schema, data_arrays)
1068}
1069
1070async fn collect_record_batches_to_display(
1081 df: DataFrame,
1082 config: FormatterConfig,
1083) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
1084 let FormatterConfig {
1085 max_bytes,
1086 min_rows,
1087 repr_rows,
1088 } = config;
1089
1090 let partitioned_stream = df.execute_stream_partitioned().await?;
1091 let mut stream = futures::stream::iter(partitioned_stream).flatten();
1092 let mut size_estimate_so_far = 0;
1093 let mut rows_so_far = 0;
1094 let mut record_batches = Vec::default();
1095 let mut has_more = false;
1096
1097 while (size_estimate_so_far < max_bytes && rows_so_far < repr_rows) || rows_so_far < min_rows {
1099 let mut rb = match stream.next().await {
1100 None => {
1101 break;
1102 }
1103 Some(Ok(r)) => r,
1104 Some(Err(e)) => return Err(e),
1105 };
1106
1107 let mut rows_in_rb = rb.num_rows();
1108 if rows_in_rb > 0 {
1109 size_estimate_so_far += rb.get_array_memory_size();
1110
1111 if size_estimate_so_far > max_bytes {
1112 let ratio = max_bytes as f32 / size_estimate_so_far as f32;
1113 let total_rows = rows_in_rb + rows_so_far;
1114
1115 let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
1116 if reduced_row_num < min_rows {
1117 reduced_row_num = min_rows.min(total_rows);
1118 }
1119
1120 let limited_rows_this_rb = reduced_row_num - rows_so_far;
1121 if limited_rows_this_rb < rows_in_rb {
1122 rows_in_rb = limited_rows_this_rb;
1123 rb = rb.slice(0, limited_rows_this_rb);
1124 has_more = true;
1125 }
1126 }
1127
1128 if rows_in_rb + rows_so_far > repr_rows {
1129 rb = rb.slice(0, repr_rows - rows_so_far);
1130 has_more = true;
1131 }
1132
1133 rows_so_far += rb.num_rows();
1134 record_batches.push(rb);
1135 }
1136 }
1137
1138 if record_batches.is_empty() {
1139 return Ok((Vec::default(), false));
1140 }
1141
1142 if !has_more {
1143 has_more = match stream.try_next().await {
1145 Ok(None) => false, Ok(Some(_)) => true,
1147 Err(_) => false, };
1149 }
1150
1151 Ok((record_batches, has_more))
1152}