diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index b9c2429..f584e8a 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -1092,7 +1092,10 @@ async def __aenter__(self): await self.initialize() return self - async def initialize(self): + async def initialize(self) -> None: + await self._initialize() + + async def _initialize(self) -> None: """ Initialize the connection to the chain. """ @@ -1117,7 +1120,7 @@ async def initialize(self): self._initializing = False async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.ws.shutdown() + await self.close() @property def metadata(self): @@ -2339,7 +2342,6 @@ async def get_block_metadata( "MetadataVersioned", data=ScaleBytes(result) ) metadata_decoder.decode() - return metadata_decoder else: return result @@ -4171,17 +4173,21 @@ class DiskCachedAsyncSubstrateInterface(AsyncSubstrateInterface): Experimental new class that uses disk-caching in addition to memory-caching for the cached methods """ + async def initialize(self) -> None: + await self.runtime_cache.load_from_disk(self.url) + await self._initialize() + async def close(self): """ Closes the substrate connection, and the websocket connection. """ try: + await self.runtime_cache.dump_to_disk(self.url) await self.ws.shutdown() except AttributeError: pass - db_conn = AsyncSqliteDB(self.url) - if db_conn._db is not None: - await db_conn._db.close() + db = AsyncSqliteDB(self.url) + await db.close() @async_sql_lru_cache(maxsize=SUBSTRATE_CACHE_METHOD_SIZE) async def get_parent_block_hash(self, block_hash): diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index a627132..e212515 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -2,6 +2,7 @@ from abc import ABC from collections import defaultdict, deque from collections.abc import Iterable +from contextlib import suppress from dataclasses import dataclass from datetime import datetime from typing import Optional, Union, Any @@ -16,6 +17,7 @@ from .const import SS58_FORMAT from .utils import json +from .utils.cache import AsyncSqliteDB logger = logging.getLogger("async_substrate_interface") @@ -34,8 +36,8 @@ class RuntimeCache: is important you are utilizing the correct version. """ - blocks: dict[int, "Runtime"] - block_hashes: dict[str, "Runtime"] + blocks: dict[int, str] + block_hashes: dict[str, int] versions: dict[int, "Runtime"] last_used: Optional["Runtime"] @@ -56,10 +58,10 @@ def add_item( Adds a Runtime object to the cache mapped to its version, block number, and/or block hash. """ self.last_used = runtime - if block is not None: - self.blocks[block] = runtime - if block_hash is not None: - self.block_hashes[block_hash] = runtime + if block is not None and block_hash is not None: + self.blocks[block] = block_hash + if block_hash is not None and runtime_version is not None: + self.block_hashes[block_hash] = runtime_version if runtime_version is not None: self.versions[runtime_version] = runtime @@ -73,33 +75,52 @@ def retrieve( Retrieves a Runtime object from the cache, using the key of its block number, block hash, or runtime version. Retrieval happens in this order. If no Runtime is found mapped to any of your supplied keys, returns `None`. """ + runtime = None if block is not None: - runtime = self.blocks.get(block) - if runtime is not None: - if block_hash is not None: - # if lookup occurs for block_hash and block, but only block matches, also map to block_hash - self.add_item(runtime, block_hash=block_hash) + if block_hash is not None: + self.blocks[block] = block_hash + if runtime_version is not None: + self.block_hashes[block_hash] = runtime_version + with suppress(KeyError): + runtime = self.versions[self.block_hashes[self.blocks[block]]] self.last_used = runtime return runtime if block_hash is not None: - runtime = self.block_hashes.get(block_hash) - if runtime is not None: - if block is not None: - # if lookup occurs for block_hash and block, but only block_hash matches, also map to block - self.add_item(runtime, block=block) + if runtime_version is not None: + self.block_hashes[block_hash] = runtime_version + with suppress(KeyError): + runtime = self.versions[self.block_hashes[block_hash]] self.last_used = runtime return runtime if runtime_version is not None: - runtime = self.versions.get(runtime_version) - if runtime is not None: - # if runtime_version matches, also map to block and block_hash (if supplied) - if block is not None: - self.add_item(runtime, block=block) - if block_hash is not None: - self.add_item(runtime, block_hash=block_hash) + with suppress(KeyError): + runtime = self.versions[runtime_version] self.last_used = runtime return runtime - return None + return runtime + + async def load_from_disk(self, chain_endpoint: str): + db = AsyncSqliteDB(chain_endpoint=chain_endpoint) + ( + block_mapping, + block_hash_mapping, + runtime_version_mapping, + ) = await db.load_runtime_cache(chain_endpoint) + if not any([block_mapping, block_hash_mapping, runtime_version_mapping]): + logger.debug("No runtime mappings in disk cache") + else: + logger.debug("Found runtime mappings in disk cache") + self.blocks = block_mapping + self.block_hashes = block_hash_mapping + self.versions = { + x: Runtime.deserialize(y) for x, y in runtime_version_mapping.items() + } + + async def dump_to_disk(self, chain_endpoint: str): + db = AsyncSqliteDB(chain_endpoint=chain_endpoint) + await db.dump_runtime_cache( + chain_endpoint, self.blocks, self.block_hashes, self.versions + ) class Runtime: @@ -149,6 +170,45 @@ def __init__( if registry is not None: self.load_registry_type_map() + def serialize(self): + metadata_value = self.metadata.data.data + return { + "chain": self.chain, + "type_registry": self.type_registry, + "metadata_value": metadata_value, + "metadata_v15": self.metadata_v15.encode_to_metadata_option(), + "runtime_info": { + "specVersion": self.runtime_version, + "transactionVersion": self.transaction_version, + }, + "registry": self.registry.registry if self.registry is not None else None, + "ss58_format": self.ss58_format, + } + + @classmethod + def deserialize(cls, serialized: dict) -> "Runtime": + ss58_format = serialized["ss58_format"] + runtime_config = RuntimeConfigurationObject(ss58_format=ss58_format) + runtime_config.clear_type_registry() + runtime_config.update_type_registry(load_type_registry_preset(name="core")) + metadata = runtime_config.create_scale_object( + "MetadataVersioned", data=ScaleBytes(serialized["metadata_value"]) + ) + metadata.decode() + registry = PortableRegistry.from_json(serialized["registry"]) + return cls( + chain=serialized["chain"], + metadata=metadata, + type_registry=serialized["type_registry"], + runtime_config=runtime_config, + metadata_v15=MetadataV15.decode_from_metadata_option( + serialized["metadata_v15"] + ), + registry=registry, + ss58_format=ss58_format, + runtime_info=serialized["runtime_info"], + ) + def load_runtime(self): """ Initial loading of the runtime's type registry information. diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 5cf1fe4..44521a5 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -20,6 +20,7 @@ if USE_CACHE else ":memory:" ) +SUBSTRATE_CACHE_METHOD_SIZE = int(os.getenv("SUBSTRATE_CACHE_METHOD_SIZE", "512")) logger = logging.getLogger("async_substrate_interface") @@ -38,13 +39,13 @@ def __new__(cls, chain_endpoint: str): cls._instances[chain_endpoint] = instance return instance - async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]: + async def close(self): async with self._lock: - if not self._db: - _ensure_dir() - self._db = await aiosqlite.connect(CACHE_LOCATION) - table_name = _get_table_name(func) - key = None + if self._db: + await self._db.close() + self._db = None + + async def _create_if_not_exists(self, chain: str, table_name: str): if not (local_chain := _check_if_local(chain)) or not USE_CACHE: await self._db.execute( f""" @@ -54,7 +55,8 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] key BLOB, value BLOB, chain TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(key, chain) ); """ ) @@ -66,25 +68,34 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] WHERE rowid IN ( SELECT rowid FROM {table_name} ORDER BY created_at DESC - LIMIT -1 OFFSET 500 + LIMIT -1 OFFSET {SUBSTRATE_CACHE_METHOD_SIZE} ); END; """ ) await self._db.commit() - key = pickle.dumps((args, kwargs or None)) - try: - cursor: aiosqlite.Cursor = await self._db.execute( - f"SELECT value FROM {table_name} WHERE key=? AND chain=?", - (key, chain), - ) - result = await cursor.fetchone() - await cursor.close() - if result is not None: - return pickle.loads(result[0]) - except (pickle.PickleError, sqlite3.Error) as e: - logger.exception("Cache error", exc_info=e) - pass + return local_chain + + async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]: + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + table_name = _get_table_name(func) + local_chain = await self._create_if_not_exists(chain, table_name) + key = pickle.dumps((args, kwargs or None)) + try: + cursor: aiosqlite.Cursor = await self._db.execute( + f"SELECT value FROM {table_name} WHERE key=? AND chain=?", + (key, chain), + ) + result = await cursor.fetchone() + await cursor.close() + if result is not None: + return pickle.loads(result[0]) + except (pickle.PickleError, sqlite3.Error) as e: + logger.exception("Cache error", exc_info=e) + pass result = await func(other_self, *args, **kwargs) if not local_chain or not USE_CACHE: # TODO use a task here @@ -95,6 +106,85 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any] await self._db.commit() return result + async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]: + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + block_mapping = {} + block_hash_mapping = {} + version_mapping = {} + tables = { + "RuntimeCache_blocks": block_mapping, + "RuntimeCache_block_hashes": block_hash_mapping, + "RuntimeCache_versions": version_mapping, + } + for table in tables.keys(): + async with self._lock: + local_chain = await self._create_if_not_exists(chain, table) + if local_chain: + return {}, {}, {} + for table_name, mapping in tables.items(): + try: + async with self._lock: + cursor: aiosqlite.Cursor = await self._db.execute( + f"SELECT key, value FROM {table_name} WHERE chain=?", + (chain,), + ) + results = await cursor.fetchall() + await cursor.close() + if results is None: + continue + for row in results: + key, value = row + runtime = pickle.loads(value) + mapping[key] = runtime + except (pickle.PickleError, sqlite3.Error) as e: + logger.exception("Cache error", exc_info=e) + return {}, {}, {} + return block_mapping, block_hash_mapping, version_mapping + + async def dump_runtime_cache( + self, + chain: str, + block_mapping: dict, + block_hash_mapping: dict, + version_mapping: dict, + ) -> None: + async with self._lock: + if not self._db: + _ensure_dir() + self._db = await aiosqlite.connect(CACHE_LOCATION) + + tables = { + "RuntimeCache_blocks": block_mapping, + "RuntimeCache_block_hashes": block_hash_mapping, + "RuntimeCache_versions": version_mapping, + } + for table, mapping in tables.items(): + local_chain = await self._create_if_not_exists(chain, table) + if local_chain: + return None + serialized_mapping = {} + for key, value in mapping.items(): + if not isinstance(value, (str, int)): + serialized_value = pickle.dumps(value.serialize()) + else: + serialized_value = pickle.dumps(value) + serialized_mapping[key] = serialized_value + + await self._db.executemany( + f"INSERT OR REPLACE INTO {table} (key, value, chain) VALUES (?,?,?)", + [ + (key, serialized_value_, chain) + for key, serialized_value_ in serialized_mapping.items() + ], + ) + + await self._db.commit() + + return None + def _ensure_dir(): path = Path(CACHE_LOCATION).parent @@ -119,7 +209,8 @@ def _create_table(c, conn, table_name): key BLOB, value BLOB, chain TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(key, chain) ); """ ) @@ -130,7 +221,7 @@ def _create_table(c, conn, table_name): WHERE rowid IN ( SELECT rowid FROM {table_name} ORDER BY created_at DESC - LIMIT -1 OFFSET 500 + LIMIT -1 OFFSET {SUBSTRATE_CACHE_METHOD_SIZE} ); END;""" ) @@ -205,7 +296,7 @@ def inner(self, *args, **kwargs): def async_sql_lru_cache(maxsize: Optional[int] = None): def decorator(func): - @cached_fetcher(max_size=maxsize) + @cached_fetcher(max_size=maxsize, cache_key_index=None) async def inner(self, *args, **kwargs): async_sql_db = AsyncSqliteDB(self.url) result = await async_sql_db(self.url, self, func, args, kwargs) @@ -353,7 +444,7 @@ def __get__(self, instance, owner): return self._instances[instance] -def cached_fetcher(max_size: Optional[int] = None, cache_key_index: int = 0): +def cached_fetcher(max_size: Optional[int] = None, cache_key_index: Optional[int] = 0): """Wrapper for CachedFetcher. See example in CachedFetcher docstring.""" def wrapper(method): diff --git a/pyproject.toml b/pyproject.toml index 94bac0d..14a1b31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,11 +8,11 @@ keywords = ["substrate", "development", "bittensor"] dependencies = [ "wheel", + "aiosqlite>=0.21.0,<1.0.0", "bt-decode==v0.8.0", "scalecodec~=1.2.11", "websockets>=14.1", "xxhash", - "aiosqlite>=0.21.0,<1.0.0" ] requires-python = ">=3.9,<3.15" diff --git a/tests/unit_tests/test_types.py b/tests/unit_tests/test_types.py index 7292177..f2e13b4 100644 --- a/tests/unit_tests/test_types.py +++ b/tests/unit_tests/test_types.py @@ -1,4 +1,12 @@ from async_substrate_interface.types import ScaleObj, Runtime, RuntimeCache +from async_substrate_interface.async_substrate import DiskCachedAsyncSubstrateInterface +from async_substrate_interface.utils import cache + +import sqlite3 +import os +import pickle +import pytest +from unittest.mock import patch def test_scale_object(): @@ -72,13 +80,83 @@ def test_runtime_cache(): # cache does not yet know that new_fake_block has the same runtime assert runtime_cache.retrieve(new_fake_block) is None assert ( - runtime_cache.retrieve(new_fake_block, runtime_version=fake_version) is not None + runtime_cache.retrieve( + new_fake_block, new_fake_hash, runtime_version=fake_version + ) + is not None ) # after checking the runtime with the new block, it now knows this runtime should also map to this block assert runtime_cache.retrieve(new_fake_block) is not None assert runtime_cache.retrieve(newer_fake_block) is None assert runtime_cache.retrieve(newer_fake_block, fake_hash) is not None assert runtime_cache.retrieve(newer_fake_block) is not None - assert runtime_cache.retrieve(block_hash=new_fake_hash) is None assert runtime_cache.retrieve(fake_block, block_hash=new_fake_hash) is not None assert runtime_cache.retrieve(block_hash=new_fake_hash) is not None + + +@pytest.mark.asyncio +async def test_runtime_cache_from_disk(): + test_db_location = "/tmp/async-substrate-interface-test-cache" + fake_chain = "ws://fake.com" + fake_block = 1 + fake_hash = "0xignore" + new_fake_block = 2 + new_fake_hash = "0xnewfakehash" + + if os.path.exists(test_db_location): + os.remove(test_db_location) + with patch.object(cache, "CACHE_LOCATION", test_db_location): + substrate = DiskCachedAsyncSubstrateInterface(fake_chain, _mock=True) + # Needed to avoid trying to initialize on the network during `substrate.initialize()` + substrate.initialized = True + + # runtime cache should be completely empty + assert substrate.runtime_cache.block_hashes == {} + assert substrate.runtime_cache.blocks == {} + assert substrate.runtime_cache.versions == {} + await substrate.initialize() + + # after initialization, runtime cache should still be completely empty + assert substrate.runtime_cache.block_hashes == {} + assert substrate.runtime_cache.blocks == {} + assert substrate.runtime_cache.versions == {} + await substrate.close() + + # ensure we have created the SQLite DB during initialize() + assert os.path.exists(test_db_location) + + # insert some fake data into our DB + conn = sqlite3.connect(test_db_location) + conn.execute( + "INSERT INTO RuntimeCache_blocks (key, value, chain) VALUES (?, ?, ?)", + (fake_block, pickle.dumps(fake_hash), fake_chain), + ) + conn.commit() + conn.close() + + substrate.initialized = True + await substrate.initialize() + assert substrate.runtime_cache.blocks == {fake_block: fake_hash} + # add an item to the cache + substrate.runtime_cache.add_item( + runtime=None, block_hash=new_fake_hash, block=new_fake_block + ) + await substrate.close() + + # verify that our added item is now in the DB + conn = sqlite3.connect(test_db_location) + cursor = conn.cursor() + cursor.execute("SELECT key, value, chain FROM RuntimeCache_blocks") + query = cursor.fetchall() + cursor.close() + conn.close() + + first_row = query[0] + assert first_row[0] == fake_block + assert pickle.loads(first_row[1]) == fake_hash + assert first_row[2] == fake_chain + + second_row = query[1] + assert second_row[0] == new_fake_block + assert pickle.loads(second_row[1]) == new_fake_hash + assert second_row[2] == fake_chain