|
| 1 | +# frozen_string_literal: true |
| 2 | + |
| 3 | +require 'dotenv/load' |
| 4 | +require 'ruby_llm' |
| 5 | + |
| 6 | +task default: ['models:update'] |
| 7 | + |
| 8 | +namespace :models do |
| 9 | + desc 'Update available models from providers (API keys needed)' |
| 10 | + task :update do |
| 11 | + puts 'Configuring RubyLLM...' |
| 12 | + configure_from_env |
| 13 | + |
| 14 | + refresh_models |
| 15 | + display_model_stats |
| 16 | + end |
| 17 | +end |
| 18 | + |
| 19 | +def configure_from_env |
| 20 | + RubyLLM.configure do |config| |
| 21 | + config.openai_api_key = ENV.fetch('OPENAI_API_KEY', nil) |
| 22 | + config.anthropic_api_key = ENV.fetch('ANTHROPIC_API_KEY', nil) |
| 23 | + config.gemini_api_key = ENV.fetch('GEMINI_API_KEY', nil) |
| 24 | + config.deepseek_api_key = ENV.fetch('DEEPSEEK_API_KEY', nil) |
| 25 | + config.openrouter_api_key = ENV.fetch('OPENROUTER_API_KEY', nil) |
| 26 | + configure_bedrock(config) |
| 27 | + configure_azure_openai(config) |
| 28 | + config.request_timeout = 30 |
| 29 | + end |
| 30 | +end |
| 31 | + |
| 32 | +def configure_azure_openai(config) |
| 33 | + config.azure_openai_api_base = ENV.fetch('AZURE_OPENAI_ENDPOINT', nil) |
| 34 | + config.azure_openai_api_key = ENV.fetch('AZURE_OPENAI_API_KEY', nil) |
| 35 | + config.azure_openai_api_version = ENV.fetch('AZURE_OPENAI_API_VER', nil) |
| 36 | +end |
| 37 | + |
| 38 | +def configure_bedrock(config) |
| 39 | + config.bedrock_api_key = ENV.fetch('AWS_ACCESS_KEY_ID', nil) |
| 40 | + config.bedrock_secret_key = ENV.fetch('AWS_SECRET_ACCESS_KEY', nil) |
| 41 | + config.bedrock_region = ENV.fetch('AWS_REGION', nil) |
| 42 | + config.bedrock_session_token = ENV.fetch('AWS_SESSION_TOKEN', nil) |
| 43 | +end |
| 44 | + |
| 45 | +def refresh_models |
| 46 | + initial_count = RubyLLM.models.all.size |
| 47 | + puts "Refreshing models (#{initial_count} cached)..." |
| 48 | + |
| 49 | + models = RubyLLM.models.refresh! |
| 50 | + |
| 51 | + if models.all.empty? && initial_count.zero? |
| 52 | + puts 'Error: Failed to fetch models.' |
| 53 | + exit(1) |
| 54 | + elsif models.all.size == initial_count && initial_count.positive? |
| 55 | + puts 'Warning: Model list unchanged.' |
| 56 | + else |
| 57 | + puts "Saving models.json (#{models.all.size} models)" |
| 58 | + models.save_models |
| 59 | + end |
| 60 | + |
| 61 | + @models = models |
| 62 | +end |
| 63 | + |
| 64 | +def display_model_stats |
| 65 | + puts "\nModel count:" |
| 66 | + provider_counts = @models.all.group_by(&:provider).transform_values(&:count) |
| 67 | + |
| 68 | + RubyLLM::Provider.providers.each_key do |sym| |
| 69 | + name = sym.to_s.capitalize |
| 70 | + count = provider_counts[sym.to_s] || 0 |
| 71 | + status = status(sym) |
| 72 | + puts " #{name}: #{count} models #{status}" |
| 73 | + end |
| 74 | + |
| 75 | + puts 'Refresh complete.' |
| 76 | +end |
| 77 | + |
| 78 | +def status(provider_sym) |
| 79 | + if RubyLLM::Provider.providers[provider_sym].local? |
| 80 | + ' (LOCAL - SKIP)' |
| 81 | + elsif RubyLLM::Provider.providers[provider_sym].configured? |
| 82 | + ' (OK)' |
| 83 | + else |
| 84 | + ' (NOT CONFIGURED)' |
| 85 | + end |
| 86 | +end |
0 commit comments