|
| 1 | +import asyncio |
| 2 | +import logging |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import TYPE_CHECKING, Optional |
| 5 | + |
| 6 | +from pylabrobot.io.capture import Command, capturer, get_capture_or_validation_active |
| 7 | +from pylabrobot.io.errors import ValidationError |
| 8 | +from pylabrobot.io.io import IOBase |
| 9 | +from pylabrobot.io.validation_utils import LOG_LEVEL_IO |
| 10 | + |
| 11 | +if TYPE_CHECKING: |
| 12 | + from pylabrobot.io.capture import CaptureReader |
| 13 | + |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +@dataclass |
| 19 | +class SocketCommand(Command): |
| 20 | + data: str |
| 21 | + |
| 22 | + def __init__(self, device_id: str, action: str, data: str, module: str = "socket"): |
| 23 | + super().__init__(module=module, device_id=device_id, action=action) |
| 24 | + self.data = data |
| 25 | + |
| 26 | + |
| 27 | +class Socket(IOBase): |
| 28 | + """IO for reading/writing to a TCP socket.""" |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + host: str, |
| 33 | + port: int, |
| 34 | + read_timeout: int = 30, |
| 35 | + write_timeout: int = 30, |
| 36 | + ): |
| 37 | + """Initialize an io.Socket object. |
| 38 | +
|
| 39 | + Args: |
| 40 | + host: The hostname or IP address to connect to. |
| 41 | + port: The port number to connect to. |
| 42 | + read_timeout: The timeout for reading from the socket in seconds. |
| 43 | + write_timeout: The timeout for writing to the socket in seconds. |
| 44 | + """ |
| 45 | + |
| 46 | + super().__init__() |
| 47 | + |
| 48 | + if get_capture_or_validation_active(): |
| 49 | + raise RuntimeError("Cannot create a new Socket object while capture or validation is active") |
| 50 | + |
| 51 | + self.host = host |
| 52 | + self.port = port |
| 53 | + self.read_timeout = read_timeout |
| 54 | + self.write_timeout = write_timeout |
| 55 | + |
| 56 | + self._reader: Optional[asyncio.StreamReader] = None |
| 57 | + self._writer: Optional[asyncio.StreamWriter] = None |
| 58 | + |
| 59 | + # unique id in the logs |
| 60 | + self._unique_id = f"[{self.host}:{self.port}]" |
| 61 | + |
| 62 | + async def write(self, data: str, timeout: Optional[int] = None): |
| 63 | + if self._writer is None: |
| 64 | + raise ConnectionError("Socket not connected. Call setup() first.") |
| 65 | + |
| 66 | + if timeout is None: |
| 67 | + timeout = self.write_timeout |
| 68 | + |
| 69 | + try: |
| 70 | + self._writer.write(data.encode("ascii")) |
| 71 | + await asyncio.wait_for(self._writer.drain(), timeout=timeout) |
| 72 | + |
| 73 | + logger.log(LOG_LEVEL_IO, "%s write: %s", self._unique_id, data.strip()) |
| 74 | + capturer.record(SocketCommand(device_id=self._unique_id, action="write", data=data)) |
| 75 | + except asyncio.TimeoutError as exc: |
| 76 | + raise TimeoutError(f"Timeout while writing to socket after {timeout} seconds") from exc |
| 77 | + |
| 78 | + async def read(self, timeout: Optional[int] = None, read_once=True) -> str: |
| 79 | + """Read data from the socket. |
| 80 | + Args: |
| 81 | + timeout: The timeout for reading from the socket in seconds. If None, uses the default |
| 82 | + read_timeout set during initialization. |
| 83 | + read_once: If True, reads until the first complete message is received. If False, continues |
| 84 | + reading until the connection is closed or a timeout occurs. |
| 85 | + """ |
| 86 | + |
| 87 | + if self._reader is None: |
| 88 | + raise ConnectionError("Socket not connected. Call setup() first.") |
| 89 | + |
| 90 | + if timeout is None: |
| 91 | + timeout = self.read_timeout |
| 92 | + |
| 93 | + try: |
| 94 | + chunks = [] |
| 95 | + while True: |
| 96 | + try: |
| 97 | + data = await asyncio.wait_for(self._reader.read(1024), timeout=timeout) |
| 98 | + if not data: |
| 99 | + # Connection closed |
| 100 | + break |
| 101 | + chunks.append(data) |
| 102 | + if read_once: |
| 103 | + break |
| 104 | + except asyncio.TimeoutError as exc: |
| 105 | + if chunks: |
| 106 | + # We have some data, return it |
| 107 | + break |
| 108 | + raise TimeoutError(f"Timeout while reading from socket after {timeout} seconds") from exc |
| 109 | + |
| 110 | + if len(chunks) == 0: |
| 111 | + raise ConnectionError("Socket connection closed") |
| 112 | + |
| 113 | + response = b"".join(chunks).decode("ascii") |
| 114 | + logger.log(LOG_LEVEL_IO, "%s read: %s", self._unique_id, response.strip()) |
| 115 | + capturer.record(SocketCommand(device_id=self._unique_id, action="read", data=response)) |
| 116 | + return response |
| 117 | + |
| 118 | + except UnicodeDecodeError as e: |
| 119 | + raise ValueError(f"Failed to decode socket response as ASCII: {e}") from e |
| 120 | + |
| 121 | + async def setup(self): |
| 122 | + """Initialize the socket connection.""" |
| 123 | + |
| 124 | + if self._writer is not None: |
| 125 | + # previous setup did not properly finish, |
| 126 | + # or we are re-initializing the connection. |
| 127 | + logger.warning("Socket already connected. Closing previous connection.") |
| 128 | + await self.stop() |
| 129 | + |
| 130 | + logger.info("Connecting to socket %s:%s...", self.host, self.port) |
| 131 | + |
| 132 | + try: |
| 133 | + self._reader, self._writer = await asyncio.open_connection(self.host, self.port) |
| 134 | + logger.info("Connected to socket %s:%s", self.host, self.port) |
| 135 | + except Exception as e: |
| 136 | + raise ConnectionError(f"Failed to connect to {self.host}:{self.port}: {e}") from e |
| 137 | + |
| 138 | + async def stop(self): |
| 139 | + """Close the socket connection.""" |
| 140 | + |
| 141 | + if self._writer is None: |
| 142 | + logger.debug("Socket already disconnected.") |
| 143 | + return |
| 144 | + |
| 145 | + logger.info("Closing connection to socket %s:%s", self.host, self.port) |
| 146 | + |
| 147 | + try: |
| 148 | + self._writer.close() |
| 149 | + await self._writer.wait_closed() |
| 150 | + except OSError as e: |
| 151 | + logger.warning("Error while closing socket connection: %s", e) |
| 152 | + finally: |
| 153 | + self._reader = None |
| 154 | + self._writer = None |
| 155 | + |
| 156 | + def serialize(self) -> dict: |
| 157 | + """Serialize the socket to a dictionary.""" |
| 158 | + |
| 159 | + return { |
| 160 | + **super().serialize(), |
| 161 | + "host": self.host, |
| 162 | + "port": self.port, |
| 163 | + "read_timeout": self.read_timeout, |
| 164 | + "write_timeout": self.write_timeout, |
| 165 | + } |
| 166 | + |
| 167 | + |
| 168 | +class SocketValidator(Socket): |
| 169 | + """Socket validator for testing/validation purposes.""" |
| 170 | + |
| 171 | + def __init__( |
| 172 | + self, |
| 173 | + cr: "CaptureReader", |
| 174 | + host: str, |
| 175 | + port: int, |
| 176 | + read_timeout: int = 30, |
| 177 | + write_timeout: int = 30, |
| 178 | + ): |
| 179 | + super().__init__( |
| 180 | + host=host, |
| 181 | + port=port, |
| 182 | + read_timeout=read_timeout, |
| 183 | + write_timeout=write_timeout, |
| 184 | + ) |
| 185 | + self.cr = cr |
| 186 | + |
| 187 | + async def setup(self): |
| 188 | + """Mock setup for validation.""" |
| 189 | + return |
| 190 | + |
| 191 | + async def write(self, *args, **kwargs): |
| 192 | + """Validate write command against captured data.""" |
| 193 | + if not args: |
| 194 | + raise ValueError("No data provided to write") |
| 195 | + |
| 196 | + data = args[0] |
| 197 | + next_command = SocketCommand(**self.cr.next_command()) |
| 198 | + if not ( |
| 199 | + next_command.module == "socket" |
| 200 | + and next_command.device_id == self._unique_id |
| 201 | + and next_command.action == "write" |
| 202 | + ): |
| 203 | + raise ValidationError( |
| 204 | + f"Expected socket write command to {self._unique_id}, " |
| 205 | + f"got {next_command.module} {next_command.action} to {next_command.device_id}" |
| 206 | + ) |
| 207 | + if not next_command.data == data: |
| 208 | + raise ValidationError( |
| 209 | + f"Socket write data mismatch. Expected:\n{next_command.data}\nGot:\n{data}" |
| 210 | + ) |
| 211 | + |
| 212 | + async def read(self, *args, **kwargs) -> str: |
| 213 | + """Return captured read data for validation.""" |
| 214 | + next_command = SocketCommand(**self.cr.next_command()) |
| 215 | + if not ( |
| 216 | + next_command.module == "socket" |
| 217 | + and next_command.device_id == self._unique_id |
| 218 | + and next_command.action == "read" |
| 219 | + ): |
| 220 | + raise ValidationError( |
| 221 | + f"Expected socket read command from {self._unique_id}, " |
| 222 | + f"got {next_command.module} {next_command.action} from {next_command.device_id}" |
| 223 | + ) |
| 224 | + return next_command.data |
| 225 | + |
| 226 | + async def stop(self): |
| 227 | + """Mock stop for validation.""" |
| 228 | + return |
0 commit comments