Skip to content

Commit c69144d

Browse files
committed
Use execute query for retrying on retryable Neo4j Exceptions
1 parent 816164b commit c69144d

18 files changed

+172
-40
lines changed

changelog.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
# Changes in 1.16
22

3-
43
## Breaking changes
54

65
## New features
76

87
## Bug fixes
98

10-
* Fixed a bug where remote projections would fail when the database is clustered
9+
- Fixed a bug where remote projections would fail when the database is clustered
1110

1211
## Improvements
1312

14-
* Allow creating sessions of size `512GB`.
15-
* Allow passing additional parameters for the Neo4j driver connection to `GdsSessions.get_or_create(neo4j_driver_config={..})`
16-
* Add helper functions to create config objects from environment variables
17-
* `AuraApiCredentials::from_env`
18-
* `DbmsConnectionInfo::from_env`
19-
13+
- Allow creating sessions of size `512GB`.
14+
- Allow passing additional parameters for the Neo4j driver connection to `GdsSessions.get_or_create(neo4j_driver_config={..})`
15+
- Add helper functions to create config objects from environment variables
16+
- `AuraApiCredentials::from_env`
17+
- `DbmsConnectionInfo::from_env`
18+
- Retry internal functions known to be idempotent. Reduces issues such as `SessionExpiredError`.
2019

2120
## Other changes

graphdatascience/graph/graph_cypher_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def project(
4545

4646
GraphCypherRunner._verify_query_ends_with_return_clause(self._namespace, query)
4747

48-
result: Optional[dict[str, Any]] = self._query_runner.run_cypher(query, params, database, False).squeeze()
48+
result: Optional[dict[str, Any]] = self._query_runner.run_retryable_cypher(
49+
query, params, database, False
50+
).squeeze()
4951

5052
if not result:
5153
raise ValueError("Projected graph cannot be empty.")

graphdatascience/graph/graph_entity_ops_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def _process_result(
162162
)
163163

164164
unique_node_ids = result["nodeId"].drop_duplicates().tolist()
165-
db_properties_df = query_runner.run_cypher(
165+
166+
db_properties_df = query_runner.run_retryable_cypher(
166167
GraphNodePropertiesRunner._build_query(db_node_properties), {"ids": unique_node_ids}
167168
)
168169

graphdatascience/graph_data_science.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def run_cypher(
218218
if isinstance(self._query_runner, ArrowQueryRunner):
219219
qr = self._query_runner.fallback_query_runner()
220220

221+
# not using qr.execute_query as we dont know if it can be retried
221222
return qr.run_cypher(query, params, database, False)
222223

223224
def driver_config(self) -> dict[str, Any]:

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ def run_cypher(
6868
) -> DataFrame:
6969
return self._fallback_query_runner.run_cypher(query, params, database, custom_error)
7070

71+
def run_retryable_cypher(
72+
self,
73+
query: str,
74+
params: Optional[dict[str, Any]] = None,
75+
database: Optional[str] = None,
76+
custom_error: bool = True,
77+
) -> DataFrame:
78+
return self._fallback_query_runner.run_retryable_cypher(query, params, database, custom_error=custom_error)
79+
7180
def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
7281
return self._fallback_query_runner.call_function(endpoint, params)
7382

@@ -78,6 +87,7 @@ def call_procedure(
7887
yields: Optional[list[str]] = None,
7988
database: Optional[str] = None,
8089
logging: bool = False,
90+
retryable: bool = False,
8191
custom_error: bool = True,
8292
) -> DataFrame:
8393
if params is None:
@@ -171,7 +181,9 @@ def call_procedure(
171181
graph_name, self._database_or_throw(), relationship_types, concurrency
172182
)
173183

174-
return self._fallback_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error)
184+
return self._fallback_query_runner.call_procedure(
185+
endpoint, params, yields, database, logging=logging, retryable=retryable, custom_error=custom_error
186+
)
175187

176188
def server_version(self) -> ServerVersion:
177189
return self._fallback_query_runner.server_version()

graphdatascience/query_runner/cypher_graph_constructor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def graph_construct_error_multidf(element: str) -> str:
104104

105105
def _should_warn_about_arrow_missing(self) -> bool:
106106
try:
107+
# TOOD use execute_query
107108
license: str = self._query_runner.run_cypher(
108109
"CALL gds.debug.sysInfo() YIELD key, value WHERE key = 'gdsEdition' RETURN value", custom_error=False
109110
).squeeze()
@@ -210,6 +211,7 @@ def run(self, node_dfs: list[DataFrame], relationship_dfs: list[DataFrame]) -> N
210211
"undirectedRelationshipTypes": self._undirected_relationship_types,
211212
}
212213

214+
# not using retryable here as gds.graph.project adds a graph to the gds graph catalog
213215
self._query_runner.run_cypher(
214216
query,
215217
{

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,10 @@ def __init__(
150150
def __run_cypher_simplified_for_query_progress_logger(self, query: str, database: Optional[str]) -> DataFrame:
151151
# progress logging should not retry a lot as it perodically fetches the latest progress anyway
152152
connectivity_retry_config = Neo4jQueryRunner.ConnectivityRetriesConfig(max_retries=2)
153+
# not using execute_query as failing is okay
153154
return self.run_cypher(query=query, database=database, connectivity_retry_config=connectivity_retry_config)
154155

156+
# only use for user defined queries
155157
def run_cypher(
156158
self,
157159
query: str,
@@ -195,12 +197,47 @@ def run_cypher(
195197

196198
return df
197199

198-
def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
200+
# better retry mechanism than run_cypher. The neo4j driver handles retryable errors internally
201+
def run_retryable_cypher(
202+
self,
203+
query: str,
204+
params: Optional[dict[str, Any]] = None,
205+
database: Optional[str] = None,
206+
custom_error: bool = True,
207+
routing: Optional[neo4j.RoutingControl] = None,
208+
connectivity_retry_config: Optional[ConnectivityRetriesConfig] = None,
209+
) -> DataFrame:
210+
if not database:
211+
database = self._database
212+
213+
if self._NEO4J_DRIVER_VERSION < SemanticVersion(5, 5, 0):
214+
return self.run_cypher(query, params, database, custom_error, connectivity_retry_config)
215+
216+
if not routing:
217+
routing = neo4j.RoutingControl.READ
218+
219+
try:
220+
return self._driver.execute_query(
221+
query_=query,
222+
parameters_=params,
223+
database=database,
224+
result_transformer_=neo4j.Result.to_df,
225+
routing_=routing,
226+
)
227+
except Exception as e:
228+
if custom_error:
229+
Neo4jQueryRunner.handle_driver_exception(self._driver, e)
230+
raise e
231+
else:
232+
raise e
233+
234+
def call_function(self, endpoint: str, params: Optional[CallParameters] = None, custom_error: bool = True) -> Any:
199235
if params is None:
200236
params = CallParameters()
201237
query = f"RETURN {endpoint}({params.placeholder_str()})"
202238

203-
return self.run_cypher(query, params).squeeze()
239+
# we can use retryable cypher as we expect all gds functions to be idempotent
240+
return self.run_retryable_cypher(query, params, custom_error=custom_error).squeeze()
204241

205242
def call_procedure(
206243
self,
@@ -209,6 +246,7 @@ def call_procedure(
209246
yields: Optional[list[str]] = None,
210247
database: Optional[str] = None,
211248
logging: bool = False,
249+
retryable: bool = False,
212250
custom_error: bool = True,
213251
) -> DataFrame:
214252
if params is None:
@@ -218,7 +256,11 @@ def call_procedure(
218256
query = f"CALL {endpoint}({params.placeholder_str()}){yields_clause}"
219257

220258
def run_cypher_query() -> DataFrame:
221-
return self.run_cypher(query, params, database, custom_error)
259+
if retryable:
260+
routing = neo4j.RoutingControl.WRITE if "write" in endpoint else neo4j.RoutingControl.READ
261+
return self.run_retryable_cypher(query, params, database, custom_error, routing=routing)
262+
else:
263+
return self.run_cypher(query, params, database, custom_error)
222264

223265
job_id = None if not params else params.get_job_id()
224266
if self._resolve_show_progress(logging) and job_id:
@@ -234,7 +276,7 @@ def server_version(self) -> ServerVersion:
234276
return self._server_version
235277

236278
try:
237-
server_version_string = self.run_cypher("RETURN gds.version()", custom_error=False).squeeze()
279+
server_version_string = self.call_function("gds.version", custom_error=False)
238280
server_version = ServerVersion.from_string(server_version_string)
239281
self._server_version = server_version
240282
return server_version
@@ -325,7 +367,7 @@ def clone(self, host: str, port: int) -> QueryRunner:
325367
)
326368

327369
@staticmethod
328-
def handle_driver_exception(session: neo4j.Session, e: Exception) -> None:
370+
def handle_driver_exception(cypher_executor: Union[neo4j.Session, neo4j.Driver], e: Exception) -> None:
329371
reg_gds_hit = re.search(
330372
r"There is no procedure with the name `(gds(?:\.\w+)+)` registered for this database instance",
331373
str(e),
@@ -335,8 +377,16 @@ def handle_driver_exception(session: neo4j.Session, e: Exception) -> None:
335377

336378
requested_endpoint = reg_gds_hit.group(1)
337379

338-
list_result = session.run("CALL gds.list() YIELD name")
339-
all_endpoints = list_result.to_df()["name"].tolist()
380+
if isinstance(cypher_executor, neo4j.Session):
381+
list_result = cypher_executor.run("CALL gds.list() YIELD name")
382+
all_endpoints = list_result.to_df()["name"].tolist()
383+
elif isinstance(cypher_executor, neo4j.Driver):
384+
result = cypher_executor.execute_query("CALL gds.list() YIELD name", result_transformer_=neo4j.Result.to_df)
385+
all_endpoints = result["name"].tolist()
386+
else:
387+
raise TypeError(
388+
f"Expected cypher_executor to be a neo4j.Session or neo4j.Driver, got {type(cypher_executor)}"
389+
)
340390

341391
raise SyntaxError(generate_suggestive_error_message(requested_endpoint, all_endpoints)) from e
342392

graphdatascience/query_runner/protocol/project_protocols.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def run_projection(
6868
logging: bool = False,
6969
) -> DataFrame:
7070
versioned_endpoint = ProtocolVersion.V1.versioned_procedure_name(endpoint)
71-
return query_runner.call_procedure(versioned_endpoint, params, yields, database, logging, False)
71+
return query_runner.call_procedure(
72+
versioned_endpoint, params, yields, database=database, logging=logging, retryable=True, custom_error=False
73+
)
7274

7375

7476
class ProjectProtocolV2(ProjectProtocol):
@@ -97,7 +99,9 @@ def run_projection(
9799
logging: bool = False,
98100
) -> DataFrame:
99101
versioned_endpoint = ProtocolVersion.V2.versioned_procedure_name(endpoint)
100-
return query_runner.call_procedure(versioned_endpoint, params, yields, database, logging, False)
102+
return query_runner.call_procedure(
103+
versioned_endpoint, params, yields, database=database, logging=logging, retryable=True, custom_error=False
104+
)
101105

102106

103107
class ProjectProtocolV3(ProjectProtocol):
@@ -149,7 +153,14 @@ def is_not_done(result: DataFrame) -> bool:
149153
def project_fn() -> DataFrame:
150154
termination_flag.assert_running()
151155
return projection_query_runner.call_procedure(
152-
ProtocolVersion.V3.versioned_procedure_name(endpoint), params, yields, database, logging, False
156+
ProtocolVersion.V3.versioned_procedure_name(endpoint),
157+
params,
158+
yields,
159+
database,
160+
database=database,
161+
logging=logging,
162+
retryable=True,
163+
custom_error=False,
153164
)
154165

155166
projection_result = project_fn()

graphdatascience/query_runner/protocol/write_protocols.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from pandas import DataFrame
66
from tenacity import retry, retry_if_result, wait_incrementing
77

8-
from graphdatascience import QueryRunner
98
from graphdatascience.call_parameters import CallParameters
109
from graphdatascience.query_runner.protocol.status import Status
10+
from graphdatascience.query_runner.query_runner import QueryRunner
1111
from graphdatascience.query_runner.termination_flag import TerminationFlag
1212
from graphdatascience.retry_utils.retry_utils import before_log
1313
from graphdatascience.session.dbms.protocol_version import ProtocolVersion
@@ -73,9 +73,10 @@ def run_write_back(
7373
ProtocolVersion.V1.versioned_procedure_name("gds.arrow.write"),
7474
parameters,
7575
yields,
76-
None,
77-
False,
78-
False,
76+
retryable=True,
77+
database=None,
78+
logging=False,
79+
custom_error=False,
7980
)
8081

8182

@@ -111,9 +112,10 @@ def run_write_back(
111112
ProtocolVersion.V2.versioned_procedure_name("gds.arrow.write"),
112113
parameters,
113114
yields,
114-
None,
115-
False,
116-
False,
115+
retryable=True,
116+
database=None,
117+
logging=False,
118+
custom_error=False,
117119
)
118120

119121

@@ -157,9 +159,10 @@ def write_fn() -> DataFrame:
157159
ProtocolVersion.V3.versioned_procedure_name("gds.arrow.write"),
158160
parameters,
159161
yields,
160-
None,
161-
False,
162-
False,
162+
retryable=True,
163+
database=None,
164+
logging=False,
165+
custom_error=False,
163166
)
164167

165168
return write_fn()

graphdatascience/query_runner/query_runner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def call_procedure(
1717
yields: Optional[list[str]] = None,
1818
database: Optional[str] = None,
1919
logging: bool = False,
20+
retryable: bool = False,
2021
custom_error: bool = True,
2122
) -> DataFrame:
2223
pass
@@ -35,6 +36,16 @@ def run_cypher(
3536
) -> DataFrame:
3637
pass
3738

39+
@abstractmethod
40+
def run_retryable_cypher(
41+
self,
42+
query: str,
43+
params: Optional[dict[str, Any]] = None,
44+
database: Optional[str] = None,
45+
custom_error: bool = True,
46+
) -> DataFrame:
47+
pass
48+
3849
@abstractmethod
3950
def server_version(self) -> ServerVersion:
4051
pass

0 commit comments

Comments
 (0)