Skip to content

Commit d130b9c

Browse files
authored
Optional check for concurrent usage errors (#989)
Optional check for concurrent usage errors Some driver objects (e.g, Sessions, Transactions, Result streams) are not safe for concurrent use. By default, it will cause hard to interpret errors or, in the worst case, wrong behavior. To aid finding such bugs, the driver now detects if the Python interpreter is running development mode and enables extra locking around those objects. If they are used concurrently, an error will be raised. The way this is implemented, it will only cause a one-time overhead when loading the driver's modules if the checks are disabled. Obviously, those checks are somewhat expensive as they entail locks (less so in the async driver). Therefore, the checks are only happening if either * Python is started in development mode (`python -X dev ...`) or * The environment variable `PYTHONNEO4JDEBUG` is set (to anything non-empty) at the time the driver's modules is loaded.
1 parent f722f65 commit d130b9c

File tree

15 files changed

+467
-24
lines changed

15 files changed

+467
-24
lines changed

docs/source/index.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,25 @@ To deactivate the current active virtual environment, use:
9999
deactivate
100100
101101
102+
Development Environment
103+
=======================
104+
105+
For development, we recommend to run Python in `development mode`_ (``python -X dev ...``).
106+
Specifically for this driver, this will:
107+
108+
* enable :class:`ResourceWarning`, which the driver emits if resources (e.g., Sessions) aren't properly closed.
109+
* enable :class:`DeprecationWarning`, which the driver emits if deprecated APIs are used.
110+
* enable the driver's debug mode (this can also be achieved by setting the environment variable ``PYTHONNEO4JDEBUG``):
111+
112+
* **This is experimental**.
113+
It might be changed or removed any time even without prior notice.
114+
* the driver will raise an exception if non-concurrency-safe methods are used concurrently.
115+
116+
.. versionadded:: 5.15
117+
118+
.. _development mode: https://docs.python.org/3/library/devmode.html
119+
120+
102121
*************
103122
Quick Example
104123
*************

docs/source/themes/neo4j/static/css/neo4j.css_t

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -503,25 +503,16 @@ dl.field-list > dd > ol {
503503
margin-left: 0;
504504
}
505505

506-
ol.simple p, ul.simple p {
507-
margin-bottom: 0;
508-
}
509-
510-
ol.simple > li:not(:first-child) > p,
511-
ul.simple > li:not(:first-child) > p,
512-
:not(li) > ol > li:first-child > :first-child,
513-
:not(li) > ul > li:first-child > :first-child {
506+
.content ol li > p:first-of-type,
507+
.content ul li > p:first-of-type {
514508
margin-top: 0;
515509
}
516510

517-
518-
li > p:last-child {
519-
margin-top: 10px;
511+
.content ol li > p:last-of-type,
512+
.content ul li > p:last-of-type {
513+
margin-bottom: 0;
520514
}
521515

522-
li > p:first-child {
523-
margin-top: 10px;
524-
}
525516

526517
table.docutils {
527518
margin-top: 10px;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from ._concurrency_check import AsyncNonConcurrentMethodChecker
18+
19+
20+
__all__ = ["AsyncNonConcurrentMethodChecker"]
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
from __future__ import annotations
18+
19+
import inspect
20+
import os
21+
import sys
22+
import traceback
23+
import typing as t
24+
from copy import deepcopy
25+
from functools import wraps
26+
27+
from ..._async_compat.concurrency import (
28+
AsyncLock,
29+
AsyncRLock,
30+
)
31+
from ..._async_compat.util import AsyncUtil
32+
from ..._meta import copy_signature
33+
34+
35+
_TWrapped = t.TypeVar("_TWrapped", bound=t.Callable[..., t.Awaitable[t.Any]])
36+
_TWrappedIter = t.TypeVar("_TWrappedIter",
37+
bound=t.Callable[..., t.AsyncIterator])
38+
39+
40+
ENABLED = sys.flags.dev_mode or bool(os.getenv("PYTHONNEO4JDEBUG"))
41+
42+
43+
class NonConcurrentMethodError(RuntimeError):
44+
pass
45+
46+
47+
class AsyncNonConcurrentMethodChecker:
48+
if ENABLED:
49+
50+
def __init__(self):
51+
self.__lock = AsyncRLock()
52+
self.__tracebacks_lock = AsyncLock()
53+
self.__tracebacks = []
54+
55+
def __make_error(self, tbs):
56+
msg = (f"Methods of {self.__class__} are not concurrency "
57+
"safe, but were invoked concurrently.")
58+
if tbs:
59+
msg += ("\n\nOther invocation site:\n\n"
60+
f"{''.join(traceback.format_list(tbs[0]))}")
61+
return NonConcurrentMethodError(msg)
62+
63+
@classmethod
64+
def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped:
65+
if AsyncUtil.is_async_code:
66+
if not inspect.iscoroutinefunction(f):
67+
raise TypeError(
68+
"cannot decorate non-coroutine function with "
69+
"AsyncNonConcurrentMethodChecked.non_concurrent_method"
70+
)
71+
else:
72+
if not callable(f):
73+
raise TypeError(
74+
"cannot decorate non-callable object with "
75+
"NonConcurrentMethodChecked.non_concurrent_method"
76+
)
77+
78+
@wraps(f)
79+
@copy_signature(f)
80+
async def inner(*args, **kwargs):
81+
self = args[0]
82+
assert isinstance(self, cls)
83+
84+
async with self.__tracebacks_lock:
85+
acquired = await self.__lock.acquire(blocking=False)
86+
if acquired:
87+
self.__tracebacks.append(AsyncUtil.extract_stack())
88+
else:
89+
tbs = deepcopy(self.__tracebacks)
90+
if acquired:
91+
try:
92+
return await f(*args, **kwargs)
93+
finally:
94+
async with self.__tracebacks_lock:
95+
self.__tracebacks.pop()
96+
self.__lock.release()
97+
else:
98+
raise self.__make_error(tbs)
99+
100+
return inner
101+
102+
@classmethod
103+
def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter:
104+
if AsyncUtil.is_async_code:
105+
if not inspect.isasyncgenfunction(f):
106+
raise TypeError(
107+
"cannot decorate non-async-generator function with "
108+
"AsyncNonConcurrentMethodChecked.non_concurrent_iter"
109+
)
110+
else:
111+
if not inspect.isgeneratorfunction(f):
112+
raise TypeError(
113+
"cannot decorate non-generator function with "
114+
"NonConcurrentMethodChecked.non_concurrent_iter"
115+
)
116+
117+
@wraps(f)
118+
@copy_signature(f)
119+
async def inner(*args, **kwargs):
120+
self = args[0]
121+
assert isinstance(self, cls)
122+
123+
iter_ = f(*args, **kwargs)
124+
while True:
125+
async with self.__tracebacks_lock:
126+
acquired = await self.__lock.acquire(blocking=False)
127+
if acquired:
128+
self.__tracebacks.append(AsyncUtil.extract_stack())
129+
else:
130+
tbs = deepcopy(self.__tracebacks)
131+
if acquired:
132+
try:
133+
item = await iter_.__anext__()
134+
finally:
135+
async with self.__tracebacks_lock:
136+
self.__tracebacks.pop()
137+
self.__lock.release()
138+
yield item
139+
else:
140+
raise self.__make_error(tbs)
141+
142+
return inner
143+
144+
else:
145+
146+
@classmethod
147+
def non_concurrent_method(cls, f: _TWrapped) -> _TWrapped:
148+
return f
149+
150+
@classmethod
151+
def non_concurrent_iter(cls, f: _TWrappedIter) -> _TWrappedIter:
152+
return f

src/neo4j/_async/work/result.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
Date,
4444
DateTime,
4545
)
46+
from .._debug import AsyncNonConcurrentMethodChecker
4647
from ..io import ConnectionErrorHandler
4748

4849

@@ -71,7 +72,7 @@
7172
)
7273

7374

74-
class AsyncResult:
75+
class AsyncResult(AsyncNonConcurrentMethodChecker):
7576
"""Handler for the result of Cypher query execution.
7677
7778
Instances of this class are typically constructed and returned by
@@ -109,6 +110,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error):
109110
self._out_of_scope = False
110111
# exception shared across all results of a transaction
111112
self._exception = None
113+
super().__init__()
112114

113115
async def _connection_error_handler(self, exc):
114116
self._exception = exc
@@ -251,11 +253,15 @@ def on_success(summary_metadata):
251253
)
252254
self._streaming = True
253255

256+
@AsyncNonConcurrentMethodChecker.non_concurrent_iter
254257
async def __aiter__(self) -> t.AsyncIterator[Record]:
255258
"""Iterator returning Records.
256259
257-
:returns: Record, it is an immutable ordered collection of key-value pairs.
258-
:rtype: :class:`neo4j.Record`
260+
Advancing the iterator advances the underlying result stream.
261+
So even when creating multiple iterators from the same result, each
262+
Record will only be returned once.
263+
264+
:returns: Iterator over the result stream's records.
259265
"""
260266
while self._record_buffer or self._attached:
261267
if self._record_buffer:
@@ -278,7 +284,9 @@ async def __aiter__(self) -> t.AsyncIterator[Record]:
278284
if self._consumed:
279285
raise ResultConsumedError(self, _RESULT_CONSUMED_ERROR)
280286

287+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
281288
async def __anext__(self) -> Record:
289+
"""Advance the result stream and return the record."""
282290
return await self.__aiter__().__anext__()
283291

284292
async def _attach(self):
@@ -367,6 +375,7 @@ def _tx_failure(self, exc):
367375
self._attached = False
368376
self._exception = exc
369377

378+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
370379
async def consume(self) -> ResultSummary:
371380
"""Consume the remainder of this result and return a :class:`neo4j.ResultSummary`.
372381
@@ -434,6 +443,7 @@ async def single(
434443
async def single(self, strict: te.Literal[True]) -> Record:
435444
...
436445

446+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
437447
async def single(self, strict: bool = False) -> t.Optional[Record]:
438448
"""Obtain the next and only remaining record or None.
439449
@@ -495,6 +505,7 @@ async def single(self, strict: bool = False) -> t.Optional[Record]:
495505
)
496506
return buffer.popleft()
497507

508+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
498509
async def fetch(self, n: int) -> t.List[Record]:
499510
"""Obtain up to n records from this result.
500511
@@ -517,6 +528,7 @@ async def fetch(self, n: int) -> t.List[Record]:
517528
for _ in range(min(n, len(self._record_buffer)))
518529
]
519530

531+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
520532
async def peek(self) -> t.Optional[Record]:
521533
"""Obtain the next record from this result without consuming it.
522534
@@ -537,6 +549,7 @@ async def peek(self) -> t.Optional[Record]:
537549
return self._record_buffer[0]
538550
return None
539551

552+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
540553
async def graph(self) -> Graph:
541554
"""Turn the result into a :class:`neo4j.Graph`.
542555
@@ -559,6 +572,7 @@ async def graph(self) -> Graph:
559572
await self._buffer_all()
560573
return self._hydration_scope.get_graph()
561574

575+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
562576
async def value(
563577
self, key: _TResultKey = 0, default: t.Optional[object] = None
564578
) -> t.List[t.Any]:
@@ -580,6 +594,7 @@ async def value(
580594
"""
581595
return [record.value(key, default) async for record in self]
582596

597+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
583598
async def values(
584599
self, *keys: _TResultKey
585600
) -> t.List[t.List[t.Any]]:
@@ -600,6 +615,7 @@ async def values(
600615
"""
601616
return [record.values(*keys) async for record in self]
602617

618+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
603619
async def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]:
604620
"""Return the remainder of the result as a list of dictionaries.
605621
@@ -626,6 +642,7 @@ async def data(self, *keys: _TResultKey) -> t.List[t.Dict[str, t.Any]]:
626642
"""
627643
return [record.data(*keys) async for record in self]
628644

645+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
629646
async def to_eager_result(self) -> EagerResult:
630647
"""Convert this result to an :class:`.EagerResult`.
631648
@@ -650,6 +667,7 @@ async def to_eager_result(self) -> EagerResult:
650667
summary=await self.consume()
651668
)
652669

670+
@AsyncNonConcurrentMethodChecker.non_concurrent_method
653671
async def to_df(
654672
self,
655673
expand: bool = False,

0 commit comments

Comments
 (0)