diff --git a/src/wokwi_client/transport.py b/src/wokwi_client/transport.py index 763a387..98ce2fe 100644 --- a/src/wokwi_client/transport.py +++ b/src/wokwi_client/transport.py @@ -35,7 +35,7 @@ def __init__(self, token: str, url: str = TRANSPORT_DEFAULT_WS_URL): self._recv_task: Optional[asyncio.Task[None]] = None self._closed = False - async def connect(self) -> dict[str, Any]: + async def connect(self, throw_error: bool = True) -> dict[str, Any]: self._ws = await websockets.connect( self._url, extra_headers={ @@ -49,7 +49,7 @@ async def connect(self) -> dict[str, Any]: hello_msg = cast(HelloMessage, hello) self._closed = False # Start background message processor - self._recv_task = asyncio.create_task(self._background_recv()) + self._recv_task = asyncio.create_task(self._background_recv(throw_error)) return {"version": hello_msg["appVersion"]} async def close(self) -> None: @@ -109,7 +109,7 @@ async def request(self, command: str, params: dict[str, Any]) -> ResponseMessage finally: del self._response_futures[msg_id] - async def _background_recv(self) -> None: + async def _background_recv(self, throw_error: bool = True) -> None: try: while not self._closed and self._ws is not None: msg: IncomingMessage = await self._recv() @@ -127,6 +127,16 @@ async def _background_recv(self) -> None: except Exception as e: warnings.warn(f"Background recv error: {e}", RuntimeWarning) + if throw_error: + self._closed = True + # Cancel all pending response futures + for future in self._response_futures.values(): + if not future.done(): + future.set_exception(e) + if self._ws: + await self._ws.close() + raise + async def _recv(self) -> IncomingMessage: if self._ws is None: raise WokwiError("Not connected")