Skip to content

Commit c313b2c

Browse files
authored
fix(reranker): tests and top_n check fix #7212 (#7284)
reranker tests and top_n check fix #7212 Signed-off-by: Mikhail Khludnev <mkhl@apache.org>
1 parent 137f163 commit c313b2c

File tree

3 files changed

+88
-34
lines changed

3 files changed

+88
-34
lines changed

core/http/endpoints/jina/rerank.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,22 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
3232
}
3333

3434
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
35-
35+
var requestTopN int32
36+
docs := int32(len(input.Documents))
37+
if input.TopN == nil { // omit top_n to get all
38+
requestTopN = docs
39+
} else {
40+
requestTopN = int32(*input.TopN)
41+
if requestTopN < 1 {
42+
return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
43+
}
44+
if requestTopN > docs { // make it more obvious for backends
45+
requestTopN = docs
46+
}
47+
}
3648
request := &proto.RerankRequest{
3749
Query: input.Query,
38-
TopN: int32(input.TopN),
50+
TopN: requestTopN,
3951
Documents: input.Documents,
4052
}
4153

core/schema/jina.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ type JINARerankRequest struct {
55
BasicModelRequest
66
Query string `json:"query"`
77
Documents []string `json:"documents"`
8-
TopN int `json:"top_n"`
8+
TopN *int `json:"top_n,omitempty"`
99
Backend string `json:"backend"`
1010
}
1111

tests/e2e-aio/e2e_test.go

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -286,45 +286,64 @@ var _ = Describe("E2E test", func() {
286286
Context("reranker", func() {
287287
It("correctly", func() {
288288
modelName := "jina-reranker-v1-base-en"
289-
290-
req := schema.JINARerankRequest{
291-
BasicModelRequest: schema.BasicModelRequest{
292-
Model: modelName,
293-
},
294-
Query: "Organic skincare products for sensitive skin",
295-
Documents: []string{
296-
"Eco-friendly kitchenware for modern homes",
297-
"Biodegradable cleaning supplies for eco-conscious consumers",
298-
"Organic cotton baby clothes for sensitive skin",
299-
"Natural organic skincare range for sensitive skin",
300-
"Tech gadgets for smart homes: 2024 edition",
301-
"Sustainable gardening tools and compost solutions",
302-
"Sensitive skin-friendly facial cleansers and toners",
303-
"Organic food wraps and storage solutions",
304-
"All-natural pet food for dogs with allergies",
305-
"Yoga mats made from recycled materials",
306-
},
307-
TopN: 3,
289+
const query = "Organic skincare products for sensitive skin"
290+
var documents = []string{
291+
"Eco-friendly kitchenware for modern homes",
292+
"Biodegradable cleaning supplies for eco-conscious consumers",
293+
"Organic cotton baby clothes for sensitive skin",
294+
"Natural organic skincare range for sensitive skin",
295+
"Tech gadgets for smart homes: 2024 edition",
296+
"Sustainable gardening tools and compost solutions",
297+
"Sensitive skin-friendly facial cleansers and toners",
298+
"Organic food wraps and storage solutions",
299+
"All-natural pet food for dogs with allergies",
300+
"Yoga mats made from recycled materials",
301+
}
302+
// Exceed len or requested results
303+
randomValue := int(GinkgoRandomSeed()) % (len(documents) + 1)
304+
requestResults := randomValue + 1 // at least 1 results
305+
// Cap expectResults by the length of documents
306+
expectResults := min(requestResults, len(documents))
307+
var maybeSkipTopN = &requestResults
308+
if requestResults >= len(documents) && int(GinkgoRandomSeed())%2 == 0 {
309+
maybeSkipTopN = nil
308310
}
309311

310-
serialized, err := json.Marshal(req)
311-
Expect(err).To(BeNil())
312-
Expect(serialized).ToNot(BeNil())
313-
314-
rerankerEndpoint := apiEndpoint + "/rerank"
315-
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
316-
Expect(err).To(BeNil())
317-
Expect(resp).ToNot(BeNil())
318-
body, err := io.ReadAll(resp.Body)
319-
Expect(err).ToNot(HaveOccurred())
312+
resp, body := requestRerank(modelName, query, documents, maybeSkipTopN, apiEndpoint)
320313
Expect(resp.StatusCode).To(Equal(200), fmt.Sprintf("body: %s, response: %+v", body, resp))
321314

322315
deserializedResponse := schema.JINARerankResponse{}
323-
err = json.Unmarshal(body, &deserializedResponse)
316+
err := json.Unmarshal(body, &deserializedResponse)
324317
Expect(err).To(BeNil())
325318
Expect(deserializedResponse).ToNot(BeZero())
326319
Expect(deserializedResponse.Model).To(Equal(modelName))
327-
Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
320+
//Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0))
321+
Expect(len(deserializedResponse.Results)).To(Equal(expectResults))
322+
// Assert that relevance scores are in decreasing order
323+
for i := 1; i < len(deserializedResponse.Results); i++ {
324+
Expect(deserializedResponse.Results[i].RelevanceScore).To(
325+
BeNumerically("<=", deserializedResponse.Results[i-1].RelevanceScore),
326+
fmt.Sprintf("Result at index %d should have lower relevance score than previous result.", i),
327+
)
328+
}
329+
// Assert that each result's index points to the correct document
330+
for i, result := range deserializedResponse.Results {
331+
Expect(result.Index).To(
332+
And(
333+
BeNumerically(">=", 0),
334+
BeNumerically("<", len(documents)),
335+
),
336+
fmt.Sprintf("Result at position %d has index %d which should be within bounds [0, %d)", i, result.Index, len(documents)),
337+
)
338+
Expect(result.Document.Text).To(
339+
Equal(documents[result.Index]),
340+
fmt.Sprintf("Result at position %d (index %d) should have document text '%s', but got '%s'",
341+
i, result.Index, documents[result.Index], result.Document.Text),
342+
)
343+
}
344+
zeroOrNeg := int(GinkgoRandomSeed())%2 - 1 // Results in either -1 or 0
345+
resp, body = requestRerank(modelName, query, documents, &zeroOrNeg, apiEndpoint)
346+
Expect(resp.StatusCode).To(Equal(422), fmt.Sprintf("body: %s, response: %+v", body, resp))
328347
})
329348
})
330349
})
@@ -350,3 +369,26 @@ func downloadHttpFile(url string) (string, error) {
350369

351370
return tmpfile.Name(), nil
352371
}
372+
373+
func requestRerank(modelName, query string, documents []string, topN *int, apiEndpoint string) (*http.Response, []byte) {
374+
req := schema.JINARerankRequest{
375+
BasicModelRequest: schema.BasicModelRequest{
376+
Model: modelName,
377+
},
378+
Query: query,
379+
Documents: documents,
380+
TopN: topN,
381+
}
382+
383+
serialized, err := json.Marshal(req)
384+
Expect(err).To(BeNil())
385+
Expect(serialized).ToNot(BeNil())
386+
rerankerEndpoint := apiEndpoint + "/rerank"
387+
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized))
388+
Expect(err).To(BeNil())
389+
Expect(resp).ToNot(BeNil())
390+
body, err := io.ReadAll(resp.Body)
391+
Expect(err).ToNot(HaveOccurred())
392+
393+
return resp, body
394+
}

0 commit comments

Comments
 (0)