@@ -526,6 +526,7 @@ def __init__(
526526 options : Optional [dict ] = None ,
527527 _log_raw_websockets : bool = False ,
528528 retry_timeout : float = 60.0 ,
529+ max_retries : int = 5 ,
529530 ):
530531 """
531532 Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -536,6 +537,10 @@ def __init__(
536537 max_subscriptions: Maximum number of subscriptions per websocket connection
537538 max_connections: Maximum number of connections total
538539 shutdown_timer: Number of seconds to shut down websocket connection after last use
540+ options: Options to pass to the websocket connection
541+ _log_raw_websockets: Whether to log raw websockets in the "raw_websocket" logger
542+ retry_timeout: Timeout in seconds to retry websocket connection
543+ max_retries: Maximum number of retries following a timeout
539544 """
540545 # TODO allow setting max concurrent connections and rpc subscriptions per connection
541546 self .ws_url = ws_url
@@ -555,6 +560,7 @@ def __init__(
555560 self ._options = options if options else {}
556561 self ._log_raw_websockets = _log_raw_websockets
557562 self ._in_use_ids = set ()
563+ self ._max_retries = max_retries
558564
559565 @property
560566 def state (self ):
@@ -615,19 +621,28 @@ async def _handler(self, ws: ClientConnection) -> None:
615621 )
616622 loop = asyncio .get_running_loop ()
617623 should_reconnect = False
624+ is_retry = False
618625 for task in pending :
619626 task .cancel ()
620627 for task in done :
628+ task_res = task .result ()
621629 if isinstance (
622- task . result () , (asyncio .TimeoutError , ConnectionClosed , TimeoutError )
630+ task_res , (asyncio .TimeoutError , ConnectionClosed , TimeoutError )
623631 ):
624632 should_reconnect = True
633+ if isinstance (task_res , (asyncio .TimeoutError , TimeoutError )):
634+ self ._attempts += 1
635+ is_retry = True
625636 if should_reconnect is True :
626637 for original_id , payload in list (self ._inflight .items ()):
627638 self ._received [original_id ] = loop .create_future ()
628639 to_send = json .loads (payload )
629640 await self ._sending .put (to_send )
630- logger .info ("Timeout occurred. Reconnecting." )
641+ if is_retry :
642+ # Otherwise the connection was just closed due to no activity, which should not count against retries
643+ logger .info (
644+ f"Timeout occurred. Reconnecting. Attempt { self ._attempts } of { self ._max_retries } "
645+ )
631646 await self .connect (True )
632647 await self ._handler (ws = self .ws )
633648 elif isinstance (e := recv_task .result (), Exception ):
@@ -690,6 +705,8 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception:
690705 recd = await asyncio .wait_for (
691706 ws .recv (decode = False ), timeout = self .retry_timeout
692707 )
708+ # reset the counter once we successfully receive something back
709+ self ._attempts = 0
693710 await self ._recv (recd )
694711 except Exception as e :
695712 if isinstance (e , ssl .SSLError ):
@@ -873,6 +890,7 @@ def __init__(
873890 },
874891 shutdown_timer = ws_shutdown_timer ,
875892 retry_timeout = self .retry_timeout ,
893+ max_retries = max_retries ,
876894 )
877895 else :
878896 self .ws = AsyncMock (spec = Websocket )
0 commit comments