diff --git a/docs/source/user-guide/configuration.rst b/docs/source/user-guide/configuration.rst index 5425a040d..f8e613cd4 100644 --- a/docs/source/user-guide/configuration.rst +++ b/docs/source/user-guide/configuration.rst @@ -15,6 +15,8 @@ .. specific language governing permissions and limitations .. under the License. +.. _configuration: + Configuration ============= diff --git a/docs/source/user-guide/sql.rst b/docs/source/user-guide/sql.rst index 6fa7f0c6a..b4bfb9611 100644 --- a/docs/source/user-guide/sql.rst +++ b/docs/source/user-guide/sql.rst @@ -23,17 +23,100 @@ DataFusion also offers a SQL API, read the full reference `here `_, +but allow passing named parameters into a SQL query. Consider this simple +example. + +.. ipython:: python + + def show_attacks(ctx: SessionContext, threshold: int) -> None: + ctx.sql( + 'SELECT "Name", "Attack" FROM pokemon WHERE "Attack" > $val', val=threshold + ).show(num=5) + show_attacks(ctx, 75) + +When passing parameters like the example above we convert the Python objects +into their string representation. We also have special case handling +for :py:class:`~datafusion.dataframe.DataFrame` objects, since they cannot simply +be turned into string representations for an SQL query. In these cases we +will register a temporary view in the :py:class:`~datafusion.context.SessionContext` +using a generated table name. + +The formatting for passing string replacement objects is to precede the +variable name with a single ``$``. This works for all dialects in +the SQL parser except ``hive`` and ``mysql``. Since these dialects do not +support named placeholders, we are unable to do this type of replacement. +We recommend either switching to another dialect or using Python +f-string style replacement. + +.. warning:: + + To support DataFrame parameterized queries, your session must support + registration of temporary views. The default + :py:class:`~datafusion.catalog.CatalogProvider` and + :py:class:`~datafusion.catalog.SchemaProvider` do have this capability. + If you have implemented custom providers, it is important that temporary + views do not persist across :py:class:`~datafusion.context.SessionContext` + or you may get unintended consequences. + +The following example shows passing in both a :py:class:`~datafusion.dataframe.DataFrame` +object as well as a Python object to be used in parameterized replacement. + +.. ipython:: python + + def show_column( + ctx: SessionContext, column: str, df: DataFrame, threshold: int + ) -> None: + ctx.sql( + 'SELECT "Name", $col FROM $df WHERE $col > $val', + col=column, + df=df, + val=threshold, + ).show(num=5) + df = ctx.table("pokemon") + show_column(ctx, '"Defense"', df, 75) + +The approach implemented for conversion of variables into a SQL query +relies on string conversion. This has the potential for data loss, +specifically for cases like floating point numbers. If you need to pass +variables into a parameterized query and it is important to maintain the +original value without conversion to a string, then you can use the +optional parameter ``param_values`` to specify these. This parameter +expects a dictionary mapping from the parameter name to a Python +object. Those objects will be cast into a +`PyArrow Scalar Value `_. + +Using ``param_values`` will rely on the SQL dialect you have configured +for your session. This can be set using the :ref:`configuration options ` +of your :py:class:`~datafusion.context.SessionContext`. Similar to how +`prepared statements `_ +work, these parameters are limited to places where you would pass in a +scalar value, such as a comparison. + +.. ipython:: python + + def param_attacks(ctx: SessionContext, threshold: int) -> None: + ctx.sql( + 'SELECT "Name", "Attack" FROM pokemon WHERE "Attack" > $val', + param_values={"val": threshold}, + ).show(num=5) + param_attacks(ctx, 75) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 0aa2f27c4..2466e338c 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -19,6 +19,7 @@ from __future__ import annotations +import uuid import warnings from typing import TYPE_CHECKING, Any, Protocol @@ -27,6 +28,7 @@ except ImportError: from typing_extensions import deprecated # Python 3.12 + import pyarrow as pa from datafusion.catalog import Catalog @@ -592,9 +594,19 @@ def register_listing_table( self._convert_file_sort_order(file_sort_order), ) - def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: + def sql( + self, + query: str, + options: SQLOptions | None = None, + param_values: dict[str, Any] | None = None, + **named_params: Any, + ) -> DataFrame: """Create a :py:class:`~datafusion.DataFrame` from SQL query text. + See the online documentation for a description of how to perform + parameterized substitution via either the param_values option + or passing in named parameters. + Note: This API implements DDL statements such as ``CREATE TABLE`` and ``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory default implementation.See @@ -603,15 +615,57 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: Args: query: SQL query text. options: If provided, the query will be validated against these options. + param_values: Provides substitution of scalar values in the query + after parsing. + named_params: Provides string or DataFrame substitution in the query string. Returns: DataFrame representation of the SQL query. """ - if options is None: - return DataFrame(self.ctx.sql(query)) - return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) - def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: + def value_to_scalar(value) -> pa.Scalar: + if isinstance(value, pa.Scalar): + return value + return pa.scalar(value) + + def value_to_string(value) -> str: + if isinstance(value, DataFrame): + view_name = str(uuid.uuid4()).replace("-", "_") + view_name = f"view_{view_name}" + view = value.df.into_view(temporary=True) + self.ctx.register_table(view_name, view) + return view_name + return str(value) + + param_values = ( + {name: value_to_scalar(value) for (name, value) in param_values.items()} + if param_values is not None + else {} + ) + param_strings = ( + {name: value_to_string(value) for (name, value) in named_params.items()} + if named_params is not None + else {} + ) + + options_raw = options.options_internal if options is not None else None + + return DataFrame( + self.ctx.sql_with_options( + query, + options=options_raw, + param_values=param_values, + param_strings=param_strings, + ) + ) + + def sql_with_options( + self, + query: str, + options: SQLOptions, + param_values: dict[str, Any] | None = None, + **named_params: Any, + ) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text. This function will first validate that the query is allowed by the @@ -620,11 +674,16 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: Args: query: SQL query text. options: SQL options. + param_values: Provides substitution of scalar values in the query + after parsing. + named_params: Provides string or DataFrame substitution in the query string. Returns: DataFrame representation of the SQL query. """ - return self.sql(query, options) + return self.sql( + query, options=options, param_values=param_values, **named_params + ) def create_dataframe( self, diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index c383edc60..14ab7636b 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -21,7 +21,7 @@ import pyarrow as pa import pyarrow.dataset as ds import pytest -from datafusion import col, udf +from datafusion import SessionContext, col, udf from datafusion.object_store import Http from pyarrow.csv import write_csv @@ -533,3 +533,46 @@ def test_register_listing_table( rd = result.to_pydict() assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2} + + +def test_parameterized_named_params(ctx, tmp_path) -> None: + path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + + df = ctx.read_parquet(path) + result = ctx.sql( + "SELECT COUNT(a) AS cnt, $lit_val as lit_val FROM $replaced_df", + lit_val=3, + replaced_df=df, + ).collect() + result = pa.Table.from_batches(result) + assert result.to_pydict() == {"cnt": [100], "lit_val": [3]} + + +def test_parameterized_param_values(ctx: SessionContext) -> None: + # Test the parameters that should be handled by the parser rather + # than our manipulation of the query string by searching for tokens + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 4])], + names=["a"], + ) + + ctx.register_record_batches("t", [[batch]]) + result = ctx.sql("SELECT a FROM t WHERE a < $val", param_values={"val": 3}) + assert result.to_pydict() == {"a": [1, 2]} + + +def test_parameterized_mixed_query(ctx: SessionContext) -> None: + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 4])], + names=["a"], + ) + ctx.register_record_batches("t", [[batch]]) + registered_df = ctx.table("t") + + result = ctx.sql( + "SELECT $col_name FROM $df WHERE a < $val", + param_values={"val": 3}, + df=registered_df, + col_name="a", + ) + assert result.to_pydict() == {"a": [1, 2]} diff --git a/src/context.rs b/src/context.rs index dc18a7676..f4008af6d 100644 --- a/src/context.rs +++ b/src/context.rs @@ -32,6 +32,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider}; +use crate::common::data_type::PyScalarValue; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; @@ -40,6 +41,7 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::exceptions::py_value_err; use crate::sql::logical::PyLogicalPlan; +use crate::sql::util::replace_placeholders_with_strings; use crate::store::StorageContexts; use crate::table::PyTable; use crate::udaf::PyAggregateUDF; @@ -427,27 +429,41 @@ impl PySessionContext { self.ctx.register_udtf(&name, func); } - /// Returns a PyDataFrame whose plan corresponds to the SQL statement. - pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult { - let result = self.ctx.sql(query); - let df = wait_for_future(py, result)??; - Ok(PyDataFrame::new(df)) - } - - #[pyo3(signature = (query, options=None))] + #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))] pub fn sql_with_options( &self, - query: &str, - options: Option, py: Python, + mut query: String, + options: Option, + param_values: HashMap, + param_strings: HashMap, ) -> PyDataFusionResult { let options = if let Some(options) = options { options.options } else { SQLOptions::new() }; - let result = self.ctx.sql_with_options(query, options); - let df = wait_for_future(py, result)??; + + let param_values = param_values + .into_iter() + .map(|(name, value)| (name, ScalarValue::from(value))) + .collect::>(); + + let state = self.ctx.state(); + let dialect = state.config().options().sql_parser.dialect.as_str(); + + if !param_strings.is_empty() { + query = replace_placeholders_with_strings(&query, dialect, param_strings)?; + } + + let mut df = wait_for_future(py, async { + self.ctx.sql_with_options(&query, options).await + })??; + + if !param_values.is_empty() { + df = df.with_param_values(param_values)?; + } + Ok(PyDataFrame::new(df)) } diff --git a/src/sql.rs b/src/sql.rs index 9f1fe81be..dea9b566a 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -17,3 +17,4 @@ pub mod exceptions; pub mod logical; +pub(crate) mod util; diff --git a/src/sql/util.rs b/src/sql/util.rs new file mode 100644 index 000000000..438c526a2 --- /dev/null +++ b/src/sql/util.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::{exec_err, plan_datafusion_err, DataFusionError}; +use datafusion::logical_expr::sqlparser::dialect::dialect_from_str; +use datafusion::sql::sqlparser::dialect::Dialect; +use datafusion::sql::sqlparser::parser::Parser; +use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer}; +use std::collections::HashMap; + +fn tokens_from_replacements( + placeholder: &str, + replacements: &HashMap>, +) -> Option> { + if let Some(pattern) = placeholder.strip_prefix("$") { + replacements.get(pattern).cloned() + } else { + None + } +} + +fn get_tokens_for_string_replacement( + dialect: &dyn Dialect, + replacements: HashMap, +) -> Result>, DataFusionError> { + replacements + .into_iter() + .map(|(name, value)| { + let tokens = Tokenizer::new(dialect, &value) + .tokenize() + .map_err(|err| DataFusionError::External(err.into()))?; + Ok((name, tokens)) + }) + .collect() +} + +pub(crate) fn replace_placeholders_with_strings( + query: &str, + dialect: &str, + replacements: HashMap, +) -> Result { + let dialect = dialect_from_str(dialect) + .ok_or_else(|| plan_datafusion_err!("Unsupported SQL dialect: {dialect}."))?; + + let replacements = get_tokens_for_string_replacement(dialect.as_ref(), replacements)?; + + let tokens = Tokenizer::new(dialect.as_ref(), query) + .tokenize() + .map_err(|err| DataFusionError::External(err.into()))?; + + let replaced_tokens = tokens + .into_iter() + .flat_map(|token| { + if let Token::Placeholder(placeholder) = &token { + tokens_from_replacements(placeholder, &replacements).unwrap_or(vec![token]) + } else { + vec![token] + } + }) + .collect::>(); + + let statement = Parser::new(dialect.as_ref()) + .with_tokens(replaced_tokens) + .parse_statements() + .map_err(|err| DataFusionError::External(Box::new(err)))?; + + if statement.len() != 1 { + return exec_err!("placeholder replacement should return exactly one statement"); + } + + Ok(statement[0].to_string()) +}