Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions starknet_py/cairo/felt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
MAX_UINT256 = (1 << 256) - 1
MIN_UINT256 = 0

MAX_UINT512 = (1 << 512) - 1
MIN_UINT512 = 0


def uint256_range_check(value: int):
if not MIN_UINT256 <= value <= MAX_UINT256:
Expand All @@ -16,6 +19,13 @@ def uint256_range_check(value: int):
)


def uint512_range_check(value: int):
if not MIN_UINT512 <= value <= MAX_UINT512:
raise ValueError(
f"Uint512 is expected to be in range [0;2**512), got: {value}."
)


MIN_FELT = -FIELD_PRIME // 2
MAX_FELT = FIELD_PRIME // 2

Expand Down
1 change: 1 addition & 0 deletions starknet_py/serialization/data_serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .struct_serializer import StructSerializer
from .tuple_serializer import TupleSerializer
from .uint256_serializer import Uint256Serializer
from .uint512_serializer import Uint512Serializer
from .uint_serializer import UintSerializer
94 changes: 94 additions & 0 deletions starknet_py/serialization/data_serializers/uint512_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from dataclasses import dataclass
from typing import Generator, TypedDict, Union

from starknet_py.cairo.felt import uint512_range_check
from starknet_py.serialization._context import (
Context,
DeserializationContext,
SerializationContext,
)
from starknet_py.serialization.data_serializers.cairo_data_serializer import (
CairoDataSerializer,
)

U128_UPPER_BOUND = 2**128


class Uint512Dict(TypedDict):
limb0: int
limb1: int
limb2: int
limb3: int


@dataclass
class Uint512Serializer(CairoDataSerializer[Union[int, Uint512Dict], int]):
"""
Serializer of Uint512. In Cairo it is represented by structure {limb0: Uint128, limb1: Uint128, limb2: Uint128, limb3: Uint128}.
Can serialize an int.
Deserializes data to an int.

Examples:
0 => [0,0,0,0]
1 => [1,0,0,0]
2**128 => [0,1,0,0]
2**256 => [0,0,1,0]
2**384 => [0,0,0,1]
3 + 2**128 => [3,1,0,0]
"""

def deserialize_with_context(self, context: DeserializationContext) -> int:
[limb0, limb1, limb2, limb3] = context.reader.read(4)

# Checking if resulting value is in [0, 2**512) range is not enough. Uint512 should be made of four uint128.
with context.push_entity("limb0"):
self._ensure_valid_uint128(limb0, context)
with context.push_entity("limb1"):
self._ensure_valid_uint128(limb1, context)
with context.push_entity("limb2"):
self._ensure_valid_uint128(limb2, context)
with context.push_entity("limb3"):
self._ensure_valid_uint128(limb3, context)

return (limb3 << 384) + (limb2 << 256) + (limb1 << 128) + limb0

def serialize_with_context(
self, context: SerializationContext, value: Union[int, Uint512Dict]
) -> Generator[int, None, None]:
context.ensure_valid_type(value, isinstance(value, (int, dict)), "int or dict")
if isinstance(value, int):
yield from self._serialize_from_int(value)
else:
yield from self._serialize_from_dict(context, value)

@staticmethod
def _serialize_from_int(value: int) -> Generator[int, None, None]:
uint512_range_check(value)
limb0 = value % (2**128)
limb1 = (value >> 128) % (2**128)
limb2 = (value >> 256) % (2**128)
limb3 = (value >> 384) % (2**128)
result = (limb0, limb1, limb2, limb3)
yield from result

def _serialize_from_dict(
self, context: SerializationContext, value: Uint512Dict
) -> Generator[int, None, None]:
with context.push_entity("limb0"):
self._ensure_valid_uint128(value["limb0"], context)
yield value["limb0"]
with context.push_entity("limb1"):
self._ensure_valid_uint128(value["limb1"], context)
yield value["limb1"]
with context.push_entity("limb2"):
self._ensure_valid_uint128(value["limb2"], context)
yield value["limb2"]
with context.push_entity("limb3"):
self._ensure_valid_uint128(value["limb3"], context)
yield value["limb3"]

@staticmethod
def _ensure_valid_uint128(value: int, context: Context):
context.ensure_valid_value(
0 <= value < U128_UPPER_BOUND, "expected value in range [0;2**128)"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import re

import pytest

from starknet_py.serialization.data_serializers.uint512_serializer import (
Uint512Serializer,
)
from starknet_py.serialization.errors import InvalidTypeException, InvalidValueException

serializer = Uint512Serializer()
SHIFT_128 = 2**128
SHIFT_256 = 2**256
SHIFT_384 = 2**384
MAX_U128 = SHIFT_128 - 1


@pytest.mark.parametrize(
"value, serialized_value",
[
(123 + 456 * SHIFT_128 + 789 * SHIFT_256 + 101 * SHIFT_384, [123, 456, 789, 101]),
(
21323213211421424142 + 347932774343 * SHIFT_128 + 987654321 * SHIFT_256 + 123456789 * SHIFT_384,
[21323213211421424142, 347932774343, 987654321, 123456789],
),
(0, [0, 0, 0, 0]),
(MAX_U128, [MAX_U128, 0, 0, 0]),
(MAX_U128 * SHIFT_128, [0, MAX_U128, 0, 0]),
(MAX_U128 * SHIFT_256, [0, 0, MAX_U128, 0]),
(MAX_U128 * SHIFT_384, [0, 0, 0, MAX_U128]),
(MAX_U128 + MAX_U128 * SHIFT_128 + MAX_U128 * SHIFT_256 + MAX_U128 * SHIFT_384, [MAX_U128, MAX_U128, MAX_U128, MAX_U128]),
(1, [1, 0, 0, 0]),
(SHIFT_128, [0, 1, 0, 0]),
(SHIFT_256, [0, 0, 1, 0]),
(SHIFT_384, [0, 0, 0, 1]),
],
)
def test_valid_values(value, serialized_value):
deserialized = serializer.deserialize(serialized_value)
assert deserialized == value

serialized = serializer.serialize(value)
assert serialized == serialized_value

assert serialized_value == serializer.serialize(
{"limb0": serialized_value[0], "limb1": serialized_value[1], "limb2": serialized_value[2], "limb3": serialized_value[3]}
)


def test_deserialize_invalid_values():
# We need to escape braces
limb0_error_message = re.escape(
"Error at path 'limb0': expected value in range [0;2**128)"
)
with pytest.raises(InvalidValueException, match=limb0_error_message):
serializer.deserialize([MAX_U128 + 1, 0, 0, 0])
with pytest.raises(InvalidValueException, match=limb0_error_message):
serializer.deserialize([MAX_U128 + 1, MAX_U128 + 1, MAX_U128 + 1, MAX_U128 + 1])
with pytest.raises(InvalidValueException, match=limb0_error_message):
serializer.deserialize([-1, 0, 0, 0])

limb1_error_message = re.escape(
"Error at path 'limb1': expected value in range [0;2**128)"
)
with pytest.raises(InvalidValueException, match=limb1_error_message):
serializer.deserialize([0, MAX_U128 + 1, 0, 0])
with pytest.raises(InvalidValueException, match=limb1_error_message):
serializer.deserialize([0, -1, 0, 0])

limb2_error_message = re.escape(
"Error at path 'limb2': expected value in range [0;2**128)"
)
with pytest.raises(InvalidValueException, match=limb2_error_message):
serializer.deserialize([0, 0, MAX_U128 + 1, 0])
with pytest.raises(InvalidValueException, match=limb2_error_message):
serializer.deserialize([0, 0, -1, 0])

limb3_error_message = re.escape(
"Error at path 'limb3': expected value in range [0;2**128)"
)
with pytest.raises(InvalidValueException, match=limb3_error_message):
serializer.deserialize([0, 0, 0, MAX_U128 + 1])
with pytest.raises(InvalidValueException, match=limb3_error_message):
serializer.deserialize([0, 0, 0, -1])


def test_serialize_invalid_int_value():
error_message = re.escape("Error: Uint512 is expected to be in range [0;2**512)")
with pytest.raises(InvalidValueException, match=error_message):
serializer.serialize(2**512)
with pytest.raises(InvalidValueException, match=error_message):
serializer.serialize(-1)


def test_serialize_invalid_dict_values():
limb0_error_message = re.escape(
"Error at path 'limb0': expected value in range [0;2**128)"
)
with pytest.raises(InvalidValueException, match=limb0_error_message):
serializer.serialize({"limb0": -1, "limb1": 12324, "limb2": 456, "limb3": 789})
with pytest.raises(InvalidValueException, match=limb0_error_message):
serializer.serialize({"limb0": MAX_U128 + 1, "limb1": 4543535, "limb2": 456, "limb3": 789})

limb1_error_message = re.escape(
"Error at path 'limb1': expected value in range [0;2**128)"
)
with pytest.raises(InvalidValueException, match=limb1_error_message):
serializer.serialize({"limb0": 652432, "limb1": -1, "limb2": 456, "limb3": 789})
with pytest.raises(InvalidValueException, match=limb1_error_message):
serializer.serialize({"limb0": 0, "limb1": MAX_U128 + 1, "limb2": 456, "limb3": 789})

limb2_error_message = re.escape(
"Error at path 'limb2': expected value in range [0;2**128)"
)
with pytest.raises(InvalidValueException, match=limb2_error_message):
serializer.serialize({"limb0": 652432, "limb1": 123, "limb2": -1, "limb3": 789})
with pytest.raises(InvalidValueException, match=limb2_error_message):
serializer.serialize({"limb0": 0, "limb1": 123, "limb2": MAX_U128 + 1, "limb3": 789})

limb3_error_message = re.escape(
"Error at path 'limb3': expected value in range [0;2**128)"
)
with pytest.raises(InvalidValueException, match=limb3_error_message):
serializer.serialize({"limb0": 652432, "limb1": 123, "limb2": 456, "limb3": -1})
with pytest.raises(InvalidValueException, match=limb3_error_message):
serializer.serialize({"limb0": 0, "limb1": 123, "limb2": 456, "limb3": MAX_U128 + 1})


def test_invalid_type():
error_message = re.escape(
"Error: expected int or dict, received 'wololoo' of type '<class 'str'>'."
)
with pytest.raises(InvalidTypeException, match=error_message):
serializer.serialize("wololoo") # type: ignore
Loading