99
1010import datetime
1111import gc
12+ import os
1213import pathlib
1314import select
1415import sys
2526)
2627from gc import collect , get_referrers
2728from os import makedirs
28- from os .path import join
2929from socket import (
3030 AF_INET ,
3131 AF_INET6 ,
124124 WantWriteError ,
125125 ZeroReturnError ,
126126 _make_requires ,
127+ _NoOverlappingProtocols ,
127128)
128129
129130from .test_crypto import (
@@ -166,25 +167,10 @@ def loopback_address(socket: socket) -> str:
166167 return "::1"
167168
168169
169- def join_bytes_or_unicode (prefix , suffix ):
170- """
171- Join two path components of either ``bytes`` or ``unicode``.
172-
173- The return type is the same as the type of ``prefix``.
174- """
175- # If the types are the same, nothing special is necessary.
176- if type (prefix ) is type (suffix ):
177- return join (prefix , suffix )
178-
179- # Otherwise, coerce suffix to the type of prefix.
180- if isinstance (prefix , str ):
181- return join (prefix , suffix .decode (getfilesystemencoding ()))
182- else :
183- return join (prefix , suffix .encode (getfilesystemencoding ()))
184-
185-
186- def verify_cb (conn , cert , errnum , depth , ok ):
187- return ok
170+ def verify_cb (
171+ conn : Connection , cert : X509 , errnum : int , depth : int , ok : int
172+ ) -> bool :
173+ return bool (ok )
188174
189175
190176def socket_pair () -> tuple [socket , socket ]:
@@ -360,7 +346,7 @@ def loopback(
360346
361347def interact_in_memory (
362348 client_conn : Connection , server_conn : Connection
363- ) -> tuple [Connection , bytes ]:
349+ ) -> tuple [Connection , bytes ] | None :
364350 """
365351 Try to read application bytes from each of the two `Connection` objects.
366352 Copy bytes back and forth between their send/receive buffers for as long
@@ -405,6 +391,8 @@ def interact_in_memory(
405391 wrote = True
406392 write .bio_write (dirty )
407393
394+ return None
395+
408396
409397def handshake_in_memory (
410398 client_conn : Connection , server_conn : Connection
@@ -1021,9 +1009,9 @@ def info(conn: Connection, where: int, ret: int) -> None:
10211009 for (conn , where , ret ) in called
10221010 if not isinstance (conn , Connection )
10231011 ]
1024- assert (
1025- [] == notConnections
1026- ), "Some info callback arguments were not Connection instances."
1012+ assert [] == notConnections , (
1013+ "Some info callback arguments were not Connection instances."
1014+ )
10271015
10281016 @pytest .mark .skipif (
10291017 not getattr (_lib , "Cryptography_HAS_KEYLOG" , None ),
@@ -1168,7 +1156,9 @@ def test_load_verify_invalid_file(self, tmpfile: bytes) -> None:
11681156 with pytest .raises (Error ):
11691157 clientContext .load_verify_locations (tmpfile )
11701158
1171- def _load_verify_directory_locations_capath (self , capath : bytes ) -> None :
1159+ def _load_verify_directory_locations_capath (
1160+ self , capath : str | bytes
1161+ ) -> None :
11721162 """
11731163 Verify that if path to a directory containing certificate files is
11741164 passed to ``Context.load_verify_locations`` for the ``capath``
@@ -1180,7 +1170,11 @@ def _load_verify_directory_locations_capath(self, capath: bytes) -> None:
11801170 # c_rehash in the test suite. One is from OpenSSL 0.9.8, the other
11811171 # from OpenSSL 1.0.0.
11821172 for name in [b"c7adac82.0" , b"c3705638.0" ]:
1183- cafile = join_bytes_or_unicode (capath , name )
1173+ cafile : str | bytes
1174+ if isinstance (capath , str ):
1175+ cafile = os .path .join (capath , name .decode ())
1176+ else :
1177+ cafile = os .path .join (capath , name )
11841178 with open (cafile , "w" ) as fObj :
11851179 fObj .write (root_cert_pem .decode ("ascii" ))
11861180
@@ -1209,9 +1203,13 @@ def test_load_verify_directory_capath(
12091203 """
12101204 if pathtype == "unicode_path" :
12111205 tmpfile += NON_ASCII .encode (getfilesystemencoding ())
1206+
12121207 if argtype == "unicode_arg" :
1213- tmpfile = tmpfile .decode (getfilesystemencoding ())
1214- self ._load_verify_directory_locations_capath (tmpfile )
1208+ self ._load_verify_directory_locations_capath (
1209+ tmpfile .decode (getfilesystemencoding ())
1210+ )
1211+ else :
1212+ self ._load_verify_directory_locations_capath (tmpfile )
12151213
12161214 def test_load_verify_locations_wrong_args (self ) -> None :
12171215 """
@@ -1393,7 +1391,14 @@ def test_set_verify_callback_connection_argument(self) -> None:
13931391 serverConnection = Connection (serverContext , None )
13941392
13951393 class VerifyCallback :
1396- def callback (self , connection : Connection , * args ) -> bool :
1394+ def callback (
1395+ self ,
1396+ connection : Connection ,
1397+ cert : X509 ,
1398+ err : int ,
1399+ depth : int ,
1400+ ok : int ,
1401+ ) -> bool :
13971402 self .connection = connection
13981403 return True
13991404
@@ -1452,7 +1457,9 @@ def test_set_verify_callback_exception(self) -> None:
14521457
14531458 clientContext = Context (TLSv1_2_METHOD )
14541459
1455- def verify_callback (* args ):
1460+ def verify_callback (
1461+ conn : Connection , cert : X509 , err : int , depth : int , ok : int
1462+ ) -> bool :
14561463 raise Exception ("silly verify failure" )
14571464
14581465 clientContext .set_verify (VERIFY_PEER , verify_callback )
@@ -1482,7 +1489,7 @@ def test_set_verify_callback_reference(self) -> None:
14821489
14831490 for i in range (5 ):
14841491
1485- def verify_callback (* args ) :
1492+ def verify_callback (* args : object ) -> bool :
14861493 return True
14871494
14881495 serverSocket , clientSocket = socket_pair ()
@@ -1589,8 +1596,14 @@ def _use_certificate_chain_file_test(self, certdir: str | bytes) -> None:
15891596
15901597 makedirs (certdir )
15911598
1592- chainFile = join_bytes_or_unicode (certdir , "chain.pem" )
1593- caFile = join_bytes_or_unicode (certdir , "ca.pem" )
1599+ chainFile : str | bytes
1600+ caFile : str | bytes
1601+ if isinstance (certdir , str ):
1602+ chainFile = os .path .join (certdir , "chain.pem" )
1603+ caFile = os .path .join (certdir , "ca.pem" )
1604+ else :
1605+ chainFile = os .path .join (certdir , b"chain.pem" )
1606+ caFile = os .path .join (certdir , b"ca.pem" )
15941607
15951608 # Write out the chain file.
15961609 with open (chainFile , "wb" ) as fObj :
@@ -1848,9 +1861,9 @@ def replacement(connection: Connection) -> None: # pragma: no cover
18481861 collect ()
18491862 collect ()
18501863
1851- callback = tracker ()
1852- if callback is not None :
1853- referrers = get_referrers (callback )
1864+ callback_ref = tracker ()
1865+ if callback_ref is not None :
1866+ referrers = get_referrers (callback_ref )
18541867 assert len (referrers ) == 1
18551868
18561869 def test_no_servername (self ) -> None :
@@ -2064,7 +2077,9 @@ def test_alpn_no_server_overlap(self) -> None:
20642077 """
20652078 refusal_args = []
20662079
2067- def refusal (conn : Connection , options : list [bytes ]):
2080+ def refusal (
2081+ conn : Connection , options : list [bytes ]
2082+ ) -> _NoOverlappingProtocols :
20682083 refusal_args .append ((conn , options ))
20692084 return NO_OVERLAPPING_PROTOCOLS
20702085
@@ -2218,7 +2233,7 @@ def test_construction(self) -> None:
22182233
22192234
22202235@pytest .fixture (params = ["context" , "connection" ])
2221- def ctx_or_conn (request ) -> Context | Connection :
2236+ def ctx_or_conn (request : pytest . FixtureRequest ) -> Context | Connection :
22222237 ctx = Context (SSLv23_METHOD )
22232238 if request .param == "context" :
22242239 return ctx
@@ -2823,9 +2838,9 @@ def callback(
28232838 )
28242839 collect ()
28252840 collect ()
2826- callback = tracker ()
2827- if callback is not None : # pragma: nocover
2828- referrers = get_referrers (callback )
2841+ callback_ref = tracker ()
2842+ if callback_ref is not None : # pragma: nocover
2843+ referrers = get_referrers (callback_ref )
28292844 assert len (referrers ) == 1
28302845
28312846 def test_get_session_unconnected (self ) -> None :
@@ -3862,7 +3877,9 @@ def test_outgoing_overflow(self) -> None:
38623877 # meaningless.
38633878 assert sent < size
38643879
3865- receiver , received = interact_in_memory (client , server )
3880+ result = interact_in_memory (client , server )
3881+ assert result is not None
3882+ receiver , received = result
38663883 assert receiver is server
38673884
38683885 # We can rely on all of these bytes being received at once because
@@ -4249,7 +4266,7 @@ def test_callbacks_arent_called_by_default(self) -> None:
42494266 called.
42504267 """
42514268
4252- def ocsp_callback (* args , ** kwargs ) : # pragma: nocover
4269+ def ocsp_callback (* args : object ) -> typing . NoReturn : # pragma: nocover
42534270 pytest .fail ("Should not be called" )
42544271
42554272 client = self ._client_connection (
@@ -4284,7 +4301,7 @@ def test_client_receives_servers_data(self) -> None:
42844301 """
42854302 calls = []
42864303
4287- def server_callback (* args , ** kwargs ) :
4304+ def server_callback (* args : object , ** kwargs : object ) -> bytes :
42884305 return self .sample_ocsp_data
42894306
42904307 def client_callback (
@@ -4307,11 +4324,15 @@ def test_callbacks_are_invoked_with_connections(self) -> None:
43074324 client_calls = []
43084325 server_calls = []
43094326
4310- def client_callback (conn , * args , ** kwargs ):
4327+ def client_callback (
4328+ conn : Connection , * args : object , ** kwargs : object
4329+ ) -> bool :
43114330 client_calls .append (conn )
43124331 return True
43134332
4314- def server_callback (conn , * args , ** kwargs ):
4333+ def server_callback (
4334+ conn : Connection , * args : object , ** kwargs : object
4335+ ) -> bytes :
43154336 server_calls .append (conn )
43164337 return self .sample_ocsp_data
43174338
@@ -4331,11 +4352,11 @@ def test_opaque_data_is_passed_through(self) -> None:
43314352 """
43324353 calls = []
43334354
4334- def server_callback (* args ) :
4355+ def server_callback (* args : object ) -> bytes :
43354356 calls .append (args )
43364357 return self .sample_ocsp_data
43374358
4338- def client_callback (* args ) :
4359+ def client_callback (* args : object ) -> bool :
43394360 calls .append (args )
43404361 return True
43414362
@@ -4360,7 +4381,7 @@ def test_server_returns_empty_string(self) -> None:
43604381 """
43614382 client_calls = []
43624383
4363- def server_callback (* args ) :
4384+ def server_callback (* args : object ) -> bytes :
43644385 return b""
43654386
43664387 def client_callback (
@@ -4381,10 +4402,10 @@ def test_client_returns_false_terminates_handshake(self) -> None:
43814402 If the client returns False from its callback, the handshake fails.
43824403 """
43834404
4384- def server_callback (* args ) :
4405+ def server_callback (* args : object ) -> bytes :
43854406 return self .sample_ocsp_data
43864407
4387- def client_callback (* args ) :
4408+ def client_callback (* args : object ) -> bool :
43884409 return False
43894410
43904411 client = self ._client_connection (callback = client_callback , data = None )
@@ -4401,10 +4422,10 @@ def test_exceptions_in_client_bubble_up(self) -> None:
44014422 class SentinelException (Exception ):
44024423 pass
44034424
4404- def server_callback (* args ) :
4425+ def server_callback (* args : object ) -> bytes :
44054426 return self .sample_ocsp_data
44064427
4407- def client_callback (* args ) :
4428+ def client_callback (* args : object ) -> typing . NoReturn :
44084429 raise SentinelException ()
44094430
44104431 client = self ._client_connection (callback = client_callback , data = None )
@@ -4421,10 +4442,12 @@ def test_exceptions_in_server_bubble_up(self) -> None:
44214442 class SentinelException (Exception ):
44224443 pass
44234444
4424- def server_callback (* args ) :
4445+ def server_callback (* args : object ) -> typing . NoReturn :
44254446 raise SentinelException ()
44264447
4427- def client_callback (* args ): # pragma: nocover
4448+ def client_callback (
4449+ * args : object ,
4450+ ) -> typing .NoReturn : # pragma: nocover
44284451 pytest .fail ("Should not be called" )
44294452
44304453 client = self ._client_connection (callback = client_callback , data = None )
@@ -4438,14 +4461,16 @@ def test_server_must_return_bytes(self) -> None:
44384461 The server callback must return a bytestring, or a TypeError is thrown.
44394462 """
44404463
4441- def server_callback (* args ) :
4464+ def server_callback (* args : object ) -> str :
44424465 return self .sample_ocsp_data .decode ("ascii" )
44434466
4444- def client_callback (* args ): # pragma: nocover
4467+ def client_callback (
4468+ * args : object ,
4469+ ) -> typing .NoReturn : # pragma: nocover
44454470 pytest .fail ("Should not be called" )
44464471
44474472 client = self ._client_connection (callback = client_callback , data = None )
4448- server = self ._server_connection (callback = server_callback , data = None )
4473+ server = self ._server_connection (callback = server_callback , data = None ) # type: ignore[arg-type]
44494474
44504475 with pytest .raises (TypeError ):
44514476 handshake_in_memory (client , server )
0 commit comments