|
1 | 1 | from typing import Any, Dict, List, Optional, Tuple, Union |
2 | 2 |
|
3 | | -from sentence_transformers import CrossEncoder |
| 3 | +from pydantic.v1 import PrivateAttr |
4 | 4 |
|
5 | 5 | from redisvl.utils.rerank.base import BaseReranker |
6 | 6 |
|
@@ -31,25 +31,44 @@ class HFCrossEncoderReranker(BaseReranker): |
31 | 31 | ) |
32 | 32 | """ |
33 | 33 |
|
| 34 | + _client: Any = PrivateAttr() |
| 35 | + |
34 | 36 | def __init__( |
35 | 37 | self, |
36 | | - model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", |
| 38 | + model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", |
37 | 39 | limit: int = 3, |
38 | 40 | return_score: bool = True, |
| 41 | + **kwargs, |
39 | 42 | ) -> None: |
40 | 43 | """ |
41 | 44 | Initialize the HFCrossEncoderReranker with a specified model and ranking criteria. |
42 | 45 |
|
43 | 46 | Parameters: |
44 | | - model_name (str): The name or path of the cross-encoder model to use for reranking. |
| 47 | + model (str): The name or path of the cross-encoder model to use for reranking. |
45 | 48 | Defaults to 'cross-encoder/ms-marco-MiniLM-L-6-v2'. |
46 | 49 | limit (int): The maximum number of results to return after reranking. Must be a positive integer. |
47 | 50 | return_score (bool): Whether to return scores alongside the reranked results. |
48 | 51 | """ |
| 52 | + model = model or kwargs.pop("model_name", None) |
49 | 53 | super().__init__( |
50 | | - model=model_name, rank_by=None, limit=limit, return_score=return_score |
| 54 | + model=model, rank_by=None, limit=limit, return_score=return_score |
51 | 55 | ) |
52 | | - self.model: CrossEncoder = CrossEncoder(model_name) |
| 56 | + self._initialize_client(**kwargs) |
| 57 | + |
| 58 | + def _initialize_client(self, **kwargs): |
| 59 | + """ |
| 60 | + Setup the huggingface cross-encoder client using optional kwargs. |
| 61 | + """ |
| 62 | + # Dynamic import of the sentence-transformers module |
| 63 | + try: |
| 64 | + from sentence_transformers import CrossEncoder |
| 65 | + except ImportError: |
| 66 | + raise ImportError( |
| 67 | + "HFCrossEncoder reranker requires the sentence-transformers library. \ |
| 68 | + Please install with `pip install sentence-transformers`" |
| 69 | + ) |
| 70 | + |
| 71 | + self._client = CrossEncoder(self.model, **kwargs) |
53 | 72 |
|
54 | 73 | def rank( |
55 | 74 | self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs |
@@ -97,7 +116,7 @@ def rank( |
97 | 116 | texts = [str(doc) for doc in docs] |
98 | 117 | doc_subset = [{"content": doc} for doc in docs] |
99 | 118 |
|
100 | | - scores = self.model.predict([(query, text) for text in texts]) |
| 119 | + scores = self._client.predict([(query, text) for text in texts]) |
101 | 120 | scores = [float(score) for score in scores] |
102 | 121 | docs_with_scores = list(zip(doc_subset, scores)) |
103 | 122 | docs_with_scores.sort(key=lambda x: x[1], reverse=True) |
|
0 commit comments