33"""
44
55import logging
6- from typing import Union
6+ from typing import Optional , Union
77
88import numpy as np
99
@@ -107,8 +107,8 @@ def _dump_raw_results(self, queries, index_list, distance_list) -> None:
107107 logging .debug (f" Hit { j } = { indices [j ]} /{ distances [j ]} : { self ._document_chunks [indices [j ]]} " )
108108 return
109109
110- def forward (self , query_or_queries : Union [str , list [str ]]) -> dspy .Prediction :
111- """Search the faiss index for self.k top passages for query.
110+ def forward (self , query_or_queries : Union [str , list [str ]], k : Optional [ int ] = None ) -> dspy .Prediction :
111+ """Search the faiss index for k or self.k top passages for query.
112112
113113 Args:
114114 query_or_queries (Union[str, List[str]]): The query or queries to search for.
@@ -122,21 +122,20 @@ def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction:
122122 emb_npa = np .array (embeddings )
123123 # For single query, just look up the top k passages
124124 if len (queries ) == 1 :
125- distance_list , index_list = self ._faiss_index .search (emb_npa , self .k )
125+ distance_list , index_list = self ._faiss_index .search (emb_npa , k or self .k )
126126 # self._dump_raw_results(queries, index_list, distance_list)
127127 passages = [(self ._document_chunks [ind ], ind ) for ind in index_list [0 ]]
128- passages = [dotdict ({"long_text" : passage [0 ], "index" : passage [1 ]}) for passage in passages ]
129- return dspy .Prediction (passages = passages )
128+ return [dotdict ({"long_text" : passage [0 ], "index" : passage [1 ]}) for passage in passages ]
130129
131- distance_list , index_list = self ._faiss_index .search (emb_npa , self .k * 3 )
130+ distance_list , index_list = self ._faiss_index .search (emb_npa , ( k or self .k ) * 3 )
132131 # self._dump_raw_results(queries, index_list, distance_list)
133132 passage_scores = {}
134133 for emb in range (len (embeddings )):
135134 indices = index_list [emb ] # indices of neighbors for embeddings[emb] - this is an array of k*3 integers
136135 distances = distance_list [
137136 emb
138137 ] # distances of neighbors for embeddings[emb] - this is an array of k*3 floating point numbers
139- for res in range (self .k * 3 ):
138+ for res in range (( k or self .k ) * 3 ):
140139 neighbor = indices [res ]
141140 distance = distances [res ]
142141 if neighbor in passage_scores :
@@ -147,10 +146,5 @@ def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction:
147146 # first degree sort: number of queries that got a hit with any particular document chunk. More
148147 # is a better match. This is len(queries)-len(x[1])
149148 # second degree sort: sum of the distances of each hit returned by faiss. Smaller distance is a better match
150- sorted_passages = sorted (passage_scores .items (), key = lambda x : (len (queries ) - len (x [1 ]), sum (x [1 ])))[: self .k ]
151- return dspy .Prediction (
152- passages = [
153- dotdict ({"long_text" : self ._document_chunks [passage_index ], "index" : passage_index })
154- for passage_index , _ in sorted_passages
155- ],
156- )
149+ sorted_passages = sorted (passage_scores .items (), key = lambda x : (len (queries ) - len (x [1 ]), sum (x [1 ])))[: k or self .k ]
150+ return [ dotdict ({"long_text" : self ._document_chunks [passage_index ], "index" : passage_index }) for passage_index , _ in sorted_passages ]
0 commit comments