Skip to content

Commit 16224e2

Browse files
committed
Fix Ruff errors
1 parent 308e774 commit 16224e2

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

python/datafusion/user_defined.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,17 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, Union, cast, overload
25+
from typing import (
26+
Any,
27+
Callable,
28+
Optional,
29+
Protocol,
30+
Sequence,
31+
TypeVar,
32+
Union,
33+
cast,
34+
overload,
35+
)
2636

2737
import pyarrow as pa
2838

@@ -33,6 +43,7 @@
3343
# Type alias for array batches exchanged with Python scalar UDFs.
3444
PyArrowArrayT = TypeVar("PyArrowArrayT", pa.Array, pa.ChunkedArray)
3545

46+
3647
class Volatility(Enum):
3748
"""Defines how stable or volatile a function is.
3849
@@ -79,7 +90,6 @@ def __str__(self) -> str:
7990

8091
def _clone_field(field: pa.Field) -> pa.Field:
8192
"""Return a deep copy of ``field`` including its DataType."""
82-
8393
return pa.schema([field]).field(0)
8494

8595

@@ -104,7 +114,8 @@ def _normalize_input_fields(
104114
raise TypeError(msg)
105115

106116
return [
107-
_normalize_field(value, default_name=f"arg_{idx}") for idx, value in enumerate(sequence)
117+
_normalize_field(value, default_name=f"arg_{idx}")
118+
for idx, value in enumerate(sequence)
108119
]
109120

110121

@@ -117,7 +128,9 @@ def _normalize_return_field(
117128
return _normalize_field(value, default_name=default_name)
118129

119130

120-
def _wrap_extension_value(value: PyArrowArrayT, data_type: pa.DataType) -> PyArrowArrayT:
131+
def _wrap_extension_value(
132+
value: PyArrowArrayT, data_type: pa.DataType
133+
) -> PyArrowArrayT:
121134
storage_type = getattr(data_type, "storage_type", None)
122135
wrap_array = getattr(data_type, "wrap_array", None)
123136
if storage_type is None or wrap_array is None:
@@ -440,10 +453,12 @@ def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
440453
This class allows you to define an aggregate function that can be used in
441454
data aggregation or window function calls.
442455
443-
Usage:
444-
- As a function: ``udaf(accum, input_types, return_type, state_type, volatility, name)``.
445-
- As a decorator: ``@udaf(input_types, return_type, state_type, volatility, name)``.
446-
When using ``udaf`` as a decorator, do not pass ``accum`` explicitly.
456+
Usage:
457+
- As a function: ``udaf(accum, input_types, return_type, state_type,``
458+
``volatility, name)``.
459+
- As a decorator: ``@udaf(input_types, return_type, state_type,``
460+
``volatility, name)``.
461+
When using ``udaf`` as a decorator, do not pass ``accum`` explicitly.
447462
448463
Function example:
449464

python/tests/test_udf.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from __future__ import annotations
19+
1820
import uuid
1921

2022
import pyarrow as pa
2123
import pytest
2224
from datafusion import column, udf
2325

24-
2526
UUID_EXTENSION_AVAILABLE = hasattr(pa, "uuid")
2627

2728

@@ -180,11 +181,7 @@ def ensure_extension(values: pa.Array | pa.ChunkedArray) -> pa.Array:
180181
)
181182

182183
df = ctx.create_dataframe([[batch]])
183-
result = (
184-
df.select(second(first(column("uuid_col"))))
185-
.collect()[0]
186-
.column(0)
187-
)
184+
result = df.select(second(first(column("uuid_col")))).collect()[0].column(0)
188185

189186
expected = uuid_type.wrap_array(storage)
190187

@@ -212,9 +209,7 @@ def ensure_extension(values: pa.Array | pa.ChunkedArray) -> pa.Array:
212209

213210
empty_df = ctx.create_dataframe([[empty_batch]])
214211
empty_result = (
215-
empty_df.select(second(empty_first(column("uuid_col"))))
216-
.collect()[0]
217-
.column(0)
212+
empty_df.select(second(empty_first(column("uuid_col")))).collect()[0].column(0)
218213
)
219214

220215
expected_empty = uuid_type.wrap_array(empty_storage)

0 commit comments

Comments
 (0)