1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15- from unittest .mock import patch , MagicMock
15+ from unittest .mock import MagicMock , patch
1616
1717import pytest
1818from neo4j .exceptions import CypherSyntaxError , Neo4jError
1919from neo4j_genai import Text2CypherRetriever
2020from neo4j_genai .exceptions import (
21- SearchValidationError ,
2221 RetrieverInitializationError ,
22+ SearchValidationError ,
2323 Text2CypherRetrievalError ,
2424)
2525from neo4j_genai .prompts import TEXT2CYPHER_PROMPT
@@ -85,14 +85,16 @@ def test_t2c_retriever_invalid_search_query(
8585def test_t2c_retriever_invalid_search_examples (
8686 _verify_version_mock : MagicMock , driver : MagicMock , llm : MagicMock
8787) -> None :
88- with pytest .raises (SearchValidationError ) as exc_info :
89- retriever = Text2CypherRetriever (
90- driver = driver , llm = llm , neo4j_schema = "dummy-text"
88+ with pytest .raises (RetrieverInitializationError ) as exc_info :
89+ Text2CypherRetriever (
90+ driver = driver ,
91+ llm = llm ,
92+ neo4j_schema = "dummy-text" ,
93+ examples = 42 , # type: ignore
9194 )
92- retriever .search (query_text = "dummy-text" , examples = 42 )
9395
9496 assert "examples" in str (exc_info .value )
95- assert "Input should be a valid list " in str (exc_info .value )
97+ assert "Initialization failed " in str (exc_info .value )
9698
9799
98100@patch ("neo4j_genai.Text2CypherRetriever._verify_version" )
@@ -106,7 +108,9 @@ def test_t2c_retriever_happy_path(
106108 query_text = "may thy knife chip and shatter"
107109 neo4j_schema = "dummy-schema"
108110 examples = ["example-1" , "example-2" ]
109- retriever = Text2CypherRetriever (driver = driver , llm = llm , neo4j_schema = neo4j_schema )
111+ retriever = Text2CypherRetriever (
112+ driver = driver , llm = llm , neo4j_schema = neo4j_schema , examples = examples
113+ )
110114 retriever .llm .invoke .return_value = t2c_query
111115 retriever .driver .execute_query .return_value = ( # type: ignore
112116 [neo4j_record ],
@@ -118,7 +122,7 @@ def test_t2c_retriever_happy_path(
118122 examples = "\n " .join (examples ),
119123 input = query_text ,
120124 )
121- retriever .search (query_text = query_text , examples = examples )
125+ retriever .search (query_text = query_text )
122126 retriever .llm .invoke .assert_called_once_with (prompt )
123127 retriever .driver .execute_query .assert_called_once_with (query_ = t2c_query ) # type: ignore
124128
@@ -130,10 +134,12 @@ def test_t2c_retriever_cypher_error(
130134 t2c_query = "this is not a cypher query"
131135 neo4j_schema = "dummy-schema"
132136 examples = ["example-1" , "example-2" ]
133- retriever = Text2CypherRetriever (driver = driver , llm = llm , neo4j_schema = neo4j_schema )
137+ retriever = Text2CypherRetriever (
138+ driver = driver , llm = llm , neo4j_schema = neo4j_schema , examples = examples
139+ )
134140 retriever .llm .invoke .return_value = t2c_query
135141 query_text = "may thy knife chip and shatter"
136142 driver .execute_query .side_effect = CypherSyntaxError
137143 with pytest .raises (Text2CypherRetrievalError ) as e :
138- retriever .search (query_text = query_text , examples = examples )
144+ retriever .search (query_text = query_text )
139145 assert "Failed to get search result" in str (e )
0 commit comments