|
| 1 | +import asyncio |
1 | 2 | import json |
2 | 3 | import threading |
3 | | -import time |
| 4 | +import traceback |
| 5 | +from typing import Any, Dict |
4 | 6 |
|
5 | | -from core import OpenInterpreter |
| 7 | +from .core import OpenInterpreter |
6 | 8 |
|
| 9 | +try: |
| 10 | + import janus |
| 11 | + import uvicorn |
| 12 | + from fastapi import APIRouter, FastAPI, WebSocket |
| 13 | +except: |
| 14 | + # Server dependencies are not required by the main package. |
| 15 | + pass |
7 | 16 |
|
8 | | -class AsyncOpenInterpreter(OpenInterpreter): |
| 17 | + |
| 18 | +class AsyncInterpreter(OpenInterpreter): |
9 | 19 | def __init__(self, *args, **kwargs): |
10 | 20 | super().__init__(*args, **kwargs) |
11 | | - self.async_thread = None |
12 | | - self.input_queue |
13 | | - self.output_queue |
| 21 | + |
| 22 | + self.respond_thread = None |
| 23 | + self.stop_event = threading.Event() |
| 24 | + self.output_queue = None |
| 25 | + |
| 26 | + self.server = Server(self) |
14 | 27 |
|
15 | 28 | async def input(self, chunk): |
16 | 29 | """ |
17 | | - Expects a chunk in streaming LMC format. |
| 30 | + Accumulates LMC chunks onto interpreter.messages. |
| 31 | + When it hits an "end" flag, calls interpreter.respond(). |
18 | 32 | """ |
19 | | - try: |
20 | | - chunk = json.loads(chunk) |
21 | | - except: |
22 | | - pass |
23 | 33 |
|
24 | 34 | if "start" in chunk: |
25 | | - self.async_thread.join() |
| 35 | + # If the user is starting something, the interpreter should stop. |
| 36 | + if self.respond_thread is not None and self.respond_thread.is_alive(): |
| 37 | + self.stop_event.set() |
| 38 | + self.respond_thread.join() |
| 39 | + self.accumulate(chunk) |
| 40 | + elif "content" in chunk: |
| 41 | + self.accumulate(chunk) |
26 | 42 | elif "end" in chunk: |
27 | | - if self.async_thread is None or not self.async_thread.is_alive(): |
28 | | - self.async_thread = threading.Thread(target=self.complete) |
29 | | - self.async_thread.start() |
30 | | - else: |
31 | | - await self._add_to_queue(self._input_queue, chunk) |
32 | | - |
33 | | - async def output(self, *args, **kwargs): |
34 | | - # Your async output code here |
35 | | - pass |
| 43 | + # If the user is done talking, the interpreter should respond. |
| 44 | + self.stop_event.clear() |
| 45 | + print("Responding.") |
| 46 | + self.respond_thread = threading.Thread(target=self.respond) |
| 47 | + self.respond_thread.start() |
| 48 | + |
| 49 | + async def output(self): |
| 50 | + if self.output_queue == None: |
| 51 | + self.output_queue = janus.Queue() |
| 52 | + return await self.output_queue.async_q.get() |
| 53 | + |
| 54 | + def respond(self): |
| 55 | + for chunk in self._respond_and_store(): |
| 56 | + print(chunk.get("content", ""), end="") |
| 57 | + if self.stop_event.is_set(): |
| 58 | + return |
| 59 | + self.output_queue.sync_q.put(chunk) |
| 60 | + |
| 61 | + self.output_queue.sync_q.put( |
| 62 | + {"role": "server", "type": "status", "content": "complete"} |
| 63 | + ) |
| 64 | + |
| 65 | + def accumulate(self, chunk): |
| 66 | + """ |
| 67 | + Accumulates LMC chunks onto interpreter.messages. |
| 68 | + """ |
| 69 | + if type(chunk) == dict: |
| 70 | + if chunk.get("format") == "active_line": |
| 71 | + # We don't do anything with these. |
| 72 | + pass |
| 73 | + |
| 74 | + elif "start" in chunk: |
| 75 | + chunk_copy = ( |
| 76 | + chunk.copy() |
| 77 | + ) # So we don't modify the original chunk, which feels wrong. |
| 78 | + chunk_copy.pop("start") |
| 79 | + chunk_copy["content"] = "" |
| 80 | + self.messages.append(chunk_copy) |
| 81 | + |
| 82 | + elif "content" in chunk: |
| 83 | + self.messages[-1]["content"] += chunk["content"] |
| 84 | + |
| 85 | + elif type(chunk) == bytes: |
| 86 | + if self.messages[-1]["content"] == "": # We initialize as an empty string ^ |
| 87 | + self.messages[-1]["content"] = b"" # But it actually should be bytes |
| 88 | + self.messages[-1]["content"] += chunk |
| 89 | + |
| 90 | + |
| 91 | +def create_router(async_interpreter): |
| 92 | + router = APIRouter() |
| 93 | + |
| 94 | + @router.get("/heartbeat") |
| 95 | + async def heartbeat(): |
| 96 | + return {"status": "alive"} |
| 97 | + |
| 98 | + @router.websocket("/") |
| 99 | + async def websocket_endpoint(websocket: WebSocket): |
| 100 | + await websocket.accept() |
| 101 | + try: |
| 102 | + |
| 103 | + async def receive_input(): |
| 104 | + while True: |
| 105 | + try: |
| 106 | + data = await websocket.receive() |
| 107 | + |
| 108 | + if data.get("type") == "websocket.receive" and "text" in data: |
| 109 | + data = json.loads(data["text"]) |
| 110 | + await async_interpreter.input(data) |
| 111 | + elif ( |
| 112 | + data.get("type") == "websocket.disconnect" |
| 113 | + and data.get("code") == 1000 |
| 114 | + ): |
| 115 | + print("Disconnecting.") |
| 116 | + return |
| 117 | + else: |
| 118 | + print("Invalid data:", data) |
| 119 | + continue |
| 120 | + |
| 121 | + except Exception as e: |
| 122 | + error_message = { |
| 123 | + "role": "server", |
| 124 | + "type": "error", |
| 125 | + "content": traceback.format_exc() + "\n" + str(e), |
| 126 | + } |
| 127 | + await websocket.send_text(json.dumps(error_message)) |
| 128 | + |
| 129 | + async def send_output(): |
| 130 | + while True: |
| 131 | + try: |
| 132 | + output = await async_interpreter.output() |
| 133 | + |
| 134 | + if isinstance(output, bytes): |
| 135 | + await websocket.send_bytes(output) |
| 136 | + else: |
| 137 | + await websocket.send_text(json.dumps(output)) |
| 138 | + except Exception as e: |
| 139 | + traceback.print_exc() |
| 140 | + error_message = { |
| 141 | + "role": "server", |
| 142 | + "type": "error", |
| 143 | + "content": traceback.format_exc() + "\n" + str(e), |
| 144 | + } |
| 145 | + await websocket.send_text(json.dumps(error_message)) |
| 146 | + |
| 147 | + await asyncio.gather(receive_input(), send_output()) |
| 148 | + except Exception as e: |
| 149 | + traceback.print_exc() |
| 150 | + try: |
| 151 | + error_message = { |
| 152 | + "role": "server", |
| 153 | + "type": "error", |
| 154 | + "content": traceback.format_exc() + "\n" + str(e), |
| 155 | + } |
| 156 | + await websocket.send_text(json.dumps(error_message)) |
| 157 | + except: |
| 158 | + # If we can't send it, that's fine. |
| 159 | + pass |
| 160 | + finally: |
| 161 | + await websocket.close() |
| 162 | + |
| 163 | + @router.post("/settings") |
| 164 | + async def settings(payload: Dict[str, Any]): |
| 165 | + for key, value in payload.items(): |
| 166 | + print(f"Updating settings: {key} = {value}") |
| 167 | + if key in ["llm", "computer"] and isinstance(value, dict): |
| 168 | + for sub_key, sub_value in value.items(): |
| 169 | + setattr(getattr(async_interpreter, key), sub_key, sub_value) |
| 170 | + else: |
| 171 | + setattr(async_interpreter, key, value) |
| 172 | + |
| 173 | + return {"status": "success"} |
| 174 | + |
| 175 | + return router |
| 176 | + |
| 177 | + |
| 178 | +class Server: |
| 179 | + def __init__(self, async_interpreter, host="0.0.0.0", port=8000): |
| 180 | + self.app = FastAPI() |
| 181 | + router = create_router(async_interpreter) |
| 182 | + self.app.include_router(router) |
| 183 | + self.host = host |
| 184 | + self.port = port |
| 185 | + self.uvicorn_server = uvicorn.Server( |
| 186 | + config=uvicorn.Config(app=self.app, host=self.host, port=self.port) |
| 187 | + ) |
| 188 | + |
| 189 | + def run(self): |
| 190 | + uvicorn.run(self.app, host=self.host, port=self.port) |
0 commit comments