From 6ca886cf1ca4bc4f8d6020877d330cf898e66851 Mon Sep 17 00:00:00 2001 From: Kevin Smith Date: Thu, 5 Jun 2025 13:45:51 -0500 Subject: [PATCH 1/3] Add async support --- singlestoredb/functions/decorator.py | 40 ++++++--- singlestoredb/functions/ext/asgi.py | 122 ++++++++++++++++++++++++--- 2 files changed, 137 insertions(+), 25 deletions(-) diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 2280ed401..c95dcf2de 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -1,3 +1,4 @@ +import asyncio import functools import inspect from typing import Any @@ -19,6 +20,7 @@ ] ReturnType = ParameterType +UDFType = Callable[..., Any] def is_valid_type(obj: Any) -> bool: @@ -100,7 +102,8 @@ def _func( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, -) -> Callable[..., Any]: + timeout: Optional[int] = None, +) -> UDFType: """Generic wrapper for UDF and TVF decorators.""" _singlestoredb_attrs = { # type: ignore @@ -115,23 +118,33 @@ def _func( # called later, so the wrapper much be created with the func passed # in at that time. if func is None: - def decorate(func: Callable[..., Any]) -> Callable[..., Any]: + def decorate(func: UDFType) -> UDFType: - def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: - return func(*args, **kwargs) # type: ignore + if asyncio.iscoroutinefunction(func): + async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType: + return await func(*args, **kwargs) # type: ignore + async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + return functools.wraps(func)(async_wrapper) - wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore - - return functools.wraps(func)(wrapper) + else: + def wrapper(*args: Any, **kwargs: Any) -> UDFType: + return func(*args, **kwargs) # type: ignore + wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + return functools.wraps(func)(wrapper) return decorate - def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]: - return func(*args, **kwargs) # type: ignore - - wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + if asyncio.iscoroutinefunction(func): + async def async_wrapper(*args: Any, **kwargs: Any) -> UDFType: + return await func(*args, **kwargs) # type: ignore + async_wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + return functools.wraps(func)(async_wrapper) - return functools.wraps(func)(wrapper) + else: + def wrapper(*args: Any, **kwargs: Any) -> UDFType: + return func(*args, **kwargs) # type: ignore + wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore + return functools.wraps(func)(wrapper) def udf( @@ -140,7 +153,8 @@ def udf( name: Optional[str] = None, args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, -) -> Callable[..., Any]: + timeout: Optional[int] = None, +) -> UDFType: """ Define a user-defined function (UDF). diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 97997ffc9..041c119de 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -273,12 +273,24 @@ def build_udf_endpoint( """ if returns_data_format in ['scalar', 'list']: + is_async = asyncio.iscoroutinefunction(func) + async def do_func( row_ids: Sequence[int], rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' - return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))] + out = [] + for row in rows: + if cancel_event.is_set(): + raise asyncio.CancelledError( + 'Function call was cancelled', + ) + if is_async: + out.append(await func(*row)) + else: + out.append(func(*row)) + return row_ids, list(zip(out)) return do_func @@ -307,6 +319,7 @@ def build_vector_udf_endpoint( """ masks = get_masked_params(func) array_cls = get_array_class(returns_data_format) + is_async = asyncio.iscoroutinefunction(func) async def do_func( row_ids: Sequence[int], @@ -320,9 +333,15 @@ async def do_func( # Call the function with `cols` as the function parameters if cols and cols[0]: - out = func(*[x if m else x[0] for x, m in zip(cols, masks)]) + if is_async: + out = await func(*[x if m else x[0] for x, m in zip(cols, masks)]) + else: + out = func(*[x if m else x[0] for x, m in zip(cols, masks)]) else: - out = func() + if is_async: + out = await func() + else: + out = func() # Single masked value if isinstance(out, Masked): @@ -360,6 +379,8 @@ def build_tvf_endpoint( """ if returns_data_format in ['scalar', 'list']: + is_async = asyncio.iscoroutinefunction(func) + async def do_func( row_ids: Sequence[int], rows: Sequence[Sequence[Any]], @@ -368,7 +389,15 @@ async def do_func( out_ids: List[int] = [] out = [] # Call function on each row of data - for i, res in zip(row_ids, func_map(func, rows)): + for i, row in zip(row_ids, rows): + if cancel_event.is_set(): + raise asyncio.CancelledError( + 'Function call was cancelled', + ) + if is_async: + res = await func(*row) + else: + res = func(*row) out.extend(as_list_of_tuples(res)) out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) return out_ids, out @@ -413,13 +442,23 @@ async def do_func( # each result row, so we just have to use the same # row ID for all rows in the result. + is_async = asyncio.iscoroutinefunction(func) + # Call function on each column of data if cols and cols[0]: - res = get_dataframe_columns( - func(*[x if m else x[0] for x, m in zip(cols, masks)]), - ) + if is_async: + res = get_dataframe_columns( + await func(*[x if m else x[0] for x, m in zip(cols, masks)]), + ) + else: + res = get_dataframe_columns( + func(*[x if m else x[0] for x, m in zip(cols, masks)]), + ) else: - res = get_dataframe_columns(func()) + if is_async: + res = get_dataframe_columns(await func()) + else: + res = get_dataframe_columns(func()) # Generate row IDs if isinstance(res[0], Masked): @@ -477,6 +516,12 @@ def make_func( # Set function type info['function_type'] = function_type + # Set timeout + info['timeout'] = max(timeout, 1) + + # Set async flag + info['is_async'] = asyncio.iscoroutinefunction(func) + # Setup argument types for rowdat_1 parser colspec = [] for x in sig['args']: @@ -859,11 +904,64 @@ async def __call__( output_handler = self.handlers[(accepts, data_version, returns_data_format)] try: - out = await func( - *input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), - ), + result = [] + + cancel_event = threading.Event() + + if func_info['is_async']: + func_task = asyncio.create_task( + func( + cancel_event, + *input_handler['load']( # type: ignore + func_info['colspec'], b''.join(data), + ), + ), + ) + else: + func_task = asyncio.create_task( + to_thread( + lambda: asyncio.run( + func( + cancel_event, + *input_handler['load']( # type: ignore + func_info['colspec'], b''.join(data), + ), + ), + ), + ), + ) + disconnect_task = asyncio.create_task( + cancel_on_disconnect(receive), + ) + timeout_task = asyncio.create_task( + cancel_on_timeout(func_info['timeout']), ) + + all_tasks = [func_task, disconnect_task, timeout_task] + + done, pending = await asyncio.wait( + all_tasks, return_when=asyncio.FIRST_COMPLETED, + ) + + cancel_all_tasks(pending) + + for task in done: + if task is disconnect_task: + cancel_event.set() + raise asyncio.CancelledError( + 'Function call was cancelled by client disconnect', + ) + + elif task is timeout_task: + cancel_event.set() + raise asyncio.TimeoutError( + 'Function call was cancelled due to timeout', + ) + + elif task is func_task: + result.extend(task.result()) + + print(result) body = output_handler['dump']( [x[1] for x in func_info['returns']], *out, # type: ignore ) From cccd5f85e2feffa7415896d72763544298c1dbb4 Mon Sep 17 00:00:00 2001 From: Kaushik Kampli Date: Tue, 8 Jul 2025 15:23:15 +0530 Subject: [PATCH 2/3] remove timeout --- singlestoredb/functions/ext/asgi.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 041c119de..e6bf5cad5 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -516,9 +516,6 @@ def make_func( # Set function type info['function_type'] = function_type - # Set timeout - info['timeout'] = max(timeout, 1) - # Set async flag info['is_async'] = asyncio.iscoroutinefunction(func) From 459d8ab7a90812bccbc165d4bd963998ec0aa66a Mon Sep 17 00:00:00 2001 From: Kaushik Kampli Date: Mon, 14 Jul 2025 11:58:33 +0530 Subject: [PATCH 3/3] interactive --- singlestoredb/functions/ext/asgi.py | 137 +++++++--------------------- 1 file changed, 34 insertions(+), 103 deletions(-) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index e6bf5cad5..264bd0d67 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -66,6 +66,9 @@ from ..signature import signature_to_sql from ..typing import Masked from ..typing import Table +from ...config import get_option +from singlestoredb.connection import build_params + try: import cloudpickle @@ -273,24 +276,12 @@ def build_udf_endpoint( """ if returns_data_format in ['scalar', 'list']: - is_async = asyncio.iscoroutinefunction(func) - async def do_func( row_ids: Sequence[int], rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' - out = [] - for row in rows: - if cancel_event.is_set(): - raise asyncio.CancelledError( - 'Function call was cancelled', - ) - if is_async: - out.append(await func(*row)) - else: - out.append(func(*row)) - return row_ids, list(zip(out)) + return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))] return do_func @@ -319,7 +310,6 @@ def build_vector_udf_endpoint( """ masks = get_masked_params(func) array_cls = get_array_class(returns_data_format) - is_async = asyncio.iscoroutinefunction(func) async def do_func( row_ids: Sequence[int], @@ -332,17 +322,18 @@ async def do_func( row_ids = array_cls(row_ids) # Call the function with `cols` as the function parameters + is_async = inspect.iscoroutinefunction(func) or inspect.iscoroutinefunction(getattr(func, "__wrapped__", None)) if cols and cols[0]: if is_async: out = await func(*[x if m else x[0] for x, m in zip(cols, masks)]) else: - out = func(*[x if m else x[0] for x, m in zip(cols, masks)]) + out = await asyncio.to_thread(func, *[x if m else x[0] for x, m in zip(cols, masks)]) else: if is_async: out = await func() else: - out = func() - + out = await asyncio.to_thread(func()) + # Single masked value if isinstance(out, Masked): return row_ids, [tuple(out)] @@ -379,8 +370,6 @@ def build_tvf_endpoint( """ if returns_data_format in ['scalar', 'list']: - is_async = asyncio.iscoroutinefunction(func) - async def do_func( row_ids: Sequence[int], rows: Sequence[Sequence[Any]], @@ -389,15 +378,7 @@ async def do_func( out_ids: List[int] = [] out = [] # Call function on each row of data - for i, row in zip(row_ids, rows): - if cancel_event.is_set(): - raise asyncio.CancelledError( - 'Function call was cancelled', - ) - if is_async: - res = await func(*row) - else: - res = func(*row) + for i, res in zip(row_ids, func_map(func, rows)): out.extend(as_list_of_tuples(res)) out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) return out_ids, out @@ -442,23 +423,13 @@ async def do_func( # each result row, so we just have to use the same # row ID for all rows in the result. - is_async = asyncio.iscoroutinefunction(func) - # Call function on each column of data if cols and cols[0]: - if is_async: - res = get_dataframe_columns( - await func(*[x if m else x[0] for x, m in zip(cols, masks)]), - ) - else: - res = get_dataframe_columns( - func(*[x if m else x[0] for x, m in zip(cols, masks)]), - ) + res = get_dataframe_columns( + func(*[x if m else x[0] for x, m in zip(cols, masks)]), + ) else: - if is_async: - res = get_dataframe_columns(await func()) - else: - res = get_dataframe_columns(func()) + res = get_dataframe_columns(func()) # Generate row IDs if isinstance(res[0], Masked): @@ -516,9 +487,6 @@ def make_func( # Set function type info['function_type'] = function_type - # Set async flag - info['is_async'] = asyncio.iscoroutinefunction(func) - # Setup argument types for rowdat_1 parser colspec = [] for x in sig['args']: @@ -901,64 +869,11 @@ async def __call__( output_handler = self.handlers[(accepts, data_version, returns_data_format)] try: - result = [] - - cancel_event = threading.Event() - - if func_info['is_async']: - func_task = asyncio.create_task( - func( - cancel_event, - *input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), - ), - ), - ) - else: - func_task = asyncio.create_task( - to_thread( - lambda: asyncio.run( - func( - cancel_event, - *input_handler['load']( # type: ignore - func_info['colspec'], b''.join(data), - ), - ), - ), - ), - ) - disconnect_task = asyncio.create_task( - cancel_on_disconnect(receive), - ) - timeout_task = asyncio.create_task( - cancel_on_timeout(func_info['timeout']), - ) - - all_tasks = [func_task, disconnect_task, timeout_task] - - done, pending = await asyncio.wait( - all_tasks, return_when=asyncio.FIRST_COMPLETED, + out = await func( + *input_handler['load']( # type: ignore + func_info['colspec'], b''.join(data), + ), ) - - cancel_all_tasks(pending) - - for task in done: - if task is disconnect_task: - cancel_event.set() - raise asyncio.CancelledError( - 'Function call was cancelled by client disconnect', - ) - - elif task is timeout_task: - cancel_event.set() - raise asyncio.TimeoutError( - 'Function call was cancelled due to timeout', - ) - - elif task is func_task: - result.extend(task.result()) - - print(result) body = output_handler['dump']( [x[1] for x in func_info['returns']], *out, # type: ignore ) @@ -1066,6 +981,19 @@ def get_function_info( sig = info['signature'] sql_map[sig['name']] = sql + if 'SINGLESTOREDB_URL' in os.environ: + dbname = build_params(host=os.environ['SINGLESTOREDB_URL']).get('database') + elif 'SINGLESTOREDB_HOST' in os.environ: + dbname = build_params(host=os.environ['SINGLESTOREDB_HOST']).get('database') + elif 'SINGLESTOREDB_DATABASE' in os.environ: + dbname = os.environ['SINGLESTOREDB_DATBASE'] + + connection_info = {} + workspace_group_id = os.environ.get('SINGLESTOREDB_WORKSPACE_GROUP') + connection_info['database_name'] = dbname + connection_info['workspace_group_id'] = workspace_group_id + + for key, (_, info) in self.endpoints.items(): if not func_name or key == func_name: sig = info['signature'] @@ -1111,7 +1039,10 @@ def get_function_info( sql_statement=sql, ) - return functions + return { + 'functions': functions, + 'connection_info': connection_info + } def get_create_functions( self,