Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions pymongo_search_utils/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,42 @@ def combine_pipelines(


def reciprocal_rank_stage(
score_field: str, penalty: float = 0, **kwargs: Any
score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any
) -> list[dict[str, Any]]:
"""Stage adds Reciprocal Rank Fusion weighting.
"""
Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring.

First, it pushes documents retrieved from previous stage
into a temporary sub-document. It then unwinds to establish
the rank to each and applies the penalty.
First, it groups documents into an array, assigns rank by array index,
and then computes a weighted RRF score.

Args:
score_field: A unique string to identify the search being ranked
penalty: A non-negative float.
extra_fields: Any fields other than text_field that one wishes to keep.
score_field: A unique string to identify the search being ranked.
penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator.
weight: A float multiplier for this source's importance.
**kwargs: Ignored; allows future extensions or passthrough args.

Returns:
RRF score
Aggregation pipeline stage for weighted RRF scoring.
"""

rrf_pipeline = [
return [
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
{
"$addFields": {
f"docs.{score_field}": {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
f"docs.{score_field}": {
"$multiply": [
weight,
{"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
]
},
"docs.rank": "$rank",
"_id": "$docs._id",
}
},
{"$replaceRoot": {"newRoot": "$docs"}},
]

return rrf_pipeline # type: ignore[return-value]


def final_hybrid_stage(scores_fields: list[str], limit: int, **kwargs: Any) -> list[dict[str, Any]]:
"""Sum weighted scores, sort, and apply limit.
Expand Down
Loading