@@ -308,7 +308,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
308308 return server
309309
310310
311- def loopback (server_factory = None , client_factory = None ):
311+ def loopback (server_factory = None , client_factory = None , blocking = True ):
312312 """
313313 Create a connected socket pair and force two connected SSL sockets
314314 to talk to each other via memory BIOs.
@@ -324,8 +324,8 @@ def loopback(server_factory=None, client_factory=None):
324324
325325 handshake (client , server )
326326
327- server .setblocking (True )
328- client .setblocking (True )
327+ server .setblocking (blocking )
328+ client .setblocking (blocking )
329329 return server , client
330330
331331
@@ -3131,11 +3131,131 @@ def test_memoryview_really_doesnt_overfill(self):
31313131 self ._doesnt_overfill_test (_make_memoryview )
31323132
31333133
3134+ @pytest .fixture
3135+ def nonblocking_tls_connections_pair ():
3136+ """Return a non-blocking TLS loopback connections pair."""
3137+ return loopback (blocking = False )
3138+
3139+
3140+ @pytest .fixture
3141+ def nonblocking_tls_server_connection (nonblocking_tls_connections_pair ):
3142+ """Return a non-blocking TLS server socket connected to loopback."""
3143+ return nonblocking_tls_connections_pair [0 ]
3144+
3145+
3146+ @pytest .fixture
3147+ def nonblocking_tls_client_connection (nonblocking_tls_connections_pair ):
3148+ """Return a non-blocking TLS client socket connected to loopback."""
3149+ return nonblocking_tls_connections_pair [1 ]
3150+
3151+
31343152class TestConnectionSendall (object ):
31353153 """
31363154 Tests for `Connection.sendall`.
31373155 """
31383156
3157+ def test_want_write (
3158+ self ,
3159+ monkeypatch ,
3160+ nonblocking_tls_server_connection ,
3161+ nonblocking_tls_client_connection ,
3162+ ):
3163+ msg = b"x"
3164+ garbage_size = 1024 * 1024 * 64
3165+ large_payload = b"p" * garbage_size * 2
3166+ payload_size = len (large_payload )
3167+
3168+ sent_garbage_size = 0
3169+ try :
3170+ sent_garbage_size += nonblocking_tls_client_connection .send (
3171+ msg * garbage_size ,
3172+ )
3173+ except WantWriteError :
3174+ pass
3175+ for i in range (garbage_size ):
3176+ try :
3177+ sent_garbage_size += nonblocking_tls_client_connection .send (
3178+ msg ,
3179+ )
3180+ except WantWriteError :
3181+ break
3182+ else :
3183+ pytest .fail (
3184+ "Failed to fill socket buffer, cannot test "
3185+ "'want write' in `sendall()`"
3186+ )
3187+ garbage_payload = sent_garbage_size * msg
3188+
3189+
3190+ def consume_garbage (conn ):
3191+ assert patched_ssl_write .want_write_counter >= 1
3192+ assert not consume_garbage .garbage_consumed
3193+
3194+ while len (consume_garbage .consumed ) < sent_garbage_size :
3195+ try :
3196+ consume_garbage .consumed += conn .recv (
3197+ sent_garbage_size - len (consume_garbage .consumed ),
3198+ )
3199+ except WantReadError :
3200+ pass
3201+
3202+ assert consume_garbage .consumed == garbage_payload
3203+
3204+ consume_garbage .garbage_consumed = True
3205+
3206+ consume_garbage .garbage_consumed = False
3207+ consume_garbage .consumed = b""
3208+
3209+ def consume_payload (conn ):
3210+ try :
3211+ consume_payload .consumed += conn .recv (payload_size )
3212+ except WantReadError :
3213+ pass
3214+ consume_payload .consumed = b""
3215+
3216+ original_ssl_write = _lib .SSL_write
3217+ def patched_ssl_write (ctx , data , size ):
3218+ write_result = original_ssl_write (ctx , data , size )
3219+ try :
3220+ nonblocking_tls_client_connection ._raise_ssl_error (
3221+ ctx , write_result ,
3222+ )
3223+ except WantWriteError :
3224+ patched_ssl_write .want_write_counter += 1
3225+ consume_data_on_server = (
3226+ consume_payload if consume_garbage .garbage_consumed
3227+ else consume_garbage
3228+ )
3229+
3230+ consume_data_on_server (nonblocking_tls_server_connection )
3231+ # NOTE: We don't re-raise it as the calling code will do
3232+ # NOTE: the same after the call.
3233+ return write_result
3234+
3235+ patched_ssl_write .want_write_counter = 0
3236+
3237+ # NOTE: Make the client think it needs a handshake so that it'll
3238+ # NOTE: attempt to `do_handshake()` on the next `SSL_write()`
3239+ # NOTE: that originates from `sendall()`:
3240+ nonblocking_tls_client_connection .set_connect_state ()
3241+ try :
3242+ nonblocking_tls_client_connection .do_handshake ()
3243+ except WantWriteError :
3244+ assert True # Sanity check
3245+ except :
3246+ assert False # This should never happen (see the note above)
3247+
3248+ with monkeypatch .context () as mp_ctx :
3249+ mp_ctx .setattr (_lib , "SSL_write" , patched_ssl_write )
3250+ nonblocking_tls_client_connection .sendall (large_payload )
3251+
3252+ assert consume_garbage .garbage_consumed
3253+
3254+ # NOTE: Read the leftover data from the very last `SSL_write()`
3255+ consume_payload (nonblocking_tls_server_connection )
3256+
3257+ assert consume_payload .consumed == large_payload
3258+
31393259 def test_wrong_args (self ):
31403260 """
31413261 When called with arguments other than a string argument for its first
0 commit comments