@@ -73,11 +73,17 @@ def __str__(self) -> str:
7373 return self .name .lower ()
7474
7575
76+ def _clone_field (field : pa .Field ) -> pa .Field :
77+ """Return a deep copy of ``field`` including its DataType."""
78+
79+ return pa .schema ([field ]).field (0 )
80+
81+
7682def _normalize_field (value : pa .DataType | pa .Field , * , default_name : str ) -> pa .Field :
7783 if isinstance (value , pa .Field ):
78- return value
84+ return _clone_field ( value )
7985 if isinstance (value , pa .DataType ):
80- return pa .field (default_name , value )
86+ return _clone_field ( pa .field (default_name , value ) )
8187 msg = "Expected a pyarrow.DataType or pyarrow.Field"
8288 raise TypeError (msg )
8389
@@ -107,6 +113,39 @@ def _normalize_return_field(
107113 return _normalize_field (value , default_name = default_name )
108114
109115
116+ def _wrap_extension_value (value : Any , data_type : pa .DataType ) -> Any :
117+ storage_type = getattr (data_type , "storage_type" , None )
118+ wrap_array = getattr (data_type , "wrap_array" , None )
119+ if storage_type is None or wrap_array is None :
120+ return value
121+ if isinstance (value , pa .Array ) and value .type .equals (storage_type ):
122+ return wrap_array (value )
123+ if isinstance (value , pa .ChunkedArray ) and value .type .equals (storage_type ):
124+ wrapped_chunks = [wrap_array (chunk ) for chunk in value .chunks ]
125+ return pa .chunked_array (wrapped_chunks )
126+ return value
127+
128+
129+ def _wrap_udf_function (
130+ func : Callable [..., Any ],
131+ input_fields : Sequence [pa .Field ],
132+ return_field : pa .Field ,
133+ ) -> Callable [..., Any ]:
134+ def wrapper (* args : Any , ** kwargs : Any ) -> Any :
135+ if args :
136+ converted_args = list (args )
137+ for idx , field in enumerate (input_fields ):
138+ if idx >= len (converted_args ):
139+ break
140+ converted_args [idx ] = _wrap_extension_value (converted_args [idx ], field .type )
141+ else :
142+ converted_args = []
143+ result = func (* converted_args , ** kwargs )
144+ return _wrap_extension_value (result , return_field .type )
145+
146+ return wrapper
147+
148+
110149class ScalarUDFExportable (Protocol ):
111150 """Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
112151
@@ -137,8 +176,9 @@ def __init__(
137176 return
138177 normalized_inputs = _normalize_input_fields (input_types )
139178 normalized_return = _normalize_return_field (return_type , name = name )
179+ wrapped_func = _wrap_udf_function (func , normalized_inputs , normalized_return )
140180 self ._udf = df_internal .ScalarUDF (
141- name , func , normalized_inputs , normalized_return , str (volatility )
181+ name , wrapped_func , normalized_inputs , normalized_return , str (volatility )
142182 )
143183
144184 def __repr__ (self ) -> str :
0 commit comments