@@ -144,14 +144,33 @@ async def get_account(self, block_hash: Hash32, address: Address) -> Account:
144144 return rlp .decode (rlp_account , sedes = Account )
145145
146146 @alru_cache (maxsize = 1024 , cache_exceptions = False )
147- async def get_contract_code (self , block_hash : Hash32 , key : bytes ) -> bytes :
147+ async def get_contract_code (self , block_hash : Hash32 , address : Address ) -> bytes :
148+ """
149+ :param block_hash: find code as of the block with block_hash
150+ :param address: which contract to look up
151+
152+ :return: bytecode of the contract, ``b''`` if no code is set
153+ """
148154 peer = cast (LESPeer , self .peer_pool .highest_td_peer )
149155 request_id = gen_request_id ()
150- peer .sub_proto .send_get_contract_code (block_hash , key , request_id )
156+ peer .sub_proto .send_get_contract_code (block_hash , keccak ( address ) , request_id )
151157 reply = await self ._wait_for_reply (request_id )
158+
152159 if not reply ['codes' ]:
153- return b''
154- return reply ['codes' ][0 ]
160+ bytecode = b''
161+ else :
162+ bytecode = reply ['codes' ][0 ]
163+
164+ # validate bytecode against a proven account
165+ account = await self .get_account (block_hash , address )
166+
167+ if account .code_hash == keccak (bytecode ):
168+ return bytecode
169+ else :
170+ # disconnect from this bad peer
171+ await self .disconnect_peer (peer , DisconnectReason .subprotocol_error )
172+ # try again with another peer
173+ return await self .get_contract_code (block_hash , address )
155174
156175 async def _get_block_header_by_hash (self , peer : LESPeer , block_hash : Hash32 ) -> BlockHeader :
157176 self .logger .debug ("Fetching header %s from %s" , encode_hex (block_hash ), peer )
0 commit comments