Skip to content

Commit 2089253

Browse files
Support ext_type in the MsgPackPacket class (#1521)
1 parent 6c9b997 commit 2089253

File tree

6 files changed

+212
-2
lines changed

6 files changed

+212
-2
lines changed

src/socketio/msgpack_packet.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,40 @@
44

55
class MsgPackPacket(packet.Packet):
66
uses_binary_events = False
7+
dumps_default = None
8+
ext_hook = msgpack.ExtType
9+
10+
@classmethod
11+
def configure(cls, dumps_default=None, ext_hook=msgpack.ExtType):
12+
"""Change the default options for msgpack encoding and decoding.
13+
14+
:param dumps_default: a function called for objects that cannot be
15+
serialized by default msgpack. The function
16+
receives one argument, the object to serialize.
17+
It should return a serializable object or a
18+
``msgpack.ExtType`` instance.
19+
:param ext_hook: a function called when a ``msgpack.ExtType`` object is
20+
seen during decoding. The function receives two
21+
arguments, the code and the data. It should return the
22+
decoded object.
23+
"""
24+
class CustomMsgPackPacket(MsgPackPacket):
25+
dumps_default = None
26+
ext_hook = None
27+
28+
CustomMsgPackPacket.dumps_default = dumps_default
29+
CustomMsgPackPacket.ext_hook = ext_hook
30+
return CustomMsgPackPacket
731

832
def encode(self):
933
"""Encode the packet for transmission."""
10-
return msgpack.dumps(self._to_dict())
34+
return msgpack.dumps(self._to_dict(),
35+
default=self.__class__.dumps_default)
1136

1237
def decode(self, encoded_packet):
1338
"""Decode a transmitted package."""
14-
decoded = msgpack.loads(encoded_packet)
39+
decoded = msgpack.loads(encoded_packet,
40+
ext_hook=self.__class__.ext_hook)
1541
self.packet_type = decoded['type']
1642
self.data = decoded.get('data')
1743
self.id = decoded.get('id')

tests/async/test_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
from unittest import mock
3+
from datetime import datetime, timezone, timedelta
34

45
import pytest
56

@@ -8,6 +9,7 @@
89
from engineio import exceptions as engineio_exceptions
910
from socketio import exceptions
1011
from socketio import packet
12+
from socketio.msgpack_packet import MsgPackPacket
1113

1214

1315
class TestAsyncClient:
@@ -1242,3 +1244,21 @@ async def test_eio_disconnect_no_reconnect(self):
12421244
assert c.sid is None
12431245
assert not c.connected
12441246
c.start_background_task.assert_not_called()
1247+
1248+
def test_serializer_args_with_msgpack(self):
1249+
def default(o):
1250+
if isinstance(o, datetime):
1251+
return o.isoformat()
1252+
raise TypeError("Unknown type")
1253+
1254+
data = {"current": datetime.now(timezone(timedelta(0)))}
1255+
c = async_client.AsyncClient(
1256+
serializer=MsgPackPacket.configure(dumps_default=default))
1257+
p = c.packet_class(data=data)
1258+
p2 = c.packet_class(encoded_packet=p.encode())
1259+
1260+
assert p.data != p2.data
1261+
assert isinstance(p2.data, dict)
1262+
assert "current" in p2.data
1263+
assert isinstance(p2.data["current"], str)
1264+
assert default(data["current"]) == p2.data["current"]

tests/async/test_server.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import logging
33
from unittest import mock
4+
from datetime import datetime, timezone, timedelta
45

56
from engineio import json
67
from engineio import packet as eio_packet
@@ -11,6 +12,7 @@
1112
from socketio import exceptions
1213
from socketio import namespace
1314
from socketio import packet
15+
from socketio.msgpack_packet import MsgPackPacket
1416

1517

1618
@mock.patch('socketio.server.engineio.AsyncServer', **{
@@ -1089,3 +1091,21 @@ async def test_sleep(self, eio):
10891091
s = async_server.AsyncServer()
10901092
await s.sleep(1.23)
10911093
s.eio.sleep.assert_awaited_once_with(1.23)
1094+
1095+
def test_serializer_args_with_msgpack(self, eio):
1096+
def default(o):
1097+
if isinstance(o, datetime):
1098+
return o.isoformat()
1099+
raise TypeError("Unknown type")
1100+
1101+
data = {"current": datetime.now(timezone(timedelta(0)))}
1102+
s = async_server.AsyncServer(
1103+
serializer=MsgPackPacket.configure(dumps_default=default))
1104+
p = s.packet_class(data=data)
1105+
p2 = s.packet_class(encoded_packet=p.encode())
1106+
1107+
assert p.data != p2.data
1108+
assert isinstance(p2.data, dict)
1109+
assert "current" in p2.data
1110+
assert isinstance(p2.data["current"], str)
1111+
assert default(data["current"]) == p2.data["current"]

tests/common/test_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import time
33
from unittest import mock
4+
from datetime import datetime, timezone, timedelta
45

56
from engineio import exceptions as engineio_exceptions
67
from engineio import json
@@ -13,6 +14,7 @@
1314
from socketio import msgpack_packet
1415
from socketio import namespace
1516
from socketio import packet
17+
from socketio.msgpack_packet import MsgPackPacket
1618

1719

1820
class TestClient:
@@ -1386,3 +1388,21 @@ def test_eio_disconnect_no_reconnect(self):
13861388
assert c.sid is None
13871389
assert not c.connected
13881390
c.start_background_task.assert_not_called()
1391+
1392+
def test_serializer_args_with_msgpack(self):
1393+
def default(o):
1394+
if isinstance(o, datetime):
1395+
return o.isoformat()
1396+
raise TypeError("Unknown type")
1397+
1398+
data = {"current": datetime.now(timezone(timedelta(0)))}
1399+
c = client.Client(
1400+
serializer=MsgPackPacket.configure(dumps_default=default))
1401+
p = c.packet_class(data=data)
1402+
p2 = c.packet_class(encoded_packet=p.encode())
1403+
1404+
assert p.data != p2.data
1405+
assert isinstance(p2.data, dict)
1406+
assert "current" in p2.data
1407+
assert isinstance(p2.data["current"], str)
1408+
assert default(data["current"]) == p2.data["current"]

tests/common/test_msgpack_packet.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from datetime import datetime, timedelta, timezone
2+
3+
import pytest
4+
import msgpack
5+
16
from socketio import msgpack_packet
27
from socketio import packet
38

@@ -32,3 +37,102 @@ def test_encode_binary_ack_packet(self):
3237
assert p.packet_type == packet.ACK
3338
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
3439
assert p2.data == {'foo': b'bar'}
40+
41+
def test_encode_with_dumps_default(self):
42+
def default(obj):
43+
if isinstance(obj, datetime):
44+
return obj.isoformat()
45+
raise TypeError('Unknown type')
46+
47+
data = {
48+
'current': datetime.now(tz=timezone(timedelta(0))),
49+
'key': 'value',
50+
}
51+
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
52+
data=data)
53+
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
54+
assert p.packet_type == p2.packet_type
55+
assert p.id == p2.id
56+
assert p.namespace == p2.namespace
57+
assert p.data != p2.data
58+
59+
assert isinstance(p2.data, dict)
60+
assert 'current' in p2.data
61+
assert isinstance(p2.data['current'], str)
62+
assert default(data['current']) == p2.data['current']
63+
64+
data.pop('current')
65+
p2_data_without_current = p2.data.copy()
66+
p2_data_without_current.pop('current')
67+
assert data == p2_data_without_current
68+
69+
def test_encode_without_dumps_default(self):
70+
data = {
71+
'current': datetime.now(tz=timezone(timedelta(0))),
72+
'key': 'value',
73+
}
74+
p_without_default = msgpack_packet.MsgPackPacket(data=data)
75+
with pytest.raises(TypeError):
76+
p_without_default.encode()
77+
78+
def test_encode_decode_with_ext_hook(self):
79+
class Custom:
80+
def __init__(self, value):
81+
self.value = value
82+
83+
def __eq__(self, value: object) -> bool:
84+
return isinstance(value, Custom) and self.value == value.value
85+
86+
def default(obj):
87+
if isinstance(obj, Custom):
88+
return msgpack.ExtType(1, obj.value)
89+
raise TypeError('Unknown type')
90+
91+
def ext_hook(code, data):
92+
if code == 1:
93+
return Custom(data)
94+
raise TypeError('Unknown ext type')
95+
96+
data = {'custom': Custom(b'custom_data'), 'key': 'value'}
97+
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
98+
data=data)
99+
p2 = msgpack_packet.MsgPackPacket.configure(ext_hook=ext_hook)(
100+
encoded_packet=p.encode()
101+
)
102+
assert p.packet_type == p2.packet_type
103+
assert p.id == p2.id
104+
assert p.data == p2.data
105+
assert p.namespace == p2.namespace
106+
107+
def test_encode_decode_without_ext_hook(self):
108+
class Custom:
109+
def __init__(self, value):
110+
self.value = value
111+
112+
def __eq__(self, value: object) -> bool:
113+
return isinstance(value, Custom) and self.value == value.value
114+
115+
def default(obj):
116+
if isinstance(obj, Custom):
117+
return msgpack.ExtType(1, obj.value)
118+
raise TypeError('Unknown type')
119+
120+
data = {'custom': Custom(b'custom_data'), 'key': 'value'}
121+
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
122+
data=data)
123+
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
124+
assert p.packet_type == p2.packet_type
125+
assert p.id == p2.id
126+
assert p.namespace == p2.namespace
127+
assert p.data != p2.data
128+
129+
assert isinstance(p2.data, dict)
130+
assert 'custom' in p2.data
131+
assert isinstance(p2.data['custom'], msgpack.ExtType)
132+
assert p2.data['custom'].code == 1
133+
assert p2.data['custom'].data == b'custom_data'
134+
135+
data.pop('custom')
136+
p2_data_without_custom = p2.data.copy()
137+
p2_data_without_custom.pop('custom')
138+
assert data == p2_data_without_custom

tests/common/test_server.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from unittest import mock
3+
from datetime import datetime, timezone, timedelta
34

45
from engineio import json
56
from engineio import packet as eio_packet
@@ -10,6 +11,7 @@
1011
from socketio import namespace
1112
from socketio import packet
1213
from socketio import server
14+
from socketio.msgpack_packet import MsgPackPacket
1315

1416

1517
@mock.patch('socketio.server.engineio.Server', **{
@@ -1032,3 +1034,21 @@ def test_sleep(self, eio):
10321034
s = server.Server()
10331035
s.sleep(1.23)
10341036
s.eio.sleep.assert_called_once_with(1.23)
1037+
1038+
def test_serializer_args_with_msgpack(self, eio):
1039+
def default(o):
1040+
if isinstance(o, datetime):
1041+
return o.isoformat()
1042+
raise TypeError("Unknown type")
1043+
1044+
data = {"current": datetime.now(timezone(timedelta(0)))}
1045+
s = server.Server(
1046+
serializer=MsgPackPacket.configure(dumps_default=default))
1047+
p = s.packet_class(data=data)
1048+
p2 = s.packet_class(encoded_packet=p.encode())
1049+
1050+
assert p.data != p2.data
1051+
assert isinstance(p2.data, dict)
1052+
assert "current" in p2.data
1053+
assert isinstance(p2.data["current"], str)
1054+
assert default(data["current"]) == p2.data["current"]

0 commit comments

Comments
 (0)