Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,10 @@ async def __aenter__(self):
await self.initialize()
return self

async def initialize(self):
async def initialize(self) -> None:
await self._initialize()

async def _initialize(self) -> None:
"""
Initialize the connection to the chain.
"""
Expand All @@ -1117,7 +1120,7 @@ async def initialize(self):
self._initializing = False

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.ws.shutdown()
await self.close()

@property
def metadata(self):
Expand Down Expand Up @@ -2339,7 +2342,6 @@ async def get_block_metadata(
"MetadataVersioned", data=ScaleBytes(result)
)
metadata_decoder.decode()

return metadata_decoder
else:
return result
Expand Down Expand Up @@ -4171,17 +4173,21 @@ class DiskCachedAsyncSubstrateInterface(AsyncSubstrateInterface):
Experimental new class that uses disk-caching in addition to memory-caching for the cached methods
"""

async def initialize(self) -> None:
await self.runtime_cache.load_from_disk(self.url)
await self._initialize()

async def close(self):
"""
Closes the substrate connection, and the websocket connection.
"""
try:
await self.runtime_cache.dump_to_disk(self.url)
await self.ws.shutdown()
except AttributeError:
pass
db_conn = AsyncSqliteDB(self.url)
if db_conn._db is not None:
await db_conn._db.close()
db = AsyncSqliteDB(self.url)
await db.close()

@async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE)
async def get_parent_block_hash(self, block_hash):
Expand Down
108 changes: 84 additions & 24 deletions async_substrate_interface/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC
from collections import defaultdict, deque
from collections.abc import Iterable
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Union, Any
Expand All @@ -16,6 +17,7 @@

from .const import SS58_FORMAT
from .utils import json
from .utils.cache import AsyncSqliteDB

logger = logging.getLogger("async_substrate_interface")

Expand All @@ -34,8 +36,8 @@ class RuntimeCache:
is important you are utilizing the correct version.
"""

blocks: dict[int, "Runtime"]
block_hashes: dict[str, "Runtime"]
blocks: dict[int, str]
block_hashes: dict[str, int]
versions: dict[int, "Runtime"]
last_used: Optional["Runtime"]

Expand All @@ -56,10 +58,10 @@ def add_item(
Adds a Runtime object to the cache mapped to its version, block number, and/or block hash.
"""
self.last_used = runtime
if block is not None:
self.blocks[block] = runtime
if block_hash is not None:
self.block_hashes[block_hash] = runtime
if block is not None and block_hash is not None:
self.blocks[block] = block_hash
if block_hash is not None and runtime_version is not None:
self.block_hashes[block_hash] = runtime_version
if runtime_version is not None:
self.versions[runtime_version] = runtime

Expand All @@ -73,33 +75,52 @@ def retrieve(
Retrieves a Runtime object from the cache, using the key of its block number, block hash, or runtime version.
Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`.
"""
runtime = None
if block is not None:
runtime = self.blocks.get(block)
if runtime is not None:
if block_hash is not None:
# if lookup occurs for block_hash and block, but only block matches, also map to block_hash
self.add_item(runtime, block_hash=block_hash)
if block_hash is not None:
self.blocks[block] = block_hash
if runtime_version is not None:
self.block_hashes[block_hash] = runtime_version
with suppress(KeyError):
runtime = self.versions[self.block_hashes[self.blocks[block]]]
self.last_used = runtime
return runtime
if block_hash is not None:
runtime = self.block_hashes.get(block_hash)
if runtime is not None:
if block is not None:
# if lookup occurs for block_hash and block, but only block_hash matches, also map to block
self.add_item(runtime, block=block)
if runtime_version is not None:
self.block_hashes[block_hash] = runtime_version
with suppress(KeyError):
runtime = self.versions[self.block_hashes[block_hash]]
self.last_used = runtime
return runtime
if runtime_version is not None:
runtime = self.versions.get(runtime_version)
if runtime is not None:
# if runtime_version matches, also map to block and block_hash (if supplied)
if block is not None:
self.add_item(runtime, block=block)
if block_hash is not None:
self.add_item(runtime, block_hash=block_hash)
with suppress(KeyError):
runtime = self.versions[runtime_version]
self.last_used = runtime
return runtime
return None
return runtime

async def load_from_disk(self, chain_endpoint: str):
db = AsyncSqliteDB(chain_endpoint=chain_endpoint)
(
block_mapping,
block_hash_mapping,
runtime_version_mapping,
) = await db.load_runtime_cache(chain_endpoint)
if not any([block_mapping, block_hash_mapping, runtime_version_mapping]):
logger.debug("No runtime mappings in disk cache")
else:
logger.debug("Found runtime mappings in disk cache")
self.blocks = block_mapping
self.block_hashes = block_hash_mapping
self.versions = {
x: Runtime.deserialize(y) for x, y in runtime_version_mapping.items()
}

async def dump_to_disk(self, chain_endpoint: str):
db = AsyncSqliteDB(chain_endpoint=chain_endpoint)
await db.dump_runtime_cache(
chain_endpoint, self.blocks, self.block_hashes, self.versions
)


class Runtime:
Expand Down Expand Up @@ -149,6 +170,45 @@ def __init__(
if registry is not None:
self.load_registry_type_map()

def serialize(self):
metadata_value = self.metadata.data.data
return {
"chain": self.chain,
"type_registry": self.type_registry,
"metadata_value": metadata_value,
"metadata_v15": self.metadata_v15.encode_to_metadata_option(),
"runtime_info": {
"specVersion": self.runtime_version,
"transactionVersion": self.transaction_version,
},
"registry": self.registry.registry if self.registry is not None else None,
"ss58_format": self.ss58_format,
}

@classmethod
def deserialize(cls, serialized: dict) -> "Runtime":
ss58_format = serialized["ss58_format"]
runtime_config = RuntimeConfigurationObject(ss58_format=ss58_format)
runtime_config.clear_type_registry()
runtime_config.update_type_registry(load_type_registry_preset(name="core"))
metadata = runtime_config.create_scale_object(
"MetadataVersioned", data=ScaleBytes(serialized["metadata_value"])
)
metadata.decode()
registry = PortableRegistry.from_json(serialized["registry"])
return cls(
chain=serialized["chain"],
metadata=metadata,
type_registry=serialized["type_registry"],
runtime_config=runtime_config,
metadata_v15=MetadataV15.decode_from_metadata_option(
serialized["metadata_v15"]
),
registry=registry,
ss58_format=ss58_format,
runtime_info=serialized["runtime_info"],
)

def load_runtime(self):
"""
Initial loading of the runtime's type registry information.
Expand Down
Loading
Loading