33# SPDX-License-Identifier: MIT
44
55import asyncio
6+ import contextlib
67import json
78import os
89import warnings
@@ -43,36 +44,33 @@ async def connect(self, throw_error: bool = True) -> dict[str, Any]:
4344 "User-Agent" : f"wokwi-client-py/{ get_version ()} " ,
4445 },
4546 )
47+ # Handshake: read the hello BEFORE starting the background loop.
4648 hello : IncomingMessage = await self ._recv ()
4749 if hello ["type" ] != MSG_TYPE_HELLO or hello .get ("protocolVersion" ) != PROTOCOL_VERSION :
4850 raise ProtocolError (f"Unsupported protocol handshake: { hello } " )
4951 hello_msg = cast (HelloMessage , hello )
5052 self ._closed = False
51- # Start background message processor
53+ # Start background message processor AFTER successful hello.
5254 self ._recv_task = asyncio .create_task (self ._background_recv (throw_error ))
5355 return {"version" : hello_msg ["appVersion" ]}
5456
5557 async def close (self ) -> None :
5658 self ._closed = True
5759 if self ._recv_task :
5860 self ._recv_task .cancel ()
59- try :
61+ with contextlib . suppress ( asyncio . CancelledError ) :
6062 await self ._recv_task
61- except asyncio .CancelledError :
62- pass
6363 if self ._ws :
6464 await self ._ws .close ()
6565
6666 def add_event_listener (self , event_type : str , listener : Callable [[EventMessage ], Any ]) -> None :
67- """Register a listener for a specific event type."""
6867 if event_type not in self ._event_listeners :
6968 self ._event_listeners [event_type ] = []
7069 self ._event_listeners [event_type ].append (listener )
7170
7271 def remove_event_listener (
7372 self , event_type : str , listener : Callable [[EventMessage ], Any ]
7473 ) -> None :
75- """Remove a previously registered listener for a specific event type."""
7674 if event_type in self ._event_listeners :
7775 self ._event_listeners [event_type ] = [
7876 registered_listener
@@ -90,52 +88,72 @@ async def _dispatch_event(self, event_msg: EventMessage) -> None:
9088 await result
9189
9290 async def request (self , command : str , params : dict [str , Any ]) -> ResponseMessage :
93- msg_id = str (self ._next_id )
94- self ._next_id += 1
9591 if self ._ws is None :
9692 raise WokwiError ("Not connected" )
93+ msg_id = str (self ._next_id )
94+ self ._next_id += 1
95+
9796 loop = asyncio .get_running_loop ()
9897 future : asyncio .Future [ResponseMessage ] = loop .create_future ()
9998 self ._response_futures [msg_id ] = future
99+
100100 await self ._ws .send (
101101 json .dumps ({"type" : "command" , "command" : command , "params" : params , "id" : msg_id })
102102 )
103103 try :
104104 resp_msg_resp = await future
105105 if resp_msg_resp .get ("error" ):
106- result = resp_msg_resp [ "result" ]
107- raise ServerError (result [ "message" ] )
106+ result = resp_msg_resp . get ( "result" , {})
107+ raise ServerError (result . get ( "message" , "Unknown server error" ) )
108108 return resp_msg_resp
109109 finally :
110- del self ._response_futures [msg_id ]
110+ # Remove future mapping if still present (be defensive)
111+ self ._response_futures .pop (msg_id , None )
111112
112- async def _background_recv (self , throw_error : bool = True ) -> None :
113+ async def _background_recv (self , throw_error : bool = True ) -> None : # noqa: PLR0912
113114 try :
114115 while not self ._closed and self ._ws is not None :
115116 msg : IncomingMessage = await self ._recv ()
116117 if msg ["type" ] == MSG_TYPE_EVENT :
117- resp_msg_event = cast (EventMessage , msg )
118- await self ._dispatch_event (resp_msg_event )
118+ await self ._dispatch_event (cast (EventMessage , msg ))
119119 elif msg ["type" ] == MSG_TYPE_RESPONSE :
120120 resp_msg_resp = cast (ResponseMessage , msg )
121- future = self ._response_futures .get (resp_msg_resp ["id" ])
121+ resp_id = str (resp_msg_resp .get ("id" ))
122+ future = self ._response_futures .get (resp_id )
122123 if future is None or future .done ():
123124 continue
124125 future .set_result (resp_msg_resp )
125- except (websockets .ConnectionClosed , asyncio .CancelledError ):
126- pass
126+ except asyncio .CancelledError :
127+ # Expected during shutdown via close()
128+ raise
129+ except websockets .ConnectionClosed as e :
130+ # Mark closed and fail pending futures to avoid hangs.
131+ self ._closed = True
132+ for fut in list (self ._response_futures .values ()):
133+ if not fut .done ():
134+ fut .set_exception (e )
135+ with contextlib .suppress (Exception ):
136+ if self ._ws :
137+ await self ._ws .close ()
138+ if throw_error :
139+ raise
127140 except Exception as e :
128141 warnings .warn (f"Background recv error: { e } " , RuntimeWarning )
129-
130142 if throw_error :
131143 self ._closed = True
132- # Cancel all pending response futures
133- for future in self . _response_futures . values ():
134- if not future . done ():
135- future . set_exception ( e )
136- if self ._ws :
137- await self ._ws .close ()
144+ for fut in list ( self . _response_futures . values ()):
145+ if not fut . done ():
146+ fut . set_exception ( e )
147+ with contextlib . suppress ( Exception ):
148+ if self ._ws :
149+ await self ._ws .close ()
138150 raise
151+ finally :
152+ # If we’re exiting the loop and marked closed, ensure no future hangs.
153+ if self ._closed :
154+ for fut in list (self ._response_futures .values ()):
155+ if not fut .done ():
156+ fut .set_exception (RuntimeError ("Transport receive loop exited" ))
139157
140158 async def _recv (self ) -> IncomingMessage :
141159 if self ._ws is None :
@@ -153,10 +171,6 @@ async def _recv(self) -> IncomingMessage:
153171 if message ["type" ] == "error" :
154172 raise WokwiError (f"Server error: { message ['message' ]} " )
155173 if message ["type" ] == "response" and message .get ("error" ):
156- result = (
157- message ["result" ]
158- if "result" in message
159- else {"code" : - 1 , "message" : "Unknown error" }
160- )
174+ result = message .get ("result" , {"code" : - 1 , "message" : "Unknown error" })
161175 raise WokwiError (f"Server error { result ['code' ]} : { result ['message' ]} " )
162176 return cast (IncomingMessage , message )
0 commit comments