2222import functools
2323from abc import ABCMeta , abstractmethod
2424from enum import Enum
25- from typing import TYPE_CHECKING , Any , Callable , Optional , Protocol , TypeVar , overload
25+ from typing import Any , Callable , Optional , Protocol , Sequence , overload
2626
2727import pyarrow as pa
2828
2929import datafusion ._internal as df_internal
3030from datafusion .expr import Expr
3131
32- if TYPE_CHECKING :
33- _R = TypeVar ("_R" , bound = pa .DataType )
34-
35-
3632class Volatility (Enum ):
3733 """Defines how stable or volatile a function is.
3834
@@ -77,6 +73,40 @@ def __str__(self) -> str:
7773 return self .name .lower ()
7874
7975
76+ def _normalize_field (value : pa .DataType | pa .Field , * , default_name : str ) -> pa .Field :
77+ if isinstance (value , pa .Field ):
78+ return value
79+ if isinstance (value , pa .DataType ):
80+ return pa .field (default_name , value )
81+ msg = "Expected a pyarrow.DataType or pyarrow.Field"
82+ raise TypeError (msg )
83+
84+
85+ def _normalize_input_fields (
86+ values : pa .DataType | pa .Field | Sequence [pa .DataType | pa .Field ],
87+ ) -> list [pa .Field ]:
88+ if isinstance (values , (pa .DataType , pa .Field )):
89+ sequence : Sequence [pa .DataType | pa .Field ] = [values ]
90+ elif isinstance (values , Sequence ) and not isinstance (values , (str , bytes )):
91+ sequence = values
92+ else :
93+ msg = "input_types must be a DataType, Field, or a sequence of them"
94+ raise TypeError (msg )
95+
96+ return [
97+ _normalize_field (value , default_name = f"arg_{ idx } " ) for idx , value in enumerate (sequence )
98+ ]
99+
100+
101+ def _normalize_return_field (
102+ value : pa .DataType | pa .Field ,
103+ * ,
104+ name : str ,
105+ ) -> pa .Field :
106+ default_name = f"{ name } _result" if name else "result"
107+ return _normalize_field (value , default_name = default_name )
108+
109+
80110class ScalarUDFExportable (Protocol ):
81111 """Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
82112
@@ -93,9 +123,9 @@ class ScalarUDF:
93123 def __init__ (
94124 self ,
95125 name : str ,
96- func : Callable [..., _R ],
97- input_types : pa .DataType | list [pa .DataType ],
98- return_type : _R ,
126+ func : Callable [..., Any ],
127+ input_types : pa .DataType | pa . Field | Sequence [pa .DataType | pa . Field ],
128+ return_type : pa . DataType | pa . Field ,
99129 volatility : Volatility | str ,
100130 ) -> None :
101131 """Instantiate a scalar user-defined function (UDF).
@@ -105,10 +135,10 @@ def __init__(
105135 if hasattr (func , "__datafusion_scalar_udf__" ):
106136 self ._udf = df_internal .ScalarUDF .from_pycapsule (func )
107137 return
108- if isinstance (input_types , pa . DataType ):
109- input_types = [ input_types ]
138+ normalized_inputs = _normalize_input_fields (input_types )
139+ normalized_return = _normalize_return_field ( return_type , name = name )
110140 self ._udf = df_internal .ScalarUDF (
111- name , func , input_types , return_type , str (volatility )
141+ name , func , normalized_inputs , normalized_return , str (volatility )
112142 )
113143
114144 def __repr__ (self ) -> str :
@@ -127,18 +157,18 @@ def __call__(self, *args: Expr) -> Expr:
127157 @overload
128158 @staticmethod
129159 def udf (
130- input_types : list [pa .DataType ],
131- return_type : _R ,
160+ input_types : list [pa .DataType | pa . Field ],
161+ return_type : pa . DataType | pa . Field ,
132162 volatility : Volatility | str ,
133163 name : Optional [str ] = None ,
134164 ) -> Callable [..., ScalarUDF ]: ...
135165
136166 @overload
137167 @staticmethod
138168 def udf (
139- func : Callable [..., _R ],
140- input_types : list [pa .DataType ],
141- return_type : _R ,
169+ func : Callable [..., Any ],
170+ input_types : list [pa .DataType | pa . Field ],
171+ return_type : pa . DataType | pa . Field ,
142172 volatility : Volatility | str ,
143173 name : Optional [str ] = None ,
144174 ) -> ScalarUDF : ...
@@ -164,10 +194,11 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
164194 backed ScalarUDF within a PyCapsule, you can pass this parameter
165195 and ignore the rest. They will be determined directly from the
166196 underlying function. See the online documentation for more information.
167- input_types (list[pa.DataType]): The data types of the arguments
168- to ``func``. This list must be of the same length as the number of
169- arguments.
170- return_type (_R): The data type of the return value from the function.
197+ input_types (list[pa.DataType | pa.Field]): The argument types for ``func``.
198+ This list must be of the same length as the number of arguments. Pass
199+ :class:`pyarrow.Field` instances to preserve extension metadata.
200+ return_type (pa.DataType | pa.Field): The return type of the function. Use a
201+ :class:`pyarrow.Field` to preserve metadata on extension arrays.
171202 volatility (Volatility | str): See `Volatility` for allowed values.
172203 name (Optional[str]): A descriptive name for the function.
173204
0 commit comments