Skip to content

Commit fd48684

Browse files
committed
Implement transformers for UDF inputs and outputs
1 parent 19a674a commit fd48684

File tree

13 files changed

+1182
-98
lines changed

13 files changed

+1182
-98
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def cancel_on_event(
311311

312312
def build_udf_endpoint(
313313
func: Callable[..., Any],
314+
args_data_format: str,
314315
returns_data_format: str,
315316
) -> Callable[..., Any]:
316317
"""
@@ -352,11 +353,12 @@ async def do_func(
352353

353354
return do_func
354355

355-
return build_vector_udf_endpoint(func, returns_data_format)
356+
return build_vector_udf_endpoint(func, args_data_format, returns_data_format)
356357

357358

358359
def build_vector_udf_endpoint(
359360
func: Callable[..., Any],
361+
args_data_format: str,
360362
returns_data_format: str,
361363
) -> Callable[..., Any]:
362364
"""
@@ -422,6 +424,7 @@ async def do_func(
422424

423425
def build_tvf_endpoint(
424426
func: Callable[..., Any],
427+
args_data_format: str,
425428
returns_data_format: str,
426429
) -> Callable[..., Any]:
427430
"""
@@ -451,27 +454,27 @@ async def do_func(
451454
rows: Sequence[Sequence[Any]],
452455
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
453456
'''Call function on given rows of data.'''
454-
out_ids: List[int] = []
455-
out = []
457+
out: List[Tuple[Any, ...]] = []
456458
# Call function on each row of data
457459
async with timer('call_function'):
460+
out = []
458461
for i, row in zip(row_ids, rows):
459462
cancel_on_event(cancel_event)
460463
if is_async:
461464
res = await func(*row)
462465
else:
463466
res = func(*row)
464467
out.extend(as_list_of_tuples(res))
465-
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
466-
return out_ids, out
468+
return [row_ids[0]] * len(out), out
467469

468470
return do_func
469471

470-
return build_vector_tvf_endpoint(func, returns_data_format)
472+
return build_vector_tvf_endpoint(func, args_data_format, returns_data_format)
471473

472474

473475
def build_vector_tvf_endpoint(
474476
func: Callable[..., Any],
477+
args_data_format: str,
475478
returns_data_format: str,
476479
) -> Callable[..., Any]:
477480
"""
@@ -575,9 +578,9 @@ def make_func(
575578
)
576579

577580
if function_type == 'tvf':
578-
do_func = build_tvf_endpoint(func, returns_data_format)
581+
do_func = build_tvf_endpoint(func, args_data_format, returns_data_format)
579582
else:
580-
do_func = build_udf_endpoint(func, returns_data_format)
583+
do_func = build_udf_endpoint(func, args_data_format, returns_data_format)
581584

582585
do_func.__name__ = name
583586
do_func.__doc__ = func.__doc__

singlestoredb/functions/ext/rowdat_1.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def _dump_vectors(
462462
default = DEFAULT_VALUES[rtype]
463463
try:
464464
if rtype in numeric_formats:
465-
if value is None:
465+
if is_null or value is None:
466466
out.write(struct.pack(numeric_formats[rtype], default))
467467
else:
468468
if rtype in int_types:
@@ -486,14 +486,14 @@ def _dump_vectors(
486486
),
487487
)
488488
elif rtype in string_types:
489-
if value is None:
489+
if is_null or value is None:
490490
out.write(struct.pack('<q', 0))
491491
else:
492492
sval = value.encode('utf-8')
493493
out.write(struct.pack('<q', len(sval)))
494494
out.write(sval)
495495
elif rtype in binary_types:
496-
if value is None:
496+
if is_null or value is None:
497497
out.write(struct.pack('<q', 0))
498498
else:
499499
out.write(struct.pack('<q', len(value)))
@@ -571,8 +571,18 @@ def _load_numpy_accel(
571571

572572
for i, (_, dtype, transformer) in enumerate(colspec):
573573
if transformer is not None:
574-
t = np.vectorize(transformer)
575-
numpy_cols[i] = (t(numpy_cols[i][0]), numpy_cols[i][1])
574+
# Numpy will try to be "helpful" and create multidimensional arrays
575+
# from nested iterables. We don't usually want that. What we want is
576+
# numpy arrays of Python objects (e.g., lists, dicts, etc). To do that,
577+
# we have to create an empty array of the correct length and dtype=object,
578+
# then fill it in with the transformed values. The transformer may have
579+
# an output_type attribute that we can use to create a more specific type.
580+
if getattr(transformer, 'output_type', None):
581+
new_col = np.empty(len(numpy_cols[i][0]), dtype=transformer.output_type)
582+
new_col[:] = list(map(transformer, numpy_cols[i][0]))
583+
else:
584+
new_col = np.array(list(map(transformer, numpy_cols[i][0])))
585+
numpy_cols[i] = (new_col, numpy_cols[i][1])
576586

577587
return numpy_ids, numpy_cols
578588

@@ -589,8 +599,7 @@ def _dump_numpy_accel(
589599

590600
for i, (_, dtype, transformer) in enumerate(returns):
591601
if transformer is not None:
592-
t = np.vectorize(transformer)
593-
cols[i] = (t(cols[i][0]), cols[i][1])
602+
cols[i] = (np.array(list(map(transformer, cols[i][0]))), cols[i][1])
594603

595604
return _singlestoredb_accel.dump_rowdat_1_numpy(returns, row_ids, cols)
596605

@@ -678,10 +687,18 @@ def _dump_polars_accel(
678687
if not has_accel:
679688
raise RuntimeError('could not load SingleStoreDB extension')
680689

690+
import numpy as np
691+
import polars as pl
692+
681693
numpy_ids = row_ids.to_numpy()
682694
numpy_cols = [
683695
(
684-
data.to_numpy(),
696+
# Polars will try to be "helpful" and convert nested iterables into
697+
# multidimensional arrays. We don't usually want that. What we want is
698+
# numpy arrays of Python objects (e.g., lists, dicts, etc). To
699+
# do that, we have to convert the Series to a list first.
700+
np.array(data.to_list())
701+
if isinstance(data.dtype, (pl.Struct, pl.Object)) else data.to_numpy(),
685702
mask.to_numpy() if mask is not None else None,
686703
)
687704
for data, mask in cols
@@ -722,7 +739,7 @@ def _create_arrow_mask(
722739
if mask is None:
723740
return data.is_null().to_numpy(zero_copy_only=False)
724741

725-
return pc.or_(data.is_null(), mask.is_null()).to_numpy(zero_copy_only=False)
742+
return pc.or_(data.is_null(), mask).to_numpy(zero_copy_only=False)
726743

727744

728745
def _dump_arrow_accel(

singlestoredb/functions/ext/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import zipfile
88
from copy import copy
99
from typing import Any
10-
from typing import Callable
1110
from typing import Dict
1211
from typing import List
1312
from typing import Optional
@@ -32,8 +31,7 @@ def formatMessage(self, record: logging.LogRecord) -> str:
3231
recordcopy.__dict__['levelprefix'] = levelname + ':' + seperator
3332
return super().formatMessage(recordcopy)
3433

35-
36-
Transformer = Callable[..., Any]
34+
from ..typing import Transformer
3735

3836

3937
def apply_transformer(func: Optional[Transformer], v: Any) -> Any:

singlestoredb/functions/signature.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,7 @@ def get_schema(
921921
spec: Any,
922922
overrides: Optional[List[ParamSpec]] = None,
923923
mode: str = 'parameter',
924+
masks: Optional[List[bool]] = None,
924925
) -> Tuple[List[ParamSpec], str, str]:
925926
"""
926927
Expand a return type annotation into a list of types and field names.
@@ -933,6 +934,8 @@ def get_schema(
933934
List of SQL type specifications for the return type
934935
mode : str
935936
The mode of the function, either 'parameter' or 'return'
937+
is_masked : bool
938+
Whether the type is wrapped in a Masked type
936939
937940
Returns
938941
-------
@@ -994,7 +997,13 @@ def get_schema(
994997
'dataclass, TypedDict, or pydantic model',
995998
)
996999
spec = typing.get_args(unpacked_spec[0])[0]
997-
data_format = 'list'
1000+
# Lists as output from TVFs are considered scalar outputs
1001+
# since they correspond to individual Python objects, not
1002+
# a true vector type.
1003+
if function_type == 'tvf':
1004+
data_format = 'scalar'
1005+
else:
1006+
data_format = 'list'
9981007

9991008
elif all([utils.is_vector(x, include_masks=True) for x in unpacked_spec]):
10001009
pass
@@ -1111,7 +1120,11 @@ def get_schema(
11111120
_, inner_apply_meta = unpack_annotated(typing.get_args(spec)[0])
11121121
if inner_apply_meta.sql_type:
11131122
udf_attrs = inner_apply_meta
1114-
colspec = get_schema(typing.get_args(spec)[0], mode=mode)[0]
1123+
colspec = get_schema(
1124+
typing.get_args(spec)[0],
1125+
mode=mode,
1126+
masks=[masks[0]] if masks else None,
1127+
)[0]
11151128
else:
11161129
colspec = [
11171130
ParamSpec(
@@ -1142,6 +1155,7 @@ def get_schema(
11421155
overrides=[overrides[i]] if overrides else [],
11431156
# Always pass UDF mode for individual items
11441157
mode=mode,
1158+
masks=[masks[i]] if masks else None,
11451159
)
11461160

11471161
# Use the name from the overrides if specified
@@ -1183,7 +1197,7 @@ def get_schema(
11831197
out = []
11841198

11851199
# Normalize colspec data types
1186-
for c in colspec:
1200+
for i, c in enumerate(colspec):
11871201

11881202
# if the dtype is a string, it is resolved already
11891203
if isinstance(c.dtype, str):
@@ -1201,13 +1215,27 @@ def get_schema(
12011215
include_null=c.is_optional,
12021216
)
12031217

1218+
sql_type = c.sql_type if isinstance(c.sql_type, str) else udf_attrs.sql_type
1219+
1220+
is_optional = (
1221+
c.is_optional
1222+
or bool(dtype and dtype.endswith('?'))
1223+
or bool(masks and masks[i])
1224+
)
1225+
1226+
if is_optional:
1227+
if dtype and not dtype.endswith('?'):
1228+
dtype += '?'
1229+
if sql_type and re.search(r' NOT NULL\b', sql_type):
1230+
sql_type = re.sub(r' NOT NULL\b', r' NULL', sql_type)
1231+
12041232
p = ParamSpec(
12051233
name=c.name,
12061234
dtype=dtype,
1207-
sql_type=c.sql_type if isinstance(c.sql_type, str) else udf_attrs.sql_type,
1208-
is_optional=c.is_optional or bool(dtype and dtype.endswith('?')),
1209-
transformer=udf_attrs.input_transformer
1210-
if mode == 'parameter' else udf_attrs.output_transformer,
1235+
sql_type=sql_type,
1236+
is_optional=is_optional,
1237+
transformer=udf_attrs.args_transformer
1238+
if mode == 'parameter' else udf_attrs.returns_transformer,
12111239
)
12121240

12131241
out.append(p)
@@ -1345,6 +1373,7 @@ def get_signature(
13451373
unpack_masked_type(param.annotation),
13461374
overrides=[args_colspec[i]] if args_colspec else [],
13471375
mode='parameter',
1376+
masks=[args_masks[i]] if args_masks else [],
13481377
)
13491378
args_data_formats.append(args_data_format)
13501379

@@ -1404,6 +1433,7 @@ def get_signature(
14041433
unpack_masked_type(signature.return_annotation),
14051434
overrides=returns_colspec if returns_colspec else None,
14061435
mode='return',
1436+
masks=ret_masks or [],
14071437
)
14081438

14091439
rdf = out['returns_data_format'] = out['returns_data_format'] or 'scalar'
@@ -1419,6 +1449,12 @@ def get_signature(
14191449
'scalar or vector types.',
14201450
)
14211451

1452+
# If we hava function parameters and the function is a TVF, then
1453+
# the return type should just match the parameter vector types. This ensures
1454+
# the output producers for scalars and vectors are consistent.
1455+
elif function_type == 'tvf' and rdf == 'scalar' and args_schema:
1456+
out['returns_data_format'] = out['args_data_format']
1457+
14221458
# All functions have to return a value, so if none was specified try to
14231459
# insert a reasonable default that includes NULLs.
14241460
if not ret_schema:

0 commit comments

Comments
 (0)