@@ -32,14 +32,16 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
3232use pyo3:: prelude:: * ;
3333
3434use crate :: catalog:: { PyCatalog , RustWrappedPyCatalogProvider } ;
35+ use crate :: common:: data_type:: PyScalarValue ;
3536use crate :: dataframe:: PyDataFrame ;
3637use crate :: dataset:: Dataset ;
37- use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionResult } ;
38+ use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError , PyDataFusionResult } ;
3839use crate :: expr:: sort_expr:: PySortExpr ;
3940use crate :: physical_plan:: PyExecutionPlan ;
4041use crate :: record_batch:: PyRecordBatchStream ;
4142use crate :: sql:: exceptions:: py_value_err;
4243use crate :: sql:: logical:: PyLogicalPlan ;
44+ use crate :: sql:: util:: replace_placeholders_with_table_names;
4345use crate :: store:: StorageContexts ;
4446use crate :: table:: { PyTable , TempViewTable } ;
4547use crate :: udaf:: PyAggregateUDF ;
@@ -427,27 +429,54 @@ impl PySessionContext {
427429 self . ctx . register_udtf ( & name, func) ;
428430 }
429431
430- /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
431- pub fn sql ( & self , query : & str , py : Python ) -> PyDataFusionResult < PyDataFrame > {
432- let result = self . ctx . sql ( query) ;
433- let df = wait_for_future ( py, result) ??;
434- Ok ( PyDataFrame :: new ( df) )
435- }
436-
437- #[ pyo3( signature = ( query, options=None ) ) ]
432+ #[ pyo3( signature = ( query, options=None , scalar_params=vec![ ] , dataframe_params=vec![ ] ) ) ]
438433 pub fn sql_with_options (
439434 & self ,
435+ py : Python ,
440436 query : & str ,
441437 options : Option < PySQLOptions > ,
442- py : Python ,
438+ scalar_params : Vec < ( String , PyScalarValue ) > ,
439+ dataframe_params : Vec < ( String , PyDataFrame ) > ,
443440 ) -> PyDataFusionResult < PyDataFrame > {
444441 let options = if let Some ( options) = options {
445442 options. options
446443 } else {
447444 SQLOptions :: new ( )
448445 } ;
449- let result = self . ctx . sql_with_options ( query, options) ;
450- let df = wait_for_future ( py, result) ??;
446+
447+ let scalar_params = scalar_params
448+ . into_iter ( )
449+ . map ( |( name, value) | ( name, ScalarValue :: from ( value) ) )
450+ . collect :: < Vec < _ > > ( ) ;
451+
452+ let dataframe_params = dataframe_params
453+ . into_iter ( )
454+ . map ( |( name, df) | {
455+ let uuid = Uuid :: new_v4 ( ) . to_string ( ) . replace ( "-" , "" ) ;
456+ let view_name = format ! ( "view_{uuid}" ) ;
457+
458+ self . create_temporary_view ( py, view_name. as_str ( ) , df, true ) ?;
459+ Ok ( ( name, view_name) )
460+ } )
461+ . collect :: < Result < HashMap < _ , _ > , PyDataFusionError > > ( ) ?;
462+
463+ let state = self . ctx . state ( ) ;
464+ let dialect = state. config ( ) . options ( ) . sql_parser . dialect . as_str ( ) ;
465+
466+ let query = replace_placeholders_with_table_names ( query, dialect, dataframe_params) ?;
467+
468+ println ! ( "using scalar params: {scalar_params:?}" ) ;
469+ let df = wait_for_future ( py, async {
470+ self . ctx
471+ . sql_with_options ( & query, options)
472+ . await
473+ . map_err ( |err| {
474+ println ! ( "error before param replacement: {}" , err) ;
475+ err
476+ } ) ?
477+ . with_param_values ( scalar_params)
478+ } ) ??;
479+
451480 Ok ( PyDataFrame :: new ( df) )
452481 }
453482
0 commit comments