1414# limitations under the License.
1515from __future__ import annotations
1616
17+ import asyncio
1718import logging
1819from abc import abstractmethod
19- from typing import Literal , Optional
20+ from typing import Any , Dict , Literal , Optional , Tuple
2021
2122import neo4j
2223from pydantic import validate_call
2728 Neo4jRelationship ,
2829)
2930from neo4j_genai .experimental .pipeline .component import Component , DataModel
30- from neo4j_genai .indexes import upsert_vector , upsert_vector_on_relationship
31+ from neo4j_genai .indexes import (
32+ async_upsert_vector ,
33+ async_upsert_vector_on_relationship ,
34+ upsert_vector ,
35+ upsert_vector_on_relationship ,
36+ )
3137from neo4j_genai .neo4j_queries import UPSERT_NODE_QUERY , UPSERT_RELATIONSHIP_QUERY
3238
3339logger = logging .getLogger (__name__ )
@@ -64,20 +70,21 @@ class Neo4jWriter(KGWriter):
6470 Args:
6571 driver (neo4j.driver): The Neo4j driver to connect to the database.
6672 neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
73+ max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
6774
6875 Example:
6976
7077 .. code-block:: python
7178
72- from neo4j import GraphDatabase
79+ from neo4j import AsyncGraphDatabase
7380 from neo4j_genai.experimental.components.kg_writer import Neo4jWriter
7481 from neo4j_genai.experimental.pipeline import Pipeline
7582
7683 URI = "neo4j://localhost:7687"
7784 AUTH = ("neo4j", "password")
7885 DATABASE = "neo4j"
7986
80- driver = GraphDatabase .driver(URI, auth=AUTH, database=DATABASE)
87+ driver = AsyncGraphDatabase .driver(URI, auth=AUTH, database=DATABASE)
8188 writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
8289
8390 pipeline = Pipeline()
@@ -89,16 +96,13 @@ def __init__(
8996 self ,
9097 driver : neo4j .driver ,
9198 neo4j_database : Optional [str ] = None ,
99+ max_concurrency : int = 5 ,
92100 ):
93101 self .driver = driver
94102 self .neo4j_database = neo4j_database
103+ self .max_concurrency = max_concurrency
95104
96- def _upsert_node (self , node : Neo4jNode ) -> None :
97- """Upserts a single node into the Neo4j database."
98-
99- Args:
100- node (Neo4jNode): The node to upsert into the database.
101- """
105+ def _get_node_query (self , node : Neo4jNode ) -> Tuple [str , Dict [str , Any ]]:
102106 # Create the initial node
103107 parameters = {"id" : node .id }
104108 if node .properties :
@@ -107,6 +111,15 @@ def _upsert_node(self, node: Neo4jNode) -> None:
107111 "{" + ", " .join (f"{ key } : ${ key } " for key in parameters .keys ()) + "}"
108112 )
109113 query = UPSERT_NODE_QUERY .format (label = node .label , properties = properties )
114+ return query , parameters
115+
116+ def _upsert_node (self , node : Neo4jNode ) -> None :
117+ """Upserts a single node into the Neo4j database."
118+
119+ Args:
120+ node (Neo4jNode): The node to upsert into the database.
121+ """
122+ query , parameters = self ._get_node_query (node )
110123 result = self .driver .execute_query (query , parameters_ = parameters )
111124 node_id = result .records [0 ]["elementID(n)" ]
112125 # Add the embedding properties to the node
@@ -120,12 +133,32 @@ def _upsert_node(self, node: Neo4jNode) -> None:
120133 neo4j_database = self .neo4j_database ,
121134 )
122135
123- def _upsert_relationship (self , rel : Neo4jRelationship ) -> None :
124- """Upserts a single relationship into the Neo4j database.
136+ async def _async_upsert_node (
137+ self ,
138+ node : Neo4jNode ,
139+ sem : asyncio .Semaphore ,
140+ ) -> None :
141+ """Asynchronously upserts a single node into the Neo4j database."
125142
126143 Args:
127- rel (Neo4jRelationship ): The relationship to upsert into the database.
144+ node (Neo4jNode ): The node to upsert into the database.
128145 """
146+ async with sem :
147+ query , parameters = self ._get_node_query (node )
148+ result = await self .driver .execute_query (query , parameters_ = parameters )
149+ node_id = result .records [0 ]["elementID(n)" ]
150+ # Add the embedding properties to the node
151+ if node .embedding_properties :
152+ for prop , vector in node .embedding_properties .items ():
153+ await async_upsert_vector (
154+ driver = self .driver ,
155+ node_id = node_id ,
156+ embedding_property = prop ,
157+ vector = vector ,
158+ neo4j_database = self .neo4j_database ,
159+ )
160+
161+ def _get_rel_query (self , rel : Neo4jRelationship ) -> Tuple [str , Dict [str , Any ]]:
129162 # Create the initial relationship
130163 parameters = {
131164 "start_node_id" : rel .start_node_id ,
@@ -142,6 +175,15 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
142175 type = rel .type ,
143176 properties = properties ,
144177 )
178+ return query , parameters
179+
180+ def _upsert_relationship (self , rel : Neo4jRelationship ) -> None :
181+ """Upserts a single relationship into the Neo4j database.
182+
183+ Args:
184+ rel (Neo4jRelationship): The relationship to upsert into the database.
185+ """
186+ query , parameters = self ._get_rel_query (rel )
145187 result = self .driver .execute_query (query , parameters_ = parameters )
146188 rel_id = result .records [0 ]["elementID(r)" ]
147189 # Add the embedding properties to the relationship
@@ -155,6 +197,29 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
155197 neo4j_database = self .neo4j_database ,
156198 )
157199
200+ async def _async_upsert_relationship (
201+ self , rel : Neo4jRelationship , sem : asyncio .Semaphore
202+ ) -> None :
203+ """Asynchronously upserts a single relationship into the Neo4j database.
204+
205+ Args:
206+ rel (Neo4jRelationship): The relationship to upsert into the database.
207+ """
208+ async with sem :
209+ query , parameters = self ._get_rel_query (rel )
210+ result = await self .driver .execute_query (query , parameters_ = parameters )
211+ rel_id = result .records [0 ]["elementID(r)" ]
212+ # Add the embedding properties to the relationship
213+ if rel .embedding_properties :
214+ for prop , vector in rel .embedding_properties .items ():
215+ await async_upsert_vector_on_relationship (
216+ driver = self .driver ,
217+ rel_id = rel_id ,
218+ embedding_property = prop ,
219+ vector = vector ,
220+ neo4j_database = self .neo4j_database ,
221+ )
222+
158223 @validate_call
159224 async def run (self , graph : Neo4jGraph ) -> KGWriterModel :
160225 """Upserts a knowledge graph into a Neo4j database.
@@ -163,11 +228,24 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
163228 graph (Neo4jGraph): The knowledge graph to upsert into the database.
164229 """
165230 try :
166- for node in graph .nodes :
167- self ._upsert_node (node )
168-
169- for rel in graph .relationships :
170- self ._upsert_relationship (rel )
231+ if isinstance (self .driver , neo4j .AsyncDriver ):
232+ sem = asyncio .Semaphore (self .max_concurrency )
233+ node_tasks = [
234+ self ._async_upsert_node (node , sem ) for node in graph .nodes
235+ ]
236+ await asyncio .gather (* node_tasks )
237+
238+ rel_tasks = [
239+ self ._async_upsert_relationship (rel , sem )
240+ for rel in graph .relationships
241+ ]
242+ await asyncio .gather (* rel_tasks )
243+ else :
244+ for node in graph .nodes :
245+ self ._upsert_node (node )
246+
247+ for rel in graph .relationships :
248+ self ._upsert_relationship (rel )
171249
172250 return KGWriterModel (status = "SUCCESS" )
173251 except neo4j .exceptions .ClientError as e :
0 commit comments