1212)
1313from pydantic import StrictStr
1414from websockets .client import WebSocketClientProtocol , connect
15+ from websockets .exceptions import (
16+ ConnectionClosedError ,
17+ WebSocketException ,
18+ )
1519
1620from api .models .error import Error
1721from api .models .logs import Stdout , Stderr
2731
2832logger = logging .getLogger (__name__ )
2933
34+ MAX_RECONNECT_RETRIES = 3
35+ PING_TIMEOUT = 30
36+
3037
3138class Execution :
3239 def __init__ (self , in_background : bool = False ):
@@ -61,6 +68,15 @@ def __init__(self, context_id: str, session_id: str, language: str, cwd: str):
6168 self ._executions : Dict [str , Execution ] = {}
6269 self ._lock = asyncio .Lock ()
6370
71+ async def reconnect (self ):
72+ if self ._ws is not None :
73+ await self ._ws .close (reason = "Reconnecting" )
74+
75+ if self ._receive_task is not None :
76+ await self ._receive_task
77+
78+ await self .connect ()
79+
6480 async def connect (self ):
6581 logger .debug (f"WebSocket connecting to { self .url } " )
6682
@@ -69,6 +85,7 @@ async def connect(self):
6985
7086 self ._ws = await connect (
7187 self .url ,
88+ ping_timeout = PING_TIMEOUT ,
7289 max_size = None ,
7390 max_queue = None ,
7491 logger = ws_logger ,
@@ -274,9 +291,6 @@ async def execute(
274291 env_vars : Dict [StrictStr , str ],
275292 access_token : str ,
276293 ):
277- message_id = str (uuid .uuid4 ())
278- self ._executions [message_id ] = Execution ()
279-
280294 if self ._ws is None :
281295 raise Exception ("WebSocket not connected" )
282296
@@ -313,13 +327,40 @@ async def execute(
313327 )
314328 complete_code = f"{ indented_env_code } \n { complete_code } "
315329
316- logger .info (
317- f"Sending code for the execution ({ message_id } ): { complete_code } "
318- )
319- request = self ._get_execute_request (message_id , complete_code , False )
330+ message_id = str (uuid .uuid4 ())
331+ execution = Execution ()
332+ self ._executions [message_id ] = execution
320333
321334 # Send the code for execution
322- await self ._ws .send (request )
335+ # Initial request and retries
336+ for i in range (1 + MAX_RECONNECT_RETRIES ):
337+ try :
338+ logger .info (
339+ f"Sending code for the execution ({ message_id } ): { complete_code } "
340+ )
341+ request = self ._get_execute_request (
342+ message_id , complete_code , False
343+ )
344+ await self ._ws .send (request )
345+ break
346+ except (ConnectionClosedError , WebSocketException ) as e :
347+ # Keep the last result, even if error
348+ if i < MAX_RECONNECT_RETRIES :
349+ logger .warning (
350+ f"WebSocket connection lost while sending execution request, { i + 1 } . reconnecting...: { str (e )} "
351+ )
352+ await self .reconnect ()
353+ else :
354+ # The retry didn't help, request wasn't sent successfully
355+ logger .error ("Failed to send execution request" )
356+ await execution .queue .put (
357+ Error (
358+ name = "WebSocketError" ,
359+ value = "Failed to send execution request" ,
360+ traceback = "" ,
361+ )
362+ )
363+ await execution .queue .put (UnexpectedEndOfExecution ())
323364
324365 # Stream the results
325366 async for item in self ._wait_for_result (message_id ):
@@ -343,6 +384,18 @@ async def _receive_message(self):
343384 await self ._process_message (json .loads (message ))
344385 except Exception as e :
345386 logger .error (f"WebSocket received error while receiving messages: { str (e )} " )
387+ finally :
388+ # To prevent infinite hang, we need to cancel all ongoing execution as we could lost results during the reconnect
389+ # Thanks to the locking, there can be either no ongoing execution or just one.
390+ for key , execution in self ._executions .items ():
391+ await execution .queue .put (
392+ Error (
393+ name = "WebSocketError" ,
394+ value = "The connections was lost, rerun the code to get the results" ,
395+ traceback = "" ,
396+ )
397+ )
398+ await execution .queue .put (UnexpectedEndOfExecution ())
346399
347400 async def _process_message (self , data : dict ):
348401 """
0 commit comments