Skip to content

Commit 792200a

Browse files
committed
Add Replicate provider
1 parent a0efaa4 commit 792200a

File tree

12 files changed

+168
-5
lines changed

12 files changed

+168
-5
lines changed

bin/console

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ RubyLLM.configure do |config|
2323
config.openai_api_key = ENV.fetch('OPENAI_API_KEY', nil)
2424
config.openrouter_api_key = ENV.fetch('OPENROUTER_API_KEY', nil)
2525
config.perplexity_api_key = ENV.fetch('PERPLEXITY_API_KEY', nil)
26+
config.replicate_api_key = ENV.fetch('REPLICATE_API_KEY', nil)
27+
config.replicate_webhook_url = ENV.fetch('REPLICATE_WEBHOOK_URL', nil)
2628
config.vertexai_location = ENV.fetch('GOOGLE_CLOUD_LOCATION', nil)
2729
config.vertexai_project_id = ENV.fetch('GOOGLE_CLOUD_PROJECT', nil)
2830
end

docs/_getting_started/configuration.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ RubyLLM.configure do |config|
5959
config.mistral_api_key = ENV['MISTRAL_API_KEY']
6060
config.perplexity_api_key = ENV['PERPLEXITY_API_KEY']
6161
config.openrouter_api_key = ENV['OPENROUTER_API_KEY']
62+
config.replicate_api_key = ENV['REPLICATE_API_KEY']
63+
config.replicate_webhook_url = ENV['REPLICATE_WEBHOOK_URL']
6264

6365
# Local providers
6466
config.ollama_api_base = 'http://localhost:11434/v1'
@@ -363,4 +365,4 @@ Now that you've configured RubyLLM, you're ready to:
363365

364366
- [Start chatting with AI models]({% link _core_features/chat.md %})
365367
- [Work with different providers and models]({% link _advanced/models.md %})
366-
- [Set up Rails integration]({% link _advanced/rails.md %})
368+
- [Set up Rails integration]({% link _advanced/rails.md %})

lib/ruby_llm.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def logger
9393
RubyLLM::Provider.register :openrouter, RubyLLM::Providers::OpenRouter
9494
RubyLLM::Provider.register :perplexity, RubyLLM::Providers::Perplexity
9595
RubyLLM::Provider.register :vertexai, RubyLLM::Providers::VertexAI
96+
RubyLLM::Provider.register :replicate, RubyLLM::Providers::Replicate
9697

9798
if defined?(Rails::Railtie)
9899
require 'ruby_llm/railtie'

lib/ruby_llm/configuration.rb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class Configuration
2323
:gpustack_api_base,
2424
:gpustack_api_key,
2525
:mistral_api_key,
26+
:replicate_api_key,
27+
:replicate_webhook_url,
28+
:replicate_webhook_events_filter,
2629
# Default models
2730
:default_model,
2831
:default_embedding_model,

lib/ruby_llm/image.rb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@ def self.paint(prompt, # rubocop:disable Metrics/ParameterLists
3636
provider: nil,
3737
assume_model_exists: false,
3838
size: '1024x1024',
39-
context: nil)
39+
context: nil,
40+
model_params: {})
4041
config = context&.config || RubyLLM.config
4142
model ||= config.default_image_model
4243
model, provider_instance = Models.resolve(model, provider: provider, assume_exists: assume_model_exists,
4344
config: config)
4445
model_id = model.id
4546

46-
provider_instance.paint(prompt, model: model_id, size:)
47+
provider_instance.paint(prompt, model: model_id, size:, **model_params)
4748
end
4849
end
4950
end

lib/ruby_llm/provider.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def moderate(input, model:)
7676
parse_moderation_response(response, model:)
7777
end
7878

79-
def paint(prompt, model:, size:)
80-
payload = render_image_payload(prompt, model:, size:)
79+
def paint(prompt, model:, size:, **params)
80+
payload = render_image_payload(prompt, model:, size:, **params)
8181
response = @connection.post images_url, payload
8282
parse_image_response(response, model:)
8383
end
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
# Replicate API integration
6+
class Replicate < Provider
7+
include Replicate::Capabilities
8+
include Replicate::Images
9+
include Replicate::Models
10+
11+
def api_base
12+
'https://api.replicate.com'
13+
end
14+
15+
def headers
16+
{
17+
'Authorization' => "Bearer #{@config.replicate_api_key}",
18+
'Content-Type' => 'application/json'
19+
}
20+
end
21+
22+
class << self
23+
def capabilities
24+
Replicate::Capabilities
25+
end
26+
27+
def configuration_requirements
28+
%i[replicate_api_key replicate_webhook_url]
29+
end
30+
end
31+
end
32+
end
33+
end
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
class Replicate
6+
# Determines capabilities for Replicate models
7+
module Capabilities
8+
module_function
9+
10+
def name_from(model_data)
11+
"replicate/#{model_data['owner']}/#{model_data['name']}"
12+
end
13+
14+
def metadata_from(model_data)
15+
input_schema = model_data.dig('latest_version', 'openapi_schema', 'components', 'schemas', 'Input')
16+
17+
{
18+
url: model_data['url'],
19+
description: model_data['description'],
20+
license_url: model_data['license_url'],
21+
is_official: model_data['is_official'],
22+
supported_parameters: input_schema['properties'].keys - ['prompt'],
23+
latest_version_created_at: model_data.dig('latest_version', 'created_at')
24+
}
25+
end
26+
end
27+
end
28+
end
29+
end
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
class Replicate
6+
# Image generation methods for the Replicate API implementation
7+
module Images
8+
attr_reader :model
9+
10+
def images_url
11+
if official_model?
12+
"/v1/models/#{canonical_model_name}/predictions"
13+
else
14+
'/v1/predictions'
15+
end
16+
end
17+
18+
def render_image_payload(prompt, model:, size:, **params)
19+
RubyLLM.logger.debug "Use `model_params` for model-dependent sizing instead of size #{size}." if size
20+
21+
self.model_id = model
22+
23+
{}.tap do |payload|
24+
payload[:webhook] = @config.replicate_webhook_url
25+
payload[:version] = model.id unless official_model?
26+
payload[:input] = { prompt: prompt }.merge(params)
27+
28+
if @config.replicate_webhook_events_filter
29+
payload[:webhook_events_filter] = @config.replicate_webhook_events_filter
30+
end
31+
end
32+
end
33+
34+
def parse_image_response(response, **)
35+
response
36+
end
37+
38+
private
39+
40+
def model_id=(id)
41+
@model_id = id
42+
@model = Models.find(@model_id, 'replicate')
43+
end
44+
45+
def official_model?
46+
model.metadata['is_official'] == true
47+
end
48+
49+
def canonical_model_name
50+
model.name.split('/')[1..].join('/')
51+
end
52+
end
53+
end
54+
end
55+
end
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
module Providers
5+
class Replicate
6+
# Models methods of the Replicate API integration
7+
module Models
8+
def list_models(**)
9+
response = @connection.get models_url
10+
parse_list_models_response response
11+
end
12+
13+
def models_url
14+
'/v1/collections/text-to-image'
15+
end
16+
17+
def parse_list_models_response(response)
18+
Array(response.body['models']).map do |model_data|
19+
Model::Info.new(
20+
id: model_data.dig('latest_version', 'id'),
21+
name: capabilities.name_from(model_data),
22+
provider: 'replicate',
23+
created_at: model_data['created_at'],
24+
modalities: { input: ['text'], output: ['text'] },
25+
capabilities: ['image_generation'],
26+
metadata: capabilities.metadata_from(model_data)
27+
)
28+
end
29+
end
30+
end
31+
end
32+
end
33+
end

0 commit comments

Comments
 (0)