3030from datafusion .record_batch import RecordBatchStream
3131from datafusion .udf import ScalarUDF , AggregateUDF , WindowUDF
3232
33+ import pathlib
3334from typing import Any , TYPE_CHECKING , Protocol
3435from typing_extensions import deprecated
3536
3637if TYPE_CHECKING :
3738 import pyarrow
3839 import pandas
3940 import polars
40- import pathlib
4141 from datafusion .plan import LogicalPlan , ExecutionPlan
4242
4343
@@ -523,9 +523,18 @@ def register_listing_table(
523523 file_sort_order_raw ,
524524 )
525525
526- def sql (self , query : str , options : SQLOptions | None = None ) -> DataFrame :
526+ def sql (
527+ self , query : str , options : SQLOptions | None = None , ** named_dfs : DataFrame
528+ ) -> DataFrame :
527529 """Create a :py:class:`~datafusion.DataFrame` from SQL query text.
528530
531+ The query string can optionally take a DataFrame as a parameter by assigning
532+ a variable inside brackets. In the following example, if we have a DataFrame
533+ called `my_df` then the DataFrame's logical plan will be converted into an
534+ SQL query string and inserted as a subtitution::
535+
536+ ctx.sql("SELECT name from {df}", df=my_df)
537+
529538 Note: This API implements DDL statements such as ``CREATE TABLE`` and
530539 ``CREATE VIEW`` and DML statements such as ``INSERT INTO`` with in-memory
531540 default implementation.See
@@ -534,12 +543,20 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
534543 Args:
535544 query: SQL query text.
536545 options: If provided, the query will be validated against these options.
546+ named_dfs: When provided, used to replace parameterized query variables
547+ in the query string.
537548
538549 Returns:
539550 DataFrame representation of the SQL query.
540551 """
552+ if named_dfs :
553+ for alias , df in named_dfs .items ():
554+ df_sql = f"({ df .logical_plan ().to_sql ()} )"
555+ query = query .replace (f"{{{ alias } }}" , df_sql )
556+
541557 if options is None :
542558 return DataFrame (self .ctx .sql (query ))
559+
543560 return DataFrame (self .ctx .sql_with_options (query , options .options_internal ))
544561
545562 def sql_with_options (self , query : str , options : SQLOptions ) -> DataFrame :
@@ -753,7 +770,7 @@ def register_parquet(
753770 def register_csv (
754771 self ,
755772 name : str ,
756- path : str | pathlib .Path | list [str | pathlib .Path ],
773+ path : str | pathlib .Path | list [str ] | list [ pathlib .Path ],
757774 schema : pyarrow .Schema | None = None ,
758775 has_header : bool = True ,
759776 delimiter : str = "," ,
@@ -917,6 +934,7 @@ def read_json(
917934 file_extension : str = ".json" ,
918935 table_partition_cols : list [tuple [str , str ]] | None = None ,
919936 file_compression_type : str | None = None ,
937+ table_name : str | None = None ,
920938 ) -> DataFrame :
921939 """Read a line-delimited JSON data source.
922940
@@ -929,22 +947,23 @@ def read_json(
929947 selected for data input.
930948 table_partition_cols: Partition columns.
931949 file_compression_type: File compression type.
950+ table_name: Name to register the table as for SQL queries
932951
933952 Returns:
934953 DataFrame representation of the read JSON files.
935954 """
936- if table_partition_cols is None :
937- table_partition_cols = []
938- return DataFrame (
939- self .ctx .read_json (
940- str (path ),
941- schema ,
942- schema_infer_max_records ,
943- file_extension ,
944- table_partition_cols ,
945- file_compression_type ,
946- )
955+ if table_name is None :
956+ table_name = self .generate_table_name (path )
957+ self .register_json (
958+ table_name ,
959+ path ,
960+ schema = schema ,
961+ schema_infer_max_records = schema_infer_max_records ,
962+ file_extension = file_extension ,
963+ table_partition_cols = table_partition_cols ,
964+ file_compression_type = file_compression_type ,
947965 )
966+ return self .table (table_name )
948967
949968 def read_csv (
950969 self ,
@@ -956,6 +975,7 @@ def read_csv(
956975 file_extension : str = ".csv" ,
957976 table_partition_cols : list [tuple [str , str ]] | None = None ,
958977 file_compression_type : str | None = None ,
978+ table_name : str | None = None ,
959979 ) -> DataFrame :
960980 """Read a CSV data source.
961981
@@ -973,27 +993,24 @@ def read_csv(
973993 selected for data input.
974994 table_partition_cols: Partition columns.
975995 file_compression_type: File compression type.
996+ table_name: Name to register the table as for SQL queries
976997
977998 Returns:
978999 DataFrame representation of the read CSV files
9791000 """
980- if table_partition_cols is None :
981- table_partition_cols = []
982-
983- path = [str (p ) for p in path ] if isinstance (path , list ) else str (path )
984-
985- return DataFrame (
986- self .ctx .read_csv (
987- path ,
988- schema ,
989- has_header ,
990- delimiter ,
991- schema_infer_max_records ,
992- file_extension ,
993- table_partition_cols ,
994- file_compression_type ,
995- )
1001+ if table_name is None :
1002+ table_name = self .generate_table_name (path )
1003+ self .register_csv (
1004+ table_name ,
1005+ path ,
1006+ schema = schema ,
1007+ has_header = has_header ,
1008+ delimiter = delimiter ,
1009+ schema_infer_max_records = schema_infer_max_records ,
1010+ file_extension = file_extension ,
1011+ file_compression_type = file_compression_type ,
9961012 )
1013+ return self .table (table_name )
9971014
9981015 def read_parquet (
9991016 self ,
@@ -1004,6 +1021,7 @@ def read_parquet(
10041021 skip_metadata : bool = True ,
10051022 schema : pyarrow .Schema | None = None ,
10061023 file_sort_order : list [list [Expr ]] | None = None ,
1024+ table_name : str | None = None ,
10071025 ) -> DataFrame :
10081026 """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
10091027
@@ -1021,30 +1039,32 @@ def read_parquet(
10211039 the parquet reader will try to infer it based on data in the
10221040 file.
10231041 file_sort_order: Sort order for the file.
1042+ table_name: Name to register the table as for SQL queries
10241043
10251044 Returns:
10261045 DataFrame representation of the read Parquet files
10271046 """
1028- if table_partition_cols is None :
1029- table_partition_cols = []
1030- return DataFrame (
1031- self .ctx .read_parquet (
1032- str (path ),
1033- table_partition_cols ,
1034- parquet_pruning ,
1035- file_extension ,
1036- skip_metadata ,
1037- schema ,
1038- file_sort_order ,
1039- )
1047+ if table_name is None :
1048+ table_name = self .generate_table_name (path )
1049+ self .register_parquet (
1050+ table_name ,
1051+ path ,
1052+ table_partition_cols = table_partition_cols ,
1053+ parquet_pruning = parquet_pruning ,
1054+ file_extension = file_extension ,
1055+ skip_metadata = skip_metadata ,
1056+ schema = schema ,
1057+ file_sort_order = file_sort_order ,
10401058 )
1059+ return self .table (table_name )
10411060
10421061 def read_avro (
10431062 self ,
10441063 path : str | pathlib .Path ,
10451064 schema : pyarrow .Schema | None = None ,
10461065 file_partition_cols : list [tuple [str , str ]] | None = None ,
10471066 file_extension : str = ".avro" ,
1067+ table_name : str | None = None ,
10481068 ) -> DataFrame :
10491069 """Create a :py:class:`DataFrame` for reading Avro data source.
10501070
@@ -1053,15 +1073,21 @@ def read_avro(
10531073 schema: The data source schema.
10541074 file_partition_cols: Partition columns.
10551075 file_extension: File extension to select.
1076+ table_name: Name to register the table as for SQL queries
10561077
10571078 Returns:
10581079 DataFrame representation of the read Avro file
10591080 """
1060- if file_partition_cols is None :
1061- file_partition_cols = []
1062- return DataFrame (
1063- self .ctx .read_avro (str (path ), schema , file_partition_cols , file_extension )
1081+ if table_name is None :
1082+ table_name = self .generate_table_name (path )
1083+ self .register_avro (
1084+ table_name ,
1085+ path ,
1086+ schema = schema ,
1087+ file_extension = file_extension ,
1088+ table_partition_cols = file_partition_cols ,
10641089 )
1090+ return self .table (table_name )
10651091
10661092 def read_table (self , table : Table ) -> DataFrame :
10671093 """Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table.
@@ -1075,3 +1101,22 @@ def read_table(self, table: Table) -> DataFrame:
10751101 def execute (self , plan : ExecutionPlan , partitions : int ) -> RecordBatchStream :
10761102 """Execute the ``plan`` and return the results."""
10771103 return RecordBatchStream (self .ctx .execute (plan ._raw_plan , partitions ))
1104+
1105+ def generate_table_name (
1106+ self , path : str | pathlib .Path | list [str ] | list [pathlib .Path ]
1107+ ) -> str :
1108+ """Generate a table name based on the file name or a uuid."""
1109+ import uuid
1110+
1111+ if isinstance (path , list ):
1112+ path = path [0 ]
1113+
1114+ if isinstance (path , str ):
1115+ path = pathlib .Path (path )
1116+
1117+ table_name = path .stem .replace ("." , "_" )
1118+
1119+ if self .table_exist (table_name ):
1120+ table_name = uuid .uuid4 ().hex
1121+
1122+ return table_name
0 commit comments