@@ -524,6 +524,7 @@ def __init__(
524524 shutdown_timer = 5 ,
525525 options : Optional [dict ] = None ,
526526 _log_raw_websockets : bool = False ,
527+ retry_timeout : float = 60.0
527528 ):
528529 """
529530 Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -542,10 +543,12 @@ def __init__(
542543 self .max_subscriptions = asyncio .Semaphore (max_subscriptions )
543544 self .max_connections = max_connections
544545 self .shutdown_timer = shutdown_timer
546+ self .retry_timeout = retry_timeout
545547 self ._received : dict [str , asyncio .Future ] = {}
546548 self ._received_subscriptions : dict [str , asyncio .Queue ] = {}
547549 self ._sending = asyncio .Queue ()
548- self ._receiving_task = None # TODO rename, as this now does send/recv
550+ self ._send_recv_task = None
551+ self ._inflight : dict [str , str ] = {}
549552 self ._attempts = 0
550553 self ._initialized = False # TODO remove
551554 self ._lock = asyncio .Lock ()
@@ -586,8 +589,8 @@ async def loop_time() -> float:
586589
587590 async def _cancel (self ):
588591 try :
589- self ._receiving_task .cancel ()
590- await self ._receiving_task
592+ self ._send_recv_task .cancel ()
593+ await self ._send_recv_task
591594 await self .ws .close ()
592595 except (
593596 AttributeError ,
@@ -601,13 +604,14 @@ async def _cancel(self):
601604 )
602605
603606 async def connect (self , force = False ):
607+ # TODO after connecting, move from _inflight to the queue
604608 now = await self .loop_time ()
605609 self .last_received = now
606610 self .last_sent = now
607611 async with self ._lock :
608612 if self ._exit_task :
609613 self ._exit_task .cancel ()
610- if self .state not in (State .OPEN , State .CONNECTING ):
614+ if self .state not in (State .OPEN , State .CONNECTING ) or force :
611615 if not self ._initialized or force :
612616 try :
613617 await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
@@ -616,21 +620,34 @@ async def connect(self, force=False):
616620 self .ws = await asyncio .wait_for (
617621 connect (self .ws_url , ** self ._options ), timeout = 10.0
618622 )
619- if self ._receiving_task is None or self ._receiving_task .done ():
620- self ._receiving_task = asyncio .get_running_loop ().create_task (
623+ if self ._send_recv_task is None or self ._send_recv_task .done ():
624+ self ._send_recv_task = asyncio .get_running_loop ().create_task (
621625 self ._handler (self .ws )
622626 )
623627 self ._initialized = True
624628
625- async def _handler (self , ws : ClientConnection ):
626- consumer_task = asyncio .create_task (self ._start_receiving (ws ))
627- producer_task = asyncio .create_task (self ._start_sending (ws ))
629+ async def _handler (self , ws : ClientConnection ) -> None :
630+ recv_task = asyncio .create_task (self ._start_receiving (ws ))
631+ send_task = asyncio .create_task (self ._start_sending (ws ))
628632 done , pending = await asyncio .wait (
629- [consumer_task , producer_task ],
633+ [recv_task , send_task ],
630634 return_when = asyncio .FIRST_COMPLETED ,
631635 )
636+ loop = asyncio .get_running_loop ()
637+ should_reconnect = False
632638 for task in pending :
633639 task .cancel ()
640+ if isinstance (task .exception (), asyncio .TimeoutError ):
641+ should_reconnect = True
642+ if should_reconnect is True :
643+ for original_id , payload in list (self ._inflight .items ()):
644+ self ._received [original_id ] = loop .create_future ()
645+ to_send = json .loads (payload )
646+ await self ._sending .put (to_send )
647+ logger .info ("Timeout occurred. Reconnecting." )
648+ await self .connect (True )
649+ await self ._handler (ws = ws )
650+
634651
635652 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
636653 if not self .state != State .CONNECTING :
@@ -662,7 +679,7 @@ async def shutdown(self):
662679 pass
663680 self .ws = None
664681 self ._initialized = False
665- self ._receiving_task = None
682+ self ._send_recv_task = None
666683 self ._is_closing = False
667684
668685 async def _recv (self , recd : bytes ) -> None :
@@ -671,9 +688,12 @@ async def _recv(self, recd: bytes) -> None:
671688 response = json .loads (recd )
672689 self .last_received = await self .loop_time ()
673690 if "id" in response :
691+ async with self ._lock :
692+ self ._inflight .pop (response ["id" ])
674693 self ._received [response ["id" ]].set_result (response )
675694 self ._in_use_ids .remove (response ["id" ])
676695 elif "params" in response :
696+ # TODO self._inflight won't work with subscriptions
677697 sub_id = response ["params" ]["subscription" ]
678698 await self ._received_subscriptions [sub_id ].put (response )
679699 else :
@@ -682,7 +702,9 @@ async def _recv(self, recd: bytes) -> None:
682702 async def _start_receiving (self , ws : ClientConnection ) -> Exception :
683703 try :
684704 while True :
685- await self ._recv (await ws .recv (decode = False ))
705+ if self ._inflight :
706+ recd = await asyncio .wait_for (ws .recv (decode = False ), timeout = self .retry_timeout )
707+ await self ._recv (recd )
686708 except Exception as e :
687709 if isinstance (e , ssl .SSLError ):
688710 e = ConnectionClosed
@@ -696,13 +718,14 @@ async def _start_sending(self, ws) -> Exception:
696718 to_send = None
697719 try :
698720 while True :
699- # TODO possibly when these are pulled from the Queue, they should also go into a dict or set, with the
700- # TODO done_callback assigned to remove them when complete. This could allow easier resending in cases
701- # TODO such as a timeout.
702- to_send = await self ._sending .get ()
721+ to_send_ = await self ._sending .get ()
722+ send_id = to_send_ ["id" ]
723+ to_send = json .dumps (to_send_ )
724+ async with self ._lock :
725+ self ._inflight [send_id ] = to_send
703726 if self ._log_raw_websockets :
704727 raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
705- await ws .send (json . dumps ( to_send ) )
728+ await ws .send (to_send )
706729 self .last_sent = await self .loop_time ()
707730 except Exception as e :
708731 if to_send is not None :
@@ -824,6 +847,7 @@ def __init__(
824847 "write_limit" : 2 ** 16 ,
825848 },
826849 shutdown_timer = ws_shutdown_timer ,
850+ retry_timeout = self .retry_timeout ,
827851 )
828852 else :
829853 self .ws = AsyncMock (spec = Websocket )
0 commit comments