Skip to content

Commit 8d02f01

Browse files
authored
Ignoring unknown keys in RPC responses by default (#1312)
* Allow ignoring unknown keys in RPC responses by set env * remove unknown=excluded from node_client * set execution_resources as optional for L1_HANDLER_TXN_TRACE schema
1 parent a2686e2 commit 8d02f01

File tree

4 files changed

+51
-36
lines changed

4 files changed

+51
-36
lines changed

starknet_py/net/client_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""
22
Dataclasses representing responses from Starknet.
33
They need to stay backwards compatible for old transactions/blocks to be fetchable.
4+
5+
If you encounter a ValidationError in the context of an RPC response, it is possible to disable validation.
6+
This can be achieved by setting the environment variable, STARKNET_PY_MARSHMALLOW_UKNOWN_EXCLUDE,
7+
to true. Consequently, any unknown fields in response will be excluded.
48
"""
59

610
import json
@@ -985,6 +989,7 @@ class L1HandlerTransactionTrace:
985989
Dataclass representing a transaction trace of an L1_HANDLER transaction.
986990
"""
987991

992+
execution_resources: Optional[ExecutionResources]
988993
function_invocation: FunctionInvocation
989994
state_diff: Optional[StateDiff] = None
990995

starknet_py/net/full_node_client.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ async def get_block_with_txs(
126126
) -> Union[StarknetBlock, PendingStarknetBlock]:
127127
return await self.get_block(block_hash=block_hash, block_number=block_number)
128128

129+
# TODO (#1323): remove unknown=EXCLUDE after devnet response fix
129130
async def get_block_with_tx_hashes(
130131
self,
131132
block_hash: Optional[Union[Hash, Tag]] = None,
@@ -147,7 +148,7 @@ async def get_block_with_tx_hashes(
147148
)
148149
return cast(
149150
StarknetBlockWithTxHashes,
150-
StarknetBlockWithTxHashesSchema().load(res, unknown=EXCLUDE),
151+
StarknetBlockWithTxHashesSchema().load(res),
151152
)
152153

153154
async def get_block_with_receipts(
@@ -167,11 +168,11 @@ async def get_block_with_receipts(
167168
if block_identifier == {"block_id": "pending"}:
168169
return cast(
169170
PendingStarknetBlockWithReceipts,
170-
PendingStarknetBlockWithReceiptsSchema().load(res, unknown=EXCLUDE),
171+
PendingStarknetBlockWithReceiptsSchema().load(res),
171172
)
172173
return cast(
173174
StarknetBlockWithReceipts,
174-
StarknetBlockWithReceiptsSchema().load(res, unknown=EXCLUDE),
175+
StarknetBlockWithReceiptsSchema().load(res),
175176
)
176177

177178
# TODO (#809): add tests with multiple emitted keys
@@ -281,6 +282,7 @@ async def _get_events_chunk(
281282
return res["events"], res["continuation_token"]
282283
return res["events"], None
283284

285+
# TODO (#1323): remove unknown=EXCLUDE after devnet fix response
284286
async def get_state_update(
285287
self,
286288
block_hash: Optional[Union[Hash, Tag]] = None,
@@ -300,9 +302,7 @@ async def get_state_update(
300302
PendingBlockStateUpdate,
301303
PendingBlockStateUpdateSchema().load(res, unknown=EXCLUDE),
302304
)
303-
return cast(
304-
BlockStateUpdate, BlockStateUpdateSchema().load(res, unknown=EXCLUDE)
305-
)
305+
return cast(BlockStateUpdate, BlockStateUpdateSchema().load(res))
306306

307307
async def get_storage_at(
308308
self,
@@ -337,7 +337,7 @@ async def get_transaction(
337337
)
338338
except ClientError as ex:
339339
raise TransactionNotReceivedError() from ex
340-
return cast(Transaction, TypesOfTransactionsSchema().load(res, unknown=EXCLUDE))
340+
return cast(Transaction, TypesOfTransactionsSchema().load(res))
341341

342342
async def get_l1_message_hash(self, tx_hash: Hash) -> Hash:
343343
"""
@@ -358,9 +358,7 @@ async def get_transaction_receipt(self, tx_hash: Hash) -> TransactionReceipt:
358358
method_name="getTransactionReceipt",
359359
params={"transaction_hash": _to_rpc_felt(tx_hash)},
360360
)
361-
return cast(
362-
TransactionReceipt, TransactionReceiptSchema().load(res, unknown=EXCLUDE)
363-
)
361+
return cast(TransactionReceipt, TransactionReceiptSchema().load(res))
364362

365363
async def estimate_fee(
366364
self,
@@ -392,9 +390,7 @@ async def estimate_fee(
392390

393391
return cast(
394392
EstimatedFee,
395-
EstimatedFeeSchema().load(
396-
res, unknown=EXCLUDE, many=(not single_transaction)
397-
),
393+
EstimatedFeeSchema().load(res, many=not single_transaction),
398394
)
399395

400396
async def estimate_message_fee(
@@ -440,7 +436,7 @@ async def estimate_message_fee(
440436
**block_identifier,
441437
},
442438
)
443-
return cast(EstimatedFee, EstimatedFeeSchema().load(res, unknown=EXCLUDE))
439+
return cast(EstimatedFee, EstimatedFeeSchema().load(res))
444440
except ClientError as err:
445441
if err.code == RPC_CONTRACT_ERROR:
446442
raise ClientError(
@@ -499,9 +495,7 @@ async def send_transaction(self, transaction: Invoke) -> SentTransactionResponse
499495
params={"invoke_transaction": params},
500496
)
501497

502-
return cast(
503-
SentTransactionResponse, SentTransactionSchema().load(res, unknown=EXCLUDE)
504-
)
498+
return cast(SentTransactionResponse, SentTransactionSchema().load(res))
505499

506500
async def deploy_account(
507501
self, transaction: DeployAccount
@@ -515,7 +509,7 @@ async def deploy_account(
515509

516510
return cast(
517511
DeployAccountTransactionResponse,
518-
DeployAccountTransactionResponseSchema().load(res, unknown=EXCLUDE),
512+
DeployAccountTransactionResponseSchema().load(res),
519513
)
520514

521515
async def declare(self, transaction: Declare) -> DeclareTransactionResponse:
@@ -528,7 +522,7 @@ async def declare(self, transaction: Declare) -> DeclareTransactionResponse:
528522

529523
return cast(
530524
DeclareTransactionResponse,
531-
DeclareTransactionResponseSchema().load(res, unknown=EXCLUDE),
525+
DeclareTransactionResponseSchema().load(res),
532526
)
533527

534528
async def get_class_hash_at(
@@ -571,9 +565,9 @@ async def get_class_by_hash(
571565
if "sierra_program" in res:
572566
return cast(
573567
SierraContractClass,
574-
SierraContractClassSchema().load(res, unknown=EXCLUDE),
568+
SierraContractClassSchema().load(res),
575569
)
576-
return cast(ContractClass, ContractClassSchema().load(res, unknown=EXCLUDE))
570+
return cast(ContractClass, ContractClassSchema().load(res))
577571

578572
# Only RPC methods
579573

@@ -602,7 +596,7 @@ async def get_transaction_by_block_id(
602596
"index": index,
603597
},
604598
)
605-
return cast(Transaction, TypesOfTransactionsSchema().load(res, unknown=EXCLUDE))
599+
return cast(Transaction, TypesOfTransactionsSchema().load(res))
606600

607601
async def get_block_transaction_count(
608602
self,
@@ -656,9 +650,9 @@ async def get_class_at(
656650
if "sierra_program" in res:
657651
return cast(
658652
SierraContractClass,
659-
SierraContractClassSchema().load(res, unknown=EXCLUDE),
653+
SierraContractClassSchema().load(res),
660654
)
661-
return cast(ContractClass, ContractClassSchema().load(res, unknown=EXCLUDE))
655+
return cast(ContractClass, ContractClassSchema().load(res))
662656

663657
async def get_contract_nonce(
664658
self,
@@ -698,7 +692,7 @@ async def get_transaction_status(self, tx_hash: Hash) -> TransactionStatusRespon
698692
)
699693
return cast(
700694
TransactionStatusResponse,
701-
TransactionStatusResponseSchema().load(res, unknown=EXCLUDE),
695+
TransactionStatusResponseSchema().load(res),
702696
)
703697

704698
# ------------------------------- Trace API -------------------------------
@@ -719,9 +713,7 @@ async def trace_transaction(
719713
"transaction_hash": _to_rpc_felt(tx_hash),
720714
},
721715
)
722-
return cast(
723-
TransactionTrace, TransactionTraceSchema().load(res, unknown=EXCLUDE)
724-
)
716+
return cast(TransactionTrace, TransactionTraceSchema().load(res))
725717

726718
async def simulate_transactions(
727719
self,
@@ -772,7 +764,7 @@ async def simulate_transactions(
772764
)
773765
return cast(
774766
List[SimulatedTransaction],
775-
SimulatedTransactionSchema().load(res, unknown=EXCLUDE, many=True),
767+
SimulatedTransactionSchema().load(res, many=True),
776768
)
777769

778770
async def trace_block_transactions(
@@ -799,7 +791,7 @@ async def trace_block_transactions(
799791
)
800792
return cast(
801793
List[BlockTransactionTrace],
802-
BlockTransactionTraceSchema().load(res, unknown=EXCLUDE, many=True),
794+
BlockTransactionTraceSchema().load(res, many=True),
803795
)
804796

805797

starknet_py/net/schemas/rpc.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pylint: disable=too-many-lines
22

3-
from marshmallow import EXCLUDE, Schema, fields, post_load
3+
from marshmallow import EXCLUDE, fields, post_load
44
from marshmallow_oneofschema import OneOfSchema
55

66
from starknet_py.abi.v0.schemas import ContractAbiEntrySchema
@@ -87,6 +87,7 @@
8787
Uint128,
8888
)
8989
from starknet_py.net.schemas.utils import _extract_tx_version
90+
from starknet_py.utils.schema import Schema
9091

9192
# pylint: disable=unused-argument, no-self-use
9293

@@ -506,7 +507,7 @@ class BlockHeaderSchema(Schema):
506507

507508
class PendingStarknetBlockSchema(PendingBlockHeaderSchema):
508509
transactions = fields.List(
509-
fields.Nested(TypesOfTransactionsSchema(unknown=EXCLUDE)),
510+
fields.Nested(TypesOfTransactionsSchema()),
510511
data_key="transactions",
511512
required=True,
512513
)
@@ -519,7 +520,7 @@ def make_dataclass(self, data, **kwargs) -> PendingStarknetBlock:
519520
class StarknetBlockSchema(BlockHeaderSchema):
520521
status = BlockStatusField(data_key="status", required=True)
521522
transactions = fields.List(
522-
fields.Nested(TypesOfTransactionsSchema(unknown=EXCLUDE)),
523+
fields.Nested(TypesOfTransactionsSchema()),
523524
data_key="transactions",
524525
required=True,
525526
)
@@ -964,6 +965,11 @@ def make_dataclass(self, data, **kwargs) -> DeployAccountTransactionTrace:
964965

965966

966967
class L1HandlerTransactionTraceSchema(Schema):
968+
# TODO (#1323): Explain with starknet, because spec doesn't contain execution_resources
969+
execution_resources = fields.Nested(
970+
ExecutionResourcesSchema(), data_key="execution_resources", load_default=None
971+
)
972+
967973
function_invocation = fields.Nested(
968974
FunctionInvocationSchema(), data_key="function_invocation", required=True
969975
)
@@ -977,16 +983,15 @@ def make_dataclass(self, data, **kwargs) -> L1HandlerTransactionTrace:
977983

978984

979985
class TransactionTraceSchema(OneOfSchema):
986+
type_field = "type"
987+
980988
type_schemas = {
981989
"INVOKE": InvokeTransactionTraceSchema(),
982990
"DECLARE": DeclareTransactionTraceSchema(),
983991
"DEPLOY_ACCOUNT": DeployAccountTransactionTraceSchema(),
984992
"L1_HANDLER": L1HandlerTransactionTraceSchema(),
985993
}
986994

987-
def get_data_type(self, data):
988-
return data["type"]
989-
990995

991996
class SimulatedTransactionSchema(Schema):
992997
# `unknown=EXCLUDE` in order to skip `type=...` field we don't want

starknet_py/utils/schema.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
3+
from marshmallow import EXCLUDE, RAISE
4+
from marshmallow import Schema as MarshmallowSchema
5+
6+
MARSHMALLOW_UKNOWN_EXCLUDE = os.environ.get("STARKNET_PY_MARSHMALLOW_UKNOWN_EXCLUDE")
7+
8+
9+
class Schema(MarshmallowSchema):
10+
class Meta:
11+
unknown = (
12+
EXCLUDE if (MARSHMALLOW_UKNOWN_EXCLUDE or "").lower() == "true" else RAISE
13+
)

0 commit comments

Comments
 (0)