55"""
66
77import asyncio
8+ from concurrent .futures import ProcessPoolExecutor
89import inspect
910import logging
1011import random
5657)
5758from async_substrate_interface .utils .storage import StorageKey
5859from async_substrate_interface .type_registry import _TYPE_REGISTRY
60+ from async_substrate_interface .utils .decoding_attempt import _decode_query_map , _decode_scale_with_runtime
5961
6062if TYPE_CHECKING :
6163 from websockets .asyncio .client import ClientConnection
@@ -413,6 +415,7 @@ def __init__(
413415 last_key : Optional [str ] = None ,
414416 max_results : Optional [int ] = None ,
415417 ignore_decoding_errors : bool = False ,
418+ executor : Optional ["ProcessPoolExecutor" ] = None
416419 ):
417420 self .records = records
418421 self .page_size = page_size
@@ -425,6 +428,7 @@ def __init__(
425428 self .params = params
426429 self .ignore_decoding_errors = ignore_decoding_errors
427430 self .loading_complete = False
431+ self .executor = executor
428432 self ._buffer = iter (self .records ) # Initialize the buffer with initial records
429433
430434 async def retrieve_next_page (self , start_key ) -> list :
@@ -437,6 +441,7 @@ async def retrieve_next_page(self, start_key) -> list:
437441 start_key = start_key ,
438442 max_results = self .max_results ,
439443 ignore_decoding_errors = self .ignore_decoding_errors ,
444+ executor = self .executor
440445 )
441446 if len (result .records ) < self .page_size :
442447 self .loading_complete = True
@@ -862,6 +867,7 @@ async def encode_scale(
862867 await self ._wait_for_registry (_attempt , _retries )
863868 return self ._encode_scale (type_string , value )
864869
870+
865871 async def decode_scale (
866872 self ,
867873 type_string : str ,
@@ -898,7 +904,7 @@ async def decode_scale(
898904 else :
899905 return obj
900906
901- async def load_runtime (self , runtime ):
907+ def load_runtime (self , runtime ):
902908 self .runtime = runtime
903909
904910 # Update type registry
@@ -954,7 +960,7 @@ async def init_runtime(
954960 )
955961
956962 if self .runtime and runtime_version == self .runtime .runtime_version :
957- return
963+ return self . runtime
958964
959965 runtime = self .runtime_cache .retrieve (runtime_version = runtime_version )
960966 if not runtime :
@@ -990,7 +996,7 @@ async def init_runtime(
990996 runtime_version = runtime_version , runtime = runtime
991997 )
992998
993- await self .load_runtime (runtime )
999+ self .load_runtime (runtime )
9941000
9951001 if self .ss58_format is None :
9961002 # Check and apply runtime constants
@@ -1000,6 +1006,7 @@ async def init_runtime(
10001006
10011007 if ss58_prefix_constant :
10021008 self .ss58_format = ss58_prefix_constant
1009+ return runtime
10031010
10041011 async def create_storage_key (
10051012 self ,
@@ -2858,6 +2865,7 @@ async def query_map(
28582865 page_size : int = 100 ,
28592866 ignore_decoding_errors : bool = False ,
28602867 reuse_block_hash : bool = False ,
2868+ executor : Optional ["ProcessPoolExecutor" ] = None
28612869 ) -> AsyncQueryMapResult :
28622870 """
28632871 Iterates over all key-pairs located at the given module and storage_function. The storage
@@ -2892,12 +2900,11 @@ async def query_map(
28922900 Returns:
28932901 AsyncQueryMapResult object
28942902 """
2895- hex_to_bytes_ = hex_to_bytes
28962903 params = params or []
28972904 block_hash = await self ._get_current_block_hash (block_hash , reuse_block_hash )
28982905 if block_hash :
28992906 self .last_block_hash = block_hash
2900- await self .init_runtime (block_hash = block_hash )
2907+ runtime = await self .init_runtime (block_hash = block_hash )
29012908
29022909 metadata_pallet = self .runtime .metadata .get_metadata_pallet (module )
29032910 if not metadata_pallet :
@@ -2952,19 +2959,6 @@ async def query_map(
29522959 result = []
29532960 last_key = None
29542961
2955- def concat_hash_len (key_hasher : str ) -> int :
2956- """
2957- Helper function to avoid if statements
2958- """
2959- if key_hasher == "Blake2_128Concat" :
2960- return 16
2961- elif key_hasher == "Twox64Concat" :
2962- return 8
2963- elif key_hasher == "Identity" :
2964- return 0
2965- else :
2966- raise ValueError ("Unsupported hash type" )
2967-
29682962 if len (result_keys ) > 0 :
29692963 last_key = result_keys [- 1 ]
29702964
@@ -2975,51 +2969,51 @@ def concat_hash_len(key_hasher: str) -> int:
29752969
29762970 if "error" in response :
29772971 raise SubstrateRequestException (response ["error" ]["message" ])
2978-
29792972 for result_group in response ["result" ]:
2980- for item in result_group ["changes" ]:
2981- try :
2982- # Determine type string
2983- key_type_string = []
2984- for n in range (len (params ), len (param_types )):
2985- key_type_string .append (
2986- f"[u8; { concat_hash_len (key_hashers [n ])} ]"
2987- )
2988- key_type_string .append (param_types [n ])
2989-
2990- item_key_obj = await self .decode_scale (
2991- type_string = f"({ ', ' .join (key_type_string )} )" ,
2992- scale_bytes = bytes .fromhex (item [0 ][len (prefix ) :]),
2993- return_scale_obj = True ,
2994- )
2995-
2996- # strip key_hashers to use as item key
2997- if len (param_types ) - len (params ) == 1 :
2998- item_key = item_key_obj [1 ]
2999- else :
3000- item_key = tuple (
3001- item_key_obj [key + 1 ]
3002- for key in range (len (params ), len (param_types ) + 1 , 2 )
3003- )
3004-
3005- except Exception as _ :
3006- if not ignore_decoding_errors :
3007- raise
3008- item_key = None
3009-
3010- try :
3011- item_bytes = hex_to_bytes_ (item [1 ])
3012-
3013- item_value = await self .decode_scale (
3014- type_string = value_type ,
3015- scale_bytes = item_bytes ,
3016- return_scale_obj = True ,
3017- )
3018- except Exception as _ :
3019- if not ignore_decoding_errors :
3020- raise
3021- item_value = None
3022- result .append ([item_key , item_value ])
2973+ if executor :
2974+ # print(
2975+ # ("prefix", type("prefix")),
2976+ # ("runtime_registry", type(runtime.registry)),
2977+ # ("param_types", type(param_types)),
2978+ # ("params", type(params)),
2979+ # ("value_type", type(value_type)),
2980+ # ("key_hasher", type(key_hashers)),
2981+ # ("ignore_decoding_errors", type(ignore_decoding_errors)),
2982+ # )
2983+ result = await asyncio .get_running_loop ().run_in_executor (
2984+ executor ,
2985+ _decode_query_map ,
2986+ result_group ["changes" ],
2987+ prefix ,
2988+ runtime .registry .registry ,
2989+ param_types ,
2990+ params ,
2991+ value_type , key_hashers , ignore_decoding_errors
2992+ )
2993+ # max_workers = executor._max_workers
2994+ # result_group_changes_groups = [result_group["changes"][i:i + max_workers] for i in range(0, len(result_group["changes"]), max_workers)]
2995+ # all_results = executor.map(
2996+ # self._decode_query_map,
2997+ # result_group["changes"],
2998+ # repeat(prefix),
2999+ # repeat(runtime.registry),
3000+ # repeat(param_types),
3001+ # repeat(params),
3002+ # repeat(value_type),
3003+ # repeat(key_hashers),
3004+ # repeat(ignore_decoding_errors)
3005+ # )
3006+ # for r in all_results:
3007+ # result.extend(r)
3008+ else :
3009+ result = _decode_query_map (
3010+ result_group ["changes" ],
3011+ prefix ,
3012+ runtime .registry .registry ,
3013+ param_types ,
3014+ params ,
3015+ value_type , key_hashers , ignore_decoding_errors
3016+ )
30233017 return AsyncQueryMapResult (
30243018 records = result ,
30253019 page_size = page_size ,
@@ -3031,6 +3025,7 @@ def concat_hash_len(key_hasher: str) -> int:
30313025 last_key = last_key ,
30323026 max_results = max_results ,
30333027 ignore_decoding_errors = ignore_decoding_errors ,
3028+ executor = executor
30343029 )
30353030
30363031 async def submit_extrinsic (
0 commit comments