Skip to content

Commit 953f836

Browse files
committed
add session reconnection to aiohttp
1 parent e675db6 commit 953f836

File tree

2 files changed

+50
-46
lines changed

2 files changed

+50
-46
lines changed

pywebio/platform/adaptor/ws.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,12 @@ class WebSocketHandler:
100100
connection: WebSocketConnection
101101
reconnectable: bool
102102

103-
def __init__(self, connection: WebSocketConnection, application, reconnectable: bool):
103+
def __init__(self, connection: WebSocketConnection, application, reconnectable: bool, ioloop=None):
104104
logger.debug("WebSocket opened")
105105
self.connection = connection
106106
self.reconnectable = reconnectable
107107
self.session_id = connection.get_query_argument('session')
108+
self.ioloop = ioloop or asyncio.get_event_loop()
108109

109110
if self.session_id in ('NEW', None): # 初始请求,创建新 Session
110111
self._init_session(application)
@@ -144,7 +145,7 @@ def _init_session(self, application):
144145
application, session_info=session_info,
145146
on_task_command=self._send_msg_to_client,
146147
on_session_close=self._close_from_session,
147-
loop=asyncio.get_event_loop())
148+
loop=self.ioloop)
148149
_state.unclosed_sessions[self.session_id] = self.session
149150

150151
def _send_msg_to_client(self, session):

pywebio/platform/aiohttp.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@
33
import json
44
import logging
55
import os
6+
import typing
67
from functools import partial
78
from urllib.parse import urlparse
89

910
from aiohttp import web
1011

12+
from .adaptor import ws as ws_adaptor
1113
from .page import make_applications, render_page
1214
from .remote_access import start_remote_access_service
1315
from .tornado import open_webbrowser_on_server_started
14-
from .utils import cdn_validation, deserialize_binary_event, print_listen_address
15-
from ..session import CoroutineBasedSession, ThreadBasedSession, register_session_implement_for_target, Session
16+
from .utils import cdn_validation, print_listen_address
17+
from ..session import register_session_implement_for_target, Session
1618
from ..session.base import get_session_info_from_headers
17-
from ..utils import get_free_port, STATIC_PATH, iscoroutinefunction, isgeneratorfunction
19+
from ..utils import get_free_port, STATIC_PATH
1820

1921
logger = logging.getLogger(__name__)
2022

@@ -39,7 +41,36 @@ def _is_same_site(origin, host):
3941
return origin == host
4042

4143

42-
def _webio_handler(applications, cdn, websocket_settings, check_origin_func=_is_same_site):
44+
class WebSocketConnection(ws_adaptor.WebSocketConnection):
45+
46+
def __init__(self, ws: web.WebSocketResponse, http: web.Request, ioloop):
47+
self.ws = ws
48+
self.http = http
49+
self.ioloop = ioloop
50+
51+
def get_query_argument(self, name) -> typing.Optional[str]:
52+
return self.http.query.getone(name, None)
53+
54+
def make_session_info(self) -> dict:
55+
session_info = get_session_info_from_headers(self.http.headers)
56+
session_info['user_ip'] = self.http.remote
57+
session_info['request'] = self.http
58+
session_info['backend'] = 'aiohttp'
59+
session_info['protocol'] = 'websocket'
60+
return session_info
61+
62+
def write_message(self, message: dict):
63+
msg_str = json.dumps(message)
64+
self.ioloop.create_task(self.ws.send_str(msg_str))
65+
66+
def closed(self) -> bool:
67+
return self.ws.closed
68+
69+
def close(self):
70+
self.ioloop.create_task(self.ws.close())
71+
72+
73+
def _webio_handler(applications, cdn, websocket_settings, reconnect_timeout=0, check_origin_func=_is_same_site):
4374
"""
4475
:param dict applications: dict of `name -> task function`
4576
:param bool/str cdn: Whether to load front-end static resources from CDN
@@ -68,61 +99,31 @@ async def wshandle(request: web.Request):
6899
ws = web.WebSocketResponse(**websocket_settings)
69100
await ws.prepare(request)
70101

71-
close_from_session_tag = False # 是否由session主动关闭连接
72-
73-
def send_msg_to_client(session: Session):
74-
for msg in session.get_task_commands():
75-
msg_str = json.dumps(msg)
76-
ioloop.create_task(ws.send_str(msg_str))
77-
78-
def close_from_session():
79-
nonlocal close_from_session_tag
80-
close_from_session_tag = True
81-
ioloop.create_task(ws.close())
82-
logger.debug("WebSocket closed from session")
83-
84-
session_info = get_session_info_from_headers(request.headers)
85-
session_info['user_ip'] = request.remote
86-
session_info['request'] = request
87-
session_info['backend'] = 'aiohttp'
88-
session_info['protocol'] = 'websocket'
89-
90102
app_name = request.query.getone('app', 'index')
91103
application = applications.get(app_name) or applications['index']
92104

93-
if iscoroutinefunction(application) or isgeneratorfunction(application):
94-
session = CoroutineBasedSession(application, session_info=session_info,
95-
on_task_command=send_msg_to_client,
96-
on_session_close=close_from_session)
97-
else:
98-
session = ThreadBasedSession(application, session_info=session_info,
99-
on_task_command=send_msg_to_client,
100-
on_session_close=close_from_session, loop=ioloop)
105+
conn = WebSocketConnection(ws, request, ioloop)
106+
handler = ws_adaptor.WebSocketHandler(
107+
connection=conn, application=application, reconnectable=bool(reconnect_timeout), ioloop=ioloop
108+
)
101109

102110
# see: https://github.com/aio-libs/aiohttp/issues/1768
103111
try:
104112
async for msg in ws:
105-
if msg.type == web.WSMsgType.text:
106-
data = msg.json()
107-
elif msg.type == web.WSMsgType.binary:
108-
data = deserialize_binary_event(msg.data)
113+
if msg.type in (web.WSMsgType.text, web.WSMsgType.binary):
114+
handler.send_client_data(msg.data)
109115
elif msg.type == web.WSMsgType.close:
110116
raise asyncio.CancelledError()
111-
112-
if data is not None:
113-
session.send_client_event(data)
114117
finally:
115-
if not close_from_session_tag:
116-
# close session because client disconnected to server
117-
session.close(nonblock=True)
118-
logger.debug("WebSocket closed from client")
118+
handler.notify_connection_lost()
119119

120120
return ws
121121

122122
return wshandle
123123

124124

125-
def webio_handler(applications, cdn=True, allowed_origins=None, check_origin=None, websocket_settings=None):
125+
def webio_handler(applications, cdn=True, reconnect_timeout=0, allowed_origins=None, check_origin=None,
126+
websocket_settings=None):
126127
"""Get the `Request Handler <https://docs.aiohttp.org/en/stable/web_quickstart.html#aiohttp-web-handler>`_ coroutine for running PyWebIO applications in aiohttp.
127128
The handler communicates with the browser by WebSocket protocol.
128129
@@ -145,6 +146,7 @@ def webio_handler(applications, cdn=True, allowed_origins=None, check_origin=Non
145146

146147
return _webio_handler(applications=applications, cdn=cdn,
147148
check_origin_func=check_origin_func,
149+
reconnect_timeout=reconnect_timeout,
148150
websocket_settings=websocket_settings)
149151

150152

@@ -168,6 +170,7 @@ async def index(request):
168170

169171
def start_server(applications, port=0, host='', debug=False,
170172
cdn=True, static_dir=None, remote_access=False,
173+
reconnect_timeout=0,
171174
allowed_origins=None, check_origin=None,
172175
auto_open_webbrowser=False,
173176
websocket_settings=None,
@@ -191,7 +194,7 @@ def start_server(applications, port=0, host='', debug=False,
191194

192195
cdn = cdn_validation(cdn, 'warn')
193196

194-
handler = webio_handler(applications, cdn=cdn, allowed_origins=allowed_origins,
197+
handler = webio_handler(applications, cdn=cdn, allowed_origins=allowed_origins, reconnect_timeout=reconnect_timeout,
195198
check_origin=check_origin, websocket_settings=websocket_settings)
196199

197200
app = web.Application(**aiohttp_settings)

0 commit comments

Comments
 (0)