11import asyncio
2+ from functools import (
3+ partial ,
4+ )
25from typing import (
36 Any ,
47 Callable ,
2427)
2528
2629from trie import HexaryTrie
30+ from trie .exceptions import BadTrieProof
2731
2832from evm .exceptions import (
2933 BlockNotFound ,
@@ -145,11 +149,26 @@ async def get_receipts(self, block_hash: Hash32) -> List[Receipt]:
145149
146150 @alru_cache (maxsize = 1024 , cache_exceptions = False )
147151 async def get_account (self , block_hash : Hash32 , address : Address ) -> Account :
148- peer = cast (LESPeer , self .peer_pool .highest_td_peer )
152+ return await self ._reattempt_on_bad_response (
153+ partial (self ._get_account_from_peer , block_hash , address )
154+ )
155+
156+ async def _get_account_from_peer (
157+ self ,
158+ block_hash : Hash32 ,
159+ address : Address ,
160+ peer : LESPeer ) -> Account :
149161 key = keccak (address )
150162 proof = await self ._get_proof (peer , block_hash , account_key = b'' , key = key )
151163 header = await self ._get_block_header_by_hash (peer , block_hash )
152- rlp_account = HexaryTrie .get_from_proof (header .state_root , key , proof )
164+ try :
165+ rlp_account = HexaryTrie .get_from_proof (header .state_root , key , proof )
166+ except BadTrieProof as exc :
167+ raise BadLESResponse ("Peer %s returned an invalid proof for account %s at block %s" % (
168+ peer ,
169+ encode_hex (address ),
170+ encode_hex (block_hash ),
171+ )) from exc
153172 return rlp .decode (rlp_account , sedes = Account )
154173
155174 @alru_cache (maxsize = 1024 , cache_exceptions = False )
@@ -173,23 +192,16 @@ async def get_contract_code(self, block_hash: Hash32, address: Address) -> bytes
173192
174193 code_hash = account .code_hash
175194
176- for _ in range (MAX_REQUEST_ATTEMPTS ):
177- peer = cast (LESPeer , self .peer_pool .highest_td_peer )
178- try :
179- return await self ._get_contract_code_from_peer (block_hash , address , peer , code_hash )
180- except BadLESResponse as exc :
181- self .logger .warn ("Disconnecting from peer, because: %s" , exc )
182- await self .disconnect_peer (peer , DisconnectReason .subprotocol_error )
183- # reattempt after removing this peer from our pool
184-
185- raise TimeoutError ("Could not get contract code within %d attempts" % MAX_REQUEST_ATTEMPTS )
195+ return await self ._reattempt_on_bad_response (
196+ partial (self ._get_contract_code_from_peer , block_hash , address , code_hash )
197+ )
186198
187199 async def _get_contract_code_from_peer (
188200 self ,
189201 block_hash : Hash32 ,
190202 address : Address ,
191- peer : LESPeer ,
192- code_hash : Hash32 ) -> bytes :
203+ code_hash : Hash32 ,
204+ peer : LESPeer ) -> bytes :
193205 """
194206 A single attempt to get the contract code from the given peer
195207
@@ -247,3 +259,15 @@ async def _get_proof(self,
247259 peer .sub_proto .send_get_proof (block_hash , account_key , key , from_level , request_id )
248260 reply = await self ._wait_for_reply (request_id )
249261 return reply ['proof' ]
262+
263+ async def _reattempt_on_bad_response (self , make_request_to_peer ):
264+ for _ in range (MAX_REQUEST_ATTEMPTS ):
265+ peer = cast (LESPeer , self .peer_pool .highest_td_peer )
266+ try :
267+ return await make_request_to_peer (peer )
268+ except BadLESResponse as exc :
269+ self .logger .warn ("Disconnecting from peer, because: %s" , exc )
270+ await self .disconnect_peer (peer , DisconnectReason .subprotocol_error )
271+ # reattempt after removing this peer from our pool
272+
273+ raise TimeoutError ("Could not complete peer request in %d attempts" % MAX_REQUEST_ATTEMPTS )
0 commit comments