From db7fcdaaf7adaf76585f5350f3a52db6960b8397 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 15:57:37 +0900 Subject: [PATCH 01/10] feat: pass dumps_default, ext_hook --- src/socketio/msgpack_packet.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 27462634..6e86a72a 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -5,13 +5,33 @@ class MsgPackPacket(packet.Packet): uses_binary_events = False + def __init__( + self, + packet_type=packet.EVENT, + data=None, + namespace=None, + id=None, + binary=None, + encoded_packet=None, + dumps_default=None, + ext_hook=None, + ): + super().__init__( + packet_type, data, namespace, id, binary, encoded_packet + ) + self.dumps_default = dumps_default + self.ext_hook = ext_hook + def encode(self): """Encode the packet for transmission.""" - return msgpack.dumps(self._to_dict()) + return msgpack.dumps(self._to_dict(), default=self.dumps_default) def decode(self, encoded_packet): """Decode a transmitted package.""" - decoded = msgpack.loads(encoded_packet) + if self.ext_hook is None: + decoded = msgpack.loads(encoded_packet) + else: + decoded = msgpack.loads(encoded_packet, ext_hook=self.ext_hook) self.packet_type = decoded['type'] self.data = decoded.get('data') self.id = decoded.get('id') From abb2c8c4bf267d0f0aea438ed9390f6475096392 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 16:20:49 +0900 Subject: [PATCH 02/10] test: msgpack packet tests --- src/socketio/msgpack_packet.py | 4 +- tests/common/test_msgpack_packet.py | 109 +++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 6e86a72a..846db2f3 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -16,11 +16,11 @@ def __init__( dumps_default=None, ext_hook=None, ): + self.dumps_default = dumps_default + self.ext_hook = ext_hook super().__init__( packet_type, data, namespace, id, binary, encoded_packet ) - self.dumps_default = dumps_default - self.ext_hook = ext_hook def encode(self): """Encode the packet for transmission.""" diff --git a/tests/common/test_msgpack_packet.py b/tests/common/test_msgpack_packet.py index e0197a27..1079e018 100644 --- a/tests/common/test_msgpack_packet.py +++ b/tests/common/test_msgpack_packet.py @@ -1,3 +1,8 @@ +from datetime import datetime, timedelta, timezone + +import pytest +import msgpack + from socketio import msgpack_packet from socketio import packet @@ -5,7 +10,8 @@ class TestMsgPackPacket: def test_encode_decode(self): p = msgpack_packet.MsgPackPacket( - packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo') + packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo' + ) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -14,7 +20,8 @@ def test_encode_decode(self): def test_encode_decode_with_id(self): p = msgpack_packet.MsgPackPacket( - packet.EVENT, data=['ev', 42], id=123, namespace='/foo') + packet.EVENT, data=['ev', 42], id=123, namespace='/foo' + ) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -32,3 +39,101 @@ def test_encode_binary_ack_packet(self): assert p.packet_type == packet.ACK p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p2.data == {'foo': b'bar'} + + def test_encode_with_dumps_default(self): + def default(obj): + if isinstance(obj, datetime): + return obj.isoformat() + raise TypeError('Unknown type') + + data = { + 'current': datetime.now(tz=timezone(timedelta(0))), + 'key': 'value', + } + p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.namespace == p2.namespace + assert p.data != p2.data + + assert isinstance(p2.data, dict) + assert 'current' in p2.data + assert isinstance(p2.data['current'], str) + assert default(data['current']) == p2.data['current'] + + data.pop('current') + p2_data_without_current = p2.data.copy() + p2_data_without_current.pop('current') + assert data == p2_data_without_current + + def test_encode_without_dumps_default(self): + data = { + 'current': datetime.now(tz=timezone(timedelta(0))), + 'key': 'value', + } + p_without_default = msgpack_packet.MsgPackPacket(data=data) + with pytest.raises( + TypeError, match="can not serialize 'datetime.datetime' object" + ): + p_without_default.encode() + + def test_encode_decode_with_ext_hook(self): + class Custom: + def __init__(self, value): + self.value = value + + def __eq__(self, value: object) -> bool: + return isinstance(value, Custom) and self.value == value.value + + def default(obj): + if isinstance(obj, Custom): + return msgpack.ExtType(1, obj.value) + raise TypeError('Unknown type') + + def ext_hook(code, data): + if code == 1: + return Custom(data) + raise TypeError('Unknown ext type') + + data = {'custom': Custom(b'custom_data'), 'key': 'value'} + p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p2 = msgpack_packet.MsgPackPacket( + encoded_packet=p.encode(), ext_hook=ext_hook + ) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.data == p2.data + assert p.namespace == p2.namespace + + def test_encode_decode_without_ext_hook(self): + class Custom: + def __init__(self, value): + self.value = value + + def __eq__(self, value: object) -> bool: + return isinstance(value, Custom) and self.value == value.value + + def default(obj): + if isinstance(obj, Custom): + return msgpack.ExtType(1, obj.value) + raise TypeError('Unknown type') + + data = {'custom': Custom(b'custom_data'), 'key': 'value'} + p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) + assert p.packet_type == p2.packet_type + assert p.id == p2.id + assert p.namespace == p2.namespace + assert p.data != p2.data + + assert isinstance(p2.data, dict) + assert 'custom' in p2.data + assert isinstance(p2.data['custom'], msgpack.ExtType) + assert p2.data['custom'].code == 1 + assert p2.data['custom'].data == b'custom_data' + + data.pop('custom') + p2_data_without_custom = p2.data.copy() + p2_data_without_custom.pop('custom') + assert data == p2_data_without_custom From 5d9e3a7e6eda266496783b27684045a8f7a98c12 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 16:32:26 +0900 Subject: [PATCH 03/10] fix: pypy tests --- tests/common/test_msgpack_packet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/common/test_msgpack_packet.py b/tests/common/test_msgpack_packet.py index 1079e018..8a3befd5 100644 --- a/tests/common/test_msgpack_packet.py +++ b/tests/common/test_msgpack_packet.py @@ -73,9 +73,7 @@ def test_encode_without_dumps_default(self): 'key': 'value', } p_without_default = msgpack_packet.MsgPackPacket(data=data) - with pytest.raises( - TypeError, match="can not serialize 'datetime.datetime' object" - ): + with pytest.raises(TypeError): p_without_default.encode() def test_encode_decode_with_ext_hook(self): From c86a3fdc1ec23ffd05b693f3a7613dd961af562d Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:04:03 +0900 Subject: [PATCH 04/10] feat: serializer_args, _create_packet --- src/socketio/base_client.py | 7 ++++++- src/socketio/base_server.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/socketio/base_client.py b/src/socketio/base_client.py index 0232dca7..2bcaafcf 100644 --- a/src/socketio/base_client.py +++ b/src/socketio/base_client.py @@ -38,7 +38,8 @@ class BaseClient: def __init__(self, reconnection=True, reconnection_attempts=0, reconnection_delay=1, reconnection_delay_max=5, randomization_factor=0.5, logger=False, serializer='default', - json=None, handle_sigint=True, **kwargs): + json=None, handle_sigint=True, serializer_args=None, + **kwargs): global original_signal_handler if handle_sigint and original_signal_handler is None and \ threading.current_thread() == threading.main_thread(): @@ -63,6 +64,7 @@ def __init__(self, reconnection=True, reconnection_attempts=0, self.packet_class = msgpack_packet.MsgPackPacket else: self.packet_class = serializer + self.packet_class_args = serializer_args or {} if json is not None: self.packet_class.json = json engineio_options['json'] = json @@ -283,6 +285,9 @@ def _generate_ack_id(self, namespace, callback): self.callbacks[namespace][id] = callback return id + def _create_packet(self, *args, **kwargs): + return self.packet_class(*args, **kwargs, **self.packet_class_args) + def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index d134eba1..873e969f 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -15,7 +15,7 @@ class BaseServer: def __init__(self, client_manager=None, logger=False, serializer='default', json=None, async_handlers=True, always_connect=False, - namespaces=None, **kwargs): + namespaces=None, serializer_args=None, **kwargs): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: @@ -27,6 +27,7 @@ def __init__(self, client_manager=None, logger=False, serializer='default', self.packet_class = msgpack_packet.MsgPackPacket else: self.packet_class = serializer + self.packet_class_args = serializer_args or {} if json is not None: self.packet_class.json = json engineio_options['json'] = json @@ -252,6 +253,10 @@ def _get_namespace_handler(self, namespace, args): handler = self.namespace_handlers['*'] args = (namespace, *args) return handler, args + + def _create_packet(self, *args, **kwargs): + return self.packet_class(*args, **kwargs, + **self.packet_class_args) def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() From 005953da1d11ff4f81e59d54466fa389f0e8924d Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:07:49 +0900 Subject: [PATCH 05/10] fix: apply _create_packet --- src/socketio/async_client.py | 10 +++++----- src/socketio/async_server.py | 16 ++++++++-------- src/socketio/client.py | 10 +++++----- src/socketio/server.py | 16 ++++++++-------- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index 678743a2..c1ec14f0 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -243,7 +243,7 @@ async def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - await self._send_packet(self.packet_class( + await self._send_packet(self._create_packet( packet.EVENT, namespace=namespace, data=[event] + data, id=id)) async def send(self, data, namespace=None, callback=None): @@ -325,7 +325,7 @@ async def disconnect(self): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - await self._send_packet(self.packet_class(packet.DISCONNECT, + await self._send_packet(self._create_packet(packet.DISCONNECT, namespace=n)) await self.eio.disconnect() @@ -422,7 +422,7 @@ async def _handle_event(self, namespace, id, data): data = list(r) else: data = [r] - await self._send_packet(self.packet_class( + await self._send_packet(self._create_packet( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, namespace, id, data): @@ -555,7 +555,7 @@ async def _handle_eio_connect(self): self.sid = self.eio.sid real_auth = await self._get_real_value(self.connection_auth) or {} for n in self.connection_namespaces: - await self._send_packet(self.packet_class( + await self._send_packet(self._create_packet( packet.CONNECT, data=real_auth, namespace=n)) async def _handle_eio_message(self, data): @@ -569,7 +569,7 @@ async def _handle_eio_message(self, data): else: await self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = self.packet_class(encoded_packet=data) + pkt = self._create_packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 6c9e3ca3..5b896bf5 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -425,7 +425,7 @@ async def disconnect(self, sid, namespace=None, ignore_queue=False): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.DISCONNECT, namespace=namespace)) await self._trigger_event('disconnect', namespace, sid, self.reason.SERVER_DISCONNECT) @@ -538,13 +538,13 @@ async def _handle_connect(self, eio_sid, namespace, data): or self.namespaces == '*' or namespace in self.namespaces: sid = await self.manager.connect(eio_sid, namespace) if sid is None: - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.CONNECT_ERROR, data='Unable to connect', namespace=namespace)) return if self.always_connect: - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -568,15 +568,15 @@ async def _handle_connect(self, eio_sid, namespace, data): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) await self.manager.disconnect(sid, namespace, ignore_queue=True) elif not self.always_connect: - await self._send_packet(eio_sid, self.packet_class( + await self._send_packet(eio_sid, self._create_packet( packet.CONNECT, {'sid': sid}, namespace=namespace)) async def _handle_disconnect(self, eio_sid, namespace, reason=None): @@ -622,7 +622,7 @@ async def _handle_event_internal(self, server, sid, eio_sid, data, data = list(r) else: data = [r] - await server._send_packet(eio_sid, self.packet_class( + await server._send_packet(eio_sid, self._create_packet( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, eio_sid, namespace, id, data): @@ -686,7 +686,7 @@ async def _handle_eio_message(self, eio_sid, data): await self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = self.packet_class(encoded_packet=data) + pkt = self._create_packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/client.py b/src/socketio/client.py index 5282e0a1..296e4dc4 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -234,7 +234,7 @@ def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - self._send_packet(self.packet_class(packet.EVENT, namespace=namespace, + self._send_packet(self._create_packet(packet.EVENT, namespace=namespace, data=[event] + data, id=id)) def send(self, data, namespace=None, callback=None): @@ -307,7 +307,7 @@ def disconnect(self): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - self._send_packet(self.packet_class( + self._send_packet(self._create_packet( packet.DISCONNECT, namespace=n)) self.eio.disconnect() @@ -402,7 +402,7 @@ def _handle_event(self, namespace, id, data): data = list(r) else: data = [r] - self._send_packet(self.packet_class( + self._send_packet(self._create_packet( packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, namespace, id, data): @@ -506,7 +506,7 @@ def _handle_eio_connect(self): self.sid = self.eio.sid real_auth = self._get_real_value(self.connection_auth) or {} for n in self.connection_namespaces: - self._send_packet(self.packet_class( + self._send_packet(self._create_packet( packet.CONNECT, data=real_auth, namespace=n)) def _handle_eio_message(self, data): @@ -520,7 +520,7 @@ def _handle_eio_message(self, data): else: self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = self.packet_class(encoded_packet=data) + pkt = self._create_packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/server.py b/src/socketio/server.py index f3257081..21d6afeb 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -401,7 +401,7 @@ def disconnect(self, sid, namespace=None, ignore_queue=False): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.DISCONNECT, namespace=namespace)) self._trigger_event('disconnect', namespace, sid, self.reason.SERVER_DISCONNECT) @@ -520,13 +520,13 @@ def _handle_connect(self, eio_sid, namespace, data): or self.namespaces == '*' or namespace in self.namespaces: sid = self.manager.connect(eio_sid, namespace) if sid is None: - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.CONNECT_ERROR, data='Unable to connect', namespace=namespace)) return if self.always_connect: - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -550,15 +550,15 @@ def _handle_connect(self, eio_sid, namespace, data): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace, ignore_queue=True) elif not self.always_connect: - self._send_packet(eio_sid, self.packet_class( + self._send_packet(eio_sid, self._create_packet( packet.CONNECT, {'sid': sid}, namespace=namespace)) def _handle_disconnect(self, eio_sid, namespace, reason=None): @@ -601,7 +601,7 @@ def _handle_event_internal(self, server, sid, eio_sid, data, namespace, data = list(r) else: data = [r] - server._send_packet(eio_sid, self.packet_class( + server._send_packet(eio_sid, self._create_packet( packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, eio_sid, namespace, id, data): @@ -650,7 +650,7 @@ def _handle_eio_message(self, eio_sid, data): else: self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = self.packet_class(encoded_packet=data) + pkt = self._create_packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: From 81e96142f9d92f1203985a07d3a5a12b63609972 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:10:48 +0900 Subject: [PATCH 06/10] docs: add serializer_args --- src/socketio/async_client.py | 3 +++ src/socketio/async_server.py | 3 +++ src/socketio/client.py | 3 +++ src/socketio/server.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index c1ec14f0..fc7ce3fa 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -45,6 +45,9 @@ class AsyncClient(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. + :param serializer_args: A mapping of additional parameters to pass to + the serializer. The content of this dictionary + depends on the selected serialization method. The Engine.IO configuration supports the following settings: diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 5b896bf5..9bdfca4f 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -50,6 +50,9 @@ class AsyncServer(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. + :param serializer_args: A mapping of additional parameters to pass to + the serializer. The content of this dictionary + depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: diff --git a/src/socketio/client.py b/src/socketio/client.py index 296e4dc4..7be92ccb 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -48,6 +48,9 @@ class Client(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. + :param serializer_args: A mapping of additional parameters to pass to + the serializer. The content of this dictionary + depends on the selected serialization method. The Engine.IO configuration supports the following settings: diff --git a/src/socketio/server.py b/src/socketio/server.py index 21d6afeb..1658fa5a 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -53,6 +53,9 @@ class Server(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. + :param serializer_args: A mapping of additional parameters to pass to + the serializer. The content of this dictionary + depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: From 5bb11b6964e99b4874fa41dff4c0032435987f8b Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:12:04 +0900 Subject: [PATCH 07/10] fix: lint error --- src/socketio/async_client.py | 2 +- src/socketio/async_server.py | 2 +- src/socketio/base_server.py | 2 +- src/socketio/client.py | 7 ++++--- src/socketio/server.py | 2 +- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index fc7ce3fa..c19c8459 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -45,7 +45,7 @@ class AsyncClient(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. - :param serializer_args: A mapping of additional parameters to pass to + :param serializer_args: A mapping of additional parameters to pass to the serializer. The content of this dictionary depends on the selected serialization method. diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index 9bdfca4f..fa22393e 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -50,7 +50,7 @@ class AsyncServer(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. - :param serializer_args: A mapping of additional parameters to pass to + :param serializer_args: A mapping of additional parameters to pass to the serializer. The content of this dictionary depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index 873e969f..488ffe1d 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -253,7 +253,7 @@ def _get_namespace_handler(self, namespace, args): handler = self.namespace_handlers['*'] args = (namespace, *args) return handler, args - + def _create_packet(self, *args, **kwargs): return self.packet_class(*args, **kwargs, **self.packet_class_args) diff --git a/src/socketio/client.py b/src/socketio/client.py index 7be92ccb..29c1f25c 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -48,7 +48,7 @@ class Client(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. - :param serializer_args: A mapping of additional parameters to pass to + :param serializer_args: A mapping of additional parameters to pass to the serializer. The content of this dictionary depends on the selected serialization method. @@ -237,8 +237,9 @@ def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - self._send_packet(self._create_packet(packet.EVENT, namespace=namespace, - data=[event] + data, id=id)) + self._send_packet( + self._create_packet(packet.EVENT, namespace=namespace, + data=[event] + data, id=id)) def send(self, data, namespace=None, callback=None): """Send a message to the server. diff --git a/src/socketio/server.py b/src/socketio/server.py index 1658fa5a..7312506c 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -53,7 +53,7 @@ class Server(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. - :param serializer_args: A mapping of additional parameters to pass to + :param serializer_args: A mapping of additional parameters to pass to the serializer. The content of this dictionary depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. From 6089dd0b9e74e9d7fd00002a0934c7a9216ed369 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:25:31 +0900 Subject: [PATCH 08/10] test: serializer_args tests --- tests/async/test_client.py | 29 +++++++++++++++++++++++++++++ tests/async/test_server.py | 29 +++++++++++++++++++++++++++++ tests/common/test_client.py | 29 +++++++++++++++++++++++++++++ tests/common/test_server.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 116 insertions(+) diff --git a/tests/async/test_client.py b/tests/async/test_client.py index b4b0c6c5..6b25c7ba 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -1,5 +1,6 @@ import asyncio from unittest import mock +from datetime import datetime, timezone, timedelta import pytest @@ -1242,3 +1243,31 @@ async def test_eio_disconnect_no_reconnect(self): assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() + + def test_serializer_args(self): + args = {"foo": "bar"} + c = async_client.AsyncClient(serializer_args=args) + assert c.packet_class_args == args + + def test_serializer_args_with_msgpack(self): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} + c = async_client.AsyncClient(serializer='msgpack', serializer_args=args) + p = c._create_packet(data=data) + p2 = c._create_packet(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] + + def test_invalid_serializer_args(self): + args = {"invalid_arg": 123} + c = async_client.AsyncClient(serializer='msgpack', serializer_args=args) + with pytest.raises(TypeError): + c._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 575f2097..6bc75b9c 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -1,6 +1,7 @@ import asyncio import logging from unittest import mock +from datetime import datetime, timezone, timedelta from engineio import json from engineio import packet as eio_packet @@ -1089,3 +1090,31 @@ async def test_sleep(self, eio): s = async_server.AsyncServer() await s.sleep(1.23) s.eio.sleep.assert_awaited_once_with(1.23) + + def test_serializer_args(self, eio): + args = {"foo": "bar"} + s = async_server.AsyncServer(serializer_args=args) + assert s.packet_class_args == args + + def test_serializer_args_with_msgpack(self, eio): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} + s = async_server.AsyncServer(serializer='msgpack', serializer_args=args) + p = s._create_packet(data=data) + p2 = s._create_packet(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] + + def test_invalid_serializer_args(self, eio): + args = {"invalid_arg": 123} + s = async_server.AsyncServer(serializer='msgpack', serializer_args=args) + with pytest.raises(TypeError): + s._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file diff --git a/tests/common/test_client.py b/tests/common/test_client.py index cbda3f1f..fd90512c 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -1,6 +1,7 @@ import logging import time from unittest import mock +from datetime import datetime, timezone, timedelta from engineio import exceptions as engineio_exceptions from engineio import json @@ -1386,3 +1387,31 @@ def test_eio_disconnect_no_reconnect(self): assert c.sid is None assert not c.connected c.start_background_task.assert_not_called() + + def test_serializer_args(self): + args = {"foo": "bar"} + c = client.Client(serializer_args=args) + assert c.packet_class_args == args + + def test_serializer_args_with_msgpack(self): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} + c = client.Client(serializer='msgpack', serializer_args=args) + p = c._create_packet(data=data) + p2 = c._create_packet(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] + + def test_invalid_serializer_args(self): + args = {"invalid_arg": 123} + c = client.Client(serializer='msgpack', serializer_args=args) + with pytest.raises(TypeError): + c._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file diff --git a/tests/common/test_server.py b/tests/common/test_server.py index bdbbfe07..a8df8c8a 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -1,5 +1,6 @@ import logging from unittest import mock +from datetime import datetime, timezone, timedelta from engineio import json from engineio import packet as eio_packet @@ -1032,3 +1033,31 @@ def test_sleep(self, eio): s = server.Server() s.sleep(1.23) s.eio.sleep.assert_called_once_with(1.23) + + def test_serializer_args(self, eio): + args = {"foo": "bar"} + s = server.Server(serializer_args=args) + assert s.packet_class_args == args + + def test_serializer_args_with_msgpack(self, eio): + def default(o): + if isinstance(o, datetime): + return o.isoformat() + raise TypeError("Unknown type") + args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} + s = server.Server(serializer='msgpack', serializer_args=args) + p = s._create_packet(data=data) + p2 = s._create_packet(encoded_packet=p.encode()) + + assert p.data != p2.data + assert isinstance(p2.data, dict) + assert "current" in p2.data + assert isinstance(p2.data["current"], str) + assert default(data["current"]) == p2.data["current"] + + def test_invalid_serializer_args(self, eio): + args = {"invalid_arg": 123} + s = server.Server(serializer='msgpack', serializer_args=args) + with pytest.raises(TypeError): + s._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file From 6c899d74259776c0cad724c0eedc220c817e8dc1 Mon Sep 17 00:00:00 2001 From: phi Date: Sun, 2 Nov 2025 22:26:53 +0900 Subject: [PATCH 09/10] fix: lint errors --- tests/async/test_client.py | 12 +++++++----- tests/async/test_server.py | 12 +++++++----- tests/common/test_client.py | 6 +++--- tests/common/test_server.py | 6 +++--- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 6b25c7ba..58e2ac75 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -1248,7 +1248,7 @@ def test_serializer_args(self): args = {"foo": "bar"} c = async_client.AsyncClient(serializer_args=args) assert c.packet_class_args == args - + def test_serializer_args_with_msgpack(self): def default(o): if isinstance(o, datetime): @@ -1256,7 +1256,8 @@ def default(o): raise TypeError("Unknown type") args = {"dumps_default": default} data = {"current": datetime.now(timezone(timedelta(0)))} - c = async_client.AsyncClient(serializer='msgpack', serializer_args=args) + c = async_client.AsyncClient(serializer='msgpack', + serializer_args=args) p = c._create_packet(data=data) p2 = c._create_packet(encoded_packet=p.encode()) @@ -1265,9 +1266,10 @@ def default(o): assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - + def test_invalid_serializer_args(self): args = {"invalid_arg": 123} - c = async_client.AsyncClient(serializer='msgpack', serializer_args=args) + c = async_client.AsyncClient(serializer='msgpack', + serializer_args=args) with pytest.raises(TypeError): - c._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file + c._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 6bc75b9c..793192f2 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -1095,7 +1095,7 @@ def test_serializer_args(self, eio): args = {"foo": "bar"} s = async_server.AsyncServer(serializer_args=args) assert s.packet_class_args == args - + def test_serializer_args_with_msgpack(self, eio): def default(o): if isinstance(o, datetime): @@ -1103,7 +1103,8 @@ def default(o): raise TypeError("Unknown type") args = {"dumps_default": default} data = {"current": datetime.now(timezone(timedelta(0)))} - s = async_server.AsyncServer(serializer='msgpack', serializer_args=args) + s = async_server.AsyncServer(serializer='msgpack', + serializer_args=args) p = s._create_packet(data=data) p2 = s._create_packet(encoded_packet=p.encode()) @@ -1112,9 +1113,10 @@ def default(o): assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - + def test_invalid_serializer_args(self, eio): args = {"invalid_arg": 123} - s = async_server.AsyncServer(serializer='msgpack', serializer_args=args) + s = async_server.AsyncServer(serializer='msgpack', + serializer_args=args) with pytest.raises(TypeError): - s._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file + s._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index fd90512c..90ab5dfd 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -1392,7 +1392,7 @@ def test_serializer_args(self): args = {"foo": "bar"} c = client.Client(serializer_args=args) assert c.packet_class_args == args - + def test_serializer_args_with_msgpack(self): def default(o): if isinstance(o, datetime): @@ -1409,9 +1409,9 @@ def default(o): assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - + def test_invalid_serializer_args(self): args = {"invalid_arg": 123} c = client.Client(serializer='msgpack', serializer_args=args) with pytest.raises(TypeError): - c._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file + c._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/common/test_server.py b/tests/common/test_server.py index a8df8c8a..6bbe7c4c 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -1038,7 +1038,7 @@ def test_serializer_args(self, eio): args = {"foo": "bar"} s = server.Server(serializer_args=args) assert s.packet_class_args == args - + def test_serializer_args_with_msgpack(self, eio): def default(o): if isinstance(o, datetime): @@ -1055,9 +1055,9 @@ def default(o): assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - + def test_invalid_serializer_args(self, eio): args = {"invalid_arg": 123} s = server.Server(serializer='msgpack', serializer_args=args) with pytest.raises(TypeError): - s._create_packet(data={"foo": "bar"}).encode() \ No newline at end of file + s._create_packet(data={"foo": "bar"}).encode() From c7b872915c24e06a0a15b37497d877605826bc58 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Thu, 6 Nov 2025 19:33:24 +0000 Subject: [PATCH 10/10] MsgPackPacket.configure method --- src/socketio/async_client.py | 13 ++++------- src/socketio/async_server.py | 19 +++++++-------- src/socketio/base_client.py | 7 +----- src/socketio/base_server.py | 7 +----- src/socketio/client.py | 16 +++++-------- src/socketio/msgpack_packet.py | 36 ++++++++++++----------------- src/socketio/server.py | 19 +++++++-------- tests/async/test_client.py | 23 +++++------------- tests/async/test_server.py | 23 +++++------------- tests/common/test_client.py | 21 +++++------------ tests/common/test_msgpack_packet.py | 19 +++++++-------- tests/common/test_server.py | 21 +++++------------ 12 files changed, 78 insertions(+), 146 deletions(-) diff --git a/src/socketio/async_client.py b/src/socketio/async_client.py index c19c8459..678743a2 100644 --- a/src/socketio/async_client.py +++ b/src/socketio/async_client.py @@ -45,9 +45,6 @@ class AsyncClient(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. - :param serializer_args: A mapping of additional parameters to pass to - the serializer. The content of this dictionary - depends on the selected serialization method. The Engine.IO configuration supports the following settings: @@ -246,7 +243,7 @@ async def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - await self._send_packet(self._create_packet( + await self._send_packet(self.packet_class( packet.EVENT, namespace=namespace, data=[event] + data, id=id)) async def send(self, data, namespace=None, callback=None): @@ -328,7 +325,7 @@ async def disconnect(self): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - await self._send_packet(self._create_packet(packet.DISCONNECT, + await self._send_packet(self.packet_class(packet.DISCONNECT, namespace=n)) await self.eio.disconnect() @@ -425,7 +422,7 @@ async def _handle_event(self, namespace, id, data): data = list(r) else: data = [r] - await self._send_packet(self._create_packet( + await self._send_packet(self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, namespace, id, data): @@ -558,7 +555,7 @@ async def _handle_eio_connect(self): self.sid = self.eio.sid real_auth = await self._get_real_value(self.connection_auth) or {} for n in self.connection_namespaces: - await self._send_packet(self._create_packet( + await self._send_packet(self.packet_class( packet.CONNECT, data=real_auth, namespace=n)) async def _handle_eio_message(self, data): @@ -572,7 +569,7 @@ async def _handle_eio_message(self, data): else: await self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = self._create_packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/async_server.py b/src/socketio/async_server.py index fa22393e..6c9e3ca3 100644 --- a/src/socketio/async_server.py +++ b/src/socketio/async_server.py @@ -50,9 +50,6 @@ class AsyncServer(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. - :param serializer_args: A mapping of additional parameters to pass to - the serializer. The content of this dictionary - depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: @@ -428,7 +425,7 @@ async def disconnect(self, sid, namespace=None, ignore_queue=False): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) await self._trigger_event('disconnect', namespace, sid, self.reason.SERVER_DISCONNECT) @@ -541,13 +538,13 @@ async def _handle_connect(self, eio_sid, namespace, data): or self.namespaces == '*' or namespace in self.namespaces: sid = await self.manager.connect(eio_sid, namespace) if sid is None: - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data='Unable to connect', namespace=namespace)) return if self.always_connect: - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -571,15 +568,15 @@ async def _handle_connect(self, eio_sid, namespace, data): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) await self.manager.disconnect(sid, namespace, ignore_queue=True) elif not self.always_connect: - await self._send_packet(eio_sid, self._create_packet( + await self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) async def _handle_disconnect(self, eio_sid, namespace, reason=None): @@ -625,7 +622,7 @@ async def _handle_event_internal(self, server, sid, eio_sid, data, data = list(r) else: data = [r] - await server._send_packet(eio_sid, self._create_packet( + await server._send_packet(eio_sid, self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, eio_sid, namespace, id, data): @@ -689,7 +686,7 @@ async def _handle_eio_message(self, eio_sid, data): await self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = self._create_packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/base_client.py b/src/socketio/base_client.py index 2bcaafcf..0232dca7 100644 --- a/src/socketio/base_client.py +++ b/src/socketio/base_client.py @@ -38,8 +38,7 @@ class BaseClient: def __init__(self, reconnection=True, reconnection_attempts=0, reconnection_delay=1, reconnection_delay_max=5, randomization_factor=0.5, logger=False, serializer='default', - json=None, handle_sigint=True, serializer_args=None, - **kwargs): + json=None, handle_sigint=True, **kwargs): global original_signal_handler if handle_sigint and original_signal_handler is None and \ threading.current_thread() == threading.main_thread(): @@ -64,7 +63,6 @@ def __init__(self, reconnection=True, reconnection_attempts=0, self.packet_class = msgpack_packet.MsgPackPacket else: self.packet_class = serializer - self.packet_class_args = serializer_args or {} if json is not None: self.packet_class.json = json engineio_options['json'] = json @@ -285,9 +283,6 @@ def _generate_ack_id(self, namespace, callback): self.callbacks[namespace][id] = callback return id - def _create_packet(self, *args, **kwargs): - return self.packet_class(*args, **kwargs, **self.packet_class_args) - def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() diff --git a/src/socketio/base_server.py b/src/socketio/base_server.py index 488ffe1d..d134eba1 100644 --- a/src/socketio/base_server.py +++ b/src/socketio/base_server.py @@ -15,7 +15,7 @@ class BaseServer: def __init__(self, client_manager=None, logger=False, serializer='default', json=None, async_handlers=True, always_connect=False, - namespaces=None, serializer_args=None, **kwargs): + namespaces=None, **kwargs): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: @@ -27,7 +27,6 @@ def __init__(self, client_manager=None, logger=False, serializer='default', self.packet_class = msgpack_packet.MsgPackPacket else: self.packet_class = serializer - self.packet_class_args = serializer_args or {} if json is not None: self.packet_class.json = json engineio_options['json'] = json @@ -254,10 +253,6 @@ def _get_namespace_handler(self, namespace, args): args = (namespace, *args) return handler, args - def _create_packet(self, *args, **kwargs): - return self.packet_class(*args, **kwargs, - **self.packet_class_args) - def _handle_eio_connect(self): # pragma: no cover raise NotImplementedError() diff --git a/src/socketio/client.py b/src/socketio/client.py index 29c1f25c..5282e0a1 100644 --- a/src/socketio/client.py +++ b/src/socketio/client.py @@ -48,9 +48,6 @@ class Client(base_client.BaseClient): leave interrupt handling to the calling application. Interrupt handling can only be enabled when the client instance is created in the main thread. - :param serializer_args: A mapping of additional parameters to pass to - the serializer. The content of this dictionary - depends on the selected serialization method. The Engine.IO configuration supports the following settings: @@ -237,9 +234,8 @@ def emit(self, event, data=None, namespace=None, callback=None): data = [data] else: data = [] - self._send_packet( - self._create_packet(packet.EVENT, namespace=namespace, - data=[event] + data, id=id)) + self._send_packet(self.packet_class(packet.EVENT, namespace=namespace, + data=[event] + data, id=id)) def send(self, data, namespace=None, callback=None): """Send a message to the server. @@ -311,7 +307,7 @@ def disconnect(self): # here we just request the disconnection # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: - self._send_packet(self._create_packet( + self._send_packet(self.packet_class( packet.DISCONNECT, namespace=n)) self.eio.disconnect() @@ -406,7 +402,7 @@ def _handle_event(self, namespace, id, data): data = list(r) else: data = [r] - self._send_packet(self._create_packet( + self._send_packet(self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, namespace, id, data): @@ -510,7 +506,7 @@ def _handle_eio_connect(self): self.sid = self.eio.sid real_auth = self._get_real_value(self.connection_auth) or {} for n in self.connection_namespaces: - self._send_packet(self._create_packet( + self._send_packet(self.packet_class( packet.CONNECT, data=real_auth, namespace=n)) def _handle_eio_message(self, data): @@ -524,7 +520,7 @@ def _handle_eio_message(self, data): else: self._handle_ack(pkt.namespace, pkt.id, pkt.data) else: - pkt = self._create_packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/src/socketio/msgpack_packet.py b/src/socketio/msgpack_packet.py index 846db2f3..9622dd26 100644 --- a/src/socketio/msgpack_packet.py +++ b/src/socketio/msgpack_packet.py @@ -4,34 +4,28 @@ class MsgPackPacket(packet.Packet): uses_binary_events = False + dumps_default = None + ext_hook = msgpack.ExtType - def __init__( - self, - packet_type=packet.EVENT, - data=None, - namespace=None, - id=None, - binary=None, - encoded_packet=None, - dumps_default=None, - ext_hook=None, - ): - self.dumps_default = dumps_default - self.ext_hook = ext_hook - super().__init__( - packet_type, data, namespace, id, binary, encoded_packet - ) + @classmethod + def configure(cls, dumps_default=None, ext_hook=msgpack.ExtType): + class CustomMsgPackPacket(MsgPackPacket): + dumps_default = None + ext_hook = None + + CustomMsgPackPacket.dumps_default = dumps_default + CustomMsgPackPacket.ext_hook = ext_hook + return CustomMsgPackPacket def encode(self): """Encode the packet for transmission.""" - return msgpack.dumps(self._to_dict(), default=self.dumps_default) + return msgpack.dumps(self._to_dict(), + default=self.__class__.dumps_default) def decode(self, encoded_packet): """Decode a transmitted package.""" - if self.ext_hook is None: - decoded = msgpack.loads(encoded_packet) - else: - decoded = msgpack.loads(encoded_packet, ext_hook=self.ext_hook) + decoded = msgpack.loads(encoded_packet, + ext_hook=self.__class__.ext_hook) self.packet_type = decoded['type'] self.data = decoded.get('data') self.id = decoded.get('id') diff --git a/src/socketio/server.py b/src/socketio/server.py index 7312506c..f3257081 100644 --- a/src/socketio/server.py +++ b/src/socketio/server.py @@ -53,9 +53,6 @@ class Server(base_server.BaseServer): default is `['/']`, which always accepts connections to the default namespace. Set to `'*'` to accept all namespaces. - :param serializer_args: A mapping of additional parameters to pass to - the serializer. The content of this dictionary - depends on the selected serialization method. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: @@ -404,7 +401,7 @@ def disconnect(self, sid, namespace=None, ignore_queue=False): if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, namespace=namespace)) self._trigger_event('disconnect', namespace, sid, self.reason.SERVER_DISCONNECT) @@ -523,13 +520,13 @@ def _handle_connect(self, eio_sid, namespace, data): or self.namespaces == '*' or namespace in self.namespaces: sid = self.manager.connect(eio_sid, namespace) if sid is None: - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data='Unable to connect', namespace=namespace)) return if self.always_connect: - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) fail_reason = exceptions.ConnectionRefusedError().error_args try: @@ -553,15 +550,15 @@ def _handle_connect(self, eio_sid, namespace, data): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.DISCONNECT, data=fail_reason, namespace=namespace)) else: - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT_ERROR, data=fail_reason, namespace=namespace)) self.manager.disconnect(sid, namespace, ignore_queue=True) elif not self.always_connect: - self._send_packet(eio_sid, self._create_packet( + self._send_packet(eio_sid, self.packet_class( packet.CONNECT, {'sid': sid}, namespace=namespace)) def _handle_disconnect(self, eio_sid, namespace, reason=None): @@ -604,7 +601,7 @@ def _handle_event_internal(self, server, sid, eio_sid, data, namespace, data = list(r) else: data = [r] - server._send_packet(eio_sid, self._create_packet( + server._send_packet(eio_sid, self.packet_class( packet.ACK, namespace=namespace, id=id, data=data)) def _handle_ack(self, eio_sid, namespace, id, data): @@ -653,7 +650,7 @@ def _handle_eio_message(self, eio_sid, data): else: self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - pkt = self._create_packet(encoded_packet=data) + pkt = self.packet_class(encoded_packet=data) if pkt.packet_type == packet.CONNECT: self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 58e2ac75..7a7bfa7c 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -9,6 +9,7 @@ from engineio import exceptions as engineio_exceptions from socketio import exceptions from socketio import packet +from socketio.msgpack_packet import MsgPackPacket class TestAsyncClient: @@ -1244,32 +1245,20 @@ async def test_eio_disconnect_no_reconnect(self): assert not c.connected c.start_background_task.assert_not_called() - def test_serializer_args(self): - args = {"foo": "bar"} - c = async_client.AsyncClient(serializer_args=args) - assert c.packet_class_args == args - def test_serializer_args_with_msgpack(self): def default(o): if isinstance(o, datetime): return o.isoformat() raise TypeError("Unknown type") - args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} - c = async_client.AsyncClient(serializer='msgpack', - serializer_args=args) - p = c._create_packet(data=data) - p2 = c._create_packet(encoded_packet=p.encode()) + c = async_client.AsyncClient( + serializer=MsgPackPacket.configure(dumps_default=default)) + p = c.packet_class(data=data) + p2 = c.packet_class(encoded_packet=p.encode()) assert p.data != p2.data assert isinstance(p2.data, dict) assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - - def test_invalid_serializer_args(self): - args = {"invalid_arg": 123} - c = async_client.AsyncClient(serializer='msgpack', - serializer_args=args) - with pytest.raises(TypeError): - c._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/async/test_server.py b/tests/async/test_server.py index 793192f2..10d7ba14 100644 --- a/tests/async/test_server.py +++ b/tests/async/test_server.py @@ -12,6 +12,7 @@ from socketio import exceptions from socketio import namespace from socketio import packet +from socketio.msgpack_packet import MsgPackPacket @mock.patch('socketio.server.engineio.AsyncServer', **{ @@ -1091,32 +1092,20 @@ async def test_sleep(self, eio): await s.sleep(1.23) s.eio.sleep.assert_awaited_once_with(1.23) - def test_serializer_args(self, eio): - args = {"foo": "bar"} - s = async_server.AsyncServer(serializer_args=args) - assert s.packet_class_args == args - def test_serializer_args_with_msgpack(self, eio): def default(o): if isinstance(o, datetime): return o.isoformat() raise TypeError("Unknown type") - args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} - s = async_server.AsyncServer(serializer='msgpack', - serializer_args=args) - p = s._create_packet(data=data) - p2 = s._create_packet(encoded_packet=p.encode()) + s = async_server.AsyncServer( + serializer=MsgPackPacket.configure(dumps_default=default)) + p = s.packet_class(data=data) + p2 = s.packet_class(encoded_packet=p.encode()) assert p.data != p2.data assert isinstance(p2.data, dict) assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - - def test_invalid_serializer_args(self, eio): - args = {"invalid_arg": 123} - s = async_server.AsyncServer(serializer='msgpack', - serializer_args=args) - with pytest.raises(TypeError): - s._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/common/test_client.py b/tests/common/test_client.py index 90ab5dfd..d386a9c3 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -14,6 +14,7 @@ from socketio import msgpack_packet from socketio import namespace from socketio import packet +from socketio.msgpack_packet import MsgPackPacket class TestClient: @@ -1388,30 +1389,20 @@ def test_eio_disconnect_no_reconnect(self): assert not c.connected c.start_background_task.assert_not_called() - def test_serializer_args(self): - args = {"foo": "bar"} - c = client.Client(serializer_args=args) - assert c.packet_class_args == args - def test_serializer_args_with_msgpack(self): def default(o): if isinstance(o, datetime): return o.isoformat() raise TypeError("Unknown type") - args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} - c = client.Client(serializer='msgpack', serializer_args=args) - p = c._create_packet(data=data) - p2 = c._create_packet(encoded_packet=p.encode()) + c = client.Client( + serializer=MsgPackPacket.configure(dumps_default=default)) + p = c.packet_class(data=data) + p2 = c.packet_class(encoded_packet=p.encode()) assert p.data != p2.data assert isinstance(p2.data, dict) assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - - def test_invalid_serializer_args(self): - args = {"invalid_arg": 123} - c = client.Client(serializer='msgpack', serializer_args=args) - with pytest.raises(TypeError): - c._create_packet(data={"foo": "bar"}).encode() diff --git a/tests/common/test_msgpack_packet.py b/tests/common/test_msgpack_packet.py index 8a3befd5..0fad0292 100644 --- a/tests/common/test_msgpack_packet.py +++ b/tests/common/test_msgpack_packet.py @@ -10,8 +10,7 @@ class TestMsgPackPacket: def test_encode_decode(self): p = msgpack_packet.MsgPackPacket( - packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo' - ) + packet.CONNECT, data={'auth': {'token': '123'}}, namespace='/foo') p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -20,8 +19,7 @@ def test_encode_decode(self): def test_encode_decode_with_id(self): p = msgpack_packet.MsgPackPacket( - packet.EVENT, data=['ev', 42], id=123, namespace='/foo' - ) + packet.EVENT, data=['ev', 42], id=123, namespace='/foo') p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.data == p2.data @@ -50,7 +48,8 @@ def default(obj): 'current': datetime.now(tz=timezone(timedelta(0))), 'key': 'value', } - p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)( + data=data) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.id == p2.id @@ -95,9 +94,10 @@ def ext_hook(code, data): raise TypeError('Unknown ext type') data = {'custom': Custom(b'custom_data'), 'key': 'value'} - p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) - p2 = msgpack_packet.MsgPackPacket( - encoded_packet=p.encode(), ext_hook=ext_hook + p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)( + data=data) + p2 = msgpack_packet.MsgPackPacket.configure(ext_hook=ext_hook)( + encoded_packet=p.encode() ) assert p.packet_type == p2.packet_type assert p.id == p2.id @@ -118,7 +118,8 @@ def default(obj): raise TypeError('Unknown type') data = {'custom': Custom(b'custom_data'), 'key': 'value'} - p = msgpack_packet.MsgPackPacket(data=data, dumps_default=default) + p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)( + data=data) p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode()) assert p.packet_type == p2.packet_type assert p.id == p2.id diff --git a/tests/common/test_server.py b/tests/common/test_server.py index 6bbe7c4c..4c2c8071 100644 --- a/tests/common/test_server.py +++ b/tests/common/test_server.py @@ -11,6 +11,7 @@ from socketio import namespace from socketio import packet from socketio import server +from socketio.msgpack_packet import MsgPackPacket @mock.patch('socketio.server.engineio.Server', **{ @@ -1034,30 +1035,20 @@ def test_sleep(self, eio): s.sleep(1.23) s.eio.sleep.assert_called_once_with(1.23) - def test_serializer_args(self, eio): - args = {"foo": "bar"} - s = server.Server(serializer_args=args) - assert s.packet_class_args == args - def test_serializer_args_with_msgpack(self, eio): def default(o): if isinstance(o, datetime): return o.isoformat() raise TypeError("Unknown type") - args = {"dumps_default": default} + data = {"current": datetime.now(timezone(timedelta(0)))} - s = server.Server(serializer='msgpack', serializer_args=args) - p = s._create_packet(data=data) - p2 = s._create_packet(encoded_packet=p.encode()) + s = server.Server( + serializer=MsgPackPacket.configure(dumps_default=default)) + p = s.packet_class(data=data) + p2 = s.packet_class(encoded_packet=p.encode()) assert p.data != p2.data assert isinstance(p2.data, dict) assert "current" in p2.data assert isinstance(p2.data["current"], str) assert default(data["current"]) == p2.data["current"] - - def test_invalid_serializer_args(self, eio): - args = {"invalid_arg": 123} - s = server.Server(serializer='msgpack', serializer_args=args) - with pytest.raises(TypeError): - s._create_packet(data={"foo": "bar"}).encode()