diff --git a/starknet_py/cairo/felt.py b/starknet_py/cairo/felt.py index 51c7d6087..1cba0d505 100644 --- a/starknet_py/cairo/felt.py +++ b/starknet_py/cairo/felt.py @@ -8,6 +8,9 @@ MAX_UINT256 = (1 << 256) - 1 MIN_UINT256 = 0 +MAX_UINT512 = (1 << 512) - 1 +MIN_UINT512 = 0 + def uint256_range_check(value: int): if not MIN_UINT256 <= value <= MAX_UINT256: @@ -16,6 +19,13 @@ def uint256_range_check(value: int): ) +def uint512_range_check(value: int): + if not MIN_UINT512 <= value <= MAX_UINT512: + raise ValueError( + f"Uint512 is expected to be in range [0;2**512), got: {value}." + ) + + MIN_FELT = -FIELD_PRIME // 2 MAX_FELT = FIELD_PRIME // 2 diff --git a/starknet_py/serialization/data_serializers/__init__.py b/starknet_py/serialization/data_serializers/__init__.py index 3d67c45fe..b97dfc03d 100644 --- a/starknet_py/serialization/data_serializers/__init__.py +++ b/starknet_py/serialization/data_serializers/__init__.py @@ -8,4 +8,5 @@ from .struct_serializer import StructSerializer from .tuple_serializer import TupleSerializer from .uint256_serializer import Uint256Serializer +from .uint512_serializer import Uint512Serializer from .uint_serializer import UintSerializer diff --git a/starknet_py/serialization/data_serializers/uint512_serializer.py b/starknet_py/serialization/data_serializers/uint512_serializer.py new file mode 100644 index 000000000..1d2abd38d --- /dev/null +++ b/starknet_py/serialization/data_serializers/uint512_serializer.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +from typing import Generator, TypedDict, Union + +from starknet_py.cairo.felt import uint512_range_check +from starknet_py.serialization._context import ( + Context, + DeserializationContext, + SerializationContext, +) +from starknet_py.serialization.data_serializers.cairo_data_serializer import ( + CairoDataSerializer, +) + +U128_UPPER_BOUND = 2**128 + + +class Uint512Dict(TypedDict): + limb0: int + limb1: int + limb2: int + limb3: int + + +@dataclass +class Uint512Serializer(CairoDataSerializer[Union[int, Uint512Dict], int]): + """ + Serializer of Uint512. In Cairo it is represented by structure {limb0: Uint128, limb1: Uint128, limb2: Uint128, limb3: Uint128}. + Can serialize an int. + Deserializes data to an int. + + Examples: + 0 => [0,0,0,0] + 1 => [1,0,0,0] + 2**128 => [0,1,0,0] + 2**256 => [0,0,1,0] + 2**384 => [0,0,0,1] + 3 + 2**128 => [3,1,0,0] + """ + + def deserialize_with_context(self, context: DeserializationContext) -> int: + [limb0, limb1, limb2, limb3] = context.reader.read(4) + + # Checking if resulting value is in [0, 2**512) range is not enough. Uint512 should be made of four uint128. + with context.push_entity("limb0"): + self._ensure_valid_uint128(limb0, context) + with context.push_entity("limb1"): + self._ensure_valid_uint128(limb1, context) + with context.push_entity("limb2"): + self._ensure_valid_uint128(limb2, context) + with context.push_entity("limb3"): + self._ensure_valid_uint128(limb3, context) + + return (limb3 << 384) + (limb2 << 256) + (limb1 << 128) + limb0 + + def serialize_with_context( + self, context: SerializationContext, value: Union[int, Uint512Dict] + ) -> Generator[int, None, None]: + context.ensure_valid_type(value, isinstance(value, (int, dict)), "int or dict") + if isinstance(value, int): + yield from self._serialize_from_int(value) + else: + yield from self._serialize_from_dict(context, value) + + @staticmethod + def _serialize_from_int(value: int) -> Generator[int, None, None]: + uint512_range_check(value) + limb0 = value % (2**128) + limb1 = (value >> 128) % (2**128) + limb2 = (value >> 256) % (2**128) + limb3 = (value >> 384) % (2**128) + result = (limb0, limb1, limb2, limb3) + yield from result + + def _serialize_from_dict( + self, context: SerializationContext, value: Uint512Dict + ) -> Generator[int, None, None]: + with context.push_entity("limb0"): + self._ensure_valid_uint128(value["limb0"], context) + yield value["limb0"] + with context.push_entity("limb1"): + self._ensure_valid_uint128(value["limb1"], context) + yield value["limb1"] + with context.push_entity("limb2"): + self._ensure_valid_uint128(value["limb2"], context) + yield value["limb2"] + with context.push_entity("limb3"): + self._ensure_valid_uint128(value["limb3"], context) + yield value["limb3"] + + @staticmethod + def _ensure_valid_uint128(value: int, context: Context): + context.ensure_valid_value( + 0 <= value < U128_UPPER_BOUND, "expected value in range [0;2**128)" + ) \ No newline at end of file diff --git a/starknet_py/tests/unit/serialization/data_serializers/uint512_serializer_test.py b/starknet_py/tests/unit/serialization/data_serializers/uint512_serializer_test.py new file mode 100644 index 000000000..45cf06dfd --- /dev/null +++ b/starknet_py/tests/unit/serialization/data_serializers/uint512_serializer_test.py @@ -0,0 +1,133 @@ +import re + +import pytest + +from starknet_py.serialization.data_serializers.uint512_serializer import ( + Uint512Serializer, +) +from starknet_py.serialization.errors import InvalidTypeException, InvalidValueException + +serializer = Uint512Serializer() +SHIFT_128 = 2**128 +SHIFT_256 = 2**256 +SHIFT_384 = 2**384 +MAX_U128 = SHIFT_128 - 1 + + +@pytest.mark.parametrize( + "value, serialized_value", + [ + (123 + 456 * SHIFT_128 + 789 * SHIFT_256 + 101 * SHIFT_384, [123, 456, 789, 101]), + ( + 21323213211421424142 + 347932774343 * SHIFT_128 + 987654321 * SHIFT_256 + 123456789 * SHIFT_384, + [21323213211421424142, 347932774343, 987654321, 123456789], + ), + (0, [0, 0, 0, 0]), + (MAX_U128, [MAX_U128, 0, 0, 0]), + (MAX_U128 * SHIFT_128, [0, MAX_U128, 0, 0]), + (MAX_U128 * SHIFT_256, [0, 0, MAX_U128, 0]), + (MAX_U128 * SHIFT_384, [0, 0, 0, MAX_U128]), + (MAX_U128 + MAX_U128 * SHIFT_128 + MAX_U128 * SHIFT_256 + MAX_U128 * SHIFT_384, [MAX_U128, MAX_U128, MAX_U128, MAX_U128]), + (1, [1, 0, 0, 0]), + (SHIFT_128, [0, 1, 0, 0]), + (SHIFT_256, [0, 0, 1, 0]), + (SHIFT_384, [0, 0, 0, 1]), + ], +) +def test_valid_values(value, serialized_value): + deserialized = serializer.deserialize(serialized_value) + assert deserialized == value + + serialized = serializer.serialize(value) + assert serialized == serialized_value + + assert serialized_value == serializer.serialize( + {"limb0": serialized_value[0], "limb1": serialized_value[1], "limb2": serialized_value[2], "limb3": serialized_value[3]} + ) + + +def test_deserialize_invalid_values(): + # We need to escape braces + limb0_error_message = re.escape( + "Error at path 'limb0': expected value in range [0;2**128)" + ) + with pytest.raises(InvalidValueException, match=limb0_error_message): + serializer.deserialize([MAX_U128 + 1, 0, 0, 0]) + with pytest.raises(InvalidValueException, match=limb0_error_message): + serializer.deserialize([MAX_U128 + 1, MAX_U128 + 1, MAX_U128 + 1, MAX_U128 + 1]) + with pytest.raises(InvalidValueException, match=limb0_error_message): + serializer.deserialize([-1, 0, 0, 0]) + + limb1_error_message = re.escape( + "Error at path 'limb1': expected value in range [0;2**128)" + ) + with pytest.raises(InvalidValueException, match=limb1_error_message): + serializer.deserialize([0, MAX_U128 + 1, 0, 0]) + with pytest.raises(InvalidValueException, match=limb1_error_message): + serializer.deserialize([0, -1, 0, 0]) + + limb2_error_message = re.escape( + "Error at path 'limb2': expected value in range [0;2**128)" + ) + with pytest.raises(InvalidValueException, match=limb2_error_message): + serializer.deserialize([0, 0, MAX_U128 + 1, 0]) + with pytest.raises(InvalidValueException, match=limb2_error_message): + serializer.deserialize([0, 0, -1, 0]) + + limb3_error_message = re.escape( + "Error at path 'limb3': expected value in range [0;2**128)" + ) + with pytest.raises(InvalidValueException, match=limb3_error_message): + serializer.deserialize([0, 0, 0, MAX_U128 + 1]) + with pytest.raises(InvalidValueException, match=limb3_error_message): + serializer.deserialize([0, 0, 0, -1]) + + +def test_serialize_invalid_int_value(): + error_message = re.escape("Error: Uint512 is expected to be in range [0;2**512)") + with pytest.raises(InvalidValueException, match=error_message): + serializer.serialize(2**512) + with pytest.raises(InvalidValueException, match=error_message): + serializer.serialize(-1) + + +def test_serialize_invalid_dict_values(): + limb0_error_message = re.escape( + "Error at path 'limb0': expected value in range [0;2**128)" + ) + with pytest.raises(InvalidValueException, match=limb0_error_message): + serializer.serialize({"limb0": -1, "limb1": 12324, "limb2": 456, "limb3": 789}) + with pytest.raises(InvalidValueException, match=limb0_error_message): + serializer.serialize({"limb0": MAX_U128 + 1, "limb1": 4543535, "limb2": 456, "limb3": 789}) + + limb1_error_message = re.escape( + "Error at path 'limb1': expected value in range [0;2**128)" + ) + with pytest.raises(InvalidValueException, match=limb1_error_message): + serializer.serialize({"limb0": 652432, "limb1": -1, "limb2": 456, "limb3": 789}) + with pytest.raises(InvalidValueException, match=limb1_error_message): + serializer.serialize({"limb0": 0, "limb1": MAX_U128 + 1, "limb2": 456, "limb3": 789}) + + limb2_error_message = re.escape( + "Error at path 'limb2': expected value in range [0;2**128)" + ) + with pytest.raises(InvalidValueException, match=limb2_error_message): + serializer.serialize({"limb0": 652432, "limb1": 123, "limb2": -1, "limb3": 789}) + with pytest.raises(InvalidValueException, match=limb2_error_message): + serializer.serialize({"limb0": 0, "limb1": 123, "limb2": MAX_U128 + 1, "limb3": 789}) + + limb3_error_message = re.escape( + "Error at path 'limb3': expected value in range [0;2**128)" + ) + with pytest.raises(InvalidValueException, match=limb3_error_message): + serializer.serialize({"limb0": 652432, "limb1": 123, "limb2": 456, "limb3": -1}) + with pytest.raises(InvalidValueException, match=limb3_error_message): + serializer.serialize({"limb0": 0, "limb1": 123, "limb2": 456, "limb3": MAX_U128 + 1}) + + +def test_invalid_type(): + error_message = re.escape( + "Error: expected int or dict, received 'wololoo' of type ''." + ) + with pytest.raises(InvalidTypeException, match=error_message): + serializer.serialize("wololoo") # type: ignore \ No newline at end of file