Skip to content

Commit 2325993

Browse files
committed
Clone pyarrow.Field objects for FFI handoff
Wrap scalar UDF inputs/outputs to maintain extension types during execution. Enhance UUID extension regression test to ensure metadata retention and normalize results for accurate comparison.
1 parent d0b02a6 commit 2325993

File tree

2 files changed

+60
-18
lines changed

2 files changed

+60
-18
lines changed

python/datafusion/user_defined.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,17 @@ def __str__(self) -> str:
7373
return self.name.lower()
7474

7575

76+
def _clone_field(field: pa.Field) -> pa.Field:
77+
"""Return a deep copy of ``field`` including its DataType."""
78+
79+
return pa.schema([field]).field(0)
80+
81+
7682
def _normalize_field(value: pa.DataType | pa.Field, *, default_name: str) -> pa.Field:
7783
if isinstance(value, pa.Field):
78-
return value
84+
return _clone_field(value)
7985
if isinstance(value, pa.DataType):
80-
return pa.field(default_name, value)
86+
return _clone_field(pa.field(default_name, value))
8187
msg = "Expected a pyarrow.DataType or pyarrow.Field"
8288
raise TypeError(msg)
8389

@@ -107,6 +113,39 @@ def _normalize_return_field(
107113
return _normalize_field(value, default_name=default_name)
108114

109115

116+
def _wrap_extension_value(value: Any, data_type: pa.DataType) -> Any:
117+
storage_type = getattr(data_type, "storage_type", None)
118+
wrap_array = getattr(data_type, "wrap_array", None)
119+
if storage_type is None or wrap_array is None:
120+
return value
121+
if isinstance(value, pa.Array) and value.type.equals(storage_type):
122+
return wrap_array(value)
123+
if isinstance(value, pa.ChunkedArray) and value.type.equals(storage_type):
124+
wrapped_chunks = [wrap_array(chunk) for chunk in value.chunks]
125+
return pa.chunked_array(wrapped_chunks)
126+
return value
127+
128+
129+
def _wrap_udf_function(
130+
func: Callable[..., Any],
131+
input_fields: Sequence[pa.Field],
132+
return_field: pa.Field,
133+
) -> Callable[..., Any]:
134+
def wrapper(*args: Any, **kwargs: Any) -> Any:
135+
if args:
136+
converted_args = list(args)
137+
for idx, field in enumerate(input_fields):
138+
if idx >= len(converted_args):
139+
break
140+
converted_args[idx] = _wrap_extension_value(converted_args[idx], field.type)
141+
else:
142+
converted_args = []
143+
result = func(*converted_args, **kwargs)
144+
return _wrap_extension_value(result, return_field.type)
145+
146+
return wrapper
147+
148+
110149
class ScalarUDFExportable(Protocol):
111150
"""Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
112151

@@ -137,8 +176,9 @@ def __init__(
137176
return
138177
normalized_inputs = _normalize_input_fields(input_types)
139178
normalized_return = _normalize_return_field(return_type, name=name)
179+
wrapped_func = _wrap_udf_function(func, normalized_inputs, normalized_return)
140180
self._udf = df_internal.ScalarUDF(
141-
name, func, normalized_inputs, normalized_return, str(volatility)
181+
name, wrapped_func, normalized_inputs, normalized_return, str(volatility)
142182
)
143183

144184
def __repr__(self) -> str:

python/tests/test_udf.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import uuid
19+
1820
import pyarrow as pa
1921
import pytest
2022
from datafusion import column, udf
@@ -150,16 +152,19 @@ def ensure_extension(values: pa.Array) -> pa.Array:
150152
name="uuid_assert",
151153
)
152154

153-
batch = pa.RecordBatch.from_arrays(
155+
# The UUID extension metadata should survive UDF registration.
156+
assert getattr(uuid_type, "extension_name", None) == "arrow.uuid"
157+
assert getattr(uuid_field.type, "extension_name", None) == "arrow.uuid"
158+
159+
storage = pa.array(
154160
[
155-
pa.array(
156-
[
157-
"00000000-0000-0000-0000-000000000000",
158-
"00000000-0000-0000-0000-000000000001",
159-
],
160-
type=uuid_type,
161-
)
161+
uuid.UUID("00000000-0000-0000-0000-000000000000").bytes,
162+
uuid.UUID("00000000-0000-0000-0000-000000000001").bytes,
162163
],
164+
type=uuid_type.storage_type,
165+
)
166+
batch = pa.RecordBatch.from_arrays(
167+
[uuid_type.wrap_array(storage)],
163168
names=["uuid_col"],
164169
)
165170

@@ -170,12 +175,9 @@ def ensure_extension(values: pa.Array) -> pa.Array:
170175
.column(0)
171176
)
172177

173-
expected = pa.array(
174-
[
175-
"00000000-0000-0000-0000-000000000000",
176-
"00000000-0000-0000-0000-000000000001",
177-
],
178-
type=uuid_type,
179-
)
178+
expected = uuid_type.wrap_array(storage)
179+
180+
if isinstance(result, pa.Array) and result.type.equals(uuid_type.storage_type):
181+
result = uuid_type.wrap_array(result)
180182

181183
assert result.equals(expected)

0 commit comments

Comments
 (0)