1- import os
21import asyncio
3- import json
42import logging
3+ import os
4+ import typing
55from functools import partial
66
77import uvicorn
88from starlette .applications import Starlette
99from starlette .requests import Request
1010from starlette .responses import HTMLResponse
1111from starlette .routing import Route , WebSocketRoute , Mount
12- from starlette .websockets import WebSocket
12+ from starlette .websockets import WebSocket , WebSocketState
1313from starlette .websockets import WebSocketDisconnect
1414
15+ from .page import make_applications , render_page
1516from .remote_access import start_remote_access_service
1617from .tornado import open_webbrowser_on_server_started
17- from .page import make_applications , render_page
18- from .utils import cdn_validation , OriginChecker , deserialize_binary_event , print_listen_address
19- from ..session import CoroutineBasedSession , ThreadBasedSession , register_session_implement_for_target , Session
18+ from .utils import cdn_validation , OriginChecker , print_listen_address
19+ from ..session import register_session_implement_for_target , Session
2020from ..session .base import get_session_info_from_headers
21- from ..utils import get_free_port , STATIC_PATH , iscoroutinefunction , isgeneratorfunction , strip_space
21+ from ..utils import get_free_port , STATIC_PATH , strip_space
2222
2323logger = logging .getLogger (__name__ )
24+ from .adaptor import ws as ws_adaptor
25+
26+
27+ class WebSocketConnection (ws_adaptor .WebSocketConnection ):
28+
29+ def __init__ (self , websocket : WebSocket , ioloop ):
30+ self .ws = websocket
31+ self .ioloop = ioloop
32+
33+ def get_query_argument (self , name ) -> typing .Optional [str ]:
34+ return self .ws .query_params .get (name , None )
35+
36+ def make_session_info (self ) -> dict :
37+ session_info = get_session_info_from_headers (self .ws .headers )
38+ session_info ['user_ip' ] = self .ws .client .host or ''
39+ session_info ['request' ] = self .ws
40+ session_info ['backend' ] = 'starlette'
41+ session_info ['protocol' ] = 'websocket'
42+ return session_info
2443
44+ def write_message (self , message : dict ):
45+ self .ioloop .create_task (self .ws .send_json (message ))
2546
26- def _webio_routes (applications , cdn , check_origin_func ):
47+ def closed (self ) -> bool :
48+ return self .ws .application_state == WebSocketState .DISCONNECTED
49+
50+ def close (self ):
51+ self .ioloop .create_task (self .ws .close ())
52+
53+
54+ def _webio_routes (applications , cdn , check_origin_func , reconnect_timeout ):
2755 """
2856 :param dict applications: dict of `name -> task function`
2957 :param bool/str cdn: Whether to load front-end static resources from CDN
@@ -49,64 +77,35 @@ async def websocket_endpoint(websocket: WebSocket):
4977 ioloop = asyncio .get_event_loop ()
5078 await websocket .accept ()
5179
52- close_from_session_tag = False # session close causes websocket close
53-
54- def send_msg_to_client (session : Session ):
55- for msg in session .get_task_commands ():
56- ioloop .create_task (websocket .send_json (msg ))
57-
58- def close_from_session ():
59- nonlocal close_from_session_tag
60- close_from_session_tag = True
61- ioloop .create_task (websocket .close ())
62- logger .debug ("WebSocket closed from session" )
63-
64- session_info = get_session_info_from_headers (websocket .headers )
65- session_info ['user_ip' ] = websocket .client .host or ''
66- session_info ['request' ] = websocket
67- session_info ['backend' ] = 'starlette'
68- session_info ['protocol' ] = 'websocket'
69-
7080 app_name = websocket .query_params .get ('app' , 'index' )
7181 application = applications .get (app_name ) or applications ['index' ]
7282
73- if iscoroutinefunction (application ) or isgeneratorfunction (application ):
74- session = CoroutineBasedSession (application , session_info = session_info ,
75- on_task_command = send_msg_to_client ,
76- on_session_close = close_from_session )
77- else :
78- session = ThreadBasedSession (application , session_info = session_info ,
79- on_task_command = send_msg_to_client ,
80- on_session_close = close_from_session , loop = ioloop )
83+ conn = WebSocketConnection (websocket , ioloop )
84+ handler = ws_adaptor .WebSocketHandler (
85+ connection = conn , application = application , reconnectable = bool (reconnect_timeout ), ioloop = ioloop
86+ )
8187
8288 while True :
8389 try :
8490 msg = await websocket .receive ()
8591 if msg ["type" ] == "websocket.disconnect" :
8692 raise WebSocketDisconnect (msg ["code" ])
8793 text , binary = msg .get ('text' ), msg .get ('bytes' )
88- event = None
8994 if text :
90- event = json . loads (text )
91- elif binary :
92- event = deserialize_binary_event ( binary )
95+ handler . send_client_data (text )
96+ if binary :
97+ handler . send_client_data ( text )
9398 except WebSocketDisconnect :
94- if not close_from_session_tag :
95- # close session because client disconnected to server
96- session .close (nonblock = True )
97- logger .debug ("WebSocket closed from client" )
99+ handler .notify_connection_lost ()
98100 break
99101
100- if event is not None :
101- session .send_client_event (event )
102-
103102 return [
104103 Route ("/" , http_endpoint ),
105104 WebSocketRoute ("/" , websocket_endpoint )
106105 ]
107106
108107
109- def webio_routes (applications , cdn = True , allowed_origins = None , check_origin = None ):
108+ def webio_routes (applications , cdn = True , reconnect_timeout = 0 , allowed_origins = None , check_origin = None ):
110109 """Get the FastAPI/Starlette routes for running PyWebIO applications.
111110
112111 The API communicates with the browser using WebSocket protocol.
@@ -137,10 +136,11 @@ def webio_routes(applications, cdn=True, allowed_origins=None, check_origin=None
137136 else :
138137 check_origin_func = lambda origin , host : OriginChecker .is_same_site (origin , host ) or check_origin (origin )
139138
140- return _webio_routes (applications = applications , cdn = cdn , check_origin_func = check_origin_func )
139+ return _webio_routes (applications = applications , cdn = cdn , check_origin_func = check_origin_func ,
140+ reconnect_timeout = reconnect_timeout )
141141
142142
143- def start_server (applications , port = 0 , host = '' , cdn = True ,
143+ def start_server (applications , port = 0 , host = '' , cdn = True , reconnect_timeout = 0 ,
144144 static_dir = None , remote_access = False , debug = False ,
145145 allowed_origins = None , check_origin = None ,
146146 auto_open_webbrowser = False ,
@@ -156,7 +156,8 @@ def start_server(applications, port=0, host='', cdn=True,
156156 .. versionadded:: 1.3
157157 """
158158
159- app = asgi_app (applications , cdn = cdn , static_dir = static_dir , debug = debug ,
159+ app = asgi_app (applications , cdn = cdn , reconnect_timeout = reconnect_timeout ,
160+ static_dir = static_dir , debug = debug ,
160161 allowed_origins = allowed_origins , check_origin = check_origin )
161162
162163 if auto_open_webbrowser :
@@ -176,7 +177,8 @@ def start_server(applications, port=0, host='', cdn=True,
176177 uvicorn .run (app , host = host , port = port , ** uvicorn_settings )
177178
178179
179- def asgi_app (applications , cdn = True , static_dir = None , debug = False , allowed_origins = None , check_origin = None ):
180+ def asgi_app (applications , cdn = True , reconnect_timeout = 0 , static_dir = None , debug = False , allowed_origins = None ,
181+ check_origin = None ):
180182 """Get the starlette/Fastapi ASGI app for running PyWebIO applications.
181183
182184 Use :func:`pywebio.platform.fastapi.webio_routes` if you prefer handling static files yourself.
@@ -210,7 +212,8 @@ def asgi_app(applications, cdn=True, static_dir=None, debug=False, allowed_origi
210212 cdn = cdn_validation (cdn , 'warn' )
211213 if cdn is False :
212214 cdn = 'pywebio_static'
213- routes = webio_routes (applications , cdn = cdn , allowed_origins = allowed_origins , check_origin = check_origin )
215+ routes = webio_routes (applications , cdn = cdn , reconnect_timeout = reconnect_timeout ,
216+ allowed_origins = allowed_origins , check_origin = check_origin )
214217 if static_dir :
215218 routes .append (Mount ('/static' , app = StaticFiles (directory = static_dir ), name = "static" ))
216219 routes .append (Mount ('/pywebio_static' , app = StaticFiles (directory = STATIC_PATH ), name = "pywebio_static" ))
0 commit comments