Skip to content

Commit b585f24

Browse files
Add support for custom OAuth functions (#1925)
* Add support for custom OAuth functions * Update token, logical cluster, pool into BearerFieldProvider
1 parent 497ef2a commit b585f24

File tree

4 files changed

+267
-91
lines changed

4 files changed

+267
-91
lines changed

src/confluent_kafka/schema_registry/schema_registry_client.py

Lines changed: 100 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#
18-
18+
import abc
1919
import json
2020
import logging
2121
import random
@@ -30,7 +30,7 @@
3030
from enum import Enum
3131
from threading import Lock
3232
from typing import List, Dict, Type, TypeVar, \
33-
cast, Optional, Union, Any, Tuple
33+
cast, Optional, Union, Any, Tuple, Callable
3434

3535
from cachetools import TTLCache, LRUCache
3636
from httpx import Response
@@ -62,18 +62,50 @@ def _urlencode(value: str) -> str:
6262
VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO']
6363

6464

65-
class _OAuthClient:
66-
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str,
67-
max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
65+
class _BearerFieldProvider(metaclass=abc.ABCMeta):
66+
@abc.abstractmethod
67+
def get_bearer_fields(self) -> dict:
68+
raise NotImplementedError
69+
70+
71+
class _StaticFieldProvider(_BearerFieldProvider):
72+
def __init__(self, token: str, logical_cluster: str, identity_pool: str):
73+
self.token = token
74+
self.logical_cluster = logical_cluster
75+
self.identity_pool = identity_pool
76+
77+
def get_bearer_fields(self) -> dict:
78+
return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster,
79+
'bearer.auth.identity.pool.id': self.identity_pool}
80+
81+
82+
class _CustomOAuthClient(_BearerFieldProvider):
83+
def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict):
84+
self.custom_function = custom_function
85+
self.custom_config = custom_config
86+
87+
def get_bearer_fields(self) -> dict:
88+
return self.custom_function(self.custom_config)
89+
90+
91+
class _OAuthClient(_BearerFieldProvider):
92+
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str,
93+
identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
6894
self.token = None
95+
self.logical_cluster = logical_cluster
96+
self.identity_pool = identity_pool
6997
self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope)
7098
self.token_endpoint = token_endpoint
7199
self.max_retries = max_retries
72100
self.retries_wait_ms = retries_wait_ms
73101
self.retries_max_wait_ms = retries_max_wait_ms
74102
self.token_expiry_threshold = 0.8
75103

76-
def token_expired(self):
104+
def get_bearer_fields(self) -> dict:
105+
return {'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster,
106+
'bearer.auth.identity.pool.id': self.identity_pool}
107+
108+
def token_expired(self) -> bool:
77109
expiry_window = self.token['expires_in'] * self.token_expiry_threshold
78110

79111
return self.token['expires_at'] < time.time() + expiry_window
@@ -84,7 +116,7 @@ def get_access_token(self) -> str:
84116

85117
return self.token['access_token']
86118

87-
def generate_access_token(self):
119+
def generate_access_token(self) -> None:
88120
for i in range(self.max_retries + 1):
89121
try:
90122
self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials')
@@ -206,23 +238,27 @@ def __init__(self, conf: dict):
206238
+ str(type(retries_max_wait_ms)))
207239
self.retries_max_wait_ms = retries_max_wait_ms
208240

209-
self.oauth_client = None
241+
self.bearer_field_provider = None
242+
logical_cluster = None
243+
identity_pool = None
210244
self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None)
211245
if self.bearer_auth_credentials_source is not None:
212246
self.auth = None
213-
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
214-
missing_headers = [header for header in headers if header not in conf_copy]
215-
if missing_headers:
216-
raise ValueError("Missing required bearer configuration properties: {}"
217-
.format(", ".join(missing_headers)))
218247

219-
self.logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
220-
if not isinstance(self.logical_cluster, str):
221-
raise TypeError("logical cluster must be a str, not " + str(type(self.logical_cluster)))
248+
if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}:
249+
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
250+
missing_headers = [header for header in headers if header not in conf_copy]
251+
if missing_headers:
252+
raise ValueError("Missing required bearer configuration properties: {}"
253+
.format(", ".join(missing_headers)))
222254

223-
self.identity_pool_id = conf_copy.pop('bearer.auth.identity.pool.id')
224-
if not isinstance(self.identity_pool_id, str):
225-
raise TypeError("identity pool id must be a str, not " + str(type(self.identity_pool_id)))
255+
logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
256+
if not isinstance(logical_cluster, str):
257+
raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster)))
258+
259+
identity_pool = conf_copy.pop('bearer.auth.identity.pool.id')
260+
if not isinstance(identity_pool, str):
261+
raise TypeError("identity pool id must be a str, not " + str(type(identity_pool)))
226262

227263
if self.bearer_auth_credentials_source == 'OAUTHBEARER':
228264
properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope',
@@ -249,15 +285,38 @@ def __init__(self, conf: dict):
249285
raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not "
250286
+ str(type(self.token_endpoint)))
251287

252-
self.oauth_client = _OAuthClient(self.client_id, self.client_secret, self.scope, self.token_endpoint,
253-
self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms)
254-
288+
self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope,
289+
self.token_endpoint, logical_cluster, identity_pool,
290+
self.max_retries, self.retries_wait_ms,
291+
self.retries_max_wait_ms)
255292
elif self.bearer_auth_credentials_source == 'STATIC_TOKEN':
256293
if 'bearer.auth.token' not in conf_copy:
257294
raise ValueError("Missing bearer.auth.token")
258-
self.bearer_token = conf_copy.pop('bearer.auth.token')
259-
if not isinstance(self.bearer_token, string_type):
260-
raise TypeError("bearer.auth.token must be a str, not " + str(type(self.bearer_token)))
295+
static_token = conf_copy.pop('bearer.auth.token')
296+
self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool)
297+
if not isinstance(static_token, string_type):
298+
raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token)))
299+
elif self.bearer_auth_credentials_source == 'CUSTOM':
300+
custom_bearer_properties = ['bearer.auth.custom.provider.function',
301+
'bearer.auth.custom.provider.config']
302+
missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy]
303+
if missing_custom_properties:
304+
raise ValueError("Missing required custom OAuth configuration properties: {}".
305+
format(", ".join(missing_custom_properties)))
306+
307+
custom_function = conf_copy.pop('bearer.auth.custom.provider.function')
308+
if not callable(custom_function):
309+
raise TypeError("bearer.auth.custom.provider.function must be a callable, not "
310+
+ str(type(custom_function)))
311+
312+
custom_config = conf_copy.pop('bearer.auth.custom.provider.config')
313+
if not isinstance(custom_config, dict):
314+
raise TypeError("bearer.auth.custom.provider.config must be a dict, not "
315+
+ str(type(custom_config)))
316+
317+
self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config)
318+
else:
319+
raise ValueError('Unrecognized bearer.auth.credentials.source')
261320

262321
# Any leftover keys are unknown to _RestClient
263322
if len(conf_copy) > 0:
@@ -298,13 +357,22 @@ def __init__(self, conf: dict):
298357
timeout=self.timeout
299358
)
300359

301-
def handle_bearer_auth(self, headers: dict):
302-
token = self.bearer_token
303-
if self.oauth_client:
304-
token = self.oauth_client.get_access_token()
305-
headers["Authorization"] = "Bearer {}".format(token)
306-
headers['Confluent-Identity-Pool-Id'] = self.identity_pool_id
307-
headers['target-sr-cluster'] = self.logical_cluster
360+
def handle_bearer_auth(self, headers: dict) -> None:
361+
bearer_fields = self.bearer_field_provider.get_bearer_fields()
362+
required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster']
363+
364+
missing_fields = []
365+
for field in required_fields:
366+
if field not in bearer_fields:
367+
missing_fields.append(field)
368+
369+
if missing_fields:
370+
raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}"
371+
.format(", ".join(missing_fields)))
372+
373+
headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token'])
374+
headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id']
375+
headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster']
308376

309377
def get(self, url: str, query: Optional[dict] = None) -> Any:
310378
return self.send_request(url, method='GET', query=query)
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2025 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
import pytest
19+
import time
20+
from unittest.mock import Mock, patch
21+
22+
from confluent_kafka.schema_registry.schema_registry_client import (_OAuthClient, _StaticFieldProvider,
23+
_CustomOAuthClient, SchemaRegistryClient)
24+
from confluent_kafka.schema_registry.error import OAuthTokenError
25+
26+
"""
27+
Tests to ensure OAuth client is set up correctly.
28+
29+
"""
30+
31+
32+
def custom_oauth_function(config: dict) -> dict:
33+
return config
34+
35+
36+
TEST_TOKEN = 'token123'
37+
TEST_CLUSTER = 'lsrc-cluster'
38+
TEST_POOL = 'pool-id'
39+
TEST_FUNCTION = custom_oauth_function
40+
TEST_CONFIG = {'bearer.auth.token': TEST_TOKEN, 'bearer.auth.logical.cluster': TEST_CLUSTER,
41+
'bearer.auth.identity.pool.id': TEST_POOL}
42+
TEST_URL = 'http://SchemaRegistry:65534'
43+
44+
45+
def test_expiry():
46+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000)
47+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1}
48+
assert not oauth_client.token_expired()
49+
time.sleep(1.5)
50+
assert oauth_client.token_expired()
51+
52+
53+
def test_get_token():
54+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000)
55+
56+
def update_token1():
57+
oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'}
58+
59+
def update_token2():
60+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'}
61+
62+
oauth_client.generate_access_token = Mock(side_effect=update_token1)
63+
oauth_client.get_access_token()
64+
assert oauth_client.generate_access_token.call_count == 1
65+
assert oauth_client.token['access_token'] == '123'
66+
67+
oauth_client.generate_access_token = Mock(side_effect=update_token2)
68+
oauth_client.get_access_token()
69+
# Call count resets to 1 after reassigning generate_access_token
70+
assert oauth_client.generate_access_token.call_count == 1
71+
assert oauth_client.token['access_token'] == '1234'
72+
73+
oauth_client.get_access_token()
74+
assert oauth_client.generate_access_token.call_count == 1
75+
76+
77+
def test_generate_token_retry_logic():
78+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 5, 1000, 20000)
79+
80+
with (patch("confluent_kafka.schema_registry.schema_registry_client.time.sleep") as mock_sleep,
81+
patch("confluent_kafka.schema_registry.schema_registry_client.full_jitter") as mock_jitter):
82+
83+
with pytest.raises(OAuthTokenError):
84+
oauth_client.generate_access_token()
85+
86+
assert mock_sleep.call_count == 5
87+
assert mock_jitter.call_count == 5
88+
89+
90+
def test_static_field_provider():
91+
static_field_provider = _StaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, TEST_POOL)
92+
bearer_fields = static_field_provider.get_bearer_fields()
93+
94+
assert bearer_fields == TEST_CONFIG
95+
96+
97+
def test_custom_oauth_client():
98+
custom_oauth_client = _CustomOAuthClient(TEST_FUNCTION, TEST_CONFIG)
99+
100+
assert custom_oauth_client.get_bearer_fields() == custom_oauth_client.get_bearer_fields()
101+
102+
103+
def test_bearer_field_headers_missing():
104+
def empty_custom(config):
105+
return {}
106+
107+
conf = {'url': TEST_URL,
108+
'bearer.auth.credentials.source': 'CUSTOM',
109+
'bearer.auth.custom.provider.function': empty_custom,
110+
'bearer.auth.custom.provider.config': TEST_CONFIG}
111+
112+
headers = {'Accept': "application/vnd.schemaregistry.v1+json,"
113+
" application/vnd.schemaregistry+json,"
114+
" application/json"}
115+
116+
client = SchemaRegistryClient(conf)
117+
118+
with pytest.raises(ValueError, match=r"Missing required bearer auth fields, "
119+
r"needs to be set in config or custom function: (.*)"):
120+
client._rest_client.handle_bearer_auth(headers)
121+
122+
123+
def test_bearer_field_headers_valid():
124+
conf = {'url': TEST_URL,
125+
'bearer.auth.credentials.source': 'CUSTOM',
126+
'bearer.auth.custom.provider.function': TEST_FUNCTION,
127+
'bearer.auth.custom.provider.config': TEST_CONFIG}
128+
129+
client = SchemaRegistryClient(conf)
130+
131+
headers = {'Accept': "application/vnd.schemaregistry.v1+json,"
132+
" application/vnd.schemaregistry+json,"
133+
" application/json"}
134+
135+
client._rest_client.handle_bearer_auth(headers)
136+
137+
assert 'Authorization' in headers
138+
assert 'Confluent-Identity-Pool-Id' in headers
139+
assert 'target-sr-cluster' in headers
140+
assert headers['Authorization'] == "Bearer {}".format(TEST_CONFIG['bearer.auth.token'])
141+
assert headers['Confluent-Identity-Pool-Id'] == TEST_CONFIG['bearer.auth.identity.pool.id']
142+
assert headers['target-sr-cluster'] == TEST_CONFIG['bearer.auth.logical.cluster']

tests/schema_registry/test_config.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,6 @@ def test_oauth_bearer_config_valid():
212212

213213
client = SchemaRegistryClient(conf)
214214

215-
assert client._rest_client.logical_cluster == TEST_CLUSTER
216-
assert client._rest_client.identity_pool_id == TEST_POOL
217215
assert client._rest_client.client_id == TEST_USERNAME
218216
assert client._rest_client.client_secret == TEST_USER_PASSWORD
219217
assert client._rest_client.scope == TEST_SCOPE
@@ -230,6 +228,31 @@ def test_static_bearer_config():
230228
SchemaRegistryClient(conf)
231229

232230

231+
def test_custom_bearer_config():
232+
conf = {'url': TEST_URL,
233+
'bearer.auth.credentials.source': 'CUSTOM'}
234+
235+
with pytest.raises(ValueError, match='Missing required custom OAuth configuration properties:'):
236+
SchemaRegistryClient(conf)
237+
238+
239+
def test_custom_bearer_config_valid():
240+
def custom_function(config: dict):
241+
return {}
242+
243+
custom_config = {}
244+
245+
conf = {'url': TEST_URL,
246+
'bearer.auth.credentials.source': 'CUSTOM',
247+
'bearer.auth.custom.provider.function': custom_function,
248+
'bearer.auth.custom.provider.config': custom_config}
249+
250+
client = SchemaRegistryClient(conf)
251+
252+
assert client._rest_client.bearer_field_provider.custom_function == custom_function
253+
assert client._rest_client.bearer_field_provider.custom_config == custom_config
254+
255+
233256
def test_config_unknown_prop():
234257
conf = {'url': TEST_URL,
235258
'basic.auth.credentials.source': 'SASL_INHERIT',

0 commit comments

Comments
 (0)