1616
1717import logging
1818import warnings
19- from typing import Any , Optional
19+ from typing import Any , List , Optional , Union
2020
2121from pydantic import ValidationError
2222
2828from neo4j_graphrag .generation .types import RagInitModel , RagResultModel , RagSearchModel
2929from neo4j_graphrag .llm import LLMInterface
3030from neo4j_graphrag .llm .types import LLMMessage
31+ from neo4j_graphrag .message_history import MessageHistory
3132from neo4j_graphrag .retrievers .base import Retriever
3233from neo4j_graphrag .types import RetrieverResult
3334
@@ -84,7 +85,7 @@ def __init__(
8485 def search (
8586 self ,
8687 query_text : str = "" ,
87- message_history : Optional [list [ LLMMessage ]] = None ,
88+ message_history : Optional [Union [ List [ LLMMessage ], MessageHistory ]] = None ,
8889 examples : str = "" ,
8990 retriever_config : Optional [dict [str , Any ]] = None ,
9091 return_context : bool | None = None ,
@@ -102,7 +103,8 @@ def search(
102103
103104 Args:
104105 query_text (str): The user question.
105- message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
106+ message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
107+ with each message having a specific role assigned.
106108 examples (str): Examples added to the LLM prompt.
107109 retriever_config (Optional[dict]): Parameters passed to the retriever.
108110 search method; e.g.: top_k
@@ -127,7 +129,9 @@ def search(
127129 )
128130 except ValidationError as e :
129131 raise SearchValidationError (e .errors ())
130- query = self .build_query (validated_data .query_text , message_history )
132+ if isinstance (message_history , MessageHistory ):
133+ message_history = message_history .messages
134+ query = self ._build_query (validated_data .query_text , message_history )
131135 retriever_result : RetrieverResult = self .retriever .search (
132136 query_text = query , ** validated_data .retriever_config
133137 )
@@ -147,12 +151,14 @@ def search(
147151 result ["retriever_result" ] = retriever_result
148152 return RagResultModel (** result )
149153
150- def build_query (
151- self , query_text : str , message_history : Optional [list [LLMMessage ]] = None
154+ def _build_query (
155+ self ,
156+ query_text : str ,
157+ message_history : Optional [List [LLMMessage ]] = None ,
152158 ) -> str :
153159 summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
154160 if message_history :
155- summarization_prompt = self .chat_summary_prompt (
161+ summarization_prompt = self ._chat_summary_prompt (
156162 message_history = message_history
157163 )
158164 summary = self .llm .invoke (
@@ -162,10 +168,9 @@ def build_query(
162168 return self .conversation_prompt (summary = summary , current_query = query_text )
163169 return query_text
164170
165- def chat_summary_prompt (self , message_history : list [LLMMessage ]) -> str :
171+ def _chat_summary_prompt (self , message_history : List [LLMMessage ]) -> str :
166172 message_list = [
167- ": " .join ([f"{ value } " for _ , value in message .items ()])
168- for message in message_history
173+ f"{ message ['role' ]} : { message ['content' ]} " for message in message_history
169174 ]
170175 history = "\n " .join (message_list )
171176 return f"""
0 commit comments