|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | from __future__ import annotations |
16 | | -from typing import Optional, Any, Callable |
| 16 | + |
| 17 | +import logging |
| 18 | +from typing import Any, Callable, Optional |
17 | 19 |
|
18 | 20 | import neo4j |
19 | 21 | from pydantic import ValidationError |
20 | 22 |
|
21 | 23 | from neo4j_genai.embedder import Embedder |
22 | 24 | from neo4j_genai.exceptions import ( |
| 25 | + EmbeddingRequiredError, |
23 | 26 | RetrieverInitializationError, |
24 | 27 | SearchValidationError, |
25 | | - EmbeddingRequiredError, |
26 | 28 | ) |
| 29 | +from neo4j_genai.neo4j_queries import get_search_query |
27 | 30 | from neo4j_genai.retrievers.base import Retriever |
28 | 31 | from neo4j_genai.types import ( |
29 | | - HybridSearchModel, |
30 | | - SearchType, |
31 | | - HybridCypherSearchModel, |
32 | | - Neo4jDriverModel, |
33 | 32 | EmbedderModel, |
34 | | - HybridRetrieverModel, |
35 | 33 | HybridCypherRetrieverModel, |
| 34 | + HybridCypherSearchModel, |
| 35 | + HybridRetrieverModel, |
| 36 | + HybridSearchModel, |
| 37 | + Neo4jDriverModel, |
36 | 38 | RawSearchResult, |
37 | 39 | RetrieverResultItem, |
| 40 | + SearchType, |
38 | 41 | ) |
39 | | -from neo4j_genai.neo4j_queries import get_search_query |
40 | | -import logging |
41 | 42 |
|
42 | 43 | logger = logging.getLogger(__name__) |
43 | 44 |
|
@@ -146,16 +147,16 @@ def get_search_results( |
146 | 147 | """ |
147 | 148 | try: |
148 | 149 | validated_data = HybridSearchModel( |
149 | | - vector_index_name=self.vector_index_name, |
150 | | - fulltext_index_name=self.fulltext_index_name, |
151 | | - top_k=top_k, |
152 | 150 | query_vector=query_vector, |
153 | 151 | query_text=query_text, |
| 152 | + top_k=top_k, |
154 | 153 | ) |
155 | 154 | except ValidationError as e: |
156 | 155 | raise SearchValidationError(e.errors()) from e |
157 | 156 |
|
158 | 157 | parameters = validated_data.model_dump(exclude_none=True) |
| 158 | + parameters["vector_index_name"] = self.vector_index_name |
| 159 | + parameters["fulltext_index_name"] = self.fulltext_index_name |
159 | 160 |
|
160 | 161 | if query_text and not query_vector: |
161 | 162 | if not self.embedder: |
@@ -276,17 +277,17 @@ def get_search_results( |
276 | 277 | """ |
277 | 278 | try: |
278 | 279 | validated_data = HybridCypherSearchModel( |
279 | | - vector_index_name=self.vector_index_name, |
280 | | - fulltext_index_name=self.fulltext_index_name, |
281 | | - top_k=top_k, |
282 | 280 | query_vector=query_vector, |
283 | 281 | query_text=query_text, |
| 282 | + top_k=top_k, |
284 | 283 | query_params=query_params, |
285 | 284 | ) |
286 | 285 | except ValidationError as e: |
287 | 286 | raise SearchValidationError(e.errors()) from e |
288 | 287 |
|
289 | 288 | parameters = validated_data.model_dump(exclude_none=True) |
| 289 | + parameters["vector_index_name"] = self.vector_index_name |
| 290 | + parameters["fulltext_index_name"] = self.fulltext_index_name |
290 | 291 |
|
291 | 292 | if query_text and not query_vector: |
292 | 293 | if not self.embedder: |
|
0 commit comments