Skip to content

Commit 5aacb41

Browse files
committed
Add shared PyArrowArray alias and refine ScalarUDFs
Introduce a shared alias for PyArrowArray and update the extension wrapping helpers to ensure scalar UDF return types are preserved when handling PyArrow arrays. Enhance ScalarUDF signatures, overloads, and documentation to align with the PyArrow array contract for Python scalar UDFs.
1 parent 6a89977 commit 5aacb41

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

python/datafusion/user_defined.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import Any, Callable, Optional, Protocol, Sequence, overload
25+
from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, cast, overload
2626

2727
import pyarrow as pa
2828

2929
import datafusion._internal as df_internal
3030
from datafusion.expr import Expr
3131

32+
PyArrowArray = pa.Array | pa.ChunkedArray
33+
# Type alias for array batches exchanged with Python scalar UDFs.
34+
PyArrowArrayT = TypeVar("PyArrowArrayT", pa.Array, pa.ChunkedArray)
35+
3236
class Volatility(Enum):
3337
"""Defines how stable or volatile a function is.
3438
@@ -113,7 +117,7 @@ def _normalize_return_field(
113117
return _normalize_field(value, default_name=default_name)
114118

115119

116-
def _wrap_extension_value(value: Any, data_type: pa.DataType) -> Any:
120+
def _wrap_extension_value(value: PyArrowArrayT, data_type: pa.DataType) -> PyArrowArrayT:
117121
storage_type = getattr(data_type, "storage_type", None)
118122
wrap_array = getattr(data_type, "wrap_array", None)
119123
if storage_type is None or wrap_array is None:
@@ -127,17 +131,20 @@ def _wrap_extension_value(value: Any, data_type: pa.DataType) -> Any:
127131

128132

129133
def _wrap_udf_function(
130-
func: Callable[..., Any],
134+
func: Callable[..., PyArrowArrayT],
131135
input_fields: Sequence[pa.Field],
132136
return_field: pa.Field,
133-
) -> Callable[..., Any]:
134-
def wrapper(*args: Any, **kwargs: Any) -> Any:
137+
) -> Callable[..., PyArrowArrayT]:
138+
def wrapper(*args: Any, **kwargs: Any) -> PyArrowArrayT:
135139
if args:
136-
converted_args = list(args)
140+
converted_args: list[Any] = list(args)
137141
for idx, field in enumerate(input_fields):
138142
if idx >= len(converted_args):
139143
break
140-
converted_args[idx] = _wrap_extension_value(converted_args[idx], field.type)
144+
converted_args[idx] = _wrap_extension_value(
145+
cast(PyArrowArray, converted_args[idx]),
146+
field.type,
147+
)
141148
else:
142149
converted_args = []
143150
result = func(*converted_args, **kwargs)
@@ -162,7 +169,7 @@ class ScalarUDF:
162169
def __init__(
163170
self,
164171
name: str,
165-
func: Callable[..., Any],
172+
func: Callable[..., PyArrowArray] | ScalarUDFExportable,
166173
input_types: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field],
167174
return_type: pa.DataType | pa.Field,
168175
volatility: Volatility | str,
@@ -201,12 +208,12 @@ def udf(
201208
return_type: pa.DataType | pa.Field,
202209
volatility: Volatility | str,
203210
name: Optional[str] = None,
204-
) -> Callable[..., ScalarUDF]: ...
211+
) -> Callable[[Callable[..., PyArrowArray]], Callable[..., Expr]]: ...
205212

206213
@overload
207214
@staticmethod
208215
def udf(
209-
func: Callable[..., Any],
216+
func: Callable[..., PyArrowArray],
210217
input_types: list[pa.DataType | pa.Field],
211218
return_type: pa.DataType | pa.Field,
212219
volatility: Volatility | str,
@@ -234,6 +241,8 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
234241
backed ScalarUDF within a PyCapsule, you can pass this parameter
235242
and ignore the rest. They will be determined directly from the
236243
underlying function. See the online documentation for more information.
244+
The callable should accept and return :class:`pyarrow.Array` or
245+
:class:`pyarrow.ChunkedArray` values.
237246
input_types (list[pa.DataType | pa.Field]): The argument types for ``func``.
238247
This list must be of the same length as the number of arguments. Pass
239248
:class:`pyarrow.Field` instances to preserve extension metadata.
@@ -261,7 +270,7 @@ def double_udf(x):
261270
"""
262271

263272
def _function(
264-
func: Callable[..., Any],
273+
func: Callable[..., PyArrowArray],
265274
input_types: list[pa.DataType | pa.Field],
266275
return_type: pa.DataType | pa.Field,
267276
volatility: Volatility | str,
@@ -288,14 +297,14 @@ def _decorator(
288297
return_type: pa.DataType | pa.Field,
289298
volatility: Volatility | str,
290299
name: Optional[str] = None,
291-
) -> Callable:
292-
def decorator(func: Callable):
300+
) -> Callable[[Callable[..., PyArrowArray]], Callable[..., Expr]]:
301+
def decorator(func: Callable[..., PyArrowArray]) -> Callable[..., Expr]:
293302
udf_caller = ScalarUDF.udf(
294303
func, input_types, return_type, volatility, name
295304
)
296305

297306
@functools.wraps(func)
298-
def wrapper(*args: Any, **kwargs: Any):
307+
def wrapper(*args: Any, **kwargs: Any) -> Expr:
299308
return udf_caller(*args, **kwargs)
300309

301310
return wrapper

0 commit comments

Comments
 (0)