|
| 1 | +# frozen_string_literal: true |
| 2 | + |
| 3 | +module RubyLLM |
| 4 | + module Providers |
| 5 | + module Cohere |
| 6 | + # Determines capabilities and constraints for Cohere models |
| 7 | + module Capabilities |
| 8 | + module_function |
| 9 | + |
| 10 | + # Model family patterns for Cohere models |
| 11 | + MODEL_PATTERNS = { |
| 12 | + 'command-a' => /^c4ai-command-a-|^command-a-/, |
| 13 | + 'command-r' => /^command-r(?!7b)/, # Command R and Command R+ but not R7B |
| 14 | + 'command-r7b' => /^command-r7b/, |
| 15 | + 'command-light' => /^command-light/, |
| 16 | + 'command-nightly' => /^command-nightly/, |
| 17 | + 'command' => /^command(?!-r|-light|-nightly)/, # Regular command models |
| 18 | + 'aya-expanse' => /^c4ai-aya-expanse/, |
| 19 | + 'aya-vision' => /^c4ai-aya-vision/, |
| 20 | + 'aya' => /^aya(?!-vision|-expanse)/, |
| 21 | + 'embed-v4' => /^embed-v4/, |
| 22 | + 'embed-english-v3' => /^embed-english-v3/, |
| 23 | + 'embed-multilingual-v3' => /^embed-multilingual-v3/, |
| 24 | + 'embed-v3' => /^embed-v3/, # Fallback for other v3 models |
| 25 | + 'embed-english' => /^embed-english/, |
| 26 | + 'embed-multilingual' => /^embed-multilingual/, |
| 27 | + 'embed' => /^embed/, # Fallback for other embed models |
| 28 | + 'rerank-v3-5' => /^rerank-v3\.5/, |
| 29 | + 'rerank-english' => /^rerank-english/, |
| 30 | + 'rerank-multilingual' => /^rerank-multilingual/, |
| 31 | + 'rerank' => /^rerank/ # Fallback for other rerank models |
| 32 | + }.freeze |
| 33 | + |
| 34 | + # Normalizes temperature values for Cohere models |
| 35 | + # @param temperature [Float] the temperature value to normalize |
| 36 | + # @param _model [String] the model identifier (unused but kept for API consistency) |
| 37 | + # @return [Float, nil] the normalized temperature value or nil if temperature is nil |
| 38 | + def normalize_temperature(temperature, _model) |
| 39 | + # Cohere accepts temperature values between 0.0 and 2.0 |
| 40 | + return nil if temperature.nil? |
| 41 | + |
| 42 | + temperature.clamp(0.0, 2.0) |
| 43 | + end |
| 44 | + |
| 45 | + # Determines the context window size for a given model |
| 46 | + # @param model_id [String] the model identifier |
| 47 | + # @return [Integer] the context window size in tokens |
| 48 | + def context_window_for(model_id) |
| 49 | + case model_family(model_id) |
| 50 | + when 'command_a' then 256_000 # Command A has 256k context window |
| 51 | + when 'command_r', 'command_r7b' then 128_000 # Command R models have 128k |
| 52 | + when 'aya_expanse', 'aya_vision' then 16_000 # Aya models have 16k |
| 53 | + when 'aya' then 8_192 |
| 54 | + else 4_096 # All other Command and Rerank models have 4k context window |
| 55 | + end |
| 56 | + end |
| 57 | + |
| 58 | + # Determines the maximum output tokens for a given model |
| 59 | + # @param model_id [String] the model identifier |
| 60 | + # @return [Integer] the maximum output tokens |
| 61 | + def max_tokens_for_model(model_id) |
| 62 | + case model_family(model_id) |
| 63 | + when 'command_a' then 8_192 # Command A has 8k max output |
| 64 | + when 'aya_expanse', 'aya_vision' then 4_000 # Aya models have 4k |
| 65 | + when 'aya' then 2_048 |
| 66 | + else 4_096 # All regular Command and Command R models have 4k context window |
| 67 | + end |
| 68 | + end |
| 69 | + |
| 70 | + # Determines if a model supports streaming responses |
| 71 | + # @param model [String] the model identifier |
| 72 | + # @return [Boolean] true if the model supports streaming |
| 73 | + def supports_streaming?(model) |
| 74 | + # Most Cohere models support streaming |
| 75 | + !model.to_s.match?(/embed|rerank/) |
| 76 | + end |
| 77 | + |
| 78 | + # Determines if a model supports tool/function calling |
| 79 | + # @param model [String] the model identifier |
| 80 | + # @return [Boolean] true if the model supports tools |
| 81 | + def supports_tools?(model) |
| 82 | + # Command models and Aya Vision support function calling |
| 83 | + model.to_s.match?(/command|aya-vision|aya-expanse/) |
| 84 | + end |
| 85 | + |
| 86 | + # Determines if a model supports image processing |
| 87 | + # @param model [String] the model identifier |
| 88 | + # @return [Boolean] true if the model supports images |
| 89 | + def supports_images?(model) |
| 90 | + # Aya Vision models support images in chat and Embed v3+ supports image embeddings |
| 91 | + model.to_s.match?(/aya-vision|^embed.*v[34]/) |
| 92 | + end |
| 93 | + |
| 94 | + # Determines if a model supports embedding generation |
| 95 | + # @param model [String] the model identifier |
| 96 | + # @return [Boolean] true if the model supports embeddings |
| 97 | + def supports_embeddings?(model) |
| 98 | + model.to_s.include?('embed') |
| 99 | + end |
| 100 | + |
| 101 | + # Determines if a model supports reranking |
| 102 | + # @param model [String] the model identifier |
| 103 | + # @return [Boolean] true if the model supports reranking |
| 104 | + def supports_reranking?(model) |
| 105 | + model.to_s.include?('rerank') |
| 106 | + end |
| 107 | + |
| 108 | + # Determines if a model supports vision capabilities |
| 109 | + # @param model_id [String] the model identifier |
| 110 | + # @return [Boolean] true if the model supports vision |
| 111 | + def supports_vision?(model_id) |
| 112 | + supports_images?(model_id) |
| 113 | + end |
| 114 | + |
| 115 | + # Determines if a model supports function calling |
| 116 | + # @param model_id [String] the model identifier |
| 117 | + # @return [Boolean] true if the model supports functions |
| 118 | + def supports_functions?(model_id) |
| 119 | + supports_tools?(model_id) |
| 120 | + end |
| 121 | + |
| 122 | + # Determines if a model supports JSON mode |
| 123 | + # @param model_id [String] the model identifier |
| 124 | + # @return [Boolean] true if the model supports JSON mode |
| 125 | + def supports_json_mode?(model_id) |
| 126 | + # Command models and Aya models support structured output |
| 127 | + model_id.match?(/command|aya-vision|aya-expanse/) |
| 128 | + end |
| 129 | + |
| 130 | + # Determines if a model supports structured output |
| 131 | + # @param model_id [String] the model identifier |
| 132 | + # @return [Boolean] true if the model supports structured output |
| 133 | + def supports_structured_output?(model_id) |
| 134 | + supports_json_mode?(model_id) |
| 135 | + end |
| 136 | + |
| 137 | + # Determines the model family for a given model ID |
| 138 | + # @param model_id [String] the model identifier |
| 139 | + # @return [String] the model family identifier |
| 140 | + def model_family(model_id) |
| 141 | + MODEL_PATTERNS.each do |family, pattern| |
| 142 | + return family.to_s if model_id.match?(pattern) |
| 143 | + end |
| 144 | + 'other' |
| 145 | + end |
| 146 | + |
| 147 | + # Returns the model type |
| 148 | + # @param model_id [String] the model identifier |
| 149 | + # @return [String] the model type ('chat', 'embedding', or 'rerank') |
| 150 | + def model_type(model_id) |
| 151 | + case model_family(model_id) |
| 152 | + when /embed/ then 'embedding' |
| 153 | + when /rerank/ then 'rerank' |
| 154 | + else 'chat' |
| 155 | + end |
| 156 | + end |
| 157 | + |
| 158 | + # Pricing information for Cohere models (per million tokens) |
| 159 | + # Source: https://cohere.com/pricing (as of 2025) |
| 160 | + PRICES = { |
| 161 | + command_a: { input: 2.50, output: 10.00 }, |
| 162 | + command_r: { input: 0.15, output: 0.60 }, |
| 163 | + command_r7b: { input: 0.0375, output: 0.15 }, |
| 164 | + command_light: { input: 0.025, output: 0.10 }, # Estimated light model pricing |
| 165 | + command_nightly: { input: 0.15, output: 0.60 }, # Same as command-r for nightly |
| 166 | + command: { input: 0.05, output: 0.20 }, # Regular command model pricing |
| 167 | + aya_expanse: { input: 0.50, output: 1.50 }, |
| 168 | + aya_vision: { input: 0.50, output: 1.50 }, |
| 169 | + aya: { input: 0.50, output: 1.50 }, |
| 170 | + embed_v4: { text_price: 0.12, image_price: 0.47 }, |
| 171 | + embed_english_v3: { price: 0.12 }, |
| 172 | + embed_multilingual_v3: { price: 0.12 }, |
| 173 | + embed_v3: { price: 0.12 }, |
| 174 | + embed_english: { price: 0.10 }, # Light models slightly cheaper |
| 175 | + embed_multilingual: { price: 0.10 }, |
| 176 | + embed: { price: 0.12 }, |
| 177 | + rerank_v3point5: { price: 2.0 }, |
| 178 | + rerank_english: { price: 1.50 }, |
| 179 | + rerank_multilingual: { price: 2.0 }, |
| 180 | + rerank: { price: 2.0 } # $2.00 per 1K searches = $2000 per 1M searches |
| 181 | + }.freeze |
| 182 | + |
| 183 | + # Gets the input price per million tokens for a given model |
| 184 | + # @param model_id [String] the model identifier |
| 185 | + # @return [Float] the price per million tokens for input |
| 186 | + def input_price_for(model_id) |
| 187 | + family = model_family(model_id).to_sym |
| 188 | + prices = PRICES.fetch(family, { input: default_input_price }) |
| 189 | + prices[:input] || prices[:text_price] || prices[:price] || default_input_price |
| 190 | + end |
| 191 | + |
| 192 | + # Gets the output price per million tokens for a given model |
| 193 | + # @param model_id [String] the model identifier |
| 194 | + # @return [Float] the price per million tokens for output |
| 195 | + def output_price_for(model_id) |
| 196 | + family = model_family(model_id).to_sym |
| 197 | + prices = PRICES.fetch(family, { output: default_output_price }) |
| 198 | + prices[:output] || prices[:text_price] || prices[:price] || default_output_price |
| 199 | + end |
| 200 | + |
| 201 | + # Gets the image price per million tokens for a given model |
| 202 | + # @param model_id [String] the model identifier |
| 203 | + # @return [Float, nil] the price per million tokens for image processing or nil if not supported |
| 204 | + def image_price_for(model_id) |
| 205 | + family = model_family(model_id).to_sym |
| 206 | + prices = PRICES.fetch(family, {}) |
| 207 | + prices[:image_price] |
| 208 | + end |
| 209 | + |
| 210 | + # Default input price if model not found in PRICES |
| 211 | + # @return [Float] default price per million tokens for input |
| 212 | + def default_input_price |
| 213 | + 1.0 |
| 214 | + end |
| 215 | + |
| 216 | + # Default output price if model not found in PRICES |
| 217 | + # @return [Float] default price per million tokens for output |
| 218 | + def default_output_price |
| 219 | + 2.0 |
| 220 | + end |
| 221 | + |
| 222 | + # Returns the supported modalities for a given model |
| 223 | + # @param model_id [String] the model identifier |
| 224 | + # @return [Hash] hash containing input and output modalities |
| 225 | + def modalities_for(model_id) |
| 226 | + modalities = { |
| 227 | + input: ['text'], |
| 228 | + output: ['text'] |
| 229 | + } |
| 230 | + |
| 231 | + modalities[:input] << 'image' if supports_images?(model_id) |
| 232 | + modalities[:output] = ['embeddings'] if supports_embeddings?(model_id) |
| 233 | + modalities[:output] = ['rerank'] if supports_reranking?(model_id) |
| 234 | + |
| 235 | + modalities |
| 236 | + end |
| 237 | + |
| 238 | + # Returns the capabilities of a given model |
| 239 | + # @param model_id [String] the model identifier |
| 240 | + # @return [Array<String>] array of capability strings |
| 241 | + def capabilities_for(model_id) |
| 242 | + capabilities = [] |
| 243 | + |
| 244 | + capabilities << 'streaming' if supports_streaming?(model_id) |
| 245 | + capabilities << 'reranking' if supports_reranking?(model_id) |
| 246 | + capabilities << 'function_calling' if supports_functions?(model_id) |
| 247 | + capabilities << 'structured_output' if supports_structured_output?(model_id) |
| 248 | + capabilities << 'multilingual' if model_id.match?(/aya|multilingual/) |
| 249 | + capabilities << 'citations' if model_id.match?(/command-a|aya-vision|aya-expanse/) |
| 250 | + |
| 251 | + capabilities |
| 252 | + end |
| 253 | + |
| 254 | + # Returns the pricing structure for a given model |
| 255 | + # @param model_id [String] the model identifier |
| 256 | + # @return [Hash] hash containing pricing information |
| 257 | + def pricing_for(model_id) |
| 258 | + family = model_family(model_id) |
| 259 | + prices = PRICES.fetch(family.to_sym, { input: default_input_price, output: default_output_price }) |
| 260 | + |
| 261 | + if prices[:price] |
| 262 | + # For models with single pricing (like older embeddings and rerank) |
| 263 | + { usage_tokens: { standard: { price_per_million: prices[:price] } } } |
| 264 | + elsif prices[:text_price] |
| 265 | + # For models with text/image pricing (like embed-v4) |
| 266 | + pricing_structure = { |
| 267 | + text_tokens: { standard: { price_per_million: prices[:text_price] } } |
| 268 | + } |
| 269 | + |
| 270 | + # Add image pricing if available |
| 271 | + if prices[:image_price] |
| 272 | + pricing_structure[:image_tokens] = { |
| 273 | + standard: { price_per_million: prices[:image_price] } |
| 274 | + } |
| 275 | + end |
| 276 | + |
| 277 | + pricing_structure |
| 278 | + else |
| 279 | + # For models with input/output pricing |
| 280 | + { |
| 281 | + text_tokens: { |
| 282 | + standard: { |
| 283 | + input_per_million: prices[:input], |
| 284 | + output_per_million: prices[:output] |
| 285 | + } |
| 286 | + } |
| 287 | + } |
| 288 | + end |
| 289 | + end |
| 290 | + |
| 291 | + # Formats a model ID for display purposes |
| 292 | + # @param model_id [String] the model identifier |
| 293 | + # @return [String] the formatted display name |
| 294 | + def format_display_name(model_id) |
| 295 | + model_id.gsub(/^c4ai-/, '') |
| 296 | + .tr('-', ' ') |
| 297 | + .split |
| 298 | + .map(&:capitalize) |
| 299 | + .join(' ') |
| 300 | + end |
| 301 | + end |
| 302 | + end |
| 303 | + end |
| 304 | +end |
0 commit comments