Skip to content

Commit d0b02a6

Browse files
committed
Implement metadata-aware PySimpleScalarUDF
Enhance scalar UDF definitions to retain Arrow Field information, including extension metadata, in DataFusion. Normalize Python UDF signatures to accept pyarrow.Field objects, ensuring metadata survives the Rust bindings roundtrip. Add a regression test for UUID-backed UDFs to verify that the second UDF correctly receives a pyarrow.ExtensionArray, preventing past metadata loss.
1 parent d9c90d2 commit d0b02a6

File tree

3 files changed

+196
-27
lines changed

3 files changed

+196
-27
lines changed

python/datafusion/user_defined.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,13 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload
25+
from typing import Any, Callable, Optional, Protocol, Sequence, overload
2626

2727
import pyarrow as pa
2828

2929
import datafusion._internal as df_internal
3030
from datafusion.expr import Expr
3131

32-
if TYPE_CHECKING:
33-
_R = TypeVar("_R", bound=pa.DataType)
34-
35-
3632
class Volatility(Enum):
3733
"""Defines how stable or volatile a function is.
3834
@@ -77,6 +73,40 @@ def __str__(self) -> str:
7773
return self.name.lower()
7874

7975

76+
def _normalize_field(value: pa.DataType | pa.Field, *, default_name: str) -> pa.Field:
77+
if isinstance(value, pa.Field):
78+
return value
79+
if isinstance(value, pa.DataType):
80+
return pa.field(default_name, value)
81+
msg = "Expected a pyarrow.DataType or pyarrow.Field"
82+
raise TypeError(msg)
83+
84+
85+
def _normalize_input_fields(
86+
values: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field],
87+
) -> list[pa.Field]:
88+
if isinstance(values, (pa.DataType, pa.Field)):
89+
sequence: Sequence[pa.DataType | pa.Field] = [values]
90+
elif isinstance(values, Sequence) and not isinstance(values, (str, bytes)):
91+
sequence = values
92+
else:
93+
msg = "input_types must be a DataType, Field, or a sequence of them"
94+
raise TypeError(msg)
95+
96+
return [
97+
_normalize_field(value, default_name=f"arg_{idx}") for idx, value in enumerate(sequence)
98+
]
99+
100+
101+
def _normalize_return_field(
102+
value: pa.DataType | pa.Field,
103+
*,
104+
name: str,
105+
) -> pa.Field:
106+
default_name = f"{name}_result" if name else "result"
107+
return _normalize_field(value, default_name=default_name)
108+
109+
80110
class ScalarUDFExportable(Protocol):
81111
"""Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
82112

@@ -93,9 +123,9 @@ class ScalarUDF:
93123
def __init__(
94124
self,
95125
name: str,
96-
func: Callable[..., _R],
97-
input_types: pa.DataType | list[pa.DataType],
98-
return_type: _R,
126+
func: Callable[..., Any],
127+
input_types: pa.DataType | pa.Field | Sequence[pa.DataType | pa.Field],
128+
return_type: pa.DataType | pa.Field,
99129
volatility: Volatility | str,
100130
) -> None:
101131
"""Instantiate a scalar user-defined function (UDF).
@@ -105,10 +135,10 @@ def __init__(
105135
if hasattr(func, "__datafusion_scalar_udf__"):
106136
self._udf = df_internal.ScalarUDF.from_pycapsule(func)
107137
return
108-
if isinstance(input_types, pa.DataType):
109-
input_types = [input_types]
138+
normalized_inputs = _normalize_input_fields(input_types)
139+
normalized_return = _normalize_return_field(return_type, name=name)
110140
self._udf = df_internal.ScalarUDF(
111-
name, func, input_types, return_type, str(volatility)
141+
name, func, normalized_inputs, normalized_return, str(volatility)
112142
)
113143

114144
def __repr__(self) -> str:
@@ -127,18 +157,18 @@ def __call__(self, *args: Expr) -> Expr:
127157
@overload
128158
@staticmethod
129159
def udf(
130-
input_types: list[pa.DataType],
131-
return_type: _R,
160+
input_types: list[pa.DataType | pa.Field],
161+
return_type: pa.DataType | pa.Field,
132162
volatility: Volatility | str,
133163
name: Optional[str] = None,
134164
) -> Callable[..., ScalarUDF]: ...
135165

136166
@overload
137167
@staticmethod
138168
def udf(
139-
func: Callable[..., _R],
140-
input_types: list[pa.DataType],
141-
return_type: _R,
169+
func: Callable[..., Any],
170+
input_types: list[pa.DataType | pa.Field],
171+
return_type: pa.DataType | pa.Field,
142172
volatility: Volatility | str,
143173
name: Optional[str] = None,
144174
) -> ScalarUDF: ...
@@ -164,10 +194,11 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
164194
backed ScalarUDF within a PyCapsule, you can pass this parameter
165195
and ignore the rest. They will be determined directly from the
166196
underlying function. See the online documentation for more information.
167-
input_types (list[pa.DataType]): The data types of the arguments
168-
to ``func``. This list must be of the same length as the number of
169-
arguments.
170-
return_type (_R): The data type of the return value from the function.
197+
input_types (list[pa.DataType | pa.Field]): The argument types for ``func``.
198+
This list must be of the same length as the number of arguments. Pass
199+
:class:`pyarrow.Field` instances to preserve extension metadata.
200+
return_type (pa.DataType | pa.Field): The return type of the function. Use a
201+
:class:`pyarrow.Field` to preserve metadata on extension arrays.
171202
volatility (Volatility | str): See `Volatility` for allowed values.
172203
name (Optional[str]): A descriptive name for the function.
173204

python/tests/test_udf.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,58 @@ def udf_with_param(values: pa.Array) -> pa.Array:
124124
result = df2.collect()[0].column(0)
125125

126126
assert result == pa.array([False, True, True])
127+
128+
129+
def test_uuid_extension_chain(ctx) -> None:
130+
uuid_type = pa.uuid()
131+
uuid_field = pa.field("uuid_col", uuid_type)
132+
133+
first = udf(
134+
lambda values: values,
135+
[uuid_field],
136+
uuid_field,
137+
volatility="immutable",
138+
name="uuid_identity",
139+
)
140+
141+
def ensure_extension(values: pa.Array) -> pa.Array:
142+
assert isinstance(values, pa.ExtensionArray)
143+
return values
144+
145+
second = udf(
146+
ensure_extension,
147+
[uuid_field],
148+
uuid_field,
149+
volatility="immutable",
150+
name="uuid_assert",
151+
)
152+
153+
batch = pa.RecordBatch.from_arrays(
154+
[
155+
pa.array(
156+
[
157+
"00000000-0000-0000-0000-000000000000",
158+
"00000000-0000-0000-0000-000000000001",
159+
],
160+
type=uuid_type,
161+
)
162+
],
163+
names=["uuid_col"],
164+
)
165+
166+
df = ctx.create_dataframe([[batch]])
167+
result = (
168+
df.select(second(first(column("uuid_col"))))
169+
.collect()[0]
170+
.column(0)
171+
)
172+
173+
expected = pa.array(
174+
[
175+
"00000000-0000-0000-0000-000000000000",
176+
"00000000-0000-0000-0000-000000000001",
177+
],
178+
type=uuid_type,
179+
)
180+
181+
assert result.equals(expected)

src/udf.rs

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,24 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::fmt;
1819
use std::sync::Arc;
1920

2021
use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF};
2122
use pyo3::types::PyCapsule;
2223
use pyo3::{prelude::*, types::PyTuple};
2324

2425
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
25-
use datafusion::arrow::datatypes::DataType;
26+
use datafusion::arrow::datatypes::{DataType, Field};
2627
use datafusion::arrow::pyarrow::FromPyArrow;
2728
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2829
use datafusion::error::DataFusionError;
2930
use datafusion::logical_expr::function::ScalarFunctionImplementation;
30-
use datafusion::logical_expr::ScalarUDF;
31-
use datafusion::logical_expr::{create_udf, ColumnarValue};
31+
use datafusion::logical_expr::ptr_eq::PtrEq;
32+
use datafusion::logical_expr::{
33+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
34+
Volatility,
35+
};
3236

3337
use crate::errors::to_datafusion_err;
3438
use crate::errors::{py_datafusion_err, PyDataFusionResult};
@@ -80,6 +84,83 @@ fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation {
8084
})
8185
}
8286

87+
#[derive(PartialEq, Eq, Hash)]
88+
struct PySimpleScalarUDF {
89+
name: String,
90+
signature: Signature,
91+
return_field: Arc<Field>,
92+
fun: PtrEq<ScalarFunctionImplementation>,
93+
}
94+
95+
impl fmt::Debug for PySimpleScalarUDF {
96+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97+
f.debug_struct("PySimpleScalarUDF")
98+
.field("name", &self.name)
99+
.field("signature", &self.signature)
100+
.field("return_field", &self.return_field)
101+
.finish()
102+
}
103+
}
104+
105+
impl PySimpleScalarUDF {
106+
fn new(
107+
name: impl Into<String>,
108+
input_fields: Vec<Field>,
109+
return_field: Field,
110+
volatility: Volatility,
111+
fun: ScalarFunctionImplementation,
112+
) -> Self {
113+
let signature_types = input_fields
114+
.into_iter()
115+
.map(|field| field.data_type().clone())
116+
.collect();
117+
let signature = Signature::exact(signature_types, volatility);
118+
Self {
119+
name: name.into(),
120+
signature,
121+
return_field: Arc::new(return_field),
122+
fun: fun.into(),
123+
}
124+
}
125+
}
126+
127+
impl ScalarUDFImpl for PySimpleScalarUDF {
128+
fn as_any(&self) -> &dyn std::any::Any {
129+
self
130+
}
131+
132+
fn name(&self) -> &str {
133+
&self.name
134+
}
135+
136+
fn signature(&self) -> &Signature {
137+
&self.signature
138+
}
139+
140+
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::error::Result<DataType> {
141+
Ok(self.return_field.data_type().clone())
142+
}
143+
144+
fn return_field_from_args(
145+
&self,
146+
_args: ReturnFieldArgs,
147+
) -> datafusion::error::Result<Arc<Field>> {
148+
Ok(Arc::new(
149+
self.return_field
150+
.as_ref()
151+
.clone()
152+
.with_name(self.name.clone()),
153+
))
154+
}
155+
156+
fn invoke_with_args(
157+
&self,
158+
args: ScalarFunctionArgs,
159+
) -> datafusion::error::Result<ColumnarValue> {
160+
(self.fun)(&args.args)
161+
}
162+
}
163+
83164
/// Represents a PyScalarUDF
84165
#[pyclass(frozen, name = "ScalarUDF", module = "datafusion", subclass)]
85166
#[derive(Debug, Clone)]
@@ -94,17 +175,19 @@ impl PyScalarUDF {
94175
fn new(
95176
name: &str,
96177
func: PyObject,
97-
input_types: PyArrowType<Vec<DataType>>,
98-
return_type: PyArrowType<DataType>,
178+
input_types: PyArrowType<Vec<Field>>,
179+
return_type: PyArrowType<Field>,
99180
volatility: &str,
100181
) -> PyResult<Self> {
101-
let function = create_udf(
182+
let volatility = parse_volatility(volatility)?;
183+
let scalar_impl = PySimpleScalarUDF::new(
102184
name,
103185
input_types.0,
104186
return_type.0,
105-
parse_volatility(volatility)?,
187+
volatility,
106188
to_scalar_function_impl(func),
107189
);
190+
let function = ScalarUDF::new_from_impl(scalar_impl);
108191
Ok(Self { function })
109192
}
110193

0 commit comments

Comments
 (0)