@@ -276,7 +276,7 @@ async def do_p2p_handshake(self) -> None:
276276 # Peers sometimes send a disconnect msg before they send the initial P2P handshake.
277277 raise HandshakeFailure ("{} disconnected before completing handshake: {}" .format (
278278 self , msg ['reason_name' ]))
279- self .process_p2p_handshake (cmd , msg )
279+ await self .process_p2p_handshake (cmd , msg )
280280
281281 @property
282282 async def genesis (self ) -> BlockHeader :
@@ -393,16 +393,17 @@ def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> N
393393 else :
394394 self .handle_sub_proto_msg (cmd , msg )
395395
396- def process_p2p_handshake (self , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
396+ async def process_p2p_handshake (
397+ self , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
397398 msg = cast (Dict [str , Any ], msg )
398399 if not isinstance (cmd , Hello ):
399- self .disconnect (DisconnectReason .bad_protocol )
400+ await self .disconnect (DisconnectReason .bad_protocol )
400401 raise HandshakeFailure ("Expected a Hello msg, got {}, disconnecting" .format (cmd ))
401402 remote_capabilities = msg ['capabilities' ]
402403 try :
403404 self .sub_proto = self .select_sub_protocol (remote_capabilities )
404405 except NoMatchingPeerCapabilities :
405- self .disconnect (DisconnectReason .useless_peer )
406+ await self .disconnect (DisconnectReason .useless_peer )
406407 raise HandshakeFailure (
407408 "No matching capabilities between us ({}) and {} ({}), disconnecting" .format (
408409 self .capabilities , self .remote , remote_capabilities ))
@@ -474,9 +475,11 @@ def send(self, header: bytes, body: bytes) -> None:
474475 self .logger .trace ("Sending msg with cmd_id: %s" , cmd_id )
475476 self .writer .write (self .encrypt (header , body ))
476477
477- def disconnect (self , reason : DisconnectReason ) -> None :
478+ async def disconnect (self , reason : DisconnectReason ) -> None :
478479 """Send a disconnect msg to the remote node and stop this Peer.
479480
481+ Also awaits for self.cancel() to ensure any pending tasks are cleaned up.
482+
480483 :param reason: An item from the DisconnectReason enum.
481484 """
482485 if not isinstance (reason , DisconnectReason ):
@@ -485,6 +488,8 @@ def disconnect(self, reason: DisconnectReason) -> None:
485488 self .logger .debug ("Disconnecting from remote peer; reason: %s" , reason .name )
486489 self .base_protocol .send_disconnect (reason .value )
487490 self .close ()
491+ if self .is_running :
492+ await self .cancel ()
488493
489494 def select_sub_protocol (self , remote_capabilities : List [Tuple [bytes , int ]]
490495 ) -> protocol .Protocol :
@@ -537,18 +542,18 @@ async def send_sub_proto_handshake(self) -> None:
537542 async def process_sub_proto_handshake (
538543 self , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
539544 if not isinstance (cmd , (les .Status , les .StatusV2 )):
540- self .disconnect (DisconnectReason .subprotocol_error )
545+ await self .disconnect (DisconnectReason .subprotocol_error )
541546 raise HandshakeFailure (
542547 "Expected a LES Status msg, got {}, disconnecting" .format (cmd ))
543548 msg = cast (Dict [str , Any ], msg )
544549 if msg ['networkId' ] != self .network_id :
545- self .disconnect (DisconnectReason .useless_peer )
550+ await self .disconnect (DisconnectReason .useless_peer )
546551 raise HandshakeFailure (
547552 "{} network ({}) does not match ours ({}), disconnecting" .format (
548553 self , msg ['networkId' ], self .network_id ))
549554 genesis = await self .genesis
550555 if msg ['genesisHash' ] != genesis .hash :
551- self .disconnect (DisconnectReason .useless_peer )
556+ await self .disconnect (DisconnectReason .useless_peer )
552557 raise HandshakeFailure (
553558 "{} genesis ({}) does not match ours ({}), disconnecting" .format (
554559 self , encode_hex (msg ['genesisHash' ]), genesis .hex_hash ))
@@ -628,18 +633,18 @@ async def send_sub_proto_handshake(self) -> None:
628633 async def process_sub_proto_handshake (
629634 self , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
630635 if not isinstance (cmd , eth .Status ):
631- self .disconnect (DisconnectReason .subprotocol_error )
636+ await self .disconnect (DisconnectReason .subprotocol_error )
632637 raise HandshakeFailure (
633638 "Expected a ETH Status msg, got {}, disconnecting" .format (cmd ))
634639 msg = cast (Dict [str , Any ], msg )
635640 if msg ['network_id' ] != self .network_id :
636- self .disconnect (DisconnectReason .useless_peer )
641+ await self .disconnect (DisconnectReason .useless_peer )
637642 raise HandshakeFailure (
638643 "{} network ({}) does not match ours ({}), disconnecting" .format (
639644 self , msg ['network_id' ], self .network_id ))
640645 genesis = await self .genesis
641646 if msg ['genesis_hash' ] != genesis .hash :
642- self .disconnect (DisconnectReason .useless_peer )
647+ await self .disconnect (DisconnectReason .useless_peer )
643648 raise HandshakeFailure (
644649 "{} genesis ({}) does not match ours ({}), disconnecting" .format (
645650 self , encode_hex (msg ['genesis_hash' ]), genesis .hex_hash ))
@@ -770,12 +775,8 @@ async def _run(self) -> None:
770775
771776 async def stop_all_peers (self ) -> None :
772777 self .logger .info ("Stopping all peers ..." )
773-
774778 peers = self .connected_nodes .values ()
775- for peer in peers :
776- peer .disconnect (DisconnectReason .client_quitting )
777-
778- await asyncio .gather (* [peer .cancel () for peer in peers ])
779+ await asyncio .gather (* [peer .disconnect (DisconnectReason .client_quitting ) for peer in peers ])
779780
780781 async def _cleanup (self ) -> None :
781782 await self .stop_all_peers ()
0 commit comments