@@ -48,7 +48,7 @@ class Text2CypherRetriever(Retriever):
4848 """
4949 Allows for the retrieval of records from a Neo4j database using natural language.
5050 Converts a user's natural language query to a Cypher query using an LLM,
51- then retrieves records from a Neo4j database using the generated Cypher query
51+ then retrieves records from a Neo4j database using the generated Cypher query.
5252
5353 Args:
5454 driver (neo4j.Driver): The Neo4j Python driver.
@@ -98,23 +98,23 @@ def __init__(
9898 self .examples = validated_data .examples
9999 self .result_formatter = validated_data .result_formatter
100100 self .custom_prompt = validated_data .custom_prompt
101- try :
101+ if validated_data .custom_prompt :
102+ neo4j_schema = ""
103+ else :
102104 if (
103- not validated_data .custom_prompt
104- ): # don't need schema for a custom prompt
105- self .neo4j_schema = (
106- validated_data .neo4j_schema_model .neo4j_schema
107- if validated_data .neo4j_schema_model
108- else get_schema (validated_data .driver_model .driver )
109- )
105+ validated_data .neo4j_schema_model
106+ and validated_data .neo4j_schema_model .neo4j_schema
107+ ):
108+ neo4j_schema = validated_data .neo4j_schema_model .neo4j_schema
110109 else :
111- self .neo4j_schema = ""
112-
113- except (Neo4jError , DriverError ) as e :
114- error_message = getattr (e , "message" , str (e ))
115- raise SchemaFetchError (
116- f"Failed to fetch schema for Text2CypherRetriever: { error_message } "
117- ) from e
110+ try :
111+ neo4j_schema = get_schema (validated_data .driver_model .driver )
112+ except (Neo4jError , DriverError ) as e :
113+ error_message = getattr (e , "message" , str (e ))
114+ raise SchemaFetchError (
115+ f"Failed to fetch schema for Text2CypherRetriever: { error_message } "
116+ ) from e
117+ self .neo4j_schema = neo4j_schema
118118
119119 def get_search_results (
120120 self , query_text : str , prompt_params : Optional [Dict [str , Any ]] = None
@@ -142,12 +142,10 @@ def get_search_results(
142142
143143 if prompt_params is not None :
144144 # parse the schema and examples inputs
145- examples_to_use = prompt_params .get ("examples" ) or (
145+ examples_to_use = prompt_params .pop ("examples" , None ) or (
146146 "\n " .join (self .examples ) if self .examples else ""
147147 )
148- schema_to_use = prompt_params .get ("schema" ) or self .neo4j_schema
149- prompt_params .pop ("examples" , None )
150- prompt_params .pop ("schema" , None )
148+ schema_to_use = prompt_params .pop ("schema" , None ) or self .neo4j_schema
151149 else :
152150 examples_to_use = "\n " .join (self .examples ) if self .examples else ""
153151 schema_to_use = self .neo4j_schema
0 commit comments