@@ -148,8 +148,8 @@ class Neo4jMessageHistory(MessageHistory):
148148 Args:
149149 session_id (Union[str, int]): Unique identifier for the chat session.
150150 driver (neo4j.Driver): Neo4j driver instance.
151- node_label (str, optional): Label used for session nodes in Neo4j. Defaults to "Session".
152151 window (Optional[PositiveInt], optional): Number of previous messages to return when retrieving messages.
152+ database (Optional[str], optional): Neo4j database name.
153153
154154 """
155155
@@ -158,28 +158,33 @@ def __init__(
158158 session_id : Union [str , int ],
159159 driver : neo4j .Driver ,
160160 window : Optional [PositiveInt ] = None ,
161+ database : Optional [str ] = None ,
161162 ) -> None :
162163 validated_data = Neo4jMessageHistoryModel (
163164 session_id = session_id ,
164165 driver_model = Neo4jDriverModel (driver = driver ),
165166 window = window ,
167+ database = database ,
166168 )
167169 self ._driver = validated_data .driver_model .driver
168170 self ._session_id = validated_data .session_id
169171 self ._window = (
170172 "" if validated_data .window is None else validated_data .window - 1
171173 )
174+ self ._database = validated_data .database
172175 # Create session node
173176 self ._driver .execute_query (
174177 query_ = CREATE_SESSION_NODE_QUERY .format (node_label = "Session" ),
175178 parameters_ = {"session_id" : self ._session_id },
179+ database_ = self ._database ,
176180 )
177181
178182 @property
179183 def messages (self ) -> List [LLMMessage ]:
180184 result = self ._driver .execute_query (
181185 query_ = GET_MESSAGES_QUERY .format (node_label = "Session" , window = self ._window ),
182186 parameters_ = {"session_id" : self ._session_id },
187+ database_ = self ._database ,
183188 )
184189 messages = [
185190 LLMMessage (
@@ -210,6 +215,7 @@ def add_message(self, message: LLMMessage) -> None:
210215 "content" : message ["content" ],
211216 "session_id" : self ._session_id ,
212217 },
218+ database_ = self ._database ,
213219 )
214220
215221 def clear (self , delete_session_node : bool = False ) -> None :
@@ -222,9 +228,11 @@ def clear(self, delete_session_node: bool = False) -> None:
222228 self ._driver .execute_query (
223229 query_ = DELETE_SESSION_AND_MESSAGES_QUERY .format (node_label = "Session" ),
224230 parameters_ = {"session_id" : self ._session_id },
231+ database_ = self ._database ,
225232 )
226233 else :
227234 self ._driver .execute_query (
228235 query_ = DELETE_MESSAGES_QUERY .format (node_label = "Session" ),
229236 parameters_ = {"session_id" : self ._session_id },
237+ database_ = self ._database ,
230238 )
0 commit comments