1717import re
1818import threading
1919import webbrowser
20- from typing import Callable , List , Optional
20+ from typing import Any , Callable , Dict , List , Optional , Tuple
2121from urllib .parse import urlparse
2222
23- from requests import Request
23+ from requests import PreparedRequest , Request , Response , Session
2424from requests .auth import AuthBase , extract_cookies_to_jar
2525from requests .utils import parse_dict_header
2626
3232
3333class Authentication (metaclass = abc .ABCMeta ):
3434 @abc .abstractmethod
35- def set_http_session (self , http_session ) :
35+ def set_http_session (self , http_session : Session ) -> Session :
3636 pass
3737
38- def get_exceptions (self ):
38+ def get_exceptions (self ) -> Tuple [ Any , ...] :
3939 return tuple ()
4040
4141
4242class KerberosAuthentication (Authentication ):
4343 def __init__ (
4444 self ,
4545 config : Optional [str ] = None ,
46- service_name : str = None ,
46+ service_name : Optional [ str ] = None ,
4747 mutual_authentication : bool = False ,
4848 force_preemptive : bool = False ,
4949 hostname_override : Optional [str ] = None ,
@@ -62,7 +62,7 @@ def __init__(
6262 self ._delegate = delegate
6363 self ._ca_bundle = ca_bundle
6464
65- def set_http_session (self , http_session ) :
65+ def set_http_session (self , http_session : Session ) -> Session :
6666 try :
6767 import requests_kerberos
6868 except ImportError :
@@ -84,15 +84,15 @@ def set_http_session(self, http_session):
8484 http_session .verify = self ._ca_bundle
8585 return http_session
8686
87- def get_exceptions (self ):
87+ def get_exceptions (self ) -> Tuple [ Any , ...] :
8888 try :
8989 from requests_kerberos .exceptions import KerberosExchangeError
9090
91- return ( KerberosExchangeError ,)
91+ return KerberosExchangeError ,
9292 except ImportError :
9393 raise RuntimeError ("unable to import requests_kerberos" )
9494
95- def __eq__ (self , other ) :
95+ def __eq__ (self , other : object ) -> bool :
9696 if not isinstance (other , KerberosAuthentication ):
9797 return False
9898 return (self ._config == other ._config
@@ -107,11 +107,11 @@ def __eq__(self, other):
107107
108108
109109class BasicAuthentication (Authentication ):
110- def __init__ (self , username , password ):
110+ def __init__ (self , username : str , password : str ):
111111 self ._username = username
112112 self ._password = password
113113
114- def set_http_session (self , http_session ) :
114+ def set_http_session (self , http_session : Session ) -> Session :
115115 try :
116116 import requests .auth
117117 except ImportError :
@@ -120,10 +120,10 @@ def set_http_session(self, http_session):
120120 http_session .auth = requests .auth .HTTPBasicAuth (self ._username , self ._password )
121121 return http_session
122122
123- def get_exceptions (self ):
123+ def get_exceptions (self ) -> Tuple [ Any , ...] :
124124 return ()
125125
126- def __eq__ (self , other ) :
126+ def __eq__ (self , other : object ) -> bool :
127127 if not isinstance (other , BasicAuthentication ):
128128 return False
129129 return self ._username == other ._username and self ._password == other ._password
@@ -134,27 +134,27 @@ class _BearerAuth(AuthBase):
134134 Custom implementation of Authentication class for bearer token
135135 """
136136
137- def __init__ (self , token ):
137+ def __init__ (self , token : str ):
138138 self .token = token
139139
140- def __call__ (self , r ) :
140+ def __call__ (self , r : PreparedRequest ) -> PreparedRequest :
141141 r .headers ["Authorization" ] = "Bearer " + self .token
142142 return r
143143
144144
145145class JWTAuthentication (Authentication ):
146146
147- def __init__ (self , token ):
147+ def __init__ (self , token : str ):
148148 self .token = token
149149
150- def set_http_session (self , http_session ) :
150+ def set_http_session (self , http_session : Session ) -> Session :
151151 http_session .auth = _BearerAuth (self .token )
152152 return http_session
153153
154- def get_exceptions (self ):
154+ def get_exceptions (self ) -> Tuple [ Any , ...] :
155155 return ()
156156
157- def __eq__ (self , other ) :
157+ def __eq__ (self , other : object ) -> bool :
158158 if not isinstance (other , JWTAuthentication ):
159159 return False
160160 return self .token == other .token
@@ -197,7 +197,7 @@ class CompositeRedirectHandler(RedirectHandler):
197197 def __init__ (self , handlers : List [Callable [[str ], None ]]):
198198 self .handlers = handlers
199199
200- def __call__ (self , url : str ):
200+ def __call__ (self , url : str ) -> None :
201201 for handler in self .handlers :
202202 handler (url )
203203
@@ -208,11 +208,11 @@ class _OAuth2TokenCache(metaclass=abc.ABCMeta):
208208 """
209209
210210 @abc .abstractmethod
211- def get_token_from_cache (self , host : str ) -> Optional [str ]:
211+ def get_token_from_cache (self , host : Optional [ str ] ) -> Optional [str ]:
212212 pass
213213
214214 @abc .abstractmethod
215- def store_token_to_cache (self , host : str , token : str ) -> None :
215+ def store_token_to_cache (self , host : Optional [ str ] , token : str ) -> None :
216216 pass
217217
218218
@@ -221,13 +221,13 @@ class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
221221 In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
222222 """
223223
224- def __init__ (self ):
225- self ._cache = {}
224+ def __init__ (self ) -> None :
225+ self ._cache : Dict [ Optional [ str ], str ] = {}
226226
227- def get_token_from_cache (self , host : str ) -> Optional [str ]:
227+ def get_token_from_cache (self , host : Optional [ str ] ) -> Optional [str ]:
228228 return self ._cache .get (host )
229229
230- def store_token_to_cache (self , host : str , token : str ) -> None :
230+ def store_token_to_cache (self , host : Optional [ str ] , token : str ) -> None :
231231 self ._cache [host ] = token
232232
233233
@@ -236,26 +236,26 @@ class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
236236 Keyring Token Cache implementation
237237 """
238238
239- def __init__ (self ):
239+ def __init__ (self ) -> None :
240240 super ().__init__ ()
241241 try :
242242 self ._keyring = importlib .import_module ("keyring" )
243243 except ImportError :
244- self ._keyring = None
244+ self ._keyring = None # type: ignore
245245 logger .info ("keyring module not found. OAuth2 token will not be stored in keyring." )
246246
247247 def is_keyring_available (self ) -> bool :
248248 return self ._keyring is not None
249249
250- def get_token_from_cache (self , host : str ) -> Optional [str ]:
250+ def get_token_from_cache (self , host : Optional [ str ] ) -> Optional [str ]:
251251 try :
252252 return self ._keyring .get_password (host , "token" )
253253 except self ._keyring .errors .NoKeyringError as e :
254254 raise trino .exceptions .NotSupportedError ("Although keyring module is installed no backend has been "
255255 "detected, check https://pypi.org/project/keyring/ for more "
256256 "information." ) from e
257257
258- def store_token_to_cache (self , host : str , token : str ) -> None :
258+ def store_token_to_cache (self , host : Optional [ str ] , token : str ) -> None :
259259 try :
260260 # keyring is installed, so we can store the token for reuse within multiple threads
261261 self ._keyring .set_password (host , "token" , token )
@@ -280,18 +280,18 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]):
280280 self ._inside_oauth_attempt_lock = threading .Lock ()
281281 self ._inside_oauth_attempt_blocker = threading .Event ()
282282
283- def __call__ (self , r ) :
283+ def __call__ (self , r : PreparedRequest ) -> PreparedRequest :
284284 host = self ._determine_host (r .url )
285285 token = self ._get_token_from_cache (host )
286286
287287 if token is not None :
288288 r .headers ['Authorization' ] = "Bearer " + token
289289
290- r .register_hook ('response' , self ._authenticate )
290+ r .register_hook ('response' , self ._authenticate ) # type: ignore
291291
292292 return r
293293
294- def _authenticate (self , response , ** kwargs ) :
294+ def _authenticate (self , response : Response , ** kwargs : Any ) -> Optional [ Response ] :
295295 if not 400 <= response .status_code < 500 :
296296 return response
297297
@@ -310,7 +310,7 @@ def _authenticate(self, response, **kwargs):
310310
311311 return self ._retry_request (response , ** kwargs )
312312
313- def _attempt_oauth (self , response , ** kwargs ) :
313+ def _attempt_oauth (self , response : Response , ** kwargs : Any ) -> None :
314314 # we have to handle the authentication, may be token the token expired, or it wasn't there at all
315315 auth_info = response .headers .get ('WWW-Authenticate' )
316316 if not auth_info :
@@ -319,7 +319,8 @@ def _attempt_oauth(self, response, **kwargs):
319319 if not _OAuth2TokenBearer ._BEARER_PREFIX .search (auth_info ):
320320 raise exceptions .TrinoAuthError (f"Error: header info didn't match { auth_info } " )
321321
322- auth_info_headers = parse_dict_header (_OAuth2TokenBearer ._BEARER_PREFIX .sub ("" , auth_info , count = 1 ))
322+ auth_info_headers = parse_dict_header (
323+ _OAuth2TokenBearer ._BEARER_PREFIX .sub ("" , auth_info , count = 1 )) # type: ignore
323324
324325 auth_server = auth_info_headers .get ('x_redirect_server' )
325326 token_server = auth_info_headers .get ('x_token_server' )
@@ -341,23 +342,26 @@ def _attempt_oauth(self, response, **kwargs):
341342 host = self ._determine_host (request .url )
342343 self ._store_token_to_cache (host , token )
343344
344- def _retry_request (self , response , ** kwargs ) :
345+ def _retry_request (self , response : Response , ** kwargs : Any ) -> Optional [ Response ] :
345346 request = response .request .copy ()
346- extract_cookies_to_jar (request ._cookies , response .request , response .raw )
347- request .prepare_cookies (request ._cookies )
347+ extract_cookies_to_jar (request ._cookies , response .request , response .raw ) # type: ignore
348+ request .prepare_cookies (request ._cookies ) # type: ignore
348349
349350 host = self ._determine_host (response .request .url )
350- request .headers ['Authorization' ] = "Bearer " + self ._get_token_from_cache (host )
351- retry_response = response .connection .send (request , ** kwargs )
351+ token = self ._get_token_from_cache (host )
352+ if token is not None :
353+ request .headers ['Authorization' ] = "Bearer " + token
354+ retry_response = response .connection .send (request , ** kwargs ) # type: ignore
352355 retry_response .history .append (response )
353356 retry_response .request = request
354357 return retry_response
355358
356- def _get_token (self , token_server , response , ** kwargs ) :
359+ def _get_token (self , token_server : str , response : Response , ** kwargs : Any ) -> str :
357360 attempts = 0
358361 while attempts < self .MAX_OAUTH_ATTEMPTS :
359362 attempts += 1
360- with response .connection .send (Request (method = 'GET' , url = token_server ).prepare (), ** kwargs ) as response :
363+ with response .connection .send (Request ( # type: ignore
364+ method = 'GET' , url = token_server ).prepare (), ** kwargs ) as response :
361365 if response .status_code == 200 :
362366 token_response = json .loads (response .text )
363367 token = token_response .get ('token' )
@@ -377,53 +381,53 @@ def _get_token(self, token_server, response, **kwargs):
377381
378382 raise exceptions .TrinoAuthError ("Exceeded max attempts while getting the token" )
379383
380- def _get_token_from_cache (self , host : str ) -> Optional [str ]:
384+ def _get_token_from_cache (self , host : Optional [ str ] ) -> Optional [str ]:
381385 with self ._token_lock :
382386 return self ._token_cache .get_token_from_cache (host )
383387
384- def _store_token_to_cache (self , host : str , token : str ) -> None :
388+ def _store_token_to_cache (self , host : Optional [ str ] , token : str ) -> None :
385389 with self ._token_lock :
386390 self ._token_cache .store_token_to_cache (host , token )
387391
388392 @staticmethod
389- def _determine_host (url ) -> Optional [str ]:
393+ def _determine_host (url : Optional [str ]) -> Any :
390394 return urlparse (url ).hostname
391395
392396
393397class OAuth2Authentication (Authentication ):
394- def __init__ (self , redirect_auth_url_handler = CompositeRedirectHandler ([
398+ def __init__ (self , redirect_auth_url_handler : CompositeRedirectHandler = CompositeRedirectHandler ([
395399 WebBrowserRedirectHandler (),
396400 ConsoleRedirectHandler ()
397401 ])):
398402 self ._redirect_auth_url = redirect_auth_url_handler
399403 self ._bearer = _OAuth2TokenBearer (self ._redirect_auth_url )
400404
401- def set_http_session (self , http_session ) :
405+ def set_http_session (self , http_session : Session ) -> Session :
402406 http_session .auth = self ._bearer
403407 return http_session
404408
405- def get_exceptions (self ):
409+ def get_exceptions (self ) -> Tuple [ Any , ...] :
406410 return ()
407411
408- def __eq__ (self , other ) :
412+ def __eq__ (self , other : object ) -> bool :
409413 if not isinstance (other , OAuth2Authentication ):
410414 return False
411415 return self ._redirect_auth_url == other ._redirect_auth_url
412416
413417
414418class CertificateAuthentication (Authentication ):
415- def __init__ (self , cert , key ):
419+ def __init__ (self , cert : str , key : str ):
416420 self ._cert = cert
417421 self ._key = key
418422
419- def set_http_session (self , http_session ) :
423+ def set_http_session (self , http_session : Session ) -> Session :
420424 http_session .cert = (self ._cert , self ._key )
421425 return http_session
422426
423- def get_exceptions (self ):
427+ def get_exceptions (self ) -> Tuple [ Any , ...] :
424428 return ()
425429
426- def __eq__ (self , other ) :
430+ def __eq__ (self , other : object ) -> bool :
427431 if not isinstance (other , CertificateAuthentication ):
428432 return False
429433 return self ._cert == other ._cert and self ._key == other ._key
0 commit comments