2222import struct
2323import sys
2424import time
25- from asyncio import BaseTransport , BufferedProtocol , Future , Transport
26- from dataclasses import dataclass
25+ from asyncio import BaseTransport , BufferedProtocol , Future , Protocol , Transport
2726from typing import (
2827 TYPE_CHECKING ,
2928 Any ,
@@ -251,7 +250,7 @@ def recv_into(self, buffer: bytes) -> int:
251250 return self .conn .recv_into (buffer )
252251
253252
254- class PyMongoBaseProtocol (BufferedProtocol ):
253+ class PyMongoBaseProtocol (Protocol ):
255254 def __init__ (self , timeout : Optional [float ] = None ):
256255 self .transport : Transport = None # type: ignore[assignment]
257256 self ._timeout = timeout
@@ -293,7 +292,7 @@ async def read(self, *args: Any) -> Any:
293292 raise NotImplementedError
294293
295294
296- class PyMongoProtocol (PyMongoBaseProtocol ):
295+ class PyMongoProtocol (PyMongoBaseProtocol , BufferedProtocol ):
297296 def __init__ (self , timeout : Optional [float ] = None ):
298297 super ().__init__ (timeout )
299298 # Each message is reader in 2-3 parts: header, compression header, and message body
@@ -477,17 +476,10 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
477476 self ._done_messages .append (msg )
478477
479478
480- @dataclass
481- class KMSBuffer :
482- buffer : memoryview
483- start_index : int
484- end_index : int
485-
486-
487479class PyMongoKMSProtocol (PyMongoBaseProtocol ):
488480 def __init__ (self , timeout : Optional [float ] = None ):
489481 super ().__init__ (timeout )
490- self ._buffers : collections .deque [KMSBuffer ] = collections .deque ()
482+ self ._buffers : collections .deque [memoryview [ bytes ] ] = collections .deque ()
491483 self ._bytes_ready = 0
492484 self ._pending_reads : collections .deque [int ] = collections .deque ()
493485 self ._pending_listeners : collections .deque [Future [Any ]] = collections .deque ()
@@ -498,6 +490,24 @@ def connection_made(self, transport: BaseTransport) -> None:
498490 """
499491 self .transport = transport # type: ignore[assignment]
500492
493+ def data_received (self , data : bytes ) -> None :
494+ if self ._connection_lost :
495+ return
496+
497+ self ._bytes_ready += len (data )
498+ self ._buffers .append (memoryview [data ])
499+
500+ if not len (self ._pending_reads ):
501+ return
502+
503+ bytes_needed = self ._pending_reads .popleft ()
504+ data = self ._read (bytes_needed )
505+ waiter = self ._pending_listeners .popleft ()
506+ waiter .set_result (data )
507+
508+ def eof_received (self ):
509+ self .close (OSError ("connection closed" ))
510+
501511 async def read (self , bytes_needed : int ) -> bytes :
502512 """Read up to the requested bytes from this connection."""
503513 # Note: all reads are "up-to" bytes_needed because we don't know if the kms_context
@@ -521,51 +531,13 @@ async def read(self, bytes_needed: int) -> bytes:
521531 self ._pending_listeners .append (read_waiter )
522532 return await read_waiter
523533
524- def get_buffer (self , sizehint : int ) -> memoryview :
525- """Called to allocate a new receive buffer.
526- The asyncio loop calls this method expecting to receive a non-empty buffer to fill with data.
527- If any data does not fit into the returned buffer, this method will be called again until
528- either no data remains or an empty buffer is returned.
529- """
530- # Reuse the active buffer if it has space.
531- # Allocate a bit more than the max response size for an AWS KMS response.
532- sizehint = max (sizehint , 16384 )
533- if len (self ._buffers ):
534- buffer = self ._buffers [- 1 ]
535- if len (buffer .buffer ) - buffer .end_index > sizehint :
536- return buffer .buffer [buffer .end_index :]
537- buffer = KMSBuffer (memoryview (bytearray (sizehint )), 0 , 0 )
538- self ._buffers .append (buffer )
539- return buffer .buffer
540-
541534 def _resolve_pending (self , exc : Optional [Exception ] = None ) -> None :
542535 while self ._pending_listeners :
543536 fut = self ._pending_listeners .popleft ()
544537 fut .set_result (b"" )
545538
546- def buffer_updated (self , nbytes : int ) -> None :
547- """Called when the buffer was updated with the received data"""
548- # Wrote 0 bytes into a non-empty buffer, signal connection closed
549- if nbytes == 0 :
550- self .close (OSError ("connection closed" ))
551- return
552- if self ._connection_lost :
553- return
554- self ._bytes_ready += nbytes
555-
556- # Update the length of the current buffer.
557- self ._buffers [- 1 ].end_index += nbytes
558-
559- if not len (self ._pending_reads ):
560- return
561-
562- bytes_needed = self ._pending_reads .popleft ()
563- data = self ._read (bytes_needed )
564- waiter = self ._pending_listeners .popleft ()
565- waiter .set_result (data )
566-
567539 def _read (self , bytes_needed : int ) -> memoryview :
568- """Read bytes from the buffer ."""
540+ """Read bytes."""
569541 # Send the bytes to the listener.
570542 if self ._bytes_ready < bytes_needed :
571543 bytes_needed = self ._bytes_ready
@@ -576,26 +548,17 @@ def _read(self, bytes_needed: int) -> memoryview:
576548 out_index = 0
577549 while n_remaining > 0 :
578550 buffer = self ._buffers .popleft ()
579- buffer_remaining = buffer . end_index - buffer . start_index
551+ buf_size = len ( buffer )
580552 # if we didn't exhaust the buffer, read the partial data and return the buffer.
581- if buffer_remaining > n_remaining :
582- output_buf [out_index : n_remaining + out_index ] = buffer .buffer [
583- buffer .start_index : buffer .start_index + n_remaining
584- ]
585- buffer .start_index += n_remaining
553+ if buf_size > n_remaining :
554+ output_buf [out_index : n_remaining + out_index ] = buffer [:n_remaining ]
586555 n_remaining = 0
587- self ._buffers .appendleft (buffer )
556+ self ._buffers .appendleft (buffer [ n_remaining :] )
588557 # otherwise exhaust the buffer.
589558 else :
590- output_buf [out_index : out_index + buffer_remaining ] = buffer .buffer [
591- buffer .start_index : buffer .end_index
592- ]
593- out_index += buffer_remaining
594- n_remaining -= buffer_remaining
595- # if this is the only buffer, add it back to the queue.
596- if not len (self ._buffers ):
597- buffer .start_index = buffer .end_index
598- self ._buffers .appendleft (buffer )
559+ output_buf [out_index : out_index + buf_size ] = buffer [:]
560+ out_index += buf_size
561+ n_remaining -= buf_size
599562 return memoryview (output_buf )
600563
601564
0 commit comments