-
Notifications
You must be signed in to change notification settings - Fork 151
LangChain Compability Restored #442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 18 commits
595b0f0
1b4c339
ab08ca1
da94032
2fc6db8
6a69656
bdae480
0710a27
db997da
8c324ab
053f70f
1a45c4c
be9bba1
c38ba4c
ce86150
0e94ddd
1c05c9d
e2dc9de
85a478a
abf0efa
b951fe6
6fbb2e5
b7cc001
4ad551b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,29 +12,35 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from __future__ import annotations | ||
|
|
||
| # built-in dependencies | ||
| from __future__ import annotations | ||
| import logging | ||
| import warnings | ||
| from typing import Any, List, Optional, Union | ||
|
|
||
| # 3rd party dependencies | ||
| from pydantic import ValidationError | ||
|
|
||
| # project dependencies | ||
| from neo4j_graphrag.exceptions import ( | ||
| RagInitializationError, | ||
| SearchValidationError, | ||
| ) | ||
| from neo4j_graphrag.generation.prompts import RagTemplate | ||
| from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel | ||
| from neo4j_graphrag.llm import LLMInterface | ||
| from neo4j_graphrag.llm import LLMInterface, LLMInterfaceV2 | ||
| from neo4j_graphrag.llm.utils import legacy_inputs_to_messages | ||
| from neo4j_graphrag.message_history import MessageHistory | ||
| from neo4j_graphrag.retrievers.base import Retriever | ||
| from neo4j_graphrag.types import LLMMessage, RetrieverResult | ||
| from neo4j_graphrag.utils.logging import prettify | ||
|
|
||
| # Set up logger | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| # pylint: disable=raise-missing-from | ||
| class GraphRAG: | ||
| """Performs a GraphRAG search using a specific retriever | ||
| and LLM. | ||
|
|
@@ -57,8 +63,10 @@ class GraphRAG: | |
|
|
||
| Args: | ||
| retriever (Retriever): The retriever used to find relevant context to pass to the LLM. | ||
| llm (LLMInterface): The LLM used to generate the answer. | ||
| prompt_template (RagTemplate): The prompt template that will be formatted with context and user question and passed to the LLM. | ||
| llm (LLMInterface, LLMInterfaceV2 or LangChain Chat Model): The LLM used to generate | ||
| the answer. | ||
| prompt_template (RagTemplate): The prompt template that will be formatted with context and | ||
| user question and passed to the LLM. | ||
|
|
||
| Raises: | ||
| RagInitializationError: If validation of the input arguments fail. | ||
|
|
@@ -67,7 +75,7 @@ class GraphRAG: | |
| def __init__( | ||
| self, | ||
| retriever: Retriever, | ||
| llm: LLMInterface, | ||
| llm: Union[LLMInterface, LLMInterfaceV2], | ||
| prompt_template: RagTemplate = RagTemplate(), | ||
| ): | ||
| try: | ||
|
|
@@ -93,7 +101,8 @@ def search( | |
| ) -> RagResultModel: | ||
| """ | ||
| .. warning:: | ||
| The default value of 'return_context' will change from 'False' to 'True' in a future version. | ||
| The default value of 'return_context' will change from 'False' | ||
| to 'True' in a future version. | ||
|
|
||
|
|
||
| This method performs a full RAG search: | ||
|
|
@@ -104,24 +113,30 @@ def search( | |
|
|
||
| Args: | ||
| query_text (str): The user question. | ||
| message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, | ||
| with each message having a specific role assigned. | ||
| message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection | ||
| previous messages, with each message having a specific role assigned. | ||
| examples (str): Examples added to the LLM prompt. | ||
| retriever_config (Optional[dict]): Parameters passed to the retriever. | ||
| search method; e.g.: top_k | ||
| return_context (bool): Whether to append the retriever result to the final result (default: False). | ||
| response_fallback (Optional[str]): If not null, will return this message instead of calling the LLM if context comes back empty. | ||
| return_context (bool): Whether to append the retriever result to the final result | ||
| (default: False). | ||
| response_fallback (Optional[str]): If not null, will return this message instead | ||
| of calling the LLM if context comes back empty. | ||
|
|
||
| Returns: | ||
| RagResultModel: The LLM-generated answer. | ||
|
|
||
| """ | ||
| if return_context is None: | ||
| warnings.warn( | ||
| "The default value of 'return_context' will change from 'False' to 'True' in a future version.", | ||
| DeprecationWarning, | ||
| ) | ||
| return_context = False | ||
| if isinstance(self.llm, LLMInterface): | ||
| warnings.warn( | ||
| "The default value of 'return_context' will change from 'False'" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changes seem to be related to line length only, are there made on purpose? (asking because I know Nathalie had some issues about this in the past)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, pylint was warning about that line is too long. Autolinter in my ide splitted this line to many lines. I can revert them back if you think it may cause any issue.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just that the next person pushing to the repo will have to make the opposite change, so we should agree on a convention. Is this something we should add to a config file in the repo? |
||
| " to 'True' in a future version.", | ||
| DeprecationWarning, | ||
| ) | ||
| return_context = False | ||
| else: | ||
| return_context = True | ||
| try: | ||
| validated_data = RagSearchModel( | ||
| query_text=query_text, | ||
|
|
@@ -145,13 +160,32 @@ def search( | |
| prompt = self.prompt_template.format( | ||
| query_text=query_text, context=context, examples=validated_data.examples | ||
| ) | ||
| logger.debug(f"RAG: retriever_result={prettify(retriever_result)}") | ||
| logger.debug(f"RAG: prompt={prompt}") | ||
| llm_response = self.llm.invoke( | ||
| prompt, | ||
| message_history, | ||
| system_instruction=self.prompt_template.system_instructions, | ||
| ) | ||
|
|
||
| logger.debug("RAG: retriever_result=%s", prettify(retriever_result)) | ||
| logger.debug("RAG: prompt=%s", prompt) | ||
|
|
||
| if isinstance(self.llm, LLMInterfaceV2) or self.llm.__module__.startswith( | ||
| "langchain" | ||
| ): | ||
| messages = legacy_inputs_to_messages( | ||
| prompt=prompt, | ||
| message_history=message_history, | ||
| system_instruction=self.prompt_template.system_instructions, | ||
| ) | ||
|
|
||
| # langchain chat model compatible invoke | ||
| llm_response = self.llm.invoke( | ||
| input=messages, | ||
| ) | ||
| elif isinstance(self.llm, LLMInterface): | ||
stellasia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # may have custom LLMs inherited from V1, keep it for backward compatibility | ||
| llm_response = self.llm.invoke( | ||
| input=prompt, | ||
| message_history=message_history, | ||
| system_instruction=self.prompt_template.system_instructions, | ||
| ) | ||
| else: | ||
| raise ValueError(f"Type {type(self.llm)} of LLM is not supported.") | ||
| answer = llm_response.content | ||
| result: dict[str, Any] = {"answer": answer} | ||
| if return_context: | ||
|
|
@@ -163,15 +197,33 @@ def _build_query( | |
| query_text: str, | ||
| message_history: Optional[List[LLMMessage]] = None, | ||
| ) -> str: | ||
| summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words." | ||
| """Builds the final query text, incorporating message history if provided.""" | ||
| summary_system_message = ( | ||
| "You are a summarization assistant. " | ||
| "Summarize the given text in no more than 300 words." | ||
| ) | ||
| if message_history: | ||
| summarization_prompt = self._chat_summary_prompt( | ||
| message_history=message_history | ||
| ) | ||
| summary = self.llm.invoke( | ||
| input=summarization_prompt, | ||
| messages = legacy_inputs_to_messages( | ||
serengil marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| summarization_prompt, | ||
| system_instruction=summary_system_message, | ||
| ).content | ||
| ) | ||
| if isinstance(self.llm, LLMInterfaceV2) or self.llm.__module__.startswith( | ||
| "langchain" | ||
| ): | ||
| summary = self.llm.invoke( | ||
| input=messages, | ||
| ).content | ||
| elif isinstance(self.llm, LLMInterface): | ||
| summary = self.llm.invoke( | ||
| input=summarization_prompt, | ||
| system_instruction=summary_system_message, | ||
| ).content | ||
| else: | ||
| raise ValueError(f"Type {type(self.llm)} of LLM is not supported.") | ||
|
|
||
| return self.conversation_prompt(summary=summary, current_query=query_text) | ||
| return query_text | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.