@@ -501,8 +501,13 @@ def connection_made(self, transport: BaseTransport) -> None:
501501 """
502502 self .transport = transport # type: ignore[assignment]
503503
504- async def read (self , bytes_needed : int ) -> bytes :
504+ async def read (self , bytes_needed : int , first = False ) -> bytes :
505505 """Read the requested bytes from this connection."""
506+ if self ._bytes_ready >= bytes_needed or (self ._bytes_ready > 0 and first ):
507+ # Wait for other listeners first.
508+ if len (self ._pending_listeners ):
509+ await asyncio .gather (* self ._pending_listeners )
510+ return self ._read (bytes_needed )
506511 if self .transport :
507512 try :
508513 self .transport .resume_reading ()
@@ -511,9 +516,7 @@ async def read(self, bytes_needed: int) -> bytes:
511516 raise OSError ("connection is already closed" ) from None
512517 if self .transport and self .transport .is_closing ():
513518 raise OSError ("connection is already closed" )
514- if self ._bytes_ready >= bytes_needed :
515- return self ._read (bytes_needed )
516- self ._pending_reads .append (bytes_needed )
519+ self ._pending_reads .append ((bytes_needed , first ))
517520 read_waiter = asyncio .get_running_loop ().create_future ()
518521 self ._pending_listeners .append (read_waiter )
519522 return await read_waiter
@@ -543,18 +546,22 @@ def buffer_updated(self, nbytes: int) -> None:
543546
544547 # Bail we don't have the current requested number of bytes.
545548 bytes_needed = self ._bytes_requested
549+ first = False
546550 if bytes_needed == 0 and self ._pending_reads :
547- bytes_needed = self ._pending_reads .popleft ()
548- if bytes_needed == 0 or self ._bytes_ready < bytes_needed :
551+ bytes_needed , first = self ._pending_reads .popleft ()
552+ read_first = first and self ._bytes_ready > 0
553+ if not read_first and (bytes_needed == 0 or self ._bytes_ready < bytes_needed ):
549554 return
550555
551- data = self ._read (bytes_needed )
556+ data = self ._read (bytes_needed , first )
552557 waiter = self ._pending_listeners .popleft ()
553558 waiter .set_result (data )
554559
555- def _read (self , bytes_needed ):
560+ def _read (self , bytes_needed , first = False ):
556561 """Read bytes from the buffer."""
557562 # Send the bytes to the listener.
563+ if first and self ._bytes_ready < bytes_needed :
564+ bytes_needed = self ._bytes_ready
558565 self ._bytes_ready -= bytes_needed
559566 self ._bytes_requested = 0
560567
@@ -591,13 +598,13 @@ async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None:
591598 raise socket .timeout ("timed out" ) from exc
592599
593600
594- async def async_receive_kms (conn : AsyncBaseConnection , bytes_needed : int ) -> bytes :
601+ async def async_receive_kms (conn : AsyncBaseConnection , bytes_needed : int , first = False ) -> bytes :
595602 """Receive raw bytes from the kms connection."""
596603
597604 def callback (result : Any ) -> bytes :
598605 return result
599606
600- return await _async_receive_data (conn , callback , bytes_needed )
607+ return await _async_receive_data (conn , callback , bytes_needed , first )
601608
602609
603610async def _async_receive_data (
0 commit comments