Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/socketio/msgpack_packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,28 @@

class MsgPackPacket(packet.Packet):
uses_binary_events = False
dumps_default = None
ext_hook = msgpack.ExtType

@classmethod
def configure(cls, dumps_default=None, ext_hook=msgpack.ExtType):
class CustomMsgPackPacket(MsgPackPacket):
dumps_default = None
ext_hook = None

CustomMsgPackPacket.dumps_default = dumps_default
CustomMsgPackPacket.ext_hook = ext_hook
return CustomMsgPackPacket

def encode(self):
"""Encode the packet for transmission."""
return msgpack.dumps(self._to_dict())
return msgpack.dumps(self._to_dict(),
default=self.__class__.dumps_default)

def decode(self, encoded_packet):
"""Decode a transmitted package."""
decoded = msgpack.loads(encoded_packet)
decoded = msgpack.loads(encoded_packet,
ext_hook=self.__class__.ext_hook)
self.packet_type = decoded['type']
self.data = decoded.get('data')
self.id = decoded.get('id')
Expand Down
20 changes: 20 additions & 0 deletions tests/async/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from unittest import mock
from datetime import datetime, timezone, timedelta

import pytest

Expand All @@ -8,6 +9,7 @@
from engineio import exceptions as engineio_exceptions
from socketio import exceptions
from socketio import packet
from socketio.msgpack_packet import MsgPackPacket


class TestAsyncClient:
Expand Down Expand Up @@ -1242,3 +1244,21 @@ async def test_eio_disconnect_no_reconnect(self):
assert c.sid is None
assert not c.connected
c.start_background_task.assert_not_called()

def test_serializer_args_with_msgpack(self):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")

data = {"current": datetime.now(timezone(timedelta(0)))}
c = async_client.AsyncClient(
serializer=MsgPackPacket.configure(dumps_default=default))
p = c.packet_class(data=data)
p2 = c.packet_class(encoded_packet=p.encode())

assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]
20 changes: 20 additions & 0 deletions tests/async/test_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from unittest import mock
from datetime import datetime, timezone, timedelta

from engineio import json
from engineio import packet as eio_packet
Expand All @@ -11,6 +12,7 @@
from socketio import exceptions
from socketio import namespace
from socketio import packet
from socketio.msgpack_packet import MsgPackPacket


@mock.patch('socketio.server.engineio.AsyncServer', **{
Expand Down Expand Up @@ -1089,3 +1091,21 @@ async def test_sleep(self, eio):
s = async_server.AsyncServer()
await s.sleep(1.23)
s.eio.sleep.assert_awaited_once_with(1.23)

def test_serializer_args_with_msgpack(self, eio):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")

data = {"current": datetime.now(timezone(timedelta(0)))}
s = async_server.AsyncServer(
serializer=MsgPackPacket.configure(dumps_default=default))
p = s.packet_class(data=data)
p2 = s.packet_class(encoded_packet=p.encode())

assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]
20 changes: 20 additions & 0 deletions tests/common/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import time
from unittest import mock
from datetime import datetime, timezone, timedelta

from engineio import exceptions as engineio_exceptions
from engineio import json
Expand All @@ -13,6 +14,7 @@
from socketio import msgpack_packet
from socketio import namespace
from socketio import packet
from socketio.msgpack_packet import MsgPackPacket


class TestClient:
Expand Down Expand Up @@ -1386,3 +1388,21 @@ def test_eio_disconnect_no_reconnect(self):
assert c.sid is None
assert not c.connected
c.start_background_task.assert_not_called()

def test_serializer_args_with_msgpack(self):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")

data = {"current": datetime.now(timezone(timedelta(0)))}
c = client.Client(
serializer=MsgPackPacket.configure(dumps_default=default))
p = c.packet_class(data=data)
p2 = c.packet_class(encoded_packet=p.encode())

assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]
104 changes: 104 additions & 0 deletions tests/common/test_msgpack_packet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from datetime import datetime, timedelta, timezone

import pytest
import msgpack

from socketio import msgpack_packet
from socketio import packet

Expand Down Expand Up @@ -32,3 +37,102 @@ def test_encode_binary_ack_packet(self):
assert p.packet_type == packet.ACK
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p2.data == {'foo': b'bar'}

def test_encode_with_dumps_default(self):
def default(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError('Unknown type')

data = {
'current': datetime.now(tz=timezone(timedelta(0))),
'key': 'value',
}
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.namespace == p2.namespace
assert p.data != p2.data

assert isinstance(p2.data, dict)
assert 'current' in p2.data
assert isinstance(p2.data['current'], str)
assert default(data['current']) == p2.data['current']

data.pop('current')
p2_data_without_current = p2.data.copy()
p2_data_without_current.pop('current')
assert data == p2_data_without_current

def test_encode_without_dumps_default(self):
data = {
'current': datetime.now(tz=timezone(timedelta(0))),
'key': 'value',
}
p_without_default = msgpack_packet.MsgPackPacket(data=data)
with pytest.raises(TypeError):
p_without_default.encode()

def test_encode_decode_with_ext_hook(self):
class Custom:
def __init__(self, value):
self.value = value

def __eq__(self, value: object) -> bool:
return isinstance(value, Custom) and self.value == value.value

def default(obj):
if isinstance(obj, Custom):
return msgpack.ExtType(1, obj.value)
raise TypeError('Unknown type')

def ext_hook(code, data):
if code == 1:
return Custom(data)
raise TypeError('Unknown ext type')

data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket.configure(ext_hook=ext_hook)(
encoded_packet=p.encode()
)
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.data == p2.data
assert p.namespace == p2.namespace

def test_encode_decode_without_ext_hook(self):
class Custom:
def __init__(self, value):
self.value = value

def __eq__(self, value: object) -> bool:
return isinstance(value, Custom) and self.value == value.value

def default(obj):
if isinstance(obj, Custom):
return msgpack.ExtType(1, obj.value)
raise TypeError('Unknown type')

data = {'custom': Custom(b'custom_data'), 'key': 'value'}
p = msgpack_packet.MsgPackPacket.configure(dumps_default=default)(
data=data)
p2 = msgpack_packet.MsgPackPacket(encoded_packet=p.encode())
assert p.packet_type == p2.packet_type
assert p.id == p2.id
assert p.namespace == p2.namespace
assert p.data != p2.data

assert isinstance(p2.data, dict)
assert 'custom' in p2.data
assert isinstance(p2.data['custom'], msgpack.ExtType)
assert p2.data['custom'].code == 1
assert p2.data['custom'].data == b'custom_data'

data.pop('custom')
p2_data_without_custom = p2.data.copy()
p2_data_without_custom.pop('custom')
assert data == p2_data_without_custom
20 changes: 20 additions & 0 deletions tests/common/test_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from unittest import mock
from datetime import datetime, timezone, timedelta

from engineio import json
from engineio import packet as eio_packet
Expand All @@ -10,6 +11,7 @@
from socketio import namespace
from socketio import packet
from socketio import server
from socketio.msgpack_packet import MsgPackPacket


@mock.patch('socketio.server.engineio.Server', **{
Expand Down Expand Up @@ -1032,3 +1034,21 @@ def test_sleep(self, eio):
s = server.Server()
s.sleep(1.23)
s.eio.sleep.assert_called_once_with(1.23)

def test_serializer_args_with_msgpack(self, eio):
def default(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Unknown type")

data = {"current": datetime.now(timezone(timedelta(0)))}
s = server.Server(
serializer=MsgPackPacket.configure(dumps_default=default))
p = s.packet_class(data=data)
p2 = s.packet_class(encoded_packet=p.encode())

assert p.data != p2.data
assert isinstance(p2.data, dict)
assert "current" in p2.data
assert isinstance(p2.data["current"], str)
assert default(data["current"]) == p2.data["current"]
Loading