Skip to content

Commit af1c07b

Browse files
kosiewntjohnson1timsaucer
authored
Freeze PyO3 wrappers & introduce interior mutability to avoid PyO3 borrow errors (#1253)
* Refactor schema, config, dataframe, and expression classes to use RwLock and Mutex for interior mutability * Add error handling to CaseBuilder methods to preserve builder state * Refactor to use parking_lot for interior mutability in schema, config, dataframe, and conditional expression modules * Add concurrency tests for SqlSchema, Config, and DataFrame * Add tests for CaseBuilder to ensure builder state is preserved on success * Add test for independent handles in CaseBuilder to verify behavior * Fix CaseBuilder to preserve state correctly in when() method * Refactor to use named constant for boolean literals in test_expr.py * fix ruff errors * Refactor to introduce type aliases for cached batches in dataframe.rs * Cherry pick from #1252 * Add most expr - cherry pick from #1252 * Add source root - cherry pick #1252 * Fix license comment formatting in config.rs * Refactor caching logic to use a local variable for IPython environment check * Add test for ensuring exposed pyclasses default to frozen * Add PyO3 class mutability guidelines reference to contributor guide * Mark boolean expression classes as frozen for immutability * Refactor PyCaseBuilder methods to eliminate redundant take/store logic * Refactor PyConfig methods to improve readability by encapsulating configuration reads * Resolve patch apply conflicts for CaseBuilder concurrency improvements - Added CaseBuilderHandle guard that keeps the underlying CaseBuilder alive while holding the mutex and restores it on drop - Updated when, otherwise, and end methods to operate through the guard and consume the builder explicitly - This prevents transient None states during concurrent access and improves thread safety * Resolve Config optimization conflicts for improved read/write concurrency - Released Config read guard before converting values to Python objects in get and get_all - Ensures locks are held only while collecting scalar entries, not during expensive Python object conversion - Added regression test that runs Config.get_all and Config.set concurrently to guard against read/write contention regressions - Improves overall performance by reducing lock contention in multi-threaded scenarios * Refactor PyConfig get methods for improved readability and performance * Refactor test_expr.py to replace positional boolean literals with named constants for improved linting compliance * fix ruff errors * Add license header to test_pyclass_frozen.py for compliance * Alternate approach to case expression * Replace case builter with keeping the expressions and then applying as required * Update unit tests * Refactor case and when functions to utilize PyCaseBuilder for improved clarity and functionality * Update src/expr/conditional_expr.rs --------- Co-authored-by: ntjohnson1 <24689722+ntjohnson1@users.noreply.github.com> Co-authored-by: Tim Saucer <timsaucer@gmail.com>
1 parent e75addf commit af1c07b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+761
-209
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ futures = "0.3"
5151
object_store = { version = "0.12.3", features = ["aws", "gcp", "azure", "http"] }
5252
url = "2"
5353
log = "0.4.27"
54+
parking_lot = "0.12"
5455

5556
[build-dependencies]
5657
prost-types = "0.13.1" # keep in line with `datafusion-substrait`

docs/source/contributor-guide/ffi.rst

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,67 @@ and you want to create a sharable FFI counterpart, you could write:
137137
let my_provider = MyTableProvider::default();
138138
let ffi_provider = FFI_TableProvider::new(Arc::new(my_provider), false, None);
139139
140+
.. _ffi_pyclass_mutability:
141+
142+
PyO3 class mutability guidelines
143+
--------------------------------
144+
145+
PyO3 bindings should present immutable wrappers whenever a struct stores shared or
146+
interior-mutable state. In practice this means that any ``#[pyclass]`` containing an
147+
``Arc<RwLock<_>>`` or similar synchronized primitive must opt into ``#[pyclass(frozen)]``
148+
unless there is a compelling reason not to.
149+
150+
The :mod:`datafusion` configuration helpers illustrate the preferred pattern. The
151+
``PyConfig`` class in :file:`src/config.rs` stores an ``Arc<RwLock<ConfigOptions>>`` and is
152+
explicitly frozen so callers interact with configuration state through provided methods
153+
instead of mutating the container directly:
154+
155+
.. code-block:: rust
156+
157+
#[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
158+
#[derive(Clone)]
159+
pub(crate) struct PyConfig {
160+
config: Arc<RwLock<ConfigOptions>>,
161+
}
162+
163+
The same approach applies to execution contexts. ``PySessionContext`` in
164+
:file:`src/context.rs` stays frozen even though it shares mutable state internally via
165+
``SessionContext``. This ensures PyO3 tracks borrows correctly while Python-facing APIs
166+
clone the inner ``SessionContext`` or return new wrappers instead of mutating the
167+
existing instance in place:
168+
169+
.. code-block:: rust
170+
171+
#[pyclass(frozen, name = "SessionContext", module = "datafusion", subclass)]
172+
#[derive(Clone)]
173+
pub struct PySessionContext {
174+
pub ctx: SessionContext,
175+
}
176+
177+
Occasionally a type must remain mutable—for example when PyO3 attribute setters need to
178+
update fields directly. In these rare cases add an inline justification so reviewers and
179+
future contributors understand why ``frozen`` is unsafe to enable. ``DataTypeMap`` in
180+
:file:`src/common/data_type.rs` includes such a comment because PyO3 still needs to track
181+
field updates:
182+
183+
.. code-block:: rust
184+
185+
// TODO: This looks like this needs pyo3 tracking so leaving unfrozen for now
186+
#[derive(Debug, Clone)]
187+
#[pyclass(name = "DataTypeMap", module = "datafusion.common", subclass)]
188+
pub struct DataTypeMap {
189+
#[pyo3(get, set)]
190+
pub arrow_type: PyDataType,
191+
#[pyo3(get, set)]
192+
pub python_type: PythonType,
193+
#[pyo3(get, set)]
194+
pub sql_type: SqlType,
195+
}
196+
197+
When reviewers encounter a mutable ``#[pyclass]`` without a comment, they should request
198+
an explanation or ask that ``frozen`` be added. Keeping these wrappers frozen by default
199+
helps avoid subtle bugs stemming from PyO3's interior mutability tracking.
200+
140201
If you were interfacing with a library that provided the above ``FFI_TableProvider`` and
141202
you needed to turn it back into an ``TableProvider``, you can turn it into a
142203
``ForeignTableProvider`` with implements the ``TableProvider`` trait.

docs/source/contributor-guide/introduction.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ We welcome and encourage contributions of all kinds, such as:
2626
In addition to submitting new PRs, we have a healthy tradition of community members reviewing each other’s PRs.
2727
Doing so is a great way to help the community as well as get more familiar with Rust and the relevant codebases.
2828

29+
Before opening a pull request that touches PyO3 bindings, please review the
30+
:ref:`PyO3 class mutability guidelines <ffi_pyclass_mutability>` so you can flag missing
31+
``#[pyclass(frozen)]`` annotations during development and review.
32+
2933
How to develop
3034
--------------
3135

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ convention = "google"
107107
[tool.ruff.lint.pycodestyle]
108108
max-doc-length = 88
109109

110+
[tool.ruff.lint.flake8-boolean-trap]
111+
extend-allowed-calls = ["lit", "datafusion.lit"]
112+
110113
# Disable docstring checking for these directories
111114
[tool.ruff.lint.per-file-ignores]
112115
"python/tests/*" = [

python/tests/test_concurrency.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from concurrent.futures import ThreadPoolExecutor
21+
22+
import pyarrow as pa
23+
from datafusion import Config, SessionContext, col, lit
24+
from datafusion import functions as f
25+
from datafusion.common import SqlSchema
26+
27+
28+
def _run_in_threads(fn, count: int = 8) -> None:
29+
with ThreadPoolExecutor(max_workers=count) as executor:
30+
futures = [executor.submit(fn, i) for i in range(count)]
31+
for future in futures:
32+
# Propagate any exception raised in the worker thread.
33+
future.result()
34+
35+
36+
def test_concurrent_access_to_shared_structures() -> None:
37+
"""Exercise SqlSchema, Config, and DataFrame concurrently."""
38+
39+
schema = SqlSchema("concurrency")
40+
config = Config()
41+
ctx = SessionContext()
42+
43+
batch = pa.record_batch([pa.array([1, 2, 3], type=pa.int32())], names=["value"])
44+
df = ctx.create_dataframe([[batch]])
45+
46+
config_key = "datafusion.execution.batch_size"
47+
expected_rows = batch.num_rows
48+
49+
def worker(index: int) -> None:
50+
schema.name = f"concurrency-{index}"
51+
assert schema.name.startswith("concurrency-")
52+
# Exercise getters that use internal locks.
53+
assert isinstance(schema.tables, list)
54+
assert isinstance(schema.views, list)
55+
assert isinstance(schema.functions, list)
56+
57+
config.set(config_key, str(1024 + index))
58+
assert config.get(config_key) is not None
59+
# Access the full config map to stress lock usage.
60+
assert config_key in config.get_all()
61+
62+
batches = df.collect()
63+
assert sum(batch.num_rows for batch in batches) == expected_rows
64+
65+
_run_in_threads(worker, count=12)
66+
67+
68+
def test_config_set_during_get_all() -> None:
69+
"""Ensure config writes proceed while another thread reads all entries."""
70+
71+
config = Config()
72+
key = "datafusion.execution.batch_size"
73+
74+
def reader() -> None:
75+
for _ in range(200):
76+
# get_all should not hold the lock while converting to Python objects
77+
config.get_all()
78+
79+
def writer() -> None:
80+
for index in range(200):
81+
config.set(key, str(1024 + index))
82+
83+
with ThreadPoolExecutor(max_workers=2) as executor:
84+
reader_future = executor.submit(reader)
85+
writer_future = executor.submit(writer)
86+
reader_future.result(timeout=10)
87+
writer_future.result(timeout=10)
88+
89+
assert config.get(key) is not None
90+
91+
92+
def test_case_builder_reuse_from_multiple_threads() -> None:
93+
"""Ensure the case builder can be safely reused across threads."""
94+
95+
ctx = SessionContext()
96+
values = pa.array([0, 1, 2, 3, 4], type=pa.int32())
97+
df = ctx.create_dataframe([[pa.record_batch([values], names=["value"])]])
98+
99+
base_builder = f.case(col("value"))
100+
101+
def add_case(i: int) -> None:
102+
nonlocal base_builder
103+
base_builder = base_builder.when(lit(i), lit(f"value-{i}"))
104+
105+
_run_in_threads(add_case, count=8)
106+
107+
with ThreadPoolExecutor(max_workers=2) as executor:
108+
otherwise_future = executor.submit(base_builder.otherwise, lit("default"))
109+
case_expr = otherwise_future.result()
110+
111+
result = df.select(case_expr.alias("label")).collect()
112+
assert sum(batch.num_rows for batch in result) == len(values)
113+
114+
predicate_builder = f.when(col("value") == lit(0), lit("zero"))
115+
116+
def add_predicate(i: int) -> None:
117+
predicate_builder.when(col("value") == lit(i + 1), lit(f"value-{i + 1}"))
118+
119+
_run_in_threads(add_predicate, count=4)
120+
121+
with ThreadPoolExecutor(max_workers=2) as executor:
122+
end_future = executor.submit(predicate_builder.end)
123+
predicate_expr = end_future.result()
124+
125+
result = df.select(predicate_expr.alias("label")).collect()
126+
assert sum(batch.num_rows for batch in result) == len(values)

python/tests/test_expr.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
import re
19+
from concurrent.futures import ThreadPoolExecutor
1920
from datetime import datetime, timezone
2021

2122
import pyarrow as pa
@@ -200,6 +201,98 @@ def traverse_logical_plan(plan):
200201
assert not variant.negated()
201202

202203

204+
def test_case_builder_error_preserves_builder_state():
205+
case_builder = functions.when(lit(True), lit(1))
206+
207+
with pytest.raises(Exception) as exc_info:
208+
_ = case_builder.otherwise(lit("bad"))
209+
210+
err_msg = str(exc_info.value)
211+
assert "multiple data types" in err_msg
212+
assert "CaseBuilder has already been consumed" not in err_msg
213+
214+
_ = case_builder.end()
215+
216+
err_msg = str(exc_info.value)
217+
assert "multiple data types" in err_msg
218+
assert "CaseBuilder has already been consumed" not in err_msg
219+
220+
221+
def test_case_builder_success_preserves_builder_state():
222+
ctx = SessionContext()
223+
df = ctx.from_pydict({"flag": [False]}, name="tbl")
224+
225+
case_builder = functions.when(col("flag"), lit("true"))
226+
227+
expr_default_one = case_builder.otherwise(lit("default-1")).alias("result")
228+
result_one = df.select(expr_default_one).collect()
229+
assert result_one[0].column(0).to_pylist() == ["default-1"]
230+
231+
expr_default_two = case_builder.otherwise(lit("default-2")).alias("result")
232+
result_two = df.select(expr_default_two).collect()
233+
assert result_two[0].column(0).to_pylist() == ["default-2"]
234+
235+
expr_end_one = case_builder.end().alias("result")
236+
end_one = df.select(expr_end_one).collect()
237+
assert end_one[0].column(0).to_pylist() == [None]
238+
239+
240+
def test_case_builder_when_handles_are_independent():
241+
ctx = SessionContext()
242+
df = ctx.from_pydict(
243+
{
244+
"flag": [True, False, False, False],
245+
"value": [1, 15, 25, 5],
246+
},
247+
name="tbl",
248+
)
249+
250+
base_builder = functions.when(col("flag"), lit("flag-true"))
251+
252+
first_builder = base_builder.when(col("value") > lit(10), lit("gt10"))
253+
second_builder = base_builder.when(col("value") > lit(20), lit("gt20"))
254+
255+
first_builder = first_builder.when(lit(True), lit("final-one"))
256+
257+
expr_first = first_builder.otherwise(lit("fallback-one")).alias("first")
258+
expr_second = second_builder.otherwise(lit("fallback-two")).alias("second")
259+
260+
result = df.select(expr_first, expr_second).collect()[0]
261+
262+
assert result.column(0).to_pylist() == [
263+
"flag-true",
264+
"gt10",
265+
"gt10",
266+
"final-one",
267+
]
268+
assert result.column(1).to_pylist() == [
269+
"flag-true",
270+
"fallback-two",
271+
"gt20",
272+
"fallback-two",
273+
]
274+
275+
276+
def test_case_builder_when_thread_safe():
277+
case_builder = functions.when(lit(True), lit(1))
278+
279+
def build_expr(value: int) -> bool:
280+
builder = case_builder.when(lit(True), lit(value))
281+
builder.otherwise(lit(value))
282+
return True
283+
284+
with ThreadPoolExecutor(max_workers=8) as executor:
285+
futures = [executor.submit(build_expr, idx) for idx in range(16)]
286+
results = [future.result() for future in futures]
287+
288+
assert all(results)
289+
290+
# Ensure the shared builder remains usable after concurrent `when` calls.
291+
follow_up_builder = case_builder.when(lit(True), lit(42))
292+
assert isinstance(follow_up_builder, type(case_builder))
293+
follow_up_builder.otherwise(lit(7))
294+
295+
203296
def test_expr_getitem() -> None:
204297
ctx = SessionContext()
205298
data = {

0 commit comments

Comments
 (0)