5555from zoneinfo import ZoneInfo
5656
5757import requests
58+ from requests import Response
59+ from requests import Session
60+ from requests .structures import CaseInsensitiveDict
5861from tzlocal import get_localzone_name # type: ignore
5962
6063import trino .logging
6164from trino import constants
6265from trino import exceptions
6366from trino ._version import __version__
67+ from trino .auth import Authentication
68+ from trino .exceptions import TrinoExternalError
69+ from trino .exceptions import TrinoQueryError
70+ from trino .exceptions import TrinoUserError
6471from trino .mapper import RowMapper
6572from trino .mapper import RowMapperFactory
6673
@@ -271,27 +278,27 @@ def __setstate__(self, state):
271278 self ._object_lock = threading .Lock ()
272279
273280
274- def get_header_values (headers , header ) :
281+ def get_header_values (headers : CaseInsensitiveDict [ str ] , header : str ) -> List [ str ] :
275282 return [val .strip () for val in headers [header ].split ("," )]
276283
277284
278- def get_session_property_values (headers , header ) :
285+ def get_session_property_values (headers : CaseInsensitiveDict [ str ] , header : str ) -> List [ Tuple [ str , str ]] :
279286 kvs = get_header_values (headers , header )
280287 return [
281288 (k .strip (), urllib .parse .unquote_plus (v .strip ()))
282289 for k , v in (kv .split ("=" , 1 ) for kv in kvs if kv )
283290 ]
284291
285292
286- def get_prepared_statement_values (headers , header ) :
293+ def get_prepared_statement_values (headers : CaseInsensitiveDict [ str ] , header : str ) -> List [ Tuple [ str , str ]] :
287294 kvs = get_header_values (headers , header )
288295 return [
289296 (k .strip (), urllib .parse .unquote_plus (v .strip ()))
290297 for k , v in (kv .split ("=" , 1 ) for kv in kvs if kv )
291298 ]
292299
293300
294- def get_roles_values (headers , header ) :
301+ def get_roles_values (headers : CaseInsensitiveDict [ str ] , header : str ) -> List [ Tuple [ str , str ]] :
295302 kvs = get_header_values (headers , header )
296303 return [
297304 (k .strip (), urllib .parse .unquote_plus (v .strip ()))
@@ -414,9 +421,9 @@ def __init__(
414421 host : str ,
415422 port : int ,
416423 client_session : ClientSession ,
417- http_session : Any = None ,
418- http_scheme : str = None ,
419- auth : Optional [Any ] = constants .DEFAULT_AUTH ,
424+ http_session : Optional [ Session ] = None ,
425+ http_scheme : Optional [ str ] = None ,
426+ auth : Optional [Authentication ] = constants .DEFAULT_AUTH ,
420427 max_attempts : int = MAX_ATTEMPTS ,
421428 request_timeout : Union [float , Tuple [float , float ]] = constants .DEFAULT_REQUEST_TIMEOUT ,
422429 handle_retry = _RetryWithExponentialBackoff (),
@@ -454,16 +461,16 @@ def __init__(
454461 self .max_attempts = max_attempts
455462
456463 @property
457- def transaction_id (self ):
464+ def transaction_id (self ) -> Optional [ str ] :
458465 return self ._client_session .transaction_id
459466
460467 @transaction_id .setter
461- def transaction_id (self , value ) :
468+ def transaction_id (self , value : Optional [ str ]) -> None :
462469 self ._client_session .transaction_id = value
463470
464471 @property
465- def http_headers (self ) -> Dict [ str , str ]:
466- headers = requests . structures . CaseInsensitiveDict ()
472+ def http_headers (self ) -> CaseInsensitiveDict [ str ]:
473+ headers : CaseInsensitiveDict [ str ] = CaseInsensitiveDict ()
467474
468475 headers [constants .HEADER_CATALOG ] = self ._client_session .catalog
469476 headers [constants .HEADER_SCHEMA ] = self ._client_session .schema
@@ -525,7 +532,7 @@ def max_attempts(self) -> int:
525532 return self ._max_attempts
526533
527534 @max_attempts .setter
528- def max_attempts (self , value ) -> None :
535+ def max_attempts (self , value : int ) -> None :
529536 self ._max_attempts = value
530537 if value == 1 : # No retry
531538 self ._get = self ._http_session .get
@@ -547,7 +554,7 @@ def max_attempts(self, value) -> None:
547554 self ._post = with_retry (self ._http_session .post )
548555 self ._delete = with_retry (self ._http_session .delete )
549556
550- def get_url (self , path ) -> str :
557+ def get_url (self , path : str ) -> str :
551558 return "{protocol}://{host}:{port}{path}" .format (
552559 protocol = self ._http_scheme , host = self ._host , port = self ._port , path = path
553560 )
@@ -560,7 +567,7 @@ def statement_url(self) -> str:
560567 def next_uri (self ) -> Optional [str ]:
561568 return self ._next_uri
562569
563- def post (self , sql : str , additional_http_headers : Optional [Dict [str , Any ]] = None ):
570+ def post (self , sql : str , additional_http_headers : Optional [Dict [str , Any ]] = None ) -> Response :
564571 data = sql .encode ("utf-8" )
565572 # Deep copy of the http_headers dict since they may be modified for this
566573 # request by the provided additional_http_headers
@@ -578,18 +585,19 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non
578585 )
579586 return http_response
580587
581- def get (self , url : str ):
588+ def get (self , url : str ) -> Response :
582589 return self ._get (
583590 url ,
584591 headers = self .http_headers ,
585592 timeout = self ._request_timeout ,
586593 proxies = PROXIES ,
587594 )
588595
589- def delete (self , url ) :
596+ def delete (self , url : str ) -> Response :
590597 return self ._delete (url , timeout = self ._request_timeout , proxies = PROXIES )
591598
592- def _process_error (self , error , query_id ):
599+ @staticmethod
600+ def _process_error (error , query_id : Optional [str ]) -> Union [TrinoExternalError , TrinoQueryError , TrinoUserError ]:
593601 error_type = error ["errorType" ]
594602 if error_type == "EXTERNAL" :
595603 raise exceptions .TrinoExternalError (error , query_id )
@@ -598,7 +606,8 @@ def _process_error(self, error, query_id):
598606
599607 return exceptions .TrinoQueryError (error , query_id )
600608
601- def raise_response_error (self , http_response ):
609+ @staticmethod
610+ def raise_response_error (http_response : Response ) -> None :
602611 if http_response .status_code == 502 :
603612 raise exceptions .Http502Error ("error 502: bad gateway" )
604613
@@ -615,7 +624,7 @@ def raise_response_error(self, http_response):
615624 )
616625 )
617626
618- def process (self , http_response ) -> TrinoStatus :
627+ def process (self , http_response : Response ) -> TrinoStatus :
619628 if not http_response .ok :
620629 self .raise_response_error (http_response )
621630
@@ -682,7 +691,8 @@ def process(self, http_response) -> TrinoStatus:
682691 columns = response .get ("columns" ),
683692 )
684693
685- def _verify_extra_credential (self , header ):
694+ @staticmethod
695+ def _verify_extra_credential (header : Tuple [str , str ]) -> None :
686696 """
687697 Verifies that key has ASCII only and non-whitespace characters.
688698 """
0 commit comments