3232 MultiAccountId ,
3333)
3434from websockets .asyncio .client import connect
35- from websockets .exceptions import ConnectionClosed
35+ from websockets .exceptions import ConnectionClosed , WebSocketException
3636
3737from async_substrate_interface .const import SS58_FORMAT
3838from async_substrate_interface .errors import (
7575ResultHandler = Callable [[dict , Any ], Awaitable [tuple [dict , bool ]]]
7676
7777logger = logging .getLogger ("async_substrate_interface" )
78+ raw_websocket_logger = logging .getLogger ("raw_websocket" )
7879
7980
8081class AsyncExtrinsicReceipt :
@@ -505,6 +506,7 @@ def __init__(
505506 max_connections = 100 ,
506507 shutdown_timer = 5 ,
507508 options : Optional [dict ] = None ,
509+ _log_raw_websockets : bool = False ,
508510 ):
509511 """
510512 Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -532,6 +534,10 @@ def __init__(
532534 self ._exit_task = None
533535 self ._open_subscriptions = 0
534536 self ._options = options if options else {}
537+ self ._log_raw_websockets = _log_raw_websockets
538+ self ._is_connecting = False
539+ self ._is_closing = False
540+
535541 try :
536542 now = asyncio .get_running_loop ().time ()
537543 except RuntimeError :
@@ -556,38 +562,63 @@ async def __aenter__(self):
556562 async def loop_time () -> float :
557563 return asyncio .get_running_loop ().time ()
558564
565+ async def _cancel (self ):
566+ try :
567+ self ._receiving_task .cancel ()
568+ await self ._receiving_task
569+ await self .ws .close ()
570+ except (
571+ AttributeError ,
572+ asyncio .CancelledError ,
573+ WebSocketException ,
574+ ):
575+ pass
576+ except Exception as e :
577+ logger .warning (
578+ f"{ e } encountered while trying to close websocket connection."
579+ )
580+
559581 async def connect (self , force = False ):
560- now = await self .loop_time ()
561- self .last_received = now
562- self .last_sent = now
563- if self ._exit_task :
564- self ._exit_task .cancel ()
565- async with self ._lock :
566- if not self ._initialized or force :
567- try :
568- self ._receiving_task .cancel ()
569- await self ._receiving_task
570- await self .ws .close ()
571- except (AttributeError , asyncio .CancelledError ):
572- pass
573- self .ws = await asyncio .wait_for (
574- connect (self .ws_url , ** self ._options ), timeout = 10
575- )
576- self ._receiving_task = asyncio .create_task (self ._start_receiving ())
577- self ._initialized = True
582+ self ._is_connecting = True
583+ try :
584+ now = await self .loop_time ()
585+ self .last_received = now
586+ self .last_sent = now
587+ if self ._exit_task :
588+ self ._exit_task .cancel ()
589+ if not self ._is_closing :
590+ if not self ._initialized or force :
591+ try :
592+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
593+ except asyncio .TimeoutError :
594+ pass
595+
596+ self .ws = await asyncio .wait_for (
597+ connect (self .ws_url , ** self ._options ), timeout = 10.0
598+ )
599+ self ._receiving_task = asyncio .get_running_loop ().create_task (
600+ self ._start_receiving ()
601+ )
602+ self ._initialized = True
603+ finally :
604+ self ._is_connecting = False
578605
579606 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
580- async with self ._lock : # TODO is this actually what I want to happen?
581- self ._in_use -= 1
582- if self ._exit_task is not None :
583- self ._exit_task .cancel ()
584- try :
585- await self ._exit_task
586- except asyncio .CancelledError :
587- pass
588- if self ._in_use == 0 and self .ws is not None :
589- self ._open_subscriptions = 0
590- self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
607+ self ._is_closing = True
608+ try :
609+ if not self ._is_connecting :
610+ self ._in_use -= 1
611+ if self ._exit_task is not None :
612+ self ._exit_task .cancel ()
613+ try :
614+ await self ._exit_task
615+ except asyncio .CancelledError :
616+ pass
617+ if self ._in_use == 0 and self .ws is not None :
618+ self ._open_subscriptions = 0
619+ self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
620+ finally :
621+ self ._is_closing = False
591622
592623 async def _exit_with_timer (self ):
593624 """
@@ -601,26 +632,24 @@ async def _exit_with_timer(self):
601632 pass
602633
603634 async def shutdown (self ):
604- async with self ._lock :
605- try :
606- self ._receiving_task .cancel ()
607- await self ._receiving_task
608- await self .ws .close ()
609- except (AttributeError , asyncio .CancelledError ):
610- pass
611- self .ws = None
612- self ._initialized = False
613- self ._receiving_task = None
635+ self ._is_closing = True
636+ try :
637+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
638+ except asyncio .TimeoutError :
639+ pass
640+ self .ws = None
641+ self ._initialized = False
642+ self ._receiving_task = None
643+ self ._is_closing = False
614644
615645 async def _recv (self ) -> None :
616646 try :
617647 # TODO consider wrapping this in asyncio.wait_for and use that for the timeout logic
618- response = json .loads (await self .ws .recv (decode = False ))
648+ recd = await self .ws .recv (decode = False )
649+ if self ._log_raw_websockets :
650+ raw_websocket_logger .debug (f"WEBSOCKET_RECEIVE> { recd .decode ()} " )
651+ response = json .loads (recd )
619652 self .last_received = await self .loop_time ()
620- async with self ._lock :
621- # note that these 'subscriptions' are all waiting sent messages which have not received
622- # responses, and thus are not the same as RPC 'subscriptions', which are unique
623- self ._open_subscriptions -= 1
624653 if "id" in response :
625654 self ._received [response ["id" ]] = response
626655 self ._in_use_ids .remove (response ["id" ])
@@ -640,8 +669,7 @@ async def _start_receiving(self):
640669 except asyncio .CancelledError :
641670 pass
642671 except ConnectionClosed :
643- async with self ._lock :
644- await self .connect (force = True )
672+ await self .connect (force = True )
645673
646674 async def send (self , payload : dict ) -> int :
647675 """
@@ -660,12 +688,14 @@ async def send(self, payload: dict) -> int:
660688 # self._open_subscriptions += 1
661689 await self .max_subscriptions .acquire ()
662690 try :
663- await self .ws .send (json .dumps ({** payload , ** {"id" : original_id }}))
691+ to_send = {** payload , ** {"id" : original_id }}
692+ if self ._log_raw_websockets :
693+ raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
694+ await self .ws .send (json .dumps (to_send ))
664695 self .last_sent = await self .loop_time ()
665696 return original_id
666697 except (ConnectionClosed , ssl .SSLError , EOFError ):
667- async with self ._lock :
668- await self .connect (force = True )
698+ await self .connect (force = True )
669699
670700 async def retrieve (self , item_id : int ) -> Optional [dict ]:
671701 """
@@ -699,6 +729,8 @@ def __init__(
699729 max_retries : int = 5 ,
700730 retry_timeout : float = 60.0 ,
701731 _mock : bool = False ,
732+ _log_raw_websockets : bool = False ,
733+ ws_shutdown_timer : float = 5.0 ,
702734 ):
703735 """
704736 The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to
@@ -716,20 +748,25 @@ def __init__(
716748 max_retries: number of times to retry RPC requests before giving up
717749 retry_timeout: how to long wait since the last ping to retry the RPC request
718750 _mock: whether to use mock version of the subtensor interface
751+ _log_raw_websockets: whether to log raw websocket requests during RPC requests
752+ ws_shutdown_timer: how long after the last connection your websocket should close
719753
720754 """
721755 self .max_retries = max_retries
722756 self .retry_timeout = retry_timeout
723757 self .chain_endpoint = url
724758 self .url = url
725759 self ._chain = chain_name
760+ self ._log_raw_websockets = _log_raw_websockets
726761 if not _mock :
727762 self .ws = Websocket (
728763 url ,
764+ _log_raw_websockets = _log_raw_websockets ,
729765 options = {
730766 "max_size" : self .ws_max_size ,
731767 "write_limit" : 2 ** 16 ,
732768 },
769+ shutdown_timer = ws_shutdown_timer ,
733770 )
734771 else :
735772 self .ws = AsyncMock (spec = Websocket )
0 commit comments