2222import functools
2323from abc import ABCMeta , abstractmethod
2424from enum import Enum
25- from typing import Any , Callable , Optional , Protocol , Sequence , overload
25+ from typing import Any , Callable , Optional , Protocol , Sequence , TypeVar , cast , overload
2626
2727import pyarrow as pa
2828
2929import datafusion ._internal as df_internal
3030from datafusion .expr import Expr
3131
32+ PyArrowArray = pa .Array | pa .ChunkedArray
33+ # Type alias for array batches exchanged with Python scalar UDFs.
34+ PyArrowArrayT = TypeVar ("PyArrowArrayT" , pa .Array , pa .ChunkedArray )
35+
3236class Volatility (Enum ):
3337 """Defines how stable or volatile a function is.
3438
@@ -113,7 +117,7 @@ def _normalize_return_field(
113117 return _normalize_field (value , default_name = default_name )
114118
115119
116- def _wrap_extension_value (value : Any , data_type : pa .DataType ) -> Any :
120+ def _wrap_extension_value (value : PyArrowArrayT , data_type : pa .DataType ) -> PyArrowArrayT :
117121 storage_type = getattr (data_type , "storage_type" , None )
118122 wrap_array = getattr (data_type , "wrap_array" , None )
119123 if storage_type is None or wrap_array is None :
@@ -127,17 +131,20 @@ def _wrap_extension_value(value: Any, data_type: pa.DataType) -> Any:
127131
128132
129133def _wrap_udf_function (
130- func : Callable [..., Any ],
134+ func : Callable [..., PyArrowArrayT ],
131135 input_fields : Sequence [pa .Field ],
132136 return_field : pa .Field ,
133- ) -> Callable [..., Any ]:
134- def wrapper (* args : Any , ** kwargs : Any ) -> Any :
137+ ) -> Callable [..., PyArrowArrayT ]:
138+ def wrapper (* args : Any , ** kwargs : Any ) -> PyArrowArrayT :
135139 if args :
136- converted_args = list (args )
140+ converted_args : list [ Any ] = list (args )
137141 for idx , field in enumerate (input_fields ):
138142 if idx >= len (converted_args ):
139143 break
140- converted_args [idx ] = _wrap_extension_value (converted_args [idx ], field .type )
144+ converted_args [idx ] = _wrap_extension_value (
145+ cast (PyArrowArray , converted_args [idx ]),
146+ field .type ,
147+ )
141148 else :
142149 converted_args = []
143150 result = func (* converted_args , ** kwargs )
@@ -162,7 +169,7 @@ class ScalarUDF:
162169 def __init__ (
163170 self ,
164171 name : str ,
165- func : Callable [..., Any ] ,
172+ func : Callable [..., PyArrowArray ] | ScalarUDFExportable ,
166173 input_types : pa .DataType | pa .Field | Sequence [pa .DataType | pa .Field ],
167174 return_type : pa .DataType | pa .Field ,
168175 volatility : Volatility | str ,
@@ -201,12 +208,12 @@ def udf(
201208 return_type : pa .DataType | pa .Field ,
202209 volatility : Volatility | str ,
203210 name : Optional [str ] = None ,
204- ) -> Callable [..., ScalarUDF ]: ...
211+ ) -> Callable [[ Callable [ ..., PyArrowArray ]], Callable [..., Expr ] ]: ...
205212
206213 @overload
207214 @staticmethod
208215 def udf (
209- func : Callable [..., Any ],
216+ func : Callable [..., PyArrowArray ],
210217 input_types : list [pa .DataType | pa .Field ],
211218 return_type : pa .DataType | pa .Field ,
212219 volatility : Volatility | str ,
@@ -234,6 +241,8 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
234241 backed ScalarUDF within a PyCapsule, you can pass this parameter
235242 and ignore the rest. They will be determined directly from the
236243 underlying function. See the online documentation for more information.
244+ The callable should accept and return :class:`pyarrow.Array` or
245+ :class:`pyarrow.ChunkedArray` values.
237246 input_types (list[pa.DataType | pa.Field]): The argument types for ``func``.
238247 This list must be of the same length as the number of arguments. Pass
239248 :class:`pyarrow.Field` instances to preserve extension metadata.
@@ -261,7 +270,7 @@ def double_udf(x):
261270 """
262271
263272 def _function (
264- func : Callable [..., Any ],
273+ func : Callable [..., PyArrowArray ],
265274 input_types : list [pa .DataType | pa .Field ],
266275 return_type : pa .DataType | pa .Field ,
267276 volatility : Volatility | str ,
@@ -288,14 +297,14 @@ def _decorator(
288297 return_type : pa .DataType | pa .Field ,
289298 volatility : Volatility | str ,
290299 name : Optional [str ] = None ,
291- ) -> Callable :
292- def decorator (func : Callable ) :
300+ ) -> Callable [[ Callable [..., PyArrowArray ]], Callable [..., Expr ]] :
301+ def decorator (func : Callable [..., PyArrowArray ]) -> Callable [..., Expr ] :
293302 udf_caller = ScalarUDF .udf (
294303 func , input_types , return_type , volatility , name
295304 )
296305
297306 @functools .wraps (func )
298- def wrapper (* args : Any , ** kwargs : Any ):
307+ def wrapper (* args : Any , ** kwargs : Any ) -> Expr :
299308 return udf_caller (* args , ** kwargs )
300309
301310 return wrapper
0 commit comments