2222import time as _time
2323from errno import EINTR as _EINTR
2424from ipaddress import ip_address as _ip_address
25+ from typing import TYPE_CHECKING , Any , Callable , List , Optional , TypeVar , Union
2526
2627from cryptography .x509 import load_der_x509_certificate as _load_der_x509_certificate
2728from OpenSSL import SSL as _SSL
3940from pymongo .socket_checker import _errno_from_exception
4041from pymongo .write_concern import validate_boolean
4142
43+ if TYPE_CHECKING :
44+ import socket
45+ from ssl import VerifyMode
46+
47+ from cryptography .x509 import Certificate
48+
49+ _T = TypeVar ("_T" )
50+
4251try :
4352 import certifi
4453
7382
7483# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
7584# not permitted for SNI hostname.
76- def _is_ip_address (address ) :
85+ def _is_ip_address (address : Any ) -> bool :
7786 try :
7887 _ip_address (address )
7988 return True
@@ -86,7 +95,7 @@ def _is_ip_address(address):
8695BLOCKING_IO_ERRORS = (_SSL .WantReadError , _SSL .WantWriteError , _SSL .WantX509LookupError )
8796
8897
89- def _ragged_eof (exc ) :
98+ def _ragged_eof (exc : BaseException ) -> bool :
9099 """Return True if the OpenSSL.SSL.SysCallError is a ragged EOF."""
91100 return exc .args == (- 1 , "Unexpected EOF" )
92101
@@ -95,12 +104,14 @@ def _ragged_eof(exc):
95104# https://github.com/pyca/pyopenssl/issues/176
96105# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
97106class _sslConn (_SSL .Connection ):
98- def __init__ (self , ctx , sock , suppress_ragged_eofs ):
107+ def __init__ (
108+ self , ctx : _SSL .Context , sock : Optional [socket .socket ], suppress_ragged_eofs : bool
109+ ):
99110 self .socket_checker = _SocketChecker ()
100111 self .suppress_ragged_eofs = suppress_ragged_eofs
101112 super ().__init__ (ctx , sock )
102113
103- def _call (self , call , * args , ** kwargs ) :
114+ def _call (self , call : Callable [..., _T ], * args : Any , ** kwargs : Any ) -> _T :
104115 timeout = self .gettimeout ()
105116 if timeout :
106117 start = _time .monotonic ()
@@ -127,10 +138,10 @@ def _call(self, call, *args, **kwargs):
127138 raise _socket .timeout ("timed out" )
128139 continue
129140
130- def do_handshake (self , * args , ** kwargs ) :
141+ def do_handshake (self , * args : Any , ** kwargs : Any ) -> None :
131142 return self ._call (super ().do_handshake , * args , ** kwargs )
132143
133- def recv (self , * args , ** kwargs ) :
144+ def recv (self , * args : Any , ** kwargs : Any ) -> bytes :
134145 try :
135146 return self ._call (super ().recv , * args , ** kwargs )
136147 except _SSL .SysCallError as exc :
@@ -139,7 +150,7 @@ def recv(self, *args, **kwargs):
139150 return b""
140151 raise
141152
142- def recv_into (self , * args , ** kwargs ) :
153+ def recv_into (self , * args : Any , ** kwargs : Any ) -> int :
143154 try :
144155 return self ._call (super ().recv_into , * args , ** kwargs )
145156 except _SSL .SysCallError as exc :
@@ -148,7 +159,7 @@ def recv_into(self, *args, **kwargs):
148159 return 0
149160 raise
150161
151- def sendall (self , buf , flags = 0 ):
162+ def sendall (self , buf : bytes , flags : int = 0 ) -> None : # type: ignore[override]
152163 view = memoryview (buf )
153164 total_length = len (buf )
154165 total_sent = 0
@@ -172,9 +183,9 @@ def sendall(self, buf, flags=0):
172183class _CallbackData :
173184 """Data class which is passed to the OCSP callback."""
174185
175- def __init__ (self ):
176- self .trusted_ca_certs = None
177- self .check_ocsp_endpoint = None
186+ def __init__ (self ) -> None :
187+ self .trusted_ca_certs : Optional [ List [ Certificate ]] = None
188+ self .check_ocsp_endpoint : Optional [ bool ] = None
178189 self .ocsp_response_cache = _OCSPCache ()
179190
180191
@@ -185,7 +196,7 @@ class SSLContext:
185196
186197 __slots__ = ("_protocol" , "_ctx" , "_callback_data" , "_check_hostname" )
187198
188- def __init__ (self , protocol ):
199+ def __init__ (self , protocol : int ):
189200 self ._protocol = protocol
190201 self ._ctx = _SSL .Context (self ._protocol )
191202 self ._callback_data = _CallbackData ()
@@ -198,66 +209,80 @@ def __init__(self, protocol):
198209 self ._ctx .set_ocsp_client_callback (callback = _ocsp_callback , data = self ._callback_data )
199210
200211 @property
201- def protocol (self ):
212+ def protocol (self ) -> int :
202213 """The protocol version chosen when constructing the context.
203214 This attribute is read-only.
204215 """
205216 return self ._protocol
206217
207- def __get_verify_mode (self ):
218+ def __get_verify_mode (self ) -> VerifyMode :
208219 """Whether to try to verify other peers' certificates and how to
209220 behave if verification fails. This attribute must be one of
210221 ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
211222 """
212223 return _REVERSE_VERIFY_MAP [self ._ctx .get_verify_mode ()]
213224
214- def __set_verify_mode (self , value ) :
225+ def __set_verify_mode (self , value : VerifyMode ) -> None :
215226 """Setter for verify_mode."""
216227
217- def _cb (connobj , x509obj , errnum , errdepth , retcode ):
228+ def _cb (
229+ connobj : _SSL .Connection ,
230+ x509obj : _crypto .X509 ,
231+ errnum : int ,
232+ errdepth : int ,
233+ retcode : int ,
234+ ) -> bool :
218235 # It seems we don't need to do anything here. Twisted doesn't,
219236 # and OpenSSL's SSL_CTX_set_verify let's you pass NULL
220237 # for the callback option. It's weird that PyOpenSSL requires
221238 # this.
222- return retcode
239+ # This is optional in pyopenssl >= 20 and can be removed once minimum
240+ # supported version is bumped
241+ # See: pyopenssl.org/en/latest/changelog.html#id47
242+ return bool (retcode )
223243
224244 self ._ctx .set_verify (_VERIFY_MAP [value ], _cb )
225245
226246 verify_mode = property (__get_verify_mode , __set_verify_mode )
227247
228- def __get_check_hostname (self ):
248+ def __get_check_hostname (self ) -> bool :
229249 return self ._check_hostname
230250
231- def __set_check_hostname (self , value ) :
251+ def __set_check_hostname (self , value : Any ) -> None :
232252 validate_boolean ("check_hostname" , value )
233253 self ._check_hostname = value
234254
235255 check_hostname = property (__get_check_hostname , __set_check_hostname )
236256
237- def __get_check_ocsp_endpoint (self ):
257+ def __get_check_ocsp_endpoint (self ) -> Optional [ bool ] :
238258 return self ._callback_data .check_ocsp_endpoint
239259
240- def __set_check_ocsp_endpoint (self , value ) :
260+ def __set_check_ocsp_endpoint (self , value : bool ) -> None :
241261 validate_boolean ("check_ocsp" , value )
242262 self ._callback_data .check_ocsp_endpoint = value
243263
244264 check_ocsp_endpoint = property (__get_check_ocsp_endpoint , __set_check_ocsp_endpoint )
245265
246- def __get_options (self ):
266+ def __get_options (self ) -> None :
247267 # Calling set_options adds the option to the existing bitmask and
248268 # returns the new bitmask.
249269 # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options
250270 return self ._ctx .set_options (0 )
251271
252- def __set_options (self , value ) :
272+ def __set_options (self , value : int ) -> None :
253273 # Explcitly convert to int, since newer CPython versions
254274 # use enum.IntFlag for options. The values are the same
255275 # regardless of implementation.
256276 self ._ctx .set_options (int (value ))
257277
258278 options = property (__get_options , __set_options )
259279
260- def load_cert_chain (self , certfile , keyfile = None , password = None ):
280+ def load_cert_chain (
281+ self ,
282+ certfile : Union [str , bytes ],
283+ keyfile : Union [str , bytes , None ] = None ,
284+ password : Optional [str ] = None ,
285+ ) -> None :
261286 """Load a private key and the corresponding certificate. The certfile
262287 string must be the path to a single file in PEM format containing the
263288 certificate as well as any number of CA certificates needed to
@@ -270,28 +295,32 @@ def load_cert_chain(self, certfile, keyfile=None, password=None):
270295 # Password callback MUST be set first or it will be ignored.
271296 if password :
272297
273- def _pwcb (max_length , prompt_twice , user_data ) :
298+ def _pwcb (max_length : int , prompt_twice : bool , user_data : bytes ) -> bytes :
274299 # XXX:We could check the password length against what OpenSSL
275300 # tells us is the max, but we can't raise an exception, so...
276301 # warn?
302+ assert password is not None
277303 return password .encode ("utf-8" )
278304
279305 self ._ctx .set_passwd_cb (_pwcb )
280306 self ._ctx .use_certificate_chain_file (certfile )
281307 self ._ctx .use_privatekey_file (keyfile or certfile )
282308 self ._ctx .check_privatekey ()
283309
284- def load_verify_locations (self , cafile = None , capath = None ):
310+ def load_verify_locations (
311+ self , cafile : Optional [str ] = None , capath : Optional [str ] = None
312+ ) -> None :
285313 """Load a set of "certification authority"(CA) certificates used to
286314 validate other peers' certificates when `~verify_mode` is other than
287315 ssl.CERT_NONE.
288316 """
289317 self ._ctx .load_verify_locations (cafile , capath )
290318 # Manually load the CA certs when get_verified_chain is not available (pyopenssl<20).
291319 if not hasattr (_SSL .Connection , "get_verified_chain" ):
320+ assert cafile is not None
292321 self ._callback_data .trusted_ca_certs = _load_trusted_ca_certs (cafile )
293322
294- def _load_certifi (self ):
323+ def _load_certifi (self ) -> None :
295324 """Attempt to load CA certs from certifi."""
296325 if _HAVE_CERTIFI :
297326 self .load_verify_locations (certifi .where ())
@@ -303,7 +332,7 @@ def _load_certifi(self):
303332 "the tlsCAFile option"
304333 )
305334
306- def _load_wincerts (self , store ) :
335+ def _load_wincerts (self , store : str ) -> None :
307336 """Attempt to load CA certs from Windows trust store."""
308337 cert_store = self ._ctx .get_cert_store ()
309338 oid = _stdlibssl .Purpose .SERVER_AUTH .oid
@@ -314,7 +343,7 @@ def _load_wincerts(self, store):
314343 _crypto .X509 .from_cryptography (_load_der_x509_certificate (cert ))
315344 )
316345
317- def load_default_certs (self ):
346+ def load_default_certs (self ) -> None :
318347 """A PyOpenSSL version of load_default_certs from CPython."""
319348 # PyOpenSSL is incapable of loading CA certs from Windows, and mostly
320349 # incapable on macOS.
@@ -330,7 +359,7 @@ def load_default_certs(self):
330359 self ._load_certifi ()
331360 self ._ctx .set_default_verify_paths ()
332361
333- def set_default_verify_paths (self ):
362+ def set_default_verify_paths (self ) -> None :
334363 """Specify that the platform provided CA certificates are to be used
335364 for verification purposes.
336365 """
@@ -340,13 +369,13 @@ def set_default_verify_paths(self):
340369
341370 def wrap_socket (
342371 self ,
343- sock ,
344- server_side = False ,
345- do_handshake_on_connect = True ,
346- suppress_ragged_eofs = True ,
347- server_hostname = None ,
348- session = None ,
349- ):
372+ sock : socket . socket ,
373+ server_side : bool = False ,
374+ do_handshake_on_connect : bool = True ,
375+ suppress_ragged_eofs : bool = True ,
376+ server_hostname : Optional [ str ] = None ,
377+ session : Optional [ _SSL . Session ] = None ,
378+ ) -> _sslConn :
350379 """Wrap an existing Python socket connection and return a TLS socket
351380 object.
352381 """
0 commit comments