Skip to content

Commit 552d894

Browse files
committed
Add Cohere Capabilities and Models modules
1 parent 9cef065 commit 552d894

File tree

3 files changed

+348
-0
lines changed

3 files changed

+348
-0
lines changed

lib/ruby_llm/providers/cohere.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ module Providers
1313
# See https://docs.cohere.com/docs/compatibility-api for more information.
1414
module Cohere
1515
extend Provider
16+
extend Cohere::Models
1617
module_function
1718

1819
def api_base(_config)
@@ -27,6 +28,7 @@ def headers(config)
2728
end
2829

2930
def capabilities
31+
Cohere::Capabilities
3032
end
3133

3234
def slug
@@ -36,6 +38,11 @@ def slug
3638
def configuration_requirements
3739
%i[cohere_api_key]
3840
end
41+
42+
def parse_error(response)
43+
return if response.body.empty?
44+
45+
JSON.parse(response.response.body)['message']
3946
end
4047
end
4148
end
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
module Cohere
6+
# Determines capabilities and constraints for Cohere models
7+
module Capabilities
8+
module_function
9+
10+
# Model family patterns for Cohere models
11+
MODEL_PATTERNS = {
12+
'command-a' => /^c4ai-command-a-|^command-a-/,
13+
'command-r' => /^command-r(?!7b)/, # Command R and Command R+ but not R7B
14+
'command-r7b' => /^command-r7b/,
15+
'command-light' => /^command-light/,
16+
'command-nightly' => /^command-nightly/,
17+
'command' => /^command(?!-r|-light|-nightly)/, # Regular command models
18+
'aya-expanse' => /^c4ai-aya-expanse/,
19+
'aya-vision' => /^c4ai-aya-vision/,
20+
'aya' => /^aya(?!-vision|-expanse)/,
21+
'embed-v4' => /^embed-v4/,
22+
'embed-english-v3' => /^embed-english-v3/,
23+
'embed-multilingual-v3' => /^embed-multilingual-v3/,
24+
'embed-v3' => /^embed-v3/, # Fallback for other v3 models
25+
'embed-english' => /^embed-english/,
26+
'embed-multilingual' => /^embed-multilingual/,
27+
'embed' => /^embed/, # Fallback for other embed models
28+
'rerank-v3-5' => /^rerank-v3\.5/,
29+
'rerank-english' => /^rerank-english/,
30+
'rerank-multilingual' => /^rerank-multilingual/,
31+
'rerank' => /^rerank/ # Fallback for other rerank models
32+
}.freeze
33+
34+
# Normalizes temperature values for Cohere models
35+
# @param temperature [Float] the temperature value to normalize
36+
# @param _model [String] the model identifier (unused but kept for API consistency)
37+
# @return [Float, nil] the normalized temperature value or nil if temperature is nil
38+
def normalize_temperature(temperature, _model)
39+
# Cohere accepts temperature values between 0.0 and 2.0
40+
return nil if temperature.nil?
41+
42+
temperature.clamp(0.0, 2.0)
43+
end
44+
45+
# Determines the context window size for a given model
46+
# @param model_id [String] the model identifier
47+
# @return [Integer] the context window size in tokens
48+
def context_window_for(model_id)
49+
case model_family(model_id)
50+
when 'command_a' then 256_000 # Command A has 256k context window
51+
when 'command_r', 'command_r7b' then 128_000 # Command R models have 128k
52+
when 'aya_expanse', 'aya_vision' then 16_000 # Aya models have 16k
53+
when 'aya' then 8_192
54+
else 4_096 # All other Command and Rerank models have 4k context window
55+
end
56+
end
57+
58+
# Determines the maximum output tokens for a given model
59+
# @param model_id [String] the model identifier
60+
# @return [Integer] the maximum output tokens
61+
def max_tokens_for_model(model_id)
62+
case model_family(model_id)
63+
when 'command_a' then 8_192 # Command A has 8k max output
64+
when 'aya_expanse', 'aya_vision' then 4_000 # Aya models have 4k
65+
when 'aya' then 2_048
66+
else 4_096 # All regular Command and Command R models have 4k context window
67+
end
68+
end
69+
70+
# Determines if a model supports streaming responses
71+
# @param model [String] the model identifier
72+
# @return [Boolean] true if the model supports streaming
73+
def supports_streaming?(model)
74+
# Most Cohere models support streaming
75+
!model.to_s.match?(/embed|rerank/)
76+
end
77+
78+
# Determines if a model supports tool/function calling
79+
# @param model [String] the model identifier
80+
# @return [Boolean] true if the model supports tools
81+
def supports_tools?(model)
82+
# Command models and Aya Vision support function calling
83+
model.to_s.match?(/command|aya-vision|aya-expanse/)
84+
end
85+
86+
# Determines if a model supports image processing
87+
# @param model [String] the model identifier
88+
# @return [Boolean] true if the model supports images
89+
def supports_images?(model)
90+
# Aya Vision models support images in chat and Embed v3+ supports image embeddings
91+
model.to_s.match?(/aya-vision|^embed.*v[34]/)
92+
end
93+
94+
# Determines if a model supports embedding generation
95+
# @param model [String] the model identifier
96+
# @return [Boolean] true if the model supports embeddings
97+
def supports_embeddings?(model)
98+
model.to_s.include?('embed')
99+
end
100+
101+
# Determines if a model supports reranking
102+
# @param model [String] the model identifier
103+
# @return [Boolean] true if the model supports reranking
104+
def supports_reranking?(model)
105+
model.to_s.include?('rerank')
106+
end
107+
108+
# Determines if a model supports vision capabilities
109+
# @param model_id [String] the model identifier
110+
# @return [Boolean] true if the model supports vision
111+
def supports_vision?(model_id)
112+
supports_images?(model_id)
113+
end
114+
115+
# Determines if a model supports function calling
116+
# @param model_id [String] the model identifier
117+
# @return [Boolean] true if the model supports functions
118+
def supports_functions?(model_id)
119+
supports_tools?(model_id)
120+
end
121+
122+
# Determines if a model supports JSON mode
123+
# @param model_id [String] the model identifier
124+
# @return [Boolean] true if the model supports JSON mode
125+
def supports_json_mode?(model_id)
126+
# Command models and Aya models support structured output
127+
model_id.match?(/command|aya-vision|aya-expanse/)
128+
end
129+
130+
# Determines if a model supports structured output
131+
# @param model_id [String] the model identifier
132+
# @return [Boolean] true if the model supports structured output
133+
def supports_structured_output?(model_id)
134+
supports_json_mode?(model_id)
135+
end
136+
137+
# Determines the model family for a given model ID
138+
# @param model_id [String] the model identifier
139+
# @return [String] the model family identifier
140+
def model_family(model_id)
141+
MODEL_PATTERNS.each do |family, pattern|
142+
return family.to_s if model_id.match?(pattern)
143+
end
144+
'other'
145+
end
146+
147+
# Returns the model type
148+
# @param model_id [String] the model identifier
149+
# @return [String] the model type ('chat', 'embedding', or 'rerank')
150+
def model_type(model_id)
151+
case model_family(model_id)
152+
when /embed/ then 'embedding'
153+
when /rerank/ then 'rerank'
154+
else 'chat'
155+
end
156+
end
157+
158+
# Pricing information for Cohere models (per million tokens)
159+
# Source: https://cohere.com/pricing (as of 2025)
160+
PRICES = {
161+
command_a: { input: 2.50, output: 10.00 },
162+
command_r: { input: 0.15, output: 0.60 },
163+
command_r7b: { input: 0.0375, output: 0.15 },
164+
command_light: { input: 0.025, output: 0.10 }, # Estimated light model pricing
165+
command_nightly: { input: 0.15, output: 0.60 }, # Same as command-r for nightly
166+
command: { input: 0.05, output: 0.20 }, # Regular command model pricing
167+
aya_expanse: { input: 0.50, output: 1.50 },
168+
aya_vision: { input: 0.50, output: 1.50 },
169+
aya: { input: 0.50, output: 1.50 },
170+
embed_v4: { text_price: 0.12, image_price: 0.47 },
171+
embed_english_v3: { price: 0.12 },
172+
embed_multilingual_v3: { price: 0.12 },
173+
embed_v3: { price: 0.12 },
174+
embed_english: { price: 0.10 }, # Light models slightly cheaper
175+
embed_multilingual: { price: 0.10 },
176+
embed: { price: 0.12 },
177+
rerank_v3point5: { price: 2.0 },
178+
rerank_english: { price: 1.50 },
179+
rerank_multilingual: { price: 2.0 },
180+
rerank: { price: 2.0 } # $2.00 per 1K searches = $2000 per 1M searches
181+
}.freeze
182+
183+
# Gets the input price per million tokens for a given model
184+
# @param model_id [String] the model identifier
185+
# @return [Float] the price per million tokens for input
186+
def input_price_for(model_id)
187+
family = model_family(model_id).to_sym
188+
prices = PRICES.fetch(family, { input: default_input_price })
189+
prices[:input] || prices[:text_price] || prices[:price] || default_input_price
190+
end
191+
192+
# Gets the output price per million tokens for a given model
193+
# @param model_id [String] the model identifier
194+
# @return [Float] the price per million tokens for output
195+
def output_price_for(model_id)
196+
family = model_family(model_id).to_sym
197+
prices = PRICES.fetch(family, { output: default_output_price })
198+
prices[:output] || prices[:text_price] || prices[:price] || default_output_price
199+
end
200+
201+
# Gets the image price per million tokens for a given model
202+
# @param model_id [String] the model identifier
203+
# @return [Float, nil] the price per million tokens for image processing or nil if not supported
204+
def image_price_for(model_id)
205+
family = model_family(model_id).to_sym
206+
prices = PRICES.fetch(family, {})
207+
prices[:image_price]
208+
end
209+
210+
# Default input price if model not found in PRICES
211+
# @return [Float] default price per million tokens for input
212+
def default_input_price
213+
1.0
214+
end
215+
216+
# Default output price if model not found in PRICES
217+
# @return [Float] default price per million tokens for output
218+
def default_output_price
219+
2.0
220+
end
221+
222+
# Returns the supported modalities for a given model
223+
# @param model_id [String] the model identifier
224+
# @return [Hash] hash containing input and output modalities
225+
def modalities_for(model_id)
226+
modalities = {
227+
input: ['text'],
228+
output: ['text']
229+
}
230+
231+
modalities[:input] << 'image' if supports_images?(model_id)
232+
modalities[:output] = ['embeddings'] if supports_embeddings?(model_id)
233+
modalities[:output] = ['rerank'] if supports_reranking?(model_id)
234+
235+
modalities
236+
end
237+
238+
# Returns the capabilities of a given model
239+
# @param model_id [String] the model identifier
240+
# @return [Array<String>] array of capability strings
241+
def capabilities_for(model_id)
242+
capabilities = []
243+
244+
capabilities << 'streaming' if supports_streaming?(model_id)
245+
capabilities << 'reranking' if supports_reranking?(model_id)
246+
capabilities << 'function_calling' if supports_functions?(model_id)
247+
capabilities << 'structured_output' if supports_structured_output?(model_id)
248+
capabilities << 'multilingual' if model_id.match?(/aya|multilingual/)
249+
capabilities << 'citations' if model_id.match?(/command-a|aya-vision|aya-expanse/)
250+
251+
capabilities
252+
end
253+
254+
# Returns the pricing structure for a given model
255+
# @param model_id [String] the model identifier
256+
# @return [Hash] hash containing pricing information
257+
def pricing_for(model_id)
258+
family = model_family(model_id)
259+
prices = PRICES.fetch(family.to_sym, { input: default_input_price, output: default_output_price })
260+
261+
if prices[:price]
262+
# For models with single pricing (like older embeddings and rerank)
263+
{ usage_tokens: { standard: { price_per_million: prices[:price] } } }
264+
elsif prices[:text_price]
265+
# For models with text/image pricing (like embed-v4)
266+
pricing_structure = {
267+
text_tokens: { standard: { price_per_million: prices[:text_price] } }
268+
}
269+
270+
# Add image pricing if available
271+
if prices[:image_price]
272+
pricing_structure[:image_tokens] = {
273+
standard: { price_per_million: prices[:image_price] }
274+
}
275+
end
276+
277+
pricing_structure
278+
else
279+
# For models with input/output pricing
280+
{
281+
text_tokens: {
282+
standard: {
283+
input_per_million: prices[:input],
284+
output_per_million: prices[:output]
285+
}
286+
}
287+
}
288+
end
289+
end
290+
291+
# Formats a model ID for display purposes
292+
# @param model_id [String] the model identifier
293+
# @return [String] the formatted display name
294+
def format_display_name(model_id)
295+
model_id.gsub(/^c4ai-/, '')
296+
.tr('-', ' ')
297+
.split
298+
.map(&:capitalize)
299+
.join(' ')
300+
end
301+
end
302+
end
303+
end
304+
end
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
module Cohere
6+
# Model definitions for Cohere API
7+
# https://docs.cohere.com/reference/list-models
8+
module Models
9+
module_function
10+
11+
def models_url
12+
'v1/models'
13+
end
14+
15+
def parse_list_models_response(response, slug, capabilities)
16+
data = response.body
17+
return [] if data.empty?
18+
19+
20+
data['models']&.map do |model_data|
21+
model_id = model_data['name']
22+
23+
Model::Info.new(
24+
id: model_id,
25+
name: capabilities.format_display_name(model_id),
26+
provider: slug,
27+
family: capabilities.model_family(model_id),
28+
modalities: capabilities.modalities_for(model_id),
29+
capabilities: capabilities.capabilities_for(model_id),
30+
pricing: capabilities.pricing_for(model_id)
31+
)
32+
end || []
33+
end
34+
end
35+
end
36+
end
37+
end

0 commit comments

Comments
 (0)