Skip to content

Commit 84a9f98

Browse files
committed
add flask render server
also factor out assumption of asyncio from base render server
1 parent 69e88c9 commit 84a9f98

File tree

7 files changed

+398
-137
lines changed

7 files changed

+398
-137
lines changed

idom/server/base.py

Lines changed: 31 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
import abc
2-
from asyncio import AbstractEventLoop, new_event_loop, set_event_loop, get_event_loop
3-
from typing import TypeVar, Dict, Any, Tuple, Type, Optional, Generic, TypeVar
2+
from typing import TypeVar, Dict, Any, Tuple, Optional, Generic, TypeVar
43
from threading import Thread, Event
54

65
from idom.core.element import ElementConstructor
7-
from idom.core.layout import Layout, Layout
8-
from idom.core.dispatcher import (
9-
AbstractDispatcher,
10-
SendCoroutine,
11-
RecvCoroutine,
12-
)
136

147

158
_App = TypeVar("_App", bound=Any)
@@ -30,26 +23,16 @@ class AbstractRenderServer(Generic[_App, _Config]):
3023
:meth:`AbstractServerExtension.register`
3124
"""
3225

33-
_loop: AbstractEventLoop
34-
_dispatcher_type: Type[AbstractDispatcher]
35-
_layout_type: Type[Layout] = Layout
36-
_daemon_server_did_start: Event
37-
3826
def __init__(
3927
self,
4028
constructor: ElementConstructor,
4129
config: Optional[_Config] = None,
4230
) -> None:
4331
self._app: Optional[_App] = None
44-
self._make_root_element = constructor
32+
self._root_element_constructor = constructor
4533
self._daemonized = False
46-
self._config = self._init_config()
47-
if config is not None:
48-
self._config = self._update_config(self._config, config)
49-
50-
@property
51-
def loop(self) -> AbstractEventLoop:
52-
return self._loop
34+
self._config = self._create_config(config)
35+
self._server_did_start = Event()
5336

5437
@property
5538
def application(self) -> _App:
@@ -59,45 +42,45 @@ def application(self) -> _App:
5942

6043
def run(self, *args: Any, **kwargs: Any) -> None:
6144
"""Run as a standalone application."""
62-
self._loop = get_event_loop()
6345
if self._app is None:
6446
app = self._default_application(self._config)
6547
self.register(app)
6648
else:
6749
app = self._app
68-
return self._run_application(app, self._config, args, kwargs)
50+
if not self._daemonized:
51+
return self._run_application(app, self._config, args, kwargs)
52+
else:
53+
return self._run_application_in_thread(app, self._config, args, kwargs)
6954

7055
def daemon(self, *args: Any, **kwargs: Any) -> Thread:
7156
"""Run the standalone application in a seperate thread."""
7257
self._daemonized = True
7358

74-
def run_in_thread() -> None:
75-
set_event_loop(new_event_loop())
76-
return self.run(*args, **kwargs)
77-
78-
thread = Thread(target=run_in_thread, daemon=True)
59+
thread = Thread(target=lambda: self.run(*args, **kwargs), daemon=True)
7960
thread.start()
8061

81-
self._wait_until_daemon_server_start()
62+
self.wait_until_server_start()
8263

8364
return thread
8465

8566
def register(self: _Self, app: Optional[_App]) -> _Self:
8667
"""Register this as an extension."""
8768
self._setup_application(app, self._config)
69+
self._setup_application_did_start_event(app, self._server_did_start)
8870
self._app = app
8971
return self
9072

91-
def stop(self) -> None:
92-
"""Stop the running application"""
93-
self.loop.call_soon_threadsafe(self._stop)
73+
def server_started(self) -> bool:
74+
"""Whether the underlying application has started"""
75+
return self._server_did_start.set()
9476

95-
@abc.abstractmethod
96-
def _stop(self) -> None:
97-
raise NotImplementedError()
77+
def wait_until_server_start(self, timeout: float = 3.0):
78+
"""Block until the underlying application has started"""
79+
if not self._server_did_start.wait(timeout=timeout):
80+
raise RuntimeError(f"Server did not start within {timeout} seconds")
9881

9982
@abc.abstractmethod
100-
def _init_config(self) -> _Config:
83+
def _create_config(self, config: Optional[_Config]) -> _Config:
10184
"""Return the default configuration options."""
10285

10386
@abc.abstractmethod
@@ -107,49 +90,24 @@ def _default_application(self, config: _Config) -> _App:
10790

10891
@abc.abstractmethod
10992
def _setup_application(self, app: _App, config: _Config) -> None:
110-
...
93+
"""General application setup - add routes, templates, static resource, etc."""
94+
raise NotImplementedError()
95+
96+
@abc.abstractmethod
97+
def _setup_application_did_start_event(self, app: _App, event: Event) -> None:
98+
"""Register a callback to the app indicating whether the server has started"""
99+
raise NotImplementedError()
111100

112101
@abc.abstractmethod
113102
def _run_application(
114103
self, app: _App, config: _Config, args: Tuple[Any, ...], kwargs: Dict[str, Any]
115104
) -> None:
105+
"""Run the application in the main thread"""
116106
raise NotImplementedError()
117107

118108
@abc.abstractmethod
119-
def _update_config(self, old: _Config, new: _Config) -> _Config: # pragma: no cover
120-
"""Return the new configuration options
121-
122-
Parameters:
123-
old: The existing configuration options
124-
new: The new configuration options
125-
"""
126-
raise NotImplementedError()
127-
128-
async def _run_dispatcher(
129-
self,
130-
send: SendCoroutine,
131-
recv: RecvCoroutine,
132-
params: Dict[str, Any],
109+
def _run_application_in_thread(
110+
self, app: _App, config: _Config, args: Tuple[Any, ...], kwargs: Dict[str, Any]
133111
) -> None:
134-
async with self._make_dispatcher(params) as dispatcher:
135-
await dispatcher.run(send, recv, None)
136-
137-
def _make_dispatcher(
138-
self,
139-
params: Dict[str, Any],
140-
) -> AbstractDispatcher:
141-
return self._dispatcher_type(self._make_layout(params))
142-
143-
def _make_layout(
144-
self,
145-
params: Dict[str, Any],
146-
) -> Layout:
147-
return self._layout_type(self._make_root_element(**params))
148-
149-
def _wait_until_daemon_server_start(self):
150-
try:
151-
self._daemon_server_did_start.wait(timeout=5)
152-
except AttributeError: # pragma: no cover
153-
raise NotImplementedError(
154-
f"Server implementation {self} did not define a server started thread event"
155-
)
112+
"""This function has been called inside a daemon thread to run the application"""
113+
raise NotImplementedError()

idom/server/flask.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import json
2+
import asyncio
3+
import logging
4+
from urllib.parse import urljoin
5+
from asyncio import Queue as AsyncQueue
6+
from threading import Event as ThreadEvent, Thread
7+
from queue import Queue as ThreadQueue
8+
from typing import Union, Tuple, Dict, Any, Optional, Callable, NamedTuple
9+
10+
from typing_extensions import TypedDict
11+
from flask import Flask, Blueprint, send_from_directory, redirect, url_for
12+
from flask_cors import CORS
13+
from flask_sockets import Sockets
14+
from geventwebsocket.websocket import WebSocket
15+
from gevent import pywsgi
16+
from geventwebsocket.handler import WebSocketHandler
17+
18+
import idom
19+
from idom.client.manage import BUILD_DIR
20+
from idom.core.layout import LayoutEvent, Layout
21+
from idom.core.dispatcher import AbstractDispatcher, SingleViewDispatcher
22+
23+
from .base import AbstractRenderServer
24+
25+
26+
class Config(TypedDict, total=False):
27+
import_name: str
28+
url_prefix: str
29+
cors: Union[bool, Dict[str, Any]]
30+
serve_static_files: bool
31+
redirect_root_to_index: bool
32+
33+
34+
class FlaskRenderServer(AbstractRenderServer[Flask, Config]):
35+
"""Base class for render servers which use Flask"""
36+
37+
_dispatcher_type: AbstractDispatcher
38+
39+
def _create_config(self, config: Optional[Config]) -> Config:
40+
return Config(
41+
{
42+
"import_name": __name__,
43+
"url_prefix": "",
44+
"cors": False,
45+
"serve_static_files": True,
46+
"redirect_root_to_index": True,
47+
**(config or {}),
48+
}
49+
)
50+
51+
def _default_application(self, config: Config) -> Flask:
52+
return Flask(config["import_name"])
53+
54+
def _setup_application(self, app: Flask, config: Config) -> None:
55+
bp = Blueprint("idom", __name__, url_prefix=config["url_prefix"])
56+
57+
self._setup_blueprint_routes(bp, config)
58+
59+
cors_config = config["cors"]
60+
if cors_config:
61+
cors_params = cors_config if isinstance(cors_config, dict) else {}
62+
CORS(bp, **cors_params)
63+
64+
app.register_blueprint(bp)
65+
66+
sockets = Sockets(app)
67+
68+
@sockets.route(urljoin(config["url_prefix"], "/stream"))
69+
def model_stream(ws: WebSocket) -> None:
70+
def send(value: Any) -> None:
71+
ws.send(json.dumps(value))
72+
73+
def recv() -> LayoutEvent:
74+
event = ws.receive()
75+
if event is not None:
76+
return LayoutEvent(**json.loads(event))
77+
else:
78+
return None
79+
80+
run_dispatcher_in_thread(
81+
lambda: self._dispatcher_type(Layout(self._root_element_constructor())),
82+
send,
83+
recv,
84+
None,
85+
)
86+
87+
def _setup_blueprint_routes(self, blueprint: Blueprint, config: Config) -> None:
88+
if config["serve_static_files"]:
89+
90+
@blueprint.route("/client/<path:path>")
91+
def send_build_dir(path):
92+
return send_from_directory(str(BUILD_DIR), path)
93+
94+
if config["redirect_root_to_index"]:
95+
96+
@blueprint.route("/")
97+
def redirect_to_index():
98+
return redirect(url_for("idom.send_build_dir", path="index.html"))
99+
100+
def _setup_application_did_start_event(
101+
self, app: Flask, event: ThreadEvent
102+
) -> None:
103+
@app.before_first_request
104+
def server_did_start():
105+
event.set()
106+
107+
def _run_application(
108+
self, app: Flask, config: Config, args: Tuple[Any, ...], kwargs: Dict[str, Any]
109+
) -> None:
110+
self._generic_run_application(app, *args, **kwargs)
111+
112+
def _run_application_in_thread(
113+
self, app: Flask, config: Config, args: Tuple[Any, ...], kwargs: Dict[str, Any]
114+
) -> None:
115+
self._generic_run_application(app, *args, **kwargs)
116+
117+
def _generic_run_application(
118+
self,
119+
app: Flask,
120+
host: str = "",
121+
port: int = 5000,
122+
debug: bool = False,
123+
*args: Any,
124+
**kwargs
125+
):
126+
if debug:
127+
logging.basicConfig(level=logging.DEBUG)
128+
logging.debug("Starting server...")
129+
_StartCallbackWSGIServer(
130+
self._server_did_start.set,
131+
(host, port),
132+
app,
133+
*args,
134+
handler_class=WebSocketHandler,
135+
**kwargs,
136+
).serve_forever()
137+
138+
139+
class PerClientStateServer(FlaskRenderServer):
140+
_dispatcher_type = SingleViewDispatcher
141+
142+
143+
def run_dispatcher_in_thread(
144+
make_dispatcher: Callable[[], AbstractDispatcher],
145+
send: Callable[[Any], None],
146+
recv: Callable[[], Optional[LayoutEvent]],
147+
context: Optional[Any],
148+
) -> None:
149+
dispatch_thread_info_created = ThreadEvent()
150+
dispatch_thread_info_ref: idom.Ref[Optional[_DispatcherThreadInfo]] = idom.Ref(None)
151+
152+
def run_dispatcher():
153+
loop = asyncio.new_event_loop()
154+
asyncio.set_event_loop(loop)
155+
156+
thread_send_queue = ThreadQueue()
157+
async_recv_queue = AsyncQueue()
158+
159+
async def send_coro(value: Any) -> None:
160+
thread_send_queue.put(value)
161+
162+
async def recv_coro() -> Any:
163+
return await async_recv_queue.get()
164+
165+
async def main():
166+
async with make_dispatcher() as dispatcher:
167+
await dispatcher.run(send_coro, recv_coro, context)
168+
169+
main_future = asyncio.ensure_future(main())
170+
171+
dispatch_thread_info_ref.current = _DispatcherThreadInfo(
172+
dispatch_loop=loop,
173+
dispatch_future=main_future,
174+
thread_send_queue=thread_send_queue,
175+
async_recv_queue=async_recv_queue,
176+
)
177+
dispatch_thread_info_created.set()
178+
179+
loop.run_until_complete(main_future)
180+
181+
Thread(target=run_dispatcher, daemon=True).start()
182+
dispatch_thread_info_created.wait()
183+
184+
dispatch_thread_info = dispatch_thread_info_ref.current
185+
assert dispatch_thread_info is not None
186+
187+
stop = ThreadEvent()
188+
189+
def run_send():
190+
while not stop.is_set():
191+
send(dispatch_thread_info.thread_send_queue.get())
192+
193+
Thread(target=run_send, daemon=True).start()
194+
195+
try:
196+
while True:
197+
value = recv()
198+
if value is None:
199+
stop.set()
200+
break
201+
dispatch_thread_info.dispatch_loop.call_soon_threadsafe(
202+
dispatch_thread_info.async_recv_queue.put_nowait, value
203+
)
204+
finally:
205+
dispatch_thread_info.dispatch_loop.call_soon_threadsafe(
206+
dispatch_thread_info.dispatch_future.cancel
207+
)
208+
209+
return None
210+
211+
212+
class _DispatcherThreadInfo(NamedTuple):
213+
dispatch_loop: asyncio.AbstractEventLoop
214+
dispatch_future: asyncio.Future
215+
thread_send_queue: ThreadQueue
216+
async_recv_queue: AsyncQueue
217+
218+
219+
class _StartCallbackWSGIServer(pywsgi.WSGIServer):
220+
def __init__(self, before_first_request: Callable[[], None], *args, **kwargs):
221+
self._before_first_request_callback = before_first_request
222+
super().__init__(*args, **kwargs)
223+
224+
def init_socket(self):
225+
self._before_first_request_callback()
226+
return super().init_socket()

0 commit comments

Comments
 (0)