datafusion_python/
context.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::execution::session_state::SessionStateBuilder;
27use object_store::ObjectStore;
28use url::Url;
29use uuid::Uuid;
30
31use pyo3::exceptions::{PyKeyError, PyValueError};
32use pyo3::prelude::*;
33
34use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider};
35use crate::dataframe::PyDataFrame;
36use crate::dataset::Dataset;
37use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
38use crate::expr::sort_expr::PySortExpr;
39use crate::physical_plan::PyExecutionPlan;
40use crate::record_batch::PyRecordBatchStream;
41use crate::sql::exceptions::py_value_err;
42use crate::sql::logical::PyLogicalPlan;
43use crate::store::StorageContexts;
44use crate::udaf::PyAggregateUDF;
45use crate::udf::PyScalarUDF;
46use crate::udtf::PyTableFunction;
47use crate::udwf::PyWindowUDF;
48use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
49use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
50use datafusion::arrow::pyarrow::PyArrowType;
51use datafusion::arrow::record_batch::RecordBatch;
52use datafusion::catalog::CatalogProvider;
53use datafusion::common::TableReference;
54use datafusion::common::{exec_err, ScalarValue};
55use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
56use datafusion::datasource::file_format::parquet::ParquetFormat;
57use datafusion::datasource::listing::{
58    ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
59};
60use datafusion::datasource::MemTable;
61use datafusion::datasource::TableProvider;
62use datafusion::execution::context::{
63    DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
64};
65use datafusion::execution::disk_manager::DiskManagerMode;
66use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
67use datafusion::execution::options::ReadOptions;
68use datafusion::execution::runtime_env::RuntimeEnvBuilder;
69use datafusion::physical_plan::SendableRecordBatchStream;
70use datafusion::prelude::{
71    AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
72};
73use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
74use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
75use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
76use pyo3::IntoPyObjectExt;
77use tokio::task::JoinHandle;
78
79/// Configuration options for a SessionContext
80#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
81#[derive(Clone, Default)]
82pub struct PySessionConfig {
83    pub config: SessionConfig,
84}
85
86impl From<SessionConfig> for PySessionConfig {
87    fn from(config: SessionConfig) -> Self {
88        Self { config }
89    }
90}
91
92#[pymethods]
93impl PySessionConfig {
94    #[pyo3(signature = (config_options=None))]
95    #[new]
96    fn new(config_options: Option<HashMap<String, String>>) -> Self {
97        let mut config = SessionConfig::new();
98        if let Some(hash_map) = config_options {
99            for (k, v) in &hash_map {
100                config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
101            }
102        }
103
104        Self { config }
105    }
106
107    fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
108        Self::from(
109            self.config
110                .clone()
111                .with_create_default_catalog_and_schema(enabled),
112        )
113    }
114
115    fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
116        Self::from(
117            self.config
118                .clone()
119                .with_default_catalog_and_schema(catalog, schema),
120        )
121    }
122
123    fn with_information_schema(&self, enabled: bool) -> Self {
124        Self::from(self.config.clone().with_information_schema(enabled))
125    }
126
127    fn with_batch_size(&self, batch_size: usize) -> Self {
128        Self::from(self.config.clone().with_batch_size(batch_size))
129    }
130
131    fn with_target_partitions(&self, target_partitions: usize) -> Self {
132        Self::from(
133            self.config
134                .clone()
135                .with_target_partitions(target_partitions),
136        )
137    }
138
139    fn with_repartition_aggregations(&self, enabled: bool) -> Self {
140        Self::from(self.config.clone().with_repartition_aggregations(enabled))
141    }
142
143    fn with_repartition_joins(&self, enabled: bool) -> Self {
144        Self::from(self.config.clone().with_repartition_joins(enabled))
145    }
146
147    fn with_repartition_windows(&self, enabled: bool) -> Self {
148        Self::from(self.config.clone().with_repartition_windows(enabled))
149    }
150
151    fn with_repartition_sorts(&self, enabled: bool) -> Self {
152        Self::from(self.config.clone().with_repartition_sorts(enabled))
153    }
154
155    fn with_repartition_file_scans(&self, enabled: bool) -> Self {
156        Self::from(self.config.clone().with_repartition_file_scans(enabled))
157    }
158
159    fn with_repartition_file_min_size(&self, size: usize) -> Self {
160        Self::from(self.config.clone().with_repartition_file_min_size(size))
161    }
162
163    fn with_parquet_pruning(&self, enabled: bool) -> Self {
164        Self::from(self.config.clone().with_parquet_pruning(enabled))
165    }
166
167    fn set(&self, key: &str, value: &str) -> Self {
168        Self::from(self.config.clone().set_str(key, value))
169    }
170}
171
172/// Runtime options for a SessionContext
173#[pyclass(name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
174#[derive(Clone)]
175pub struct PyRuntimeEnvBuilder {
176    pub builder: RuntimeEnvBuilder,
177}
178
179#[pymethods]
180impl PyRuntimeEnvBuilder {
181    #[new]
182    fn new() -> Self {
183        Self {
184            builder: RuntimeEnvBuilder::default(),
185        }
186    }
187
188    fn with_disk_manager_disabled(&self) -> Self {
189        let mut runtime_builder = self.builder.clone();
190
191        let mut disk_mgr_builder = runtime_builder
192            .disk_manager_builder
193            .clone()
194            .unwrap_or_default();
195        disk_mgr_builder.set_mode(DiskManagerMode::Disabled);
196
197        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
198        Self {
199            builder: runtime_builder,
200        }
201    }
202
203    fn with_disk_manager_os(&self) -> Self {
204        let mut runtime_builder = self.builder.clone();
205
206        let mut disk_mgr_builder = runtime_builder
207            .disk_manager_builder
208            .clone()
209            .unwrap_or_default();
210        disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory);
211
212        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
213        Self {
214            builder: runtime_builder,
215        }
216    }
217
218    fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
219        let paths = paths.iter().map(|s| s.into()).collect();
220        let mut runtime_builder = self.builder.clone();
221
222        let mut disk_mgr_builder = runtime_builder
223            .disk_manager_builder
224            .clone()
225            .unwrap_or_default();
226        disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths));
227
228        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
229        Self {
230            builder: runtime_builder,
231        }
232    }
233
234    fn with_unbounded_memory_pool(&self) -> Self {
235        let builder = self.builder.clone();
236        let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
237        Self { builder }
238    }
239
240    fn with_fair_spill_pool(&self, size: usize) -> Self {
241        let builder = self.builder.clone();
242        let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
243        Self { builder }
244    }
245
246    fn with_greedy_memory_pool(&self, size: usize) -> Self {
247        let builder = self.builder.clone();
248        let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
249        Self { builder }
250    }
251
252    fn with_temp_file_path(&self, path: &str) -> Self {
253        let builder = self.builder.clone();
254        let builder = builder.with_temp_file_path(path);
255        Self { builder }
256    }
257}
258
259/// `PySQLOptions` allows you to specify options to the sql execution.
260#[pyclass(name = "SQLOptions", module = "datafusion", subclass)]
261#[derive(Clone)]
262pub struct PySQLOptions {
263    pub options: SQLOptions,
264}
265
266impl From<SQLOptions> for PySQLOptions {
267    fn from(options: SQLOptions) -> Self {
268        Self { options }
269    }
270}
271
272#[pymethods]
273impl PySQLOptions {
274    #[new]
275    fn new() -> Self {
276        let options = SQLOptions::new();
277        Self { options }
278    }
279
280    /// Should DDL data modification commands  (e.g. `CREATE TABLE`) be run? Defaults to `true`.
281    fn with_allow_ddl(&self, allow: bool) -> Self {
282        Self::from(self.options.with_allow_ddl(allow))
283    }
284
285    /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`
286    pub fn with_allow_dml(&self, allow: bool) -> Self {
287        Self::from(self.options.with_allow_dml(allow))
288    }
289
290    /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true`
291    pub fn with_allow_statements(&self, allow: bool) -> Self {
292        Self::from(self.options.with_allow_statements(allow))
293    }
294}
295
296/// `PySessionContext` is able to plan and execute DataFusion plans.
297/// It has a powerful optimizer, a physical planner for local execution, and a
298/// multi-threaded execution engine to perform the execution.
299#[pyclass(name = "SessionContext", module = "datafusion", subclass)]
300#[derive(Clone)]
301pub struct PySessionContext {
302    pub ctx: SessionContext,
303}
304
305#[pymethods]
306impl PySessionContext {
307    #[pyo3(signature = (config=None, runtime=None))]
308    #[new]
309    pub fn new(
310        config: Option<PySessionConfig>,
311        runtime: Option<PyRuntimeEnvBuilder>,
312    ) -> PyDataFusionResult<Self> {
313        let config = if let Some(c) = config {
314            c.config
315        } else {
316            SessionConfig::default().with_information_schema(true)
317        };
318        let runtime_env_builder = if let Some(c) = runtime {
319            c.builder
320        } else {
321            RuntimeEnvBuilder::default()
322        };
323        let runtime = Arc::new(runtime_env_builder.build()?);
324        let session_state = SessionStateBuilder::new()
325            .with_config(config)
326            .with_runtime_env(runtime)
327            .with_default_features()
328            .build();
329        Ok(PySessionContext {
330            ctx: SessionContext::new_with_state(session_state),
331        })
332    }
333
334    pub fn enable_url_table(&self) -> PyResult<Self> {
335        Ok(PySessionContext {
336            ctx: self.ctx.clone().enable_url_table(),
337        })
338    }
339
340    #[classmethod]
341    #[pyo3(signature = ())]
342    fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
343        Ok(Self {
344            ctx: get_global_ctx().clone(),
345        })
346    }
347
348    /// Register an object store with the given name
349    #[pyo3(signature = (scheme, store, host=None))]
350    pub fn register_object_store(
351        &mut self,
352        scheme: &str,
353        store: StorageContexts,
354        host: Option<&str>,
355    ) -> PyResult<()> {
356        // for most stores the "host" is the bucket name and can be inferred from the store
357        let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
358            StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
359            StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
360            StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
361            StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
362            StorageContexts::HTTP(http) => (http.store, http.url),
363        };
364
365        // let users override the host to match the api signature from upstream
366        let derived_host = if let Some(host) = host {
367            host
368        } else {
369            &upstream_host
370        };
371        let url_string = format!("{scheme}{derived_host}");
372        let url = Url::parse(&url_string).unwrap();
373        self.ctx.runtime_env().register_object_store(&url, store);
374        Ok(())
375    }
376
377    #[allow(clippy::too_many_arguments)]
378    #[pyo3(signature = (name, path, table_partition_cols=vec![],
379    file_extension=".parquet",
380    schema=None,
381    file_sort_order=None))]
382    pub fn register_listing_table(
383        &mut self,
384        name: &str,
385        path: &str,
386        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
387        file_extension: &str,
388        schema: Option<PyArrowType<Schema>>,
389        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
390        py: Python,
391    ) -> PyDataFusionResult<()> {
392        let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
393            .with_file_extension(file_extension)
394            .with_table_partition_cols(
395                table_partition_cols
396                    .into_iter()
397                    .map(|(name, ty)| (name, ty.0))
398                    .collect::<Vec<(String, DataType)>>(),
399            )
400            .with_file_sort_order(
401                file_sort_order
402                    .unwrap_or_default()
403                    .into_iter()
404                    .map(|e| e.into_iter().map(|f| f.into()).collect())
405                    .collect(),
406            );
407        let table_path = ListingTableUrl::parse(path)?;
408        let resolved_schema: SchemaRef = match schema {
409            Some(s) => Arc::new(s.0),
410            None => {
411                let state = self.ctx.state();
412                let schema = options.infer_schema(&state, &table_path);
413                wait_for_future(py, schema)??
414            }
415        };
416        let config = ListingTableConfig::new(table_path)
417            .with_listing_options(options)
418            .with_schema(resolved_schema);
419        let table = ListingTable::try_new(config)?;
420        self.register_table(
421            name,
422            &PyTable {
423                table: Arc::new(table),
424            },
425        )?;
426        Ok(())
427    }
428
429    pub fn register_udtf(&mut self, func: PyTableFunction) {
430        let name = func.name.clone();
431        let func = Arc::new(func);
432        self.ctx.register_udtf(&name, func);
433    }
434
435    /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
436    pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
437        let result = self.ctx.sql(query);
438        let df = wait_for_future(py, result)??;
439        Ok(PyDataFrame::new(df))
440    }
441
442    #[pyo3(signature = (query, options=None))]
443    pub fn sql_with_options(
444        &mut self,
445        query: &str,
446        options: Option<PySQLOptions>,
447        py: Python,
448    ) -> PyDataFusionResult<PyDataFrame> {
449        let options = if let Some(options) = options {
450            options.options
451        } else {
452            SQLOptions::new()
453        };
454        let result = self.ctx.sql_with_options(query, options);
455        let df = wait_for_future(py, result)??;
456        Ok(PyDataFrame::new(df))
457    }
458
459    #[pyo3(signature = (partitions, name=None, schema=None))]
460    pub fn create_dataframe(
461        &mut self,
462        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
463        name: Option<&str>,
464        schema: Option<PyArrowType<Schema>>,
465        py: Python,
466    ) -> PyDataFusionResult<PyDataFrame> {
467        let schema = if let Some(schema) = schema {
468            SchemaRef::from(schema.0)
469        } else {
470            partitions.0[0][0].schema()
471        };
472
473        let table = MemTable::try_new(schema, partitions.0)?;
474
475        // generate a random (unique) name for this table if none is provided
476        // table name cannot start with numeric digit
477        let table_name = match name {
478            Some(val) => val.to_owned(),
479            None => {
480                "c".to_owned()
481                    + Uuid::new_v4()
482                        .simple()
483                        .encode_lower(&mut Uuid::encode_buffer())
484            }
485        };
486
487        self.ctx.register_table(&*table_name, Arc::new(table))?;
488
489        let table = wait_for_future(py, self._table(&table_name))??;
490
491        let df = PyDataFrame::new(table);
492        Ok(df)
493    }
494
495    /// Create a DataFrame from an existing logical plan
496    pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
497        PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
498    }
499
500    /// Construct datafusion dataframe from Python list
501    #[pyo3(signature = (data, name=None))]
502    pub fn from_pylist(
503        &mut self,
504        data: Bound<'_, PyList>,
505        name: Option<&str>,
506    ) -> PyResult<PyDataFrame> {
507        // Acquire GIL Token
508        let py = data.py();
509
510        // Instantiate pyarrow Table object & convert to Arrow Table
511        let table_class = py.import("pyarrow")?.getattr("Table")?;
512        let args = PyTuple::new(py, &[data])?;
513        let table = table_class.call_method1("from_pylist", args)?;
514
515        // Convert Arrow Table to datafusion DataFrame
516        let df = self.from_arrow(table, name, py)?;
517        Ok(df)
518    }
519
520    /// Construct datafusion dataframe from Python dictionary
521    #[pyo3(signature = (data, name=None))]
522    pub fn from_pydict(
523        &mut self,
524        data: Bound<'_, PyDict>,
525        name: Option<&str>,
526    ) -> PyResult<PyDataFrame> {
527        // Acquire GIL Token
528        let py = data.py();
529
530        // Instantiate pyarrow Table object & convert to Arrow Table
531        let table_class = py.import("pyarrow")?.getattr("Table")?;
532        let args = PyTuple::new(py, &[data])?;
533        let table = table_class.call_method1("from_pydict", args)?;
534
535        // Convert Arrow Table to datafusion DataFrame
536        let df = self.from_arrow(table, name, py)?;
537        Ok(df)
538    }
539
540    /// Construct datafusion dataframe from Arrow Table
541    #[pyo3(signature = (data, name=None))]
542    pub fn from_arrow(
543        &mut self,
544        data: Bound<'_, PyAny>,
545        name: Option<&str>,
546        py: Python,
547    ) -> PyDataFusionResult<PyDataFrame> {
548        let (schema, batches) =
549            if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
550                // Works for any object that implements __arrow_c_stream__ in pycapsule.
551
552                let schema = stream_reader.schema().as_ref().to_owned();
553                let batches = stream_reader
554                    .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
555
556                (schema, batches)
557            } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
558                // While this says RecordBatch, it will work for any object that implements
559                // __arrow_c_array__ and returns a StructArray.
560
561                (array.schema().as_ref().to_owned(), vec![array])
562            } else {
563                return Err(crate::errors::PyDataFusionError::Common(
564                    "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
565                ));
566            };
567
568        // Because create_dataframe() expects a vector of vectors of record batches
569        // here we need to wrap the vector of record batches in an additional vector
570        let list_of_batches = PyArrowType::from(vec![batches]);
571        self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
572    }
573
574    /// Construct datafusion dataframe from pandas
575    #[allow(clippy::wrong_self_convention)]
576    #[pyo3(signature = (data, name=None))]
577    pub fn from_pandas(
578        &mut self,
579        data: Bound<'_, PyAny>,
580        name: Option<&str>,
581    ) -> PyResult<PyDataFrame> {
582        // Obtain GIL token
583        let py = data.py();
584
585        // Instantiate pyarrow Table object & convert to Arrow Table
586        let table_class = py.import("pyarrow")?.getattr("Table")?;
587        let args = PyTuple::new(py, &[data])?;
588        let table = table_class.call_method1("from_pandas", args)?;
589
590        // Convert Arrow Table to datafusion DataFrame
591        let df = self.from_arrow(table, name, py)?;
592        Ok(df)
593    }
594
595    /// Construct datafusion dataframe from polars
596    #[pyo3(signature = (data, name=None))]
597    pub fn from_polars(
598        &mut self,
599        data: Bound<'_, PyAny>,
600        name: Option<&str>,
601    ) -> PyResult<PyDataFrame> {
602        // Convert Polars dataframe to Arrow Table
603        let table = data.call_method0("to_arrow")?;
604
605        // Convert Arrow Table to datafusion DataFrame
606        let df = self.from_arrow(table, name, data.py())?;
607        Ok(df)
608    }
609
610    pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyDataFusionResult<()> {
611        self.ctx.register_table(name, table.table())?;
612        Ok(())
613    }
614
615    pub fn deregister_table(&mut self, name: &str) -> PyDataFusionResult<()> {
616        self.ctx.deregister_table(name)?;
617        Ok(())
618    }
619
620    pub fn register_catalog_provider(
621        &mut self,
622        name: &str,
623        provider: Bound<'_, PyAny>,
624    ) -> PyDataFusionResult<()> {
625        let provider = if provider.hasattr("__datafusion_catalog_provider__")? {
626            let capsule = provider
627                .getattr("__datafusion_catalog_provider__")?
628                .call0()?;
629            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
630            validate_pycapsule(capsule, "datafusion_catalog_provider")?;
631
632            let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
633            let provider: ForeignCatalogProvider = provider.into();
634            Arc::new(provider) as Arc<dyn CatalogProvider>
635        } else {
636            match provider.extract::<PyCatalog>() {
637                Ok(py_catalog) => py_catalog.catalog,
638                Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
639                    as Arc<dyn CatalogProvider>,
640            }
641        };
642
643        let _ = self.ctx.register_catalog(name, provider);
644
645        Ok(())
646    }
647
648    /// Construct datafusion dataframe from Arrow Table
649    pub fn register_table_provider(
650        &mut self,
651        name: &str,
652        provider: Bound<'_, PyAny>,
653    ) -> PyDataFusionResult<()> {
654        if provider.hasattr("__datafusion_table_provider__")? {
655            let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
656            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
657            validate_pycapsule(capsule, "datafusion_table_provider")?;
658
659            let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
660            let provider: ForeignTableProvider = provider.into();
661
662            let _ = self.ctx.register_table(name, Arc::new(provider))?;
663
664            Ok(())
665        } else {
666            Err(crate::errors::PyDataFusionError::Common(
667                "__datafusion_table_provider__ does not exist on Table Provider object."
668                    .to_string(),
669            ))
670        }
671    }
672
673    pub fn register_record_batches(
674        &mut self,
675        name: &str,
676        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
677    ) -> PyDataFusionResult<()> {
678        let schema = partitions.0[0][0].schema();
679        let table = MemTable::try_new(schema, partitions.0)?;
680        self.ctx.register_table(name, Arc::new(table))?;
681        Ok(())
682    }
683
684    #[allow(clippy::too_many_arguments)]
685    #[pyo3(signature = (name, path, table_partition_cols=vec![],
686                        parquet_pruning=true,
687                        file_extension=".parquet",
688                        skip_metadata=true,
689                        schema=None,
690                        file_sort_order=None))]
691    pub fn register_parquet(
692        &mut self,
693        name: &str,
694        path: &str,
695        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
696        parquet_pruning: bool,
697        file_extension: &str,
698        skip_metadata: bool,
699        schema: Option<PyArrowType<Schema>>,
700        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
701        py: Python,
702    ) -> PyDataFusionResult<()> {
703        let mut options = ParquetReadOptions::default()
704            .table_partition_cols(
705                table_partition_cols
706                    .into_iter()
707                    .map(|(name, ty)| (name, ty.0))
708                    .collect::<Vec<(String, DataType)>>(),
709            )
710            .parquet_pruning(parquet_pruning)
711            .skip_metadata(skip_metadata);
712        options.file_extension = file_extension;
713        options.schema = schema.as_ref().map(|x| &x.0);
714        options.file_sort_order = file_sort_order
715            .unwrap_or_default()
716            .into_iter()
717            .map(|e| e.into_iter().map(|f| f.into()).collect())
718            .collect();
719
720        let result = self.ctx.register_parquet(name, path, options);
721        wait_for_future(py, result)??;
722        Ok(())
723    }
724
725    #[allow(clippy::too_many_arguments)]
726    #[pyo3(signature = (name,
727                        path,
728                        schema=None,
729                        has_header=true,
730                        delimiter=",",
731                        schema_infer_max_records=1000,
732                        file_extension=".csv",
733                        file_compression_type=None))]
734    pub fn register_csv(
735        &mut self,
736        name: &str,
737        path: &Bound<'_, PyAny>,
738        schema: Option<PyArrowType<Schema>>,
739        has_header: bool,
740        delimiter: &str,
741        schema_infer_max_records: usize,
742        file_extension: &str,
743        file_compression_type: Option<String>,
744        py: Python,
745    ) -> PyDataFusionResult<()> {
746        let delimiter = delimiter.as_bytes();
747        if delimiter.len() != 1 {
748            return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
749                "Delimiter must be a single character",
750            )));
751        }
752
753        let mut options = CsvReadOptions::new()
754            .has_header(has_header)
755            .delimiter(delimiter[0])
756            .schema_infer_max_records(schema_infer_max_records)
757            .file_extension(file_extension)
758            .file_compression_type(parse_file_compression_type(file_compression_type)?);
759        options.schema = schema.as_ref().map(|x| &x.0);
760
761        if path.is_instance_of::<PyList>() {
762            let paths = path.extract::<Vec<String>>()?;
763            let result = self.register_csv_from_multiple_paths(name, paths, options);
764            wait_for_future(py, result)??;
765        } else {
766            let path = path.extract::<String>()?;
767            let result = self.ctx.register_csv(name, &path, options);
768            wait_for_future(py, result)??;
769        }
770
771        Ok(())
772    }
773
774    #[allow(clippy::too_many_arguments)]
775    #[pyo3(signature = (name,
776                        path,
777                        schema=None,
778                        schema_infer_max_records=1000,
779                        file_extension=".json",
780                        table_partition_cols=vec![],
781                        file_compression_type=None))]
782    pub fn register_json(
783        &mut self,
784        name: &str,
785        path: PathBuf,
786        schema: Option<PyArrowType<Schema>>,
787        schema_infer_max_records: usize,
788        file_extension: &str,
789        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
790        file_compression_type: Option<String>,
791        py: Python,
792    ) -> PyDataFusionResult<()> {
793        let path = path
794            .to_str()
795            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
796
797        let mut options = NdJsonReadOptions::default()
798            .file_compression_type(parse_file_compression_type(file_compression_type)?)
799            .table_partition_cols(
800                table_partition_cols
801                    .into_iter()
802                    .map(|(name, ty)| (name, ty.0))
803                    .collect::<Vec<(String, DataType)>>(),
804            );
805        options.schema_infer_max_records = schema_infer_max_records;
806        options.file_extension = file_extension;
807        options.schema = schema.as_ref().map(|x| &x.0);
808
809        let result = self.ctx.register_json(name, path, options);
810        wait_for_future(py, result)??;
811
812        Ok(())
813    }
814
815    #[allow(clippy::too_many_arguments)]
816    #[pyo3(signature = (name,
817                        path,
818                        schema=None,
819                        file_extension=".avro",
820                        table_partition_cols=vec![]))]
821    pub fn register_avro(
822        &mut self,
823        name: &str,
824        path: PathBuf,
825        schema: Option<PyArrowType<Schema>>,
826        file_extension: &str,
827        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
828        py: Python,
829    ) -> PyDataFusionResult<()> {
830        let path = path
831            .to_str()
832            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
833
834        let mut options = AvroReadOptions::default().table_partition_cols(
835            table_partition_cols
836                .into_iter()
837                .map(|(name, ty)| (name, ty.0))
838                .collect::<Vec<(String, DataType)>>(),
839        );
840        options.file_extension = file_extension;
841        options.schema = schema.as_ref().map(|x| &x.0);
842
843        let result = self.ctx.register_avro(name, path, options);
844        wait_for_future(py, result)??;
845
846        Ok(())
847    }
848
849    // Registers a PyArrow.Dataset
850    pub fn register_dataset(
851        &self,
852        name: &str,
853        dataset: &Bound<'_, PyAny>,
854        py: Python,
855    ) -> PyDataFusionResult<()> {
856        let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
857
858        self.ctx.register_table(name, table)?;
859
860        Ok(())
861    }
862
863    pub fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
864        self.ctx.register_udf(udf.function);
865        Ok(())
866    }
867
868    pub fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> {
869        self.ctx.register_udaf(udaf.function);
870        Ok(())
871    }
872
873    pub fn register_udwf(&mut self, udwf: PyWindowUDF) -> PyResult<()> {
874        self.ctx.register_udwf(udwf.function);
875        Ok(())
876    }
877
878    #[pyo3(signature = (name="datafusion"))]
879    pub fn catalog(&self, name: &str) -> PyResult<PyObject> {
880        let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
881            "Catalog with name {name} doesn't exist."
882        )))?;
883
884        Python::with_gil(|py| {
885            match catalog
886                .as_any()
887                .downcast_ref::<RustWrappedPyCatalogProvider>()
888            {
889                Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)),
890                None => PyCatalog::from(catalog).into_py_any(py),
891            }
892        })
893    }
894
895    pub fn catalog_names(&self) -> HashSet<String> {
896        self.ctx.catalog_names().into_iter().collect()
897    }
898
899    pub fn tables(&self) -> HashSet<String> {
900        self.ctx
901            .catalog_names()
902            .into_iter()
903            .filter_map(|name| self.ctx.catalog(&name))
904            .flat_map(move |catalog| {
905                catalog
906                    .schema_names()
907                    .into_iter()
908                    .filter_map(move |name| catalog.schema(&name))
909            })
910            .flat_map(|schema| schema.table_names())
911            .collect()
912    }
913
914    pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
915        let res = wait_for_future(py, self.ctx.table(name))
916            .map_err(|e| PyKeyError::new_err(e.to_string()))?;
917        match res {
918            Ok(df) => Ok(PyDataFrame::new(df)),
919            Err(e) => {
920                if let datafusion::error::DataFusionError::Plan(msg) = &e {
921                    if msg.contains("No table named") {
922                        return Err(PyKeyError::new_err(msg.to_string()));
923                    }
924                }
925                Err(py_datafusion_err(e))
926            }
927        }
928    }
929
930    pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
931        Ok(self.ctx.table_exist(name)?)
932    }
933
934    pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
935        Ok(PyDataFrame::new(self.ctx.read_empty()?))
936    }
937
938    pub fn session_id(&self) -> String {
939        self.ctx.session_id()
940    }
941
942    #[allow(clippy::too_many_arguments)]
943    #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
944    pub fn read_json(
945        &mut self,
946        path: PathBuf,
947        schema: Option<PyArrowType<Schema>>,
948        schema_infer_max_records: usize,
949        file_extension: &str,
950        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
951        file_compression_type: Option<String>,
952        py: Python,
953    ) -> PyDataFusionResult<PyDataFrame> {
954        let path = path
955            .to_str()
956            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
957        let mut options = NdJsonReadOptions::default()
958            .table_partition_cols(
959                table_partition_cols
960                    .into_iter()
961                    .map(|(name, ty)| (name, ty.0))
962                    .collect::<Vec<(String, DataType)>>(),
963            )
964            .file_compression_type(parse_file_compression_type(file_compression_type)?);
965        options.schema_infer_max_records = schema_infer_max_records;
966        options.file_extension = file_extension;
967        let df = if let Some(schema) = schema {
968            options.schema = Some(&schema.0);
969            let result = self.ctx.read_json(path, options);
970            wait_for_future(py, result)??
971        } else {
972            let result = self.ctx.read_json(path, options);
973            wait_for_future(py, result)??
974        };
975        Ok(PyDataFrame::new(df))
976    }
977
978    #[allow(clippy::too_many_arguments)]
979    #[pyo3(signature = (
980        path,
981        schema=None,
982        has_header=true,
983        delimiter=",",
984        schema_infer_max_records=1000,
985        file_extension=".csv",
986        table_partition_cols=vec![],
987        file_compression_type=None))]
988    pub fn read_csv(
989        &self,
990        path: &Bound<'_, PyAny>,
991        schema: Option<PyArrowType<Schema>>,
992        has_header: bool,
993        delimiter: &str,
994        schema_infer_max_records: usize,
995        file_extension: &str,
996        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
997        file_compression_type: Option<String>,
998        py: Python,
999    ) -> PyDataFusionResult<PyDataFrame> {
1000        let delimiter = delimiter.as_bytes();
1001        if delimiter.len() != 1 {
1002            return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
1003                "Delimiter must be a single character",
1004            )));
1005        };
1006
1007        let mut options = CsvReadOptions::new()
1008            .has_header(has_header)
1009            .delimiter(delimiter[0])
1010            .schema_infer_max_records(schema_infer_max_records)
1011            .file_extension(file_extension)
1012            .table_partition_cols(
1013                table_partition_cols
1014                    .into_iter()
1015                    .map(|(name, ty)| (name, ty.0))
1016                    .collect::<Vec<(String, DataType)>>(),
1017            )
1018            .file_compression_type(parse_file_compression_type(file_compression_type)?);
1019        options.schema = schema.as_ref().map(|x| &x.0);
1020
1021        if path.is_instance_of::<PyList>() {
1022            let paths = path.extract::<Vec<String>>()?;
1023            let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
1024            let result = self.ctx.read_csv(paths, options);
1025            let df = PyDataFrame::new(wait_for_future(py, result)??);
1026            Ok(df)
1027        } else {
1028            let path = path.extract::<String>()?;
1029            let result = self.ctx.read_csv(path, options);
1030            let df = PyDataFrame::new(wait_for_future(py, result)??);
1031            Ok(df)
1032        }
1033    }
1034
1035    #[allow(clippy::too_many_arguments)]
1036    #[pyo3(signature = (
1037        path,
1038        table_partition_cols=vec![],
1039        parquet_pruning=true,
1040        file_extension=".parquet",
1041        skip_metadata=true,
1042        schema=None,
1043        file_sort_order=None))]
1044    pub fn read_parquet(
1045        &self,
1046        path: &str,
1047        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1048        parquet_pruning: bool,
1049        file_extension: &str,
1050        skip_metadata: bool,
1051        schema: Option<PyArrowType<Schema>>,
1052        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1053        py: Python,
1054    ) -> PyDataFusionResult<PyDataFrame> {
1055        let mut options = ParquetReadOptions::default()
1056            .table_partition_cols(
1057                table_partition_cols
1058                    .into_iter()
1059                    .map(|(name, ty)| (name, ty.0))
1060                    .collect::<Vec<(String, DataType)>>(),
1061            )
1062            .parquet_pruning(parquet_pruning)
1063            .skip_metadata(skip_metadata);
1064        options.file_extension = file_extension;
1065        options.schema = schema.as_ref().map(|x| &x.0);
1066        options.file_sort_order = file_sort_order
1067            .unwrap_or_default()
1068            .into_iter()
1069            .map(|e| e.into_iter().map(|f| f.into()).collect())
1070            .collect();
1071
1072        let result = self.ctx.read_parquet(path, options);
1073        let df = PyDataFrame::new(wait_for_future(py, result)??);
1074        Ok(df)
1075    }
1076
1077    #[allow(clippy::too_many_arguments)]
1078    #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
1079    pub fn read_avro(
1080        &self,
1081        path: &str,
1082        schema: Option<PyArrowType<Schema>>,
1083        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1084        file_extension: &str,
1085        py: Python,
1086    ) -> PyDataFusionResult<PyDataFrame> {
1087        let mut options = AvroReadOptions::default().table_partition_cols(
1088            table_partition_cols
1089                .into_iter()
1090                .map(|(name, ty)| (name, ty.0))
1091                .collect::<Vec<(String, DataType)>>(),
1092        );
1093        options.file_extension = file_extension;
1094        let df = if let Some(schema) = schema {
1095            options.schema = Some(&schema.0);
1096            let read_future = self.ctx.read_avro(path, options);
1097            wait_for_future(py, read_future)??
1098        } else {
1099            let read_future = self.ctx.read_avro(path, options);
1100            wait_for_future(py, read_future)??
1101        };
1102        Ok(PyDataFrame::new(df))
1103    }
1104
1105    pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
1106        let df = self.ctx.read_table(table.table())?;
1107        Ok(PyDataFrame::new(df))
1108    }
1109
1110    fn __repr__(&self) -> PyResult<String> {
1111        let config = self.ctx.copied_config();
1112        let mut config_entries = config
1113            .options()
1114            .entries()
1115            .iter()
1116            .filter(|e| e.value.is_some())
1117            .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1118            .collect::<Vec<_>>();
1119        config_entries.sort();
1120        Ok(format!(
1121            "SessionContext: id={}; configs=[\n\t{}]",
1122            self.session_id(),
1123            config_entries.join("\n\t")
1124        ))
1125    }
1126
1127    /// Execute a partition of an execution plan and return a stream of record batches
1128    pub fn execute(
1129        &self,
1130        plan: PyExecutionPlan,
1131        part: usize,
1132        py: Python,
1133    ) -> PyDataFusionResult<PyRecordBatchStream> {
1134        let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1135        // create a Tokio runtime to run the async code
1136        let rt = &get_tokio_runtime().0;
1137        let plan = plan.plan.clone();
1138        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1139            rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1140        let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
1141        Ok(PyRecordBatchStream::new(stream))
1142    }
1143}
1144
1145impl PySessionContext {
1146    async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1147        self.ctx.table(name).await
1148    }
1149
1150    async fn register_csv_from_multiple_paths(
1151        &self,
1152        name: &str,
1153        table_paths: Vec<String>,
1154        options: CsvReadOptions<'_>,
1155    ) -> datafusion::common::Result<()> {
1156        let table_paths = table_paths.to_urls()?;
1157        let session_config = self.ctx.copied_config();
1158        let listing_options =
1159            options.to_listing_options(&session_config, self.ctx.copied_table_options());
1160
1161        let option_extension = listing_options.file_extension.clone();
1162
1163        if table_paths.is_empty() {
1164            return exec_err!("No table paths were provided");
1165        }
1166
1167        // check if the file extension matches the expected extension
1168        for path in &table_paths {
1169            let file_path = path.as_str();
1170            if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1171                return exec_err!(
1172                    "File path '{file_path}' does not match the expected extension '{option_extension}'"
1173                );
1174            }
1175        }
1176
1177        let resolved_schema = options
1178            .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1179            .await?;
1180
1181        let config = ListingTableConfig::new_with_multi_paths(table_paths)
1182            .with_listing_options(listing_options)
1183            .with_schema(resolved_schema);
1184        let table = ListingTable::try_new(config)?;
1185        self.ctx
1186            .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1187        Ok(())
1188    }
1189}
1190
1191pub fn parse_file_compression_type(
1192    file_compression_type: Option<String>,
1193) -> Result<FileCompressionType, PyErr> {
1194    FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1195        .map_err(|_| {
1196            PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1197        })
1198}
1199
1200impl From<PySessionContext> for SessionContext {
1201    fn from(ctx: PySessionContext) -> SessionContext {
1202        ctx.ctx
1203    }
1204}
1205
1206impl From<SessionContext> for PySessionContext {
1207    fn from(ctx: SessionContext) -> PySessionContext {
1208        PySessionContext { ctx }
1209    }
1210}