Skip to content

Commit 3a4797f

Browse files
authored
Add validation of UDF parameters and return types (#83)
* Add validation of UDF parameters and return types * Fix polars issues
1 parent 3a91a21 commit 3a4797f

File tree

4 files changed

+409
-277
lines changed

4 files changed

+409
-277
lines changed

singlestoredb/functions/decorator.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,20 @@ def is_valid_type(obj: Any) -> bool:
4545
return False
4646

4747

48-
def is_valid_callable(obj: Any) -> bool:
48+
def is_sqlstr_callable(obj: Any) -> bool:
4949
"""Check if the object is a valid callable for a parameter type."""
5050
if not callable(obj):
5151
return False
5252

5353
returns = utils.get_annotations(obj).get('return', None)
5454

55-
if inspect.isclass(returns) and issubclass(returns, str):
55+
if inspect.isclass(returns) and issubclass(returns, SQLString):
5656
return True
5757

58-
raise TypeError(
59-
f'callable {obj} must return a str, '
60-
f'but got {returns}',
61-
)
58+
return False
6259

6360

64-
def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
61+
def expand_types(args: Any) -> Optional[List[Any]]:
6562
"""Expand the types for the function arguments / return values."""
6663
if args is None:
6764
return None
@@ -70,28 +67,32 @@ def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
7067
if isinstance(args, str):
7168
return [args]
7269

73-
# General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
74-
elif is_valid_type(args):
75-
return args
76-
7770
# List of SQL strings or callables
7871
elif isinstance(args, list):
79-
new_args = []
72+
new_args: List[Any] = []
8073
for arg in args:
8174
if isinstance(arg, str):
8275
new_args.append(arg)
83-
elif callable(arg):
76+
elif is_sqlstr_callable(arg):
8477
new_args.append(arg())
78+
elif type(arg) is type:
79+
new_args.append(arg)
80+
elif is_valid_type(arg):
81+
new_args.append(arg)
8582
else:
8683
raise TypeError(f'unrecognized type for parameter: {arg}')
8784
return new_args
8885

8986
# Callable that returns a SQL string
90-
elif is_valid_callable(args):
91-
out = args()
92-
if not isinstance(out, str):
93-
raise TypeError(f'unrecognized type for parameter: {args}')
94-
return [out]
87+
elif is_sqlstr_callable(args):
88+
return [args()]
89+
90+
# General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
91+
elif is_valid_type(args):
92+
return [args]
93+
94+
elif type(args) is type:
95+
return [args]
9596

9697
raise TypeError(f'unrecognized type for parameter: {args}')
9798

0 commit comments

Comments
 (0)