3333 Neo4jRelationship ,
3434)
3535from neo4j_graphrag .experimental .pipeline .component import Component , DataModel
36- from neo4j_graphrag .neo4j_queries import UPSERT_NODE_QUERY , UPSERT_RELATIONSHIP_QUERY
36+ from neo4j_graphrag .neo4j_queries import (
37+ UPSERT_NODE_QUERY ,
38+ UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE ,
39+ UPSERT_RELATIONSHIP_QUERY ,
40+ UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE ,
41+ )
3742
3843logger = logging .getLogger (__name__ )
3944
@@ -113,6 +118,7 @@ def __init__(
113118 self .neo4j_database = neo4j_database
114119 self .batch_size = batch_size
115120 self .max_concurrency = max_concurrency
121+ self .is_version_5_23_or_above = self ._check_if_version_5_23_or_above ()
116122
117123 def _db_setup (self ) -> None :
118124 # create index on __Entity__.id
@@ -147,7 +153,12 @@ def _upsert_nodes(self, nodes: list[Neo4jNode]) -> None:
147153 nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
148154 """
149155 parameters = {"rows" : self ._nodes_to_rows (nodes )}
150- self .driver .execute_query (UPSERT_NODE_QUERY , parameters_ = parameters )
156+ if self .is_version_5_23_or_above :
157+ self .driver .execute_query (
158+ UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE , parameters_ = parameters
159+ )
160+ else :
161+ self .driver .execute_query (UPSERT_NODE_QUERY , parameters_ = parameters )
151162
152163 async def _async_upsert_nodes (
153164 self ,
@@ -161,7 +172,32 @@ async def _async_upsert_nodes(
161172 """
162173 async with sem :
163174 parameters = {"rows" : self ._nodes_to_rows (nodes )}
164- await self .driver .execute_query (UPSERT_NODE_QUERY , parameters_ = parameters )
175+ await self .driver .execute_query (
176+ UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE , parameters_ = parameters
177+ )
178+
179+ def _get_version (self ) -> tuple [int , ...]:
180+ records , _ , _ = self .driver .execute_query (
181+ "CALL dbms.components()" , database_ = self .neo4j_database
182+ )
183+ version = records [0 ]["versions" ][0 ]
184+ # Drop everything after the '-' first
185+ version_main , * _ = version .split ("-" )
186+ # Convert each number between '.' into int
187+ version_tuple = tuple (map (int , version_main .split ("." )))
188+ # If no patch version, consider it's 0
189+ if len (version_tuple ) < 3 :
190+ version_tuple = (* version_tuple , 0 )
191+ return version_tuple
192+
193+ def _check_if_version_5_23_or_above (self ) -> bool :
194+ """
195+ Check if the connected Neo4j database version supports the required features.
196+
197+ Sets a flag if the connected Neo4j version is 5.23 or above.
198+ """
199+ version_tuple = self ._get_version ()
200+ return version_tuple >= (5 , 23 , 0 )
165201
166202 def _upsert_relationships (self , rels : list [Neo4jRelationship ]) -> None :
167203 """Upserts a single relationship into the Neo4j database.
@@ -170,7 +206,12 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
170206 rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
171207 """
172208 parameters = {"rows" : [rel .model_dump () for rel in rels ]}
173- self .driver .execute_query (UPSERT_RELATIONSHIP_QUERY , parameters_ = parameters )
209+ if self .is_version_5_23_or_above :
210+ self .driver .execute_query (
211+ UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE , parameters_ = parameters
212+ )
213+ else :
214+ self .driver .execute_query (UPSERT_RELATIONSHIP_QUERY , parameters_ = parameters )
174215
175216 async def _async_upsert_relationships (
176217 self , rels : list [Neo4jRelationship ], sem : asyncio .Semaphore
@@ -182,9 +223,15 @@ async def _async_upsert_relationships(
182223 """
183224 async with sem :
184225 parameters = {"rows" : [rel .model_dump () for rel in rels ]}
185- await self .driver .execute_query (
186- UPSERT_RELATIONSHIP_QUERY , parameters_ = parameters
187- )
226+ if self .is_version_5_23_or_above :
227+ await self .driver .execute_query (
228+ UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE ,
229+ parameters_ = parameters ,
230+ )
231+ else :
232+ await self .driver .execute_query (
233+ UPSERT_RELATIONSHIP_QUERY , parameters_ = parameters
234+ )
188235
189236 @validate_call
190237 async def run (self , graph : Neo4jGraph ) -> KGWriterModel :
@@ -193,12 +240,6 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
193240 Args:
194241 graph (Neo4jGraph): The knowledge graph to upsert into the database.
195242 """
196- # we disable the notification logger to get rid of the deprecation
197- # warning about Cypher subqueries. Once the queries are updated
198- # for Neo4j 5.23, we can remove this line and the 'finally' block
199- notification_logger = logging .getLogger ("neo4j.notifications" )
200- notification_level = notification_logger .level
201- notification_logger .setLevel (logging .ERROR )
202243 try :
203244 if inspect .iscoroutinefunction (self .driver .execute_query ):
204245 await self ._async_db_setup ()
@@ -233,5 +274,3 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
233274 except neo4j .exceptions .ClientError as e :
234275 logger .exception (e )
235276 return KGWriterModel (status = "FAILURE" , metadata = {"error" : str (e )})
236- finally :
237- notification_logger .setLevel (notification_level )
0 commit comments