Skip to content

Commit 4d3d602

Browse files
committed
Reworking to do token parsing of sql query instead of string manipulation
1 parent eab2793 commit 4d3d602

File tree

5 files changed

+186
-33
lines changed

5 files changed

+186
-33
lines changed

python/datafusion/context.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
except ImportError:
2828
from typing_extensions import deprecated # Python 3.12
2929

30-
import uuid
3130

3231
import pyarrow as pa
3332

@@ -612,25 +611,41 @@ def sql(
612611
Returns:
613612
DataFrame representation of the SQL query.
614613
"""
615-
if named_params:
616-
for alias, param in named_params.items():
617-
if isinstance(param, DataFrame):
618-
view_name = str(uuid.uuid4()).replace("-", "_")
619-
view_name = f"view_{view_name}"
620-
self.ctx.create_temporary_view(
621-
view_name, param.df, replace_if_exists=True
622-
)
623-
replace_str = view_name
624-
else:
625-
replace_str = str(param)
626614

627-
query = query.replace(f"{{{alias}}}", replace_str)
615+
def scalar_params(**p: Any) -> list[tuple[str, pa.Scalar]]:
616+
if p is None:
617+
return []
618+
619+
return [
620+
(name, pa.scalar(value))
621+
for (name, value) in p.items()
622+
if not isinstance(value, DataFrame)
623+
]
628624

629-
if options is None:
630-
return DataFrame(self.ctx.sql(query))
631-
return DataFrame(self.ctx.sql_with_options(query, options.options_internal))
625+
def dataframe_params(**p: Any) -> list[tuple[str, DataFrame]]:
626+
if p is None:
627+
return []
632628

633-
def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
629+
return [
630+
(name, value.df)
631+
for (name, value) in p.items()
632+
if isinstance(value, DataFrame)
633+
]
634+
635+
options_raw = options.options_internal if options is not None else None
636+
637+
return DataFrame(
638+
self.ctx.sql_with_options(
639+
query,
640+
options=options_raw,
641+
scalar_params=scalar_params(**named_params),
642+
dataframe_params=dataframe_params(**named_params),
643+
)
644+
)
645+
646+
def sql_with_options(
647+
self, query: str, options: SQLOptions, **named_params: Any
648+
) -> DataFrame:
634649
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.
635650
636651
This function will first validate that the query is allowed by the
@@ -639,11 +654,12 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
639654
Args:
640655
query: SQL query text.
641656
options: SQL options.
657+
named_params: Provides substitution in the query string.
642658
643659
Returns:
644660
DataFrame representation of the SQL query.
645661
"""
646-
return self.sql(query, options)
662+
return self.sql(query, options, **named_params)
647663

648664
def create_dataframe(
649665
self,

python/tests/test_sql.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pyarrow as pa
2222
import pyarrow.dataset as ds
2323
import pytest
24-
from datafusion import col, udf
24+
from datafusion import SessionContext, col, udf
2525
from datafusion.object_store import Http
2626
from pyarrow.csv import write_csv
2727

@@ -535,11 +535,25 @@ def test_register_listing_table(
535535
assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2}
536536

537537

538-
def test_parameterized_sql(ctx, tmp_path) -> None:
538+
def test_parameterized_df_in_sql(ctx, tmp_path) -> None:
539539
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
540+
540541
df = ctx.read_parquet(path)
541542
result = ctx.sql(
542-
"SELECT COUNT(a) AS cnt FROM {replaced_df}", replaced_df=df
543+
"SELECT COUNT(a) AS cnt FROM $replaced_df", replaced_df=df
543544
).collect()
544545
result = pa.Table.from_batches(result)
545546
assert result.to_pydict() == {"cnt": [100]}
547+
548+
549+
def test_parameterized_pass_through_in_sql(ctx: SessionContext) -> None:
550+
# Test the parameters that should be handled by the parser rather
551+
# than our manipulation of the query string by searching for tokens
552+
batch = pa.RecordBatch.from_arrays(
553+
[pa.array([1, 2, 3, 4])],
554+
names=["a"],
555+
)
556+
557+
ctx.register_record_batches("t", [[batch]])
558+
result = ctx.sql("SELECT a FROM t WHERE a < $val", val=3)
559+
assert result.to_pydict() == {"a": [1, 2]}

src/context.rs

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
3232
use pyo3::prelude::*;
3333

3434
use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
35+
use crate::common::data_type::PyScalarValue;
3536
use crate::dataframe::PyDataFrame;
3637
use crate::dataset::Dataset;
37-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
38+
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
3839
use crate::expr::sort_expr::PySortExpr;
3940
use crate::physical_plan::PyExecutionPlan;
4041
use crate::record_batch::PyRecordBatchStream;
4142
use crate::sql::exceptions::py_value_err;
4243
use crate::sql::logical::PyLogicalPlan;
44+
use crate::sql::util::replace_placeholders_with_table_names;
4345
use crate::store::StorageContexts;
4446
use crate::table::{PyTable, TempViewTable};
4547
use crate::udaf::PyAggregateUDF;
@@ -427,27 +429,54 @@ impl PySessionContext {
427429
self.ctx.register_udtf(&name, func);
428430
}
429431

430-
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
431-
pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
432-
let result = self.ctx.sql(query);
433-
let df = wait_for_future(py, result)??;
434-
Ok(PyDataFrame::new(df))
435-
}
436-
437-
#[pyo3(signature = (query, options=None))]
432+
#[pyo3(signature = (query, options=None, scalar_params=vec![], dataframe_params=vec![]))]
438433
pub fn sql_with_options(
439434
&self,
435+
py: Python,
440436
query: &str,
441437
options: Option<PySQLOptions>,
442-
py: Python,
438+
scalar_params: Vec<(String, PyScalarValue)>,
439+
dataframe_params: Vec<(String, PyDataFrame)>,
443440
) -> PyDataFusionResult<PyDataFrame> {
444441
let options = if let Some(options) = options {
445442
options.options
446443
} else {
447444
SQLOptions::new()
448445
};
449-
let result = self.ctx.sql_with_options(query, options);
450-
let df = wait_for_future(py, result)??;
446+
447+
let scalar_params = scalar_params
448+
.into_iter()
449+
.map(|(name, value)| (name, ScalarValue::from(value)))
450+
.collect::<Vec<_>>();
451+
452+
let dataframe_params = dataframe_params
453+
.into_iter()
454+
.map(|(name, df)| {
455+
let uuid = Uuid::new_v4().to_string().replace("-", "");
456+
let view_name = format!("view_{uuid}");
457+
458+
self.create_temporary_view(py, view_name.as_str(), df, true)?;
459+
Ok((name, view_name))
460+
})
461+
.collect::<Result<HashMap<_, _>, PyDataFusionError>>()?;
462+
463+
let state = self.ctx.state();
464+
let dialect = state.config().options().sql_parser.dialect.as_str();
465+
466+
let query = replace_placeholders_with_table_names(query, dialect, dataframe_params)?;
467+
468+
println!("using scalar params: {scalar_params:?}");
469+
let df = wait_for_future(py, async {
470+
self.ctx
471+
.sql_with_options(&query, options)
472+
.await
473+
.map_err(|err| {
474+
println!("error before param replacement: {}", err);
475+
err
476+
})?
477+
.with_param_values(scalar_params)
478+
})??;
479+
451480
Ok(PyDataFrame::new(df))
452481
}
453482

src/sql.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717

1818
pub mod exceptions;
1919
pub mod logical;
20+
pub(crate) mod util;

src/sql/util.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
use datafusion::common::{internal_err, plan_datafusion_err, DataFusionError};
2+
use datafusion::logical_expr::sqlparser::dialect::dialect_from_str;
3+
use datafusion::sql::sqlparser::dialect::Dialect;
4+
use datafusion::sql::sqlparser::keywords::Keyword;
5+
use datafusion::sql::sqlparser::parser::Parser;
6+
use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer, Word};
7+
use std::collections::HashMap;
8+
9+
fn value_from_replacements(
10+
placeholder: &str,
11+
replacements: &HashMap<String, String>,
12+
) -> Option<Token> {
13+
if let Some(pattern) = placeholder.strip_prefix("$") {
14+
replacements.get(pattern).map(|replacement| {
15+
Token::Word(Word {
16+
value: replacement.to_owned(),
17+
quote_style: None,
18+
keyword: Keyword::NoKeyword,
19+
})
20+
})
21+
} else {
22+
None
23+
}
24+
}
25+
26+
fn table_names_are_valid(dialect: &dyn Dialect, replacements: &HashMap<String, String>) -> bool {
27+
for name in replacements.values() {
28+
let tokens = Tokenizer::new(dialect, name).tokenize().unwrap();
29+
if tokens.len() != 1 {
30+
// We should get exactly one token for our temporary table name
31+
return false;
32+
}
33+
34+
if let Token::Word(word) = &tokens[0] {
35+
// Generated table names should be not quoted or have keywords
36+
if word.quote_style.is_some() || word.keyword != Keyword::NoKeyword {
37+
return false;
38+
}
39+
} else {
40+
// We should always parse table names to a Word
41+
return false;
42+
}
43+
}
44+
45+
true
46+
}
47+
48+
pub(crate) fn replace_placeholders_with_table_names(
49+
query: &str,
50+
dialect: &str,
51+
replacements: HashMap<String, String>,
52+
) -> Result<String, DataFusionError> {
53+
let dialect = dialect_from_str(dialect).ok_or_else(|| {
54+
plan_datafusion_err!(
55+
"Unsupported SQL dialect: {dialect}. Available dialects: \
56+
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
57+
MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
58+
)
59+
})?;
60+
61+
if !table_names_are_valid(dialect.as_ref(), &replacements) {
62+
return internal_err!("Invalid generated table name when replacing placeholders");
63+
}
64+
let tokens = Tokenizer::new(dialect.as_ref(), query).tokenize().unwrap();
65+
66+
let replaced_tokens = tokens
67+
.into_iter()
68+
.map(|token| {
69+
if let Token::Word(word) = &token {
70+
let Word {
71+
value,
72+
quote_style: _,
73+
keyword: _,
74+
} = word;
75+
76+
value_from_replacements(value, &replacements).unwrap_or(token)
77+
} else if let Token::Placeholder(placeholder) = &token {
78+
value_from_replacements(placeholder, &replacements).unwrap_or(token)
79+
} else {
80+
token
81+
}
82+
})
83+
.collect::<Vec<Token>>();
84+
85+
Ok(Parser::new(dialect.as_ref())
86+
.with_tokens(replaced_tokens)
87+
.parse_statements()
88+
.map_err(|err| DataFusionError::External(Box::new(err)))?
89+
.into_iter()
90+
.map(|s| s.to_string())
91+
.collect::<Vec<_>>()
92+
.join(" "))
93+
}

0 commit comments

Comments
 (0)