1414import threading
1515import time
1616from contextlib import contextmanager
17- from typing import Dict , Optional
17+ from typing import Dict , Optional , List
18+ from collections import deque
1819
1920from ..page import make_applications , render_page
2021from ..utils import deserialize_binary_event
2122from ...session import CoroutineBasedSession , ThreadBasedSession , register_session_implement_for_target
22- from ...session .base import get_session_info_from_headers
23+ from ...session .base import get_session_info_from_headers , Session
2324from ...utils import random_str , LRUDict , isgeneratorfunction , iscoroutinefunction , check_webio_js
2425
2526
@@ -35,7 +36,7 @@ def request_obj(self):
3536 Return the current request object"""
3637 pass
3738
38- def request_method (self ):
39+ def request_method (self ) -> str :
3940 """返回当前请求的方法,大写
4041 Return the HTTP method of the current request, uppercase"""
4142 pass
@@ -45,29 +46,19 @@ def request_headers(self) -> Dict:
4546 Return the header dictionary of the current request"""
4647 pass
4748
48- def request_url_parameter (self , name , default = None ):
49+ def request_url_parameter (self , name , default = None ) -> str :
4950 """返回当前请求的URL参数
5051 Returns the value of the given URL parameter of the current request"""
5152 pass
5253
53- def request_body (self ):
54+ def request_body (self ) -> bytes :
5455 """返回当前请求的body数据
5556 Returns the data of the current request body
5657
5758 :return: bytes/bytearray
5859 """
5960 return b''
6061
61- def request_json (self ) -> Optional [Dict ]:
62- """返回当前请求的json反序列化后的内容,若请求数据不为json格式,返回None
63- Return the data (json deserialization) of the currently requested, if the data is not in json format, return None"""
64- try :
65- if self .request_headers ().get ('content-type' ) == 'application/octet-stream' :
66- return deserialize_binary_event (self .request_body ())
67- return json .loads (self .request_body ())
68- except Exception :
69- return None
70-
7162 def set_header (self , name , value ):
7263 """为当前响应设置header
7364 Set a header for the current response"""
@@ -92,7 +83,7 @@ def get_response(self):
9283 Get the current response object"""
9384 pass
9485
95- def get_client_ip (self ):
86+ def get_client_ip (self ) -> str :
9687 """获取用户的ip
9788 Get the user's ip"""
9889 pass
@@ -102,6 +93,56 @@ def get_client_ip(self):
10293_event_loop = None
10394
10495
96+ class ReliableTransport :
97+ def __init__ (self , session : Session , message_window : int = 4 ):
98+ self .session = session
99+ self .messages = deque ()
100+ self .window_size = message_window
101+ self .min_msg_id = 0 # the id of the first message in the window
102+ self .finished_event_id = - 1 # the id of the last finished event
103+
104+ @staticmethod
105+ def close_message (ack ):
106+ return dict (
107+ commands = [[dict (command = 'close_session' )]],
108+ seq = ack + 1
109+ )
110+
111+ def push_event (self , events : List [Dict ], seq : int ) -> int :
112+ """Send client events to the session and return the success message count"""
113+ if not events :
114+ return 0
115+
116+ submit_cnt = 0
117+ for eid , event in enumerate (events , start = seq ):
118+ if eid > self .finished_event_id :
119+ self .finished_event_id = eid # todo: use lock for check and set operation
120+ self .session .send_client_event (event )
121+ submit_cnt += 1
122+
123+ return submit_cnt
124+
125+ def get_response (self , ack = 0 ):
126+ """
127+ ack num is the number of messages that the client has received.
128+ response is a list of messages that the client should receive, along with their min id `seq`.
129+ """
130+ while ack >= self .min_msg_id and self .messages :
131+ self .messages .popleft ()
132+ self .min_msg_id += 1
133+
134+ if len (self .messages ) < self .window_size :
135+ msgs = self .session .get_task_commands ()
136+ if msgs :
137+ self .messages .append (msgs )
138+
139+ return dict (
140+ commands = list (self .messages ),
141+ seq = self .min_msg_id ,
142+ ack = self .finished_event_id
143+ )
144+
145+
105146# todo: use lock to avoid thread race condition
106147class HttpHandler :
107148 """基于HTTP的后端Handler实现
@@ -112,7 +153,7 @@ class HttpHandler:
112153
113154 """
114155 _webio_sessions = {} # WebIOSessionID -> WebIOSession()
115- _webio_last_commands = {} # WebIOSessionID -> (last commands, commands sequence id)
156+ _webio_transports = {} # WebIOSessionID -> ReliableTransport(), type: Dict[str, ReliableTransport]
116157 _webio_expire = LRUDict () # WebIOSessionID -> last active timestamp. In increasing order of last active time
117158 _webio_expire_lock = threading .Lock ()
118159
@@ -143,23 +184,13 @@ def _remove_expired_sessions(cls, session_expire_seconds):
143184 if session :
144185 session .close (nonblock = True )
145186 del cls ._webio_sessions [sid ]
187+ del cls ._webio_transports [sid ]
146188
147189 @classmethod
148190 def _remove_webio_session (cls , sid ):
149191 cls ._webio_sessions .pop (sid , None )
150192 cls ._webio_expire .pop (sid , None )
151193
152- @classmethod
153- def get_response (cls , sid , ack = 0 ):
154- commands , seq = cls ._webio_last_commands .get (sid , ([], 0 ))
155- if ack == seq :
156- webio_session = cls ._webio_sessions [sid ]
157- commands = webio_session .get_task_commands ()
158- seq += 1
159- cls ._webio_last_commands [sid ] = (commands , seq )
160-
161- return {'commands' : commands , 'seq' : seq }
162-
163194 def _process_cors (self , context : HttpContext ):
164195 """Handling cross-domain requests: check the source of the request and set headers"""
165196 origin = context .request_headers ().get ('Origin' , '' )
@@ -209,6 +240,14 @@ def get_cdn(self, context):
209240 return False
210241 return self .cdn
211242
243+ def read_event_data (self , context : HttpContext ) -> List [Dict ]:
244+ try :
245+ if context .request_headers ().get ('content-type' ) == 'application/octet-stream' :
246+ return [deserialize_binary_event (context .request_body ())]
247+ return json .loads (context .request_body ())
248+ except Exception :
249+ return []
250+
212251 @contextmanager
213252 def handle_request_context (self , context : HttpContext ):
214253 """called when every http request"""
@@ -240,16 +279,18 @@ def handle_request_context(self, context: HttpContext):
240279 context .set_content (html )
241280 return context .get_response ()
242281
243- webio_session_id = None
282+ ack = int (context .request_url_parameter ('ack' , 0 ))
283+ webio_session_id = request_headers ['webio-session-id' ]
284+ new_request = False
285+ if webio_session_id .startswith ('NEW-' ):
286+ new_request = True
287+ webio_session_id = webio_session_id [4 :]
244288
245- # 初始请求,创建新 Session
246- if not request_headers ['webio-session-id' ] or request_headers ['webio-session-id' ] == 'NEW' :
289+ if new_request and webio_session_id not in cls ._webio_sessions : # 初始请求,创建新 Session
247290 if context .request_method () == 'POST' : # 不能在POST请求中创建Session,防止CSRF攻击
248291 context .set_status (403 )
249292 return context .get_response ()
250293
251- webio_session_id = random_str (24 )
252- context .set_header ('webio-session-id' , webio_session_id )
253294 session_info = get_session_info_from_headers (context .request_headers ())
254295 session_info ['user_ip' ] = context .get_client_ip ()
255296 session_info ['request' ] = context .request_obj ()
@@ -264,17 +305,23 @@ def handle_request_context(self, context: HttpContext):
264305 session_cls = ThreadBasedSession
265306 webio_session = session_cls (application , session_info = session_info )
266307 cls ._webio_sessions [webio_session_id ] = webio_session
267- yield type (self ).WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
268- elif request_headers ['webio-session-id' ] not in cls ._webio_sessions : # WebIOSession deleted
269- context .set_content ([dict (command = 'close_session' )], json_type = True )
308+ cls ._webio_transports [webio_session_id ] = ReliableTransport (webio_session )
309+ yield cls .WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
310+ elif webio_session_id not in cls ._webio_sessions : # WebIOSession deleted
311+ close_msg = ReliableTransport .close_message (ack )
312+ context .set_content (close_msg , json_type = True )
270313 return context .get_response ()
271314 else :
272- webio_session_id = request_headers ['webio-session-id' ]
315+ # in this case, the request_headers['webio-session-id'] may also startswith NEW,
316+ # this is because the response for the previous new session request has not been received by the client,
317+ # and the client has sent a new request with the same session id.
273318 webio_session = cls ._webio_sessions [webio_session_id ]
274319
275320 if context .request_method () == 'POST' : # client push event
276- if context .request_json () is not None :
277- webio_session .send_client_event (context .request_json ())
321+ seq = int (context .request_url_parameter ('seq' , 0 ))
322+ event_data = self .read_event_data (context )
323+ submit_cnt = cls ._webio_transports [webio_session_id ].push_event (event_data , seq )
324+ if submit_cnt > 0 :
278325 yield type (self ).WAIT_MS_ON_POST / 1000.0 # <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <--- <---
279326 elif context .request_method () == 'GET' : # client pull messages
280327 pass
@@ -283,8 +330,8 @@ def handle_request_context(self, context: HttpContext):
283330
284331 self .interval_cleaning ()
285332
286- ack = int ( context . request_url_parameter ( ' ack' , 0 ) )
287- context .set_content (type ( self ). get_response ( webio_session_id , ack = ack ) , json_type = True )
333+ resp = cls . _webio_transports [ webio_session_id ]. get_response ( ack )
334+ context .set_content (resp , json_type = True )
288335
289336 if webio_session .closed ():
290337 self ._remove_webio_session (webio_session_id )
0 commit comments