Skip to content

Commit eb1c74c

Browse files
committed
add session reconnection to fastapi
1 parent 953f836 commit eb1c74c

File tree

1 file changed

+54
-51
lines changed

1 file changed

+54
-51
lines changed

pywebio/platform/fastapi.py

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,57 @@
1-
import os
21
import asyncio
3-
import json
42
import logging
3+
import os
4+
import typing
55
from functools import partial
66

77
import uvicorn
88
from starlette.applications import Starlette
99
from starlette.requests import Request
1010
from starlette.responses import HTMLResponse
1111
from starlette.routing import Route, WebSocketRoute, Mount
12-
from starlette.websockets import WebSocket
12+
from starlette.websockets import WebSocket, WebSocketState
1313
from starlette.websockets import WebSocketDisconnect
1414

15+
from .page import make_applications, render_page
1516
from .remote_access import start_remote_access_service
1617
from .tornado import open_webbrowser_on_server_started
17-
from .page import make_applications, render_page
18-
from .utils import cdn_validation, OriginChecker, deserialize_binary_event, print_listen_address
19-
from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, Session
18+
from .utils import cdn_validation, OriginChecker, print_listen_address
19+
from ..session import register_session_implement_for_target, Session
2020
from ..session.base import get_session_info_from_headers
21-
from ..utils import get_free_port, STATIC_PATH, iscoroutinefunction, isgeneratorfunction, strip_space
21+
from ..utils import get_free_port, STATIC_PATH, strip_space
2222

2323
logger = logging.getLogger(__name__)
24+
from .adaptor import ws as ws_adaptor
25+
26+
27+
class WebSocketConnection(ws_adaptor.WebSocketConnection):
28+
29+
def __init__(self, websocket: WebSocket, ioloop):
30+
self.ws = websocket
31+
self.ioloop = ioloop
32+
33+
def get_query_argument(self, name) -> typing.Optional[str]:
34+
return self.ws.query_params.get(name, None)
35+
36+
def make_session_info(self) -> dict:
37+
session_info = get_session_info_from_headers(self.ws.headers)
38+
session_info['user_ip'] = self.ws.client.host or ''
39+
session_info['request'] = self.ws
40+
session_info['backend'] = 'starlette'
41+
session_info['protocol'] = 'websocket'
42+
return session_info
2443

44+
def write_message(self, message: dict):
45+
self.ioloop.create_task(self.ws.send_json(message))
2546

26-
def _webio_routes(applications, cdn, check_origin_func):
47+
def closed(self) -> bool:
48+
return self.ws.application_state == WebSocketState.DISCONNECTED
49+
50+
def close(self):
51+
self.ioloop.create_task(self.ws.close())
52+
53+
54+
def _webio_routes(applications, cdn, check_origin_func, reconnect_timeout):
2755
"""
2856
:param dict applications: dict of `name -> task function`
2957
:param bool/str cdn: Whether to load front-end static resources from CDN
@@ -49,64 +77,35 @@ async def websocket_endpoint(websocket: WebSocket):
4977
ioloop = asyncio.get_event_loop()
5078
await websocket.accept()
5179

52-
close_from_session_tag = False # session close causes websocket close
53-
54-
def send_msg_to_client(session: Session):
55-
for msg in session.get_task_commands():
56-
ioloop.create_task(websocket.send_json(msg))
57-
58-
def close_from_session():
59-
nonlocal close_from_session_tag
60-
close_from_session_tag = True
61-
ioloop.create_task(websocket.close())
62-
logger.debug("WebSocket closed from session")
63-
64-
session_info = get_session_info_from_headers(websocket.headers)
65-
session_info['user_ip'] = websocket.client.host or ''
66-
session_info['request'] = websocket
67-
session_info['backend'] = 'starlette'
68-
session_info['protocol'] = 'websocket'
69-
7080
app_name = websocket.query_params.get('app', 'index')
7181
application = applications.get(app_name) or applications['index']
7282

73-
if iscoroutinefunction(application) or isgeneratorfunction(application):
74-
session = CoroutineBasedSession(application, session_info=session_info,
75-
on_task_command=send_msg_to_client,
76-
on_session_close=close_from_session)
77-
else:
78-
session = ThreadBasedSession(application, session_info=session_info,
79-
on_task_command=send_msg_to_client,
80-
on_session_close=close_from_session, loop=ioloop)
83+
conn = WebSocketConnection(websocket, ioloop)
84+
handler = ws_adaptor.WebSocketHandler(
85+
connection=conn, application=application, reconnectable=bool(reconnect_timeout), ioloop=ioloop
86+
)
8187

8288
while True:
8389
try:
8490
msg = await websocket.receive()
8591
if msg["type"] == "websocket.disconnect":
8692
raise WebSocketDisconnect(msg["code"])
8793
text, binary = msg.get('text'), msg.get('bytes')
88-
event = None
8994
if text:
90-
event = json.loads(text)
91-
elif binary:
92-
event = deserialize_binary_event(binary)
95+
handler.send_client_data(text)
96+
if binary:
97+
handler.send_client_data(text)
9398
except WebSocketDisconnect:
94-
if not close_from_session_tag:
95-
# close session because client disconnected to server
96-
session.close(nonblock=True)
97-
logger.debug("WebSocket closed from client")
99+
handler.notify_connection_lost()
98100
break
99101

100-
if event is not None:
101-
session.send_client_event(event)
102-
103102
return [
104103
Route("/", http_endpoint),
105104
WebSocketRoute("/", websocket_endpoint)
106105
]
107106

108107

109-
def webio_routes(applications, cdn=True, allowed_origins=None, check_origin=None):
108+
def webio_routes(applications, cdn=True, reconnect_timeout=0, allowed_origins=None, check_origin=None):
110109
"""Get the FastAPI/Starlette routes for running PyWebIO applications.
111110
112111
The API communicates with the browser using WebSocket protocol.
@@ -137,10 +136,11 @@ def webio_routes(applications, cdn=True, allowed_origins=None, check_origin=None
137136
else:
138137
check_origin_func = lambda origin, host: OriginChecker.is_same_site(origin, host) or check_origin(origin)
139138

140-
return _webio_routes(applications=applications, cdn=cdn, check_origin_func=check_origin_func)
139+
return _webio_routes(applications=applications, cdn=cdn, check_origin_func=check_origin_func,
140+
reconnect_timeout=reconnect_timeout)
141141

142142

143-
def start_server(applications, port=0, host='', cdn=True,
143+
def start_server(applications, port=0, host='', cdn=True, reconnect_timeout=0,
144144
static_dir=None, remote_access=False, debug=False,
145145
allowed_origins=None, check_origin=None,
146146
auto_open_webbrowser=False,
@@ -156,7 +156,8 @@ def start_server(applications, port=0, host='', cdn=True,
156156
.. versionadded:: 1.3
157157
"""
158158

159-
app = asgi_app(applications, cdn=cdn, static_dir=static_dir, debug=debug,
159+
app = asgi_app(applications, cdn=cdn, reconnect_timeout=reconnect_timeout,
160+
static_dir=static_dir, debug=debug,
160161
allowed_origins=allowed_origins, check_origin=check_origin)
161162

162163
if auto_open_webbrowser:
@@ -176,7 +177,8 @@ def start_server(applications, port=0, host='', cdn=True,
176177
uvicorn.run(app, host=host, port=port, **uvicorn_settings)
177178

178179

179-
def asgi_app(applications, cdn=True, static_dir=None, debug=False, allowed_origins=None, check_origin=None):
180+
def asgi_app(applications, cdn=True, reconnect_timeout=0, static_dir=None, debug=False, allowed_origins=None,
181+
check_origin=None):
180182
"""Get the starlette/Fastapi ASGI app for running PyWebIO applications.
181183
182184
Use :func:`pywebio.platform.fastapi.webio_routes` if you prefer handling static files yourself.
@@ -210,7 +212,8 @@ def asgi_app(applications, cdn=True, static_dir=None, debug=False, allowed_origi
210212
cdn = cdn_validation(cdn, 'warn')
211213
if cdn is False:
212214
cdn = 'pywebio_static'
213-
routes = webio_routes(applications, cdn=cdn, allowed_origins=allowed_origins, check_origin=check_origin)
215+
routes = webio_routes(applications, cdn=cdn, reconnect_timeout=reconnect_timeout,
216+
allowed_origins=allowed_origins, check_origin=check_origin)
214217
if static_dir:
215218
routes.append(Mount('/static', app=StaticFiles(directory=static_dir), name="static"))
216219
routes.append(Mount('/pywebio_static', app=StaticFiles(directory=STATIC_PATH), name="pywebio_static"))

0 commit comments

Comments
 (0)