1313logger = logging .getLogger (__name__ )
1414
1515
16- class _state :
17- # only used in reconnect enabled
16+ # used to store global state when reconnect enabled
17+ class _reconnect_state :
1818 # used to clean up session
1919 detached_sessions = LRUDict () # session_id -> detached timestamp. In increasing order of the time
2020
2121 # unclosed and unexpired session
22- # only used in reconnect enabled
2322 # used to clean up session
2423 # used to retrieve session by id when new connection
2524 unclosed_sessions : Dict [str , Session ] = {} # session_id -> session
2625
2726 # the messages that can't deliver to browser when session close due to connection lost
28- undelivered_messages : Dict [str , list ] = {} # session_id -> unhandled message list
27+ session_will_messages : Dict [str , list ] = {} # session_id -> unhandled message list
2928
3029 # used to get the active conn in session's callbacks
3130 active_connections : Dict [str , 'WebSocketConnection' ] = {} # session_id -> WSHandler
@@ -34,24 +33,24 @@ class _state:
3433
3534
3635def set_expire_second (sec ):
37- _state .expire_second = max (_state .expire_second , sec )
36+ _reconnect_state .expire_second = max (_reconnect_state .expire_second , sec )
3837
3938
4039def clean_expired_sessions ():
41- while _state .detached_sessions :
42- session_id , detached_ts = _state .detached_sessions .popitem (last = False ) # 弹出最早过期的session
40+ while _reconnect_state .detached_sessions :
41+ session_id , detached_ts = _reconnect_state .detached_sessions .popitem (last = False ) # 弹出最早过期的session
4342
44- if time .time () < detached_ts + _state .expire_second :
43+ if time .time () < detached_ts + _reconnect_state .expire_second :
4544 # this session is not expired
46- _state .detached_sessions [session_id ] = detached_ts # restore
47- _state .detached_sessions .move_to_end (session_id , last = False ) # move to head
45+ _reconnect_state .detached_sessions [session_id ] = detached_ts # restore
46+ _reconnect_state .detached_sessions .move_to_end (session_id , last = False ) # move to head
4847 break
4948
5049 # clean this session
5150 logger .debug ("session %s expired" % session_id )
52- _state .active_connections .pop (session_id , None )
53- _state . undelivered_messages .pop (session_id , None )
54- session = _state .unclosed_sessions .pop (session_id , None )
51+ _reconnect_state .active_connections .pop (session_id , None )
52+ _reconnect_state . session_will_messages .pop (session_id , None )
53+ session = _reconnect_state .unclosed_sessions .pop (session_id , None )
5554 if session :
5655 session .close (nonblock = True )
5756
@@ -61,7 +60,7 @@ def clean_expired_sessions():
6160
6261async def session_clean_task ():
6362 global _session_clean_task_started
64- if _session_clean_task_started or not _state .expire_second :
63+ if _session_clean_task_started or not _reconnect_state .expire_second :
6564 return
6665
6766 _session_clean_task_started = True
@@ -72,7 +71,7 @@ async def session_clean_task():
7271 except Exception :
7372 logger .exception ("Error when clean expired sessions" )
7473
75- await asyncio .sleep (_state .expire_second // 2 )
74+ await asyncio .sleep (_reconnect_state .expire_second // 2 )
7675
7776
7877class WebSocketConnection (abc .ABC ):
@@ -118,30 +117,30 @@ def __init__(self, connection: WebSocketConnection, application, reconnectable:
118117 if self .session_id in ('NEW' , None ): # 初始请求,创建新 Session
119118 self ._init_session (application )
120119 if reconnectable :
120+ _reconnect_state .active_connections [self .session_id ] = self .connection
121+ _reconnect_state .unclosed_sessions [self .session_id ] = self .session
121122 # set session id to client, so the client can send it back to server to recover a session when it
122123 # resumes form a connection lost
123124 connection .write_message (dict (command = 'set_session_id' , spec = self .session_id ))
124- elif self .session_id not in _state .unclosed_sessions : # session is expired
125+ elif self .session_id not in _reconnect_state .unclosed_sessions : # session is expired
125126 bye_msg = dict (command = 'close_session' )
126- for m in _state . undelivered_messages .get (self .session_id , [bye_msg ]):
127+ for m in _reconnect_state . session_will_messages .get (self .session_id , [bye_msg ]):
127128 try :
128129 connection .write_message (m )
129130 except Exception :
130131 logger .exception ("Error in sending message via websocket" )
131- else :
132- self .session = _state .unclosed_sessions [self .session_id ]
133- _state .detached_sessions .pop (self .session_id , None )
134- _state .active_connections [self .session_id ] = connection
132+ else : # resumes form a connection lost
133+ self .session = _reconnect_state .unclosed_sessions [self .session_id ]
134+ _reconnect_state .detached_sessions .pop (self .session_id , None )
135+ _reconnect_state .active_connections [self .session_id ] = connection
135136 # send the latest messages to client
136- self ._send_msg_to_client (self . session )
137+ self ._send_msg_to_client ()
137138
138139 logger .debug ('session id: %s' % self .session_id )
139140
140141 def _init_session (self , application ):
141142 session_info = self .connection .make_session_info ()
142143 self .session_id = random_str (24 )
143- # todo: only set item when reconnection enabled
144- _state .active_connections [self .session_id ] = self .connection
145144
146145 if iscoroutinefunction (application ) or isgeneratorfunction (application ):
147146 self .session = CoroutineBasedSession (
@@ -154,12 +153,20 @@ def _init_session(self, application):
154153 on_task_command = self ._send_msg_to_client ,
155154 on_session_close = self ._close_from_session ,
156155 loop = self .ioloop )
157- _state .unclosed_sessions [self .session_id ] = self .session
158156
159- def _send_msg_to_client (self , session ):
160- # self.connection may not be active,
161- # here we need the active connection for this session
162- conn = _state .active_connections .get (self .session_id )
157+ def _get_active_connection (self ) -> Optional [WebSocketConnection ]:
158+ # when reconnect enabled, the active connection for this session is in _reconnect_state.active_connections,
159+ # otherwise, it's self.connection.
160+ if self .reconnectable :
161+ conn = _reconnect_state .active_connections .get (self .session_id )
162+ else :
163+ conn = self .connection
164+
165+ return conn
166+
167+ def _send_msg_to_client (self , session : Session = None ):
168+ conn = self ._get_active_connection ()
169+ session = session or self .session
163170
164171 if not conn or conn .closed ():
165172 return
@@ -175,17 +182,13 @@ def _send_msg_to_client(self, session):
175182 logger .exception ("Error in sending message via websocket" )
176183
177184 def _close_from_session (self ):
178- session = _state .unclosed_sessions [self .session_id ]
179- if self .session_id in _state .active_connections :
180- # send the undelivered messages to client
181- self ._send_msg_to_client (session = session )
182- else :
183- _state .undelivered_messages [self .session_id ] = session .get_task_commands ()
184-
185- conn = _state .active_connections .pop (self .session_id , None )
186- _state .unclosed_sessions .pop (self .session_id , None )
185+ conn = self ._get_active_connection ()
187186 if conn and not conn .closed ():
187+ self ._send_msg_to_client ()
188188 conn .close ()
189+ elif self .reconnectable : # no active connection, and reconnect is enabled
190+ _reconnect_state .session_will_messages [self .session_id ] = self .session .get_task_commands ()
191+ self .session = None
189192
190193 def send_client_data (self , data ):
191194 if isinstance (data , bytes ):
@@ -197,12 +200,14 @@ def send_client_data(self, data):
197200 self .session .send_client_event (event )
198201
199202 def notify_connection_lost (self ):
200- _state . active_connections . pop ( self . session_id , None )
201- if not self .reconnectable :
203+ logger . debug ( "WebSocket closed" )
204+ if not self .reconnectable and self . session :
202205 # when the connection lost is caused by `on_session_close()`, it's OK to close the session here though.
203206 # because the `session.close()` is reentrant
204207 self .session .close (nonblock = True )
205- else :
206- if self .session_id in _state .unclosed_sessions :
207- _state .detached_sessions [self .session_id ] = time .time ()
208- logger .debug ("WebSocket closed" )
208+ self .session = None # reset the reference
209+ return
210+
211+ _reconnect_state .active_connections .pop (self .session_id , None )
212+ if self .session_id in _reconnect_state .unclosed_sessions :
213+ _reconnect_state .detached_sessions [self .session_id ] = time .time ()
0 commit comments