Skip to content

Commit fd5977f

Browse files
committed
When calling read_x to create dataframes in the session context, always register the dataframes so they can be used with sql queries.
1 parent 3bd30a0 commit fd5977f

File tree

3 files changed

+102
-46
lines changed

3 files changed

+102
-46
lines changed

python/datafusion/context.py

Lines changed: 91 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030
from datafusion.record_batch import RecordBatchStream
3131
from datafusion.udf import ScalarUDF, AggregateUDF, WindowUDF
3232

33+
import pathlib
3334
from typing import Any, TYPE_CHECKING, Protocol
3435
from typing_extensions import deprecated
3536

3637
if TYPE_CHECKING:
3738
import pyarrow
3839
import pandas
3940
import polars
40-
import pathlib
4141
from datafusion.plan import LogicalPlan, ExecutionPlan
4242

4343

@@ -523,9 +523,18 @@ def register_listing_table(
523523
file_sort_order_raw,
524524
)
525525

526-
def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
526+
def sql(
527+
self, query: str, options: SQLOptions | None = None, **named_dfs: DataFrame
528+
) -> DataFrame:
527529
"""Create a :py:class:`~datafusion.DataFrame` from SQL query text.
528530
531+
The query string can optionally take a DataFrame as a parameter by assigning
532+
a variable inside brackets. In the following example, if we have a DataFrame
533+
called `my_df` then the DataFrame's logical plan will be converted into an
534+
SQL query string and inserted as a subtitution::
535+
536+
ctx.sql("SELECT name from {df}", df=my_df)
537+
529538
Note: This API implements DDL statements such as ``CREATE TABLE`` and
530539
``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory
531540
default implementation.See
@@ -534,12 +543,20 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
534543
Args:
535544
query: SQL query text.
536545
options: If provided, the query will be validated against these options.
546+
named_dfs: When provided, used to replace parameterized query variables
547+
in the query string.
537548
538549
Returns:
539550
DataFrame representation of the SQL query.
540551
"""
552+
if named_dfs:
553+
for alias, df in named_dfs.items():
554+
df_sql = f"({df.logical_plan().to_sql()})"
555+
query = query.replace(f"{{{alias}}}", df_sql)
556+
541557
if options is None:
542558
return DataFrame(self.ctx.sql(query))
559+
543560
return DataFrame(self.ctx.sql_with_options(query, options.options_internal))
544561

545562
def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
@@ -753,7 +770,7 @@ def register_parquet(
753770
def register_csv(
754771
self,
755772
name: str,
756-
path: str | pathlib.Path | list[str | pathlib.Path],
773+
path: str | pathlib.Path | list[str] | list[pathlib.Path],
757774
schema: pyarrow.Schema | None = None,
758775
has_header: bool = True,
759776
delimiter: str = ",",
@@ -917,6 +934,7 @@ def read_json(
917934
file_extension: str = ".json",
918935
table_partition_cols: list[tuple[str, str]] | None = None,
919936
file_compression_type: str | None = None,
937+
table_name: str | None = None,
920938
) -> DataFrame:
921939
"""Read a line-delimited JSON data source.
922940
@@ -929,22 +947,23 @@ def read_json(
929947
selected for data input.
930948
table_partition_cols: Partition columns.
931949
file_compression_type: File compression type.
950+
table_name: Name to register the table as for SQL queries
932951
933952
Returns:
934953
DataFrame representation of the read JSON files.
935954
"""
936-
if table_partition_cols is None:
937-
table_partition_cols = []
938-
return DataFrame(
939-
self.ctx.read_json(
940-
str(path),
941-
schema,
942-
schema_infer_max_records,
943-
file_extension,
944-
table_partition_cols,
945-
file_compression_type,
946-
)
955+
if table_name is None:
956+
table_name = self.generate_table_name(path)
957+
self.register_json(
958+
table_name,
959+
path,
960+
schema=schema,
961+
schema_infer_max_records=schema_infer_max_records,
962+
file_extension=file_extension,
963+
table_partition_cols=table_partition_cols,
964+
file_compression_type=file_compression_type,
947965
)
966+
return self.table(table_name)
948967

949968
def read_csv(
950969
self,
@@ -956,6 +975,7 @@ def read_csv(
956975
file_extension: str = ".csv",
957976
table_partition_cols: list[tuple[str, str]] | None = None,
958977
file_compression_type: str | None = None,
978+
table_name: str | None = None,
959979
) -> DataFrame:
960980
"""Read a CSV data source.
961981
@@ -973,27 +993,24 @@ def read_csv(
973993
selected for data input.
974994
table_partition_cols: Partition columns.
975995
file_compression_type: File compression type.
996+
table_name: Name to register the table as for SQL queries
976997
977998
Returns:
978999
DataFrame representation of the read CSV files
9791000
"""
980-
if table_partition_cols is None:
981-
table_partition_cols = []
982-
983-
path = [str(p) for p in path] if isinstance(path, list) else str(path)
984-
985-
return DataFrame(
986-
self.ctx.read_csv(
987-
path,
988-
schema,
989-
has_header,
990-
delimiter,
991-
schema_infer_max_records,
992-
file_extension,
993-
table_partition_cols,
994-
file_compression_type,
995-
)
1001+
if table_name is None:
1002+
table_name = self.generate_table_name(path)
1003+
self.register_csv(
1004+
table_name,
1005+
path,
1006+
schema=schema,
1007+
has_header=has_header,
1008+
delimiter=delimiter,
1009+
schema_infer_max_records=schema_infer_max_records,
1010+
file_extension=file_extension,
1011+
file_compression_type=file_compression_type,
9961012
)
1013+
return self.table(table_name)
9971014

9981015
def read_parquet(
9991016
self,
@@ -1004,6 +1021,7 @@ def read_parquet(
10041021
skip_metadata: bool = True,
10051022
schema: pyarrow.Schema | None = None,
10061023
file_sort_order: list[list[Expr]] | None = None,
1024+
table_name: str | None = None,
10071025
) -> DataFrame:
10081026
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
10091027
@@ -1021,30 +1039,32 @@ def read_parquet(
10211039
the parquet reader will try to infer it based on data in the
10221040
file.
10231041
file_sort_order: Sort order for the file.
1042+
table_name: Name to register the table as for SQL queries
10241043
10251044
Returns:
10261045
DataFrame representation of the read Parquet files
10271046
"""
1028-
if table_partition_cols is None:
1029-
table_partition_cols = []
1030-
return DataFrame(
1031-
self.ctx.read_parquet(
1032-
str(path),
1033-
table_partition_cols,
1034-
parquet_pruning,
1035-
file_extension,
1036-
skip_metadata,
1037-
schema,
1038-
file_sort_order,
1039-
)
1047+
if table_name is None:
1048+
table_name = self.generate_table_name(path)
1049+
self.register_parquet(
1050+
table_name,
1051+
path,
1052+
table_partition_cols=table_partition_cols,
1053+
parquet_pruning=parquet_pruning,
1054+
file_extension=file_extension,
1055+
skip_metadata=skip_metadata,
1056+
schema=schema,
1057+
file_sort_order=file_sort_order,
10401058
)
1059+
return self.table(table_name)
10411060

10421061
def read_avro(
10431062
self,
10441063
path: str | pathlib.Path,
10451064
schema: pyarrow.Schema | None = None,
10461065
file_partition_cols: list[tuple[str, str]] | None = None,
10471066
file_extension: str = ".avro",
1067+
table_name: str | None = None,
10481068
) -> DataFrame:
10491069
"""Create a :py:class:`DataFrame` for reading Avro data source.
10501070
@@ -1053,15 +1073,21 @@ def read_avro(
10531073
schema: The data source schema.
10541074
file_partition_cols: Partition columns.
10551075
file_extension: File extension to select.
1076+
table_name: Name to register the table as for SQL queries
10561077
10571078
Returns:
10581079
DataFrame representation of the read Avro file
10591080
"""
1060-
if file_partition_cols is None:
1061-
file_partition_cols = []
1062-
return DataFrame(
1063-
self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension)
1081+
if table_name is None:
1082+
table_name = self.generate_table_name(path)
1083+
self.register_avro(
1084+
table_name,
1085+
path,
1086+
schema=schema,
1087+
file_extension=file_extension,
1088+
table_partition_cols=file_partition_cols,
10641089
)
1090+
return self.table(table_name)
10651091

10661092
def read_table(self, table: Table) -> DataFrame:
10671093
"""Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table.
@@ -1075,3 +1101,22 @@ def read_table(self, table: Table) -> DataFrame:
10751101
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
10761102
"""Execute the ``plan`` and return the results."""
10771103
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
1104+
1105+
def generate_table_name(
1106+
self, path: str | pathlib.Path | list[str] | list[pathlib.Path]
1107+
) -> str:
1108+
"""Generate a table name based on the file name or a uuid."""
1109+
import uuid
1110+
1111+
if isinstance(path, list):
1112+
path = path[0]
1113+
1114+
if isinstance(path, str):
1115+
path = pathlib.Path(path)
1116+
1117+
table_name = path.stem.replace(".", "_")
1118+
1119+
if self.table_exist(table_name):
1120+
table_name = uuid.uuid4().hex
1121+
1122+
return table_name

python/datafusion/plan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def to_proto(self) -> bytes:
9898
"""
9999
return self._raw_plan.to_proto()
100100

101+
def to_sql(self) -> str:
102+
"""Return the SQL equivalent statement for this logical plan."""
103+
return self._raw_plan.to_sql()
104+
101105

102106
class ExecutionPlan:
103107
"""Represent nodes in the DataFusion Physical Plan."""

src/sql/logical.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use crate::expr::table_scan::PyTableScan;
3434
use crate::expr::unnest::PyUnnest;
3535
use crate::expr::window::PyWindowExpr;
3636
use crate::{context::PySessionContext, errors::py_unsupported_variant_err};
37+
use datafusion::sql::unparser::plan_to_sql;
3738
use datafusion::{error::DataFusionError, logical_expr::LogicalPlan};
3839
use datafusion_proto::logical_plan::{AsLogicalPlan, DefaultLogicalExtensionCodec};
3940
use prost::Message;
@@ -153,6 +154,12 @@ impl PyLogicalPlan {
153154
.map_err(DataFusionError::from)?;
154155
Ok(Self::new(plan))
155156
}
157+
158+
pub fn to_sql(&self) -> PyResult<String> {
159+
plan_to_sql(&self.plan)
160+
.map(|v| v.to_string())
161+
.map_err(|err| PyRuntimeError::new_err(err.to_string()))
162+
}
156163
}
157164

158165
impl From<PyLogicalPlan> for LogicalPlan {

0 commit comments

Comments
 (0)