Skip to content

Commit eab2793

Browse files
committed
Intermediate work on parameterizing queries
1 parent 94b6f55 commit eab2793

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

python/datafusion/context.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
except ImportError:
2828
from typing_extensions import deprecated # Python 3.12
2929

30+
import uuid
31+
3032
import pyarrow as pa
3133

3234
from datafusion.catalog import Catalog
@@ -592,7 +594,9 @@ def register_listing_table(
592594
self._convert_file_sort_order(file_sort_order),
593595
)
594596

595-
def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
597+
def sql(
598+
self, query: str, options: SQLOptions | None = None, **named_params: Any
599+
) -> DataFrame:
596600
"""Create a :py:class:`~datafusion.DataFrame` from SQL query text.
597601
598602
Note: This API implements DDL statements such as ``CREATE TABLE`` and
@@ -603,10 +607,25 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
603607
Args:
604608
query: SQL query text.
605609
options: If provided, the query will be validated against these options.
610+
named_params: Provides substitution in the query string.
606611
607612
Returns:
608613
DataFrame representation of the SQL query.
609614
"""
615+
if named_params:
616+
for alias, param in named_params.items():
617+
if isinstance(param, DataFrame):
618+
view_name = str(uuid.uuid4()).replace("-", "_")
619+
view_name = f"view_{view_name}"
620+
self.ctx.create_temporary_view(
621+
view_name, param.df, replace_if_exists=True
622+
)
623+
replace_str = view_name
624+
else:
625+
replace_str = str(param)
626+
627+
query = query.replace(f"{{{alias}}}", replace_str)
628+
610629
if options is None:
611630
return DataFrame(self.ctx.sql(query))
612631
return DataFrame(self.ctx.sql_with_options(query, options.options_internal))

python/tests/test_sql.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,13 @@ def test_register_listing_table(
533533

534534
rd = result.to_pydict()
535535
assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2}
536+
537+
538+
def test_parameterized_sql(ctx, tmp_path) -> None:
539+
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
540+
df = ctx.read_parquet(path)
541+
result = ctx.sql(
542+
"SELECT COUNT(a) AS cnt FROM {replaced_df}", replaced_df=df
543+
).collect()
544+
result = pa.Table.from_batches(result)
545+
assert result.to_pydict() == {"cnt": [100]}

src/context.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use crate::record_batch::PyRecordBatchStream;
4141
use crate::sql::exceptions::py_value_err;
4242
use crate::sql::logical::PyLogicalPlan;
4343
use crate::store::StorageContexts;
44-
use crate::table::PyTable;
44+
use crate::table::{PyTable, TempViewTable};
4545
use crate::udaf::PyAggregateUDF;
4646
use crate::udf::PyScalarUDF;
4747
use crate::udtf::PyTableFunction;
@@ -492,6 +492,29 @@ impl PySessionContext {
492492
PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
493493
}
494494

495+
pub fn create_temporary_view(
496+
&self,
497+
py: Python,
498+
name: &str,
499+
df: PyDataFrame,
500+
replace_if_exists: bool,
501+
) -> PyDataFusionResult<()> {
502+
if self.table(name, py).is_ok() {
503+
if replace_if_exists {
504+
let _ = self.deregister_table(name);
505+
} else {
506+
exec_err!(
507+
"Unable to create temporary view. Table with name {name} already exists."
508+
)?;
509+
}
510+
}
511+
512+
let table = Arc::new(TempViewTable::new(df.inner_df()));
513+
self.ctx.register_table(name, table)?;
514+
515+
Ok(())
516+
}
517+
495518
/// Construct datafusion dataframe from Python list
496519
#[pyo3(signature = (data, name=None))]
497520
pub fn from_pylist(

src/table.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ use async_trait::async_trait;
2424
use datafusion::catalog::Session;
2525
use datafusion::common::Column;
2626
use datafusion::datasource::{TableProvider, TableType};
27+
<<<<<<< HEAD
2728
use datafusion::logical_expr::{Expr, LogicalPlanBuilder, TableProviderFilterPushDown};
29+
=======
30+
use datafusion::logical_expr::{Expr, LogicalPlanBuilder};
31+
>>>>>>> 3c9a96c (Intermediate work on temp views and parameterizing queries)
2832
use datafusion::physical_plan::ExecutionPlan;
2933
use datafusion::prelude::DataFrame;
3034
use pyo3::prelude::*;

0 commit comments

Comments
 (0)