1use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::execution::session_state::SessionStateBuilder;
27use object_store::ObjectStore;
28use url::Url;
29use uuid::Uuid;
30
31use pyo3::exceptions::{PyKeyError, PyValueError};
32use pyo3::prelude::*;
33
34use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider};
35use crate::dataframe::PyDataFrame;
36use crate::dataset::Dataset;
37use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
38use crate::expr::sort_expr::PySortExpr;
39use crate::physical_plan::PyExecutionPlan;
40use crate::record_batch::PyRecordBatchStream;
41use crate::sql::exceptions::py_value_err;
42use crate::sql::logical::PyLogicalPlan;
43use crate::store::StorageContexts;
44use crate::udaf::PyAggregateUDF;
45use crate::udf::PyScalarUDF;
46use crate::udtf::PyTableFunction;
47use crate::udwf::PyWindowUDF;
48use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
49use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
50use datafusion::arrow::pyarrow::PyArrowType;
51use datafusion::arrow::record_batch::RecordBatch;
52use datafusion::catalog::CatalogProvider;
53use datafusion::common::TableReference;
54use datafusion::common::{exec_err, ScalarValue};
55use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
56use datafusion::datasource::file_format::parquet::ParquetFormat;
57use datafusion::datasource::listing::{
58 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
59};
60use datafusion::datasource::MemTable;
61use datafusion::datasource::TableProvider;
62use datafusion::execution::context::{
63 DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
64};
65use datafusion::execution::disk_manager::DiskManagerMode;
66use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
67use datafusion::execution::options::ReadOptions;
68use datafusion::execution::runtime_env::RuntimeEnvBuilder;
69use datafusion::physical_plan::SendableRecordBatchStream;
70use datafusion::prelude::{
71 AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
72};
73use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
74use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
75use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
76use pyo3::IntoPyObjectExt;
77use tokio::task::JoinHandle;
78
79#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
81#[derive(Clone, Default)]
82pub struct PySessionConfig {
83 pub config: SessionConfig,
84}
85
86impl From<SessionConfig> for PySessionConfig {
87 fn from(config: SessionConfig) -> Self {
88 Self { config }
89 }
90}
91
92#[pymethods]
93impl PySessionConfig {
94 #[pyo3(signature = (config_options=None))]
95 #[new]
96 fn new(config_options: Option<HashMap<String, String>>) -> Self {
97 let mut config = SessionConfig::new();
98 if let Some(hash_map) = config_options {
99 for (k, v) in &hash_map {
100 config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
101 }
102 }
103
104 Self { config }
105 }
106
107 fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
108 Self::from(
109 self.config
110 .clone()
111 .with_create_default_catalog_and_schema(enabled),
112 )
113 }
114
115 fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
116 Self::from(
117 self.config
118 .clone()
119 .with_default_catalog_and_schema(catalog, schema),
120 )
121 }
122
123 fn with_information_schema(&self, enabled: bool) -> Self {
124 Self::from(self.config.clone().with_information_schema(enabled))
125 }
126
127 fn with_batch_size(&self, batch_size: usize) -> Self {
128 Self::from(self.config.clone().with_batch_size(batch_size))
129 }
130
131 fn with_target_partitions(&self, target_partitions: usize) -> Self {
132 Self::from(
133 self.config
134 .clone()
135 .with_target_partitions(target_partitions),
136 )
137 }
138
139 fn with_repartition_aggregations(&self, enabled: bool) -> Self {
140 Self::from(self.config.clone().with_repartition_aggregations(enabled))
141 }
142
143 fn with_repartition_joins(&self, enabled: bool) -> Self {
144 Self::from(self.config.clone().with_repartition_joins(enabled))
145 }
146
147 fn with_repartition_windows(&self, enabled: bool) -> Self {
148 Self::from(self.config.clone().with_repartition_windows(enabled))
149 }
150
151 fn with_repartition_sorts(&self, enabled: bool) -> Self {
152 Self::from(self.config.clone().with_repartition_sorts(enabled))
153 }
154
155 fn with_repartition_file_scans(&self, enabled: bool) -> Self {
156 Self::from(self.config.clone().with_repartition_file_scans(enabled))
157 }
158
159 fn with_repartition_file_min_size(&self, size: usize) -> Self {
160 Self::from(self.config.clone().with_repartition_file_min_size(size))
161 }
162
163 fn with_parquet_pruning(&self, enabled: bool) -> Self {
164 Self::from(self.config.clone().with_parquet_pruning(enabled))
165 }
166
167 fn set(&self, key: &str, value: &str) -> Self {
168 Self::from(self.config.clone().set_str(key, value))
169 }
170}
171
172#[pyclass(name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
174#[derive(Clone)]
175pub struct PyRuntimeEnvBuilder {
176 pub builder: RuntimeEnvBuilder,
177}
178
179#[pymethods]
180impl PyRuntimeEnvBuilder {
181 #[new]
182 fn new() -> Self {
183 Self {
184 builder: RuntimeEnvBuilder::default(),
185 }
186 }
187
188 fn with_disk_manager_disabled(&self) -> Self {
189 let mut runtime_builder = self.builder.clone();
190
191 let mut disk_mgr_builder = runtime_builder
192 .disk_manager_builder
193 .clone()
194 .unwrap_or_default();
195 disk_mgr_builder.set_mode(DiskManagerMode::Disabled);
196
197 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
198 Self {
199 builder: runtime_builder,
200 }
201 }
202
203 fn with_disk_manager_os(&self) -> Self {
204 let mut runtime_builder = self.builder.clone();
205
206 let mut disk_mgr_builder = runtime_builder
207 .disk_manager_builder
208 .clone()
209 .unwrap_or_default();
210 disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory);
211
212 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
213 Self {
214 builder: runtime_builder,
215 }
216 }
217
218 fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
219 let paths = paths.iter().map(|s| s.into()).collect();
220 let mut runtime_builder = self.builder.clone();
221
222 let mut disk_mgr_builder = runtime_builder
223 .disk_manager_builder
224 .clone()
225 .unwrap_or_default();
226 disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths));
227
228 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
229 Self {
230 builder: runtime_builder,
231 }
232 }
233
234 fn with_unbounded_memory_pool(&self) -> Self {
235 let builder = self.builder.clone();
236 let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
237 Self { builder }
238 }
239
240 fn with_fair_spill_pool(&self, size: usize) -> Self {
241 let builder = self.builder.clone();
242 let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
243 Self { builder }
244 }
245
246 fn with_greedy_memory_pool(&self, size: usize) -> Self {
247 let builder = self.builder.clone();
248 let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
249 Self { builder }
250 }
251
252 fn with_temp_file_path(&self, path: &str) -> Self {
253 let builder = self.builder.clone();
254 let builder = builder.with_temp_file_path(path);
255 Self { builder }
256 }
257}
258
259#[pyclass(name = "SQLOptions", module = "datafusion", subclass)]
261#[derive(Clone)]
262pub struct PySQLOptions {
263 pub options: SQLOptions,
264}
265
266impl From<SQLOptions> for PySQLOptions {
267 fn from(options: SQLOptions) -> Self {
268 Self { options }
269 }
270}
271
272#[pymethods]
273impl PySQLOptions {
274 #[new]
275 fn new() -> Self {
276 let options = SQLOptions::new();
277 Self { options }
278 }
279
280 fn with_allow_ddl(&self, allow: bool) -> Self {
282 Self::from(self.options.with_allow_ddl(allow))
283 }
284
285 pub fn with_allow_dml(&self, allow: bool) -> Self {
287 Self::from(self.options.with_allow_dml(allow))
288 }
289
290 pub fn with_allow_statements(&self, allow: bool) -> Self {
292 Self::from(self.options.with_allow_statements(allow))
293 }
294}
295
296#[pyclass(name = "SessionContext", module = "datafusion", subclass)]
300#[derive(Clone)]
301pub struct PySessionContext {
302 pub ctx: SessionContext,
303}
304
305#[pymethods]
306impl PySessionContext {
307 #[pyo3(signature = (config=None, runtime=None))]
308 #[new]
309 pub fn new(
310 config: Option<PySessionConfig>,
311 runtime: Option<PyRuntimeEnvBuilder>,
312 ) -> PyDataFusionResult<Self> {
313 let config = if let Some(c) = config {
314 c.config
315 } else {
316 SessionConfig::default().with_information_schema(true)
317 };
318 let runtime_env_builder = if let Some(c) = runtime {
319 c.builder
320 } else {
321 RuntimeEnvBuilder::default()
322 };
323 let runtime = Arc::new(runtime_env_builder.build()?);
324 let session_state = SessionStateBuilder::new()
325 .with_config(config)
326 .with_runtime_env(runtime)
327 .with_default_features()
328 .build();
329 Ok(PySessionContext {
330 ctx: SessionContext::new_with_state(session_state),
331 })
332 }
333
334 pub fn enable_url_table(&self) -> PyResult<Self> {
335 Ok(PySessionContext {
336 ctx: self.ctx.clone().enable_url_table(),
337 })
338 }
339
340 #[classmethod]
341 #[pyo3(signature = ())]
342 fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
343 Ok(Self {
344 ctx: get_global_ctx().clone(),
345 })
346 }
347
348 #[pyo3(signature = (scheme, store, host=None))]
350 pub fn register_object_store(
351 &mut self,
352 scheme: &str,
353 store: StorageContexts,
354 host: Option<&str>,
355 ) -> PyResult<()> {
356 let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
358 StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
359 StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
360 StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
361 StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
362 StorageContexts::HTTP(http) => (http.store, http.url),
363 };
364
365 let derived_host = if let Some(host) = host {
367 host
368 } else {
369 &upstream_host
370 };
371 let url_string = format!("{scheme}{derived_host}");
372 let url = Url::parse(&url_string).unwrap();
373 self.ctx.runtime_env().register_object_store(&url, store);
374 Ok(())
375 }
376
377 #[allow(clippy::too_many_arguments)]
378 #[pyo3(signature = (name, path, table_partition_cols=vec![],
379 file_extension=".parquet",
380 schema=None,
381 file_sort_order=None))]
382 pub fn register_listing_table(
383 &mut self,
384 name: &str,
385 path: &str,
386 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
387 file_extension: &str,
388 schema: Option<PyArrowType<Schema>>,
389 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
390 py: Python,
391 ) -> PyDataFusionResult<()> {
392 let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
393 .with_file_extension(file_extension)
394 .with_table_partition_cols(
395 table_partition_cols
396 .into_iter()
397 .map(|(name, ty)| (name, ty.0))
398 .collect::<Vec<(String, DataType)>>(),
399 )
400 .with_file_sort_order(
401 file_sort_order
402 .unwrap_or_default()
403 .into_iter()
404 .map(|e| e.into_iter().map(|f| f.into()).collect())
405 .collect(),
406 );
407 let table_path = ListingTableUrl::parse(path)?;
408 let resolved_schema: SchemaRef = match schema {
409 Some(s) => Arc::new(s.0),
410 None => {
411 let state = self.ctx.state();
412 let schema = options.infer_schema(&state, &table_path);
413 wait_for_future(py, schema)??
414 }
415 };
416 let config = ListingTableConfig::new(table_path)
417 .with_listing_options(options)
418 .with_schema(resolved_schema);
419 let table = ListingTable::try_new(config)?;
420 self.register_table(
421 name,
422 &PyTable {
423 table: Arc::new(table),
424 },
425 )?;
426 Ok(())
427 }
428
429 pub fn register_udtf(&mut self, func: PyTableFunction) {
430 let name = func.name.clone();
431 let func = Arc::new(func);
432 self.ctx.register_udtf(&name, func);
433 }
434
435 pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
437 let result = self.ctx.sql(query);
438 let df = wait_for_future(py, result)??;
439 Ok(PyDataFrame::new(df))
440 }
441
442 #[pyo3(signature = (query, options=None))]
443 pub fn sql_with_options(
444 &mut self,
445 query: &str,
446 options: Option<PySQLOptions>,
447 py: Python,
448 ) -> PyDataFusionResult<PyDataFrame> {
449 let options = if let Some(options) = options {
450 options.options
451 } else {
452 SQLOptions::new()
453 };
454 let result = self.ctx.sql_with_options(query, options);
455 let df = wait_for_future(py, result)??;
456 Ok(PyDataFrame::new(df))
457 }
458
459 #[pyo3(signature = (partitions, name=None, schema=None))]
460 pub fn create_dataframe(
461 &mut self,
462 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
463 name: Option<&str>,
464 schema: Option<PyArrowType<Schema>>,
465 py: Python,
466 ) -> PyDataFusionResult<PyDataFrame> {
467 let schema = if let Some(schema) = schema {
468 SchemaRef::from(schema.0)
469 } else {
470 partitions.0[0][0].schema()
471 };
472
473 let table = MemTable::try_new(schema, partitions.0)?;
474
475 let table_name = match name {
478 Some(val) => val.to_owned(),
479 None => {
480 "c".to_owned()
481 + Uuid::new_v4()
482 .simple()
483 .encode_lower(&mut Uuid::encode_buffer())
484 }
485 };
486
487 self.ctx.register_table(&*table_name, Arc::new(table))?;
488
489 let table = wait_for_future(py, self._table(&table_name))??;
490
491 let df = PyDataFrame::new(table);
492 Ok(df)
493 }
494
495 pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
497 PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
498 }
499
500 #[pyo3(signature = (data, name=None))]
502 pub fn from_pylist(
503 &mut self,
504 data: Bound<'_, PyList>,
505 name: Option<&str>,
506 ) -> PyResult<PyDataFrame> {
507 let py = data.py();
509
510 let table_class = py.import("pyarrow")?.getattr("Table")?;
512 let args = PyTuple::new(py, &[data])?;
513 let table = table_class.call_method1("from_pylist", args)?;
514
515 let df = self.from_arrow(table, name, py)?;
517 Ok(df)
518 }
519
520 #[pyo3(signature = (data, name=None))]
522 pub fn from_pydict(
523 &mut self,
524 data: Bound<'_, PyDict>,
525 name: Option<&str>,
526 ) -> PyResult<PyDataFrame> {
527 let py = data.py();
529
530 let table_class = py.import("pyarrow")?.getattr("Table")?;
532 let args = PyTuple::new(py, &[data])?;
533 let table = table_class.call_method1("from_pydict", args)?;
534
535 let df = self.from_arrow(table, name, py)?;
537 Ok(df)
538 }
539
540 #[pyo3(signature = (data, name=None))]
542 pub fn from_arrow(
543 &mut self,
544 data: Bound<'_, PyAny>,
545 name: Option<&str>,
546 py: Python,
547 ) -> PyDataFusionResult<PyDataFrame> {
548 let (schema, batches) =
549 if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
550 let schema = stream_reader.schema().as_ref().to_owned();
553 let batches = stream_reader
554 .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
555
556 (schema, batches)
557 } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
558 (array.schema().as_ref().to_owned(), vec![array])
562 } else {
563 return Err(crate::errors::PyDataFusionError::Common(
564 "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
565 ));
566 };
567
568 let list_of_batches = PyArrowType::from(vec![batches]);
571 self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
572 }
573
574 #[allow(clippy::wrong_self_convention)]
576 #[pyo3(signature = (data, name=None))]
577 pub fn from_pandas(
578 &mut self,
579 data: Bound<'_, PyAny>,
580 name: Option<&str>,
581 ) -> PyResult<PyDataFrame> {
582 let py = data.py();
584
585 let table_class = py.import("pyarrow")?.getattr("Table")?;
587 let args = PyTuple::new(py, &[data])?;
588 let table = table_class.call_method1("from_pandas", args)?;
589
590 let df = self.from_arrow(table, name, py)?;
592 Ok(df)
593 }
594
595 #[pyo3(signature = (data, name=None))]
597 pub fn from_polars(
598 &mut self,
599 data: Bound<'_, PyAny>,
600 name: Option<&str>,
601 ) -> PyResult<PyDataFrame> {
602 let table = data.call_method0("to_arrow")?;
604
605 let df = self.from_arrow(table, name, data.py())?;
607 Ok(df)
608 }
609
610 pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyDataFusionResult<()> {
611 self.ctx.register_table(name, table.table())?;
612 Ok(())
613 }
614
615 pub fn deregister_table(&mut self, name: &str) -> PyDataFusionResult<()> {
616 self.ctx.deregister_table(name)?;
617 Ok(())
618 }
619
620 pub fn register_catalog_provider(
621 &mut self,
622 name: &str,
623 provider: Bound<'_, PyAny>,
624 ) -> PyDataFusionResult<()> {
625 let provider = if provider.hasattr("__datafusion_catalog_provider__")? {
626 let capsule = provider
627 .getattr("__datafusion_catalog_provider__")?
628 .call0()?;
629 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
630 validate_pycapsule(capsule, "datafusion_catalog_provider")?;
631
632 let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
633 let provider: ForeignCatalogProvider = provider.into();
634 Arc::new(provider) as Arc<dyn CatalogProvider>
635 } else {
636 match provider.extract::<PyCatalog>() {
637 Ok(py_catalog) => py_catalog.catalog,
638 Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
639 as Arc<dyn CatalogProvider>,
640 }
641 };
642
643 let _ = self.ctx.register_catalog(name, provider);
644
645 Ok(())
646 }
647
648 pub fn register_table_provider(
650 &mut self,
651 name: &str,
652 provider: Bound<'_, PyAny>,
653 ) -> PyDataFusionResult<()> {
654 if provider.hasattr("__datafusion_table_provider__")? {
655 let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
656 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
657 validate_pycapsule(capsule, "datafusion_table_provider")?;
658
659 let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
660 let provider: ForeignTableProvider = provider.into();
661
662 let _ = self.ctx.register_table(name, Arc::new(provider))?;
663
664 Ok(())
665 } else {
666 Err(crate::errors::PyDataFusionError::Common(
667 "__datafusion_table_provider__ does not exist on Table Provider object."
668 .to_string(),
669 ))
670 }
671 }
672
673 pub fn register_record_batches(
674 &mut self,
675 name: &str,
676 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
677 ) -> PyDataFusionResult<()> {
678 let schema = partitions.0[0][0].schema();
679 let table = MemTable::try_new(schema, partitions.0)?;
680 self.ctx.register_table(name, Arc::new(table))?;
681 Ok(())
682 }
683
684 #[allow(clippy::too_many_arguments)]
685 #[pyo3(signature = (name, path, table_partition_cols=vec![],
686 parquet_pruning=true,
687 file_extension=".parquet",
688 skip_metadata=true,
689 schema=None,
690 file_sort_order=None))]
691 pub fn register_parquet(
692 &mut self,
693 name: &str,
694 path: &str,
695 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
696 parquet_pruning: bool,
697 file_extension: &str,
698 skip_metadata: bool,
699 schema: Option<PyArrowType<Schema>>,
700 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
701 py: Python,
702 ) -> PyDataFusionResult<()> {
703 let mut options = ParquetReadOptions::default()
704 .table_partition_cols(
705 table_partition_cols
706 .into_iter()
707 .map(|(name, ty)| (name, ty.0))
708 .collect::<Vec<(String, DataType)>>(),
709 )
710 .parquet_pruning(parquet_pruning)
711 .skip_metadata(skip_metadata);
712 options.file_extension = file_extension;
713 options.schema = schema.as_ref().map(|x| &x.0);
714 options.file_sort_order = file_sort_order
715 .unwrap_or_default()
716 .into_iter()
717 .map(|e| e.into_iter().map(|f| f.into()).collect())
718 .collect();
719
720 let result = self.ctx.register_parquet(name, path, options);
721 wait_for_future(py, result)??;
722 Ok(())
723 }
724
725 #[allow(clippy::too_many_arguments)]
726 #[pyo3(signature = (name,
727 path,
728 schema=None,
729 has_header=true,
730 delimiter=",",
731 schema_infer_max_records=1000,
732 file_extension=".csv",
733 file_compression_type=None))]
734 pub fn register_csv(
735 &mut self,
736 name: &str,
737 path: &Bound<'_, PyAny>,
738 schema: Option<PyArrowType<Schema>>,
739 has_header: bool,
740 delimiter: &str,
741 schema_infer_max_records: usize,
742 file_extension: &str,
743 file_compression_type: Option<String>,
744 py: Python,
745 ) -> PyDataFusionResult<()> {
746 let delimiter = delimiter.as_bytes();
747 if delimiter.len() != 1 {
748 return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
749 "Delimiter must be a single character",
750 )));
751 }
752
753 let mut options = CsvReadOptions::new()
754 .has_header(has_header)
755 .delimiter(delimiter[0])
756 .schema_infer_max_records(schema_infer_max_records)
757 .file_extension(file_extension)
758 .file_compression_type(parse_file_compression_type(file_compression_type)?);
759 options.schema = schema.as_ref().map(|x| &x.0);
760
761 if path.is_instance_of::<PyList>() {
762 let paths = path.extract::<Vec<String>>()?;
763 let result = self.register_csv_from_multiple_paths(name, paths, options);
764 wait_for_future(py, result)??;
765 } else {
766 let path = path.extract::<String>()?;
767 let result = self.ctx.register_csv(name, &path, options);
768 wait_for_future(py, result)??;
769 }
770
771 Ok(())
772 }
773
774 #[allow(clippy::too_many_arguments)]
775 #[pyo3(signature = (name,
776 path,
777 schema=None,
778 schema_infer_max_records=1000,
779 file_extension=".json",
780 table_partition_cols=vec![],
781 file_compression_type=None))]
782 pub fn register_json(
783 &mut self,
784 name: &str,
785 path: PathBuf,
786 schema: Option<PyArrowType<Schema>>,
787 schema_infer_max_records: usize,
788 file_extension: &str,
789 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
790 file_compression_type: Option<String>,
791 py: Python,
792 ) -> PyDataFusionResult<()> {
793 let path = path
794 .to_str()
795 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
796
797 let mut options = NdJsonReadOptions::default()
798 .file_compression_type(parse_file_compression_type(file_compression_type)?)
799 .table_partition_cols(
800 table_partition_cols
801 .into_iter()
802 .map(|(name, ty)| (name, ty.0))
803 .collect::<Vec<(String, DataType)>>(),
804 );
805 options.schema_infer_max_records = schema_infer_max_records;
806 options.file_extension = file_extension;
807 options.schema = schema.as_ref().map(|x| &x.0);
808
809 let result = self.ctx.register_json(name, path, options);
810 wait_for_future(py, result)??;
811
812 Ok(())
813 }
814
815 #[allow(clippy::too_many_arguments)]
816 #[pyo3(signature = (name,
817 path,
818 schema=None,
819 file_extension=".avro",
820 table_partition_cols=vec![]))]
821 pub fn register_avro(
822 &mut self,
823 name: &str,
824 path: PathBuf,
825 schema: Option<PyArrowType<Schema>>,
826 file_extension: &str,
827 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
828 py: Python,
829 ) -> PyDataFusionResult<()> {
830 let path = path
831 .to_str()
832 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
833
834 let mut options = AvroReadOptions::default().table_partition_cols(
835 table_partition_cols
836 .into_iter()
837 .map(|(name, ty)| (name, ty.0))
838 .collect::<Vec<(String, DataType)>>(),
839 );
840 options.file_extension = file_extension;
841 options.schema = schema.as_ref().map(|x| &x.0);
842
843 let result = self.ctx.register_avro(name, path, options);
844 wait_for_future(py, result)??;
845
846 Ok(())
847 }
848
849 pub fn register_dataset(
851 &self,
852 name: &str,
853 dataset: &Bound<'_, PyAny>,
854 py: Python,
855 ) -> PyDataFusionResult<()> {
856 let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
857
858 self.ctx.register_table(name, table)?;
859
860 Ok(())
861 }
862
863 pub fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
864 self.ctx.register_udf(udf.function);
865 Ok(())
866 }
867
868 pub fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> {
869 self.ctx.register_udaf(udaf.function);
870 Ok(())
871 }
872
873 pub fn register_udwf(&mut self, udwf: PyWindowUDF) -> PyResult<()> {
874 self.ctx.register_udwf(udwf.function);
875 Ok(())
876 }
877
878 #[pyo3(signature = (name="datafusion"))]
879 pub fn catalog(&self, name: &str) -> PyResult<PyObject> {
880 let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
881 "Catalog with name {name} doesn't exist."
882 )))?;
883
884 Python::with_gil(|py| {
885 match catalog
886 .as_any()
887 .downcast_ref::<RustWrappedPyCatalogProvider>()
888 {
889 Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)),
890 None => PyCatalog::from(catalog).into_py_any(py),
891 }
892 })
893 }
894
895 pub fn catalog_names(&self) -> HashSet<String> {
896 self.ctx.catalog_names().into_iter().collect()
897 }
898
899 pub fn tables(&self) -> HashSet<String> {
900 self.ctx
901 .catalog_names()
902 .into_iter()
903 .filter_map(|name| self.ctx.catalog(&name))
904 .flat_map(move |catalog| {
905 catalog
906 .schema_names()
907 .into_iter()
908 .filter_map(move |name| catalog.schema(&name))
909 })
910 .flat_map(|schema| schema.table_names())
911 .collect()
912 }
913
914 pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
915 let res = wait_for_future(py, self.ctx.table(name))
916 .map_err(|e| PyKeyError::new_err(e.to_string()))?;
917 match res {
918 Ok(df) => Ok(PyDataFrame::new(df)),
919 Err(e) => {
920 if let datafusion::error::DataFusionError::Plan(msg) = &e {
921 if msg.contains("No table named") {
922 return Err(PyKeyError::new_err(msg.to_string()));
923 }
924 }
925 Err(py_datafusion_err(e))
926 }
927 }
928 }
929
930 pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
931 Ok(self.ctx.table_exist(name)?)
932 }
933
934 pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
935 Ok(PyDataFrame::new(self.ctx.read_empty()?))
936 }
937
938 pub fn session_id(&self) -> String {
939 self.ctx.session_id()
940 }
941
942 #[allow(clippy::too_many_arguments)]
943 #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
944 pub fn read_json(
945 &mut self,
946 path: PathBuf,
947 schema: Option<PyArrowType<Schema>>,
948 schema_infer_max_records: usize,
949 file_extension: &str,
950 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
951 file_compression_type: Option<String>,
952 py: Python,
953 ) -> PyDataFusionResult<PyDataFrame> {
954 let path = path
955 .to_str()
956 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
957 let mut options = NdJsonReadOptions::default()
958 .table_partition_cols(
959 table_partition_cols
960 .into_iter()
961 .map(|(name, ty)| (name, ty.0))
962 .collect::<Vec<(String, DataType)>>(),
963 )
964 .file_compression_type(parse_file_compression_type(file_compression_type)?);
965 options.schema_infer_max_records = schema_infer_max_records;
966 options.file_extension = file_extension;
967 let df = if let Some(schema) = schema {
968 options.schema = Some(&schema.0);
969 let result = self.ctx.read_json(path, options);
970 wait_for_future(py, result)??
971 } else {
972 let result = self.ctx.read_json(path, options);
973 wait_for_future(py, result)??
974 };
975 Ok(PyDataFrame::new(df))
976 }
977
978 #[allow(clippy::too_many_arguments)]
979 #[pyo3(signature = (
980 path,
981 schema=None,
982 has_header=true,
983 delimiter=",",
984 schema_infer_max_records=1000,
985 file_extension=".csv",
986 table_partition_cols=vec![],
987 file_compression_type=None))]
988 pub fn read_csv(
989 &self,
990 path: &Bound<'_, PyAny>,
991 schema: Option<PyArrowType<Schema>>,
992 has_header: bool,
993 delimiter: &str,
994 schema_infer_max_records: usize,
995 file_extension: &str,
996 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
997 file_compression_type: Option<String>,
998 py: Python,
999 ) -> PyDataFusionResult<PyDataFrame> {
1000 let delimiter = delimiter.as_bytes();
1001 if delimiter.len() != 1 {
1002 return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
1003 "Delimiter must be a single character",
1004 )));
1005 };
1006
1007 let mut options = CsvReadOptions::new()
1008 .has_header(has_header)
1009 .delimiter(delimiter[0])
1010 .schema_infer_max_records(schema_infer_max_records)
1011 .file_extension(file_extension)
1012 .table_partition_cols(
1013 table_partition_cols
1014 .into_iter()
1015 .map(|(name, ty)| (name, ty.0))
1016 .collect::<Vec<(String, DataType)>>(),
1017 )
1018 .file_compression_type(parse_file_compression_type(file_compression_type)?);
1019 options.schema = schema.as_ref().map(|x| &x.0);
1020
1021 if path.is_instance_of::<PyList>() {
1022 let paths = path.extract::<Vec<String>>()?;
1023 let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
1024 let result = self.ctx.read_csv(paths, options);
1025 let df = PyDataFrame::new(wait_for_future(py, result)??);
1026 Ok(df)
1027 } else {
1028 let path = path.extract::<String>()?;
1029 let result = self.ctx.read_csv(path, options);
1030 let df = PyDataFrame::new(wait_for_future(py, result)??);
1031 Ok(df)
1032 }
1033 }
1034
1035 #[allow(clippy::too_many_arguments)]
1036 #[pyo3(signature = (
1037 path,
1038 table_partition_cols=vec![],
1039 parquet_pruning=true,
1040 file_extension=".parquet",
1041 skip_metadata=true,
1042 schema=None,
1043 file_sort_order=None))]
1044 pub fn read_parquet(
1045 &self,
1046 path: &str,
1047 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1048 parquet_pruning: bool,
1049 file_extension: &str,
1050 skip_metadata: bool,
1051 schema: Option<PyArrowType<Schema>>,
1052 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1053 py: Python,
1054 ) -> PyDataFusionResult<PyDataFrame> {
1055 let mut options = ParquetReadOptions::default()
1056 .table_partition_cols(
1057 table_partition_cols
1058 .into_iter()
1059 .map(|(name, ty)| (name, ty.0))
1060 .collect::<Vec<(String, DataType)>>(),
1061 )
1062 .parquet_pruning(parquet_pruning)
1063 .skip_metadata(skip_metadata);
1064 options.file_extension = file_extension;
1065 options.schema = schema.as_ref().map(|x| &x.0);
1066 options.file_sort_order = file_sort_order
1067 .unwrap_or_default()
1068 .into_iter()
1069 .map(|e| e.into_iter().map(|f| f.into()).collect())
1070 .collect();
1071
1072 let result = self.ctx.read_parquet(path, options);
1073 let df = PyDataFrame::new(wait_for_future(py, result)??);
1074 Ok(df)
1075 }
1076
1077 #[allow(clippy::too_many_arguments)]
1078 #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
1079 pub fn read_avro(
1080 &self,
1081 path: &str,
1082 schema: Option<PyArrowType<Schema>>,
1083 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1084 file_extension: &str,
1085 py: Python,
1086 ) -> PyDataFusionResult<PyDataFrame> {
1087 let mut options = AvroReadOptions::default().table_partition_cols(
1088 table_partition_cols
1089 .into_iter()
1090 .map(|(name, ty)| (name, ty.0))
1091 .collect::<Vec<(String, DataType)>>(),
1092 );
1093 options.file_extension = file_extension;
1094 let df = if let Some(schema) = schema {
1095 options.schema = Some(&schema.0);
1096 let read_future = self.ctx.read_avro(path, options);
1097 wait_for_future(py, read_future)??
1098 } else {
1099 let read_future = self.ctx.read_avro(path, options);
1100 wait_for_future(py, read_future)??
1101 };
1102 Ok(PyDataFrame::new(df))
1103 }
1104
1105 pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
1106 let df = self.ctx.read_table(table.table())?;
1107 Ok(PyDataFrame::new(df))
1108 }
1109
1110 fn __repr__(&self) -> PyResult<String> {
1111 let config = self.ctx.copied_config();
1112 let mut config_entries = config
1113 .options()
1114 .entries()
1115 .iter()
1116 .filter(|e| e.value.is_some())
1117 .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1118 .collect::<Vec<_>>();
1119 config_entries.sort();
1120 Ok(format!(
1121 "SessionContext: id={}; configs=[\n\t{}]",
1122 self.session_id(),
1123 config_entries.join("\n\t")
1124 ))
1125 }
1126
1127 pub fn execute(
1129 &self,
1130 plan: PyExecutionPlan,
1131 part: usize,
1132 py: Python,
1133 ) -> PyDataFusionResult<PyRecordBatchStream> {
1134 let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1135 let rt = &get_tokio_runtime().0;
1137 let plan = plan.plan.clone();
1138 let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1139 rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1140 let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
1141 Ok(PyRecordBatchStream::new(stream))
1142 }
1143}
1144
1145impl PySessionContext {
1146 async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1147 self.ctx.table(name).await
1148 }
1149
1150 async fn register_csv_from_multiple_paths(
1151 &self,
1152 name: &str,
1153 table_paths: Vec<String>,
1154 options: CsvReadOptions<'_>,
1155 ) -> datafusion::common::Result<()> {
1156 let table_paths = table_paths.to_urls()?;
1157 let session_config = self.ctx.copied_config();
1158 let listing_options =
1159 options.to_listing_options(&session_config, self.ctx.copied_table_options());
1160
1161 let option_extension = listing_options.file_extension.clone();
1162
1163 if table_paths.is_empty() {
1164 return exec_err!("No table paths were provided");
1165 }
1166
1167 for path in &table_paths {
1169 let file_path = path.as_str();
1170 if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1171 return exec_err!(
1172 "File path '{file_path}' does not match the expected extension '{option_extension}'"
1173 );
1174 }
1175 }
1176
1177 let resolved_schema = options
1178 .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1179 .await?;
1180
1181 let config = ListingTableConfig::new_with_multi_paths(table_paths)
1182 .with_listing_options(listing_options)
1183 .with_schema(resolved_schema);
1184 let table = ListingTable::try_new(config)?;
1185 self.ctx
1186 .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1187 Ok(())
1188 }
1189}
1190
1191pub fn parse_file_compression_type(
1192 file_compression_type: Option<String>,
1193) -> Result<FileCompressionType, PyErr> {
1194 FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1195 .map_err(|_| {
1196 PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1197 })
1198}
1199
1200impl From<PySessionContext> for SessionContext {
1201 fn from(ctx: PySessionContext) -> SessionContext {
1202 ctx.ctx
1203 }
1204}
1205
1206impl From<SessionContext> for PySessionContext {
1207 fn from(ctx: SessionContext) -> PySessionContext {
1208 PySessionContext { ctx }
1209 }
1210}