From 63ead5192b8c05cc66f078b076391ba4691befbd Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 6 Oct 2025 07:23:19 -0400 Subject: [PATCH 1/7] Intermediate work on parameterizing queries --- python/datafusion/context.py | 21 ++++++++++++++++++++- python/tests/test_sql.py | 10 ++++++++++ src/context.rs | 25 ++++++++++++++++++++++++- src/table.rs | 4 ++++ 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 0aa2f27c4..91402c909 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -27,6 +27,8 @@ except ImportError: from typing_extensions import deprecated # Python 3.12 +import uuid + import pyarrow as pa from datafusion.catalog import Catalog @@ -592,7 +594,9 @@ def register_listing_table( self._convert_file_sort_order(file_sort_order), ) - def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: + def sql( + self, query: str, options: SQLOptions | None = None, **named_params: Any + ) -> DataFrame: """Create a :py:class:`~datafusion.DataFrame` from SQL query text. Note: This API implements DDL statements such as ``CREATE TABLE`` and @@ -603,10 +607,25 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: Args: query: SQL query text. options: If provided, the query will be validated against these options. + named_params: Provides substitution in the query string. Returns: DataFrame representation of the SQL query. """ + if named_params: + for alias, param in named_params.items(): + if isinstance(param, DataFrame): + view_name = str(uuid.uuid4()).replace("-", "_") + view_name = f"view_{view_name}" + self.ctx.create_temporary_view( + view_name, param.df, replace_if_exists=True + ) + replace_str = view_name + else: + replace_str = str(param) + + query = query.replace(f"{{{alias}}}", replace_str) + if options is None: return DataFrame(self.ctx.sql(query)) return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index c383edc60..e3e3e0ac8 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -533,3 +533,13 @@ def test_register_listing_table( rd = result.to_pydict() assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2} + + +def test_parameterized_sql(ctx, tmp_path) -> None: + path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + df = ctx.read_parquet(path) + result = ctx.sql( + "SELECT COUNT(a) AS cnt FROM {replaced_df}", replaced_df=df + ).collect() + result = pa.Table.from_batches(result) + assert result.to_pydict() == {"cnt": [100]} diff --git a/src/context.rs b/src/context.rs index dc18a7676..5b42e6ad4 100644 --- a/src/context.rs +++ b/src/context.rs @@ -41,7 +41,7 @@ use crate::record_batch::PyRecordBatchStream; use crate::sql::exceptions::py_value_err; use crate::sql::logical::PyLogicalPlan; use crate::store::StorageContexts; -use crate::table::PyTable; +use crate::table::{PyTable, TempViewTable}; use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udtf::PyTableFunction; @@ -492,6 +492,29 @@ impl PySessionContext { PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone())) } + pub fn create_temporary_view( + &self, + py: Python, + name: &str, + df: PyDataFrame, + replace_if_exists: bool, + ) -> PyDataFusionResult<()> { + if self.table(name, py).is_ok() { + if replace_if_exists { + let _ = self.deregister_table(name); + } else { + exec_err!( + "Unable to create temporary view. Table with name {name} already exists." + )?; + } + } + + let table = Arc::new(TempViewTable::new(df.inner_df())); + self.ctx.register_table(name, table)?; + + Ok(()) + } + /// Construct datafusion dataframe from Python list #[pyo3(signature = (data, name=None))] pub fn from_pylist( diff --git a/src/table.rs b/src/table.rs index fdca4d3e6..713307c02 100644 --- a/src/table.rs +++ b/src/table.rs @@ -24,7 +24,11 @@ use async_trait::async_trait; use datafusion::catalog::Session; use datafusion::common::Column; use datafusion::datasource::{TableProvider, TableType}; +<<<<<<< HEAD use datafusion::logical_expr::{Expr, LogicalPlanBuilder, TableProviderFilterPushDown}; +======= +use datafusion::logical_expr::{Expr, LogicalPlanBuilder}; +>>>>>>> 3c9a96c (Intermediate work on temp views and parameterizing queries) use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::DataFrame; use pyo3::prelude::*; From dd68364106a47acb0f1705845596a299f0320e95 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 9 Oct 2025 18:48:46 -0400 Subject: [PATCH 2/7] Reworking to do token parsing of sql query instead of string manipulation --- python/datafusion/context.py | 52 +++++++++++++------- python/tests/test_sql.py | 20 ++++++-- src/context.rs | 53 +++++++++++++++----- src/sql.rs | 1 + src/sql/util.rs | 93 ++++++++++++++++++++++++++++++++++++ 5 files changed, 186 insertions(+), 33 deletions(-) create mode 100644 src/sql/util.rs diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 91402c909..822f7364e 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -27,7 +27,6 @@ except ImportError: from typing_extensions import deprecated # Python 3.12 -import uuid import pyarrow as pa @@ -612,25 +611,41 @@ def sql( Returns: DataFrame representation of the SQL query. """ - if named_params: - for alias, param in named_params.items(): - if isinstance(param, DataFrame): - view_name = str(uuid.uuid4()).replace("-", "_") - view_name = f"view_{view_name}" - self.ctx.create_temporary_view( - view_name, param.df, replace_if_exists=True - ) - replace_str = view_name - else: - replace_str = str(param) - query = query.replace(f"{{{alias}}}", replace_str) + def scalar_params(**p: Any) -> list[tuple[str, pa.Scalar]]: + if p is None: + return [] + + return [ + (name, pa.scalar(value)) + for (name, value) in p.items() + if not isinstance(value, DataFrame) + ] - if options is None: - return DataFrame(self.ctx.sql(query)) - return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) + def dataframe_params(**p: Any) -> list[tuple[str, DataFrame]]: + if p is None: + return [] - def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: + return [ + (name, value.df) + for (name, value) in p.items() + if isinstance(value, DataFrame) + ] + + options_raw = options.options_internal if options is not None else None + + return DataFrame( + self.ctx.sql_with_options( + query, + options=options_raw, + scalar_params=scalar_params(**named_params), + dataframe_params=dataframe_params(**named_params), + ) + ) + + def sql_with_options( + self, query: str, options: SQLOptions, **named_params: Any + ) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text. This function will first validate that the query is allowed by the @@ -639,11 +654,12 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: Args: query: SQL query text. options: SQL options. + named_params: Provides substitution in the query string. Returns: DataFrame representation of the SQL query. """ - return self.sql(query, options) + return self.sql(query, options, **named_params) def create_dataframe( self, diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index e3e3e0ac8..688f95b07 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -21,7 +21,7 @@ import pyarrow as pa import pyarrow.dataset as ds import pytest -from datafusion import col, udf +from datafusion import SessionContext, col, udf from datafusion.object_store import Http from pyarrow.csv import write_csv @@ -535,11 +535,25 @@ def test_register_listing_table( assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2} -def test_parameterized_sql(ctx, tmp_path) -> None: +def test_parameterized_df_in_sql(ctx, tmp_path) -> None: path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + df = ctx.read_parquet(path) result = ctx.sql( - "SELECT COUNT(a) AS cnt FROM {replaced_df}", replaced_df=df + "SELECT COUNT(a) AS cnt FROM $replaced_df", replaced_df=df ).collect() result = pa.Table.from_batches(result) assert result.to_pydict() == {"cnt": [100]} + + +def test_parameterized_pass_through_in_sql(ctx: SessionContext) -> None: + # Test the parameters that should be handled by the parser rather + # than our manipulation of the query string by searching for tokens + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 4])], + names=["a"], + ) + + ctx.register_record_batches("t", [[batch]]) + result = ctx.sql("SELECT a FROM t WHERE a < $val", val=3) + assert result.to_pydict() == {"a": [1, 2]} diff --git a/src/context.rs b/src/context.rs index 5b42e6ad4..40f1ba0a8 100644 --- a/src/context.rs +++ b/src/context.rs @@ -32,14 +32,16 @@ use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider}; +use crate::common::data_type::PyScalarValue; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; -use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::expr::sort_expr::PySortExpr; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::exceptions::py_value_err; use crate::sql::logical::PyLogicalPlan; +use crate::sql::util::replace_placeholders_with_table_names; use crate::store::StorageContexts; use crate::table::{PyTable, TempViewTable}; use crate::udaf::PyAggregateUDF; @@ -427,27 +429,54 @@ impl PySessionContext { self.ctx.register_udtf(&name, func); } - /// Returns a PyDataFrame whose plan corresponds to the SQL statement. - pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult { - let result = self.ctx.sql(query); - let df = wait_for_future(py, result)??; - Ok(PyDataFrame::new(df)) - } - - #[pyo3(signature = (query, options=None))] + #[pyo3(signature = (query, options=None, scalar_params=vec![], dataframe_params=vec![]))] pub fn sql_with_options( &self, + py: Python, query: &str, options: Option, - py: Python, + scalar_params: Vec<(String, PyScalarValue)>, + dataframe_params: Vec<(String, PyDataFrame)>, ) -> PyDataFusionResult { let options = if let Some(options) = options { options.options } else { SQLOptions::new() }; - let result = self.ctx.sql_with_options(query, options); - let df = wait_for_future(py, result)??; + + let scalar_params = scalar_params + .into_iter() + .map(|(name, value)| (name, ScalarValue::from(value))) + .collect::>(); + + let dataframe_params = dataframe_params + .into_iter() + .map(|(name, df)| { + let uuid = Uuid::new_v4().to_string().replace("-", ""); + let view_name = format!("view_{uuid}"); + + self.create_temporary_view(py, view_name.as_str(), df, true)?; + Ok((name, view_name)) + }) + .collect::, PyDataFusionError>>()?; + + let state = self.ctx.state(); + let dialect = state.config().options().sql_parser.dialect.as_str(); + + let query = replace_placeholders_with_table_names(query, dialect, dataframe_params)?; + + println!("using scalar params: {scalar_params:?}"); + let df = wait_for_future(py, async { + self.ctx + .sql_with_options(&query, options) + .await + .map_err(|err| { + println!("error before param replacement: {}", err); + err + })? + .with_param_values(scalar_params) + })??; + Ok(PyDataFrame::new(df)) } diff --git a/src/sql.rs b/src/sql.rs index 9f1fe81be..dea9b566a 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -17,3 +17,4 @@ pub mod exceptions; pub mod logical; +pub(crate) mod util; diff --git a/src/sql/util.rs b/src/sql/util.rs new file mode 100644 index 000000000..34cd62ecd --- /dev/null +++ b/src/sql/util.rs @@ -0,0 +1,93 @@ +use datafusion::common::{internal_err, plan_datafusion_err, DataFusionError}; +use datafusion::logical_expr::sqlparser::dialect::dialect_from_str; +use datafusion::sql::sqlparser::dialect::Dialect; +use datafusion::sql::sqlparser::keywords::Keyword; +use datafusion::sql::sqlparser::parser::Parser; +use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer, Word}; +use std::collections::HashMap; + +fn value_from_replacements( + placeholder: &str, + replacements: &HashMap, +) -> Option { + if let Some(pattern) = placeholder.strip_prefix("$") { + replacements.get(pattern).map(|replacement| { + Token::Word(Word { + value: replacement.to_owned(), + quote_style: None, + keyword: Keyword::NoKeyword, + }) + }) + } else { + None + } +} + +fn table_names_are_valid(dialect: &dyn Dialect, replacements: &HashMap) -> bool { + for name in replacements.values() { + let tokens = Tokenizer::new(dialect, name).tokenize().unwrap(); + if tokens.len() != 1 { + // We should get exactly one token for our temporary table name + return false; + } + + if let Token::Word(word) = &tokens[0] { + // Generated table names should be not quoted or have keywords + if word.quote_style.is_some() || word.keyword != Keyword::NoKeyword { + return false; + } + } else { + // We should always parse table names to a Word + return false; + } + } + + true +} + +pub(crate) fn replace_placeholders_with_table_names( + query: &str, + dialect: &str, + replacements: HashMap, +) -> Result { + let dialect = dialect_from_str(dialect).ok_or_else(|| { + plan_datafusion_err!( + "Unsupported SQL dialect: {dialect}. Available dialects: \ + Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ + MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks." + ) + })?; + + if !table_names_are_valid(dialect.as_ref(), &replacements) { + return internal_err!("Invalid generated table name when replacing placeholders"); + } + let tokens = Tokenizer::new(dialect.as_ref(), query).tokenize().unwrap(); + + let replaced_tokens = tokens + .into_iter() + .map(|token| { + if let Token::Word(word) = &token { + let Word { + value, + quote_style: _, + keyword: _, + } = word; + + value_from_replacements(value, &replacements).unwrap_or(token) + } else if let Token::Placeholder(placeholder) = &token { + value_from_replacements(placeholder, &replacements).unwrap_or(token) + } else { + token + } + }) + .collect::>(); + + Ok(Parser::new(dialect.as_ref()) + .with_tokens(replaced_tokens) + .parse_statements() + .map_err(|err| DataFusionError::External(Box::new(err)))? + .into_iter() + .map(|s| s.to_string()) + .collect::>() + .join(" ")) +} From cefbdb0055d0b97c33b9d683f6a26cd42c7c6455 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 12 Oct 2025 07:39:40 -0400 Subject: [PATCH 3/7] Switching to explicit param_values or named parameters that will perform string replacement via parsed tokens --- python/datafusion/context.py | 76 +++++++++++++++++--------- src/context.rs | 74 +++++++------------------ src/sql/util.rs | 102 ++++++++++++++--------------------- 3 files changed, 108 insertions(+), 144 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 822f7364e..384b14ed4 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -19,6 +19,7 @@ from __future__ import annotations +import uuid import warnings from typing import TYPE_CHECKING, Any, Protocol @@ -594,10 +595,18 @@ def register_listing_table( ) def sql( - self, query: str, options: SQLOptions | None = None, **named_params: Any + self, + query: str, + options: SQLOptions | None = None, + param_values: dict[str, Any] | None = None, + **named_params: Any, ) -> DataFrame: """Create a :py:class:`~datafusion.DataFrame` from SQL query text. + See the online documentation for a description of how to perform + parameterized substitution via either the param_values option + or passing in named parameters. + Note: This API implements DDL statements such as ``CREATE TABLE`` and ``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory default implementation.See @@ -606,31 +615,38 @@ def sql( Args: query: SQL query text. options: If provided, the query will be validated against these options. - named_params: Provides substitution in the query string. + param_values: Provides substitution of scalar values in the query + after parsing. + named_params: Provides string or DataFrame substitution in the query string. Returns: DataFrame representation of the SQL query. """ - def scalar_params(**p: Any) -> list[tuple[str, pa.Scalar]]: - if p is None: - return [] - - return [ - (name, pa.scalar(value)) - for (name, value) in p.items() - if not isinstance(value, DataFrame) - ] - - def dataframe_params(**p: Any) -> list[tuple[str, DataFrame]]: - if p is None: - return [] - - return [ - (name, value.df) - for (name, value) in p.items() - if isinstance(value, DataFrame) - ] + def value_to_scalar(value) -> pa.Scalar: + if isinstance(value, pa.Scalar): + return value + return pa.scalar(value) + + def value_to_string(value) -> str: + if isinstance(value, DataFrame): + view_name = str(uuid.uuid4()).replace("-", "_") + view_name = f"view_{view_name}" + view = value.df.into_view(temporary=True) + self.ctx.register_table(view_name, view) + return view_name + return str(value) + + param_values = ( + {name: value_to_scalar(value) for (name, value) in param_values} + if param_values is not None + else {} + ) + param_strings = ( + {name: value_to_string(value) for (name, value) in named_params.items()} + if named_params is not None + else {} + ) options_raw = options.options_internal if options is not None else None @@ -638,13 +654,17 @@ def dataframe_params(**p: Any) -> list[tuple[str, DataFrame]]: self.ctx.sql_with_options( query, options=options_raw, - scalar_params=scalar_params(**named_params), - dataframe_params=dataframe_params(**named_params), + param_values=param_values, + param_strings=param_strings, ) ) def sql_with_options( - self, query: str, options: SQLOptions, **named_params: Any + self, + query: str, + options: SQLOptions, + param_values: dict[str, Any] | None = None, + **named_params: Any, ) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text. @@ -654,12 +674,16 @@ def sql_with_options( Args: query: SQL query text. options: SQL options. - named_params: Provides substitution in the query string. + param_values: Provides substitution of scalar values in the query + after parsing. + named_params: Provides string or DataFrame substitution in the query string. Returns: DataFrame representation of the SQL query. """ - return self.sql(query, options, **named_params) + return self.sql( + query, options=options, param_values=param_values, **named_params + ) def create_dataframe( self, diff --git a/src/context.rs b/src/context.rs index 40f1ba0a8..f4008af6d 100644 --- a/src/context.rs +++ b/src/context.rs @@ -35,15 +35,15 @@ use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider}; use crate::common::data_type::PyScalarValue; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; -use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; use crate::expr::sort_expr::PySortExpr; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::exceptions::py_value_err; use crate::sql::logical::PyLogicalPlan; -use crate::sql::util::replace_placeholders_with_table_names; +use crate::sql::util::replace_placeholders_with_strings; use crate::store::StorageContexts; -use crate::table::{PyTable, TempViewTable}; +use crate::table::PyTable; use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udtf::PyTableFunction; @@ -429,14 +429,14 @@ impl PySessionContext { self.ctx.register_udtf(&name, func); } - #[pyo3(signature = (query, options=None, scalar_params=vec![], dataframe_params=vec![]))] + #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))] pub fn sql_with_options( &self, py: Python, - query: &str, + mut query: String, options: Option, - scalar_params: Vec<(String, PyScalarValue)>, - dataframe_params: Vec<(String, PyDataFrame)>, + param_values: HashMap, + param_strings: HashMap, ) -> PyDataFusionResult { let options = if let Some(options) = options { options.options @@ -444,39 +444,26 @@ impl PySessionContext { SQLOptions::new() }; - let scalar_params = scalar_params + let param_values = param_values .into_iter() .map(|(name, value)| (name, ScalarValue::from(value))) - .collect::>(); - - let dataframe_params = dataframe_params - .into_iter() - .map(|(name, df)| { - let uuid = Uuid::new_v4().to_string().replace("-", ""); - let view_name = format!("view_{uuid}"); - - self.create_temporary_view(py, view_name.as_str(), df, true)?; - Ok((name, view_name)) - }) - .collect::, PyDataFusionError>>()?; + .collect::>(); let state = self.ctx.state(); let dialect = state.config().options().sql_parser.dialect.as_str(); - let query = replace_placeholders_with_table_names(query, dialect, dataframe_params)?; - - println!("using scalar params: {scalar_params:?}"); - let df = wait_for_future(py, async { - self.ctx - .sql_with_options(&query, options) - .await - .map_err(|err| { - println!("error before param replacement: {}", err); - err - })? - .with_param_values(scalar_params) + if !param_strings.is_empty() { + query = replace_placeholders_with_strings(&query, dialect, param_strings)?; + } + + let mut df = wait_for_future(py, async { + self.ctx.sql_with_options(&query, options).await })??; + if !param_values.is_empty() { + df = df.with_param_values(param_values)?; + } + Ok(PyDataFrame::new(df)) } @@ -521,29 +508,6 @@ impl PySessionContext { PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone())) } - pub fn create_temporary_view( - &self, - py: Python, - name: &str, - df: PyDataFrame, - replace_if_exists: bool, - ) -> PyDataFusionResult<()> { - if self.table(name, py).is_ok() { - if replace_if_exists { - let _ = self.deregister_table(name); - } else { - exec_err!( - "Unable to create temporary view. Table with name {name} already exists." - )?; - } - } - - let table = Arc::new(TempViewTable::new(df.inner_df())); - self.ctx.register_table(name, table)?; - - Ok(()) - } - /// Construct datafusion dataframe from Python list #[pyo3(signature = (data, name=None))] pub fn from_pylist( diff --git a/src/sql/util.rs b/src/sql/util.rs index 34cd62ecd..dfaeb9c38 100644 --- a/src/sql/util.rs +++ b/src/sql/util.rs @@ -1,93 +1,69 @@ -use datafusion::common::{internal_err, plan_datafusion_err, DataFusionError}; +use datafusion::common::{exec_err, plan_datafusion_err, DataFusionError}; use datafusion::logical_expr::sqlparser::dialect::dialect_from_str; use datafusion::sql::sqlparser::dialect::Dialect; -use datafusion::sql::sqlparser::keywords::Keyword; use datafusion::sql::sqlparser::parser::Parser; -use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer, Word}; +use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer}; use std::collections::HashMap; -fn value_from_replacements( +fn tokens_from_replacements( placeholder: &str, - replacements: &HashMap, -) -> Option { + replacements: &HashMap>, +) -> Option> { if let Some(pattern) = placeholder.strip_prefix("$") { - replacements.get(pattern).map(|replacement| { - Token::Word(Word { - value: replacement.to_owned(), - quote_style: None, - keyword: Keyword::NoKeyword, - }) - }) + replacements.get(pattern).cloned() } else { None } } -fn table_names_are_valid(dialect: &dyn Dialect, replacements: &HashMap) -> bool { - for name in replacements.values() { - let tokens = Tokenizer::new(dialect, name).tokenize().unwrap(); - if tokens.len() != 1 { - // We should get exactly one token for our temporary table name - return false; - } - - if let Token::Word(word) = &tokens[0] { - // Generated table names should be not quoted or have keywords - if word.quote_style.is_some() || word.keyword != Keyword::NoKeyword { - return false; - } - } else { - // We should always parse table names to a Word - return false; - } - } - - true +fn get_tokens_for_string_replacement( + dialect: &dyn Dialect, + replacements: HashMap, +) -> Result>, DataFusionError> { + replacements + .into_iter() + .map(|(name, value)| { + let tokens = Tokenizer::new(dialect, &value) + .tokenize() + .map_err(|err| DataFusionError::External(err.into()))?; + Ok((name, tokens)) + }) + .collect() } -pub(crate) fn replace_placeholders_with_table_names( +pub(crate) fn replace_placeholders_with_strings( query: &str, dialect: &str, replacements: HashMap, ) -> Result { - let dialect = dialect_from_str(dialect).ok_or_else(|| { - plan_datafusion_err!( - "Unsupported SQL dialect: {dialect}. Available dialects: \ - Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ - MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks." - ) - })?; + let dialect = dialect_from_str(dialect) + .ok_or_else(|| plan_datafusion_err!("Unsupported SQL dialect: {dialect}."))?; - if !table_names_are_valid(dialect.as_ref(), &replacements) { - return internal_err!("Invalid generated table name when replacing placeholders"); - } - let tokens = Tokenizer::new(dialect.as_ref(), query).tokenize().unwrap(); + let replacements = get_tokens_for_string_replacement(dialect.as_ref(), replacements)?; + + let tokens = Tokenizer::new(dialect.as_ref(), query) + .tokenize() + .map_err(|err| DataFusionError::External(err.into()))?; let replaced_tokens = tokens .into_iter() - .map(|token| { - if let Token::Word(word) = &token { - let Word { - value, - quote_style: _, - keyword: _, - } = word; - - value_from_replacements(value, &replacements).unwrap_or(token) - } else if let Token::Placeholder(placeholder) = &token { - value_from_replacements(placeholder, &replacements).unwrap_or(token) + .flat_map(|token| { + if let Token::Placeholder(placeholder) = &token { + tokens_from_replacements(placeholder, &replacements).unwrap_or(vec![token]) } else { - token + vec![token] } }) .collect::>(); - Ok(Parser::new(dialect.as_ref()) + let statement = Parser::new(dialect.as_ref()) .with_tokens(replaced_tokens) .parse_statements() - .map_err(|err| DataFusionError::External(Box::new(err)))? - .into_iter() - .map(|s| s.to_string()) - .collect::>() - .join(" ")) + .map_err(|err| DataFusionError::External(Box::new(err)))?; + + if statement.len() != 1 { + return exec_err!("placeholder replacement should return exactly one statement"); + } + + Ok(statement[0].to_string()) } From 23721b50b1e5ded93cfc19285882828bc4dd6de5 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 12 Oct 2025 08:48:27 -0400 Subject: [PATCH 4/7] Add additional unit tests for parameterized queries --- python/datafusion/context.py | 2 +- python/tests/test_sql.py | 29 ++++++++++++++++++++++++----- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 384b14ed4..2466e338c 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -638,7 +638,7 @@ def value_to_string(value) -> str: return str(value) param_values = ( - {name: value_to_scalar(value) for (name, value) in param_values} + {name: value_to_scalar(value) for (name, value) in param_values.items()} if param_values is not None else {} ) diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 688f95b07..14ab7636b 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -535,18 +535,20 @@ def test_register_listing_table( assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2} -def test_parameterized_df_in_sql(ctx, tmp_path) -> None: +def test_parameterized_named_params(ctx, tmp_path) -> None: path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) df = ctx.read_parquet(path) result = ctx.sql( - "SELECT COUNT(a) AS cnt FROM $replaced_df", replaced_df=df + "SELECT COUNT(a) AS cnt, $lit_val as lit_val FROM $replaced_df", + lit_val=3, + replaced_df=df, ).collect() result = pa.Table.from_batches(result) - assert result.to_pydict() == {"cnt": [100]} + assert result.to_pydict() == {"cnt": [100], "lit_val": [3]} -def test_parameterized_pass_through_in_sql(ctx: SessionContext) -> None: +def test_parameterized_param_values(ctx: SessionContext) -> None: # Test the parameters that should be handled by the parser rather # than our manipulation of the query string by searching for tokens batch = pa.RecordBatch.from_arrays( @@ -555,5 +557,22 @@ def test_parameterized_pass_through_in_sql(ctx: SessionContext) -> None: ) ctx.register_record_batches("t", [[batch]]) - result = ctx.sql("SELECT a FROM t WHERE a < $val", val=3) + result = ctx.sql("SELECT a FROM t WHERE a < $val", param_values={"val": 3}) + assert result.to_pydict() == {"a": [1, 2]} + + +def test_parameterized_mixed_query(ctx: SessionContext) -> None: + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 4])], + names=["a"], + ) + ctx.register_record_batches("t", [[batch]]) + registered_df = ctx.table("t") + + result = ctx.sql( + "SELECT $col_name FROM $df WHERE a < $val", + param_values={"val": 3}, + df=registered_df, + col_name="a", + ) assert result.to_pydict() == {"a": [1, 2]} From aa660ff4e87a997b527a7b0e0416ec28393fcf11 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 12 Oct 2025 10:34:07 -0400 Subject: [PATCH 5/7] merge conflict --- src/table.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/table.rs b/src/table.rs index 713307c02..fdca4d3e6 100644 --- a/src/table.rs +++ b/src/table.rs @@ -24,11 +24,7 @@ use async_trait::async_trait; use datafusion::catalog::Session; use datafusion::common::Column; use datafusion::datasource::{TableProvider, TableType}; -<<<<<<< HEAD use datafusion::logical_expr::{Expr, LogicalPlanBuilder, TableProviderFilterPushDown}; -======= -use datafusion::logical_expr::{Expr, LogicalPlanBuilder}; ->>>>>>> 3c9a96c (Intermediate work on temp views and parameterizing queries) use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::DataFrame; use pyo3::prelude::*; From 6979c0a65027f7e1ba5e3dfeb85691dbce8fe785 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 12 Oct 2025 11:44:26 -0400 Subject: [PATCH 6/7] license text --- src/sql/util.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/sql/util.rs b/src/sql/util.rs index dfaeb9c38..438c526a2 100644 --- a/src/sql/util.rs +++ b/src/sql/util.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use datafusion::common::{exec_err, plan_datafusion_err, DataFusionError}; use datafusion::logical_expr::sqlparser::dialect::dialect_from_str; use datafusion::sql::sqlparser::dialect::Dialect; From 3c73ed3a04135a84e17a47b9f98b51177528452b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 13 Oct 2025 08:06:36 -0400 Subject: [PATCH 7/7] Add documentation --- docs/source/user-guide/configuration.rst | 2 + docs/source/user-guide/sql.rst | 91 ++++++++++++++++++++++-- 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/docs/source/user-guide/configuration.rst b/docs/source/user-guide/configuration.rst index 5425a040d..f8e613cd4 100644 --- a/docs/source/user-guide/configuration.rst +++ b/docs/source/user-guide/configuration.rst @@ -15,6 +15,8 @@ .. specific language governing permissions and limitations .. under the License. +.. _configuration: + Configuration ============= diff --git a/docs/source/user-guide/sql.rst b/docs/source/user-guide/sql.rst index 6fa7f0c6a..b4bfb9611 100644 --- a/docs/source/user-guide/sql.rst +++ b/docs/source/user-guide/sql.rst @@ -23,17 +23,100 @@ DataFusion also offers a SQL API, read the full reference `here `_, +but allow passing named parameters into a SQL query. Consider this simple +example. + +.. ipython:: python + + def show_attacks(ctx: SessionContext, threshold: int) -> None: + ctx.sql( + 'SELECT "Name", "Attack" FROM pokemon WHERE "Attack" > $val', val=threshold + ).show(num=5) + show_attacks(ctx, 75) + +When passing parameters like the example above we convert the Python objects +into their string representation. We also have special case handling +for :py:class:`~datafusion.dataframe.DataFrame` objects, since they cannot simply +be turned into string representations for an SQL query. In these cases we +will register a temporary view in the :py:class:`~datafusion.context.SessionContext` +using a generated table name. + +The formatting for passing string replacement objects is to precede the +variable name with a single ``$``. This works for all dialects in +the SQL parser except ``hive`` and ``mysql``. Since these dialects do not +support named placeholders, we are unable to do this type of replacement. +We recommend either switching to another dialect or using Python +f-string style replacement. + +.. warning:: + + To support DataFrame parameterized queries, your session must support + registration of temporary views. The default + :py:class:`~datafusion.catalog.CatalogProvider` and + :py:class:`~datafusion.catalog.SchemaProvider` do have this capability. + If you have implemented custom providers, it is important that temporary + views do not persist across :py:class:`~datafusion.context.SessionContext` + or you may get unintended consequences. + +The following example shows passing in both a :py:class:`~datafusion.dataframe.DataFrame` +object as well as a Python object to be used in parameterized replacement. + +.. ipython:: python + + def show_column( + ctx: SessionContext, column: str, df: DataFrame, threshold: int + ) -> None: + ctx.sql( + 'SELECT "Name", $col FROM $df WHERE $col > $val', + col=column, + df=df, + val=threshold, + ).show(num=5) + df = ctx.table("pokemon") + show_column(ctx, '"Defense"', df, 75) + +The approach implemented for conversion of variables into a SQL query +relies on string conversion. This has the potential for data loss, +specifically for cases like floating point numbers. If you need to pass +variables into a parameterized query and it is important to maintain the +original value without conversion to a string, then you can use the +optional parameter ``param_values`` to specify these. This parameter +expects a dictionary mapping from the parameter name to a Python +object. Those objects will be cast into a +`PyArrow Scalar Value `_. + +Using ``param_values`` will rely on the SQL dialect you have configured +for your session. This can be set using the :ref:`configuration options ` +of your :py:class:`~datafusion.context.SessionContext`. Similar to how +`prepared statements `_ +work, these parameters are limited to places where you would pass in a +scalar value, such as a comparison. + +.. ipython:: python + + def param_attacks(ctx: SessionContext, threshold: int) -> None: + ctx.sql( + 'SELECT "Name", "Attack" FROM pokemon WHERE "Attack" > $val', + param_values={"val": threshold}, + ).show(num=5) + param_attacks(ctx, 75)