Skip to content

Commit 122e4c7

Browse files
authored
fix(reranker): reproduce ignoring top_n (#7025)
* fix(reranker): reproduce ignoring top_n Signed-off-by: Mikhail Khludnev <mkhl@apache.org> * fix(reranker): ignoring top_n Signed-off-by: Mikhail Khludnev <mkhl@apache.org> --------- Signed-off-by: Mikhail Khludnev <mkhl@apache.org>
1 parent 2573102 commit 122e4c7

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

backend/python/rerankers/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def LoadModel(self, request, context):
6161
if request.PipelineType != "": # Reuse the PipelineType field for language
6262
kwargs['lang'] = request.PipelineType
6363
self.model_name = model_name
64-
self.model = Reranker(model_name, **kwargs)
64+
self.model = Reranker(model_name, **kwargs)
6565
except Exception as err:
6666
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
6767

@@ -80,7 +80,7 @@ def Rerank(self, request, context):
8080
index=res.doc_id,
8181
text=res.text,
8282
relevance_score=res.score
83-
) for res in ranked_results.results
83+
) for res in ranked_results.top_k(request.top_n)
8484
]
8585

8686
# Calculate the usage and total tokens

backend/python/rerankers/test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,33 @@ def test_rerank(self):
8686
except Exception as err:
8787
print(err)
8888
self.fail("Reranker service failed")
89+
finally:
90+
self.tearDown()
91+
92+
def test_rerank_crop(self):
93+
"""
94+
This method tests if the embeddings are generated successfully
95+
"""
96+
try:
97+
self.setUp()
98+
with grpc.insecure_channel("localhost:50051") as channel:
99+
stub = backend_pb2_grpc.BackendStub(channel)
100+
request = backend_pb2.RerankRequest(
101+
query="I love you",
102+
documents=["I hate you", "I really like you", "I hate ignoring top_n"],
103+
top_n=2
104+
)
105+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
106+
self.assertTrue(response.success)
107+
108+
rerank_response = stub.Rerank(request)
109+
print(rerank_response.results[0])
110+
self.assertIsNotNone(rerank_response.results)
111+
self.assertEqual(len(rerank_response.results), 2)
112+
self.assertEqual(rerank_response.results[0].text, "I really like you")
113+
self.assertEqual(rerank_response.results[1].text, "I hate you")
114+
except Exception as err:
115+
print(err)
116+
self.fail("Reranker service failed")
89117
finally:
90118
self.tearDown()

0 commit comments

Comments
 (0)