Skip to content

Commit e675db6

Browse files
committed
refactor session reconnect
1 parent 9ce8fb2 commit e675db6

File tree

5 files changed

+255
-135
lines changed

5 files changed

+255
-135
lines changed

pywebio/platform/adaptor/__init__.py

Whitespace-only changes.

pywebio/platform/adaptor/ws.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import asyncio
2+
import json
3+
import logging
4+
import time
5+
import typing
6+
from typing import Dict
7+
import abc
8+
from ..utils import deserialize_binary_event
9+
from ...session import CoroutineBasedSession, ThreadBasedSession, Session
10+
from ...utils import iscoroutinefunction, isgeneratorfunction, \
11+
random_str, LRUDict
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class _state:
17+
# only used in reconnect enabled
18+
# used to clean up session
19+
detached_sessions = LRUDict() # session_id -> detached timestamp. In increasing order of the time
20+
21+
# unclosed and unexpired session
22+
# only used in reconnect enabled
23+
# used to clean up session
24+
# used to retrieve session by id when new connection
25+
unclosed_sessions: Dict[str, Session] = {} # session_id -> session
26+
27+
# the messages that can't deliver to browser when session close due to connection lost
28+
undelivered_messages: Dict[str, list] = {} # session_id -> unhandled message list
29+
30+
# used to get the active conn in session's callbacks
31+
active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler
32+
33+
expire_second = 10
34+
35+
36+
def set_expire_second(sec):
37+
_state.expire_second = max(_state.expire_second, sec)
38+
39+
40+
def clean_expired_sessions():
41+
while _state.detached_sessions:
42+
session_id, detached_ts = _state.detached_sessions.popitem(last=False) # 弹出最早过期的session
43+
44+
if time.time() < detached_ts + _state.expire_second:
45+
# this session is not expired
46+
_state.detached_sessions[session_id] = detached_ts # restore
47+
_state.detached_sessions.move_to_end(session_id, last=False) # move to head
48+
break
49+
50+
# clean this session
51+
logger.debug("session %s expired" % session_id)
52+
_state.active_connections.pop(session_id, None)
53+
_state.undelivered_messages.pop(session_id, None)
54+
session = _state.unclosed_sessions.pop(session_id, None)
55+
if session:
56+
session.close(nonblock=True)
57+
58+
59+
async def session_clean_task():
60+
logger.debug("Start session cleaning task")
61+
while True:
62+
try:
63+
clean_expired_sessions()
64+
except Exception:
65+
logger.exception("Error when clean expired sessions")
66+
67+
await asyncio.sleep(_state.expire_second // 2)
68+
69+
70+
class WebSocketConnection(abc.ABC):
71+
@abc.abstractmethod
72+
def get_query_argument(self, name) -> typing.Optional[str]:
73+
pass
74+
75+
@abc.abstractmethod
76+
def make_session_info(self) -> dict:
77+
pass
78+
79+
@abc.abstractmethod
80+
def write_message(self, message: dict):
81+
pass
82+
83+
@abc.abstractmethod
84+
def closed(self) -> bool:
85+
return False
86+
87+
@abc.abstractmethod
88+
def close(self):
89+
pass
90+
91+
92+
class WebSocketHandler:
93+
"""
94+
hold by one connection,
95+
share one session with multiple connection in session lifetime, but one conn at a time
96+
"""
97+
98+
session_id: str = None
99+
session: Session = None # the session that current connection attaches
100+
connection: WebSocketConnection
101+
reconnectable: bool
102+
103+
def __init__(self, connection: WebSocketConnection, application, reconnectable: bool):
104+
logger.debug("WebSocket opened")
105+
self.connection = connection
106+
self.reconnectable = reconnectable
107+
self.session_id = connection.get_query_argument('session')
108+
109+
if self.session_id in ('NEW', None): # 初始请求,创建新 Session
110+
self._init_session(application)
111+
if reconnectable:
112+
# set session id to client, so the client can send it back to server to recover a session when it
113+
# resumes form a connection lost
114+
connection.write_message(dict(command='set_session_id', spec=self.session_id))
115+
elif self.session_id not in _state.unclosed_sessions: # session is expired
116+
bye_msg = dict(command='close_session')
117+
for m in _state.undelivered_messages.get(self.session_id, [bye_msg]):
118+
try:
119+
connection.write_message(m)
120+
except Exception:
121+
logger.exception("Error in sending message via websocket")
122+
else:
123+
self.session = _state.unclosed_sessions[self.session_id]
124+
_state.detached_sessions.pop(self.session_id, None)
125+
_state.active_connections[self.session_id] = connection
126+
# send the latest messages to client
127+
self._send_msg_to_client(self.session)
128+
129+
logger.debug('session id: %s' % self.session_id)
130+
131+
def _init_session(self, application):
132+
session_info = self.connection.make_session_info()
133+
self.session_id = random_str(24)
134+
# todo: only set item when reconnection enabled
135+
_state.active_connections[self.session_id] = self.connection
136+
137+
if iscoroutinefunction(application) or isgeneratorfunction(application):
138+
self.session = CoroutineBasedSession(
139+
application, session_info=session_info,
140+
on_task_command=self._send_msg_to_client,
141+
on_session_close=self._close_from_session)
142+
else:
143+
self.session = ThreadBasedSession(
144+
application, session_info=session_info,
145+
on_task_command=self._send_msg_to_client,
146+
on_session_close=self._close_from_session,
147+
loop=asyncio.get_event_loop())
148+
_state.unclosed_sessions[self.session_id] = self.session
149+
150+
def _send_msg_to_client(self, session):
151+
# self.connection may not be active,
152+
# here we need the active connection for this session
153+
conn = _state.active_connections.get(self.session_id)
154+
155+
if not conn or conn.closed():
156+
return
157+
158+
for msg in session.get_task_commands():
159+
try:
160+
conn.write_message(msg)
161+
except TypeError as e:
162+
logger.exception('Data serialization error: %s\n'
163+
'This may be because you pass the wrong type of parameter to the function'
164+
' of PyWebIO.\nData content: %s', e, msg)
165+
except Exception:
166+
logger.exception("Error in sending message via websocket")
167+
168+
def _close_from_session(self):
169+
session = _state.unclosed_sessions[self.session_id]
170+
if self.session_id in _state.active_connections:
171+
# send the undelivered messages to client
172+
self._send_msg_to_client(session=session)
173+
else:
174+
_state.undelivered_messages[self.session_id] = session.get_task_commands()
175+
176+
conn = _state.active_connections.pop(self.session_id, None)
177+
_state.unclosed_sessions.pop(self.session_id, None)
178+
if conn and not conn.closed():
179+
conn.close()
180+
181+
def send_client_data(self, data):
182+
if isinstance(data, bytes):
183+
event = deserialize_binary_event(data)
184+
else:
185+
event = json.loads(data)
186+
if event is None:
187+
return
188+
self.session.send_client_event(event)
189+
190+
def notify_connection_lost(self):
191+
_state.active_connections.pop(self.session_id, None)
192+
if not self.reconnectable:
193+
# when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
194+
# because the `session.close()` is reentrant
195+
self.session.close(nonblock=True)
196+
else:
197+
if self.session_id in _state.unclosed_sessions:
198+
_state.detached_sessions[self.session_id] = time.time()
199+
logger.debug("WebSocket closed")

0 commit comments

Comments
 (0)