Skip to content

Commit 5bef4ba

Browse files
committed
Add Cohere reranking support to Cohere Provider
1 parent 03019fc commit 5bef4ba

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

lib/ruby_llm/providers/cohere.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ module Cohere
1515
extend Provider
1616
extend Cohere::Chat
1717
extend Cohere::Embeddings
18+
extend Cohere::Reranking
1819
extend Cohere::Models
1920
extend Cohere::Streaming
2021
extend Cohere::Tools
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
module Cohere
6+
# Reranking methods for the Cohere API integration
7+
# - https://docs.cohere.com/reference/rerank
8+
# - https://docs.cohere.com/docs/rerank-overview
9+
# - https://docs.cohere.com/docs/reranking-best-practices
10+
module Reranking
11+
module_function
12+
13+
def rerank_url(...)
14+
'v2/rerank'
15+
end
16+
17+
def render_rerank_payload(query, documents, model:, top_n:, max_tokens_per_doc:)
18+
@documents = documents
19+
20+
{
21+
model: model,
22+
query: query,
23+
documents: @documents,
24+
top_n: top_n || documents.count,
25+
max_tokens_per_doc: max_tokens_per_doc || 4_096
26+
}
27+
end
28+
29+
def parse_rerank_response(response, model:)
30+
data = response.body
31+
raise Error.new(response, data['message']) if data['message'] && response.status != 200
32+
33+
results = data['results'] || []
34+
results = results.map do |r|
35+
RerankResult.new(
36+
index: r['index'],
37+
relevance_score: r['relevance_score'],
38+
document: @documents[r['index']]
39+
)
40+
end
41+
search_units = data.dig('meta', 'billed_units', 'search_units') || 0
42+
43+
Rerank.new(results:, model:, search_units:)
44+
end
45+
end
46+
end
47+
end
48+
end

0 commit comments

Comments
 (0)