Skip to content

Commit 02016e1

Browse files
Merge pull request #559 from software-artisan/main
add 'k' as argument to FaissRM.forward()
2 parents ab65929 + d73efde commit 02016e1

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

dspy/retrieve/faiss_rm.py

100644100755
Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import logging
6-
from typing import Union
6+
from typing import Optional, Union
77

88
import numpy as np
99

@@ -107,8 +107,8 @@ def _dump_raw_results(self, queries, index_list, distance_list) -> None:
107107
logging.debug(f" Hit {j} = {indices[j]}/{distances[j]}: {self._document_chunks[indices[j]]}")
108108
return
109109

110-
def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction:
111-
"""Search the faiss index for self.k top passages for query.
110+
def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction:
111+
"""Search the faiss index for k or self.k top passages for query.
112112
113113
Args:
114114
query_or_queries (Union[str, List[str]]): The query or queries to search for.
@@ -122,21 +122,20 @@ def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction:
122122
emb_npa = np.array(embeddings)
123123
# For single query, just look up the top k passages
124124
if len(queries) == 1:
125-
distance_list, index_list = self._faiss_index.search(emb_npa, self.k)
125+
distance_list, index_list = self._faiss_index.search(emb_npa, k or self.k)
126126
# self._dump_raw_results(queries, index_list, distance_list)
127127
passages = [(self._document_chunks[ind], ind) for ind in index_list[0]]
128-
passages = [dotdict({"long_text": passage[0], "index": passage[1]}) for passage in passages]
129-
return dspy.Prediction(passages=passages)
128+
return [dotdict({"long_text": passage[0], "index": passage[1]}) for passage in passages]
130129

131-
distance_list, index_list = self._faiss_index.search(emb_npa, self.k * 3)
130+
distance_list, index_list = self._faiss_index.search(emb_npa, (k or self.k) * 3)
132131
# self._dump_raw_results(queries, index_list, distance_list)
133132
passage_scores = {}
134133
for emb in range(len(embeddings)):
135134
indices = index_list[emb] # indices of neighbors for embeddings[emb] - this is an array of k*3 integers
136135
distances = distance_list[
137136
emb
138137
] # distances of neighbors for embeddings[emb] - this is an array of k*3 floating point numbers
139-
for res in range(self.k * 3):
138+
for res in range((k or self.k) * 3):
140139
neighbor = indices[res]
141140
distance = distances[res]
142141
if neighbor in passage_scores:
@@ -147,10 +146,5 @@ def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction:
147146
# first degree sort: number of queries that got a hit with any particular document chunk. More
148147
# is a better match. This is len(queries)-len(x[1])
149148
# second degree sort: sum of the distances of each hit returned by faiss. Smaller distance is a better match
150-
sorted_passages = sorted(passage_scores.items(), key=lambda x: (len(queries) - len(x[1]), sum(x[1])))[: self.k]
151-
return dspy.Prediction(
152-
passages=[
153-
dotdict({"long_text": self._document_chunks[passage_index], "index": passage_index})
154-
for passage_index, _ in sorted_passages
155-
],
156-
)
149+
sorted_passages = sorted(passage_scores.items(), key=lambda x: (len(queries) - len(x[1]), sum(x[1])))[: k or self.k]
150+
return [ dotdict({"long_text": self._document_chunks[passage_index], "index": passage_index}) for passage_index, _ in sorted_passages ]

0 commit comments

Comments
 (0)