11import functools
22import inspect
33import os
4+ import sys
45import tempfile
6+ import atexit
7+ import shutil
8+ import textwrap
59from xml .etree import ElementTree as ET
10+ import chdb
611
7- tempdir = tempfile .TemporaryDirectory ()
8- # os.chdir(tempdir.name)
9- os .chdir ("user_scripts" )
10- # print(f"Current working directory: {os.getcwd()}")
1112
1213def generate_udf (func_name , args , return_type , udf_body ):
1314 # generate python script
14- with open (f"{ func_name } .py" , "w" ) as f :
15- f .write ("#!/usr/bin/python3 \n " )
15+ with open (f"{ chdb . g_udf_path } / { func_name } .py" , "w" ) as f :
16+ f .write (f "#!{ sys . executable } \n " )
1617 f .write ("import sys\n " )
1718 f .write ("\n " )
1819 for line in udf_body .split ("\n " ):
@@ -25,38 +26,81 @@ def generate_udf(func_name, args, return_type, udf_body):
2526 f .write (f" { arg } = args[{ i } ]\n " )
2627 f .write (f" print({ func_name } ({ ', ' .join (args )} ))\n " )
2728 f .write (" sys.stdout.flush()\n " )
28- os .chmod (f"{ func_name } .py" , 0o755 )
29+ os .chmod (f"{ chdb . g_udf_path } / { func_name } .py" , 0o755 )
2930 # generate xml file
30- xml_file = " udf_config.xml"
31- root = ET .Element (' functions' )
31+ xml_file = f" { chdb . g_udf_path } / udf_config.xml"
32+ root = ET .Element (" functions" )
3233 if os .path .exists (xml_file ):
3334 tree = ET .parse (xml_file )
3435 root = tree .getroot ()
35- function = ET .SubElement (root , ' function' )
36- ET .SubElement (function , ' type' ).text = ' executable'
37- ET .SubElement (function , ' name' ).text = func_name
38- ET .SubElement (function , ' return_type' ).text = return_type
39- ET .SubElement (function , ' format' ).text = ' TabSeparated'
40- ET .SubElement (function , ' command' ).text = f"{ func_name } .py"
36+ function = ET .SubElement (root , " function" )
37+ ET .SubElement (function , " type" ).text = " executable"
38+ ET .SubElement (function , " name" ).text = func_name
39+ ET .SubElement (function , " return_type" ).text = return_type
40+ ET .SubElement (function , " format" ).text = " TabSeparated"
41+ ET .SubElement (function , " command" ).text = f"{ func_name } .py"
4142 for arg in args :
42- argument = ET .SubElement (function , ' argument' )
43+ argument = ET .SubElement (function , " argument" )
4344 # We use TabSeparated format, so assume all arguments are strings
44- ET .SubElement (argument , ' type' ).text = ' String'
45- ET .SubElement (argument , ' name' ).text = arg
45+ ET .SubElement (argument , " type" ).text = " String"
46+ ET .SubElement (argument , " name" ).text = arg
4647 tree = ET .ElementTree (root )
4748 tree .write (xml_file )
4849
49- def to_clickhouse_udf (return_type = "String" ):
50+
51+ def chdb_udf (return_type = "String" ):
52+ """
53+ Decorator for chDB Python UDF(User Defined Function).
54+ 1. The function should be stateless. So, only UDFs are supported, not UDAFs(User Defined Aggregation Function).
55+ 2. Default return type is String. If you want to change the return type, you can pass in the return type as an argument.
56+ The return type should be one of the following: https://clickhouse.com/docs/en/sql-reference/data-types
57+ 3. The function should take in arguments of type String. As the input is TabSeparated, all arguments are strings.
58+ 4. The function will be called for each line of input. Something like this:
59+ ```
60+ def sum_udf(lhs, rhs):
61+ return int(lhs) + int(rhs)
62+
63+ for line in sys.stdin:
64+ args = line.strip().split('\t ')
65+ lhs = args[0]
66+ rhs = args[1]
67+ print(sum_udf(lhs, rhs))
68+ sys.stdout.flush()
69+ ```
70+ 5. The function should be pure python function. You SHOULD import all python modules used IN THE FUNCTION.
71+ ```
72+ def func_use_json(arg):
73+ import json
74+ ...
75+ ```
76+ 6. Python interpertor used is the same as the one used to run the script. Get from `sys.executable`
77+ """
78+
5079 def decorator (func ):
5180 func_name = func .__name__
5281 sig = inspect .signature (func )
5382 args = list (sig .parameters .keys ())
5483 src = inspect .getsource (func )
55- udf_body = src .split ("\n " , 1 )[1 ]
84+ src = textwrap .dedent (src )
85+ udf_body = src .split ("\n " , 1 )[1 ] # remove the first line "@chdb_udf()"
86+ # create tmp dir and make sure the dir is deleted when the process exits
87+ if chdb .g_udf_path == "" :
88+ chdb .g_udf_path = tempfile .mkdtemp ()
89+
90+ # clean up the tmp dir on exit
91+ @atexit .register
92+ def _cleanup ():
93+ try :
94+ shutil .rmtree (chdb .g_udf_path )
95+ except :
96+ pass
97+
5698 generate_udf (func_name , args , return_type , udf_body )
99+
57100 @functools .wraps (func )
58101 def wrapper (* args , ** kwargs ):
59102 return func (* args , ** kwargs )
103+
60104 return wrapper
61- return decorator
62105
106+ return decorator
0 commit comments