diff --git a/jupyter_client/session.py b/jupyter_client/session.py index c58067ad..8cfd892d 100644 --- a/jupyter_client/session.py +++ b/jupyter_client/session.py @@ -13,6 +13,7 @@ # Distributed under the terms of the Modified BSD License. from __future__ import annotations +import functools import hashlib import hmac import json @@ -33,6 +34,7 @@ from traitlets import ( Any, Bool, + Callable, CBytes, CUnicode, Dict, @@ -125,6 +127,40 @@ def json_unpacker(s: str | bytes) -> t.Any: return json.loads(s) +try: + import orjson # type:ignore[import-not-found] +except ModuleNotFoundError: + orjson = None + orjson_packer, orjson_unpacker = json_packer, json_unpacker +else: + + def orjson_packer(obj, *, options=orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z) -> bytes: + """Convert a json object to a bytes using orjson with fallback to json_packer.""" + try: + return orjson.dumps(obj, default=json_default, options=options) + except Exception: + pass + return json_packer(obj) + + def orjson_unpacker(s: str | bytes) -> t.Any: + """Convert a json bytes or string to an object using orjson with fallback to json_unpacker.""" + try: + orjson.loads(s) + except Exception: + pass + return json_unpacker(s) + + +try: + import msgpack # type:ignore[import-not-found] + +except ModuleNotFoundError: + msgpack = None +else: + msgpack_packer = functools.partial(msgpack.packb, default=json_default) + msgpack_unpacker = msgpack.unpackb + + def pickle_packer(o: t.Any) -> bytes: """Pack an object using the pickle module.""" return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) @@ -132,8 +168,6 @@ def pickle_packer(o: t.Any) -> bytes: pickle_unpacker = pickle.loads -default_packer = json_packer -default_unpacker = json_unpacker DELIM = b"" # singleton dummy tracker, which will always report as done @@ -316,7 +350,7 @@ class Session(Configurable): debug : bool whether to trigger extra debugging statements - packer/unpacker : str : 'json', 'pickle' or import_string + packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. @@ -351,48 +385,40 @@ class Session(Configurable): """, ) + # serialization traits: packer = DottedObjectName( - "json", + "orjson" if orjson else "json", config=True, help="""The name of the packer for serializing messages. Should be one of 'json', 'pickle', or an import name for a custom callable serializer.""", ) - - @observe("packer") - def _packer_changed(self, change: t.Any) -> None: - new = change["new"] - if new.lower() == "json": - self.pack = json_packer - self.unpack = json_unpacker - self.unpacker = new - elif new.lower() == "pickle": - self.pack = pickle_packer - self.unpack = pickle_unpacker - self.unpacker = new - else: - self.pack = import_item(str(new)) - unpacker = DottedObjectName( - "json", + "orjson" if orjson else "json", config=True, help="""The name of the unpacker for unserializing messages. Only used with custom functions for `packer`.""", ) - - @observe("unpacker") - def _unpacker_changed(self, change: t.Any) -> None: - new = change["new"] - if new.lower() == "json": - self.pack = json_packer - self.unpack = json_unpacker - self.packer = new - elif new.lower() == "pickle": - self.pack = pickle_packer - self.unpack = pickle_unpacker - self.packer = new + pack = Callable(orjson_packer if orjson else json_packer) # the actual packer function + unpack = Callable(orjson_unpacker if orjson else json_unpacker) # the actual unpacker function + + @observe("packer", "unpacker") + def _packer_unpacker_changed(self, change: t.Any) -> None: + new = change["new"].lower() + if new == "orjson" and orjson: + self.pack, self.unpack = orjson_packer, orjson_unpacker + elif new == "json" or new == "orjson": + self.pack, self.unpack = json_packer, json_unpacker + elif new == "pickle": + self.pack, self.unpack = pickle_packer, pickle_unpacker + elif new == "msgpack" and msgpack: + self.pack, self.unpack = msgpack_packer, msgpack_unpacker else: - self.unpack = import_item(str(new)) + obj = import_item(str(change["new"])) + name = "pack" if change["name"] == "packer" else "unpack" + self.set_trait(name, obj) + return + self.packer = self.unpacker = change["new"] session = CUnicode("", config=True, help="""The UUID identifying this session.""") @@ -417,8 +443,7 @@ def _session_changed(self, change: t.Any) -> None: metadata = Dict( {}, config=True, - help="Metadata dictionary, which serves as the default top-level metadata dict for each " - "message.", + help="Metadata dictionary, which serves as the default top-level metadata dict for each message.", ) # if 0, no adapting to do. @@ -487,25 +512,6 @@ def _keyfile_changed(self, change: t.Any) -> None: # for protecting against sends from forks pid = Integer() - # serialization traits: - - pack = Any(default_packer) # the actual packer function - - @observe("pack") - def _pack_changed(self, change: t.Any) -> None: - new = change["new"] - if not callable(new): - raise TypeError("packer must be callable, not %s" % type(new)) - - unpack = Any(default_unpacker) # the actual packer function - - @observe("unpack") - def _unpack_changed(self, change: t.Any) -> None: - # unpacker is not checked - it is assumed to be - new = change["new"] - if not callable(new): - raise TypeError("unpacker must be callable, not %s" % type(new)) - # thresholds: copy_threshold = Integer( 2**16, @@ -515,8 +521,7 @@ def _unpack_changed(self, change: t.Any) -> None: buffer_threshold = Integer( MAX_BYTES, config=True, - help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid " - "pickling.", + help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.", ) item_threshold = Integer( MAX_ITEMS, @@ -534,7 +539,7 @@ def __init__(self, **kwargs: t.Any) -> None: debug : bool whether to trigger extra debugging statements - packer/unpacker : str : 'json', 'pickle' or import_string + packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. @@ -626,10 +631,7 @@ def _check_packers(self) -> None: unpacked = unpack(packed) assert unpacked == msg_list except Exception as e: - msg = ( - f"unpacker '{self.unpacker}' could not handle output from packer" - f" '{self.packer}': {e}" - ) + msg = f"unpacker {self.unpacker!r} could not handle output from packer {self.packer!r}: {e}" raise ValueError(msg) from e # check datetime support diff --git a/pyproject.toml b/pyproject.toml index 67ce34b4..cf4378f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "jupyter_core>=5.1", + "orjson>=3.10.18; implementation_name == 'cpython'", "python-dateutil>=2.8.2", "pyzmq>=25.0", "tornado>=6.4.1", @@ -55,6 +56,7 @@ test = [ "pytest-jupyter[client]>=0.6.2", "pytest-cov", "pytest-timeout", + "msgpack" ] docs = [ "ipykernel", diff --git a/tests/test_session.py b/tests/test_session.py index de817423..b7a136b9 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -9,12 +9,14 @@ import uuid import warnings from datetime import datetime +from pickle import PicklingError from unittest import mock import pytest import zmq from dateutil.tz import tzlocal from tornado import ioloop +from traitlets import TraitError from zmq.eventloop.zmqstream import ZMQStream from jupyter_client import jsonutil @@ -41,6 +43,16 @@ def session(): return ss.Session() +serializers = [ + ("json", ss.json_packer, ss.json_unpacker), + ("pickle", ss.pickle_packer, ss.pickle_unpacker), +] +if ss.orjson: + serializers.append(("orjson", ss.orjson_packer, ss.orjson_unpacker)) +if ss.msgpack: + serializers.append(("msgpack", ss.msgpack_packer, ss.msgpack_unpacker)) + + @pytest.mark.usefixtures("no_copy_threshold") class TestSession: def assertEqual(self, a, b): @@ -64,7 +76,11 @@ def test_msg(self, session): self.assertEqual(msg["header"]["msg_type"], "execute") self.assertEqual(msg["msg_type"], "execute") - def test_serialize(self, session): + @pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) + def test_serialize(self, session, packer, pack, unpack): + session.packer = packer + assert session.pack is pack + assert session.unpack is unpack msg = session.msg("execute", content=dict(a=10, b=1.1)) msg_list = session.serialize(msg, ident=b"foo") ident, msg_list = session.feed_identities(msg_list) @@ -234,16 +250,16 @@ async def test_send(self, session): def test_args(self, session): """initialization arguments for Session""" s = session - self.assertTrue(s.pack is ss.default_packer) - self.assertTrue(s.unpack is ss.default_unpacker) + assert s.pack is ss._default_pack_unpack[0] + assert s.unpack is ss._default_pack_unpack[1] self.assertEqual(s.username, os.environ.get("USER", "username")) s = ss.Session() self.assertEqual(s.username, os.environ.get("USER", "username")) - with pytest.raises(TypeError): + with pytest.raises(TraitError): ss.Session(pack="hi") - with pytest.raises(TypeError): + with pytest.raises(TraitError): ss.Session(unpack="hi") u = str(uuid.uuid4()) s = ss.Session(username="carrot", session=u) @@ -491,11 +507,6 @@ async def test_send_raw(self, session): B.close() ctx.term() - def test_set_packer(self, session): - s = session - s.packer = "json" - s.unpacker = "json" - def test_clone(self, session): s = session s._add_digest("initial") @@ -515,14 +526,45 @@ def test_squash_unicode(): assert ss.squash_unicode("hi") == b"hi" -def test_json_packer(): - ss.json_packer(dict(a=1)) - with pytest.raises(ValueError): - ss.json_packer(dict(a=ss.Session())) - ss.json_packer(dict(a=datetime(2021, 4, 1, 12, tzinfo=tzlocal()))) +@pytest.mark.parametrize( + ["description", "data"], + [ + ("dict", [{"a": 1}, [{"a": 1}]]), + ("infinite", [math.inf, ["inf", None]]), + ("datetime", [datetime(2021, 4, 1, 12, tzinfo=tzlocal()), []]), + ], +) +@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) +def test_serialize_objects(packer, pack, unpack, description, data): + data_in, data_out_options = data with warnings.catch_warnings(): warnings.simplefilter("ignore") - ss.json_packer(dict(a=math.inf)) + value = pack(data_in) + unpacked = unpack(value) + if (description == "infinite") and (packer in ["pickle", "msgpack"]): + assert math.isinf(unpacked) + elif description == "datetime": + assert data_in == jsonutil.parse_date(unpacked) + else: + assert unpacked in data_out_options + + +@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) +def test_cannot_serialize(session, packer, pack, unpack): + data = {"a": session} + with pytest.raises((TypeError, ValueError, PicklingError)): + pack(data) + + +@pytest.mark.parametrize("mode", ["packer", "unpacker"]) +@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) +def test_pack_unpack(session, packer, pack, unpack, mode): + s: ss.Session = session + s.set_trait(mode, packer) + assert s.pack is pack + assert s.unpack is unpack + mode_reverse = "unpacker" if mode == "packer" else "packer" + assert getattr(s, mode_reverse) == packer def test_message_cls():