1212
1313import shortuuid
1414from pydantic import BaseModel
15+ from starlette .websockets import WebSocketState
1516
1617from .core import OpenInterpreter
1718
@@ -387,12 +388,14 @@ async def home():
387388 async def websocket_endpoint (websocket : WebSocket ):
388389 await websocket .accept ()
389390
390- try :
391+ try : # solving it ;)/ # killian super wrote this
391392
392393 async def receive_input ():
393394 authenticated = False
394395 while True :
395396 try :
397+ if websocket .client_state != WebSocketState .CONNECTED :
398+ return
396399 data = await websocket .receive ()
397400
398401 if not authenticated :
@@ -425,7 +428,7 @@ async def receive_input():
425428 data = data ["bytes" ]
426429 await async_interpreter .input (data )
427430 elif data .get ("type" ) == "websocket.disconnect" :
428- print ("Disconnecting ." )
431+ print ("Client wants to disconnect, that's fine. ." )
429432 return
430433 else :
431434 print ("Invalid data:" , data )
@@ -446,6 +449,8 @@ async def receive_input():
446449
447450 async def send_output ():
448451 while True :
452+ if websocket .client_state != WebSocketState .CONNECTED :
453+ return
449454 try :
450455 # First, try to send any unsent messages
451456 while async_interpreter .unsent_messages :
@@ -488,9 +493,12 @@ async def send_message(output):
488493 ):
489494 output ["id" ] = id
490495
491- for attempt in range (100 ):
492- if websocket .client_state == 3 : # 3 represents 'CLOSED' state
496+ for attempt in range (20 ):
497+ # time.sleep(0.5)
498+
499+ if websocket .client_state != WebSocketState .CONNECTED :
493500 break
501+
494502 try :
495503 if isinstance (output , bytes ):
496504 await websocket .send_bytes (output )
@@ -501,7 +509,7 @@ async def send_message(output):
501509
502510 if async_interpreter .require_acknowledge :
503511 acknowledged = False
504- for _ in range (1000 ):
512+ for _ in range (100 ):
505513 if id in async_interpreter .acknowledged_outputs :
506514 async_interpreter .acknowledged_outputs .remove (id )
507515 acknowledged = True
@@ -523,10 +531,13 @@ async def send_message(output):
523531 await asyncio .sleep (0.05 )
524532
525533 # If we've reached this point, we've failed to send after 100 attempts
526- async_interpreter .unsent_messages .append (output )
527- print (
528- f"Added message to unsent_messages queue after failed attempts: { output } "
529- )
534+ if output not in async_interpreter .unsent_messages :
535+ async_interpreter .unsent_messages .append (output )
536+ print (
537+ f"Added message to unsent_messages queue after failed attempts: { output } "
538+ )
539+ else :
540+ print ("Why was this already in unsent_messages?" , output )
530541
531542 await asyncio .gather (receive_input (), send_output ())
532543
@@ -731,6 +742,10 @@ def __init__(self, async_interpreter, host=None, port=None):
731742 # Add authentication middleware
732743 @self .app .middleware ("http" )
733744 async def validate_api_key (request : Request , call_next ):
745+ # Ignore authentication for the /heartbeat route
746+ if request .url .path == "/heartbeat" :
747+ return await call_next (request )
748+
734749 api_key = request .headers .get ("X-API-KEY" )
735750 if self .authenticate (api_key ):
736751 response = await call_next (request )
0 commit comments