Skip to content

Commit 3fe47f4

Browse files
committed
add session clean task in aiohttp & fastapi
1 parent eb1c74c commit 3fe47f4

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

pywebio/platform/adaptor/ws.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class _state:
3030
# used to get the active conn in session's callbacks
3131
active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler
3232

33-
expire_second = 10
33+
expire_second = 0
3434

3535

3636
def set_expire_second(sec):
@@ -56,7 +56,14 @@ def clean_expired_sessions():
5656
session.close(nonblock=True)
5757

5858

59+
_session_clean_task_started = False
60+
61+
5962
async def session_clean_task():
63+
global _session_clean_task_started
64+
if _session_clean_task_started or not _state.expire_second:
65+
return
66+
_session_clean_task_started = True
6067
logger.debug("Start session cleaning task")
6168
while True:
6269
try:

pywebio/platform/aiohttp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def _webio_handler(applications, cdn, websocket_settings, reconnect_timeout=0, c
7777
:param callable check_origin_func: check_origin_func(origin, host) -> bool
7878
:return: aiohttp Request Handler
7979
"""
80+
ws_adaptor.set_expire_second(reconnect_timeout)
81+
asyncio.get_event_loop().create_task(ws_adaptor.session_clean_task())
8082

8183
async def wshandle(request: web.Request):
8284
ioloop = asyncio.get_event_loop()

pywebio/platform/fastapi.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def _webio_routes(applications, cdn, check_origin_func, reconnect_timeout):
5858
:param callable check_origin_func: check_origin_func(origin, host) -> bool
5959
"""
6060

61+
ws_adaptor.set_expire_second(reconnect_timeout)
62+
asyncio.get_event_loop().create_task(ws_adaptor.session_clean_task())
63+
6164
async def http_endpoint(request: Request):
6265
origin = request.headers.get('origin')
6366
if origin and not check_origin_func(origin=origin, host=request.headers.get('host')):

0 commit comments

Comments
 (0)