1616
1717import copy
1818import logging
19- from typing import Any , Callable , Optional
19+ from typing import Any , Callable , Optional , Union
2020
2121import neo4j
2222from pydantic import ValidationError
3939 RawSearchResult ,
4040 RetrieverResultItem ,
4141 SearchType ,
42+ HybridSearchRanker ,
4243)
4344
4445logger = logging .getLogger (__name__ )
@@ -142,6 +143,8 @@ def get_search_results(
142143 query_vector : Optional [list [float ]] = None ,
143144 top_k : int = 5 ,
144145 effective_search_ratio : int = 1 ,
146+ ranker : Union [str , HybridSearchRanker ] = HybridSearchRanker .NAIVE ,
147+ alpha : Optional [float ] = None ,
145148 ) -> RawSearchResult :
146149 """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
147150 Both query_vector and query_text can be provided.
@@ -162,6 +165,10 @@ def get_search_results(
162165 top_k (int, optional): The number of neighbors to return. Defaults to 5.
163166 effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
164167 accuracy and performance. Defaults to 1.
168+ ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
169+ alpha (Optional[float]): Weight for the vector score when using the linear ranker.
170+ The fulltext index score is multiplied by (1 - alpha).
171+ **Required** when using the linear ranker; must be between 0 and 1.
165172
166173 Raises:
167174 SearchValidationError: If validation of the input arguments fail.
@@ -176,6 +183,8 @@ def get_search_results(
176183 query_text = query_text ,
177184 top_k = top_k ,
178185 effective_search_ratio = effective_search_ratio ,
186+ ranker = ranker ,
187+ alpha = alpha ,
179188 )
180189 except ValidationError as e :
181190 raise SearchValidationError (e .errors ()) from e
@@ -191,13 +200,18 @@ def get_search_results(
191200 )
192201 query_vector = self .embedder .embed_query (query_text )
193202 parameters ["query_vector" ] = query_vector
194-
195203 search_query , _ = get_search_query (
196204 search_type = SearchType .HYBRID ,
197205 return_properties = self .return_properties ,
198206 embedding_node_property = self ._embedding_node_property ,
199207 neo4j_version_is_5_23_or_above = self .neo4j_version_is_5_23_or_above ,
208+ ranker = validated_data .ranker ,
209+ alpha = validated_data .alpha ,
200210 )
211+
212+ if "ranker" in parameters :
213+ del parameters ["ranker" ]
214+
201215 sanitized_parameters = copy .deepcopy (parameters )
202216 if "query_vector" in sanitized_parameters :
203217 sanitized_parameters ["query_vector" ] = "..."
@@ -301,6 +315,8 @@ def get_search_results(
301315 top_k : int = 5 ,
302316 effective_search_ratio : int = 1 ,
303317 query_params : Optional [dict [str , Any ]] = None ,
318+ ranker : Union [str , HybridSearchRanker ] = HybridSearchRanker .NAIVE ,
319+ alpha : Optional [float ] = None ,
304320 ) -> RawSearchResult :
305321 """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
306322 Both query_vector and query_text can be provided.
@@ -320,7 +336,10 @@ def get_search_results(
320336 effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
321337 accuracy and performance. Defaults to 1.
322338 query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
323-
339+ ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
340+ alpha (Optional[float]): Weight for the vector score when using the linear ranker.
341+ The fulltext index score is multiplied by (1 - alpha).
342+ **Required** when using the linear ranker; must be between 0 and 1.
324343 Raises:
325344 SearchValidationError: If validation of the input arguments fail.
326345 EmbeddingRequiredError: If no embedder is provided.
@@ -334,6 +353,8 @@ def get_search_results(
334353 query_text = query_text ,
335354 top_k = top_k ,
336355 effective_search_ratio = effective_search_ratio ,
356+ ranker = ranker ,
357+ alpha = alpha ,
337358 query_params = query_params ,
338359 )
339360 except ValidationError as e :
@@ -361,7 +382,13 @@ def get_search_results(
361382 search_type = SearchType .HYBRID ,
362383 retrieval_query = self .retrieval_query ,
363384 neo4j_version_is_5_23_or_above = self .neo4j_version_is_5_23_or_above ,
385+ ranker = validated_data .ranker ,
386+ alpha = validated_data .alpha ,
364387 )
388+
389+ if "ranker" in parameters :
390+ del parameters ["ranker" ]
391+
365392 sanitized_parameters = copy .deepcopy (parameters )
366393 if "query_vector" in sanitized_parameters :
367394 sanitized_parameters ["query_vector" ] = "..."
0 commit comments