From aeaef01e1327e3a96e99ebb5168d2f5931283335 Mon Sep 17 00:00:00 2001 From: Erik Friese Date: Sat, 5 Aug 2023 18:28:57 +0200 Subject: [PATCH 1/2] Class ScalarArray introduced This increases (de-)serialization speed of repeated scalar fields (of fixed length) drastically in the case they are used as numpy arrays. --- src/betterproto/__init__.py | 73 +++++++++++++++++--------- src/betterproto/plugin/models.py | 8 +-- src/betterproto/scalar_array.py | 89 ++++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 29 deletions(-) create mode 100644 src/betterproto/scalar_array.py diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 99a578478..d21f480ce 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -304,6 +304,17 @@ def map_field( number, TYPE_MAP, map_types=(key_type, value_type), group=group ) +def _is_sequence(x: Any) -> bool: + return not isinstance(x, str) and not isinstance(x, bytes) and isinstance(x, typing.Sequence) + +def _is_empty_sequence(x: Any) -> bool: + return _is_sequence(x) and len(x) == 0 + +def _is_nonempty_sequence(x: Any) -> bool: + return _is_sequence(x) and len(x) != 0 + +def _is_sequence_type(t: Any) -> bool: + return getattr(t, '_name', None) in ['List', 'Sequence'] class Enum(enum.IntEnum): """ @@ -808,7 +819,7 @@ def __bytes__(self) -> bytes: field_name=field_name, meta=meta ) - if value == self._get_field_default(field_name) and not ( + if (_is_empty_sequence(value) or value == self._get_field_default(field_name)) and not ( selected_in_group or serialize_empty or include_default_value_for_oneof ): # Default (zero) values are not serialized. Two exceptions are @@ -816,8 +827,12 @@ def __bytes__(self) -> bytes: # serialize an empty message (i.e. zero value was explicitly # set by the user). continue - - if isinstance(value, list): + + if isinstance(value, ScalarArray) and meta.proto_type in FIXED_TYPES: + if value._ScalarArray__proto_type != meta.proto_type: + raise ValueError("Scalar array has incompatible type") + output += _serialize_single(meta.number, TYPE_BYTES, bytes(value)) + elif _is_sequence(value): if meta.proto_type in PACKED_TYPES: # Packed lists look like a length-delimited field. First, # preprocess/encode each value into a buffer and then @@ -908,6 +923,8 @@ def _get_field_default(self, field_name: str) -> Any: with warnings.catch_warnings(): # ignore warnings when initialising deprecated field defaults warnings.filterwarnings("ignore", category=DeprecationWarning) + if _is_sequence_type(self._betterproto.default_gen[field_name]): + return [] return self._betterproto.default_gen[field_name]() @classmethod @@ -918,7 +935,7 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: if t.__origin__ is dict: # This is some kind of map (dict in Python). return dict - elif t.__origin__ is list: + elif _is_sequence_type(t.__origin__): # This is some kind of list (repeated) field. return list elif t.__origin__ is Union and t.__args__[1] is type(None): @@ -1016,22 +1033,25 @@ def parse(self: T, data: bytes) -> T: value: Any if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES: # This is a packed repeated field. - pos = 0 - value = [] - while pos < len(parsed.value): - if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32): - decoded, pos = parsed.value[pos : pos + 4], pos + 4 - wire_type = WIRE_FIXED_32 - elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64): - decoded, pos = parsed.value[pos : pos + 8], pos + 8 - wire_type = WIRE_FIXED_64 - else: - decoded, pos = decode_varint(parsed.value, pos) - wire_type = WIRE_VARINT - decoded = self._postprocess_single( - wire_type, meta, field_name, decoded - ) - value.append(decoded) + if meta.proto_type in FIXED_TYPES: + value = ScalarArray(parsed.value, meta.proto_type) + else: + pos = 0 + value = [] + while pos < len(parsed.value): + if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32): + decoded, pos = parsed.value[pos : pos + 4], pos + 4 + wire_type = WIRE_FIXED_32 + elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64): + decoded, pos = parsed.value[pos : pos + 8], pos + 8 + wire_type = WIRE_FIXED_64 + else: + decoded, pos = decode_varint(parsed.value, pos) + wire_type = WIRE_VARINT + decoded = self._postprocess_single( + wire_type, meta, field_name, decoded + ) + value.append(decoded) else: value = self._postprocess_single( parsed.wire_type, meta, field_name, parsed.value @@ -1046,7 +1066,7 @@ def parse(self: T, data: bytes) -> T: if meta.proto_type == TYPE_MAP: # Value represents a single key/value pair entry in the map. current[value.key] = value.value - elif isinstance(current, list) and not isinstance(value, list): + elif _is_sequence(current) and not _is_sequence(value): current.append(value) else: setattr(self, field_name, value) @@ -1102,7 +1122,7 @@ def to_dict( field_types = self._type_hints() defaults = self._betterproto.default_gen for field_name, meta in self._betterproto.meta_by_field_name.items(): - field_is_repeated = defaults[field_name] is list + field_is_repeated = _is_sequence_type(defaults[field_name]) try: value = getattr(self, field_name) except AttributeError: @@ -1163,7 +1183,7 @@ def to_dict( if value or include_default_values: output[cased_name] = output_map elif ( - value != self._get_field_default(field_name) + (_is_nonempty_sequence(value) or value != self._get_field_default(field_name)) or include_default_values or self._include_default_value_for_oneof( field_name=field_name, meta=meta @@ -1332,6 +1352,7 @@ def to_json( return json.dumps( self.to_dict(include_default_values=include_default_values, casing=casing), indent=indent, + default=lambda x: x.__json__() ) def from_json(self: T, value: Union[str, bytes]) -> T: @@ -1379,7 +1400,7 @@ def to_pydict( output: Dict[str, Any] = {} defaults = self._betterproto.default_gen for field_name, meta in self._betterproto.meta_by_field_name.items(): - field_is_repeated = defaults[field_name] is list + field_is_repeated = _is_sequence_type(defaults[field_name]) value = getattr(self, field_name) cased_name = casing(field_name).rstrip("_") # type: ignore if meta.proto_type == TYPE_MESSAGE: @@ -1428,7 +1449,7 @@ def to_pydict( if value or include_default_values: output[cased_name] = value elif ( - value != self._get_field_default(field_name) + (_is_nonempty_sequence(value) or value != self._get_field_default(field_name)) or include_default_values or self._include_default_value_for_oneof( field_name=field_name, meta=meta @@ -1584,6 +1605,8 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]] UInt64Value, ) +from .scalar_array import ScalarArray + class _Duration(Duration): @classmethod diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index ea819d44d..c7c2cb854 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -322,7 +322,7 @@ def py_name(self) -> str: @property def annotation(self) -> str: if self.repeated: - return f"List[{self.py_name}]" + return f"Sequence[{self.py_name}]" return self.py_name @property @@ -440,8 +440,8 @@ def typing_imports(self) -> Set[str]: annotation = self.annotation if "Optional[" in annotation: imports.add("Optional") - if "List[" in annotation: - imports.add("List") + if "Sequence[" in annotation: + imports.add("Sequence") if "Dict[" in annotation: imports.add("Dict") return imports @@ -572,7 +572,7 @@ def annotation(self) -> str: if self.use_builtins: py_type = f"builtins.{py_type}" if self.repeated: - return f"List[{py_type}]" + return f"Sequence[{py_type}]" if self.optional: return f"Optional[{py_type}]" return py_type diff --git a/src/betterproto/scalar_array.py b/src/betterproto/scalar_array.py new file mode 100644 index 000000000..e84e37efe --- /dev/null +++ b/src/betterproto/scalar_array.py @@ -0,0 +1,89 @@ +from typing import TypeVar, Any +from collections.abc import Sequence +from . import TYPE_DOUBLE, TYPE_FLOAT, TYPE_SFIXED32, TYPE_FIXED32, TYPE_FIXED64, TYPE_SFIXED64, _pack_fmt +import struct + +NP_DOUBLE = 'float64' +NP_FLOAT = 'float32' +NP_SFIXED32 = 'int32' +NP_FIXED32 = 'uint32' +NP_SFIXED64 = 'int64' +NP_FIXED64 = 'uint64' + +def _convert_types_np2proto(np_type: str) -> str: + return { + NP_DOUBLE: TYPE_DOUBLE, + NP_FLOAT: TYPE_FLOAT, + NP_SFIXED32: TYPE_SFIXED32, + NP_FIXED32: TYPE_FIXED32, + NP_SFIXED64: TYPE_SFIXED64, + NP_FIXED64: TYPE_FIXED64 + }[np_type] + +def _convert_types_proto2np(proto_type: str) -> str: + return { + TYPE_DOUBLE: NP_DOUBLE, + TYPE_FLOAT: NP_FLOAT, + TYPE_SFIXED32: NP_SFIXED32, + TYPE_FIXED32: NP_FIXED32, + TYPE_SFIXED64: NP_SFIXED64, + TYPE_FIXED64: NP_FIXED64 + }[proto_type] + +def _item_size(proto_type: str) -> int: + return { + TYPE_DOUBLE: 8, + TYPE_FLOAT: 4, + TYPE_SFIXED32: 4, + TYPE_FIXED32: 4, + TYPE_SFIXED64: 8, + TYPE_FIXED64: 8 + }[proto_type] + + +T = TypeVar('T', covariant=True) + +class ScalarArray(Sequence[T]): + __data: bytes + __item_size: int + __proto_type: str + + def __init__(self, data: bytes, proto_type: str) -> None: + self.__data = data + self.__item_size = _item_size(proto_type) + self.__proto_type = proto_type + + def __len__(self) -> int: + return len(self.__data) // self.__item_size + + def __getitem__(self, i: int) -> T: + if i < 0: + i += len(self) + if i < 0 or i >= len(self): + raise IndexError + + value = self.__data[i*self.__item_size:(i+1)*self.__item_size] + value = struct.unpack(_pack_fmt(self.__proto_type), value)[0] + return value + + def __bytes__(self) -> bytes: + return self.__data + + def __repr__(self) -> str: + return str(list(self)) + + def __array__(self): + import numpy as np + return np.frombuffer(self.__data, dtype=_convert_types_proto2np(self.__proto_type)) + + def __json__(self): + return list(self) + + def __eq__(self, other): + if isinstance(other, ScalarArray): + return self.__data == other.__data and self.__item_size == other.__item_size and self.__proto_type == other.__proto_type + return isinstance(other, Sequence) and list(self) == list(other) + + @staticmethod + def from_numpy(ar) -> 'ScalarArray[Any]': + return ScalarArray(bytes(ar), _convert_types_np2proto(str(ar.dtype))) From 126505ab20bffddd83ca570004eeed3349d2a02d Mon Sep 17 00:00:00 2001 From: Erik Friese Date: Sat, 12 Aug 2023 14:21:05 +0200 Subject: [PATCH 2/2] Clean Up unreachable code removed generic type parameter removed from ScalarArray for compatibility with Python 3.7 and 3.8 code (auto-)reformatted --- src/betterproto/__init__.py | 43 ++++++++++++++-------- src/betterproto/scalar_array.py | 65 ++++++++++++++++++++------------- 2 files changed, 67 insertions(+), 41 deletions(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index d21f480ce..a68695fb1 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -304,17 +304,26 @@ def map_field( number, TYPE_MAP, map_types=(key_type, value_type), group=group ) + def _is_sequence(x: Any) -> bool: - return not isinstance(x, str) and not isinstance(x, bytes) and isinstance(x, typing.Sequence) + return ( + not isinstance(x, str) + and not isinstance(x, bytes) + and isinstance(x, typing.Sequence) + ) + def _is_empty_sequence(x: Any) -> bool: return _is_sequence(x) and len(x) == 0 + def _is_nonempty_sequence(x: Any) -> bool: return _is_sequence(x) and len(x) != 0 + def _is_sequence_type(t: Any) -> bool: - return getattr(t, '_name', None) in ['List', 'Sequence'] + return getattr(t, "_name", None) in ["List", "Sequence"] + class Enum(enum.IntEnum): """ @@ -819,7 +828,10 @@ def __bytes__(self) -> bytes: field_name=field_name, meta=meta ) - if (_is_empty_sequence(value) or value == self._get_field_default(field_name)) and not ( + if ( + _is_empty_sequence(value) + or value == self._get_field_default(field_name) + ) and not ( selected_in_group or serialize_empty or include_default_value_for_oneof ): # Default (zero) values are not serialized. Two exceptions are @@ -827,7 +839,7 @@ def __bytes__(self) -> bytes: # serialize an empty message (i.e. zero value was explicitly # set by the user). continue - + if isinstance(value, ScalarArray) and meta.proto_type in FIXED_TYPES: if value._ScalarArray__proto_type != meta.proto_type: raise ValueError("Scalar array has incompatible type") @@ -1039,15 +1051,8 @@ def parse(self: T, data: bytes) -> T: pos = 0 value = [] while pos < len(parsed.value): - if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32): - decoded, pos = parsed.value[pos : pos + 4], pos + 4 - wire_type = WIRE_FIXED_32 - elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64): - decoded, pos = parsed.value[pos : pos + 8], pos + 8 - wire_type = WIRE_FIXED_64 - else: - decoded, pos = decode_varint(parsed.value, pos) - wire_type = WIRE_VARINT + decoded, pos = decode_varint(parsed.value, pos) + wire_type = WIRE_VARINT decoded = self._postprocess_single( wire_type, meta, field_name, decoded ) @@ -1183,7 +1188,10 @@ def to_dict( if value or include_default_values: output[cased_name] = output_map elif ( - (_is_nonempty_sequence(value) or value != self._get_field_default(field_name)) + ( + _is_nonempty_sequence(value) + or value != self._get_field_default(field_name) + ) or include_default_values or self._include_default_value_for_oneof( field_name=field_name, meta=meta @@ -1352,7 +1360,7 @@ def to_json( return json.dumps( self.to_dict(include_default_values=include_default_values, casing=casing), indent=indent, - default=lambda x: x.__json__() + default=lambda x: x.__json__(), ) def from_json(self: T, value: Union[str, bytes]) -> T: @@ -1449,7 +1457,10 @@ def to_pydict( if value or include_default_values: output[cased_name] = value elif ( - (_is_nonempty_sequence(value) or value != self._get_field_default(field_name)) + ( + _is_nonempty_sequence(value) + or value != self._get_field_default(field_name) + ) or include_default_values or self._include_default_value_for_oneof( field_name=field_name, meta=meta diff --git a/src/betterproto/scalar_array.py b/src/betterproto/scalar_array.py index e84e37efe..18553f9ba 100644 --- a/src/betterproto/scalar_array.py +++ b/src/betterproto/scalar_array.py @@ -1,14 +1,22 @@ -from typing import TypeVar, Any -from collections.abc import Sequence -from . import TYPE_DOUBLE, TYPE_FLOAT, TYPE_SFIXED32, TYPE_FIXED32, TYPE_FIXED64, TYPE_SFIXED64, _pack_fmt import struct +from collections.abc import Sequence +from . import ( + TYPE_DOUBLE, + TYPE_FIXED32, + TYPE_FIXED64, + TYPE_FLOAT, + TYPE_SFIXED32, + TYPE_SFIXED64, + _pack_fmt, +) + +NP_DOUBLE = "float64" +NP_FLOAT = "float32" +NP_SFIXED32 = "int32" +NP_FIXED32 = "uint32" +NP_SFIXED64 = "int64" +NP_FIXED64 = "uint64" -NP_DOUBLE = 'float64' -NP_FLOAT = 'float32' -NP_SFIXED32 = 'int32' -NP_FIXED32 = 'uint32' -NP_SFIXED64 = 'int64' -NP_FIXED64 = 'uint64' def _convert_types_np2proto(np_type: str) -> str: return { @@ -17,9 +25,10 @@ def _convert_types_np2proto(np_type: str) -> str: NP_SFIXED32: TYPE_SFIXED32, NP_FIXED32: TYPE_FIXED32, NP_SFIXED64: TYPE_SFIXED64, - NP_FIXED64: TYPE_FIXED64 + NP_FIXED64: TYPE_FIXED64, }[np_type] + def _convert_types_proto2np(proto_type: str) -> str: return { TYPE_DOUBLE: NP_DOUBLE, @@ -27,9 +36,10 @@ def _convert_types_proto2np(proto_type: str) -> str: TYPE_SFIXED32: NP_SFIXED32, TYPE_FIXED32: NP_FIXED32, TYPE_SFIXED64: NP_SFIXED64, - TYPE_FIXED64: NP_FIXED64 + TYPE_FIXED64: NP_FIXED64, }[proto_type] + def _item_size(proto_type: str) -> int: return { TYPE_DOUBLE: 8, @@ -37,13 +47,11 @@ def _item_size(proto_type: str) -> int: TYPE_SFIXED32: 4, TYPE_FIXED32: 4, TYPE_SFIXED64: 8, - TYPE_FIXED64: 8 + TYPE_FIXED64: 8, }[proto_type] -T = TypeVar('T', covariant=True) - -class ScalarArray(Sequence[T]): +class ScalarArray(Sequence): __data: bytes __item_size: int __proto_type: str @@ -55,35 +63,42 @@ def __init__(self, data: bytes, proto_type: str) -> None: def __len__(self) -> int: return len(self.__data) // self.__item_size - - def __getitem__(self, i: int) -> T: + + def __getitem__(self, i: int): if i < 0: i += len(self) if i < 0 or i >= len(self): raise IndexError - value = self.__data[i*self.__item_size:(i+1)*self.__item_size] + value = self.__data[i * self.__item_size : (i + 1) * self.__item_size] value = struct.unpack(_pack_fmt(self.__proto_type), value)[0] return value - + def __bytes__(self) -> bytes: return self.__data - + def __repr__(self) -> str: return str(list(self)) def __array__(self): import numpy as np - return np.frombuffer(self.__data, dtype=_convert_types_proto2np(self.__proto_type)) + + return np.frombuffer( + self.__data, dtype=_convert_types_proto2np(self.__proto_type) + ) def __json__(self): return list(self) - + def __eq__(self, other): if isinstance(other, ScalarArray): - return self.__data == other.__data and self.__item_size == other.__item_size and self.__proto_type == other.__proto_type + return ( + self.__data == other.__data + and self.__item_size == other.__item_size + and self.__proto_type == other.__proto_type + ) return isinstance(other, Sequence) and list(self) == list(other) - + @staticmethod - def from_numpy(ar) -> 'ScalarArray[Any]': + def from_numpy(ar) -> "ScalarArray": return ScalarArray(bytes(ar), _convert_types_np2proto(str(ar.dtype)))