Skip to content

Commit 59a7606

Browse files
committed
add support for Azure token providers
1 parent 7666e30 commit 59a7606

File tree

7 files changed

+250
-9
lines changed

7 files changed

+250
-9
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,21 @@ To use the [Azure OpenAI Service](https://learn.microsoft.com/en-us/azure/cognit
186186

187187
where `AZURE_OPENAI_URI` is e.g. `https://custom-domain.openai.azure.com/openai/deployments/gpt-35-turbo`
188188

189+
##### Azure with Azure AD tokens
190+
191+
To use Azure AD tokens you can configure the gem with a proc like this:
192+
193+
```ruby
194+
OpenAI.configure do |config|
195+
config.azure_token_provider = ->() { your_code_caches_or_refreshes_token }
196+
config.uri_base = ENV.fetch("AZURE_OPENAI_URI")
197+
config.api_type = :azure
198+
config.api_version = "2023-03-15-preview"
199+
end
200+
```
201+
202+
The azure_token_provider will be called on every request. This allows tokens to be cached and periodically refreshed by your custom code.
203+
189204
### Counting Tokens
190205

191206
OpenAI parses prompt text into [tokens](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them), which are words or portions of words. (These tokens are unrelated to your API access_token.) Counting tokens can help you estimate your [costs](https://openai.com/pricing). It can also help you ensure your prompt text size is within the max-token limits of your model's context window, and choose an appropriate [`max_tokens`](https://platform.openai.com/docs/api-reference/chat/create#chat/create-max_tokens) completion parameter so your response will fit as well.

lib/openai.rb

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def call(env)
3636
end
3737

3838
class Configuration
39-
attr_writer :access_token
39+
attr_reader :azure_token_provider
4040
attr_accessor :api_type, :api_version, :organization_id, :uri_base, :request_timeout,
41-
:extra_headers
41+
:extra_headers, :access_token
4242

4343
DEFAULT_API_VERSION = "v1".freeze
4444
DEFAULT_URI_BASE = "https://api.openai.com/".freeze
@@ -52,13 +52,16 @@ def initialize
5252
@uri_base = DEFAULT_URI_BASE
5353
@request_timeout = DEFAULT_REQUEST_TIMEOUT
5454
@extra_headers = {}
55+
@azure_token_provider = nil
5556
end
5657

57-
def access_token
58-
return @access_token if @access_token
58+
def azure_token_provider=(provider)
59+
unless provider.nil? || (provider.is_a?(Proc) && provider.arity.zero?)
60+
raise ConfigurationError,
61+
"OpenAI Azure AD token provider must be a Proc that takes no arguments"
62+
end
5963

60-
error_text = "OpenAI access token missing! See https://github.com/alexrudall/ruby-openai#usage"
61-
raise ConfigurationError, error_text
64+
@azure_token_provider = provider
6265
end
6366
end
6467

lib/openai/client.rb

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,19 @@ class Client
1010
uri_base
1111
request_timeout
1212
extra_headers
13+
azure_token_provider
1314
].freeze
1415
attr_reader *CONFIG_KEYS, :faraday_middleware
1516

1617
def initialize(config = {}, &faraday_middleware)
1718
CONFIG_KEYS.each do |key|
1819
# Set instance variables like api_type & access_token. Fall back to global config
1920
# if not present.
20-
instance_variable_set("@#{key}", config[key] || OpenAI.configuration.send(key))
21+
instance_variable_set("@#{key}",
22+
config.key?(key) ? config[key] : OpenAI.configuration.send(key))
2123
end
2224
@faraday_middleware = faraday_middleware
25+
validate_credential_config!
2326
end
2427

2528
def chat(parameters: {})
@@ -87,5 +90,19 @@ def beta(apis)
8790
client.add_headers("OpenAI-Beta": apis.map { |k, v| "#{k}=#{v}" }.join(";"))
8891
end
8992
end
93+
94+
private
95+
96+
def validate_credential_config!
97+
if @access_token && @azure_token_provider
98+
raise ConfigurationError,
99+
"Only one of OpenAI access token or Azure token provider can be set! See https://github.com/alexrudall/ruby-openai#usage"
100+
end
101+
102+
return if @access_token || @azure_token_provider
103+
104+
raise ConfigurationError,
105+
"OpenAI access token or Azure token provider missing! See https://github.com/alexrudall/ruby-openai#usage"
106+
end
90107
end
91108
end

lib/openai/http_headers.rb

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,20 @@ def openai_headers
2525
def azure_headers
2626
{
2727
"Content-Type" => "application/json",
28-
"api-key" => @access_token
28+
**azure_auth_headers
2929
}
3030
end
3131

32+
def azure_auth_headers
33+
if @access_token
34+
{ "api-key" => @access_token }
35+
elsif @azure_token_provider
36+
{ "Authorization" => "Bearer #{@azure_token_provider.call}" }
37+
else
38+
raise ConfigurationError, "access_token or azure_token_provider must be set."
39+
end
40+
end
41+
3242
def extra_headers
3343
@extra_headers ||= {}
3444
end

spec/fixtures/cassettes/http_json_post_with_azure_token_provider.yml

Lines changed: 125 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

spec/openai/client/client_spec.rb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
let!(:c2) do
3131
OpenAI::Client.new(
3232
access_token: "access_token2",
33-
organization_id: nil,
3433
request_timeout: 1,
3534
uri_base: "https://example.com/"
3635
)

spec/openai/client/http_spec.rb

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,56 @@
120120
end
121121
end
122122

123+
describe ".json_post" do
124+
context "with azure_token_provider" do
125+
let(:token_provider) do
126+
counter = 0
127+
lambda do
128+
counter += 1
129+
"some dynamic token #{counter}"
130+
end
131+
end
132+
133+
let(:client) do
134+
OpenAI::Client.new(
135+
access_token: nil,
136+
azure_token_provider: token_provider,
137+
api_type: :azure,
138+
uri_base: "https://custom-domain.openai.azure.com/openai/deployments/gpt-35-turbo",
139+
api_version: "2024-02-01"
140+
)
141+
end
142+
143+
let(:cassette) { "http json post with azure token provider" }
144+
145+
it "calls the token provider on every request" do
146+
expect(token_provider).to receive(:call).twice.and_call_original
147+
VCR.use_cassette(cassette, record: :none) do
148+
client.chat(
149+
parameters: {
150+
messages: [
151+
{
152+
"role" => "user",
153+
"content" => "Hello world!"
154+
}
155+
]
156+
}
157+
)
158+
client.chat(
159+
parameters: {
160+
messages: [
161+
{
162+
"role" => "user",
163+
"content" => "Who were the founders of Microsoft?"
164+
}
165+
]
166+
}
167+
)
168+
end
169+
end
170+
end
171+
end
172+
123173
describe ".to_json_stream" do
124174
context "with a proc" do
125175
let(:user_proc) { proc { |x| x } }
@@ -269,6 +319,28 @@
269319
expect(headers).to eq({ "Content-Type" => "application/json",
270320
"api-key" => OpenAI.configuration.access_token })
271321
}
322+
323+
context "with azure_token_provider" do
324+
let(:token) { "some dynamic token" }
325+
let(:token_provider) { -> { token } }
326+
327+
around do |ex|
328+
old_access_token = OpenAI.configuration.access_token
329+
OpenAI.configuration.access_token = nil
330+
OpenAI.configuration.azure_token_provider = token_provider
331+
332+
ex.call
333+
ensure
334+
OpenAI.configuration.azure_token_provider = nil
335+
OpenAI.configuration.access_token = old_access_token
336+
end
337+
338+
it {
339+
expect(token_provider).to receive(:call).once.and_call_original
340+
expect(headers).to eq({ "Content-Type" => "application/json",
341+
"Authorization" => "Bearer #{token}" })
342+
}
343+
end
272344
end
273345
end
274346
end

0 commit comments

Comments
 (0)