Skip to content

Commit 59252c6

Browse files
authored
Merge pull request #911 from FlorentinD/external-run-cypher-retryable
Expose mode and retryable on user-facing run_cypher
2 parents 3551f40 + 4fd2ec0 commit 59252c6

18 files changed

+165
-25
lines changed

changelog.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66

77
## Bug fixes
88

9-
* Fix reporting error based on http responses from the Aura-API with an invalid JSON body. Earlier the client would report JSONDecodeError instead of showing the actual issue.
9+
- Fix reporting error based on http responses from the Aura-API with an invalid JSON body. Earlier the client would report JSONDecodeError instead of showing the actual issue.
1010

1111
## Improvements
1212

13+
- `GraphDataScience::run_query` now supports setting the `mode` of the query to be used for routing. Previously queries would always route the leader of the cluster, assuming write mode.
14+
- `GraphDataScience::run_query` now support setting `retryable` to enable a retry-mechanism for appropriate errors. This requires `neo4j>=5.5.0`.
15+
16+
1317
## Other changes

graphdatascience/graph/graph_cypher_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from pandas import Series
88

9+
from graphdatascience.query_runner.query_mode import QueryMode
10+
911
from ..caller_base import CallerBase
1012
from ..query_runner.query_runner import QueryRunner
1113
from ..server_version.server_version import ServerVersion
@@ -46,7 +48,7 @@ def project(
4648
GraphCypherRunner._verify_query_ends_with_return_clause(self._namespace, query)
4749

4850
result: Optional[dict[str, Any]] = self._query_runner.run_retryable_cypher(
49-
query, params, database, custom_error=False
51+
query, params, database, custom_error=False, mode=QueryMode.READ
5052
).squeeze()
5153

5254
if not result:

graphdatascience/graph/graph_entity_ops_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pandas as pd
66
from pandas import DataFrame, Series
77

8+
from graphdatascience.query_runner.query_mode import QueryMode
9+
810
from ..call_parameters import CallParameters
911
from ..error.cypher_warning_handler import (
1012
filter_id_func_deprecation_warning,
@@ -164,7 +166,9 @@ def _process_result(
164166
unique_node_ids = result["nodeId"].drop_duplicates().tolist()
165167

166168
db_properties_df = query_runner.run_retryable_cypher(
167-
GraphNodePropertiesRunner._build_query(db_node_properties), params={"ids": unique_node_ids}
169+
GraphNodePropertiesRunner._build_query(db_node_properties),
170+
params={"ids": unique_node_ids},
171+
mode=QueryMode.READ,
168172
)
169173

170174
if "propertyValue" not in result.keys():

graphdatascience/graph_data_science.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas import DataFrame
1010

1111
from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication
12+
from graphdatascience.query_runner.query_mode import QueryMode
1213

1314
from .call_builder import IndirectCallBuilder
1415
from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
@@ -199,7 +200,12 @@ def last_bookmarks(self) -> Optional[Any]:
199200
return self._query_runner.last_bookmarks()
200201

201202
def run_cypher(
202-
self, query: str, params: Optional[dict[str, Any]] = None, database: Optional[str] = None
203+
self,
204+
query: str,
205+
params: Optional[dict[str, Any]] = None,
206+
database: Optional[str] = None,
207+
retryable: bool = False,
208+
mode: QueryMode = QueryMode.WRITE,
203209
) -> DataFrame:
204210
"""
205211
Run a Cypher query
@@ -212,6 +218,10 @@ def run_cypher(
212218
parameters to the query
213219
database: str
214220
the database on which to run the query
221+
retryable: bool
222+
whether the query can be automatically retried. Make sure the query is idempotent if set to True.
223+
mode: QueryMode
224+
the query mode to use (READ or WRITE). Set based on the operation performed in the query.
215225
216226
Returns:
217227
The query result as a DataFrame
@@ -222,8 +232,10 @@ def run_cypher(
222232
if isinstance(self._query_runner, ArrowQueryRunner):
223233
qr = self._query_runner.fallback_query_runner()
224234

225-
# not using qr.run_retryable_cypher as we dont know if it can be retried
226-
return qr.run_cypher(query, params, database, False)
235+
if retryable:
236+
return qr.run_retryable_cypher(query, params, database, custom_error=False, mode=mode)
237+
else:
238+
return qr.run_cypher(query, params, database, custom_error=False, mode=mode)
227239

228240
def driver_config(self) -> dict[str, Any]:
229241
"""

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,22 @@ def run_cypher(
6565
query: str,
6666
params: Optional[dict[str, Any]] = None,
6767
database: Optional[str] = None,
68+
mode: Optional[QueryMode] = None,
6869
custom_error: bool = True,
6970
) -> DataFrame:
70-
return self._fallback_query_runner.run_cypher(query, params, database, custom_error)
71+
return self._fallback_query_runner.run_cypher(query, params, database, mode, custom_error=custom_error)
7172

7273
def run_retryable_cypher(
7374
self,
7475
query: str,
7576
params: Optional[dict[str, Any]] = None,
7677
database: Optional[str] = None,
78+
mode: Optional[QueryMode] = None,
7779
custom_error: bool = True,
7880
) -> DataFrame:
79-
return self._fallback_query_runner.run_retryable_cypher(query, params, database, custom_error=custom_error)
81+
return self._fallback_query_runner.run_retryable_cypher(
82+
query, params, database, mode, custom_error=custom_error
83+
)
8084

8185
def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
8286
return self._fallback_query_runner.call_function(endpoint, params)

graphdatascience/query_runner/cypher_graph_constructor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from pandas import DataFrame, concat
99

10+
from graphdatascience.query_runner.query_mode import QueryMode
11+
1012
from ..server_version.server_version import ServerVersion
1113
from .graph_constructor import GraphConstructor
1214
from .query_runner import QueryRunner
@@ -105,7 +107,9 @@ def graph_construct_error_multidf(element: str) -> str:
105107
def _should_warn_about_arrow_missing(self) -> bool:
106108
try:
107109
license: str = self._query_runner.run_retryable_cypher(
108-
"CALL gds.debug.sysInfo() YIELD key, value WHERE key = 'gdsEdition' RETURN value", custom_error=False
110+
"CALL gds.debug.sysInfo() YIELD key, value WHERE key = 'gdsEdition' RETURN value",
111+
custom_error=False,
112+
mode=QueryMode.READ,
109113
).squeeze()
110114
should_warn = license == "Licensed"
111115
except Exception as e:

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,20 +161,28 @@ def run_cypher(
161161
query: str,
162162
params: Optional[dict[str, Any]] = None,
163163
database: Optional[str] = None,
164+
mode: Optional[QueryMode] = None,
164165
custom_error: bool = True,
165166
connectivity_retry_config: Optional[ConnectivityRetriesConfig] = None,
166167
) -> DataFrame:
167168
if params is None:
168169
params = {}
169170

171+
if mode is None:
172+
mode = QueryMode.WRITE
173+
170174
if database is None:
171175
database = self._database
172176

173177
if connectivity_retry_config is None:
174178
connectivity_retry_config = Neo4jQueryRunner.ConnectivityRetriesConfig()
175179
self._verify_connectivity(database=database, retry_config=connectivity_retry_config)
176180

177-
with self._driver.session(database=database, bookmarks=self.bookmarks()) as session:
181+
with self._driver.session(
182+
database=database,
183+
bookmarks=self.bookmarks(),
184+
default_access_mode=mode.neo4j_access_mode(),
185+
) as session:
178186
try:
179187
result = session.run(query, params)
180188
except Exception as e:
@@ -205,18 +213,18 @@ def run_retryable_cypher(
205213
query: str,
206214
params: Optional[dict[str, Any]] = None,
207215
database: Optional[str] = None,
208-
custom_error: bool = True,
209216
mode: Optional[QueryMode] = None,
217+
custom_error: bool = True,
210218
connectivity_retry_config: Optional[ConnectivityRetriesConfig] = None,
211219
) -> DataFrame:
212220
if not database:
213221
database = self._database
214222

215223
if self._NEO4J_DRIVER_VERSION < SemanticVersion(5, 5, 0):
216-
return self.run_cypher(query, params, database, custom_error, connectivity_retry_config)
224+
return self.run_cypher(query, params, database, mode, custom_error, connectivity_retry_config)
217225

218226
if not mode:
219-
routing = neo4j.RoutingControl.READ
227+
routing = neo4j.RoutingControl.WRITE
220228
else:
221229
routing = mode.neo4j_routing()
222230

@@ -263,9 +271,9 @@ def call_procedure(
263271

264272
def run_cypher_query() -> DataFrame:
265273
if retryable:
266-
return self.run_retryable_cypher(query, params, database, custom_error, mode=mode)
274+
return self.run_retryable_cypher(query, params, database, custom_error=custom_error, mode=mode)
267275
else:
268-
return self.run_cypher(query, params, database, custom_error)
276+
return self.run_cypher(query, params, database, custom_error=custom_error)
269277

270278
job_id = None if not params else params.get_job_id()
271279
if self._resolve_show_progress(logging) and job_id:

graphdatascience/query_runner/query_mode.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@ def neo4j_routing(self) -> "neo4j.RoutingControl":
1414
return neo4j.RoutingControl.WRITE
1515
else:
1616
raise ValueError(f"Unknown query mode: {self}")
17+
18+
def neo4j_access_mode(self) -> str:
19+
if self == QueryMode.READ:
20+
return neo4j.READ_ACCESS
21+
elif self == QueryMode.WRITE:
22+
return neo4j.WRITE_ACCESS
23+
else:
24+
raise ValueError(f"Unknown query mode: {self}")

graphdatascience/query_runner/query_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def run_cypher(
3535
query: str,
3636
params: Optional[dict[str, Any]] = None,
3737
database: Optional[str] = None,
38+
mode: Optional[QueryMode] = None,
3839
custom_error: bool = True,
3940
) -> DataFrame:
4041
pass
@@ -45,6 +46,7 @@ def run_retryable_cypher(
4546
query: str,
4647
params: Optional[dict[str, Any]] = None,
4748
database: Optional[str] = None,
49+
mode: Optional[QueryMode] = None,
4850
custom_error: bool = True,
4951
) -> DataFrame:
5052
pass

graphdatascience/query_runner/session_query_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,20 @@ def run_cypher(
5454
query: str,
5555
params: Optional[dict[str, Any]] = None,
5656
database: Optional[str] = None,
57+
mode: Optional[QueryMode] = None,
5758
custom_error: bool = True,
5859
) -> DataFrame:
59-
return self._db_query_runner.run_cypher(query, params, database, custom_error)
60+
return self._db_query_runner.run_cypher(query, params, database, mode, custom_error)
6061

6162
def run_retryable_cypher(
6263
self,
6364
query: str,
6465
params: Optional[dict[str, Any]] = None,
6566
database: Optional[str] = None,
67+
mode: Optional[QueryMode] = None,
6668
custom_error: bool = True,
6769
) -> DataFrame:
68-
return self._db_query_runner.run_retryable_cypher(query, params, database, custom_error=custom_error)
70+
return self._db_query_runner.run_retryable_cypher(query, params, database, mode=mode, custom_error=custom_error)
6971

7072
def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
7173
return self._gds_query_runner.call_function(endpoint, params)

0 commit comments

Comments
 (0)