|
| 1 | +import asyncio |
| 2 | +import json |
| 3 | +import logging |
| 4 | +import time |
| 5 | +import typing |
| 6 | +from typing import Dict |
| 7 | +import abc |
| 8 | +from ..utils import deserialize_binary_event |
| 9 | +from ...session import CoroutineBasedSession, ThreadBasedSession, Session |
| 10 | +from ...utils import iscoroutinefunction, isgeneratorfunction, \ |
| 11 | + random_str, LRUDict |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +class _state: |
| 17 | + # only used in reconnect enabled |
| 18 | + # used to clean up session |
| 19 | + detached_sessions = LRUDict() # session_id -> detached timestamp. In increasing order of the time |
| 20 | + |
| 21 | + # unclosed and unexpired session |
| 22 | + # only used in reconnect enabled |
| 23 | + # used to clean up session |
| 24 | + # used to retrieve session by id when new connection |
| 25 | + unclosed_sessions: Dict[str, Session] = {} # session_id -> session |
| 26 | + |
| 27 | + # the messages that can't deliver to browser when session close due to connection lost |
| 28 | + undelivered_messages: Dict[str, list] = {} # session_id -> unhandled message list |
| 29 | + |
| 30 | + # used to get the active conn in session's callbacks |
| 31 | + active_connections: Dict[str, 'WebSocketConnection'] = {} # session_id -> WSHandler |
| 32 | + |
| 33 | + expire_second = 10 |
| 34 | + |
| 35 | + |
| 36 | +def set_expire_second(sec): |
| 37 | + _state.expire_second = max(_state.expire_second, sec) |
| 38 | + |
| 39 | + |
| 40 | +def clean_expired_sessions(): |
| 41 | + while _state.detached_sessions: |
| 42 | + session_id, detached_ts = _state.detached_sessions.popitem(last=False) # 弹出最早过期的session |
| 43 | + |
| 44 | + if time.time() < detached_ts + _state.expire_second: |
| 45 | + # this session is not expired |
| 46 | + _state.detached_sessions[session_id] = detached_ts # restore |
| 47 | + _state.detached_sessions.move_to_end(session_id, last=False) # move to head |
| 48 | + break |
| 49 | + |
| 50 | + # clean this session |
| 51 | + logger.debug("session %s expired" % session_id) |
| 52 | + _state.active_connections.pop(session_id, None) |
| 53 | + _state.undelivered_messages.pop(session_id, None) |
| 54 | + session = _state.unclosed_sessions.pop(session_id, None) |
| 55 | + if session: |
| 56 | + session.close(nonblock=True) |
| 57 | + |
| 58 | + |
| 59 | +async def session_clean_task(): |
| 60 | + logger.debug("Start session cleaning task") |
| 61 | + while True: |
| 62 | + try: |
| 63 | + clean_expired_sessions() |
| 64 | + except Exception: |
| 65 | + logger.exception("Error when clean expired sessions") |
| 66 | + |
| 67 | + await asyncio.sleep(_state.expire_second // 2) |
| 68 | + |
| 69 | + |
| 70 | +class WebSocketConnection(abc.ABC): |
| 71 | + @abc.abstractmethod |
| 72 | + def get_query_argument(self, name) -> typing.Optional[str]: |
| 73 | + pass |
| 74 | + |
| 75 | + @abc.abstractmethod |
| 76 | + def make_session_info(self) -> dict: |
| 77 | + pass |
| 78 | + |
| 79 | + @abc.abstractmethod |
| 80 | + def write_message(self, message: dict): |
| 81 | + pass |
| 82 | + |
| 83 | + @abc.abstractmethod |
| 84 | + def closed(self) -> bool: |
| 85 | + return False |
| 86 | + |
| 87 | + @abc.abstractmethod |
| 88 | + def close(self): |
| 89 | + pass |
| 90 | + |
| 91 | + |
| 92 | +class WebSocketHandler: |
| 93 | + """ |
| 94 | + hold by one connection, |
| 95 | + share one session with multiple connection in session lifetime, but one conn at a time |
| 96 | + """ |
| 97 | + |
| 98 | + session_id: str = None |
| 99 | + session: Session = None # the session that current connection attaches |
| 100 | + connection: WebSocketConnection |
| 101 | + reconnectable: bool |
| 102 | + |
| 103 | + def __init__(self, connection: WebSocketConnection, application, reconnectable: bool): |
| 104 | + logger.debug("WebSocket opened") |
| 105 | + self.connection = connection |
| 106 | + self.reconnectable = reconnectable |
| 107 | + self.session_id = connection.get_query_argument('session') |
| 108 | + |
| 109 | + if self.session_id in ('NEW', None): # 初始请求,创建新 Session |
| 110 | + self._init_session(application) |
| 111 | + if reconnectable: |
| 112 | + # set session id to client, so the client can send it back to server to recover a session when it |
| 113 | + # resumes form a connection lost |
| 114 | + connection.write_message(dict(command='set_session_id', spec=self.session_id)) |
| 115 | + elif self.session_id not in _state.unclosed_sessions: # session is expired |
| 116 | + bye_msg = dict(command='close_session') |
| 117 | + for m in _state.undelivered_messages.get(self.session_id, [bye_msg]): |
| 118 | + try: |
| 119 | + connection.write_message(m) |
| 120 | + except Exception: |
| 121 | + logger.exception("Error in sending message via websocket") |
| 122 | + else: |
| 123 | + self.session = _state.unclosed_sessions[self.session_id] |
| 124 | + _state.detached_sessions.pop(self.session_id, None) |
| 125 | + _state.active_connections[self.session_id] = connection |
| 126 | + # send the latest messages to client |
| 127 | + self._send_msg_to_client(self.session) |
| 128 | + |
| 129 | + logger.debug('session id: %s' % self.session_id) |
| 130 | + |
| 131 | + def _init_session(self, application): |
| 132 | + session_info = self.connection.make_session_info() |
| 133 | + self.session_id = random_str(24) |
| 134 | + # todo: only set item when reconnection enabled |
| 135 | + _state.active_connections[self.session_id] = self.connection |
| 136 | + |
| 137 | + if iscoroutinefunction(application) or isgeneratorfunction(application): |
| 138 | + self.session = CoroutineBasedSession( |
| 139 | + application, session_info=session_info, |
| 140 | + on_task_command=self._send_msg_to_client, |
| 141 | + on_session_close=self._close_from_session) |
| 142 | + else: |
| 143 | + self.session = ThreadBasedSession( |
| 144 | + application, session_info=session_info, |
| 145 | + on_task_command=self._send_msg_to_client, |
| 146 | + on_session_close=self._close_from_session, |
| 147 | + loop=asyncio.get_event_loop()) |
| 148 | + _state.unclosed_sessions[self.session_id] = self.session |
| 149 | + |
| 150 | + def _send_msg_to_client(self, session): |
| 151 | + # self.connection may not be active, |
| 152 | + # here we need the active connection for this session |
| 153 | + conn = _state.active_connections.get(self.session_id) |
| 154 | + |
| 155 | + if not conn or conn.closed(): |
| 156 | + return |
| 157 | + |
| 158 | + for msg in session.get_task_commands(): |
| 159 | + try: |
| 160 | + conn.write_message(msg) |
| 161 | + except TypeError as e: |
| 162 | + logger.exception('Data serialization error: %s\n' |
| 163 | + 'This may be because you pass the wrong type of parameter to the function' |
| 164 | + ' of PyWebIO.\nData content: %s', e, msg) |
| 165 | + except Exception: |
| 166 | + logger.exception("Error in sending message via websocket") |
| 167 | + |
| 168 | + def _close_from_session(self): |
| 169 | + session = _state.unclosed_sessions[self.session_id] |
| 170 | + if self.session_id in _state.active_connections: |
| 171 | + # send the undelivered messages to client |
| 172 | + self._send_msg_to_client(session=session) |
| 173 | + else: |
| 174 | + _state.undelivered_messages[self.session_id] = session.get_task_commands() |
| 175 | + |
| 176 | + conn = _state.active_connections.pop(self.session_id, None) |
| 177 | + _state.unclosed_sessions.pop(self.session_id, None) |
| 178 | + if conn and not conn.closed(): |
| 179 | + conn.close() |
| 180 | + |
| 181 | + def send_client_data(self, data): |
| 182 | + if isinstance(data, bytes): |
| 183 | + event = deserialize_binary_event(data) |
| 184 | + else: |
| 185 | + event = json.loads(data) |
| 186 | + if event is None: |
| 187 | + return |
| 188 | + self.session.send_client_event(event) |
| 189 | + |
| 190 | + def notify_connection_lost(self): |
| 191 | + _state.active_connections.pop(self.session_id, None) |
| 192 | + if not self.reconnectable: |
| 193 | + # when the connection lost is caused by `on_session_close()`, it's OK to close the session here though. |
| 194 | + # because the `session.close()` is reentrant |
| 195 | + self.session.close(nonblock=True) |
| 196 | + else: |
| 197 | + if self.session_id in _state.unclosed_sessions: |
| 198 | + _state.detached_sessions[self.session_id] = time.time() |
| 199 | + logger.debug("WebSocket closed") |
0 commit comments