|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | from unittest.mock import MagicMock |
| 16 | +from warnings import catch_warnings |
16 | 17 |
|
17 | 18 | import pytest |
18 | 19 | from neo4j_genai.exceptions import RagInitializationError, SearchValidationError |
|
21 | 22 | from neo4j_genai.generation.types import RagResultModel |
22 | 23 | from neo4j_genai.llm import LLMResponse |
23 | 24 | from neo4j_genai.types import RetrieverResult, RetrieverResultItem |
| 25 | +from pydantic import ValidationError |
24 | 26 |
|
25 | 27 |
|
26 | 28 | def test_graphrag_prompt_template() -> None: |
@@ -99,3 +101,21 @@ def test_graphrag_search_error(retriever_mock: MagicMock, llm: MagicMock) -> Non |
99 | 101 | with pytest.raises(SearchValidationError) as excinfo: |
100 | 102 | rag.search(10) # type: ignore |
101 | 103 | assert "Input should be a valid string" in str(excinfo) |
| 104 | + |
| 105 | + |
| 106 | +def test_graphrag_search_query_deprecation_warning( |
| 107 | + retriever_mock: MagicMock, llm: MagicMock |
| 108 | +) -> None: |
| 109 | + with catch_warnings(record=True) as warn_list: |
| 110 | + rag = GraphRAG( |
| 111 | + retriever=retriever_mock, |
| 112 | + llm=llm, |
| 113 | + ) |
| 114 | + with pytest.raises(ValidationError): |
| 115 | + rag.search(query="Some query text") |
| 116 | + |
| 117 | + assert len(warn_list) == 1 |
| 118 | + assert ( |
| 119 | + str(warn_list[0].message) |
| 120 | + == "'query' is deprecated and will be removed in a future version, please use 'query_text' instead." |
| 121 | + ) |
0 commit comments