diff --git a/src/wokwi_client/client_sync.py b/src/wokwi_client/client_sync.py index daca415..e98f77a 100644 --- a/src/wokwi_client/client_sync.py +++ b/src/wokwi_client/client_sync.py @@ -39,34 +39,31 @@ class WokwiClientSync: tracked, so we can cancel & drain them on `disconnect()`. """ - # Public attributes mirrored for convenience - version: str - last_pause_nanos: int # this proxy resolves via __getattr__ - def __init__(self, token: str, server: str | None = None): - # Create a fresh event loop + thread (daemon so it won't prevent process exit). + # Create a new event loop for the background thread self._loop = asyncio.new_event_loop() + # Event to signal that the event loop is running + self._loop_started_event = threading.Event() + # Start background thread running the event loop self._thread = threading.Thread( target=self._run_loop, args=(self._loop,), daemon=True, name="wokwi-sync-loop" ) self._thread.start() - - # Underlying async client + # **Wait until loop is fully started before proceeding** (prevents race conditions) + if not self._loop_started_event.wait(timeout=8.0): # timeout to avoid deadlock + raise RuntimeError("WokwiClientSync event loop failed to start") + # Initialize underlying async client on the running loop self._async_client = WokwiClient(token, server) - - # Mirror library version for quick access - self.version = self._async_client.version - - # Track background tasks created via run_coroutine_threadsafe (serial monitors) + # Track background monitor tasks (futures) for cancellation on exit self._bg_futures: set[Future[Any]] = set() - - # Idempotent disconnect guard + # Flag to avoid double-closing self._closed = False - @staticmethod - def _run_loop(loop: asyncio.AbstractEventLoop) -> None: - """Background thread loop runner.""" + def _run_loop(self, loop: asyncio.AbstractEventLoop) -> None: + """Target function for the background thread: runs the asyncio event loop.""" asyncio.set_event_loop(loop) + # Signal that the loop is now running and ready to accept tasks + loop.call_soon(self._loop_started_event.set) loop.run_forever() # ----- Internal helpers ------------------------------------------------- @@ -75,8 +72,11 @@ def _submit(self, coro: Coroutine[Any, Any, T]) -> Future[T]: return asyncio.run_coroutine_threadsafe(coro, self._loop) def _call(self, coro: Coroutine[Any, Any, T]) -> T: - """Submit a coroutine to the loop and block until it completes (or raises).""" - return self._submit(coro).result() + """Submit a coroutine to the background loop and wait for result.""" + if self._closed: + raise RuntimeError("Cannot call methods on a closed WokwiClientSync") + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result() # Block until the coroutine completes or raises def _add_bg_future(self, fut: Future[Any]) -> None: """Track a background future so we can cancel & drain on shutdown.""" @@ -96,30 +96,21 @@ def connect(self) -> dict[str, Any]: return self._call(self._async_client.connect()) def disconnect(self) -> None: - """Disconnect and stop the background loop. - - Order matters: - 1) Cancel and drain background serial-monitor futures. - 2) Disconnect the underlying transport. - 3) Stop the loop and join the thread. - Safe to call multiple times. - """ if self._closed: return - self._closed = True # (1) Cancel + drain monitors for fut in list(self._bg_futures): fut.cancel() for fut in list(self._bg_futures): with contextlib.suppress(FutureTimeoutError, Exception): - # Give each monitor a short window to handle cancellation cleanly. fut.result(timeout=1.0) self._bg_futures.discard(fut) # (2) Disconnect transport with contextlib.suppress(Exception): - self._call(self._async_client._transport.close()) + fut = asyncio.run_coroutine_threadsafe(self._async_client.disconnect(), self._loop) + fut.result(timeout=2.0) # (3) Stop loop / join thread if self._loop.is_running(): @@ -127,6 +118,13 @@ def disconnect(self) -> None: if self._thread.is_alive(): self._thread.join(timeout=5.0) + # (4) Close loop + with contextlib.suppress(Exception): + self._loop.close() + + # (5) Mark closed at the very end + self._closed = True + # ----- Serial monitoring ------------------------------------------------ def serial_monitor(self, callback: Callable[[bytes], Any]) -> None: """ @@ -138,17 +136,25 @@ def serial_monitor(self, callback: Callable[[bytes], Any]) -> None: """ async def _runner() -> None: - async for line in monitor_lines(self._async_client._transport): - try: - maybe_awaitable = callback(line) - if inspect.isawaitable(maybe_awaitable): - await maybe_awaitable - except Exception: - # Keep the monitor alive even if the callback throws. - pass - - fut = self._submit(_runner()) - self._add_bg_future(fut) + try: + # **Prepare to receive serial events before enabling monitor** + # (monitor_lines will subscribe to serial events internally) + async for line in monitor_lines(self._async_client._transport): + try: + result = callback(line) # invoke callback with the raw bytes line + if inspect.isawaitable(result): + await result # await if callback is async + except Exception: + # Swallow exceptions from callback to keep monitor alive + pass + finally: + # Remove this task’s future from the set when done + self._bg_futures.discard(task_future) + + # Schedule the serial monitor runner on the event loop: + task_future = asyncio.run_coroutine_threadsafe(_runner(), self._loop) + self._bg_futures.add(task_future) + # (No return value; monitoring happens in background) def serial_monitor_cat(self, decode_utf8: bool = True, errors: str = "replace") -> None: """ @@ -160,34 +166,32 @@ def serial_monitor_cat(self, decode_utf8: bool = True, errors: str = "replace") """ async def _runner() -> None: - async for line in monitor_lines(self._async_client._transport): - try: - if decode_utf8: - try: - print(line.decode("utf-8", errors=errors), end="", flush=True) - except UnicodeDecodeError: + try: + # **Subscribe to serial events before reading output** + async for line in monitor_lines(self._async_client._transport): + try: + if decode_utf8: + # Decode bytes to string (handle errors per parameter) + text = line.decode("utf-8", errors=errors) + print(text, end="", flush=True) + else: + # Print raw bytes print(line, end="", flush=True) - else: - print(line, end="", flush=True) - except Exception: - # Keep the monitor alive even if printing raises intermittently. - pass + except Exception: + # Swallow print errors to keep stream alive + pass + finally: + self._bg_futures.discard(task_future) - fut = self._submit(_runner()) - self._add_bg_future(fut) + task_future = asyncio.run_coroutine_threadsafe(_runner(), self._loop) + self._bg_futures.add(task_future) + # (No return; printing continues in background) def stop_serial_monitors(self) -> None: - """ - Cancel and drain all running serial monitors without disconnecting. - - Useful if you want to stop printing but keep the connection alive. - """ + """Stop all active serial monitor background tasks.""" for fut in list(self._bg_futures): fut.cancel() - for fut in list(self._bg_futures): - with contextlib.suppress(FutureTimeoutError, Exception): - fut.result(timeout=1.0) - self._bg_futures.discard(fut) + self._bg_futures.clear() # ----- Dynamic method wrapping ----------------------------------------- def __getattr__(self, name: str) -> Any: @@ -197,16 +201,17 @@ def __getattr__(self, name: str) -> Any: If the attribute on `WokwiClient` is a coroutine function, return a sync wrapper that blocks until the coroutine completes. """ - # Explicit methods above (serial monitors) take precedence. + # Explicit methods (like serial_monitor functions above) take precedence over __getattr__ attr = getattr(self._async_client, name) if callable(attr): + # Get the function object from WokwiClient class (unbound) to check if coroutine func = getattr(WokwiClient, name, None) if func is not None and inspect.iscoroutinefunction(func): - + # Wrap coroutine method to run in background loop def sync_wrapper(*args: Any, **kwargs: Any) -> Any: return self._call(attr(*args, **kwargs)) sync_wrapper.__name__ = name - sync_wrapper.__doc__ = func.__doc__ + sync_wrapper.__doc__ = getattr(func, "__doc__", "") return sync_wrapper return attr diff --git a/src/wokwi_client/serial.py b/src/wokwi_client/serial.py index 7c751d3..ebaca08 100644 --- a/src/wokwi_client/serial.py +++ b/src/wokwi_client/serial.py @@ -10,6 +10,9 @@ async def monitor_lines(transport: Transport) -> AsyncGenerator[bytes, None]: + """ + Monitor the serial output lines. + """ await transport.request("serial-monitor:listen", {}) with EventQueue(transport, "serial-monitor:data") as queue: while True: diff --git a/src/wokwi_client/transport.py b/src/wokwi_client/transport.py index 98ce2fe..74775a3 100644 --- a/src/wokwi_client/transport.py +++ b/src/wokwi_client/transport.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT import asyncio +import contextlib import json import os import warnings @@ -48,7 +49,7 @@ async def connect(self, throw_error: bool = True) -> dict[str, Any]: raise ProtocolError(f"Unsupported protocol handshake: {hello}") hello_msg = cast(HelloMessage, hello) self._closed = False - # Start background message processor + # Start background message processor after successful hello. self._recv_task = asyncio.create_task(self._background_recv(throw_error)) return {"version": hello_msg["appVersion"]} @@ -56,15 +57,12 @@ async def close(self) -> None: self._closed = True if self._recv_task: self._recv_task.cancel() - try: + with contextlib.suppress(asyncio.CancelledError): await self._recv_task - except asyncio.CancelledError: - pass if self._ws: await self._ws.close() def add_event_listener(self, event_type: str, listener: Callable[[EventMessage], Any]) -> None: - """Register a listener for a specific event type.""" if event_type not in self._event_listeners: self._event_listeners[event_type] = [] self._event_listeners[event_type].append(listener) @@ -72,7 +70,6 @@ def add_event_listener(self, event_type: str, listener: Callable[[EventMessage], def remove_event_listener( self, event_type: str, listener: Callable[[EventMessage], Any] ) -> None: - """Remove a previously registered listener for a specific event type.""" if event_type in self._event_listeners: self._event_listeners[event_type] = [ registered_listener @@ -90,52 +87,72 @@ async def _dispatch_event(self, event_msg: EventMessage) -> None: await result async def request(self, command: str, params: dict[str, Any]) -> ResponseMessage: - msg_id = str(self._next_id) - self._next_id += 1 if self._ws is None: raise WokwiError("Not connected") + msg_id = str(self._next_id) + self._next_id += 1 + loop = asyncio.get_running_loop() future: asyncio.Future[ResponseMessage] = loop.create_future() self._response_futures[msg_id] = future + await self._ws.send( json.dumps({"type": "command", "command": command, "params": params, "id": msg_id}) ) try: resp_msg_resp = await future if resp_msg_resp.get("error"): - result = resp_msg_resp["result"] - raise ServerError(result["message"]) + result = resp_msg_resp.get("result", {}) + raise ServerError(result.get("message", "Unknown server error")) return resp_msg_resp finally: - del self._response_futures[msg_id] + # Remove future mapping if still present (be defensive) + self._response_futures.pop(msg_id, None) - async def _background_recv(self, throw_error: bool = True) -> None: + async def _background_recv(self, throw_error: bool = True) -> None: # noqa: PLR0912 try: while not self._closed and self._ws is not None: msg: IncomingMessage = await self._recv() if msg["type"] == MSG_TYPE_EVENT: - resp_msg_event = cast(EventMessage, msg) - await self._dispatch_event(resp_msg_event) + await self._dispatch_event(cast(EventMessage, msg)) elif msg["type"] == MSG_TYPE_RESPONSE: resp_msg_resp = cast(ResponseMessage, msg) - future = self._response_futures.get(resp_msg_resp["id"]) + resp_id = str(resp_msg_resp.get("id")) + future = self._response_futures.get(resp_id) if future is None or future.done(): continue future.set_result(resp_msg_resp) - except (websockets.ConnectionClosed, asyncio.CancelledError): - pass + except asyncio.CancelledError: + # Expected during shutdown via close() + raise + except websockets.ConnectionClosed as e: + # Mark closed and fail pending futures to avoid hangs. + self._closed = True + for fut in list(self._response_futures.values()): + if not fut.done(): + fut.set_exception(e) + with contextlib.suppress(Exception): + if self._ws: + await self._ws.close() + if throw_error: + raise except Exception as e: warnings.warn(f"Background recv error: {e}", RuntimeWarning) - if throw_error: self._closed = True - # Cancel all pending response futures - for future in self._response_futures.values(): - if not future.done(): - future.set_exception(e) - if self._ws: - await self._ws.close() + for fut in list(self._response_futures.values()): + if not fut.done(): + fut.set_exception(e) + with contextlib.suppress(Exception): + if self._ws: + await self._ws.close() raise + finally: + # If we’re exiting the loop and marked closed, ensure no future hangs. + if self._closed: + for fut in list(self._response_futures.values()): + if not fut.done(): + fut.set_exception(RuntimeError("Transport receive loop exited")) async def _recv(self) -> IncomingMessage: if self._ws is None: @@ -153,10 +170,6 @@ async def _recv(self) -> IncomingMessage: if message["type"] == "error": raise WokwiError(f"Server error: {message['message']}") if message["type"] == "response" and message.get("error"): - result = ( - message["result"] - if "result" in message - else {"code": -1, "message": "Unknown error"} - ) + result = message.get("result", {"code": -1, "message": "Unknown error"}) raise WokwiError(f"Server error {result['code']}: {result['message']}") return cast(IncomingMessage, message)