Skip to content

Commit 656c796

Browse files
authored
Accept transaction config for execute_query (#991)
`Driver.execute_query` now accepts a `Query` object to specify transaction config like metadata and transaction timeout. Example: ```python from neo4j import ( GraphDatabase, Query, ) with GraphDatabase.driver(...) as driver: driver.execute_query( Query( "MATCH (n) RETURN n", # metadata to be logged with the transaction metadata={"foo": "bar"}, # give the transaction 5 seconds to complete on the DBMS timeout=5, ), # all the other configuration options as before database_="neo4j", # ... ) ```
1 parent 17c6097 commit 656c796

File tree

9 files changed

+178
-49
lines changed

9 files changed

+178
-49
lines changed

docs/source/api.rst

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,9 @@ Closing a driver will immediately shut down all connections in the pool.
187187
query_, parameters_, routing_, database_, impersonated_user_,
188188
bookmark_manager_, auth_, result_transformer_, **kwargs
189189
):
190+
@unit_of_work(query_.metadata, query_.timeout)
190191
def work(tx):
191-
result = tx.run(query_, parameters_, **kwargs)
192+
result = tx.run(query_.text, parameters_, **kwargs)
192193
return result_transformer_(result)
193194

194195
with driver.session(
@@ -245,16 +246,19 @@ Closing a driver will immediately shut down all connections in the pool.
245246
assert isinstance(count, int)
246247
return count
247248

248-
:param query_: cypher query to execute
249-
:type query_: typing.LiteralString
249+
:param query_:
250+
Cypher query to execute.
251+
Use a :class:`.Query` object to pass a query with additional
252+
transaction configuration.
253+
:type query_: typing.LiteralString | Query
250254
:param parameters_: parameters to use in the query
251255
:type parameters_: typing.Dict[str, typing.Any] | None
252256
:param routing_:
253-
whether to route the query to a reader (follower/read replica) or
257+
Whether to route the query to a reader (follower/read replica) or
254258
a writer (leader) in the cluster. Default is to route to a writer.
255259
:type routing_: RoutingControl
256260
:param database_:
257-
database to execute the query against.
261+
Database to execute the query against.
258262

259263
None (default) uses the database configured on the server side.
260264

@@ -375,6 +379,10 @@ Closing a driver will immediately shut down all connections in the pool.
375379
.. versionchanged:: 5.14
376380
Stabilized ``auth_`` parameter from preview.
377381

382+
.. versionchanged:: 5.15
383+
The ``query_`` parameter now also accepts a :class:`.Query` object
384+
instead of only :class:`str`.
385+
378386

379387
.. _driver-configuration-ref:
380388

docs/source/async_api.rst

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ Closing a driver will immediately shut down all connections in the pool.
174174
query_, parameters_, routing_, database_, impersonated_user_,
175175
bookmark_manager_, auth_, result_transformer_, **kwargs
176176
):
177+
@unit_of_work(query_.metadata, query_.timeout)
177178
async def work(tx):
178-
result = await tx.run(query_, parameters_, **kwargs)
179+
result = await tx.run(query_.text, parameters_, **kwargs)
179180
return await result_transformer_(result)
180181

181182
async with driver.session(
@@ -232,16 +233,19 @@ Closing a driver will immediately shut down all connections in the pool.
232233
assert isinstance(count, int)
233234
return count
234235

235-
:param query_: cypher query to execute
236-
:type query_: typing.LiteralString
236+
:param query_:
237+
Cypher query to execute.
238+
Use a :class:`.Query` object to pass a query with additional
239+
transaction configuration.
240+
:type query_: typing.LiteralString | Query
237241
:param parameters_: parameters to use in the query
238242
:type parameters_: typing.Dict[str, typing.Any] | None
239243
:param routing_:
240-
whether to route the query to a reader (follower/read replica) or
244+
Whether to route the query to a reader (follower/read replica) or
241245
a writer (leader) in the cluster. Default is to route to a writer.
242246
:type routing_: RoutingControl
243247
:param database_:
244-
database to execute the query against.
248+
Database to execute the query against.
245249

246250
None (default) uses the database configured on the server side.
247251

@@ -362,6 +366,10 @@ Closing a driver will immediately shut down all connections in the pool.
362366
.. versionchanged:: 5.14
363367
Stabilized ``auth_`` parameter from preview.
364368

369+
.. versionchanged:: 5.15
370+
The ``query_`` parameter now also accepts a :class:`.Query` object
371+
instead of only :class:`str`.
372+
365373

366374
.. _async-driver-configuration-ref:
367375

src/neo4j/_async/driver.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@
4747
experimental_warn,
4848
unclosed_resource_warn,
4949
)
50-
from .._work import EagerResult
50+
from .._work import (
51+
EagerResult,
52+
Query,
53+
unit_of_work,
54+
)
5155
from ..addressing import Address
5256
from ..api import (
5357
AsyncBookmarkManager,
@@ -581,7 +585,7 @@ async def close(self) -> None:
581585
@t.overload
582586
async def execute_query(
583587
self,
584-
query_: te.LiteralString,
588+
query_: t.Union[te.LiteralString, Query],
585589
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
586590
routing_: T_RoutingControl = RoutingControl.WRITE,
587591
database_: t.Optional[str] = None,
@@ -600,7 +604,7 @@ async def execute_query(
600604
@t.overload
601605
async def execute_query(
602606
self,
603-
query_: te.LiteralString,
607+
query_: t.Union[te.LiteralString, Query],
604608
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
605609
routing_: T_RoutingControl = RoutingControl.WRITE,
606610
database_: t.Optional[str] = None,
@@ -618,7 +622,7 @@ async def execute_query(
618622

619623
async def execute_query(
620624
self,
621-
query_: te.LiteralString,
625+
query_: t.Union[te.LiteralString, Query],
622626
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
623627
routing_: T_RoutingControl = RoutingControl.WRITE,
624628
database_: t.Optional[str] = None,
@@ -651,8 +655,9 @@ async def execute_query(
651655
query_, parameters_, routing_, database_, impersonated_user_,
652656
bookmark_manager_, auth_, result_transformer_, **kwargs
653657
):
658+
@unit_of_work(query_.metadata, query_.timeout)
654659
async def work(tx):
655-
result = await tx.run(query_, parameters_, **kwargs)
660+
result = await tx.run(query_.text, parameters_, **kwargs)
656661
return await result_transformer_(result)
657662
658663
async with driver.session(
@@ -709,16 +714,19 @@ async def example(driver: neo4j.AsyncDriver) -> int:
709714
assert isinstance(count, int)
710715
return count
711716
712-
:param query_: cypher query to execute
713-
:type query_: typing.LiteralString
717+
:param query_:
718+
Cypher query to execute.
719+
Use a :class:`.Query` object to pass a query with additional
720+
transaction configuration.
721+
:type query_: typing.LiteralString | Query
714722
:param parameters_: parameters to use in the query
715723
:type parameters_: typing.Optional[typing.Dict[str, typing.Any]]
716724
:param routing_:
717-
whether to route the query to a reader (follower/read replica) or
725+
Whether to route the query to a reader (follower/read replica) or
718726
a writer (leader) in the cluster. Default is to route to a writer.
719727
:type routing_: RoutingControl
720728
:param database_:
721-
database to execute the query against.
729+
Database to execute the query against.
722730
723731
None (default) uses the database configured on the server side.
724732
@@ -838,6 +846,10 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::
838846
839847
.. versionchanged:: 5.14
840848
Stabilized ``auth_`` parameter from preview.
849+
850+
.. versionchanged:: 5.15
851+
The ``query_`` parameter now also accepts a :class:`.Query` object
852+
instead of only :class:`str`.
841853
"""
842854
self._check_state()
843855
invalid_kwargs = [k for k in kwargs if
@@ -850,6 +862,14 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::
850862
"latter case, use the `parameters_` dictionary instead."
851863
% invalid_kwargs
852864
)
865+
if isinstance(query_, Query):
866+
timeout = query_.timeout
867+
metadata = query_.metadata
868+
query_str = query_.text
869+
work = unit_of_work(metadata, timeout)(_work)
870+
else:
871+
query_str = query_
872+
work = _work
853873
parameters = dict(parameters_ or {}, **kwargs)
854874

855875
if bookmark_manager_ is _default:
@@ -876,7 +896,7 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record::
876896
with session._pipelined_begin:
877897
return await session._run_transaction(
878898
access_mode, TelemetryAPI.DRIVER,
879-
_work, (query_, parameters, result_transformer_), {}
899+
work, (query_str, parameters, result_transformer_), {}
880900
)
881901

882902
@property
@@ -1195,7 +1215,7 @@ async def _get_server_info(self, session_config) -> ServerInfo:
11951215

11961216
async def _work(
11971217
tx: AsyncManagedTransaction,
1198-
query: str,
1218+
query: te.LiteralString,
11991219
parameters: t.Dict[str, t.Any],
12001220
transformer: t.Callable[[AsyncResult], t.Awaitable[_T]]
12011221
) -> _T:

src/neo4j/_sync/driver.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@
4747
experimental_warn,
4848
unclosed_resource_warn,
4949
)
50-
from .._work import EagerResult
50+
from .._work import (
51+
EagerResult,
52+
Query,
53+
unit_of_work,
54+
)
5155
from ..addressing import Address
5256
from ..api import (
5357
Auth,
@@ -580,7 +584,7 @@ def close(self) -> None:
580584
@t.overload
581585
def execute_query(
582586
self,
583-
query_: te.LiteralString,
587+
query_: t.Union[te.LiteralString, Query],
584588
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
585589
routing_: T_RoutingControl = RoutingControl.WRITE,
586590
database_: t.Optional[str] = None,
@@ -599,7 +603,7 @@ def execute_query(
599603
@t.overload
600604
def execute_query(
601605
self,
602-
query_: te.LiteralString,
606+
query_: t.Union[te.LiteralString, Query],
603607
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
604608
routing_: T_RoutingControl = RoutingControl.WRITE,
605609
database_: t.Optional[str] = None,
@@ -617,7 +621,7 @@ def execute_query(
617621

618622
def execute_query(
619623
self,
620-
query_: te.LiteralString,
624+
query_: t.Union[te.LiteralString, Query],
621625
parameters_: t.Optional[t.Dict[str, t.Any]] = None,
622626
routing_: T_RoutingControl = RoutingControl.WRITE,
623627
database_: t.Optional[str] = None,
@@ -650,8 +654,9 @@ def execute_query(
650654
query_, parameters_, routing_, database_, impersonated_user_,
651655
bookmark_manager_, auth_, result_transformer_, **kwargs
652656
):
657+
@unit_of_work(query_.metadata, query_.timeout)
653658
def work(tx):
654-
result = tx.run(query_, parameters_, **kwargs)
659+
result = tx.run(query_.text, parameters_, **kwargs)
655660
return result_transformer_(result)
656661
657662
with driver.session(
@@ -708,16 +713,19 @@ def example(driver: neo4j.Driver) -> int:
708713
assert isinstance(count, int)
709714
return count
710715
711-
:param query_: cypher query to execute
712-
:type query_: typing.LiteralString
716+
:param query_:
717+
Cypher query to execute.
718+
Use a :class:`.Query` object to pass a query with additional
719+
transaction configuration.
720+
:type query_: typing.LiteralString | Query
713721
:param parameters_: parameters to use in the query
714722
:type parameters_: typing.Optional[typing.Dict[str, typing.Any]]
715723
:param routing_:
716-
whether to route the query to a reader (follower/read replica) or
724+
Whether to route the query to a reader (follower/read replica) or
717725
a writer (leader) in the cluster. Default is to route to a writer.
718726
:type routing_: RoutingControl
719727
:param database_:
720-
database to execute the query against.
728+
Database to execute the query against.
721729
722730
None (default) uses the database configured on the server side.
723731
@@ -837,6 +845,10 @@ def example(driver: neo4j.Driver) -> neo4j.Record::
837845
838846
.. versionchanged:: 5.14
839847
Stabilized ``auth_`` parameter from preview.
848+
849+
.. versionchanged:: 5.15
850+
The ``query_`` parameter now also accepts a :class:`.Query` object
851+
instead of only :class:`str`.
840852
"""
841853
self._check_state()
842854
invalid_kwargs = [k for k in kwargs if
@@ -849,6 +861,14 @@ def example(driver: neo4j.Driver) -> neo4j.Record::
849861
"latter case, use the `parameters_` dictionary instead."
850862
% invalid_kwargs
851863
)
864+
if isinstance(query_, Query):
865+
timeout = query_.timeout
866+
metadata = query_.metadata
867+
query_str = query_.text
868+
work = unit_of_work(metadata, timeout)(_work)
869+
else:
870+
query_str = query_
871+
work = _work
852872
parameters = dict(parameters_ or {}, **kwargs)
853873

854874
if bookmark_manager_ is _default:
@@ -875,7 +895,7 @@ def example(driver: neo4j.Driver) -> neo4j.Record::
875895
with session._pipelined_begin:
876896
return session._run_transaction(
877897
access_mode, TelemetryAPI.DRIVER,
878-
_work, (query_, parameters, result_transformer_), {}
898+
work, (query_str, parameters, result_transformer_), {}
879899
)
880900

881901
@property
@@ -1194,7 +1214,7 @@ def _get_server_info(self, session_config) -> ServerInfo:
11941214

11951215
def _work(
11961216
tx: ManagedTransaction,
1197-
query: str,
1217+
query: te.LiteralString,
11981218
parameters: t.Dict[str, t.Any],
11991219
transformer: t.Callable[[Result], t.Union[_T]]
12001220
) -> _T:

src/neo4j/_work/query.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ class Query:
2929
"""A query with attached extra data.
3030
3131
This wrapper class for queries is used to attach extra data to queries
32-
passed to :meth:`.Session.run` and :meth:`.AsyncSession.run`, fulfilling
33-
a similar role as :func:`.unit_of_work` for transactions functions.
32+
passed to :meth:`.Session.run`/:meth:`.AsyncSession.run` and
33+
:meth:`.Driver.execute_query`/:meth:`.AsyncDriver.execute_query`,
34+
fulfilling a similar role as :func:`.unit_of_work` for transactions
35+
functions.
3436
3537
:param text: The query text.
3638
:type text: typing.LiteralString
@@ -74,7 +76,12 @@ def __init__(
7476
self.timeout = timeout
7577

7678
def __str__(self) -> te.LiteralString:
77-
return str(self.text)
79+
# we know that if Query is constructed with a LiteralString,
80+
# str(self.text) will be a LiteralString as well. The conversion isn't
81+
# necessary if the user adheres to the type hints. However, it was
82+
# here before, and we don't want to break backwards compatibility.
83+
text: te.LiteralString = str(self.text) # type: ignore[assignment]
84+
return text
7885

7986

8087
def unit_of_work(

testkitbackend/_async/requests.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,11 @@ async def ExecuteQuery(backend, data):
363363
value = config.get(config_key, None)
364364
if value is not None:
365365
kwargs[kwargs_key] = value
366+
tx_kwargs = fromtestkit.to_tx_kwargs(config)
367+
if tx_kwargs:
368+
query = neo4j.Query(cypher, **tx_kwargs)
369+
else:
370+
query = cypher
366371
bookmark_manager_id = config.get("bookmarkManagerId")
367372
if bookmark_manager_id is not None:
368373
if bookmark_manager_id == -1:
@@ -371,7 +376,7 @@ async def ExecuteQuery(backend, data):
371376
bookmark_manager = backend.bookmark_managers[bookmark_manager_id]
372377
kwargs["bookmark_manager_"] = bookmark_manager
373378

374-
eager_result = await driver.execute_query(cypher, params, **kwargs)
379+
eager_result = await driver.execute_query(query, params, **kwargs)
375380
await backend.send_response("EagerResult", {
376381
"keys": eager_result.keys,
377382
"records": list(map(totestkit.record, eager_result.records)),

testkitbackend/_sync/requests.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,11 @@ def ExecuteQuery(backend, data):
363363
value = config.get(config_key, None)
364364
if value is not None:
365365
kwargs[kwargs_key] = value
366+
tx_kwargs = fromtestkit.to_tx_kwargs(config)
367+
if tx_kwargs:
368+
query = neo4j.Query(cypher, **tx_kwargs)
369+
else:
370+
query = cypher
366371
bookmark_manager_id = config.get("bookmarkManagerId")
367372
if bookmark_manager_id is not None:
368373
if bookmark_manager_id == -1:
@@ -371,7 +376,7 @@ def ExecuteQuery(backend, data):
371376
bookmark_manager = backend.bookmark_managers[bookmark_manager_id]
372377
kwargs["bookmark_manager_"] = bookmark_manager
373378

374-
eager_result = driver.execute_query(cypher, params, **kwargs)
379+
eager_result = driver.execute_query(query, params, **kwargs)
375380
backend.send_response("EagerResult", {
376381
"keys": eager_result.keys,
377382
"records": list(map(totestkit.record, eager_result.records)),

0 commit comments

Comments
 (0)