Skip to content

Commit bfb4df8

Browse files
committed
Use temp dir for UDF config and executables
1 parent 1a7b699 commit bfb4df8

File tree

1 file changed

+65
-21
lines changed

1 file changed

+65
-21
lines changed

chdb/udf/udf.py

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import functools
22
import inspect
33
import os
4+
import sys
45
import tempfile
6+
import atexit
7+
import shutil
8+
import textwrap
59
from xml.etree import ElementTree as ET
10+
import chdb
611

7-
tempdir = tempfile.TemporaryDirectory()
8-
# os.chdir(tempdir.name)
9-
os.chdir("user_scripts")
10-
# print(f"Current working directory: {os.getcwd()}")
1112

1213
def generate_udf(func_name, args, return_type, udf_body):
1314
# generate python script
14-
with open(f"{func_name}.py", "w") as f:
15-
f.write("#!/usr/bin/python3\n")
15+
with open(f"{chdb.g_udf_path}/{func_name}.py", "w") as f:
16+
f.write(f"#!{sys.executable}\n")
1617
f.write("import sys\n")
1718
f.write("\n")
1819
for line in udf_body.split("\n"):
@@ -25,38 +26,81 @@ def generate_udf(func_name, args, return_type, udf_body):
2526
f.write(f" {arg} = args[{i}]\n")
2627
f.write(f" print({func_name}({', '.join(args)}))\n")
2728
f.write(" sys.stdout.flush()\n")
28-
os.chmod(f"{func_name}.py", 0o755)
29+
os.chmod(f"{chdb.g_udf_path}/{func_name}.py", 0o755)
2930
# generate xml file
30-
xml_file = "udf_config.xml"
31-
root = ET.Element('functions')
31+
xml_file = f"{chdb.g_udf_path}/udf_config.xml"
32+
root = ET.Element("functions")
3233
if os.path.exists(xml_file):
3334
tree = ET.parse(xml_file)
3435
root = tree.getroot()
35-
function = ET.SubElement(root, 'function')
36-
ET.SubElement(function, 'type').text = 'executable'
37-
ET.SubElement(function, 'name').text = func_name
38-
ET.SubElement(function, 'return_type').text = return_type
39-
ET.SubElement(function, 'format').text = 'TabSeparated'
40-
ET.SubElement(function, 'command').text = f"{func_name}.py"
36+
function = ET.SubElement(root, "function")
37+
ET.SubElement(function, "type").text = "executable"
38+
ET.SubElement(function, "name").text = func_name
39+
ET.SubElement(function, "return_type").text = return_type
40+
ET.SubElement(function, "format").text = "TabSeparated"
41+
ET.SubElement(function, "command").text = f"{func_name}.py"
4142
for arg in args:
42-
argument = ET.SubElement(function, 'argument')
43+
argument = ET.SubElement(function, "argument")
4344
# We use TabSeparated format, so assume all arguments are strings
44-
ET.SubElement(argument, 'type').text = 'String'
45-
ET.SubElement(argument, 'name').text = arg
45+
ET.SubElement(argument, "type").text = "String"
46+
ET.SubElement(argument, "name").text = arg
4647
tree = ET.ElementTree(root)
4748
tree.write(xml_file)
4849

49-
def to_clickhouse_udf(return_type="String"):
50+
51+
def chdb_udf(return_type="String"):
52+
"""
53+
Decorator for chDB Python UDF(User Defined Function).
54+
1. The function should be stateless. So, only UDFs are supported, not UDAFs(User Defined Aggregation Function).
55+
2. Default return type is String. If you want to change the return type, you can pass in the return type as an argument.
56+
The return type should be one of the following: https://clickhouse.com/docs/en/sql-reference/data-types
57+
3. The function should take in arguments of type String. As the input is TabSeparated, all arguments are strings.
58+
4. The function will be called for each line of input. Something like this:
59+
```
60+
def sum_udf(lhs, rhs):
61+
return int(lhs) + int(rhs)
62+
63+
for line in sys.stdin:
64+
args = line.strip().split('\t')
65+
lhs = args[0]
66+
rhs = args[1]
67+
print(sum_udf(lhs, rhs))
68+
sys.stdout.flush()
69+
```
70+
5. The function should be pure python function. You SHOULD import all python modules used IN THE FUNCTION.
71+
```
72+
def func_use_json(arg):
73+
import json
74+
...
75+
```
76+
6. Python interpertor used is the same as the one used to run the script. Get from `sys.executable`
77+
"""
78+
5079
def decorator(func):
5180
func_name = func.__name__
5281
sig = inspect.signature(func)
5382
args = list(sig.parameters.keys())
5483
src = inspect.getsource(func)
55-
udf_body = src.split("\n", 1)[1]
84+
src = textwrap.dedent(src)
85+
udf_body = src.split("\n", 1)[1] # remove the first line "@chdb_udf()"
86+
# create tmp dir and make sure the dir is deleted when the process exits
87+
if chdb.g_udf_path == "":
88+
chdb.g_udf_path = tempfile.mkdtemp()
89+
90+
# clean up the tmp dir on exit
91+
@atexit.register
92+
def _cleanup():
93+
try:
94+
shutil.rmtree(chdb.g_udf_path)
95+
except:
96+
pass
97+
5698
generate_udf(func_name, args, return_type, udf_body)
99+
57100
@functools.wraps(func)
58101
def wrapper(*args, **kwargs):
59102
return func(*args, **kwargs)
103+
60104
return wrapper
61-
return decorator
62105

106+
return decorator

0 commit comments

Comments
 (0)