Skip to content

Commit 2372123

Browse files
committed
Add support for listing models
1 parent 847dbc1 commit 2372123

File tree

4 files changed

+107
-2
lines changed

4 files changed

+107
-2
lines changed

lib/ruby_llm/providers/azure_openai.rb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ module AzureOpenAI
88
extend OpenAI
99
extend AzureOpenAI::Chat
1010
extend AzureOpenAI::Streaming
11+
extend AzureOpenAI::Models
1112

1213
module_function
1314

1415
def api_base(config)
1516
# https://<ENDPOINT>/openai/deployments/<MODEL>/chat/completions?api-version=<APIVERSION>
16-
"#{config.azure_openai_api_base}/openai/deployments"
17+
"#{config.azure_openai_api_base}/openai"
1718
end
1819

1920
def headers(config)

lib/ruby_llm/providers/azure_openai/chat.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def sync_response(connection, payload)
1717

1818
def completion_url
1919
# https://<ENDPOINT>/openai/deployments/<MODEL>/chat/completions?api-version=<APIVERSION>
20-
"#{@model_id}/chat/completions?api-version=#{@config.azure_openai_api_version}"
20+
"deployments/#{@model_id}/chat/completions?api-version=#{@config.azure_openai_api_version}"
2121
end
2222

2323
def render_payload(messages, tools:, temperature:, model:, stream: false)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
module AzureOpenAI
6+
# Models methods of the OpenAI API integration
7+
module Models
8+
extend OpenAI::Models
9+
10+
module_function
11+
12+
def models_url
13+
'models?api-version=2024-10-21'
14+
end
15+
end
16+
end
17+
end
18+
end

lib/tasks/models_update.rake

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# frozen_string_literal: true
2+
3+
require 'dotenv/load'
4+
require 'ruby_llm'
5+
6+
task default: ['models:update']
7+
8+
namespace :models do
9+
desc 'Update available models from providers (API keys needed)'
10+
task :update do
11+
puts 'Configuring RubyLLM...'
12+
configure_from_env
13+
14+
refresh_models
15+
display_model_stats
16+
end
17+
end
18+
19+
def configure_from_env
20+
RubyLLM.configure do |config|
21+
config.openai_api_key = ENV.fetch('OPENAI_API_KEY', nil)
22+
config.anthropic_api_key = ENV.fetch('ANTHROPIC_API_KEY', nil)
23+
config.gemini_api_key = ENV.fetch('GEMINI_API_KEY', nil)
24+
config.deepseek_api_key = ENV.fetch('DEEPSEEK_API_KEY', nil)
25+
config.openrouter_api_key = ENV.fetch('OPENROUTER_API_KEY', nil)
26+
configure_bedrock(config)
27+
configure_azure_openai(config)
28+
config.request_timeout = 30
29+
end
30+
end
31+
32+
def configure_azure_openai(config)
33+
config.azure_openai_api_base = ENV.fetch('AZURE_OPENAI_ENDPOINT', nil)
34+
config.azure_openai_api_key = ENV.fetch('AZURE_OPENAI_API_KEY', nil)
35+
config.azure_openai_api_version = ENV.fetch('AZURE_OPENAI_API_VER', nil)
36+
end
37+
38+
def configure_bedrock(config)
39+
config.bedrock_api_key = ENV.fetch('AWS_ACCESS_KEY_ID', nil)
40+
config.bedrock_secret_key = ENV.fetch('AWS_SECRET_ACCESS_KEY', nil)
41+
config.bedrock_region = ENV.fetch('AWS_REGION', nil)
42+
config.bedrock_session_token = ENV.fetch('AWS_SESSION_TOKEN', nil)
43+
end
44+
45+
def refresh_models
46+
initial_count = RubyLLM.models.all.size
47+
puts "Refreshing models (#{initial_count} cached)..."
48+
49+
models = RubyLLM.models.refresh!
50+
51+
if models.all.empty? && initial_count.zero?
52+
puts 'Error: Failed to fetch models.'
53+
exit(1)
54+
elsif models.all.size == initial_count && initial_count.positive?
55+
puts 'Warning: Model list unchanged.'
56+
else
57+
puts "Saving models.json (#{models.all.size} models)"
58+
models.save_models
59+
end
60+
61+
@models = models
62+
end
63+
64+
def display_model_stats
65+
puts "\nModel count:"
66+
provider_counts = @models.all.group_by(&:provider).transform_values(&:count)
67+
68+
RubyLLM::Provider.providers.each_key do |sym|
69+
name = sym.to_s.capitalize
70+
count = provider_counts[sym.to_s] || 0
71+
status = status(sym)
72+
puts " #{name}: #{count} models #{status}"
73+
end
74+
75+
puts 'Refresh complete.'
76+
end
77+
78+
def status(provider_sym)
79+
if RubyLLM::Provider.providers[provider_sym].local?
80+
' (LOCAL - SKIP)'
81+
elsif RubyLLM::Provider.providers[provider_sym].configured?
82+
' (OK)'
83+
else
84+
' (NOT CONFIGURED)'
85+
end
86+
end

0 commit comments

Comments
 (0)