Skip to content

Commit b9d6468

Browse files
committed
fix(serial, transport): streamline transport request handling and improve error management
1 parent 93c29b0 commit b9d6468

File tree

2 files changed

+44
-31
lines changed

2 files changed

+44
-31
lines changed

src/wokwi_client/serial.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ async def monitor_lines(transport: Transport) -> AsyncGenerator[bytes, None]:
1313
"""
1414
Monitor the serial output lines.
1515
"""
16-
# Create the queue/listener before enabling the monitor to catch all data
16+
await transport.request("serial-monitor:listen", {})
1717
with EventQueue(transport, "serial-monitor:data") as queue:
18-
await transport.request("serial-monitor:listen", {})
1918
while True:
2019
event_msg = await queue.get()
2120
yield bytes(event_msg["payload"]["bytes"])

src/wokwi_client/transport.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MIT
44

55
import asyncio
6+
import contextlib
67
import json
78
import os
89
import warnings
@@ -43,36 +44,33 @@ async def connect(self, throw_error: bool = True) -> dict[str, Any]:
4344
"User-Agent": f"wokwi-client-py/{get_version()}",
4445
},
4546
)
47+
# Handshake: read the hello BEFORE starting the background loop.
4648
hello: IncomingMessage = await self._recv()
4749
if hello["type"] != MSG_TYPE_HELLO or hello.get("protocolVersion") != PROTOCOL_VERSION:
4850
raise ProtocolError(f"Unsupported protocol handshake: {hello}")
4951
hello_msg = cast(HelloMessage, hello)
5052
self._closed = False
51-
# Start background message processor
53+
# Start background message processor AFTER successful hello.
5254
self._recv_task = asyncio.create_task(self._background_recv(throw_error))
5355
return {"version": hello_msg["appVersion"]}
5456

5557
async def close(self) -> None:
5658
self._closed = True
5759
if self._recv_task:
5860
self._recv_task.cancel()
59-
try:
61+
with contextlib.suppress(asyncio.CancelledError):
6062
await self._recv_task
61-
except asyncio.CancelledError:
62-
pass
6363
if self._ws:
6464
await self._ws.close()
6565

6666
def add_event_listener(self, event_type: str, listener: Callable[[EventMessage], Any]) -> None:
67-
"""Register a listener for a specific event type."""
6867
if event_type not in self._event_listeners:
6968
self._event_listeners[event_type] = []
7069
self._event_listeners[event_type].append(listener)
7170

7271
def remove_event_listener(
7372
self, event_type: str, listener: Callable[[EventMessage], Any]
7473
) -> None:
75-
"""Remove a previously registered listener for a specific event type."""
7674
if event_type in self._event_listeners:
7775
self._event_listeners[event_type] = [
7876
registered_listener
@@ -90,52 +88,72 @@ async def _dispatch_event(self, event_msg: EventMessage) -> None:
9088
await result
9189

9290
async def request(self, command: str, params: dict[str, Any]) -> ResponseMessage:
93-
msg_id = str(self._next_id)
94-
self._next_id += 1
9591
if self._ws is None:
9692
raise WokwiError("Not connected")
93+
msg_id = str(self._next_id)
94+
self._next_id += 1
95+
9796
loop = asyncio.get_running_loop()
9897
future: asyncio.Future[ResponseMessage] = loop.create_future()
9998
self._response_futures[msg_id] = future
99+
100100
await self._ws.send(
101101
json.dumps({"type": "command", "command": command, "params": params, "id": msg_id})
102102
)
103103
try:
104104
resp_msg_resp = await future
105105
if resp_msg_resp.get("error"):
106-
result = resp_msg_resp["result"]
107-
raise ServerError(result["message"])
106+
result = resp_msg_resp.get("result", {})
107+
raise ServerError(result.get("message", "Unknown server error"))
108108
return resp_msg_resp
109109
finally:
110-
del self._response_futures[msg_id]
110+
# Remove future mapping if still present (be defensive)
111+
self._response_futures.pop(msg_id, None)
111112

112-
async def _background_recv(self, throw_error: bool = True) -> None:
113+
async def _background_recv(self, throw_error: bool = True) -> None: # noqa: PLR0912
113114
try:
114115
while not self._closed and self._ws is not None:
115116
msg: IncomingMessage = await self._recv()
116117
if msg["type"] == MSG_TYPE_EVENT:
117-
resp_msg_event = cast(EventMessage, msg)
118-
await self._dispatch_event(resp_msg_event)
118+
await self._dispatch_event(cast(EventMessage, msg))
119119
elif msg["type"] == MSG_TYPE_RESPONSE:
120120
resp_msg_resp = cast(ResponseMessage, msg)
121-
future = self._response_futures.get(resp_msg_resp["id"])
121+
resp_id = str(resp_msg_resp.get("id"))
122+
future = self._response_futures.get(resp_id)
122123
if future is None or future.done():
123124
continue
124125
future.set_result(resp_msg_resp)
125-
except (websockets.ConnectionClosed, asyncio.CancelledError):
126-
pass
126+
except asyncio.CancelledError:
127+
# Expected during shutdown via close()
128+
raise
129+
except websockets.ConnectionClosed as e:
130+
# Mark closed and fail pending futures to avoid hangs.
131+
self._closed = True
132+
for fut in list(self._response_futures.values()):
133+
if not fut.done():
134+
fut.set_exception(e)
135+
with contextlib.suppress(Exception):
136+
if self._ws:
137+
await self._ws.close()
138+
if throw_error:
139+
raise
127140
except Exception as e:
128141
warnings.warn(f"Background recv error: {e}", RuntimeWarning)
129-
130142
if throw_error:
131143
self._closed = True
132-
# Cancel all pending response futures
133-
for future in self._response_futures.values():
134-
if not future.done():
135-
future.set_exception(e)
136-
if self._ws:
137-
await self._ws.close()
144+
for fut in list(self._response_futures.values()):
145+
if not fut.done():
146+
fut.set_exception(e)
147+
with contextlib.suppress(Exception):
148+
if self._ws:
149+
await self._ws.close()
138150
raise
151+
finally:
152+
# If we’re exiting the loop and marked closed, ensure no future hangs.
153+
if self._closed:
154+
for fut in list(self._response_futures.values()):
155+
if not fut.done():
156+
fut.set_exception(RuntimeError("Transport receive loop exited"))
139157

140158
async def _recv(self) -> IncomingMessage:
141159
if self._ws is None:
@@ -153,10 +171,6 @@ async def _recv(self) -> IncomingMessage:
153171
if message["type"] == "error":
154172
raise WokwiError(f"Server error: {message['message']}")
155173
if message["type"] == "response" and message.get("error"):
156-
result = (
157-
message["result"]
158-
if "result" in message
159-
else {"code": -1, "message": "Unknown error"}
160-
)
174+
result = message.get("result", {"code": -1, "message": "Unknown error"})
161175
raise WokwiError(f"Server error {result['code']}: {result['message']}")
162176
return cast(IncomingMessage, message)

0 commit comments

Comments
 (0)