1515# See the License for the specific language governing permissions and
1616# limitations under the License.
1717#
18-
18+ import abc
1919import json
2020import logging
2121import random
3030from enum import Enum
3131from threading import Lock
3232from typing import List , Dict , Type , TypeVar , \
33- cast , Optional , Union , Any , Tuple
33+ cast , Optional , Union , Any , Tuple , Callable
3434
3535from cachetools import TTLCache , LRUCache
3636from httpx import Response
@@ -62,18 +62,50 @@ def _urlencode(value: str) -> str:
6262VALID_AUTH_PROVIDERS = ['URL' , 'USER_INFO' ]
6363
6464
65- class _OAuthClient :
66- def __init__ (self , client_id : str , client_secret : str , scope : str , token_endpoint : str ,
67- max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
65+ class _BearerFieldProvider (metaclass = abc .ABCMeta ):
66+ @abc .abstractmethod
67+ def get_bearer_fields (self ) -> dict :
68+ raise NotImplementedError
69+
70+
71+ class _StaticFieldProvider (_BearerFieldProvider ):
72+ def __init__ (self , token : str , logical_cluster : str , identity_pool : str ):
73+ self .token = token
74+ self .logical_cluster = logical_cluster
75+ self .identity_pool = identity_pool
76+
77+ def get_bearer_fields (self ) -> dict :
78+ return {'bearer.auth.token' : self .token , 'bearer.auth.logical.cluster' : self .logical_cluster ,
79+ 'bearer.auth.identity.pool.id' : self .identity_pool }
80+
81+
82+ class _CustomOAuthClient (_BearerFieldProvider ):
83+ def __init__ (self , custom_function : Callable [[Dict ], Dict ], custom_config : dict ):
84+ self .custom_function = custom_function
85+ self .custom_config = custom_config
86+
87+ def get_bearer_fields (self ) -> dict :
88+ return self .custom_function (self .custom_config )
89+
90+
91+ class _OAuthClient (_BearerFieldProvider ):
92+ def __init__ (self , client_id : str , client_secret : str , scope : str , token_endpoint : str , logical_cluster : str ,
93+ identity_pool : str , max_retries : int , retries_wait_ms : int , retries_max_wait_ms : int ):
6894 self .token = None
95+ self .logical_cluster = logical_cluster
96+ self .identity_pool = identity_pool
6997 self .client = OAuth2Client (client_id = client_id , client_secret = client_secret , scope = scope )
7098 self .token_endpoint = token_endpoint
7199 self .max_retries = max_retries
72100 self .retries_wait_ms = retries_wait_ms
73101 self .retries_max_wait_ms = retries_max_wait_ms
74102 self .token_expiry_threshold = 0.8
75103
76- def token_expired (self ):
104+ def get_bearer_fields (self ) -> dict :
105+ return {'bearer.auth.token' : self .get_access_token (), 'bearer.auth.logical.cluster' : self .logical_cluster ,
106+ 'bearer.auth.identity.pool.id' : self .identity_pool }
107+
108+ def token_expired (self ) -> bool :
77109 expiry_window = self .token ['expires_in' ] * self .token_expiry_threshold
78110
79111 return self .token ['expires_at' ] < time .time () + expiry_window
@@ -84,7 +116,7 @@ def get_access_token(self) -> str:
84116
85117 return self .token ['access_token' ]
86118
87- def generate_access_token (self ):
119+ def generate_access_token (self ) -> None :
88120 for i in range (self .max_retries + 1 ):
89121 try :
90122 self .token = self .client .fetch_token (url = self .token_endpoint , grant_type = 'client_credentials' )
@@ -206,23 +238,27 @@ def __init__(self, conf: dict):
206238 + str (type (retries_max_wait_ms )))
207239 self .retries_max_wait_ms = retries_max_wait_ms
208240
209- self .oauth_client = None
241+ self .bearer_field_provider = None
242+ logical_cluster = None
243+ identity_pool = None
210244 self .bearer_auth_credentials_source = conf_copy .pop ('bearer.auth.credentials.source' , None )
211245 if self .bearer_auth_credentials_source is not None :
212246 self .auth = None
213- headers = ['bearer.auth.logical.cluster' , 'bearer.auth.identity.pool.id' ]
214- missing_headers = [header for header in headers if header not in conf_copy ]
215- if missing_headers :
216- raise ValueError ("Missing required bearer configuration properties: {}"
217- .format (", " .join (missing_headers )))
218247
219- self .logical_cluster = conf_copy .pop ('bearer.auth.logical.cluster' )
220- if not isinstance (self .logical_cluster , str ):
221- raise TypeError ("logical cluster must be a str, not " + str (type (self .logical_cluster )))
248+ if self .bearer_auth_credentials_source in {'OAUTHBEARER' , 'STATIC_TOKEN' }:
249+ headers = ['bearer.auth.logical.cluster' , 'bearer.auth.identity.pool.id' ]
250+ missing_headers = [header for header in headers if header not in conf_copy ]
251+ if missing_headers :
252+ raise ValueError ("Missing required bearer configuration properties: {}"
253+ .format (", " .join (missing_headers )))
222254
223- self .identity_pool_id = conf_copy .pop ('bearer.auth.identity.pool.id' )
224- if not isinstance (self .identity_pool_id , str ):
225- raise TypeError ("identity pool id must be a str, not " + str (type (self .identity_pool_id )))
255+ logical_cluster = conf_copy .pop ('bearer.auth.logical.cluster' )
256+ if not isinstance (logical_cluster , str ):
257+ raise TypeError ("logical cluster must be a str, not " + str (type (logical_cluster )))
258+
259+ identity_pool = conf_copy .pop ('bearer.auth.identity.pool.id' )
260+ if not isinstance (identity_pool , str ):
261+ raise TypeError ("identity pool id must be a str, not " + str (type (identity_pool )))
226262
227263 if self .bearer_auth_credentials_source == 'OAUTHBEARER' :
228264 properties_list = ['bearer.auth.client.id' , 'bearer.auth.client.secret' , 'bearer.auth.scope' ,
@@ -249,15 +285,38 @@ def __init__(self, conf: dict):
249285 raise TypeError ("bearer.auth.issuer.endpoint.url must be a str, not "
250286 + str (type (self .token_endpoint )))
251287
252- self .oauth_client = _OAuthClient (self .client_id , self .client_secret , self .scope , self .token_endpoint ,
253- self .max_retries , self .retries_wait_ms , self .retries_max_wait_ms )
254-
288+ self .bearer_field_provider = _OAuthClient (self .client_id , self .client_secret , self .scope ,
289+ self .token_endpoint , logical_cluster , identity_pool ,
290+ self .max_retries , self .retries_wait_ms ,
291+ self .retries_max_wait_ms )
255292 elif self .bearer_auth_credentials_source == 'STATIC_TOKEN' :
256293 if 'bearer.auth.token' not in conf_copy :
257294 raise ValueError ("Missing bearer.auth.token" )
258- self .bearer_token = conf_copy .pop ('bearer.auth.token' )
259- if not isinstance (self .bearer_token , string_type ):
260- raise TypeError ("bearer.auth.token must be a str, not " + str (type (self .bearer_token )))
295+ static_token = conf_copy .pop ('bearer.auth.token' )
296+ self .bearer_field_provider = _StaticFieldProvider (static_token , logical_cluster , identity_pool )
297+ if not isinstance (static_token , string_type ):
298+ raise TypeError ("bearer.auth.token must be a str, not " + str (type (static_token )))
299+ elif self .bearer_auth_credentials_source == 'CUSTOM' :
300+ custom_bearer_properties = ['bearer.auth.custom.provider.function' ,
301+ 'bearer.auth.custom.provider.config' ]
302+ missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy ]
303+ if missing_custom_properties :
304+ raise ValueError ("Missing required custom OAuth configuration properties: {}" .
305+ format (", " .join (missing_custom_properties )))
306+
307+ custom_function = conf_copy .pop ('bearer.auth.custom.provider.function' )
308+ if not callable (custom_function ):
309+ raise TypeError ("bearer.auth.custom.provider.function must be a callable, not "
310+ + str (type (custom_function )))
311+
312+ custom_config = conf_copy .pop ('bearer.auth.custom.provider.config' )
313+ if not isinstance (custom_config , dict ):
314+ raise TypeError ("bearer.auth.custom.provider.config must be a dict, not "
315+ + str (type (custom_config )))
316+
317+ self .bearer_field_provider = _CustomOAuthClient (custom_function , custom_config )
318+ else :
319+ raise ValueError ('Unrecognized bearer.auth.credentials.source' )
261320
262321 # Any leftover keys are unknown to _RestClient
263322 if len (conf_copy ) > 0 :
@@ -298,13 +357,22 @@ def __init__(self, conf: dict):
298357 timeout = self .timeout
299358 )
300359
301- def handle_bearer_auth (self , headers : dict ):
302- token = self .bearer_token
303- if self .oauth_client :
304- token = self .oauth_client .get_access_token ()
305- headers ["Authorization" ] = "Bearer {}" .format (token )
306- headers ['Confluent-Identity-Pool-Id' ] = self .identity_pool_id
307- headers ['target-sr-cluster' ] = self .logical_cluster
360+ def handle_bearer_auth (self , headers : dict ) -> None :
361+ bearer_fields = self .bearer_field_provider .get_bearer_fields ()
362+ required_fields = ['bearer.auth.token' , 'bearer.auth.identity.pool.id' , 'bearer.auth.logical.cluster' ]
363+
364+ missing_fields = []
365+ for field in required_fields :
366+ if field not in bearer_fields :
367+ missing_fields .append (field )
368+
369+ if missing_fields :
370+ raise ValueError ("Missing required bearer auth fields, needs to be set in config or custom function: {}"
371+ .format (", " .join (missing_fields )))
372+
373+ headers ["Authorization" ] = "Bearer {}" .format (bearer_fields ['bearer.auth.token' ])
374+ headers ['Confluent-Identity-Pool-Id' ] = bearer_fields ['bearer.auth.identity.pool.id' ]
375+ headers ['target-sr-cluster' ] = bearer_fields ['bearer.auth.logical.cluster' ]
308376
309377 def get (self , url : str , query : Optional [dict ] = None ) -> Any :
310378 return self .send_request (url , method = 'GET' , query = query )
0 commit comments