Skip to content

Commit 42b18c3

Browse files
committed
Add initial Cohere Chat and Embeddings module implementations
1 parent 552d894 commit 42b18c3

File tree

6 files changed

+394
-0
lines changed

6 files changed

+394
-0
lines changed

lib/ruby_llm/providers/cohere.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@ module Providers
1313
# See https://docs.cohere.com/docs/compatibility-api for more information.
1414
module Cohere
1515
extend Provider
16+
extend Cohere::Chat
17+
extend Cohere::Embeddings
18+
extend Cohere::Reranking
1619
extend Cohere::Models
20+
extend Cohere::Streaming
21+
extend Cohere::Tools
22+
extend Cohere::Media
23+
1724
module_function
1825

1926
def api_base(_config)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
module Cohere
6+
# Chat methods of the Cohere API integration
7+
# - https://docs.cohere.com/reference/chat
8+
# - https://docs.cohere.com/docs/chat-api
9+
module Chat
10+
def completion_url
11+
'v2/chat'
12+
end
13+
14+
module_function
15+
16+
def render_payload(messages, tools:, temperature:, model:, stream: false)
17+
@model_id = model
18+
19+
{
20+
model: model,
21+
messages: format_messages(messages),
22+
temperature: temperature,
23+
stream: stream,
24+
tools: tools.any? ? tools.map { |_, tool| Tools.tool_for(tool) } : nil
25+
}.compact
26+
end
27+
28+
def parse_completion_response(response)
29+
data = response.body
30+
return if data.empty?
31+
32+
raise Error.new(response, data['message']) if data['message'] && response.status != 200
33+
34+
message_data = data['message']
35+
return unless message_data
36+
37+
Message.new(
38+
role: message_data['role'].to_sym,
39+
content: message_data.dig('content', 0, 'text'),
40+
tool_calls: Tools.parse_tool_calls(message_data['tool_calls']),
41+
input_tokens: data.dig('usage', 'tokens', 'input_tokens'),
42+
output_tokens: data.dig('usage', 'tokens', 'output_tokens'),
43+
model_id: @model_id
44+
)
45+
end
46+
47+
def format_messages(messages)
48+
messages.map { |msg| format_message(msg) }
49+
end
50+
51+
def format_message(msg)
52+
if msg.tool_call?
53+
Tools.format_tool_call(msg)
54+
elsif msg.tool_result?
55+
Tools.format_tool_result(msg)
56+
else
57+
format_basic_message(msg)
58+
end
59+
end
60+
61+
def format_basic_message(msg)
62+
{
63+
role: format_role(msg.role),
64+
content: Media.format_content(msg.content)
65+
}.compact
66+
end
67+
68+
def format_role(role)
69+
case role
70+
when :system
71+
'system'
72+
when :user, :tool
73+
'user'
74+
when :assistant
75+
'assistant'
76+
else
77+
role.to_s
78+
end
79+
end
80+
end
81+
end
82+
end
83+
end
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
module Cohere
6+
# Embeddings methods of the Cohere API integration
7+
# - https://docs.cohere.com/reference/embed
8+
# - https://docs.cohere.com/docs/embeddings
9+
module Embeddings
10+
module_function
11+
12+
def embedding_url(...)
13+
'v2/embed'
14+
end
15+
16+
def render_embedding_payload(input, model:, dimensions: nil)
17+
{
18+
model: model,
19+
embedding_types: ['float'],
20+
texts: Array(input),
21+
input_type: 'search_document',
22+
output_dimension: dimensions,
23+
truncate: 'END' # Handle long texts by truncating at the end
24+
}
25+
end
26+
27+
def parse_embedding_response(response, model:)
28+
data = response.body
29+
raise Error.new(response, data['message']) if data['message'] && response.status != 200
30+
31+
vectors = data.dig('embeddings', 'float') || []
32+
input_tokens = data.dig('meta', 'billed_units', 'input_tokens') || 0
33+
34+
# If we only got one embedding, return it as a single vector
35+
vectors = vectors.first if vectors.length == 1
36+
37+
Embedding.new(vectors:, model:, input_tokens:)
38+
end
39+
end
40+
end
41+
end
42+
end
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
module Cohere
6+
# Handles formatting of media content (images) for Cohere APIs
7+
# Supports Aya Vision models with multimodal capabilities
8+
module Media
9+
module_function
10+
11+
def format_content(content)
12+
return [format_text(content)] unless content.is_a?(RubyLLM::Content)
13+
14+
parts = []
15+
parts << format_text(content.text) if content.text
16+
17+
content.attachments.each do |attachment|
18+
case attachment.type
19+
when :image
20+
parts << format_image(attachment)
21+
when :text
22+
parts << format_text_file(attachment)
23+
else
24+
raise RubyLLM::UnsupportedAttachmentError, attachment.type
25+
end
26+
end
27+
28+
parts
29+
end
30+
31+
def format_text(text)
32+
{
33+
type: 'text',
34+
text: text
35+
}
36+
end
37+
38+
def format_image(image)
39+
if image.url?
40+
# Use URL directly for Cohere API
41+
{
42+
type: 'image_url',
43+
image_url: {
44+
url: image.source
45+
}
46+
}
47+
else
48+
# Use base64 encoding for local images
49+
{
50+
type: 'image_url',
51+
image_url: {
52+
url: "data:#{image.mime_type};base64,#{image.encoded}"
53+
}
54+
}
55+
end
56+
end
57+
58+
def format_text_file(text_file)
59+
{
60+
type: 'text',
61+
text: Utils.format_text_file_for_llm(text_file)
62+
}
63+
end
64+
end
65+
end
66+
end
67+
end
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
module Cohere
6+
# Streaming methods of the Cohere API integration
7+
# - https://docs.cohere.com/docs/streaming
8+
module Streaming
9+
private
10+
11+
def stream_url
12+
completion_url
13+
end
14+
15+
def build_chunk(data)
16+
Chunk.new(
17+
role: :assistant,
18+
model_id: extract_model_id(data),
19+
content: extract_content(data),
20+
input_tokens: extract_input_tokens(data),
21+
output_tokens: extract_output_tokens(data),
22+
tool_calls: extract_tool_calls(data)
23+
)
24+
end
25+
26+
def extract_model_id(data)
27+
data['response_id'] || data['id']
28+
end
29+
30+
def extract_content(data)
31+
case data['type']
32+
when 'content-delta'
33+
data.dig('delta', 'message', 'content', 'text')
34+
when 'message-end'
35+
# Final message content
36+
data.dig('delta', 'message', 'content', 0, 'text')
37+
end
38+
end
39+
40+
def extract_input_tokens(data)
41+
return unless data['type'] == 'message-end'
42+
43+
data.dig('delta', 'usage', 'tokens', 'input_tokens')
44+
end
45+
46+
def extract_output_tokens(data)
47+
return unless data['type'] == 'message-end'
48+
49+
data.dig('delta', 'usage', 'tokens', 'output_tokens')
50+
end
51+
52+
def extract_tool_calls(data)
53+
case data['type']
54+
when 'tool-call-start'
55+
tool_call_data = data.dig('delta', 'message', 'tool_calls')
56+
return {} unless tool_call_data
57+
58+
tool_call = ToolCall.new(
59+
id: tool_call_data['id'],
60+
name: tool_call_data.dig('function', 'name'),
61+
arguments: tool_call_data.dig('function', 'arguments') || ''
62+
)
63+
{ tool_call.id => tool_call }
64+
when 'tool-call-delta'
65+
# Handle streaming tool call arguments
66+
argument_delta = data.dig('delta', 'message', 'tool_calls', 'function', 'arguments')
67+
return {} unless argument_delta
68+
69+
{ nil => ToolCall.new(id: nil, name: nil, arguments: argument_delta) }
70+
when 'message-end'
71+
tool_calls = data.dig('delta', 'message', 'tool_calls')
72+
return {} unless tool_calls
73+
74+
result = {}
75+
tool_calls.each do |call|
76+
tool_call = ToolCall.new(
77+
id: call['id'],
78+
name: call.dig('function', 'name'),
79+
arguments: call.dig('function', 'parameters')
80+
)
81+
result[tool_call.id] = tool_call
82+
end
83+
result
84+
else
85+
{}
86+
end
87+
end
88+
89+
def parse_streaming_error(data)
90+
error_data = JSON.parse(data)
91+
return unless error_data['type'] == 'error'
92+
93+
message = error_data.dig('error', 'message') || 'Unknown error'
94+
[500, message]
95+
rescue JSON::ParserError
96+
[500, 'Failed to parse error response']
97+
end
98+
end
99+
end
100+
end
101+
end

0 commit comments

Comments
 (0)