2525)
2626from neo4j_graphrag .generation .prompts import Text2CypherTemplate
2727from neo4j_graphrag .llm import LLMResponse
28- from neo4j_graphrag .retrievers import Text2CypherRetriever
28+ from neo4j_graphrag .retrievers . text2cypher import Text2CypherRetriever , extract_cypher
2929from neo4j_graphrag .types import RetrieverResult , RetrieverResultItem
3030
3131
@@ -204,9 +204,11 @@ def test_t2c_retriever_with_result_format_function(
204204 )
205205
206206
207+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
207208@patch ("neo4j_graphrag.retrievers.base.get_version" )
208209def test_t2c_retriever_initialization_with_custom_prompt (
209210 mock_get_version : MagicMock ,
211+ mock_extract_cypher : MagicMock ,
210212 driver : MagicMock ,
211213 llm : MagicMock ,
212214 neo4j_record : MagicMock ,
@@ -224,9 +226,11 @@ def test_t2c_retriever_initialization_with_custom_prompt(
224226 llm .invoke .assert_called_once_with ("This is a custom prompt. test" )
225227
226228
229+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
227230@patch ("neo4j_graphrag.retrievers.base.get_version" )
228231def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples (
229232 mock_get_version : MagicMock ,
233+ mock_extract_cypher : MagicMock ,
230234 driver : MagicMock ,
231235 llm : MagicMock ,
232236 neo4j_record : MagicMock ,
@@ -254,9 +258,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
254258 llm .invoke .assert_called_once_with ("This is a custom prompt. test" )
255259
256260
261+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
257262@patch ("neo4j_graphrag.retrievers.base.get_version" )
258263def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples_for_prompt_params (
259264 mock_get_version : MagicMock ,
265+ mock_extract_cypher : MagicMock ,
260266 driver : MagicMock ,
261267 llm : MagicMock ,
262268 neo4j_record : MagicMock ,
@@ -286,9 +292,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
286292 )
287293
288294
295+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
289296@patch ("neo4j_graphrag.retrievers.base.get_version" )
290297def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_examples (
291298 mock_get_version : MagicMock ,
299+ mock_extract_cypher : MagicMock ,
292300 driver : MagicMock ,
293301 llm : MagicMock ,
294302 neo4j_record : MagicMock ,
@@ -321,9 +329,13 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_e
321329 )
322330
323331
332+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
324333@patch ("neo4j_graphrag.retrievers.base.get_version" )
325334def test_t2c_retriever_invalid_custom_prompt_type (
326- mock_get_version : MagicMock , driver : MagicMock , llm : MagicMock
335+ mock_get_version : MagicMock ,
336+ mock_extract_cypher : MagicMock ,
337+ driver : MagicMock ,
338+ llm : MagicMock ,
327339) -> None :
328340 mock_get_version .return_value = ((5 , 23 , 0 ), False , False )
329341 with pytest .raises (RetrieverInitializationError ) as exc_info :
@@ -336,9 +348,11 @@ def test_t2c_retriever_invalid_custom_prompt_type(
336348 assert "Input should be a valid string" in str (exc_info .value )
337349
338350
351+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
339352@patch ("neo4j_graphrag.retrievers.base.get_version" )
340353def test_t2c_retriever_with_custom_prompt_prompt_params (
341354 mock_get_version : MagicMock ,
355+ mock_extract_cypher : MagicMock ,
342356 driver : MagicMock ,
343357 llm : MagicMock ,
344358 neo4j_record : MagicMock ,
@@ -361,9 +375,11 @@ def test_t2c_retriever_with_custom_prompt_prompt_params(
361375 )
362376
363377
378+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
364379@patch ("neo4j_graphrag.retrievers.base.get_version" )
365380def test_t2c_retriever_with_custom_prompt_bad_prompt_params (
366381 mock_get_version : MagicMock ,
382+ mock_extract_cypher : MagicMock ,
367383 driver : MagicMock ,
368384 llm : MagicMock ,
369385 neo4j_record : MagicMock ,
@@ -392,11 +408,13 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
392408 )
393409
394410
411+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
395412@patch ("neo4j_graphrag.retrievers.base.get_version" )
396413@patch ("neo4j_graphrag.retrievers.text2cypher.get_schema" )
397414def test_t2c_retriever_with_custom_prompt_and_schema (
398415 get_schema_mock : MagicMock ,
399416 mock_get_version : MagicMock ,
417+ mock_extract_cypher : MagicMock ,
400418 driver : MagicMock ,
401419 llm : MagicMock ,
402420 neo4j_record : MagicMock ,
@@ -419,3 +437,67 @@ def test_t2c_retriever_with_custom_prompt_and_schema(
419437
420438 get_schema_mock .assert_not_called ()
421439 llm .invoke .assert_called_once_with ("""This is a custom prompt. test """ )
440+
441+
442+ @pytest .mark .parametrize (
443+ "description, cypher_query, expected_output" ,
444+ [
445+ ("No changes" , "MATCH (n) RETURN n;" , "MATCH (n) RETURN n;" ),
446+ (
447+ "Surrounded by backticks" ,
448+ "Cypher query: ```MATCH (n) RETURN n;```" ,
449+ "MATCH (n) RETURN n;" ,
450+ ),
451+ (
452+ "Spaces in label" ,
453+ "Cypher query: ```MATCH (n: Label With Spaces ) RETURN n;```" ,
454+ "MATCH (n:`Label With Spaces`) RETURN n;" ,
455+ ),
456+ (
457+ "No spaces in label" ,
458+ "Cypher query: ```MATCH (n: LabelWithNoSpaces ) RETURN n;```" ,
459+ "MATCH (n: LabelWithNoSpaces ) RETURN n;" ,
460+ ),
461+ (
462+ "Backticks in label" ,
463+ "Cypher query: ```MATCH (n: `LabelWithBackticks` ) RETURN n;```" ,
464+ "MATCH (n: `LabelWithBackticks` ) RETURN n;" ,
465+ ),
466+ (
467+ "Spaces in property key" ,
468+ "Cypher query: ```MATCH (n: { prop 1: 1, prop 2: 2 }) RETURN n;```" ,
469+ "MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;" ,
470+ ),
471+ (
472+ "No spaces in property key" ,
473+ "Cypher query: ```MATCH (n: { prop1: 1, prop2: 2 }) RETURN n;```" ,
474+ "MATCH (n: { prop1: 1, prop2: 2 }) RETURN n;" ,
475+ ),
476+ (
477+ "Backticks in property key" ,
478+ "Cypher query: ```MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;```" ,
479+ "MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;" ,
480+ ),
481+ (
482+ "Spaces in relationship type" ,
483+ "Cypher query: ```MATCH (n)-[: Relationship With Spaces ]->(m) RETURN n, m;```" ,
484+ "MATCH (n)-[:`Relationship With Spaces`]->(m) RETURN n, m;" ,
485+ ),
486+ (
487+ "No spaces in relationship type" ,
488+ "Cypher query: ```MATCH (n)-[ : RelationshipWithNoSpaces ]->(m) RETURN n, m;```" ,
489+ "MATCH (n)-[ : RelationshipWithNoSpaces ]->(m) RETURN n, m;" ,
490+ ),
491+ (
492+ "Backticks in relationship type" ,
493+ "Cypher query: ```MATCH (n)-[ : `RelationshipWithBackticks` ]->(m) RETURN n, m;```" ,
494+ "MATCH (n)-[ : `RelationshipWithBackticks` ]->(m) RETURN n, m;" ,
495+ ),
496+ ],
497+ )
498+ def test_extract_cypher (
499+ description : str , cypher_query : str , expected_output : str
500+ ) -> None :
501+ assert (
502+ extract_cypher (cypher_query ) == expected_output
503+ ), f"Failed test case: { description } "
0 commit comments