Skip to content

Commit f10d958

Browse files
committed
Switching to explicit param_values or named parameters that will perform string replacement via parsed tokens
1 parent 4d3d602 commit f10d958

File tree

3 files changed

+108
-144
lines changed

3 files changed

+108
-144
lines changed

python/datafusion/context.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from __future__ import annotations
2121

22+
import uuid
2223
import warnings
2324
from typing import TYPE_CHECKING, Any, Protocol
2425

@@ -594,10 +595,18 @@ def register_listing_table(
594595
)
595596

596597
def sql(
597-
self, query: str, options: SQLOptions | None = None, **named_params: Any
598+
self,
599+
query: str,
600+
options: SQLOptions | None = None,
601+
param_values: dict[str, Any] | None = None,
602+
**named_params: Any,
598603
) -> DataFrame:
599604
"""Create a :py:class:`~datafusion.DataFrame` from SQL query text.
600605
606+
See the online documentation for a description of how to perform
607+
parameterized substitution via either the param_values option
608+
or passing in named parameters.
609+
601610
Note: This API implements DDL statements such as ``CREATE TABLE`` and
602611
``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory
603612
default implementation.See
@@ -606,45 +615,56 @@ def sql(
606615
Args:
607616
query: SQL query text.
608617
options: If provided, the query will be validated against these options.
609-
named_params: Provides substitution in the query string.
618+
param_values: Provides substitution of scalar values in the query
619+
after parsing.
620+
named_params: Provides string or DataFrame substitution in the query string.
610621
611622
Returns:
612623
DataFrame representation of the SQL query.
613624
"""
614625

615-
def scalar_params(**p: Any) -> list[tuple[str, pa.Scalar]]:
616-
if p is None:
617-
return []
618-
619-
return [
620-
(name, pa.scalar(value))
621-
for (name, value) in p.items()
622-
if not isinstance(value, DataFrame)
623-
]
624-
625-
def dataframe_params(**p: Any) -> list[tuple[str, DataFrame]]:
626-
if p is None:
627-
return []
628-
629-
return [
630-
(name, value.df)
631-
for (name, value) in p.items()
632-
if isinstance(value, DataFrame)
633-
]
626+
def value_to_scalar(value) -> pa.Scalar:
627+
if isinstance(value, pa.Scalar):
628+
return value
629+
return pa.scalar(value)
630+
631+
def value_to_string(value) -> str:
632+
if isinstance(value, DataFrame):
633+
view_name = str(uuid.uuid4()).replace("-", "_")
634+
view_name = f"view_{view_name}"
635+
view = value.df.into_view(temporary=True)
636+
self.ctx.register_table(view_name, view)
637+
return view_name
638+
return str(value)
639+
640+
param_values = (
641+
{name: value_to_scalar(value) for (name, value) in param_values}
642+
if param_values is not None
643+
else {}
644+
)
645+
param_strings = (
646+
{name: value_to_string(value) for (name, value) in named_params.items()}
647+
if named_params is not None
648+
else {}
649+
)
634650

635651
options_raw = options.options_internal if options is not None else None
636652

637653
return DataFrame(
638654
self.ctx.sql_with_options(
639655
query,
640656
options=options_raw,
641-
scalar_params=scalar_params(**named_params),
642-
dataframe_params=dataframe_params(**named_params),
657+
param_values=param_values,
658+
param_strings=param_strings,
643659
)
644660
)
645661

646662
def sql_with_options(
647-
self, query: str, options: SQLOptions, **named_params: Any
663+
self,
664+
query: str,
665+
options: SQLOptions,
666+
param_values: dict[str, Any] | None = None,
667+
**named_params: Any,
648668
) -> DataFrame:
649669
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.
650670
@@ -654,12 +674,16 @@ def sql_with_options(
654674
Args:
655675
query: SQL query text.
656676
options: SQL options.
657-
named_params: Provides substitution in the query string.
677+
param_values: Provides substitution of scalar values in the query
678+
after parsing.
679+
named_params: Provides string or DataFrame substitution in the query string.
658680
659681
Returns:
660682
DataFrame representation of the SQL query.
661683
"""
662-
return self.sql(query, options, **named_params)
684+
return self.sql(
685+
query, options=options, param_values=param_values, **named_params
686+
)
663687

664688
def create_dataframe(
665689
self,

src/context.rs

Lines changed: 19 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
3535
use crate::common::data_type::PyScalarValue;
3636
use crate::dataframe::PyDataFrame;
3737
use crate::dataset::Dataset;
38-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
38+
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
3939
use crate::expr::sort_expr::PySortExpr;
4040
use crate::physical_plan::PyExecutionPlan;
4141
use crate::record_batch::PyRecordBatchStream;
4242
use crate::sql::exceptions::py_value_err;
4343
use crate::sql::logical::PyLogicalPlan;
44-
use crate::sql::util::replace_placeholders_with_table_names;
44+
use crate::sql::util::replace_placeholders_with_strings;
4545
use crate::store::StorageContexts;
46-
use crate::table::{PyTable, TempViewTable};
46+
use crate::table::PyTable;
4747
use crate::udaf::PyAggregateUDF;
4848
use crate::udf::PyScalarUDF;
4949
use crate::udtf::PyTableFunction;
@@ -429,54 +429,41 @@ impl PySessionContext {
429429
self.ctx.register_udtf(&name, func);
430430
}
431431

432-
#[pyo3(signature = (query, options=None, scalar_params=vec![], dataframe_params=vec![]))]
432+
#[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
433433
pub fn sql_with_options(
434434
&self,
435435
py: Python,
436-
query: &str,
436+
mut query: String,
437437
options: Option<PySQLOptions>,
438-
scalar_params: Vec<(String, PyScalarValue)>,
439-
dataframe_params: Vec<(String, PyDataFrame)>,
438+
param_values: HashMap<String, PyScalarValue>,
439+
param_strings: HashMap<String, String>,
440440
) -> PyDataFusionResult<PyDataFrame> {
441441
let options = if let Some(options) = options {
442442
options.options
443443
} else {
444444
SQLOptions::new()
445445
};
446446

447-
let scalar_params = scalar_params
447+
let param_values = param_values
448448
.into_iter()
449449
.map(|(name, value)| (name, ScalarValue::from(value)))
450-
.collect::<Vec<_>>();
451-
452-
let dataframe_params = dataframe_params
453-
.into_iter()
454-
.map(|(name, df)| {
455-
let uuid = Uuid::new_v4().to_string().replace("-", "");
456-
let view_name = format!("view_{uuid}");
457-
458-
self.create_temporary_view(py, view_name.as_str(), df, true)?;
459-
Ok((name, view_name))
460-
})
461-
.collect::<Result<HashMap<_, _>, PyDataFusionError>>()?;
450+
.collect::<HashMap<_, _>>();
462451

463452
let state = self.ctx.state();
464453
let dialect = state.config().options().sql_parser.dialect.as_str();
465454

466-
let query = replace_placeholders_with_table_names(query, dialect, dataframe_params)?;
467-
468-
println!("using scalar params: {scalar_params:?}");
469-
let df = wait_for_future(py, async {
470-
self.ctx
471-
.sql_with_options(&query, options)
472-
.await
473-
.map_err(|err| {
474-
println!("error before param replacement: {}", err);
475-
err
476-
})?
477-
.with_param_values(scalar_params)
455+
if !param_strings.is_empty() {
456+
query = replace_placeholders_with_strings(&query, dialect, param_strings)?;
457+
}
458+
459+
let mut df = wait_for_future(py, async {
460+
self.ctx.sql_with_options(&query, options).await
478461
})??;
479462

463+
if !param_values.is_empty() {
464+
df = df.with_param_values(param_values)?;
465+
}
466+
480467
Ok(PyDataFrame::new(df))
481468
}
482469

@@ -521,29 +508,6 @@ impl PySessionContext {
521508
PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
522509
}
523510

524-
pub fn create_temporary_view(
525-
&self,
526-
py: Python,
527-
name: &str,
528-
df: PyDataFrame,
529-
replace_if_exists: bool,
530-
) -> PyDataFusionResult<()> {
531-
if self.table(name, py).is_ok() {
532-
if replace_if_exists {
533-
let _ = self.deregister_table(name);
534-
} else {
535-
exec_err!(
536-
"Unable to create temporary view. Table with name {name} already exists."
537-
)?;
538-
}
539-
}
540-
541-
let table = Arc::new(TempViewTable::new(df.inner_df()));
542-
self.ctx.register_table(name, table)?;
543-
544-
Ok(())
545-
}
546-
547511
/// Construct datafusion dataframe from Python list
548512
#[pyo3(signature = (data, name=None))]
549513
pub fn from_pylist(

src/sql/util.rs

Lines changed: 39 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,69 @@
1-
use datafusion::common::{internal_err, plan_datafusion_err, DataFusionError};
1+
use datafusion::common::{exec_err, plan_datafusion_err, DataFusionError};
22
use datafusion::logical_expr::sqlparser::dialect::dialect_from_str;
33
use datafusion::sql::sqlparser::dialect::Dialect;
4-
use datafusion::sql::sqlparser::keywords::Keyword;
54
use datafusion::sql::sqlparser::parser::Parser;
6-
use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer, Word};
5+
use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer};
76
use std::collections::HashMap;
87

9-
fn value_from_replacements(
8+
fn tokens_from_replacements(
109
placeholder: &str,
11-
replacements: &HashMap<String, String>,
12-
) -> Option<Token> {
10+
replacements: &HashMap<String, Vec<Token>>,
11+
) -> Option<Vec<Token>> {
1312
if let Some(pattern) = placeholder.strip_prefix("$") {
14-
replacements.get(pattern).map(|replacement| {
15-
Token::Word(Word {
16-
value: replacement.to_owned(),
17-
quote_style: None,
18-
keyword: Keyword::NoKeyword,
19-
})
20-
})
13+
replacements.get(pattern).cloned()
2114
} else {
2215
None
2316
}
2417
}
2518

26-
fn table_names_are_valid(dialect: &dyn Dialect, replacements: &HashMap<String, String>) -> bool {
27-
for name in replacements.values() {
28-
let tokens = Tokenizer::new(dialect, name).tokenize().unwrap();
29-
if tokens.len() != 1 {
30-
// We should get exactly one token for our temporary table name
31-
return false;
32-
}
33-
34-
if let Token::Word(word) = &tokens[0] {
35-
// Generated table names should be not quoted or have keywords
36-
if word.quote_style.is_some() || word.keyword != Keyword::NoKeyword {
37-
return false;
38-
}
39-
} else {
40-
// We should always parse table names to a Word
41-
return false;
42-
}
43-
}
44-
45-
true
19+
fn get_tokens_for_string_replacement(
20+
dialect: &dyn Dialect,
21+
replacements: HashMap<String, String>,
22+
) -> Result<HashMap<String, Vec<Token>>, DataFusionError> {
23+
replacements
24+
.into_iter()
25+
.map(|(name, value)| {
26+
let tokens = Tokenizer::new(dialect, &value)
27+
.tokenize()
28+
.map_err(|err| DataFusionError::External(err.into()))?;
29+
Ok((name, tokens))
30+
})
31+
.collect()
4632
}
4733

48-
pub(crate) fn replace_placeholders_with_table_names(
34+
pub(crate) fn replace_placeholders_with_strings(
4935
query: &str,
5036
dialect: &str,
5137
replacements: HashMap<String, String>,
5238
) -> Result<String, DataFusionError> {
53-
let dialect = dialect_from_str(dialect).ok_or_else(|| {
54-
plan_datafusion_err!(
55-
"Unsupported SQL dialect: {dialect}. Available dialects: \
56-
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
57-
MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
58-
)
59-
})?;
39+
let dialect = dialect_from_str(dialect)
40+
.ok_or_else(|| plan_datafusion_err!("Unsupported SQL dialect: {dialect}."))?;
6041

61-
if !table_names_are_valid(dialect.as_ref(), &replacements) {
62-
return internal_err!("Invalid generated table name when replacing placeholders");
63-
}
64-
let tokens = Tokenizer::new(dialect.as_ref(), query).tokenize().unwrap();
42+
let replacements = get_tokens_for_string_replacement(dialect.as_ref(), replacements)?;
43+
44+
let tokens = Tokenizer::new(dialect.as_ref(), query)
45+
.tokenize()
46+
.map_err(|err| DataFusionError::External(err.into()))?;
6547

6648
let replaced_tokens = tokens
6749
.into_iter()
68-
.map(|token| {
69-
if let Token::Word(word) = &token {
70-
let Word {
71-
value,
72-
quote_style: _,
73-
keyword: _,
74-
} = word;
75-
76-
value_from_replacements(value, &replacements).unwrap_or(token)
77-
} else if let Token::Placeholder(placeholder) = &token {
78-
value_from_replacements(placeholder, &replacements).unwrap_or(token)
50+
.flat_map(|token| {
51+
if let Token::Placeholder(placeholder) = &token {
52+
tokens_from_replacements(placeholder, &replacements).unwrap_or(vec![token])
7953
} else {
80-
token
54+
vec![token]
8155
}
8256
})
8357
.collect::<Vec<Token>>();
8458

85-
Ok(Parser::new(dialect.as_ref())
59+
let statement = Parser::new(dialect.as_ref())
8660
.with_tokens(replaced_tokens)
8761
.parse_statements()
88-
.map_err(|err| DataFusionError::External(Box::new(err)))?
89-
.into_iter()
90-
.map(|s| s.to_string())
91-
.collect::<Vec<_>>()
92-
.join(" "))
62+
.map_err(|err| DataFusionError::External(Box::new(err)))?;
63+
64+
if statement.len() != 1 {
65+
return exec_err!("placeholder replacement should return exactly one statement");
66+
}
67+
68+
Ok(statement[0].to_string())
9369
}

0 commit comments

Comments
 (0)