22#
33# SPDX-License-Identifier: MIT
44
5+ import asyncio
56import json
67import os
78import warnings
8- from typing import Any , Optional , cast
9+ from typing import Any , Callable , Optional , cast
910
1011import websockets
1112
1213from .__version__ import get_version
1314from .constants import (
1415 DEFAULT_WS_URL ,
16+ MSG_TYPE_EVENT ,
1517 MSG_TYPE_HELLO ,
1618 MSG_TYPE_RESPONSE ,
1719 PROTOCOL_VERSION ,
1820)
1921from .exceptions import ProtocolError , ServerError , WokwiError
20- from .protocol_types import HelloMessage , IncomingMessage , ResponseMessage
22+ from .protocol_types import EventMessage , HelloMessage , IncomingMessage , ResponseMessage
2123
2224TRANSPORT_DEFAULT_WS_URL = os .getenv ("WOKWI_CLI_SERVER" , DEFAULT_WS_URL )
2325
@@ -28,6 +30,10 @@ def __init__(self, token: str, url: str = TRANSPORT_DEFAULT_WS_URL):
2830 self ._url = url
2931 self ._next_id = 1
3032 self ._ws : Optional [websockets .WebSocketClientProtocol ] = None
33+ self ._event_listeners : dict [str , list [Callable [[EventMessage ], Any ]]] = {}
34+ self ._response_futures : dict [str , asyncio .Future [ResponseMessage ]] = {}
35+ self ._recv_task : Optional [asyncio .Task [None ]] = None
36+ self ._closed = False
3137
3238 async def connect (self ) -> dict [str , Any ]:
3339 self ._ws = await websockets .connect (
@@ -41,28 +47,85 @@ async def connect(self) -> dict[str, Any]:
4147 if hello ["type" ] != MSG_TYPE_HELLO or hello .get ("protocolVersion" ) != PROTOCOL_VERSION :
4248 raise ProtocolError (f"Unsupported protocol handshake: { hello } " )
4349 hello_msg = cast (HelloMessage , hello )
50+ self ._closed = False
51+ # Start background message processor
52+ self ._recv_task = asyncio .create_task (self ._background_recv ())
4453 return {"version" : hello_msg ["appVersion" ]}
4554
4655 async def close (self ) -> None :
56+ self ._closed = True
57+ if self ._recv_task :
58+ self ._recv_task .cancel ()
59+ try :
60+ await self ._recv_task
61+ except asyncio .CancelledError :
62+ pass
4763 if self ._ws :
4864 await self ._ws .close ()
4965
66+ def add_event_listener (self , event_type : str , listener : Callable [[EventMessage ], Any ]) -> None :
67+ """Register a listener for a specific event type."""
68+ if event_type not in self ._event_listeners :
69+ self ._event_listeners [event_type ] = []
70+ self ._event_listeners [event_type ].append (listener )
71+
72+ def remove_event_listener (
73+ self , event_type : str , listener : Callable [[EventMessage ], Any ]
74+ ) -> None :
75+ """Remove a previously registered listener for a specific event type."""
76+ if event_type in self ._event_listeners :
77+ self ._event_listeners [event_type ] = [
78+ registered_listener
79+ for registered_listener in self ._event_listeners [event_type ]
80+ if registered_listener != listener
81+ ]
82+ if not self ._event_listeners [event_type ]:
83+ del self ._event_listeners [event_type ]
84+
85+ async def _dispatch_event (self , event_msg : EventMessage ) -> None :
86+ listeners = self ._event_listeners .get (event_msg ["event" ], [])
87+ for listener in listeners :
88+ result = listener (event_msg )
89+ if hasattr (result , "__await__" ):
90+ await result
91+
5092 async def request (self , command : str , params : dict [str , Any ]) -> ResponseMessage :
5193 msg_id = str (self ._next_id )
5294 self ._next_id += 1
5395 if self ._ws is None :
5496 raise WokwiError ("Not connected" )
97+ loop = asyncio .get_running_loop ()
98+ future : asyncio .Future [ResponseMessage ] = loop .create_future ()
99+ self ._response_futures [msg_id ] = future
55100 await self ._ws .send (
56101 json .dumps ({"type" : "command" , "command" : command , "params" : params , "id" : msg_id })
57102 )
58- while True :
59- msg : IncomingMessage = await self ._recv ()
60- if msg ["type" ] == MSG_TYPE_RESPONSE and msg .get ("id" ) == msg_id :
61- resp_msg = cast (ResponseMessage , msg )
62- if resp_msg .get ("error" ):
63- result = resp_msg ["result" ]
64- raise ServerError (result ["message" ])
65- return resp_msg
103+ try :
104+ resp_msg_resp = await future
105+ if resp_msg_resp .get ("error" ):
106+ result = resp_msg_resp ["result" ]
107+ raise ServerError (result ["message" ])
108+ return resp_msg_resp
109+ finally :
110+ del self ._response_futures [msg_id ]
111+
112+ async def _background_recv (self ) -> None :
113+ try :
114+ while not self ._closed and self ._ws is not None :
115+ msg : IncomingMessage = await self ._recv ()
116+ if msg ["type" ] == MSG_TYPE_EVENT :
117+ resp_msg_event = cast (EventMessage , msg )
118+ await self ._dispatch_event (resp_msg_event )
119+ elif msg ["type" ] == MSG_TYPE_RESPONSE :
120+ resp_msg_resp = cast (ResponseMessage , msg )
121+ future = self ._response_futures .get (resp_msg_resp ["id" ])
122+ if future is None or future .done ():
123+ continue
124+ future .set_result (resp_msg_resp )
125+ except (websockets .ConnectionClosed , asyncio .CancelledError ):
126+ pass
127+ except Exception as e :
128+ warnings .warn (f"Background recv error: { e } " , RuntimeWarning )
66129
67130 async def _recv (self ) -> IncomingMessage :
68131 if self ._ws is None :
@@ -87,6 +150,3 @@ async def _recv(self) -> IncomingMessage:
87150 )
88151 raise WokwiError (f"Server error { result ['code' ]} : { result ['message' ]} " )
89152 return cast (IncomingMessage , message )
90-
91- async def recv (self ) -> IncomingMessage :
92- return await self ._recv ()
0 commit comments