|
12 | 12 | from pymongo_search_utils import ( |
13 | 13 | combine_pipelines, # noqa: F401 |
14 | 14 | final_hybrid_stage, # noqa: F401 |
| 15 | + reciprocal_rank_stage, # noqa: F401 |
15 | 16 | ) |
16 | 17 |
|
17 | 18 |
|
@@ -94,41 +95,3 @@ def vector_search_stage( |
94 | 95 | if filter: |
95 | 96 | stage["filter"] = filter |
96 | 97 | return {"$vectorSearch": stage} |
97 | | - |
98 | | - |
99 | | -def reciprocal_rank_stage( |
100 | | - score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any |
101 | | -) -> List[Dict[str, Any]]: |
102 | | - """ |
103 | | - Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring. |
104 | | -
|
105 | | - First, it groups documents into an array, assigns rank by array index, |
106 | | - and then computes a weighted RRF score. |
107 | | -
|
108 | | - Args: |
109 | | - score_field: A unique string to identify the search being ranked. |
110 | | - penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator. |
111 | | - weight: A float multiplier for this source's importance. |
112 | | - **kwargs: Ignored; allows future extensions or passthrough args. |
113 | | -
|
114 | | - Returns: |
115 | | - Aggregation pipeline stage for weighted RRF scoring. |
116 | | - """ |
117 | | - |
118 | | - return [ |
119 | | - {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, |
120 | | - {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, |
121 | | - { |
122 | | - "$addFields": { |
123 | | - f"docs.{score_field}": { |
124 | | - "$multiply": [ |
125 | | - weight, |
126 | | - {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]}, |
127 | | - ] |
128 | | - }, |
129 | | - "docs.rank": "$rank", |
130 | | - "_id": "$docs._id", |
131 | | - } |
132 | | - }, |
133 | | - {"$replaceRoot": {"newRoot": "$docs"}}, |
134 | | - ] |
0 commit comments