77 cast ,
88 Dict ,
99 List ,
10+ TYPE_CHECKING ,
1011)
1112
1213from cytoolz .itertoolz import partition_all
3839
3940from p2p import eth
4041from p2p import protocol
42+ from p2p .chain import lookup_headers
4143from p2p .cancel_token import CancelToken
4244from p2p .exceptions import OperationCancelled
4345from p2p .peer import ETHPeer , PeerPool , PeerPoolSubscriber
4446from p2p .service import BaseService
4547from p2p .utils import get_process_pool_executor
4648
4749
50+ if TYPE_CHECKING :
51+ from trinity .db .chain import AsyncChainDB # noqa: F401
52+
53+
4854class StateDownloader (BaseService , PeerPoolSubscriber ):
4955 _pending_nodes : Dict [Any , float ] = {}
5056 _total_processed_nodes = 0
@@ -54,11 +60,13 @@ class StateDownloader(BaseService, PeerPoolSubscriber):
5460 _total_timeouts = 0
5561
5662 def __init__ (self ,
63+ chaindb : 'AsyncChainDB' ,
5764 account_db : BaseDB ,
5865 root_hash : bytes ,
5966 peer_pool : PeerPool ,
6067 token : CancelToken = None ) -> None :
6168 super ().__init__ (token )
69+ self .chaindb = chaindb
6270 self .peer_pool = peer_pool
6371 self .root_hash = root_hash
6472 self .scheduler = StateSync (root_hash , account_db )
@@ -74,8 +82,7 @@ def idle_peers(self) -> List[ETHPeer]:
7482
7583 async def get_idle_peer (self ) -> ETHPeer :
7684 while not self .idle_peers :
77- self .logger .debug ("Waiting for an idle peer..." )
78- await self .wait_first (asyncio .sleep (0.02 ))
85+ await self .wait (asyncio .sleep (0.02 ))
7986 return secrets .choice (self .idle_peers )
8087
8188 async def _handle_msg_loop (self ) -> None :
@@ -93,10 +100,15 @@ async def _handle_msg_loop(self) -> None:
93100
94101 async def _handle_msg (
95102 self , peer : ETHPeer , cmd : protocol .Command , msg : protocol ._DecodedMsgType ) -> None :
96- loop = asyncio .get_event_loop ()
97- if isinstance (cmd , eth .NodeData ):
103+ # Throughout the whole state sync our chain head is fixed, so it makes sense to ignore
104+ # messages related to new blocks/transactions, but we must handle requests for data from
105+ # other peers or else they will disconnect from us.
106+ ignored_commands = (eth .Transactions , eth .NewBlock , eth .NewBlockHashes )
107+ if isinstance (cmd , ignored_commands ):
108+ pass
109+ elif isinstance (cmd , eth .NodeData ):
98110 self .logger .debug ("Got %d NodeData entries from %s" , len (msg ), peer )
99-
111+ loop = asyncio . get_event_loop ()
100112 # Check before we remove because sometimes a reply may come after our timeout and in
101113 # that case we won't be expecting it anymore.
102114 if peer in self ._peers_with_pending_requests :
@@ -113,9 +125,16 @@ async def _handle_msg(
113125 pass
114126 # A node may be received more than once, so pop() with a default value.
115127 self ._pending_nodes .pop (node_key , None )
128+ elif isinstance (cmd , eth .GetBlockHeaders ):
129+ await self ._handle_get_block_headers (peer , cast (Dict [str , Any ], msg ))
116130 else :
117- # We ignore everything that is not a NodeData when doing a StateSync.
118- self .logger .debug ("Ignoring %s msg while doing a StateSync" , cmd )
131+ self .logger .warn ("%s not handled during StateSync, must be implemented" , cmd )
132+
133+ async def _handle_get_block_headers (self , peer : ETHPeer , msg : Dict [str , Any ]) -> None :
134+ headers = await lookup_headers (
135+ self .chaindb , msg ['block_number_or_hash' ], msg ['max_headers' ],
136+ msg ['skip' ], msg ['reverse' ], self .logger , self .cancel_token )
137+ peer .sub_proto .send_block_headers (headers )
119138
120139 async def _cleanup (self ) -> None :
121140 # We don't need to cancel() anything, but we yield control just so that the coroutines we
@@ -261,7 +280,7 @@ def _test() -> None:
261280 asyncio .ensure_future (connect_to_peers_loop (peer_pool , nodes ))
262281
263282 head = chaindb .get_canonical_head ()
264- downloader = StateDownloader (db , head .state_root , peer_pool )
283+ downloader = StateDownloader (chaindb , db , head .state_root , peer_pool )
265284 loop = asyncio .get_event_loop ()
266285
267286 sigint_received = asyncio .Event ()
@@ -274,9 +293,14 @@ async def exit_on_sigint() -> None:
274293 await downloader .cancel ()
275294 loop .stop ()
276295
296+ async def run () -> None :
297+ await downloader .run ()
298+ downloader .logger .info ("run() finished, exiting" )
299+ sigint_received .set ()
300+
277301 loop .set_debug (True )
278302 asyncio .ensure_future (exit_on_sigint ())
279- asyncio .ensure_future (downloader . run ())
303+ asyncio .ensure_future (run ())
280304 loop .run_forever ()
281305 loop .close ()
282306
0 commit comments