Skip to content

Commit 95fff37

Browse files
committed
fix #545: memory leak after close session
1 parent f0b7beb commit 95fff37

File tree

3 files changed

+71
-52
lines changed

3 files changed

+71
-52
lines changed

pywebio/platform/adaptor/ws.py

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,18 @@
1313
logger = logging.getLogger(__name__)
1414

1515

16-
class _state:
17-
# only used in reconnect enabled
16+
# used to store global state when reconnect enabled
17+
class _reconnect_state:
1818
# used to clean up session
1919
detached_sessions = LRUDict() # session_id -> detached timestamp. In increasing order of the time
2020

2121
# unclosed and unexpired session
22-
# only used in reconnect enabled
2322
# used to clean up session
2423
# used to retrieve session by id when new connection
2524
unclosed_sessions: Dict[str, Session] = {} # session_id -> session
2625

2726
# the messages that can't deliver to browser when session close due to connection lost
28-
undelivered_messages: Dict[str, list] = {} # session_id -> unhandled message list
27+
session_will_messages: Dict[str, list] = {} # session_id -> unhandled message list
2928

3029
# used to get the active conn in session's callbacks
3130
active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler
@@ -34,24 +33,24 @@ class _state:
3433

3534

3635
def set_expire_second(sec):
37-
_state.expire_second = max(_state.expire_second, sec)
36+
_reconnect_state.expire_second = max(_reconnect_state.expire_second, sec)
3837

3938

4039
def clean_expired_sessions():
41-
while _state.detached_sessions:
42-
session_id, detached_ts = _state.detached_sessions.popitem(last=False) # 弹出最早过期的session
40+
while _reconnect_state.detached_sessions:
41+
session_id, detached_ts = _reconnect_state.detached_sessions.popitem(last=False) # 弹出最早过期的session
4342

44-
if time.time() < detached_ts + _state.expire_second:
43+
if time.time() < detached_ts + _reconnect_state.expire_second:
4544
# this session is not expired
46-
_state.detached_sessions[session_id] = detached_ts # restore
47-
_state.detached_sessions.move_to_end(session_id, last=False) # move to head
45+
_reconnect_state.detached_sessions[session_id] = detached_ts # restore
46+
_reconnect_state.detached_sessions.move_to_end(session_id, last=False) # move to head
4847
break
4948

5049
# clean this session
5150
logger.debug("session %s expired" % session_id)
52-
_state.active_connections.pop(session_id, None)
53-
_state.undelivered_messages.pop(session_id, None)
54-
session = _state.unclosed_sessions.pop(session_id, None)
51+
_reconnect_state.active_connections.pop(session_id, None)
52+
_reconnect_state.session_will_messages.pop(session_id, None)
53+
session = _reconnect_state.unclosed_sessions.pop(session_id, None)
5554
if session:
5655
session.close(nonblock=True)
5756

@@ -61,7 +60,7 @@ def clean_expired_sessions():
6160

6261
async def session_clean_task():
6362
global _session_clean_task_started
64-
if _session_clean_task_started or not _state.expire_second:
63+
if _session_clean_task_started or not _reconnect_state.expire_second:
6564
return
6665

6766
_session_clean_task_started = True
@@ -72,7 +71,7 @@ async def session_clean_task():
7271
except Exception:
7372
logger.exception("Error when clean expired sessions")
7473

75-
await asyncio.sleep(_state.expire_second // 2)
74+
await asyncio.sleep(_reconnect_state.expire_second // 2)
7675

7776

7877
class WebSocketConnection(abc.ABC):
@@ -118,30 +117,30 @@ def __init__(self, connection: WebSocketConnection, application, reconnectable:
118117
if self.session_id in ('NEW', None): # 初始请求,创建新 Session
119118
self._init_session(application)
120119
if reconnectable:
120+
_reconnect_state.active_connections[self.session_id] = self.connection
121+
_reconnect_state.unclosed_sessions[self.session_id] = self.session
121122
# set session id to client, so the client can send it back to server to recover a session when it
122123
# resumes form a connection lost
123124
connection.write_message(dict(command='set_session_id', spec=self.session_id))
124-
elif self.session_id not in _state.unclosed_sessions: # session is expired
125+
elif self.session_id not in _reconnect_state.unclosed_sessions: # session is expired
125126
bye_msg = dict(command='close_session')
126-
for m in _state.undelivered_messages.get(self.session_id, [bye_msg]):
127+
for m in _reconnect_state.session_will_messages.get(self.session_id, [bye_msg]):
127128
try:
128129
connection.write_message(m)
129130
except Exception:
130131
logger.exception("Error in sending message via websocket")
131-
else:
132-
self.session = _state.unclosed_sessions[self.session_id]
133-
_state.detached_sessions.pop(self.session_id, None)
134-
_state.active_connections[self.session_id] = connection
132+
else: # resumes form a connection lost
133+
self.session = _reconnect_state.unclosed_sessions[self.session_id]
134+
_reconnect_state.detached_sessions.pop(self.session_id, None)
135+
_reconnect_state.active_connections[self.session_id] = connection
135136
# send the latest messages to client
136-
self._send_msg_to_client(self.session)
137+
self._send_msg_to_client()
137138

138139
logger.debug('session id: %s' % self.session_id)
139140

140141
def _init_session(self, application):
141142
session_info = self.connection.make_session_info()
142143
self.session_id = random_str(24)
143-
# todo: only set item when reconnection enabled
144-
_state.active_connections[self.session_id] = self.connection
145144

146145
if iscoroutinefunction(application) or isgeneratorfunction(application):
147146
self.session = CoroutineBasedSession(
@@ -154,12 +153,20 @@ def _init_session(self, application):
154153
on_task_command=self._send_msg_to_client,
155154
on_session_close=self._close_from_session,
156155
loop=self.ioloop)
157-
_state.unclosed_sessions[self.session_id] = self.session
158156

159-
def _send_msg_to_client(self, session):
160-
# self.connection may not be active,
161-
# here we need the active connection for this session
162-
conn = _state.active_connections.get(self.session_id)
157+
def _get_active_connection(self) -> Optional[WebSocketConnection]:
158+
# when reconnect enabled, the active connection for this session is in _reconnect_state.active_connections,
159+
# otherwise, it's self.connection.
160+
if self.reconnectable:
161+
conn = _reconnect_state.active_connections.get(self.session_id)
162+
else:
163+
conn = self.connection
164+
165+
return conn
166+
167+
def _send_msg_to_client(self, session: Session = None):
168+
conn = self._get_active_connection()
169+
session = session or self.session
163170

164171
if not conn or conn.closed():
165172
return
@@ -175,17 +182,13 @@ def _send_msg_to_client(self, session):
175182
logger.exception("Error in sending message via websocket")
176183

177184
def _close_from_session(self):
178-
session = _state.unclosed_sessions[self.session_id]
179-
if self.session_id in _state.active_connections:
180-
# send the undelivered messages to client
181-
self._send_msg_to_client(session=session)
182-
else:
183-
_state.undelivered_messages[self.session_id] = session.get_task_commands()
184-
185-
conn = _state.active_connections.pop(self.session_id, None)
186-
_state.unclosed_sessions.pop(self.session_id, None)
185+
conn = self._get_active_connection()
187186
if conn and not conn.closed():
187+
self._send_msg_to_client()
188188
conn.close()
189+
elif self.reconnectable: # no active connection, and reconnect is enabled
190+
_reconnect_state.session_will_messages[self.session_id] = self.session.get_task_commands()
191+
self.session = None
189192

190193
def send_client_data(self, data):
191194
if isinstance(data, bytes):
@@ -197,12 +200,14 @@ def send_client_data(self, data):
197200
self.session.send_client_event(event)
198201

199202
def notify_connection_lost(self):
200-
_state.active_connections.pop(self.session_id, None)
201-
if not self.reconnectable:
203+
logger.debug("WebSocket closed")
204+
if not self.reconnectable and self.session:
202205
# when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
203206
# because the `session.close()` is reentrant
204207
self.session.close(nonblock=True)
205-
else:
206-
if self.session_id in _state.unclosed_sessions:
207-
_state.detached_sessions[self.session_id] = time.time()
208-
logger.debug("WebSocket closed")
208+
self.session = None # reset the reference
209+
return
210+
211+
_reconnect_state.active_connections.pop(self.session_id, None)
212+
if self.session_id in _reconnect_state.unclosed_sessions:
213+
_reconnect_state.detached_sessions[self.session_id] = time.time()

pywebio/session/coroutinebased.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,14 @@ def get_task_commands(self):
153153
def _cleanup(self):
154154
for t in list(self.coros.values()): # t.close() may cause self.coros changed size
155155
t.step(SessionClosedException, throw_exp=True)
156+
# in case that the task catch the SessionClosedException, we need to close it manually
156157
t.close()
157158
self.coros = {} # delete session tasks
158159

160+
# reset the reference, to avoid circular reference
161+
self._on_session_close = None
162+
self._on_task_command = None
163+
159164
def close(self, nonblock=False):
160165
"""关闭当前Session。由Backend调用"""
161166
if self.closed():
@@ -295,7 +300,6 @@ def __init__(self, coro, session: CoroutineBasedSession, on_coro_stop=None):
295300
"""
296301
self.session = session
297302
self.coro = coro
298-
self.coro_id = None
299303
self.result = None
300304
self.task_closed = False # 任务完毕/取消
301305
self.on_coro_stop = on_coro_stop or (lambda _: None)
@@ -322,21 +326,23 @@ def step(self, result=None, throw_exp=False):
322326
except StopIteration as e:
323327
if len(e.args) == 1:
324328
self.result = e.args[0]
325-
self.task_closed = True
329+
self.close()
326330
logger.debug('Task[%s] finished', self.coro_id)
327-
self.on_coro_stop(self)
328331
except Exception as e:
329332
if not isinstance(e, SessionException):
330333
self.session.on_task_exception()
331-
self.task_closed = True
332-
self.on_coro_stop(self)
334+
self.close()
335+
336+
if coro_yield is None:
337+
return
333338

334339
future = None
335340
if isinstance(coro_yield, WebIOFuture):
336341
if coro_yield.coro:
337342
future = asyncio.run_coroutine_threadsafe(coro_yield.coro, asyncio.get_event_loop())
338-
elif coro_yield is not None:
343+
else:
339344
future = coro_yield
345+
340346
if not self.session.closed() and hasattr(future, 'add_done_callback'):
341347
future.add_done_callback(self._wakeup)
342348
self.pending_futures[id(future)] = future
@@ -350,14 +356,18 @@ def close(self):
350356
if self.task_closed:
351357
return
352358

353-
logger.debug('Task[%s] closed', self.coro_id)
359+
self.task_closed = True
360+
354361
self.coro.close()
355362
while self.pending_futures:
356363
_, f = self.pending_futures.popitem()
357364
f.cancel()
358365

359-
self.task_closed = True
360366
self.on_coro_stop(self)
367+
self.on_coro_stop = None # avoid circular reference
368+
self.session = None
369+
370+
logger.debug('Task[%s] closed', self.coro_id)
361371

362372
def __del__(self):
363373
if not self.task_closed:

pywebio/session/threadbased.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ def _trigger_close_event(self):
173173
self._on_session_close()
174174

175175
def _cleanup(self, nonblock=False):
176+
# reset the reference, to avoid circular reference
177+
self._on_session_close = None
178+
self._on_task_command = None
179+
176180
cls = type(self)
177181
if not nonblock:
178182
self.unhandled_task_msgs.wait_empty(8)

0 commit comments

Comments
 (0)