@@ -150,8 +150,10 @@ def __init__(
150150 def __run_cypher_simplified_for_query_progress_logger (self , query : str , database : Optional [str ]) -> DataFrame :
151151 # progress logging should not retry a lot as it perodically fetches the latest progress anyway
152152 connectivity_retry_config = Neo4jQueryRunner .ConnectivityRetriesConfig (max_retries = 2 )
153+ # not using execute_query as failing is okay
153154 return self .run_cypher (query = query , database = database , connectivity_retry_config = connectivity_retry_config )
154155
156+ # only use for user defined queries
155157 def run_cypher (
156158 self ,
157159 query : str ,
@@ -195,12 +197,47 @@ def run_cypher(
195197
196198 return df
197199
198- def call_function (self , endpoint : str , params : Optional [CallParameters ] = None ) -> Any :
200+ # better retry mechanism than run_cypher. The neo4j driver handles retryable errors internally
201+ def run_retryable_cypher (
202+ self ,
203+ query : str ,
204+ params : Optional [dict [str , Any ]] = None ,
205+ database : Optional [str ] = None ,
206+ custom_error : bool = True ,
207+ routing : Optional [neo4j .RoutingControl ] = None ,
208+ connectivity_retry_config : Optional [ConnectivityRetriesConfig ] = None ,
209+ ) -> DataFrame :
210+ if not database :
211+ database = self ._database
212+
213+ if self ._NEO4J_DRIVER_VERSION < SemanticVersion (5 , 5 , 0 ):
214+ return self .run_cypher (query , params , database , custom_error , connectivity_retry_config )
215+
216+ if not routing :
217+ routing = neo4j .RoutingControl .READ
218+
219+ try :
220+ return self ._driver .execute_query (
221+ query_ = query ,
222+ parameters_ = params ,
223+ database = database ,
224+ result_transformer_ = neo4j .Result .to_df ,
225+ routing_ = routing ,
226+ )
227+ except Exception as e :
228+ if custom_error :
229+ Neo4jQueryRunner .handle_driver_exception (self ._driver , e )
230+ raise e
231+ else :
232+ raise e
233+
234+ def call_function (self , endpoint : str , params : Optional [CallParameters ] = None , custom_error : bool = True ) -> Any :
199235 if params is None :
200236 params = CallParameters ()
201237 query = f"RETURN { endpoint } ({ params .placeholder_str ()} )"
202238
203- return self .run_cypher (query , params ).squeeze ()
239+ # we can use retryable cypher as we expect all gds functions to be idempotent
240+ return self .run_retryable_cypher (query , params , custom_error = custom_error ).squeeze ()
204241
205242 def call_procedure (
206243 self ,
@@ -209,6 +246,7 @@ def call_procedure(
209246 yields : Optional [list [str ]] = None ,
210247 database : Optional [str ] = None ,
211248 logging : bool = False ,
249+ retryable : bool = False ,
212250 custom_error : bool = True ,
213251 ) -> DataFrame :
214252 if params is None :
@@ -218,7 +256,11 @@ def call_procedure(
218256 query = f"CALL { endpoint } ({ params .placeholder_str ()} ){ yields_clause } "
219257
220258 def run_cypher_query () -> DataFrame :
221- return self .run_cypher (query , params , database , custom_error )
259+ if retryable :
260+ routing = neo4j .RoutingControl .WRITE if "write" in endpoint else neo4j .RoutingControl .READ
261+ return self .run_retryable_cypher (query , params , database , custom_error , routing = routing )
262+ else :
263+ return self .run_cypher (query , params , database , custom_error )
222264
223265 job_id = None if not params else params .get_job_id ()
224266 if self ._resolve_show_progress (logging ) and job_id :
@@ -234,7 +276,7 @@ def server_version(self) -> ServerVersion:
234276 return self ._server_version
235277
236278 try :
237- server_version_string = self .run_cypher ( "RETURN gds.version() " , custom_error = False ). squeeze ( )
279+ server_version_string = self .call_function ( " gds.version" , custom_error = False )
238280 server_version = ServerVersion .from_string (server_version_string )
239281 self ._server_version = server_version
240282 return server_version
@@ -325,7 +367,7 @@ def clone(self, host: str, port: int) -> QueryRunner:
325367 )
326368
327369 @staticmethod
328- def handle_driver_exception (session : neo4j .Session , e : Exception ) -> None :
370+ def handle_driver_exception (cypher_executor : Union [ neo4j .Session , neo4j . Driver ] , e : Exception ) -> None :
329371 reg_gds_hit = re .search (
330372 r"There is no procedure with the name `(gds(?:\.\w+)+)` registered for this database instance" ,
331373 str (e ),
@@ -335,8 +377,16 @@ def handle_driver_exception(session: neo4j.Session, e: Exception) -> None:
335377
336378 requested_endpoint = reg_gds_hit .group (1 )
337379
338- list_result = session .run ("CALL gds.list() YIELD name" )
339- all_endpoints = list_result .to_df ()["name" ].tolist ()
380+ if isinstance (cypher_executor , neo4j .Session ):
381+ list_result = cypher_executor .run ("CALL gds.list() YIELD name" )
382+ all_endpoints = list_result .to_df ()["name" ].tolist ()
383+ elif isinstance (cypher_executor , neo4j .Driver ):
384+ result = cypher_executor .execute_query ("CALL gds.list() YIELD name" , result_transformer_ = neo4j .Result .to_df )
385+ all_endpoints = result ["name" ].tolist ()
386+ else :
387+ raise TypeError (
388+ f"Expected cypher_executor to be a neo4j.Session or neo4j.Driver, got { type (cypher_executor )} "
389+ )
340390
341391 raise SyntaxError (generate_suggestive_error_message (requested_endpoint , all_endpoints )) from e
342392
0 commit comments