|
9 | 9 |
|
10 | 10 | from typing import Any, Dict, List, Optional, Union |
11 | 11 |
|
12 | | -from pymongo_search_utils import combine_pipelines # noqa: F401 |
| 12 | +from pymongo_search_utils import ( |
| 13 | + combine_pipelines, # noqa: F401 |
| 14 | + final_hybrid_stage, # noqa: F401 |
| 15 | +) |
13 | 16 |
|
14 | 17 |
|
15 | 18 | def text_search_stage( |
@@ -129,26 +132,3 @@ def reciprocal_rank_stage( |
129 | 132 | }, |
130 | 133 | {"$replaceRoot": {"newRoot": "$docs"}}, |
131 | 134 | ] |
132 | | - |
133 | | - |
134 | | -def final_hybrid_stage( |
135 | | - scores_fields: List[str], limit: int, **kwargs: Any |
136 | | -) -> List[Dict[str, Any]]: |
137 | | - """Sum weighted scores, sort, and apply limit. |
138 | | -
|
139 | | - Args: |
140 | | - scores_fields: List of fields given to scores of vector and text searches |
141 | | - limit: Number of documents to return |
142 | | -
|
143 | | - Returns: |
144 | | - Final aggregation stages |
145 | | - """ |
146 | | - |
147 | | - return [ |
148 | | - {"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}}, |
149 | | - {"$replaceRoot": {"newRoot": "$docs"}}, |
150 | | - {"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}}, |
151 | | - {"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}}, |
152 | | - {"$sort": {"score": -1}}, |
153 | | - {"$limit": limit}, |
154 | | - ] |
0 commit comments