Skip to content

Commit b4194a8

Browse files
committed
feat: enhance type hinting across client and control modules for improved clarity
1 parent cc325df commit b4194a8

File tree

6 files changed

+76
-59
lines changed

6 files changed

+76
-59
lines changed

src/wokwi_client/client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5+
import typing
56
from pathlib import Path
67
from typing import Any, Optional
78

@@ -181,7 +182,7 @@ async def serial_monitor_cat(self) -> None:
181182
async for line in monitor_lines(self._transport):
182183
print(line.decode("utf-8"), end="", flush=True)
183184

184-
async def serial_write(self, data: bytes | str | list[int]) -> None:
185+
async def serial_write(self, data: typing.Union[bytes, str, list[int]]) -> None:
185186
"""Write data to the simulation serial monitor interface."""
186187
await write_serial(self._transport, data)
187188

@@ -210,7 +211,9 @@ async def listen_pin(self, part: str, pin: str, listen: bool = True) -> Response
210211
"""
211212
return await pin_listen(self._transport, part=part, pin=pin, listen=listen)
212213

213-
async def set_control(self, part: str, control: str, value: int | bool | float) -> ResponseMessage:
214+
async def set_control(
215+
self, part: str, control: str, value: typing.Union[int, bool, float]
216+
) -> ResponseMessage:
214217
"""Set a control value (e.g. simulate button press).
215218
216219
Args:

src/wokwi_client/client_sync.py

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,67 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5+
from __future__ import annotations
6+
57
import asyncio
68
import logging
79
import threading
810
import typing as t
11+
from concurrent.futures import Future
912
from pathlib import Path
1013

1114
from wokwi_client import WokwiClient
1215
from wokwi_client.serial import monitor_lines as monitor_serial_lines
1316

17+
if t.TYPE_CHECKING:
18+
from collections.abc import Iterable
19+
1420

1521
class WokwiClientSync:
1622
"""Synchronous wrapper around the async WokwiClient."""
1723

18-
def __init__(self, token: str, server: t.Optional[str] = None):
24+
token: str
25+
server: t.Optional[str]
26+
_loop: t.Optional[asyncio.AbstractEventLoop]
27+
_loop_thread: t.Optional[threading.Thread]
28+
_client: t.Optional[WokwiClient]
29+
_monitor_task: t.Optional[Future[t.Any]]
30+
_connected: bool
31+
32+
def __init__(self, token: str, server: t.Optional[str] = None) -> None:
1933
self.token = token
2034
self.server = server
2135
self._loop = None
2236
self._loop_thread = None
2337
self._client = None
38+
self._monitor_task = None
2439
self._connected = False
2540

26-
def _ensure_loop(self):
27-
"""Ensure the async event loop is running."""
41+
def _ensure_loop(self) -> None:
2842
if self._loop is None:
2943
self._loop = asyncio.new_event_loop()
3044
self._loop_thread = threading.Thread(target=self._loop.run_forever, daemon=True)
3145
self._loop_thread.start()
3246

33-
def _run_async(self, coro, timeout=30):
34-
"""Run an async coroutine synchronously."""
47+
def _run_async(self, coro: t.Coroutine[t.Any, t.Any, t.Any], timeout: float = 30) -> t.Any:
3548
self._ensure_loop()
49+
assert self._loop is not None
3650
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
3751
return future.result(timeout=timeout)
3852

39-
def connect(self):
40-
"""Connect to Wokwi server."""
53+
def connect(self) -> t.Dict[str, t.Any]:
4154
if not self._connected:
4255
self._client = WokwiClient(self.token, self.server)
43-
result = self._run_async(self._client.connect())
56+
result: t.Dict[str, t.Any] = t.cast(t.Dict[str, t.Any], self._run_async(self._client.connect()))
4457
self._connected = True
4558
return result
4659
return {}
4760

48-
def disconnect(self):
49-
"""Disconnect from Wokwi server."""
61+
def disconnect(self) -> None:
5062
if self._connected and self._client:
5163
try:
52-
# Stop any ongoing monitor task
53-
if hasattr(self, '_monitor_task') and self._monitor_task:
64+
if self._monitor_task:
5465
self._monitor_task.cancel()
55-
56-
# Disconnect the client
5766
self._run_async(self._client.disconnect(), timeout=5)
5867
except Exception as e:
5968
logging.debug(f"Error during disconnect: {e}")
@@ -70,79 +79,85 @@ def disconnect(self):
7079
self._loop = None
7180
self._loop_thread = None
7281

73-
def upload(self, name: str, content: bytes):
74-
"""Upload a file to the simulator from bytes content."""
82+
def upload(self, name: str, content: bytes) -> t.Any:
7583
if not self._connected:
7684
raise RuntimeError("Client not connected")
85+
assert self._client is not None
7786
return self._run_async(self._client.upload(name, content))
7887

79-
def upload_file(self, filename: str, local_path: t.Optional[Path] = None):
80-
"""Upload a file to the simulator."""
88+
def upload_file(self, filename: str, local_path: t.Optional[Path] = None) -> t.Any:
8189
if not self._connected:
8290
raise RuntimeError("Client not connected")
91+
assert self._client is not None
8392
return self._run_async(self._client.upload_file(filename, local_path))
8493

85-
def start_simulation(self, firmware: str, elf: t.Optional[str] = None, pause: bool = False, chips: list[str] = []):
86-
"""Start a simulation."""
94+
def start_simulation(
95+
self,
96+
firmware: str,
97+
elf: t.Optional[str] = None,
98+
pause: bool = False,
99+
chips: t.Optional[t.List[str]] = None,
100+
) -> t.Any:
87101
if not self._connected:
88102
raise RuntimeError("Client not connected")
89-
return self._run_async(self._client.start_simulation(firmware, elf, pause, chips))
103+
assert self._client is not None
104+
return self._run_async(self._client.start_simulation(firmware, elf, pause, chips or []))
90105

91-
def pause_simulation(self):
92-
"""Pause the running simulation."""
106+
def pause_simulation(self) -> t.Any:
93107
if not self._connected:
94108
raise RuntimeError("Client not connected")
109+
assert self._client is not None
95110
return self._run_async(self._client.pause_simulation())
96111

97-
def resume_simulation(self, pause_after: t.Optional[int] = None):
98-
"""Resume the simulation, optionally pausing after a given number of nanoseconds."""
112+
def resume_simulation(self, pause_after: t.Optional[int] = None) -> t.Any:
99113
if not self._connected:
100114
raise RuntimeError("Client not connected")
115+
assert self._client is not None
101116
return self._run_async(self._client.resume_simulation(pause_after))
102117

103-
def wait_until_simulation_time(self, seconds: float):
104-
"""Pause and resume the simulation until the given simulation time (in seconds) is reached."""
118+
def wait_until_simulation_time(self, seconds: float) -> t.Any:
105119
if not self._connected:
106120
raise RuntimeError("Client not connected")
121+
assert self._client is not None
107122
return self._run_async(self._client.wait_until_simulation_time(seconds))
108123

109-
def restart_simulation(self, pause: bool = False):
110-
"""Restart the simulation, optionally starting paused."""
124+
def restart_simulation(self, pause: bool = False) -> t.Any:
111125
if not self._connected:
112126
raise RuntimeError("Client not connected")
127+
assert self._client is not None
113128
return self._run_async(self._client.restart_simulation(pause))
114129

115-
def serial_monitor_cat(self):
116-
"""Print serial monitor output to stdout as it is received from the simulation."""
130+
def serial_monitor_cat(self) -> t.Any:
117131
if not self._connected:
118132
raise RuntimeError("Client not connected")
133+
assert self._client is not None
119134
return self._run_async(self._client.serial_monitor_cat())
120135

121-
def write_serial(self, data: t.Union[bytes, str, list[int]]):
122-
"""Write data to serial."""
136+
def write_serial(self, data: t.Union[bytes, str, t.List[int]]) -> t.Any:
123137
if not self._connected:
124138
raise RuntimeError("Client not connected")
139+
assert self._client is not None
125140
return self._run_async(self._client.serial_write(data))
126141

127-
def read_pin(self, part: str, pin: str):
128-
"""Read the current state of a pin."""
142+
def read_pin(self, part: str, pin: str) -> t.Any:
129143
if not self._connected:
130144
raise RuntimeError("Client not connected")
145+
assert self._client is not None
131146
return self._run_async(self._client.read_pin(part, pin))
132147

133-
def listen_pin(self, part: str, pin: str, listen: bool = True):
134-
"""Start or stop listening for changes on a pin."""
148+
def listen_pin(self, part: str, pin: str, listen: bool = True) -> t.Any:
135149
if not self._connected:
136150
raise RuntimeError("Client not connected")
151+
assert self._client is not None
137152
return self._run_async(self._client.listen_pin(part, pin, listen))
138153

139-
def monitor_serial(self, callback):
140-
"""Start monitoring serial output with a callback."""
154+
def monitor_serial(self, callback: t.Callable[[bytes], None]) -> None:
141155
if not self._connected:
142156
raise RuntimeError("Client not connected")
143157

144-
async def _monitor():
158+
async def _monitor() -> None:
145159
try:
160+
assert self._client is not None
146161
async for line in monitor_serial_lines(self._client._transport):
147162
if not self._connected:
148163
break
@@ -154,27 +169,24 @@ async def _monitor():
154169
except Exception as e:
155170
logging.error(f"Error in serial monitor: {e}")
156171

157-
# Start monitoring in background
172+
assert self._loop is not None
158173
self._monitor_task = asyncio.run_coroutine_threadsafe(_monitor(), self._loop)
159174

160-
def set_control(self, part: str, control: str, value: t.Union[int, bool, float]):
161-
"""Set control value for a part."""
175+
def set_control(self, part: str, control: str, value: t.Union[int, bool, float]) -> t.Any:
162176
if not self._connected:
163177
raise RuntimeError("Client not connected")
178+
assert self._client is not None
164179
return self._run_async(self._client.set_control(part, control, value))
165180

166181
@property
167-
def version(self):
168-
"""Get client version."""
182+
def version(self) -> str:
169183
if self._client:
170184
return self._client.version
171-
# Return a default version if client not initialized yet
172185
client = WokwiClient(self.token, self.server)
173186
return client.version
174187

175188
@property
176-
def last_pause_nanos(self):
177-
"""Get the last pause time in nanoseconds."""
189+
def last_pause_nanos(self) -> int:
178190
if self._client:
179191
return self._client.last_pause_nanos
180192
return 0

src/wokwi_client/control.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
#
1212
# SPDX-License-Identifier: MIT
1313

14+
import typing
15+
1416
from .protocol_types import ResponseMessage
1517
from .transport import Transport
1618

1719

1820
async def set_control(
19-
transport: Transport, *, part: str, control: str, value: int | bool | float
21+
transport: Transport, *, part: str, control: str, value: typing.Union[int, bool, float]
2022
) -> ResponseMessage:
2123
"""Set a control value on a part (e.g. simulate button press/release).
2224

src/wokwi_client/pins.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
from .transport import Transport
1616

1717

18-
async def pin_read(
19-
transport: Transport, *, part: str, pin: str
20-
) -> ResponseMessage:
18+
async def pin_read(transport: Transport, *, part: str, pin: str) -> ResponseMessage:
2119
"""Read the state of a pin.
2220
2321
Args:
@@ -44,6 +42,4 @@ async def pin_listen(
4442
listen: True to start listening, False to stop.
4543
"""
4644

47-
return await transport.request(
48-
"pin:listen", {"part": part, "pin": pin, "listen": listen}
49-
)
45+
return await transport.request("pin:listen", {"part": part, "pin": pin, "listen": listen})

src/wokwi_client/serial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5+
import typing
56
from collections.abc import AsyncGenerator, Iterable
67

78
from .event_queue import EventQueue
@@ -16,7 +17,7 @@ async def monitor_lines(transport: Transport) -> AsyncGenerator[bytes, None]:
1617
yield bytes(event_msg["payload"]["bytes"])
1718

1819

19-
async def write_serial(transport: Transport, data: bytes | str | Iterable[int]) -> None:
20+
async def write_serial(transport: Transport, data: typing.Union[bytes, str, Iterable[int]]) -> None:
2021
"""Write data to the serial monitor.
2122
2223
Accepts bytes, str (encoded as utf-8), or an iterable of integer byte values.

tests/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
import subprocess
1212
import sys
1313
from collections.abc import Mapping
14+
from subprocess import CompletedProcess
1415

1516

16-
def run_example_module(module: str, *, sleep_time: str = "1", extra_env: Mapping[str, str] | None = None) -> subprocess.CompletedProcess:
17+
def run_example_module(
18+
module: str, *, sleep_time: str = "1", extra_env: Mapping[str, str] | None = None
19+
) -> CompletedProcess[str]:
1720
"""Run an example module with a short simulation time.
1821
1922
Requires WOKWI_CLI_TOKEN to be set in the environment.

0 commit comments

Comments
 (0)