1414# limitations under the License.
1515from __future__ import annotations
1616
17- import asyncio
18- import inspect
1917import logging
2018from abc import abstractmethod
2119from typing import Any , Generator , Literal , Optional
@@ -87,21 +85,21 @@ class Neo4jWriter(KGWriter):
8785 Args:
8886 driver (neo4j.driver): The Neo4j driver to connect to the database.
8987 neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
90- max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM .
88+ batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000 .
9189
9290 Example:
9391
9492 .. code-block:: python
9593
96- from neo4j import AsyncGraphDatabase
94+ from neo4j import GraphDatabase
9795 from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
9896 from neo4j_graphrag.experimental.pipeline import Pipeline
9997
10098 URI = "neo4j://localhost:7687"
10199 AUTH = ("neo4j", "password")
102100 DATABASE = "neo4j"
103101
104- driver = AsyncGraphDatabase .driver(URI, auth=AUTH, database=DATABASE)
102+ driver = GraphDatabase .driver(URI, auth=AUTH, database=DATABASE)
105103 writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
106104
107105 pipeline = Pipeline()
@@ -111,15 +109,13 @@ class Neo4jWriter(KGWriter):
111109
112110 def __init__ (
113111 self ,
114- driver : neo4j .driver ,
112+ driver : neo4j .Driver ,
115113 neo4j_database : Optional [str ] = None ,
116114 batch_size : int = 1000 ,
117- max_concurrency : int = 5 ,
118115 ):
119116 self .driver = driver
120117 self .neo4j_database = neo4j_database
121118 self .batch_size = batch_size
122- self .max_concurrency = max_concurrency
123119 self .is_version_5_23_or_above = self ._check_if_version_5_23_or_above ()
124120
125121 def _db_setup (self ) -> None :
@@ -129,13 +125,6 @@ def _db_setup(self) -> None:
129125 "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
130126 )
131127
132- async def _async_db_setup (self ) -> None :
133- # create index on __Entity__.id
134- # used when creating the relationships
135- await self .driver .execute_query (
136- "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
137- )
138-
139128 @staticmethod
140129 def _nodes_to_rows (
141130 nodes : list [Neo4jNode ], lexical_graph_config : LexicalGraphConfig
@@ -166,23 +155,6 @@ def _upsert_nodes(
166155 else :
167156 self .driver .execute_query (UPSERT_NODE_QUERY , parameters_ = parameters )
168157
169- async def _async_upsert_nodes (
170- self ,
171- nodes : list [Neo4jNode ],
172- lexical_graph_config : LexicalGraphConfig ,
173- sem : asyncio .Semaphore ,
174- ) -> None :
175- """Asynchronously upserts a single node into the Neo4j database."
176-
177- Args:
178- nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
179- """
180- async with sem :
181- parameters = {"rows" : self ._nodes_to_rows (nodes , lexical_graph_config )}
182- await self .driver .execute_query (
183- UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE , parameters_ = parameters
184- )
185-
186158 def _get_version (self ) -> tuple [int , ...]:
187159 records , _ , _ = self .driver .execute_query (
188160 "CALL dbms.components()" , database_ = self .neo4j_database
@@ -220,26 +192,6 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
220192 else :
221193 self .driver .execute_query (UPSERT_RELATIONSHIP_QUERY , parameters_ = parameters )
222194
223- async def _async_upsert_relationships (
224- self , rels : list [Neo4jRelationship ], sem : asyncio .Semaphore
225- ) -> None :
226- """Asynchronously upserts a single relationship into the Neo4j database.
227-
228- Args:
229- rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
230- """
231- async with sem :
232- parameters = {"rows" : [rel .model_dump () for rel in rels ]}
233- if self .is_version_5_23_or_above :
234- await self .driver .execute_query (
235- UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE ,
236- parameters_ = parameters ,
237- )
238- else :
239- await self .driver .execute_query (
240- UPSERT_RELATIONSHIP_QUERY , parameters_ = parameters
241- )
242-
243195 @validate_call
244196 async def run (
245197 self ,
@@ -253,28 +205,13 @@ async def run(
253205 lexical_graph_config (LexicalGraphConfig):
254206 """
255207 try :
256- if inspect .iscoroutinefunction (self .driver .execute_query ):
257- await self ._async_db_setup ()
258- sem = asyncio .Semaphore (self .max_concurrency )
259- node_tasks = [
260- self ._async_upsert_nodes (batch , lexical_graph_config , sem )
261- for batch in batched (graph .nodes , self .batch_size )
262- ]
263- await asyncio .gather (* node_tasks )
264-
265- rel_tasks = [
266- self ._async_upsert_relationships (batch , sem )
267- for batch in batched (graph .relationships , self .batch_size )
268- ]
269- await asyncio .gather (* rel_tasks )
270- else :
271- self ._db_setup ()
272-
273- for batch in batched (graph .nodes , self .batch_size ):
274- self ._upsert_nodes (batch , lexical_graph_config )
275-
276- for batch in batched (graph .relationships , self .batch_size ):
277- self ._upsert_relationships (batch )
208+ self ._db_setup ()
209+
210+ for batch in batched (graph .nodes , self .batch_size ):
211+ self ._upsert_nodes (batch , lexical_graph_config )
212+
213+ for batch in batched (graph .relationships , self .batch_size ):
214+ self ._upsert_relationships (batch )
278215
279216 return KGWriterModel (
280217 status = "SUCCESS" ,
0 commit comments