|
1 | 1 | import sys |
2 | 2 | import os |
3 | 3 |
|
| 4 | + |
| 5 | +# If any UDF is defined, the path of the UDF will be set to this variable |
| 6 | +# and the path will be deleted when the process exits |
| 7 | +# UDF config path will be f"{g_udf_path}/udf_config.xml" |
| 8 | +# UDF script path will be f"{g_udf_path}/{func_name}.py" |
| 9 | +g_udf_path = "" |
| 10 | + |
4 | 11 | chdb_version = (0, 6, 0) |
5 | 12 | if sys.version_info[:2] >= (3, 7): |
6 | 13 | # get the path of the current file |
@@ -32,27 +39,30 @@ def to_arrowTable(res): |
32 | 39 | import pyarrow as pa |
33 | 40 | import pandas |
34 | 41 | except ImportError as e: |
35 | | - print(f'ImportError: {e}') |
| 42 | + print(f"ImportError: {e}") |
36 | 43 | print('Please install pyarrow and pandas via "pip install pyarrow pandas"') |
37 | | - raise ImportError('Failed to import pyarrow or pandas') from None |
| 44 | + raise ImportError("Failed to import pyarrow or pandas") from None |
38 | 45 | if len(res) == 0: |
39 | 46 | return pa.Table.from_batches([], schema=pa.schema([])) |
40 | 47 | return pa.RecordBatchFileReader(res.bytes()).read_all() |
41 | 48 |
|
42 | 49 |
|
43 | 50 | # return pandas dataframe |
44 | 51 | def to_df(r): |
45 | | - """"convert arrow table to Dataframe""" |
| 52 | + """convert arrow table to Dataframe""" |
46 | 53 | t = to_arrowTable(r) |
47 | 54 | return t.to_pandas(use_threads=True) |
48 | 55 |
|
49 | 56 |
|
50 | 57 | # wrap _chdb functions |
51 | | -def query(sql, output_format="CSV", path=None, udf_path=None): |
| 58 | +def query(sql, output_format="CSV", path="", udf_path=""): |
| 59 | + global g_udf_path |
| 60 | + if udf_path != "": |
| 61 | + g_udf_path = udf_path |
52 | 62 | lower_output_format = output_format.lower() |
53 | 63 | if lower_output_format == "dataframe": |
54 | | - return to_df(_chdb.query(sql, "Arrow", path=path, udf_path=udf_path)) |
55 | | - elif lower_output_format == 'arrowtable': |
56 | | - return to_arrowTable(_chdb.query(sql, "Arrow", path=path, udf_path=udf_path)) |
| 64 | + return to_df(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path)) |
| 65 | + elif lower_output_format == "arrowtable": |
| 66 | + return to_arrowTable(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path)) |
57 | 67 | else: |
58 | | - return _chdb.query(sql, output_format, path=path, udf_path=udf_path) |
| 68 | + return _chdb.query(sql, output_format, path=path, udf_path=g_udf_path) |
0 commit comments