|
7 | 7 | from typing import Any, Callable, Union |
8 | 8 |
|
9 | 9 | import cbor2 |
| 10 | +import pandas |
10 | 11 | import pyarrow |
11 | 12 | import websockets.exceptions |
12 | 13 | import websockets.protocol |
@@ -120,61 +121,70 @@ def __listen(self) -> None: |
120 | 121 | ) |
121 | 122 | return |
122 | 123 |
|
123 | | - if kind == EventKind.STATE_UPDATED: |
| 124 | + # Incoming state transitions are handled here. |
| 125 | + if kind == EventKind.STATE_UPDATED or kind == EventKind.EXECUTION_RESULT: |
124 | 126 | try: |
125 | 127 | query.state = ExecutionState[message["state"].upper()] |
126 | 128 | logging.info("Query %s is now %s.", execution_id, query.state) |
127 | 129 | except KeyError: |
128 | 130 | logging.warning("Invalid state update message for %s", execution_id) |
129 | 131 | return |
130 | 132 |
|
131 | | - # Incoming state transitions are handled here. |
132 | 133 | if query.state == ExecutionState.SUCCEEDED: |
133 | | - self.__request_results(execution_id) |
| 134 | + # On a state_updated event telling us the query succeeded, |
| 135 | + # ask for results. |
| 136 | + if kind == EventKind.STATE_UPDATED: |
| 137 | + self.__request_results(execution_id) |
| 138 | + return |
| 139 | + |
| 140 | + # Otherwise, process the results from the execution_result event. |
| 141 | + results = message.get("results") |
| 142 | + if not results or not isinstance(results, dict): |
| 143 | + logging.warning("Got no results back from %s.", execution_id) |
| 144 | + return |
| 145 | + |
| 146 | + query.state = ExecutionState.COMPLETED |
| 147 | + query.handler(self._handle_results(execution_id, results)) |
134 | 148 | elif query.state == ExecutionState.CANCELLED: |
135 | | - logging.info("Query %s has been cancelled.", execution_id) |
| 149 | + logging.info( |
| 150 | + "Query %s has been cancelled; returning empty results.", |
| 151 | + execution_id, |
| 152 | + ) |
| 153 | + query.handler(pandas.DataFrame()) |
136 | 154 | self.__queries.pop(execution_id) |
137 | 155 | elif query.state == ExecutionState.FAILED: |
138 | 156 | # Don't do anything here; the ERROR event is coming with more |
139 | 157 | # details. |
140 | 158 | pass |
141 | | - |
142 | | - elif kind == EventKind.EXECUTION_RESULT: |
143 | | - results = message.get("results") |
144 | | - if not results or not isinstance(results, dict): |
145 | | - logging.warning("Got no results back from %s.", execution_id) |
146 | | - return |
147 | | - |
148 | | - result_bytes = results.get("result_bytes") |
149 | | - result_format = results.get("format") |
150 | | - result_compression = results.get("compression") |
151 | | - logging.info( |
152 | | - "Received %d bytes of %s-compressed %s results from %s.", |
153 | | - len(result_bytes), |
154 | | - result_compression, |
155 | | - result_format, |
156 | | - execution_id, |
157 | | - ) |
158 | | - |
159 | | - query.state = ExecutionState.COMPLETED |
160 | | - if result_format == ResultsFormat.JSON: |
161 | | - query.handler(json.loads(result_bytes.decode("utf-8"))) |
162 | | - elif result_format == ResultsFormat.ARROW: |
163 | | - buffer = pyarrow.py_buffer(result_bytes) |
164 | | - stream = pyarrow.input_stream(buffer, result_compression) |
165 | | - with pyarrow.ipc.open_stream(stream) as reader: |
166 | | - query.handler(reader.read_pandas()) |
167 | | - else: |
168 | | - query.handler( |
169 | | - OperationalError(f"Unsupported results format {result_format}") |
170 | | - ) |
171 | 159 | elif kind == EventKind.ERROR: |
172 | 160 | query.state = ExecutionState.FAILED |
173 | 161 | error = message.get("message") |
174 | 162 | query.handler(OperationalError(error)) |
175 | 163 | else: |
176 | 164 | logging.warning("Received unknown %s event!", kind) |
177 | 165 |
|
| 166 | + def _handle_results(self, execution_id: str, results: dict[str, Any]) -> Any: |
| 167 | + result_bytes = results.get("result_bytes") |
| 168 | + result_format = results.get("format") |
| 169 | + result_compression = results.get("compression") |
| 170 | + logging.info( |
| 171 | + "Received %d bytes of %s-compressed %s results from %s.", |
| 172 | + len(result_bytes), |
| 173 | + result_compression, |
| 174 | + result_format, |
| 175 | + execution_id, |
| 176 | + ) |
| 177 | + |
| 178 | + if result_format == ResultsFormat.JSON: |
| 179 | + return json.loads(result_bytes.decode("utf-8")) |
| 180 | + elif result_format == ResultsFormat.ARROW: |
| 181 | + buffer = pyarrow.py_buffer(result_bytes) |
| 182 | + stream = pyarrow.input_stream(buffer, result_compression) |
| 183 | + with pyarrow.ipc.open_stream(stream) as reader: |
| 184 | + return reader.read_pandas() |
| 185 | + else: |
| 186 | + return OperationalError(f"Unsupported results format {result_format}") |
| 187 | + |
178 | 188 | def __send(self, message: dict[str, Any]) -> None: |
179 | 189 | request = json.dumps(message) |
180 | 190 | logging.debug("Request: %s", request) |
|
0 commit comments