diff --git a/Cargo.lock b/Cargo.lock index d242b5922f..04686202fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8297,6 +8297,7 @@ dependencies = [ "pallet-nomination-pools-runtime-api", "pallet-offences", "pallet-preimage", + "pallet-rate-limiting", "pallet-registry", "pallet-safe-mode", "pallet-scheduler", @@ -10257,6 +10258,48 @@ dependencies = [ "sp-runtime", ] +[[package]] +name = "pallet-rate-limiting" +version = "0.1.0" +dependencies = [ + "frame-benchmarking", + "frame-support", + "frame-system", + "parity-scale-codec", + "scale-info", + "serde", + "sp-core", + "sp-io", + "sp-runtime", + "sp-std", + "subtensor-runtime-common", +] + +[[package]] +name = "pallet-rate-limiting-rpc" +version = "0.1.0" +dependencies = [ + "jsonrpsee", + "pallet-rate-limiting-runtime-api", + "sp-api", + "sp-blockchain", + "sp-runtime", + "subtensor-runtime-common", +] + +[[package]] +name = "pallet-rate-limiting-runtime-api" +version = "0.1.0" +dependencies = [ + "pallet-rate-limiting", + "parity-scale-codec", + "scale-info", + "serde", + "sp-api", + "sp-std", + "subtensor-runtime-common", +] + [[package]] name = "pallet-recovery" version = "41.0.0" @@ -10714,6 +10757,7 @@ dependencies = [ "pallet-crowdloan", "pallet-drand", "pallet-preimage", + "pallet-rate-limiting", "pallet-scheduler", "pallet-subtensor-proxy", "pallet-subtensor-swap", diff --git a/Cargo.toml b/Cargo.toml index 6139004914..1faa26f232 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,8 @@ members = [ "common", "node", "pallets/*", + "pallets/rate-limiting/runtime-api", + "pallets/rate-limiting/rpc", "precompiles", "primitives/*", "runtime", @@ -59,6 +61,9 @@ pallet-subtensor = { path = "pallets/subtensor", default-features = false } pallet-subtensor-swap = { path = "pallets/swap", default-features = false } pallet-subtensor-swap-runtime-api = { path = "pallets/swap/runtime-api", default-features = false } pallet-subtensor-swap-rpc = { path = "pallets/swap/rpc", default-features = false } +pallet-rate-limiting = { path = "pallets/rate-limiting", default-features = false } +pallet-rate-limiting-runtime-api = { path = "pallets/rate-limiting/runtime-api", default-features = false } +pallet-rate-limiting-rpc = { path = "pallets/rate-limiting/rpc", default-features = false } procedural-fork = { path = "support/procedural-fork", default-features = false } safe-math = { path = "primitives/safe-math", default-features = false } share-pool = { path = "primitives/share-pool", default-features = false } diff --git a/common/src/lib.rs b/common/src/lib.rs index 28a33c2ae6..b08bed0696 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -15,8 +15,10 @@ use sp_runtime::{ use subtensor_macros::freeze_struct; pub use currency::*; +pub use rate_limiting::{RateLimitScope, RateLimitUsageKey}; mod currency; +mod rate_limiting; /// Balance of an account. pub type Balance = u64; diff --git a/common/src/rate_limiting.rs b/common/src/rate_limiting.rs new file mode 100644 index 0000000000..3c88758943 --- /dev/null +++ b/common/src/rate_limiting.rs @@ -0,0 +1,66 @@ +use codec::{Decode, DecodeWithMemTracking, Encode, MaxEncodedLen}; +use frame_support::pallet_prelude::Parameter; +use scale_info::TypeInfo; +use serde::{Deserialize, Serialize}; + +use crate::{MechId, NetUid}; + +#[derive( + Serialize, + Deserialize, + Encode, + Decode, + DecodeWithMemTracking, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Debug, + TypeInfo, + MaxEncodedLen, +)] +pub enum RateLimitScope { + Subnet(NetUid), + SubnetMechanism { netuid: NetUid, mecid: MechId }, +} + +#[derive( + Serialize, + Deserialize, + Encode, + Decode, + DecodeWithMemTracking, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + Debug, + TypeInfo, + MaxEncodedLen, +)] +#[scale_info(skip_type_params(AccountId))] +pub enum RateLimitUsageKey { + Account(AccountId), + Subnet(NetUid), + AccountSubnet { + account: AccountId, + netuid: NetUid, + }, + ColdkeyHotkeySubnet { + coldkey: AccountId, + hotkey: AccountId, + netuid: NetUid, + }, + SubnetNeuron { + netuid: NetUid, + uid: u16, + }, + SubnetMechanismNeuron { + netuid: NetUid, + mecid: MechId, + uid: u16, + }, +} diff --git a/pallets/rate-limiting/Cargo.toml b/pallets/rate-limiting/Cargo.toml new file mode 100644 index 0000000000..67e2710f4b --- /dev/null +++ b/pallets/rate-limiting/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "pallet-rate-limiting" +version = "0.1.0" +edition.workspace = true + +[lints] +workspace = true + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] + +[dependencies] +codec = { workspace = true, features = ["derive"] } +frame-benchmarking = { workspace = true, optional = true } +frame-support.workspace = true +frame-system.workspace = true +scale-info = { workspace = true, features = ["derive"] } +serde = { workspace = true, features = ["derive"] } +sp-std.workspace = true +sp-runtime.workspace = true +subtensor-runtime-common.workspace = true + +[dev-dependencies] +sp-core.workspace = true +sp-io.workspace = true +sp-runtime.workspace = true + +[features] +default = ["std"] +std = [ + "codec/std", + "frame-benchmarking?/std", + "frame-support/std", + "frame-system/std", + "scale-info/std", + "serde/std", + "sp-std/std", + "sp-runtime/std", + "subtensor-runtime-common/std", +] +runtime-benchmarks = [ + "frame-benchmarking", +] +try-runtime = [ + "frame-support/try-runtime", + "frame-system/try-runtime", +] diff --git a/pallets/rate-limiting/rpc/Cargo.toml b/pallets/rate-limiting/rpc/Cargo.toml new file mode 100644 index 0000000000..d5bf689e8b --- /dev/null +++ b/pallets/rate-limiting/rpc/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "pallet-rate-limiting-rpc" +version = "0.1.0" +description = "RPC interface for the rate limiting pallet" +edition.workspace = true + +[dependencies] +jsonrpsee = { workspace = true, features = ["client-core", "server", "macros"] } +sp-api.workspace = true +sp-blockchain.workspace = true +sp-runtime.workspace = true +pallet-rate-limiting-runtime-api.workspace = true +subtensor-runtime-common = { workspace = true, default-features = false } + +[features] +default = ["std"] +std = [ + "sp-api/std", + "sp-runtime/std", + "pallet-rate-limiting-runtime-api/std", + "subtensor-runtime-common/std", +] diff --git a/pallets/rate-limiting/rpc/src/lib.rs b/pallets/rate-limiting/rpc/src/lib.rs new file mode 100644 index 0000000000..ca7452a7a0 --- /dev/null +++ b/pallets/rate-limiting/rpc/src/lib.rs @@ -0,0 +1,82 @@ +//! RPC interface for the rate limiting pallet. + +use jsonrpsee::{ + core::RpcResult, + proc_macros::rpc, + types::{ErrorObjectOwned, error::ErrorObject}, +}; +use sp_api::ProvideRuntimeApi; +use sp_blockchain::HeaderBackend; +use sp_runtime::traits::Block as BlockT; +use std::sync::Arc; + +pub use pallet_rate_limiting_runtime_api::{RateLimitRpcResponse, RateLimitingRuntimeApi}; + +#[rpc(client, server)] +pub trait RateLimitingRpcApi { + #[method(name = "rateLimiting_getRateLimit")] + fn get_rate_limit( + &self, + pallet: Vec, + extrinsic: Vec, + at: Option, + ) -> RpcResult>; +} + +/// Error type of this RPC api. +pub enum Error { + /// The call to runtime failed. + RuntimeError(String), +} + +impl From for ErrorObjectOwned { + fn from(e: Error) -> Self { + match e { + Error::RuntimeError(e) => ErrorObject::owned(1, e, None::<()>), + } + } +} + +impl From for i32 { + fn from(e: Error) -> i32 { + match e { + Error::RuntimeError(_) => 1, + } + } +} + +/// RPC implementation for the rate limiting pallet. +pub struct RateLimiting { + client: Arc, + _marker: std::marker::PhantomData, +} + +impl RateLimiting { + /// Creates a new instance of the rate limiting RPC helper. + pub fn new(client: Arc) -> Self { + Self { + client, + _marker: Default::default(), + } + } +} + +impl RateLimitingRpcApiServer<::Hash> for RateLimiting +where + Block: BlockT, + C: ProvideRuntimeApi + HeaderBackend + Send + Sync + 'static, + C::Api: RateLimitingRuntimeApi, +{ + fn get_rate_limit( + &self, + pallet: Vec, + extrinsic: Vec, + at: Option<::Hash>, + ) -> RpcResult> { + let api = self.client.runtime_api(); + let at = at.unwrap_or_else(|| self.client.info().best_hash); + + api.get_rate_limit(at, pallet, extrinsic) + .map_err(|e| Error::RuntimeError(format!("Unable to fetch rate limit: {e:?}")).into()) + } +} diff --git a/pallets/rate-limiting/runtime-api/Cargo.toml b/pallets/rate-limiting/runtime-api/Cargo.toml new file mode 100644 index 0000000000..2847d865dd --- /dev/null +++ b/pallets/rate-limiting/runtime-api/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "pallet-rate-limiting-runtime-api" +version = "0.1.0" +description = "Runtime API for the rate limiting pallet" +edition.workspace = true + +[dependencies] +codec = { workspace = true, features = ["derive"] } +scale-info = { workspace = true, features = ["derive"] } +sp-api.workspace = true +sp-std.workspace = true +pallet-rate-limiting.workspace = true +subtensor-runtime-common = { workspace = true, default-features = false } +serde = { workspace = true, features = ["derive"], optional = true } + +[features] +default = ["std"] +std = [ + "codec/std", + "scale-info/std", + "sp-api/std", + "sp-std/std", + "pallet-rate-limiting/std", + "subtensor-runtime-common/std", + "serde", +] diff --git a/pallets/rate-limiting/runtime-api/src/lib.rs b/pallets/rate-limiting/runtime-api/src/lib.rs new file mode 100644 index 0000000000..98b55e9a26 --- /dev/null +++ b/pallets/rate-limiting/runtime-api/src/lib.rs @@ -0,0 +1,25 @@ +#![cfg_attr(not(feature = "std"), no_std)] + +use codec::{Decode, Encode}; +use pallet_rate_limiting::RateLimitKind; +use scale_info::TypeInfo; +use sp_std::vec::Vec; +use subtensor_runtime_common::BlockNumber; + +#[cfg(feature = "std")] +use serde::{Deserialize, Serialize}; + +#[cfg_attr(feature = "std", derive(Serialize, Deserialize))] +#[derive(Clone, Debug, Decode, Encode, Eq, PartialEq, TypeInfo)] +pub struct RateLimitRpcResponse { + pub global: Option>, + pub contextual: Vec<(Vec, RateLimitKind)>, + pub default_limit: BlockNumber, + pub resolved: Option, +} + +sp_api::decl_runtime_apis! { + pub trait RateLimitingRuntimeApi { + fn get_rate_limit(pallet: Vec, extrinsic: Vec) -> Option; + } +} diff --git a/pallets/rate-limiting/src/benchmarking.rs b/pallets/rate-limiting/src/benchmarking.rs new file mode 100644 index 0000000000..23dfecec85 --- /dev/null +++ b/pallets/rate-limiting/src/benchmarking.rs @@ -0,0 +1,183 @@ +//! Benchmarking setup for pallet-rate-limiting +#![cfg(feature = "runtime-benchmarks")] +#![allow(clippy::arithmetic_side_effects)] + +use codec::Decode; +use frame_benchmarking::v2::*; +use frame_system::{RawOrigin, pallet_prelude::BlockNumberFor}; +use sp_runtime::traits::{One, Saturating}; + +use super::*; + +pub trait BenchmarkHelper { + fn sample_call() -> Call; +} + +impl BenchmarkHelper for () +where + Call: Decode, +{ + fn sample_call() -> Call { + Decode::decode(&mut &[][..]).expect("Provide a call via BenchmarkHelper::sample_call") + } +} + +fn sample_call() -> Box<::RuntimeCall> +where + T::BenchmarkHelper: BenchmarkHelper<::RuntimeCall>, +{ + Box::new(T::BenchmarkHelper::sample_call()) +} + +fn seed_group(name: &[u8], sharing: GroupSharing) -> ::GroupId { + Pallet::::create_group(RawOrigin::Root.into(), name.to_vec(), sharing) + .expect("group created"); + Pallet::::next_group_id().saturating_sub(::GroupId::one()) +} + +fn register_call_with_group( + group: Option<::GroupId>, +) -> TransactionIdentifier { + let call = sample_call::(); + let identifier = TransactionIdentifier::from_call::(call.as_ref()).expect("id"); + Pallet::::register_call(RawOrigin::Root.into(), call, group).expect("registered"); + identifier +} + +#[benchmarks] +mod benchmarks { + use super::*; + use sp_std::vec::Vec; + + #[benchmark] + fn register_call() { + let call = sample_call::(); + let identifier = TransactionIdentifier::from_call::(call.as_ref()).expect("id"); + let target = RateLimitTarget::Transaction(identifier); + + #[extrinsic_call] + _(RawOrigin::Root, call, None); + + assert!(Limits::::contains_key(target)); + } + + #[benchmark] + fn set_rate_limit() { + let call = sample_call::(); + let identifier = TransactionIdentifier::from_call::(call.as_ref()).expect("id"); + let target = RateLimitTarget::Transaction(identifier); + Limits::::insert(target, RateLimit::global(RateLimitKind::Default)); + + let limit = RateLimitKind::>::Exact(BlockNumberFor::::from(10u32)); + + #[extrinsic_call] + _(RawOrigin::Root, target, None, limit); + + let stored = Limits::::get(target).expect("limit stored"); + assert!( + matches!(stored, RateLimit::Global(RateLimitKind::Exact(span)) if span == BlockNumberFor::::from(10u32)) + ); + } + + #[benchmark] + fn assign_call_to_group() { + let group = seed_group::(b"grp", GroupSharing::UsageOnly); + let identifier = register_call_with_group::(None); + + #[extrinsic_call] + _(RawOrigin::Root, identifier, group); + + assert_eq!(CallGroups::::get(identifier), Some(group)); + assert!(GroupMembers::::get(group).contains(&identifier)); + } + + #[benchmark] + fn remove_call_from_group() { + let group = seed_group::(b"team", GroupSharing::ConfigOnly); + let identifier = register_call_with_group::(Some(group)); + + #[extrinsic_call] + _(RawOrigin::Root, identifier); + + assert!(CallGroups::::get(identifier).is_none()); + assert!(!GroupMembers::::get(group).contains(&identifier)); + } + + #[benchmark] + fn create_group() { + let name = b"bench".to_vec(); + let sharing = GroupSharing::ConfigAndUsage; + + #[extrinsic_call] + _(RawOrigin::Root, name.clone(), sharing); + + let group = Pallet::::next_group_id().saturating_sub(::GroupId::one()); + let details = Groups::::get(group).expect("group stored"); + let stored: Vec = details.name.into(); + assert_eq!(stored, name); + assert_eq!(details.sharing, sharing); + } + + #[benchmark] + fn update_group() { + let group = seed_group::(b"old", GroupSharing::UsageOnly); + let new_name = b"new".to_vec(); + let new_sharing = GroupSharing::ConfigAndUsage; + + #[extrinsic_call] + _( + RawOrigin::Root, + group, + Some(new_name.clone()), + Some(new_sharing), + ); + + let details = Groups::::get(group).expect("group exists"); + let stored: Vec = details.name.into(); + assert_eq!(stored, new_name); + assert_eq!(details.sharing, new_sharing); + } + + #[benchmark] + fn delete_group() { + let group = seed_group::(b"delete", GroupSharing::UsageOnly); + + #[extrinsic_call] + _(RawOrigin::Root, group); + + assert!(Groups::::get(group).is_none()); + } + + #[benchmark] + fn deregister_call() { + let group = seed_group::(b"dreg", GroupSharing::ConfigAndUsage); + let identifier = register_call_with_group::(Some(group)); + let target = RateLimitTarget::Transaction(identifier); + let usage_target = Pallet::::usage_target(&identifier).expect("usage target"); + LastSeen::::insert( + usage_target, + None::, + BlockNumberFor::::from(1u32), + ); + + #[extrinsic_call] + _(RawOrigin::Root, identifier, None, true); + + assert!(Limits::::get(target).is_none()); + assert!(LastSeen::::get(usage_target, None::).is_none()); + assert!(CallGroups::::get(identifier).is_none()); + assert!(!GroupMembers::::get(group).contains(&identifier)); + } + + #[benchmark] + fn set_default_rate_limit() { + let block_span = BlockNumberFor::::from(10u32); + + #[extrinsic_call] + _(RawOrigin::Root, block_span); + + assert_eq!(DefaultLimit::::get(), block_span); + } + + impl_benchmark_test_suite!(Pallet, crate::mock::new_test_ext(), crate::mock::Test); +} diff --git a/pallets/rate-limiting/src/lib.rs b/pallets/rate-limiting/src/lib.rs new file mode 100644 index 0000000000..8b7cc1072f --- /dev/null +++ b/pallets/rate-limiting/src/lib.rs @@ -0,0 +1,1202 @@ +#![cfg_attr(not(feature = "std"), no_std)] + +//! Rate limiting for runtime calls with optional contextual restrictions. +//! +//! # Overview +//! +//! `pallet-rate-limiting` lets a runtime restrict how frequently particular calls can execute. +//! Limits are stored on-chain, keyed by explicit [`RateLimitTarget`] values. A target is either a +//! single [`TransactionIdentifier`] (the pallet/extrinsic indices) or a named *group* managed by the +//! admin APIs. Groups provide a way to give multiple calls the same configuration and/or usage +//! tracking without duplicating storage. Each target entry stores either a global span or a set of +//! scoped spans resolved at runtime. The pallet exposes a handful of extrinsics, restricted by +//! [`Config::AdminOrigin`], to manage this data: +//! +//! - [`register_call`](pallet::Pallet::register_call): register a call for rate limiting, seed its +//! initial configuration using [`Config::LimitScopeResolver`], and optionally place it into a +//! group. +//! - [`set_rate_limit`](pallet::Pallet::set_rate_limit): assign or override the limit at a specific +//! target/scope by supplying a [`RateLimitKind`] span. +//! - [`assign_call_to_group`](pallet::Pallet::assign_call_to_group) and +//! [`remove_call_from_group`](pallet::Pallet::remove_call_from_group): manage group membership for +//! registered calls. +//! - [`deregister_call`](pallet::Pallet::deregister_call): remove scoped configuration or wipe the +//! registration entirely. +//! - [`set_default_rate_limit`](pallet::Pallet::set_default_rate_limit): set the global default +//! block span used by `RateLimitKind::Default` entries. +//! +//! The pallet also tracks the last block in which a target was observed, per optional *usage key*. +//! A usage key may refine tracking beyond the limit scope (for example combining a `netuid` with a +//! hyperparameter), so the two concepts are explicitly separated in the configuration. When the +//! admin puts several calls into a group and marks usage as shared, each dispatch still runs the +//! resolver: the group only chooses the storage target, while the resolver output (or `None`) picks +//! the row under that target. Calls that resolve to the same usage key update the same timestamp; +//! calls that resolve to different keys keep isolated timers even when they share a group. The same +//! rule applies to limit scopes—grouping funnels configuration into the same target, but the scope +//! resolver decides whether that entry is global or per-context. +//! +//! Each storage map is namespaced by pallet instance; runtimes can deploy multiple independent +//! instances to manage distinct rate-limiting scopes (in the global sense). +//! +//! # Transaction extension +//! +//! Enforcement happens via [`RateLimitTransactionExtension`], which implements +//! `sp_runtime::traits::TransactionExtension`. The extension consults `Limits`, fetches the current +//! block, and decides whether the call is eligible. If successful, it returns metadata that causes +//! [`LastSeen`](pallet::LastSeen) to update after dispatch. A rejected call yields +//! `InvalidTransaction::Custom(1)`. +//! +//! To enable the extension, add it to your runtime's transaction extension tuple. For example: +//! +//! ```ignore +//! pub type TransactionExtensions = ( +//! // ... other extensions ... +//! pallet_rate_limiting::RateLimitTransactionExtension, +//! ); +//! ``` +//! +//! # Context resolvers +//! +//! The pallet relies on two resolvers: +//! +//! - [`Config::LimitScopeResolver`], which determines how limits are stored (for example by +//! returning a `netuid`). The resolver can also signal that a call should bypass rate limiting or +//! adjust the effective span at validation time. When it returns `None`, the configuration is +//! stored as a global fallback. +//! - [`Config::UsageResolver`], which decides how executions are tracked in +//! [`LastSeen`](pallet::LastSeen). This can refine the limit scope (for example by returning a +//! tuple of `(netuid, hyperparameter)`). +//! +//! Each resolver receives the origin and call and may return `Some(identifier)` when scoping is +//! required, or `None` to use the global entry. Extrinsics such as +//! [`set_rate_limit`](pallet::Pallet::set_rate_limit) automatically consult these resolvers. When a +//! call belongs to a group the pallet still runs the resolver—instead of indexing storage at the +//! transaction-level target, it indexes at the group target. Resolving to different contexts keeps +//! independent limit/usage rows even though the calls share a group; resolving to the same context +//! causes them to share enforcement state. +//! +//! ```ignore +//! pub struct WeightsContextResolver; +//! +//! // Limits are scoped per netuid. +//! pub struct ScopeResolver; +//! impl pallet_rate_limiting::RateLimitScopeResolver< +//! RuntimeOrigin, +//! RuntimeCall, +//! NetUid, +//! BlockNumber, +//! > for ScopeResolver { +//! fn context(origin: &RuntimeOrigin, call: &RuntimeCall) -> Option { +//! match call { +//! RuntimeCall::Subtensor(pallet_subtensor::Call::set_weights { netuid, .. }) => { +//! Some(*netuid) +//! } +//! _ => None, +//! } +//! } +//! +//! fn should_bypass(origin: &RuntimeOrigin, _call: &RuntimeCall) -> bool { +//! matches!(origin, RuntimeOrigin::Root) +//! } +//! +//! fn adjust_span(_origin: &RuntimeOrigin, _call: &RuntimeCall, span: BlockNumber) -> BlockNumber { +//! span +//! } +//! } +//! +//! // Usage tracking distinguishes hyperparameter + netuid. +//! pub struct UsageResolver; +//! impl pallet_rate_limiting::RateLimitUsageResolver< +//! RuntimeOrigin, +//! RuntimeCall, +//! (NetUid, HyperParam), +//! > for UsageResolver { +//! fn context(_origin: &RuntimeOrigin, call: &RuntimeCall) -> Option<(NetUid, HyperParam)> { +//! match call { +//! RuntimeCall::Subtensor(pallet_subtensor::Call::set_hyperparam { +//! netuid, +//! hyper, +//! .. +//! }) => Some((*netuid, *hyper)), +//! _ => None, +//! } +//! } +//! } +//! +//! impl pallet_rate_limiting::Config for Runtime { +//! type RuntimeCall = RuntimeCall; +//! type LimitScope = NetUid; +//! type LimitScopeResolver = ScopeResolver; +//! type UsageKey = (NetUid, HyperParam); +//! type UsageResolver = UsageResolver; +//! type AdminOrigin = frame_system::EnsureRoot; +//! } +//! ``` + +#[cfg(feature = "runtime-benchmarks")] +pub use benchmarking::BenchmarkHelper; +pub use pallet::*; +pub use tx_extension::RateLimitTransactionExtension; +pub use types::{ + GroupSharing, RateLimit, RateLimitGroup, RateLimitKind, RateLimitScopeResolver, + RateLimitTarget, RateLimitUsageResolver, TransactionIdentifier, +}; + +#[cfg(feature = "runtime-benchmarks")] +mod benchmarking; +mod tx_extension; +mod types; + +#[cfg(test)] +mod mock; + +#[cfg(test)] +mod tests; + +#[frame_support::pallet] +pub mod pallet { + use codec::Codec; + use frame_support::{ + BoundedBTreeSet, BoundedVec, + pallet_prelude::*, + traits::{BuildGenesisConfig, EnsureOrigin, GetCallMetadata}, + }; + use frame_system::pallet_prelude::*; + use sp_runtime::traits::{ + AtLeast32BitUnsigned, DispatchOriginOf, Dispatchable, Member, One, Saturating, Zero, + }; + use sp_std::{boxed::Box, convert::TryFrom, marker::PhantomData, vec::Vec}; + + #[cfg(feature = "runtime-benchmarks")] + use crate::benchmarking::BenchmarkHelper as BenchmarkHelperTrait; + use crate::types::{ + GroupSharing, RateLimit, RateLimitGroup, RateLimitKind, RateLimitScopeResolver, + RateLimitTarget, RateLimitUsageResolver, TransactionIdentifier, + }; + + type GroupNameOf = BoundedVec>::MaxGroupNameLength>; + type GroupMembersOf = + BoundedBTreeSet>::MaxGroupMembers>; + type GroupDetailsOf = RateLimitGroup<>::GroupId, GroupNameOf>; + + /// Configuration trait for the rate limiting pallet. + #[pallet::config] + pub trait Config: frame_system::Config + where + BlockNumberFor: MaybeSerializeDeserialize, + <>::RuntimeCall as Dispatchable>::RuntimeOrigin: + From<::RuntimeOrigin>, + { + /// The overarching runtime call type. + type RuntimeCall: Parameter + + Codec + + GetCallMetadata + + Dispatchable + + IsType<::RuntimeCall>; + + /// Origin permitted to configure rate limits. + type AdminOrigin: EnsureOrigin>; + + /// Scope identifier used to namespace stored rate limits. + type LimitScope: Parameter + Clone + PartialEq + Eq + Ord + MaybeSerializeDeserialize; + + /// Resolves the scope for the given runtime call when configuring limits. + type LimitScopeResolver: RateLimitScopeResolver< + DispatchOriginOf<>::RuntimeCall>, + >::RuntimeCall, + Self::LimitScope, + BlockNumberFor, + >; + + /// Usage key tracked in [`LastSeen`] for rate-limited calls. + type UsageKey: Parameter + Clone + PartialEq + Eq + Ord + MaybeSerializeDeserialize; + + /// Resolves the usage key for the given runtime call when enforcing limits. + type UsageResolver: RateLimitUsageResolver< + DispatchOriginOf<>::RuntimeCall>, + >::RuntimeCall, + Self::UsageKey, + >; + + /// Identifier assigned to managed groups. + type GroupId: Parameter + + Member + + Copy + + MaybeSerializeDeserialize + + MaxEncodedLen + + AtLeast32BitUnsigned + + Default; + + /// Maximum number of extrinsics that may belong to a single group. + #[pallet::constant] + type MaxGroupMembers: Get; + + /// Maximum length (in bytes) of a group name. + #[pallet::constant] + type MaxGroupNameLength: Get; + + /// Helper used to construct runtime calls for benchmarking. + #[cfg(feature = "runtime-benchmarks")] + type BenchmarkHelper: BenchmarkHelperTrait<>::RuntimeCall>; + } + + /// Storage mapping from rate limit target to its configured rate limit. + #[pallet::storage] + #[pallet::getter(fn limits)] + pub type Limits, I: 'static = ()> = StorageMap< + _, + Blake2_128Concat, + RateLimitTarget<>::GroupId>, + RateLimit<>::LimitScope, BlockNumberFor>, + OptionQuery, + >; + + /// Tracks when a rate-limited target was last observed per usage key. + #[pallet::storage] + pub type LastSeen, I: 'static = ()> = StorageDoubleMap< + _, + Blake2_128Concat, + RateLimitTarget<>::GroupId>, + Blake2_128Concat, + Option<>::UsageKey>, + BlockNumberFor, + OptionQuery, + >; + + /// Default block span applied when an extrinsic uses the default rate limit. + #[pallet::storage] + #[pallet::getter(fn default_limit)] + pub type DefaultLimit, I: 'static = ()> = + StorageValue<_, BlockNumberFor, ValueQuery>; + + /// Maps a transaction identifier to its assigned group. + #[pallet::storage] + #[pallet::getter(fn call_group)] + pub type CallGroups, I: 'static = ()> = StorageMap< + _, + Blake2_128Concat, + TransactionIdentifier, + >::GroupId, + OptionQuery, + >; + + /// Metadata for each configured group. + #[pallet::storage] + #[pallet::getter(fn groups)] + pub type Groups, I: 'static = ()> = StorageMap< + _, + Blake2_128Concat, + >::GroupId, + GroupDetailsOf, + OptionQuery, + >; + + /// Tracks membership for each group. + #[pallet::storage] + #[pallet::getter(fn group_members)] + pub type GroupMembers, I: 'static = ()> = StorageMap< + _, + Blake2_128Concat, + >::GroupId, + GroupMembersOf, + ValueQuery, + >; + + /// Enforces unique group names. + #[pallet::storage] + #[pallet::getter(fn group_id_by_name)] + pub type GroupNameIndex, I: 'static = ()> = + StorageMap<_, Blake2_128Concat, GroupNameOf, >::GroupId, OptionQuery>; + + /// Identifier used for the next group creation. + #[pallet::storage] + #[pallet::getter(fn next_group_id)] + pub type NextGroupId, I: 'static = ()> = + StorageValue<_, >::GroupId, ValueQuery>; + + /// Events emitted by the rate limiting pallet. + #[pallet::event] + #[pallet::generate_deposit(pub(super) fn deposit_event)] + pub enum Event, I: 'static = ()> { + /// A call was registered for rate limiting. + CallRegistered { + /// Identifier of the registered transaction. + transaction: TransactionIdentifier, + /// Scope seeded during registration (if any). + scope: Option<>::LimitScope>, + /// Optional group assignment applied at registration time. + group: Option<>::GroupId>, + /// Pallet name associated with the transaction. + pallet: Vec, + /// Extrinsic name associated with the transaction. + extrinsic: Vec, + }, + /// A rate limit was set or updated for the specified target. + RateLimitSet { + /// Target whose configuration changed. + target: RateLimitTarget<>::GroupId>, + /// Identifier of the transaction when the target represents a call. + transaction: Option, + /// Limit scope to which the configuration applies, if any. + scope: Option<>::LimitScope>, + /// The rate limit policy applied to the target. + limit: RateLimitKind>, + /// Pallet name associated with the transaction, when available. + pallet: Option>, + /// Extrinsic name associated with the transaction, when available. + extrinsic: Option>, + }, + /// A rate-limited call was deregistered or had a scoped entry cleared. + CallDeregistered { + /// Target whose configuration changed. + target: RateLimitTarget<>::GroupId>, + /// Identifier of the transaction when the target represents a call. + transaction: Option, + /// Limit scope from which the configuration was cleared, if any. + scope: Option<>::LimitScope>, + /// Pallet name associated with the transaction, when available. + pallet: Option>, + /// Extrinsic name associated with the transaction, when available. + extrinsic: Option>, + }, + /// The default rate limit was set or updated. + DefaultRateLimitSet { + /// The new default limit expressed in blocks. + block_span: BlockNumberFor, + }, + /// A group was created. + GroupCreated { + /// Identifier of the new group. + group: >::GroupId, + /// Human readable group name. + name: Vec, + /// Sharing policy configured for the group. + sharing: GroupSharing, + }, + /// A group's metadata or policy changed. + GroupUpdated { + /// Identifier of the group. + group: >::GroupId, + /// Human readable name. + name: Vec, + /// Updated sharing configuration. + sharing: GroupSharing, + }, + /// A group was deleted. + GroupDeleted { + /// Identifier of the removed group. + group: >::GroupId, + }, + /// A transaction was assigned to or removed from a group. + CallGroupUpdated { + /// Identifier of the transaction. + transaction: TransactionIdentifier, + /// Updated group assignment (None when cleared). + group: Option<>::GroupId>, + }, + } + + /// Errors that can occur while configuring rate limits. + #[pallet::error] + pub enum Error { + /// Failed to extract the pallet and extrinsic indices from the call. + InvalidRuntimeCall, + /// Attempted to remove a limit that is not present. + MissingRateLimit, + /// Group metadata was not found. + UnknownGroup, + /// Attempted to create or rename a group to an existing name. + DuplicateGroupName, + /// Group name exceeds the configured maximum length. + GroupNameTooLong, + /// Operation requires the group to have no members. + GroupHasMembers, + /// Adding a member would exceed the configured limit. + GroupMemberLimitExceeded, + /// Call already belongs to the requested group. + CallAlreadyInGroup, + /// Call is not assigned to a group. + CallNotInGroup, + /// Operation requires the call to be registered first. + CallNotRegistered, + /// Attempted to register a call that already exists. + CallAlreadyRegistered, + /// Rate limit for this call must be configured via its group target. + MustTargetGroup, + /// Resolver failed to supply a required context value. + MissingScope, + /// Group cannot be removed because configuration or usage entries remain. + GroupInUse, + } + + #[pallet::genesis_config] + pub struct GenesisConfig, I: 'static = ()> { + pub default_limit: BlockNumberFor, + pub limits: Vec<( + RateLimitTarget<>::GroupId>, + Option<>::LimitScope>, + RateLimitKind>, + )>, + pub groups: Vec<(>::GroupId, Vec, GroupSharing)>, + } + + #[cfg(feature = "std")] + impl, I: 'static> Default for GenesisConfig { + fn default() -> Self { + Self { + default_limit: Zero::zero(), + limits: Vec::new(), + groups: Vec::new(), + } + } + } + + #[pallet::genesis_build] + impl, I: 'static> BuildGenesisConfig for GenesisConfig { + fn build(&self) { + DefaultLimit::::put(self.default_limit); + + // Seed groups first so limit targets can reference them. + let mut max_group: >::GroupId = Zero::zero(); + for (group_id, name, sharing) in &self.groups { + let bounded = GroupNameOf::::try_from(name.clone()) + .expect("Genesis group name exceeds MaxGroupNameLength"); + + assert!( + !Groups::::contains_key(group_id), + "Duplicate group id in genesis config" + ); + assert!( + !GroupNameIndex::::contains_key(&bounded), + "Duplicate group name in genesis config" + ); + + Groups::::insert( + group_id, + RateLimitGroup { + id: *group_id, + name: bounded.clone(), + sharing: *sharing, + }, + ); + GroupNameIndex::::insert(&bounded, *group_id); + GroupMembers::::insert(*group_id, GroupMembersOf::::new()); + if *group_id > max_group { + max_group = *group_id; + } + } + let next = max_group.saturating_add(One::one()); + NextGroupId::::put(next); + + for (identifier, scope, kind) in &self.limits { + if let RateLimitTarget::Group(group) = identifier { + assert!( + Groups::::contains_key(group), + "Genesis limit references unknown group" + ); + } + let target = *identifier; + Limits::::mutate(target, |entry| match scope { + None => { + *entry = Some(RateLimit::global(*kind)); + } + Some(sc) => { + if let Some(config) = entry { + config.upsert_scope(sc.clone(), *kind); + } else { + *entry = Some(RateLimit::scoped_single(sc.clone(), *kind)); + } + } + }); + } + } + } + + #[pallet::pallet] + #[pallet::without_storage_info] + pub struct Pallet(PhantomData<(T, I)>); + + impl, I: 'static> Pallet { + /// Returns `true` when the given transaction identifier passes its configured rate limit + /// within the provided usage scope. + pub fn is_within_limit( + origin: &DispatchOriginOf<>::RuntimeCall>, + call: &>::RuntimeCall, + identifier: &TransactionIdentifier, + scope: &Option<>::LimitScope>, + usage_key: &Option<>::UsageKey>, + ) -> Result { + if >::LimitScopeResolver::should_bypass(origin, call) { + return Ok(true); + } + + let target = Self::config_target(identifier)?; + Self::ensure_scope_available(&target, scope)?; + + let Some(block_span) = Self::effective_span(origin, call, &target, scope) else { + return Ok(true); + }; + + let usage_target = Self::usage_target(identifier)?; + Ok(Self::within_span(&usage_target, usage_key, block_span)) + } + + pub(crate) fn resolved_limit( + target: &RateLimitTarget<>::GroupId>, + scope: &Option<>::LimitScope>, + ) -> Option> { + let config = Limits::::get(target)?; + let kind = config.kind_for(scope.as_ref())?; + Some(match *kind { + RateLimitKind::Default => DefaultLimit::::get(), + RateLimitKind::Exact(block_span) => block_span, + }) + } + + pub(crate) fn effective_span( + origin: &DispatchOriginOf<>::RuntimeCall>, + call: &>::RuntimeCall, + target: &RateLimitTarget<>::GroupId>, + scope: &Option<>::LimitScope>, + ) -> Option> { + let span = Self::resolved_limit(target, scope)?; + Some(>::LimitScopeResolver::adjust_span( + origin, call, span, + )) + } + + pub(crate) fn within_span( + target: &RateLimitTarget<>::GroupId>, + usage_key: &Option<>::UsageKey>, + block_span: BlockNumberFor, + ) -> bool { + if block_span.is_zero() { + return true; + } + + if let Some(last) = LastSeen::::get(target, usage_key.clone()) { + let current = frame_system::Pallet::::block_number(); + let delta = current.saturating_sub(last); + if delta < block_span { + return false; + } + } + + true + } + + /// Inserts or updates the cached usage timestamp for a rate-limited call. + /// + /// This is primarily intended for migrations that need to hydrate the new tracking storage + /// from legacy pallets. + pub fn record_last_seen( + target: RateLimitTarget<>::GroupId>, + usage_key: Option<>::UsageKey>, + block_number: BlockNumberFor, + ) { + LastSeen::::insert(target, usage_key, block_number); + } + + /// Migrates a stored rate limit configuration from one scope to another. + /// + /// Returns `true` when an entry was moved. Passing identical `from`/`to` scopes simply + /// checks that a configuration exists. + pub fn migrate_limit_scope( + target: RateLimitTarget<>::GroupId>, + from: Option<>::LimitScope>, + to: Option<>::LimitScope>, + ) -> bool { + if from == to { + return Limits::::contains_key(target); + } + + let mut migrated = false; + Limits::::mutate(target, |maybe_config| { + if let Some(config) = maybe_config { + match (from.as_ref(), to.as_ref()) { + (None, Some(target)) => { + if let RateLimit::Global(kind) = config { + *config = RateLimit::scoped_single(target.clone(), *kind); + migrated = true; + } + } + (Some(source), Some(target)) => { + if let RateLimit::Scoped(map) = config { + if let Some(kind) = map.remove(source) { + map.insert(target.clone(), kind); + migrated = true; + } + } + } + (Some(source), None) => { + if let RateLimit::Scoped(map) = config { + if map.len() == 1 && map.contains_key(source) { + if let Some(kind) = map.remove(source) { + *config = RateLimit::global(kind); + migrated = true; + } + } + } + } + _ => {} + } + } + }); + + migrated + } + + /// Migrates the cached usage information for a rate-limited call to a new key. + /// + /// Returns `true` when an entry was moved. Passing identical keys simply checks that an + /// entry exists. + pub fn migrate_usage_key( + target: RateLimitTarget<>::GroupId>, + from: Option<>::UsageKey>, + to: Option<>::UsageKey>, + ) -> bool { + if from == to { + return LastSeen::::contains_key(target, to); + } + + let Some(block) = LastSeen::::take(target, from) else { + return false; + }; + + LastSeen::::insert(target, to, block); + true + } + + /// Returns the configured limit for the specified pallet/extrinsic names, if any. + pub fn limit_for_call_names( + pallet_name: &str, + extrinsic_name: &str, + scope: Option<>::LimitScope>, + ) -> Option>> { + let identifier = Self::identifier_for_call_names(pallet_name, extrinsic_name)?; + let target = Self::config_target(&identifier).ok()?; + Limits::::get(target).and_then(|config| config.kind_for(scope.as_ref()).copied()) + } + + /// Returns the resolved block span for the specified pallet/extrinsic names, if any. + pub fn resolved_limit_for_call_names( + pallet_name: &str, + extrinsic_name: &str, + scope: Option<>::LimitScope>, + ) -> Option> { + let identifier = Self::identifier_for_call_names(pallet_name, extrinsic_name)?; + let target = Self::config_target(&identifier).ok()?; + Self::resolved_limit(&target, &scope) + } + + fn identifier_for_call_names( + pallet_name: &str, + extrinsic_name: &str, + ) -> Option { + let modules = >::RuntimeCall::get_module_names(); + let pallet_pos = modules.iter().position(|name| *name == pallet_name)?; + let call_names = >::RuntimeCall::get_call_names(pallet_name); + let extrinsic_pos = call_names.iter().position(|name| *name == extrinsic_name)?; + let pallet_index = u8::try_from(pallet_pos).ok()?; + let extrinsic_index = u8::try_from(extrinsic_pos).ok()?; + Some(TransactionIdentifier::new(pallet_index, extrinsic_index)) + } + + fn ensure_call_registered(identifier: &TransactionIdentifier) -> DispatchResult { + let target = RateLimitTarget::Transaction(*identifier); + ensure!( + Limits::::contains_key(target), + Error::::CallNotRegistered + ); + Ok(()) + } + + fn ensure_call_unregistered(identifier: &TransactionIdentifier) -> DispatchResult { + let target = RateLimitTarget::Transaction(*identifier); + ensure!( + !Limits::::contains_key(target), + Error::::CallAlreadyRegistered + ); + Ok(()) + } + + fn call_metadata( + identifier: &TransactionIdentifier, + ) -> Result<(Vec, Vec), DispatchError> { + let (pallet_name, extrinsic_name) = identifier.names::()?; + Ok(( + Vec::from(pallet_name.as_bytes()), + Vec::from(extrinsic_name.as_bytes()), + )) + } + + pub(crate) fn config_target( + identifier: &TransactionIdentifier, + ) -> Result>::GroupId>, DispatchError> { + Self::target_for(identifier, GroupSharing::config_uses_group) + } + + pub(crate) fn usage_target( + identifier: &TransactionIdentifier, + ) -> Result>::GroupId>, DispatchError> { + Self::target_for(identifier, GroupSharing::usage_uses_group) + } + + fn target_for( + identifier: &TransactionIdentifier, + predicate: impl Fn(GroupSharing) -> bool, + ) -> Result>::GroupId>, DispatchError> { + let group = Self::group_assignment(identifier)?; + Ok(Self::target_from_details( + identifier, + group.as_ref(), + predicate, + )) + } + + fn group_assignment( + identifier: &TransactionIdentifier, + ) -> Result>, DispatchError> { + let Some(group) = CallGroups::::get(identifier) else { + return Ok(None); + }; + let details = Self::ensure_group_details(group)?; + Ok(Some(details)) + } + + fn target_from_details( + identifier: &TransactionIdentifier, + details: Option<&GroupDetailsOf>, + predicate: impl Fn(GroupSharing) -> bool, + ) -> RateLimitTarget<>::GroupId> { + if let Some(details) = details { + if predicate(details.sharing) { + return RateLimitTarget::Group(details.id); + } + } + RateLimitTarget::Transaction(*identifier) + } + + fn ensure_group_details( + group: >::GroupId, + ) -> Result, DispatchError> { + Groups::::get(group).ok_or(Error::::UnknownGroup.into()) + } + + fn ensure_scope_available( + target: &RateLimitTarget<>::GroupId>, + scope: &Option<>::LimitScope>, + ) -> Result<(), DispatchError> { + if scope.is_some() { + return Ok(()); + } + + if let Some(RateLimit::Scoped(map)) = Limits::::get(target) { + if !map.is_empty() { + return Err(Error::::MissingScope.into()); + } + } + + Ok(()) + } + + fn bounded_group_name(name: Vec) -> Result, DispatchError> { + GroupNameOf::::try_from(name).map_err(|_| Error::::GroupNameTooLong.into()) + } + + fn ensure_group_name_available( + name: &GroupNameOf, + current: Option<>::GroupId>, + ) -> DispatchResult { + if let Some(existing) = GroupNameIndex::::get(name) { + ensure!(Some(existing) == current, Error::::DuplicateGroupName); + } + Ok(()) + } + + fn ensure_group_deletable(group: >::GroupId) -> DispatchResult { + ensure!( + GroupMembers::::get(group).is_empty(), + Error::::GroupHasMembers + ); + let target = RateLimitTarget::Group(group); + ensure!( + !Limits::::contains_key(target), + Error::::GroupInUse + ); + ensure!( + LastSeen::::iter_prefix(target).next().is_none(), + Error::::GroupInUse + ); + Ok(()) + } + + fn insert_call_into_group( + identifier: &TransactionIdentifier, + group: >::GroupId, + ) -> DispatchResult { + GroupMembers::::try_mutate(group, |members| -> DispatchResult { + match members.try_insert(*identifier) { + Ok(true) => Ok(()), + Ok(false) => Err(Error::::CallAlreadyInGroup.into()), + Err(_) => Err(Error::::GroupMemberLimitExceeded.into()), + } + })?; + Ok(()) + } + + fn detach_call_from_group( + identifier: &TransactionIdentifier, + group: >::GroupId, + ) -> bool { + GroupMembers::::mutate(group, |members| members.remove(identifier)) + } + } + + #[pallet::call] + impl, I: 'static> Pallet { + /// Registers a call for rate limiting and seeds its initial configuration. + #[pallet::call_index(0)] + #[pallet::weight(T::DbWeight::get().reads_writes(3, 3))] + pub fn register_call( + origin: OriginFor, + call: Box<>::RuntimeCall>, + group: Option<>::GroupId>, + ) -> DispatchResult { + let resolver_origin: DispatchOriginOf<>::RuntimeCall> = + Into::>::RuntimeCall>>::into(origin.clone()); + let scope = + >::LimitScopeResolver::context(&resolver_origin, call.as_ref()); + + T::AdminOrigin::ensure_origin(origin)?; + + let identifier = TransactionIdentifier::from_call::(call.as_ref())?; + Self::ensure_call_unregistered(&identifier)?; + + let target = RateLimitTarget::Transaction(identifier); + + if let Some(ref sc) = scope { + Limits::::insert( + target, + RateLimit::scoped_single(sc.clone(), RateLimitKind::Default), + ); + } else { + Limits::::insert(target, RateLimit::global(RateLimitKind::Default)); + } + + let mut assigned_group = None; + if let Some(group_id) = group { + Self::ensure_group_details(group_id)?; + Self::insert_call_into_group(&identifier, group_id)?; + CallGroups::::insert(&identifier, group_id); + assigned_group = Some(group_id); + } + + let (pallet, extrinsic) = Self::call_metadata(&identifier)?; + Self::deposit_event(Event::CallRegistered { + transaction: identifier, + scope: scope.clone(), + group: assigned_group, + pallet: pallet.clone(), + extrinsic: extrinsic.clone(), + }); + + if let Some(group_id) = assigned_group { + Self::deposit_event(Event::CallGroupUpdated { + transaction: identifier, + group: Some(group_id), + }); + } + + Ok(()) + } + + /// Configures a rate limit for either a transaction or group target. + #[pallet::call_index(1)] + #[pallet::weight(T::DbWeight::get().reads_writes(2, 2))] + pub fn set_rate_limit( + origin: OriginFor, + target: RateLimitTarget<>::GroupId>, + scope: Option<>::LimitScope>, + limit: RateLimitKind>, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + let (transaction, pallet, extrinsic) = match target { + RateLimitTarget::Transaction(identifier) => { + Self::ensure_call_registered(&identifier)?; + if let Some(group) = CallGroups::::get(&identifier) { + let details = Self::ensure_group_details(group)?; + ensure!( + !details.sharing.config_uses_group(), + Error::::MustTargetGroup + ); + } + let (pallet, extrinsic) = Self::call_metadata(&identifier)?; + (Some(identifier), Some(pallet), Some(extrinsic)) + } + RateLimitTarget::Group(group) => { + Self::ensure_group_details(group)?; + (None, None, None) + } + }; + + if let Some(ref scoped) = scope { + Limits::::mutate(target, |slot| match slot { + Some(config) => config.upsert_scope(scoped.clone(), limit), + None => *slot = Some(RateLimit::scoped_single(scoped.clone(), limit)), + }); + } else { + Limits::::insert(target, RateLimit::global(limit)); + } + + Self::deposit_event(Event::RateLimitSet { + target, + transaction, + scope, + limit, + pallet, + extrinsic, + }); + Ok(()) + } + + /// Assigns a registered call to the specified group. + #[pallet::call_index(2)] + #[pallet::weight(T::DbWeight::get().reads_writes(3, 3))] + pub fn assign_call_to_group( + origin: OriginFor, + transaction: TransactionIdentifier, + group: >::GroupId, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + Self::ensure_call_registered(&transaction)?; + Self::ensure_group_details(group)?; + + let current = CallGroups::::get(&transaction); + if current == Some(group) { + return Err(Error::::CallAlreadyInGroup.into()); + } + + Self::insert_call_into_group(&transaction, group)?; + if let Some(existing) = current { + Self::detach_call_from_group(&transaction, existing); + } + CallGroups::::insert(&transaction, group); + + Self::deposit_event(Event::CallGroupUpdated { + transaction, + group: Some(group), + }); + + Ok(()) + } + + /// Removes a registered call from its current group assignment. + #[pallet::call_index(3)] + #[pallet::weight(T::DbWeight::get().reads_writes(2, 2))] + pub fn remove_call_from_group( + origin: OriginFor, + transaction: TransactionIdentifier, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + Self::ensure_call_registered(&transaction)?; + let Some(group) = CallGroups::::take(&transaction) else { + return Err(Error::::CallNotInGroup.into()); + }; + Self::detach_call_from_group(&transaction, group); + + Self::deposit_event(Event::CallGroupUpdated { + transaction, + group: None, + }); + + Ok(()) + } + + /// Sets the default rate limit that applies when an extrinsic uses [`RateLimitKind::Default`]. + #[pallet::call_index(4)] + #[pallet::weight(T::DbWeight::get().writes(1))] + pub fn set_default_rate_limit( + origin: OriginFor, + block_span: BlockNumberFor, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + DefaultLimit::::put(block_span); + Self::deposit_event(Event::DefaultRateLimitSet { block_span }); + Ok(()) + } + + /// Creates a new rate-limiting group with the provided name and sharing configuration. + #[pallet::call_index(5)] + #[pallet::weight(T::DbWeight::get().reads_writes(1, 3))] + pub fn create_group( + origin: OriginFor, + name: Vec, + sharing: GroupSharing, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + let bounded = Self::bounded_group_name(name)?; + Self::ensure_group_name_available(&bounded, None)?; + + let group = NextGroupId::::mutate(|current| { + let next = current.saturating_add(One::one()); + sp_std::mem::replace(current, next) + }); + + Groups::::insert( + group, + RateLimitGroup { + id: group, + name: bounded.clone(), + sharing, + }, + ); + GroupNameIndex::::insert(&bounded, group); + GroupMembers::::insert(group, GroupMembersOf::::new()); + + let name_bytes: Vec = bounded.into(); + Self::deposit_event(Event::GroupCreated { + group, + name: name_bytes, + sharing, + }); + Ok(()) + } + + /// Updates the metadata or sharing configuration of an existing group. + #[pallet::call_index(6)] + #[pallet::weight(T::DbWeight::get().reads_writes(3, 3))] + pub fn update_group( + origin: OriginFor, + group: >::GroupId, + name: Option>, + sharing: Option, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + Groups::::try_mutate(group, |maybe_details| -> DispatchResult { + let details = maybe_details.as_mut().ok_or(Error::::UnknownGroup)?; + + if let Some(new_name) = name { + let bounded = Self::bounded_group_name(new_name)?; + Self::ensure_group_name_available(&bounded, Some(group))?; + GroupNameIndex::::remove(&details.name); + GroupNameIndex::::insert(&bounded, group); + details.name = bounded; + } + + if let Some(new_sharing) = sharing { + details.sharing = new_sharing; + } + + Ok(()) + })?; + + let updated = Self::ensure_group_details(group)?; + let name_bytes: Vec = updated.name.clone().into(); + Self::deposit_event(Event::GroupUpdated { + group, + name: name_bytes, + sharing: updated.sharing, + }); + + Ok(()) + } + + /// Deletes an existing group. The group must be empty and unused. + #[pallet::call_index(7)] + #[pallet::weight(T::DbWeight::get().reads_writes(3, 3))] + pub fn delete_group( + origin: OriginFor, + group: >::GroupId, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + Self::ensure_group_deletable(group)?; + + let details = Groups::::take(group).ok_or(Error::::UnknownGroup)?; + GroupNameIndex::::remove(&details.name); + GroupMembers::::remove(group); + + Self::deposit_event(Event::GroupDeleted { group }); + + Ok(()) + } + + /// Deregisters a call or removes a scoped entry from its configuration. + #[pallet::call_index(8)] + #[pallet::weight(T::DbWeight::get().reads_writes(4, 4))] + pub fn deregister_call( + origin: OriginFor, + transaction: TransactionIdentifier, + scope: Option<>::LimitScope>, + clear_usage: bool, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + Self::ensure_call_registered(&transaction)?; + let target = Self::config_target(&transaction)?; + let tx_target = RateLimitTarget::Transaction(transaction); + let usage_target = Self::usage_target(&transaction)?; + + match &scope { + Some(sc) => { + let mut removed = false; + Limits::::mutate_exists(target, |maybe_config| { + if let Some(RateLimit::Scoped(map)) = maybe_config { + if map.remove(sc).is_some() { + removed = true; + if map.is_empty() { + *maybe_config = None; + } + } + } + }); + ensure!(removed, Error::::MissingRateLimit); + + if let Some(group) = CallGroups::::take(&transaction) { + Self::detach_call_from_group(&transaction, group); + Self::deposit_event(Event::CallGroupUpdated { + transaction, + group: None, + }); + } + } + None => { + Limits::::remove(target); + if target != tx_target { + Limits::::remove(tx_target); + } + + if let Some(group) = CallGroups::::take(&transaction) { + Self::detach_call_from_group(&transaction, group); + Self::deposit_event(Event::CallGroupUpdated { + transaction, + group: None, + }); + } + } + } + + if clear_usage { + let _ = LastSeen::::clear_prefix(&usage_target, u32::MAX, None); + } + + let (pallet, extrinsic) = Self::call_metadata(&transaction)?; + Self::deposit_event(Event::CallDeregistered { + target, + transaction: Some(transaction), + scope, + pallet: Some(pallet), + extrinsic: Some(extrinsic), + }); + + Ok(()) + } + } +} diff --git a/pallets/rate-limiting/src/mock.rs b/pallets/rate-limiting/src/mock.rs new file mode 100644 index 0000000000..b643dec64d --- /dev/null +++ b/pallets/rate-limiting/src/mock.rs @@ -0,0 +1,153 @@ +#![allow(dead_code)] + +use core::convert::TryInto; + +use frame_support::{ + derive_impl, + sp_runtime::{ + BuildStorage, + traits::{BlakeTwo256, IdentityLookup}, + }, + traits::{ConstU16, ConstU32, ConstU64, Everything}, +}; +use frame_system::EnsureRoot; +use sp_core::H256; +use sp_io::TestExternalities; +use sp_std::vec::Vec; + +use crate as pallet_rate_limiting; +use crate::TransactionIdentifier; + +pub type UncheckedExtrinsic = frame_system::mocking::MockUncheckedExtrinsic; +pub type Block = frame_system::mocking::MockBlock; + +frame_support::construct_runtime!( + pub enum Test { + System: frame_system, + RateLimiting: pallet_rate_limiting, + } +); + +#[derive_impl(frame_system::config_preludes::TestDefaultConfig)] +impl frame_system::Config for Test { + type BaseCallFilter = Everything; + type BlockWeights = (); + type BlockLength = (); + type DbWeight = (); + type RuntimeOrigin = RuntimeOrigin; + type RuntimeCall = RuntimeCall; + type Nonce = u64; + type Hash = H256; + type Hashing = BlakeTwo256; + type AccountId = u64; + type Lookup = IdentityLookup; + type RuntimeEvent = RuntimeEvent; + type BlockHashCount = ConstU64<250>; + type Version = (); + type PalletInfo = PalletInfo; + type AccountData = (); + type OnNewAccount = (); + type OnKilledAccount = (); + type SystemWeightInfo = (); + type SS58Prefix = ConstU16<42>; + type OnSetCode = (); + type MaxConsumers = ConstU32<16>; + type Block = Block; +} + +pub type LimitScope = u16; +pub type UsageKey = u16; +pub type GroupId = u32; + +pub struct TestScopeResolver; +pub struct TestUsageResolver; + +impl pallet_rate_limiting::RateLimitScopeResolver + for TestScopeResolver +{ + fn context(_origin: &RuntimeOrigin, call: &RuntimeCall) -> Option { + match call { + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span }) => { + (*block_span).try_into().ok() + } + RuntimeCall::RateLimiting(_) => Some(1), + _ => None, + } + } + + fn should_bypass(_origin: &RuntimeOrigin, call: &RuntimeCall) -> bool { + matches!( + call, + RuntimeCall::RateLimiting(RateLimitingCall::remove_call_from_group { .. }) + ) + } + + fn adjust_span(_origin: &RuntimeOrigin, call: &RuntimeCall, span: u64) -> u64 { + if matches!( + call, + RuntimeCall::RateLimiting(RateLimitingCall::deregister_call { .. }) + ) { + span.saturating_mul(2) + } else { + span + } + } +} + +impl pallet_rate_limiting::RateLimitUsageResolver + for TestUsageResolver +{ + fn context(_origin: &RuntimeOrigin, call: &RuntimeCall) -> Option { + match call { + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span }) => { + (*block_span).try_into().ok() + } + RuntimeCall::RateLimiting(_) => Some(1), + _ => None, + } + } +} + +impl pallet_rate_limiting::Config for Test { + type RuntimeCall = RuntimeCall; + type LimitScope = LimitScope; + type LimitScopeResolver = TestScopeResolver; + type UsageKey = UsageKey; + type UsageResolver = TestUsageResolver; + type AdminOrigin = EnsureRoot; + type GroupId = GroupId; + type MaxGroupMembers = ConstU32<32>; + type MaxGroupNameLength = ConstU32<64>; + #[cfg(feature = "runtime-benchmarks")] + type BenchmarkHelper = BenchHelper; +} + +#[cfg(feature = "runtime-benchmarks")] +pub struct BenchHelper; + +#[cfg(feature = "runtime-benchmarks")] +impl crate::BenchmarkHelper for BenchHelper { + fn sample_call() -> RuntimeCall { + RuntimeCall::System(frame_system::Call::remark { remark: Vec::new() }) + } +} + +pub type RateLimitingCall = crate::Call; + +pub fn new_test_ext() -> TestExternalities { + let storage = frame_system::GenesisConfig::::default() + .build_storage() + .expect("genesis build succeeds"); + + let mut ext = TestExternalities::new(storage); + ext.execute_with(|| System::set_block_number(1)); + ext +} + +pub(crate) fn identifier_for(call: &RuntimeCall) -> TransactionIdentifier { + TransactionIdentifier::from_call::(call).expect("identifier for call") +} + +pub(crate) fn pop_last_event() -> RuntimeEvent { + System::events().pop().expect("event expected").event +} diff --git a/pallets/rate-limiting/src/tests.rs b/pallets/rate-limiting/src/tests.rs new file mode 100644 index 0000000000..5027909b67 --- /dev/null +++ b/pallets/rate-limiting/src/tests.rs @@ -0,0 +1,656 @@ +use frame_support::{assert_noop, assert_ok}; +use sp_std::vec::Vec; + +use crate::{ + CallGroups, Config, GroupMembers, GroupSharing, LastSeen, Limits, RateLimit, RateLimitKind, + RateLimitTarget, TransactionIdentifier, mock::*, pallet::Error, +}; +use frame_support::traits::Get; + +fn target(identifier: TransactionIdentifier) -> RateLimitTarget { + RateLimitTarget::Transaction(identifier) +} + +fn remark_call() -> RuntimeCall { + RuntimeCall::System(frame_system::Call::::remark { remark: Vec::new() }) +} + +fn scoped_call() -> RuntimeCall { + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 1 }) +} + +fn register(call: RuntimeCall, group: Option) -> TransactionIdentifier { + let identifier = identifier_for(&call); + assert_ok!(RateLimiting::register_call( + RuntimeOrigin::root(), + Box::new(call), + group + )); + identifier +} + +fn create_group(name: &[u8], sharing: GroupSharing) -> GroupId { + assert_ok!(RateLimiting::create_group( + RuntimeOrigin::root(), + name.to_vec(), + sharing, + )); + RateLimiting::next_group_id().saturating_sub(1) +} + +fn last_event() -> RuntimeEvent { + pop_last_event() +} + +#[test] +fn register_call_seeds_global_limit() { + new_test_ext().execute_with(|| { + let identifier = register(remark_call(), None); + let tx_target = target(identifier); + let stored = Limits::::get(tx_target).expect("limit"); + assert!(matches!(stored, RateLimit::Global(RateLimitKind::Default))); + + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::CallRegistered { transaction, .. }) + if transaction == identifier + )); + }); +} + +#[test] +fn register_call_seeds_scoped_limit() { + new_test_ext().execute_with(|| { + let identifier = register(scoped_call(), None); + let tx_target = target(identifier); + let stored = Limits::::get(tx_target).expect("limit"); + match stored { + RateLimit::Scoped(map) => { + assert_eq!(map.get(&1u16), Some(&RateLimitKind::Default)); + } + _ => panic!("expected scoped entry"), + } + + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::CallRegistered { transaction, scope, .. }) + if transaction == identifier && scope == Some(1u16) + )); + }); +} + +#[test] +fn set_rate_limit_updates_transaction_target() { + new_test_ext().execute_with(|| { + let identifier = register(remark_call(), None); + let tx_target = target(identifier); + let limit = RateLimitKind::Exact(9); + assert_ok!(RateLimiting::set_rate_limit( + RuntimeOrigin::root(), + tx_target, + None, + limit, + )); + let stored = Limits::::get(tx_target).expect("limit"); + assert!(matches!(stored, RateLimit::Global(RateLimitKind::Exact(9)))); + + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::RateLimitSet { + target: RateLimitTarget::Transaction(t), + limit: RateLimitKind::Exact(9), + .. + }) if t == identifier + )); + }); +} + +#[test] +fn set_rate_limit_requires_registration_and_group_targeting() { + new_test_ext().execute_with(|| { + let identifier = register(remark_call(), None); + let target = target(identifier); + + // Unregistered call. + let unknown = TransactionIdentifier::new(99, 0); + assert_noop!( + RateLimiting::set_rate_limit( + RuntimeOrigin::root(), + RateLimitTarget::Transaction(unknown), + None, + RateLimitKind::Exact(1), + ), + Error::::CallNotRegistered + ); + + // Group requires targeting the group. + let group = create_group(b"cfg", GroupSharing::ConfigAndUsage); + assert_ok!(RateLimiting::assign_call_to_group( + RuntimeOrigin::root(), + identifier, + group, + )); + assert_noop!( + RateLimiting::set_rate_limit( + RuntimeOrigin::root(), + target, + None, + RateLimitKind::Exact(2), + ), + Error::::MustTargetGroup + ); + }); +} + +#[test] +fn set_rate_limit_respects_group_config_sharing() { + new_test_ext().execute_with(|| { + let identifier = register(remark_call(), None); + let group = create_group(b"test", GroupSharing::ConfigAndUsage); + assert_ok!(RateLimiting::assign_call_to_group( + RuntimeOrigin::root(), + identifier, + group, + )); + assert_noop!( + RateLimiting::set_rate_limit( + RuntimeOrigin::root(), + RateLimitTarget::Transaction(identifier), + None, + RateLimitKind::Exact(5), + ), + Error::::MustTargetGroup + ); + + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::CallGroupUpdated { + transaction, + group: Some(g), + }) if transaction == identifier && g == group + )); + }); +} + +#[test] +fn assign_and_remove_group_membership() { + new_test_ext().execute_with(|| { + let identifier = register(remark_call(), None); + let group = create_group(b"team", GroupSharing::UsageOnly); + assert_ok!(RateLimiting::assign_call_to_group( + RuntimeOrigin::root(), + identifier, + group, + )); + assert_eq!(CallGroups::::get(identifier), Some(group)); + assert!(GroupMembers::::get(group).contains(&identifier)); + assert_ok!(RateLimiting::remove_call_from_group( + RuntimeOrigin::root(), + identifier, + )); + assert!(CallGroups::::get(identifier).is_none()); + + // Last event should signal removal. + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::CallGroupUpdated { transaction, group: None }) + if transaction == identifier + )); + }); +} + +#[test] +fn set_rate_limit_on_group_updates_storage() { + new_test_ext().execute_with(|| { + let group = create_group(b"grp", GroupSharing::ConfigOnly); + let target = RateLimitTarget::Group(group); + assert_ok!(RateLimiting::set_rate_limit( + RuntimeOrigin::root(), + target, + None, + RateLimitKind::Exact(3), + )); + assert!(matches!( + Limits::::get(target), + Some(RateLimit::Global(RateLimitKind::Exact(3))) + )); + + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::RateLimitSet { + target: RateLimitTarget::Group(g), + limit: RateLimitKind::Exact(3), + .. + }) if g == group + )); + }); +} + +#[test] +fn create_and_delete_group_emit_events() { + new_test_ext().execute_with(|| { + assert_ok!(RateLimiting::create_group( + RuntimeOrigin::root(), + b"ev".to_vec(), + GroupSharing::UsageOnly, + )); + let group = RateLimiting::next_group_id().saturating_sub(1); + let created = last_event(); + assert!(matches!( + created, + RuntimeEvent::RateLimiting(crate::Event::GroupCreated { group: g, .. }) if g == group + )); + + assert_ok!(RateLimiting::delete_group(RuntimeOrigin::root(), group)); + let deleted = last_event(); + assert!(matches!( + deleted, + RuntimeEvent::RateLimiting(crate::Event::GroupDeleted { group: g }) if g == group + )); + }); +} + +#[test] +fn deregister_call_scope_removes_entry() { + new_test_ext().execute_with(|| { + let identifier = register(scoped_call(), None); + let tx_target = target(identifier); + assert_ok!(RateLimiting::set_rate_limit( + RuntimeOrigin::root(), + tx_target, + Some(2u16), + RateLimitKind::Exact(4), + )); + LastSeen::::insert(tx_target, Some(9u16), 10); + assert_ok!(RateLimiting::deregister_call( + RuntimeOrigin::root(), + identifier, + Some(2u16), + false, + )); + match Limits::::get(tx_target) { + Some(RateLimit::Scoped(map)) => { + assert!(map.contains_key(&1u16)); + assert!(!map.contains_key(&2u16)); + } + other => panic!("unexpected config: {:?}", other), + } + // usage remains intact when clear_usage is false + assert_eq!(LastSeen::::get(tx_target, Some(9u16)), Some(10)); + + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::CallDeregistered { + target, + transaction: Some(t), + scope: Some(sc), + .. + }) if target == tx_target && t == identifier && sc == 2u16 + )); + + // No group assigned in this test. + assert!(CallGroups::::get(identifier).is_none()); + }); +} + +#[test] +fn register_call_rejects_duplicates_and_unknown_group() { + new_test_ext().execute_with(|| { + let identifier = register(remark_call(), None); + // Duplicate should fail. + assert_noop!( + RateLimiting::register_call(RuntimeOrigin::root(), Box::new(remark_call()), None), + Error::::CallAlreadyRegistered + ); + + // Unknown group should fail. + assert_noop!( + RateLimiting::register_call(RuntimeOrigin::root(), Box::new(scoped_call()), Some(99)), + Error::::UnknownGroup + ); + + assert!(Limits::::contains_key(target(identifier))); + }); +} + +#[test] +fn group_name_limits_and_uniqueness_enforced() { + new_test_ext().execute_with(|| { + // Overlong name. + let max_name = <::MaxGroupNameLength as Get>::get() as usize; + let long_name = vec![0u8; max_name + 1]; + assert_noop!( + RateLimiting::create_group(RuntimeOrigin::root(), long_name, GroupSharing::UsageOnly), + Error::::GroupNameTooLong + ); + + // Duplicate names rejected on create and update. + let first = create_group(b"alpha", GroupSharing::UsageOnly); + let second = create_group(b"beta", GroupSharing::UsageOnly); + + assert_noop!( + RateLimiting::create_group( + RuntimeOrigin::root(), + b"alpha".to_vec(), + GroupSharing::UsageOnly + ), + Error::::DuplicateGroupName + ); + + assert_noop!( + RateLimiting::update_group( + RuntimeOrigin::root(), + second, + Some(b"alpha".to_vec()), + None + ), + Error::::DuplicateGroupName + ); + + // Unknown group update. + assert_noop!( + RateLimiting::update_group(RuntimeOrigin::root(), 99, None, None), + Error::::UnknownGroup + ); + + assert_eq!( + RateLimiting::groups(first).unwrap().name.into_inner(), + b"alpha".to_vec() + ); + + // Updating first group emits event. + assert_ok!(RateLimiting::update_group( + RuntimeOrigin::root(), + first, + Some(b"gamma".to_vec()), + None, + )); + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::GroupUpdated { group, .. }) if group == first + )); + }); +} + +#[test] +fn group_member_limit_and_removal_errors() { + new_test_ext().execute_with(|| { + let group = create_group(b"cap", GroupSharing::UsageOnly); + + let max_members = <::MaxGroupMembers as Get>::get(); + GroupMembers::::mutate(group, |members| { + for i in 0..max_members { + let _ = members.try_insert(TransactionIdentifier::new(0, (i + 1) as u8)); + } + }); + + // Next insert should fail. + let extra = register(remark_call(), None); + assert_noop!( + RateLimiting::assign_call_to_group(RuntimeOrigin::root(), extra, group), + Error::::GroupMemberLimitExceeded + ); + + // Removing a call not in a group errors. + assert_noop!( + RateLimiting::remove_call_from_group(RuntimeOrigin::root(), extra), + Error::::CallNotInGroup + ); + }); +} + +#[test] +fn cannot_delete_group_in_use_or_unknown() { + new_test_ext().execute_with(|| { + let group = create_group(b"busy", GroupSharing::ConfigOnly); + let identifier = register(remark_call(), Some(group)); + let target = RateLimitTarget::Group(group); + Limits::::insert(target, RateLimit::global(RateLimitKind::Exact(1))); + LastSeen::::insert(target, None::, 10); + + // Remove member so only config/usage keep the group in-use. + assert_ok!(RateLimiting::remove_call_from_group( + RuntimeOrigin::root(), + identifier + )); + + // Cannot delete when in use. + assert_noop!( + RateLimiting::delete_group(RuntimeOrigin::root(), group), + Error::::GroupInUse + ); + + // Clear state then delete. + Limits::::remove(target); + let _ = LastSeen::::clear_prefix(&target, u32::MAX, None); + assert_ok!(RateLimiting::delete_group(RuntimeOrigin::root(), group)); + + // Unknown group. + assert_noop!( + RateLimiting::delete_group(RuntimeOrigin::root(), 999), + Error::::UnknownGroup + ); + }); +} + +#[test] +fn deregister_call_clears_registration() { + new_test_ext().execute_with(|| { + let identifier = register(remark_call(), None); + let tx_target = target(identifier); + LastSeen::::insert(tx_target, None::, 5); + assert_ok!(RateLimiting::deregister_call( + RuntimeOrigin::root(), + identifier, + None, + true, + )); + assert!(Limits::::get(tx_target).is_none()); + assert!(LastSeen::::get(tx_target, None::).is_none()); + assert!(CallGroups::::get(identifier).is_none()); + + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::CallDeregistered { + target, + transaction: Some(t), + scope: None, + .. + }) if target == tx_target && t == identifier + )); + }); +} + +#[test] +fn deregister_errors_for_unknown_or_missing_scope() { + new_test_ext().execute_with(|| { + let unknown = TransactionIdentifier::new(10, 1); + assert_noop!( + RateLimiting::deregister_call(RuntimeOrigin::root(), unknown, None, true), + Error::::CallNotRegistered + ); + + let identifier = register(scoped_call(), None); + let tx_target = target(identifier); + // Removing a non-existent scoped entry fails. + assert_noop!( + RateLimiting::deregister_call(RuntimeOrigin::root(), identifier, Some(99u16), false), + Error::::MissingRateLimit + ); + + // Removing the last scoped entry clears Limits and LastSeen. + LastSeen::::insert(tx_target, Some(1u16), 5); + assert_ok!(RateLimiting::deregister_call( + RuntimeOrigin::root(), + identifier, + Some(1u16), + true, + )); + assert!(Limits::::get(tx_target).is_none()); + assert!(LastSeen::::get(tx_target, Some(1u16)).is_none()); + }); +} + +#[test] +fn is_within_limit_detects_rate_limited_scope() { + new_test_ext().execute_with(|| { + let call = scoped_call(); + let identifier = identifier_for(&call); + let tx_target = target(identifier); + Limits::::insert( + tx_target, + RateLimit::scoped_single(7u16, RateLimitKind::Exact(3)), + ); + LastSeen::::insert(tx_target, Some(1u16), 9); + System::set_block_number(11); + let result = RateLimiting::is_within_limit( + &RuntimeOrigin::signed(1), + &call, + &identifier, + &Some(7u16), + &Some(1u16), + ) + .expect("ok"); + assert!(!result); + }); +} + +#[test] +fn migrate_usage_key_tracks_scope() { + new_test_ext().execute_with(|| { + let call = scoped_call(); + let identifier = identifier_for(&call); + let tx_target = target(identifier); + LastSeen::::insert(tx_target, Some(6u16), 10); + assert!(RateLimiting::migrate_usage_key( + tx_target, + Some(6u16), + Some(7u16) + )); + assert_eq!(LastSeen::::get(tx_target, Some(7u16)), Some(10)); + }); +} + +#[test] +fn migrate_limit_scope_covers_transitions() { + new_test_ext().execute_with(|| { + let identifier = register(remark_call(), None); + let tx_target = target(identifier); + + // global -> scoped + assert!(RateLimiting::migrate_limit_scope( + tx_target, + None, + Some(42u16) + )); + match Limits::::get(tx_target) { + Some(RateLimit::Scoped(map)) => { + assert_eq!(map.get(&42u16), Some(&RateLimitKind::Default)) + } + other => panic!("unexpected config: {:?}", other), + } + + // scoped -> scoped + assert!(RateLimiting::migrate_limit_scope( + tx_target, + Some(42u16), + Some(43u16) + )); + match Limits::::get(tx_target) { + Some(RateLimit::Scoped(map)) => { + assert_eq!(map.get(&43u16), Some(&RateLimitKind::Default)) + } + other => panic!("unexpected config: {:?}", other), + } + + // scoped -> global (only entry) + assert!(RateLimiting::migrate_limit_scope( + tx_target, + Some(43u16), + None + )); + assert!(matches!( + Limits::::get(tx_target), + Some(RateLimit::Global(RateLimitKind::Default)) + )); + + // no-op when scopes identical + assert!(RateLimiting::migrate_limit_scope(tx_target, None, None)); + }); +} + +#[test] +fn set_default_limit_updates_span_and_resolves_in_enforcement() { + new_test_ext().execute_with(|| { + assert_eq!(RateLimiting::default_limit(), 0); + assert_ok!(RateLimiting::set_default_rate_limit( + RuntimeOrigin::root(), + 5 + )); + let event = last_event(); + assert!(matches!( + event, + RuntimeEvent::RateLimiting(crate::Event::DefaultRateLimitSet { block_span: 5 }) + )); + assert_eq!(RateLimiting::default_limit(), 5); + + let call = remark_call(); + let identifier = register(call.clone(), None); + let tx_target = target(identifier); + + System::set_block_number(10); + // No last-seen yet, first call passes. + assert!( + RateLimiting::is_within_limit( + &RuntimeOrigin::signed(1), + &call, + &identifier, + &None, + &None, + ) + .unwrap() + ); + + LastSeen::::insert(tx_target, None::, 12); + System::set_block_number(15); + // Span 5 should block when delta < 5. + assert!( + !RateLimiting::is_within_limit( + &RuntimeOrigin::signed(1), + &call, + &identifier, + &None, + &None, + ) + .unwrap() + ); + }); +} + +#[test] +fn limit_for_call_names_prefers_scoped_value() { + new_test_ext().execute_with(|| { + let call = scoped_call(); + let identifier = identifier_for(&call); + Limits::::insert( + target(identifier), + RateLimit::scoped_single(9u16, RateLimitKind::Exact(8)), + ); + let fetched = RateLimiting::limit_for_call_names( + "RateLimiting", + "set_default_rate_limit", + Some(9u16), + ) + .expect("limit"); + assert_eq!(fetched, RateLimitKind::Exact(8)); + }); +} diff --git a/pallets/rate-limiting/src/tx_extension.rs b/pallets/rate-limiting/src/tx_extension.rs new file mode 100644 index 0000000000..41b4add270 --- /dev/null +++ b/pallets/rate-limiting/src/tx_extension.rs @@ -0,0 +1,497 @@ +use codec::{Decode, DecodeWithMemTracking, Encode}; +use frame_support::{ + dispatch::{DispatchInfo, DispatchResult, PostDispatchInfo}, + pallet_prelude::Weight, + sp_runtime::{ + traits::{ + DispatchInfoOf, DispatchOriginOf, Dispatchable, Implication, TransactionExtension, + ValidateResult, Zero, + }, + transaction_validity::{ + InvalidTransaction, TransactionSource, TransactionValidityError, ValidTransaction, + }, + }, +}; +use scale_info::TypeInfo; +use sp_std::{marker::PhantomData, result::Result}; + +use crate::{ + Config, LastSeen, Pallet, + types::{ + RateLimitScopeResolver, RateLimitTarget, RateLimitUsageResolver, TransactionIdentifier, + }, +}; + +/// Identifier returned in the transaction metadata for the rate limiting extension. +const IDENTIFIER: &str = "RateLimitTransactionExtension"; + +/// Custom error code used to signal a rate limit violation. +const RATE_LIMIT_DENIED: u8 = 1; + +/// Transaction extension that enforces pallet rate limiting rules. +#[derive(Default, Encode, Decode, DecodeWithMemTracking, TypeInfo)] +pub struct RateLimitTransactionExtension(PhantomData<(T, I)>) +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo; + +impl Clone for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo, +{ + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl PartialEq for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo, +{ + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo, +{ +} + +impl core::fmt::Debug for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(IDENTIFIER) + } +} + +impl TransactionExtension<>::RuntimeCall> + for RateLimitTransactionExtension +where + T: Config + Send + Sync + TypeInfo, + I: 'static + TypeInfo + Send + Sync, + >::RuntimeCall: Dispatchable, +{ + const IDENTIFIER: &'static str = IDENTIFIER; + + type Implicit = (); + type Val = Option<( + RateLimitTarget<>::GroupId>, + Option<>::UsageKey>, + )>; + type Pre = Option<( + RateLimitTarget<>::GroupId>, + Option<>::UsageKey>, + )>; + + fn weight(&self, _call: &>::RuntimeCall) -> Weight { + Weight::zero() + } + + fn validate( + &self, + origin: DispatchOriginOf<>::RuntimeCall>, + call: &>::RuntimeCall, + _info: &DispatchInfoOf<>::RuntimeCall>, + _len: usize, + _self_implicit: Self::Implicit, + _inherited_implication: &impl Implication, + _source: TransactionSource, + ) -> ValidateResult>::RuntimeCall> { + if >::LimitScopeResolver::should_bypass(&origin, call) { + return Ok((ValidTransaction::default(), None, origin)); + } + + let identifier = match TransactionIdentifier::from_call::(call) { + Ok(identifier) => identifier, + Err(_) => return Err(TransactionValidityError::Invalid(InvalidTransaction::Call)), + }; + + let scope = >::LimitScopeResolver::context(&origin, call); + let usage = >::UsageResolver::context(&origin, call); + + let config_target = Pallet::::config_target(&identifier) + .map_err(|_| TransactionValidityError::Invalid(InvalidTransaction::Call))?; + let usage_target = Pallet::::usage_target(&identifier) + .map_err(|_| TransactionValidityError::Invalid(InvalidTransaction::Call))?; + + let Some(block_span) = + Pallet::::effective_span(&origin, call, &config_target, &scope) + else { + return Ok((ValidTransaction::default(), None, origin)); + }; + + if block_span.is_zero() { + return Ok((ValidTransaction::default(), None, origin)); + } + + let within_limit = Pallet::::within_span(&usage_target, &usage, block_span); + + if !within_limit { + return Err(TransactionValidityError::Invalid( + InvalidTransaction::Custom(RATE_LIMIT_DENIED), + )); + } + + Ok(( + ValidTransaction::default(), + Some((usage_target, usage)), + origin, + )) + } + + fn prepare( + self, + val: Self::Val, + _origin: &DispatchOriginOf<>::RuntimeCall>, + _call: &>::RuntimeCall, + _info: &DispatchInfoOf<>::RuntimeCall>, + _len: usize, + ) -> Result { + Ok(val) + } + + fn post_dispatch( + pre: Self::Pre, + _info: &DispatchInfoOf<>::RuntimeCall>, + _post_info: &mut PostDispatchInfo, + _len: usize, + result: &DispatchResult, + ) -> Result<(), TransactionValidityError> { + if result.is_ok() { + if let Some((target, usage)) = pre { + let block_number = frame_system::Pallet::::block_number(); + LastSeen::::insert(target, usage, block_number); + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use codec::Encode; + use frame_support::{ + assert_ok, + dispatch::{GetDispatchInfo, PostDispatchInfo}, + }; + use sp_runtime::{ + traits::{TransactionExtension, TxBaseImplication}, + transaction_validity::{InvalidTransaction, TransactionSource, TransactionValidityError}, + }; + + use crate::{ + GroupSharing, LastSeen, Limits, + types::{RateLimit, RateLimitKind}, + }; + + use super::*; + use crate::mock::*; + + fn remark_call() -> RuntimeCall { + RuntimeCall::System(frame_system::Call::::remark { remark: Vec::new() }) + } + + fn bypass_call() -> RuntimeCall { + RuntimeCall::RateLimiting(RateLimitingCall::remove_call_from_group { + transaction: TransactionIdentifier::new(0, 0), + }) + } + + fn adjustable_call() -> RuntimeCall { + RuntimeCall::RateLimiting(RateLimitingCall::deregister_call { + transaction: TransactionIdentifier::new(0, 0), + scope: None, + clear_usage: false, + }) + } + + fn new_tx_extension() -> RateLimitTransactionExtension { + RateLimitTransactionExtension(Default::default()) + } + + fn target_for_call(call: &RuntimeCall) -> RateLimitTarget { + RateLimitTarget::Transaction(identifier_for(call)) + } + + fn validate_with_tx_extension( + extension: &RateLimitTransactionExtension, + call: &RuntimeCall, + ) -> Result< + ( + sp_runtime::transaction_validity::ValidTransaction, + Option<(RateLimitTarget, Option)>, + RuntimeOrigin, + ), + TransactionValidityError, + > { + let info = call.get_dispatch_info(); + let len = call.encode().len(); + extension.validate( + RuntimeOrigin::signed(42), + call, + &info, + len, + (), + &TxBaseImplication(()), + TransactionSource::External, + ) + } + + #[test] + fn tx_extension_allows_calls_without_limit() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = remark_call(); + + let (_valid, val, _origin) = + validate_with_tx_extension(&extension, &call).expect("valid"); + assert!(val.is_none()); + + let info = call.get_dispatch_info(); + let len = call.encode().len(); + let origin_for_prepare = RuntimeOrigin::signed(42); + let pre = extension + .clone() + .prepare(val.clone(), &origin_for_prepare, &call, &info, len) + .expect("prepare succeeds"); + + let mut post = PostDispatchInfo::default(); + RateLimitTransactionExtension::::post_dispatch( + pre, + &info, + &mut post, + len, + &Ok(()), + ) + .expect("post_dispatch succeeds"); + + let target = target_for_call(&call); + assert_eq!(LastSeen::::get(target, None::), None); + }); + } + + #[test] + fn tx_extension_honors_bypass_signal() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = bypass_call(); + + let (valid, val, _) = + validate_with_tx_extension(&extension, &call).expect("bypass should succeed"); + assert_eq!(valid.priority, 0); + assert!(val.is_none()); + + let identifier = identifier_for(&call); + let target = RateLimitTarget::Transaction(identifier); + Limits::::insert(target, RateLimit::global(RateLimitKind::Exact(3))); + LastSeen::::insert(target, None::, 1); + + let (_valid, post_val, _) = + validate_with_tx_extension(&extension, &call).expect("still bypassed"); + assert!(post_val.is_none()); + }); + } + + #[test] + fn tx_extension_applies_adjusted_span() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = adjustable_call(); + let identifier = identifier_for(&call); + let target = RateLimitTarget::Transaction(identifier); + Limits::::insert(target, RateLimit::global(RateLimitKind::Exact(4))); + LastSeen::::insert(target, Some(1u16), 10); + + System::set_block_number(14); + + // Stored span (4) would allow the call, but adjusted span (8) should block it. + let err = validate_with_tx_extension(&extension, &call) + .expect_err("adjusted span should apply"); + match err { + TransactionValidityError::Invalid(InvalidTransaction::Custom(code)) => { + assert_eq!(code, RATE_LIMIT_DENIED); + } + other => panic!("unexpected error: {:?}", other), + } + }); + } + + #[test] + fn tx_extension_records_last_seen_for_successful_call() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = remark_call(); + let identifier = identifier_for(&call); + let target = RateLimitTarget::Transaction(identifier); + Limits::::insert(target, RateLimit::global(RateLimitKind::Exact(5))); + + System::set_block_number(10); + + let (_valid, val, _) = validate_with_tx_extension(&extension, &call).expect("valid"); + assert!(val.is_some()); + + let info = call.get_dispatch_info(); + let len = call.encode().len(); + let origin_for_prepare = RuntimeOrigin::signed(42); + let pre = extension + .clone() + .prepare(val.clone(), &origin_for_prepare, &call, &info, len) + .expect("prepare succeeds"); + + let mut post = PostDispatchInfo::default(); + RateLimitTransactionExtension::::post_dispatch( + pre, + &info, + &mut post, + len, + &Ok(()), + ) + .expect("post_dispatch succeeds"); + + assert_eq!( + LastSeen::::get(target, None::), + Some(10) + ); + }); + } + + #[test] + fn tx_extension_rejects_when_call_occurs_too_soon() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = remark_call(); + let identifier = identifier_for(&call); + let target = RateLimitTarget::Transaction(identifier); + Limits::::insert(target, RateLimit::global(RateLimitKind::Exact(5))); + LastSeen::::insert(target, None::, 20); + + System::set_block_number(22); + + let err = + validate_with_tx_extension(&extension, &call).expect_err("should be rate limited"); + match err { + TransactionValidityError::Invalid(InvalidTransaction::Custom(code)) => { + assert_eq!(code, 1); + } + other => panic!("unexpected error: {:?}", other), + } + }); + } + + #[test] + fn tx_extension_skips_last_seen_when_span_zero() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + let call = remark_call(); + let identifier = identifier_for(&call); + let target = RateLimitTarget::Transaction(identifier); + Limits::::insert(target, RateLimit::global(RateLimitKind::Exact(0))); + + System::set_block_number(30); + + let (_valid, val, _) = validate_with_tx_extension(&extension, &call).expect("valid"); + assert!(val.is_none()); + + let info = call.get_dispatch_info(); + let len = call.encode().len(); + let origin_for_prepare = RuntimeOrigin::signed(42); + let pre = extension + .clone() + .prepare(val.clone(), &origin_for_prepare, &call, &info, len) + .expect("prepare succeeds"); + + let mut post = PostDispatchInfo::default(); + RateLimitTransactionExtension::::post_dispatch( + pre, + &info, + &mut post, + len, + &Ok(()), + ) + .expect("post_dispatch succeeds"); + + assert_eq!(LastSeen::::get(target, None::), None); + }); + } + + #[test] + fn tx_extension_respects_usage_group_sharing() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + assert_ok!(RateLimiting::create_group( + RuntimeOrigin::root(), + b"use".to_vec(), + GroupSharing::UsageOnly, + )); + let group = RateLimiting::next_group_id().saturating_sub(1); + + let call = remark_call(); + let identifier = identifier_for(&call); + assert_ok!(RateLimiting::register_call( + RuntimeOrigin::root(), + Box::new(call.clone()), + Some(group), + )); + + let tx_target = RateLimitTarget::Transaction(identifier); + let usage_target = RateLimitTarget::Group(group); + Limits::::insert(tx_target, RateLimit::global(RateLimitKind::Exact(5))); + LastSeen::::insert(usage_target, None::, 10); + System::set_block_number(12); + + let err = validate_with_tx_extension(&extension, &call) + .expect_err("usage grouping should rate limit"); + match err { + TransactionValidityError::Invalid(InvalidTransaction::Custom(code)) => { + assert_eq!(code, RATE_LIMIT_DENIED); + } + other => panic!("unexpected error: {:?}", other), + } + }); + } + + #[test] + fn tx_extension_respects_config_group_sharing() { + new_test_ext().execute_with(|| { + let extension = new_tx_extension(); + assert_ok!(RateLimiting::create_group( + RuntimeOrigin::root(), + b"cfg".to_vec(), + GroupSharing::ConfigOnly, + )); + let group = RateLimiting::next_group_id().saturating_sub(1); + + let call = remark_call(); + let identifier = identifier_for(&call); + assert_ok!(RateLimiting::register_call( + RuntimeOrigin::root(), + Box::new(call.clone()), + Some(group), + )); + + let tx_target = RateLimitTarget::Transaction(identifier); + let group_target = RateLimitTarget::Group(group); + Limits::::remove(tx_target); + Limits::::insert(group_target, RateLimit::global(RateLimitKind::Exact(5))); + LastSeen::::insert(tx_target, None::, 10); + System::set_block_number(12); + + let err = validate_with_tx_extension(&extension, &call) + .expect_err("config grouping should rate limit"); + match err { + TransactionValidityError::Invalid(InvalidTransaction::Custom(code)) => { + assert_eq!(code, RATE_LIMIT_DENIED); + } + other => panic!("unexpected error: {:?}", other), + } + }); + } +} diff --git a/pallets/rate-limiting/src/types.rs b/pallets/rate-limiting/src/types.rs new file mode 100644 index 0000000000..1faff7c300 --- /dev/null +++ b/pallets/rate-limiting/src/types.rs @@ -0,0 +1,348 @@ +use codec::{Decode, DecodeWithMemTracking, Encode, MaxEncodedLen}; +use frame_support::{pallet_prelude::DispatchError, traits::GetCallMetadata}; +use scale_info::TypeInfo; +use serde::{Deserialize, Serialize}; +use sp_std::collections::btree_map::BTreeMap; + +/// Resolves the optional identifier within which a rate limit applies and can optionally adjust +/// enforcement behaviour. +pub trait RateLimitScopeResolver { + /// Returns `Some(scope)` when the limit should be applied per-scope, or `None` for global + /// limits. + fn context(origin: &Origin, call: &Call) -> Option; + + /// Returns `true` when the rate limit should be bypassed for the provided origin/call pair. + /// Defaults to `false`. + fn should_bypass(_origin: &Origin, _call: &Call) -> bool { + false + } + + /// Optionally adjusts the effective span used during enforcement. Defaults to the original + /// `span`. + fn adjust_span(_origin: &Origin, _call: &Call, span: Span) -> Span { + span + } +} + +/// Resolves the optional usage tracking key applied when enforcing limits. +pub trait RateLimitUsageResolver { + /// Returns `Some(usage)` when usage should be tracked per-key, or `None` for global usage + /// tracking. + fn context(origin: &Origin, call: &Call) -> Option; +} + +/// Identifies a runtime call by pallet and extrinsic indices. +#[derive( + Serialize, + Deserialize, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Encode, + Decode, + DecodeWithMemTracking, + TypeInfo, + MaxEncodedLen, + Debug, +)] +pub struct TransactionIdentifier { + /// Pallet variant index. + pub pallet_index: u8, + /// Call variant index within the pallet. + pub extrinsic_index: u8, +} + +/// Target identifier for rate limit and usage configuration. +#[derive( + Serialize, + Deserialize, + Clone, + Copy, + PartialEq, + Eq, + Encode, + Decode, + DecodeWithMemTracking, + TypeInfo, + MaxEncodedLen, + Debug, +)] +pub enum RateLimitTarget { + /// Per-transaction configuration keyed by pallet/extrinsic indices. + Transaction(TransactionIdentifier), + /// Shared configuration for a named group. + Group(GroupId), +} + +impl RateLimitTarget { + /// Returns the transaction identifier when the target represents a single extrinsic. + pub fn as_transaction(&self) -> Option<&TransactionIdentifier> { + match self { + RateLimitTarget::Transaction(identifier) => Some(identifier), + RateLimitTarget::Group(_) => None, + } + } + + /// Returns the group identifier when the target represents a group configuration. + pub fn as_group(&self) -> Option<&GroupId> { + match self { + RateLimitTarget::Transaction(_) => None, + RateLimitTarget::Group(id) => Some(id), + } + } +} + +/// Sharing mode configured for a group. +#[derive( + Serialize, + Deserialize, + Clone, + Copy, + PartialEq, + Eq, + Encode, + Decode, + DecodeWithMemTracking, + TypeInfo, + MaxEncodedLen, + Debug, +)] +pub enum GroupSharing { + /// Limits remain per transaction; usage is shared by the group. + UsageOnly, + /// Limits are shared by the group; usage remains per transaction. + ConfigOnly, + /// Both limits and usage are shared by the group. + ConfigAndUsage, +} + +impl GroupSharing { + /// Returns `true` when configuration for this group should use the group target key. + pub fn config_uses_group(self) -> bool { + matches!( + self, + GroupSharing::ConfigOnly | GroupSharing::ConfigAndUsage + ) + } + + /// Returns `true` when usage tracking for this group should use the group target key. + pub fn usage_uses_group(self) -> bool { + matches!(self, GroupSharing::UsageOnly | GroupSharing::ConfigAndUsage) + } +} + +/// Metadata describing a configured group. +#[derive( + Serialize, + Deserialize, + Clone, + PartialEq, + Eq, + Encode, + Decode, + DecodeWithMemTracking, + TypeInfo, + MaxEncodedLen, + Debug, +)] +pub struct RateLimitGroup { + /// Stable identifier assigned to the group. + pub id: GroupId, + /// Human readable group name. + pub name: Name, + /// Sharing configuration enforced for the group. + pub sharing: GroupSharing, +} + +impl TransactionIdentifier { + /// Builds a new identifier from pallet/extrinsic indices. + pub const fn new(pallet_index: u8, extrinsic_index: u8) -> Self { + Self { + pallet_index, + extrinsic_index, + } + } + + /// Returns the pallet and extrinsic names associated with this identifier. + pub fn names(&self) -> Result<(&'static str, &'static str), DispatchError> + where + T: crate::pallet::Config, + I: 'static, + >::RuntimeCall: GetCallMetadata, + { + let modules = >::RuntimeCall::get_module_names(); + let pallet_name = modules + .get(self.pallet_index as usize) + .copied() + .ok_or(crate::pallet::Error::::InvalidRuntimeCall)?; + let call_names = >::RuntimeCall::get_call_names(pallet_name); + let extrinsic_name = call_names + .get(self.extrinsic_index as usize) + .copied() + .ok_or(crate::pallet::Error::::InvalidRuntimeCall)?; + Ok((pallet_name, extrinsic_name)) + } + + /// Builds an identifier from a runtime call by extracting pallet/extrinsic indices. + pub fn from_call( + call: &>::RuntimeCall, + ) -> Result + where + T: crate::pallet::Config, + I: 'static, + { + call.using_encoded(|encoded| { + let pallet_index = *encoded + .get(0) + .ok_or(crate::pallet::Error::::InvalidRuntimeCall)?; + let extrinsic_index = *encoded + .get(1) + .ok_or(crate::pallet::Error::::InvalidRuntimeCall)?; + Ok(TransactionIdentifier::new(pallet_index, extrinsic_index)) + }) + } +} + +/// Policy describing the block span enforced by a rate limit. +#[derive( + Serialize, + Deserialize, + Clone, + Copy, + PartialEq, + Eq, + Encode, + Decode, + DecodeWithMemTracking, + TypeInfo, + MaxEncodedLen, + Debug, +)] +pub enum RateLimitKind { + /// Use the pallet-level default rate limit. + Default, + /// Apply an exact rate limit measured in blocks. + Exact(BlockNumber), +} + +/// Stored rate limit configuration for a transaction identifier. +/// +/// The configuration is mutually exclusive: either the call is globally limited or it stores a set +/// of per-scope spans. +#[derive( + Serialize, + Deserialize, + Clone, + PartialEq, + Eq, + Encode, + Decode, + DecodeWithMemTracking, + TypeInfo, + Debug, +)] +#[serde( + bound = "Scope: Ord + serde::Serialize + serde::de::DeserializeOwned, BlockNumber: serde::Serialize + serde::de::DeserializeOwned" +)] +pub enum RateLimit { + /// Global span applied to every invocation. + Global(RateLimitKind), + /// Per-scope spans keyed by `Scope`. + Scoped(BTreeMap>), +} + +impl RateLimit +where + Scope: Ord, +{ + /// Convenience helper to build a global configuration. + pub fn global(kind: RateLimitKind) -> Self { + Self::Global(kind) + } + + /// Convenience helper to build a scoped configuration containing a single entry. + pub fn scoped_single(scope: Scope, kind: RateLimitKind) -> Self { + let mut map = BTreeMap::new(); + map.insert(scope, kind); + Self::Scoped(map) + } + + /// Returns the span configured for the provided scope, if any. + pub fn kind_for(&self, scope: Option<&Scope>) -> Option<&RateLimitKind> { + match self { + RateLimit::Global(kind) => Some(kind), + RateLimit::Scoped(map) => scope.and_then(|key| map.get(key)), + } + } + + /// Inserts or updates a scoped entry, converting from a global configuration if needed. + pub fn upsert_scope(&mut self, scope: Scope, kind: RateLimitKind) { + match self { + RateLimit::Global(_) => { + let mut map = BTreeMap::new(); + map.insert(scope, kind); + *self = RateLimit::Scoped(map); + } + RateLimit::Scoped(map) => { + map.insert(scope, kind); + } + } + } + + /// Removes a scoped entry, returning whether one existed. + pub fn remove_scope(&mut self, scope: &Scope) -> bool { + match self { + RateLimit::Global(_) => false, + RateLimit::Scoped(map) => map.remove(scope).is_some(), + } + } + + /// Returns true when the scoped configuration contains no entries. + pub fn is_scoped_empty(&self) -> bool { + matches!(self, RateLimit::Scoped(map) if map.is_empty()) + } +} + +#[cfg(test)] +mod tests { + use sp_runtime::DispatchError; + + use super::*; + use crate::{mock::*, pallet::Error}; + + #[test] + fn transaction_identifier_from_call_matches_expected_indices() { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + + let identifier = TransactionIdentifier::from_call::(&call).expect("identifier"); + + // System is the first pallet in the mock runtime, RateLimiting is second. + assert_eq!(identifier.pallet_index, 1); + // set_default_rate_limit has call_index 4. + assert_eq!(identifier.extrinsic_index, 4); + } + + #[test] + fn transaction_identifier_names_matches_call_metadata() { + let call = + RuntimeCall::RateLimiting(RateLimitingCall::set_default_rate_limit { block_span: 0 }); + let identifier = TransactionIdentifier::from_call::(&call).expect("identifier"); + + let (pallet, extrinsic) = identifier.names::().expect("call metadata"); + assert_eq!(pallet, "RateLimiting"); + assert_eq!(extrinsic, "set_default_rate_limit"); + } + + #[test] + fn transaction_identifier_names_error_for_unknown_indices() { + let identifier = TransactionIdentifier::new(99, 0); + + let err = identifier.names::().expect_err("should fail"); + let expected: DispatchError = Error::::InvalidRuntimeCall.into(); + assert_eq!(err, expected); + } +} diff --git a/pallets/subtensor/Cargo.toml b/pallets/subtensor/Cargo.toml index 2e35a89d19..a7885e32a8 100644 --- a/pallets/subtensor/Cargo.toml +++ b/pallets/subtensor/Cargo.toml @@ -55,6 +55,7 @@ sha2.workspace = true rand_chacha.workspace = true pallet-crowdloan.workspace = true pallet-subtensor-proxy.workspace = true +pallet-rate-limiting.workspace = true [dev-dependencies] pallet-balances = { workspace = true, features = ["std"] } @@ -114,6 +115,7 @@ std = [ "pallet-crowdloan/std", "pallet-drand/std", "pallet-subtensor-proxy/std", + "pallet-rate-limiting/std", "pallet-subtensor-swap/std", "subtensor-swap-interface/std", "pallet-subtensor-utility/std", diff --git a/pallets/subtensor/src/utils/rate_limiting.rs b/pallets/subtensor/src/utils/rate_limiting.rs index 85f58cfc64..468aecd1c1 100644 --- a/pallets/subtensor/src/utils/rate_limiting.rs +++ b/pallets/subtensor/src/utils/rate_limiting.rs @@ -1,3 +1,5 @@ +use codec::{Decode, Encode}; +use scale_info::TypeInfo; use subtensor_runtime_common::NetUid; use super::*; diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 9760ac1b53..5d40215c49 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -38,6 +38,7 @@ frame-system = { workspace = true } frame-try-runtime = { workspace = true, optional = true } pallet-timestamp.workspace = true pallet-transaction-payment.workspace = true +pallet-rate-limiting.workspace = true pallet-subtensor-utility.workspace = true frame-executive.workspace = true frame-metadata-hash-extension.workspace = true @@ -52,6 +53,7 @@ sp-inherents.workspace = true sp-offchain.workspace = true sp-runtime.workspace = true sp-session.workspace = true +sp-io.workspace = true sp-std.workspace = true sp-transaction-pool.workspace = true sp-version.workspace = true @@ -153,7 +155,6 @@ ethereum.workspace = true [dev-dependencies] frame-metadata.workspace = true -sp-io.workspace = true sp-tracing.workspace = true [build-dependencies] @@ -187,6 +188,7 @@ std = [ "pallet-timestamp/std", "pallet-transaction-payment-rpc-runtime-api/std", "pallet-transaction-payment/std", + "pallet-rate-limiting/std", "pallet-subtensor-utility/std", "pallet-sudo/std", "pallet-multisig/std", @@ -328,6 +330,7 @@ try-runtime = [ "pallet-insecure-randomness-collective-flip/try-runtime", "pallet-timestamp/try-runtime", "pallet-transaction-payment/try-runtime", + "pallet-rate-limiting/try-runtime", "pallet-subtensor-utility/try-runtime", "pallet-safe-mode/try-runtime", "pallet-subtensor/try-runtime", diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 9ece1dd025..2c642c9af5 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -12,6 +12,7 @@ use core::num::NonZeroU64; pub mod check_nonce; mod migrations; +mod rate_limiting; pub mod transaction_payment_wrapper; extern crate alloc; @@ -70,6 +71,10 @@ use subtensor_precompiles::Precompiles; use subtensor_runtime_common::{AlphaCurrency, TaoCurrency, time::*, *}; use subtensor_swap_interface::{Order, SwapHandler}; +pub use rate_limiting::{ + ScopeResolver as RuntimeScopeResolver, UsageResolver as RuntimeUsageResolver, +}; + // A few exports that help ease life for downstream crates. pub use frame_support::{ StorageValue, construct_runtime, parameter_types, diff --git a/runtime/src/rate_limiting/migration.rs b/runtime/src/rate_limiting/migration.rs new file mode 100644 index 0000000000..c4217b6dd4 --- /dev/null +++ b/runtime/src/rate_limiting/migration.rs @@ -0,0 +1,1038 @@ +use core::convert::TryFrom; + +use codec::Encode; +use frame_support::{pallet_prelude::Parameter, traits::Get, weights::Weight}; +use frame_system::pallet_prelude::BlockNumberFor; +use log::info; +use pallet_rate_limiting::{ + GroupSharing, RateLimit, RateLimitGroup, RateLimitKind, RateLimitTarget, TransactionIdentifier, +}; +use pallet_subtensor::{ + self, AssociatedEvmAddress, Axons, Config as SubtensorConfig, HasMigrationRun, + LastRateLimitedBlock, LastUpdate, MaxUidsTrimmingRateLimit, MechanismCountCurrent, + MechanismCountSetRateLimit, MechanismEmissionRateLimit, NetworkRateLimit, + OwnerHyperparamRateLimit, Pallet, Prometheus, RateLimitKey, TransactionKeyLastBlock, + TxChildkeyTakeRateLimit, TxDelegateTakeRateLimit, TxRateLimit, WeightsVersionKeyRateLimit, + utils::rate_limiting::{Hyperparameter, TransactionType}, +}; +use sp_io::{ + hashing::{blake2_128, twox_128}, + storage, +}; +use sp_runtime::traits::SaturatedConversion; +use sp_std::{ + collections::{btree_map::BTreeMap, btree_set::BTreeSet}, + vec, + vec::Vec, +}; +use subtensor_runtime_common::{MechId, NetUid, RateLimitScope, RateLimitUsageKey}; + +type RateLimitConfigOf = RateLimit>; +type RateLimitTargetOf = RateLimitTarget; +type RateLimitGroupOf = RateLimitGroup>; +type LimitEntries = Vec<(RateLimitTargetOf, RateLimitConfigOf)>; +type LastSeenKey = ( + RateLimitTargetOf, + Option::AccountId>>, +); +type LastSeenEntries = Vec<(LastSeenKey, BlockNumberFor)>; + +/// Pallet index assigned to `pallet_subtensor` in `construct_runtime!`. +const SUBTENSOR_PALLET_INDEX: u8 = 7; +/// Pallet index assigned to `pallet_admin_utils` in `construct_runtime!`. +const ADMIN_UTILS_PALLET_INDEX: u8 = 19; + +/// Marker stored in `HasMigrationRun` once the migration finishes. +const MIGRATION_NAME: &[u8] = b"migrate_rate_limiting"; + +/// `set_children` is rate-limited to once every 150 blocks. +const SET_CHILDREN_RATE_LIMIT: u64 = 150; +/// `set_sn_owner_hotkey` default interval (blocks). +const DEFAULT_SET_SN_OWNER_HOTKEY_LIMIT: u64 = 50_400; + +type GroupId = u32; + +struct GroupDefinition { + id: GroupId, + name: &'static [u8], + sharing: GroupSharing, + members: Vec, +} + +const GROUP_SERVE_AXON: GroupId = 0; +const GROUP_DELEGATE_TAKE: GroupId = 1; +const GROUP_WEIGHTS_SUBNET: GroupId = 2; +const GROUP_WEIGHTS_MECHANISM: GroupId = 3; +const GROUP_REGISTER_NETWORK: GroupId = 4; +const GROUP_OWNER_HPARAMS: GroupId = 5; + +fn hyperparameter_identifiers() -> Vec { + HYPERPARAMETERS + .iter() + .filter_map(|h| identifier_for_hyperparameter(*h)) + .collect() +} + +fn group_definitions() -> Vec { + vec![ + GroupDefinition { + id: GROUP_SERVE_AXON, + name: b"serve-axon", + sharing: GroupSharing::ConfigAndUsage, + members: vec![subtensor_identifier(4), subtensor_identifier(40)], + }, + GroupDefinition { + id: GROUP_DELEGATE_TAKE, + name: b"delegate-take", + sharing: GroupSharing::ConfigAndUsage, + members: vec![subtensor_identifier(66), subtensor_identifier(65)], + }, + GroupDefinition { + id: GROUP_WEIGHTS_SUBNET, + name: b"weights-subnet", + sharing: GroupSharing::ConfigAndUsage, + members: vec![ + subtensor_identifier(0), + subtensor_identifier(96), + subtensor_identifier(100), + subtensor_identifier(113), + ], + }, + GroupDefinition { + id: GROUP_WEIGHTS_MECHANISM, + name: b"weights-mechanism", + sharing: GroupSharing::ConfigAndUsage, + members: vec![ + subtensor_identifier(119), + subtensor_identifier(115), + subtensor_identifier(117), + subtensor_identifier(118), + ], + }, + GroupDefinition { + id: GROUP_REGISTER_NETWORK, + name: b"register-network", + sharing: GroupSharing::ConfigAndUsage, + members: vec![subtensor_identifier(59), subtensor_identifier(79)], + }, + GroupDefinition { + id: GROUP_OWNER_HPARAMS, + name: b"owner-hparams", + sharing: GroupSharing::ConfigOnly, + members: hyperparameter_identifiers(), + }, + ] +} + +/// Hyperparameter extrinsics routed through owner-or-root rate limiting. +const HYPERPARAMETERS: &[Hyperparameter] = &[ + Hyperparameter::ServingRateLimit, + Hyperparameter::MaxDifficulty, + Hyperparameter::AdjustmentAlpha, + Hyperparameter::ImmunityPeriod, + Hyperparameter::MinAllowedWeights, + Hyperparameter::MaxAllowedUids, + Hyperparameter::Kappa, + Hyperparameter::Rho, + Hyperparameter::ActivityCutoff, + Hyperparameter::PowRegistrationAllowed, + Hyperparameter::MinBurn, + Hyperparameter::MaxBurn, + Hyperparameter::BondsMovingAverage, + Hyperparameter::BondsPenalty, + Hyperparameter::CommitRevealEnabled, + Hyperparameter::LiquidAlphaEnabled, + Hyperparameter::AlphaValues, + Hyperparameter::WeightCommitInterval, + Hyperparameter::TransferEnabled, + Hyperparameter::AlphaSigmoidSteepness, + Hyperparameter::Yuma3Enabled, + Hyperparameter::BondsResetEnabled, + Hyperparameter::ImmuneNeuronLimit, + Hyperparameter::RecycleOrBurn, +]; + +#[derive(Clone, Copy)] +struct GroupInfo { + id: GroupId, + sharing: GroupSharing, +} + +#[derive(Default)] +struct Grouping { + assignments: BTreeMap, + members: BTreeMap>, + details: Vec, + next_group_id: GroupId, + max_group_id: Option, +} + +impl Grouping { + fn members(&self, id: GroupId) -> Option<&BTreeSet> { + self.members.get(&id) + } + + fn insert_group( + &mut self, + id: GroupId, + name: &[u8], + sharing: GroupSharing, + members: &[TransactionIdentifier], + ) { + let entry = self.members.entry(id).or_insert_with(BTreeSet::new); + for member in members { + self.assignments.insert(*member, GroupInfo { id, sharing }); + entry.insert(*member); + } + + self.details.push(RateLimitGroup { + id, + name: name.to_vec(), + sharing, + }); + + self.max_group_id = Some(self.max_group_id.map_or(id, |current| current.max(id))); + } + + fn finalize_next_id(&mut self) { + self.next_group_id = self.max_group_id.map_or(0, |id| id.saturating_add(1)); + } + + fn config_target(&self, identifier: TransactionIdentifier) -> RateLimitTargetOf { + if let Some(info) = self.assignments.get(&identifier) { + if info.sharing.config_uses_group() { + return RateLimitTarget::Group(info.id); + } + } + RateLimitTarget::Transaction(identifier) + } + + fn usage_target(&self, identifier: TransactionIdentifier) -> RateLimitTargetOf { + if let Some(info) = self.assignments.get(&identifier) { + if info.sharing.usage_uses_group() { + return RateLimitTarget::Group(info.id); + } + } + RateLimitTarget::Transaction(identifier) + } +} + +const SERVE_PROM_IDENTIFIER: TransactionIdentifier = subtensor_identifier(5); + +fn serve_calls(grouping: &Grouping) -> Vec { + let mut calls = Vec::new(); + if let Some(members) = grouping.members(GROUP_SERVE_AXON) { + calls.extend(members.iter().copied()); + } + calls.push(SERVE_PROM_IDENTIFIER); + calls +} + +fn weight_calls_subnet(grouping: &Grouping) -> Vec { + grouping + .members(GROUP_WEIGHTS_SUBNET) + .map(|m| m.iter().copied().collect()) + .unwrap_or_default() +} + +fn weight_calls_mechanism(grouping: &Grouping) -> Vec { + grouping + .members(GROUP_WEIGHTS_MECHANISM) + .map(|m| m.iter().copied().collect()) + .unwrap_or_default() +} + +fn build_grouping() -> Grouping { + let mut grouping = Grouping::default(); + + for definition in group_definitions() { + grouping.insert_group( + definition.id, + definition.name, + definition.sharing, + &definition.members, + ); + } + + grouping.finalize_next_id(); + grouping +} + +pub fn migrate_rate_limiting() -> Weight { + let mut weight = T::DbWeight::get().reads(1); + if HasMigrationRun::::get(MIGRATION_NAME) { + info!("Rate-limiting migration already executed. Skipping."); + return weight; + } + + let grouping = build_grouping(); + let (limits, limit_reads) = build_limits::(&grouping); + let (last_seen, seen_reads) = build_last_seen::(&grouping); + + let limit_writes = write_limits::(&limits); + let seen_writes = write_last_seen::(&last_seen); + let group_writes = write_groups::(&grouping); + + HasMigrationRun::::insert(MIGRATION_NAME, true); + + weight = weight + .saturating_add(T::DbWeight::get().reads(limit_reads.saturating_add(seen_reads))) + .saturating_add( + T::DbWeight::get().writes( + limit_writes + .saturating_add(seen_writes) + .saturating_add(group_writes) + .saturating_add(1), + ), + ); + + info!( + "Migrated {} rate-limit configs, {} last-seen entries, and {} groups into pallet-rate-limiting", + limits.len(), + last_seen.len(), + grouping.details.len() + ); + + weight +} + +fn build_limits(grouping: &Grouping) -> (LimitEntries, u64) { + let mut limits = LimitEntries::::new(); + let mut reads: u64 = 0; + + reads += gather_simple_limits::(&mut limits, grouping); + reads += gather_owner_hparam_limits::(&mut limits, grouping); + reads += gather_serving_limits::(&mut limits, grouping); + reads += gather_weight_limits::(&mut limits, grouping); + + (limits, reads) +} + +fn gather_simple_limits( + limits: &mut LimitEntries, + grouping: &Grouping, +) -> u64 { + let mut reads: u64 = 0; + + reads += 1; + if let Some(span) = block_number::(TxRateLimit::::get()) { + set_global_limit::( + limits, + grouping.config_target(subtensor_identifier(70)), + span, + ); + } + + reads += 1; + if let Some(span) = block_number::(TxDelegateTakeRateLimit::::get()) { + if let Some(members) = grouping.members(GROUP_DELEGATE_TAKE) { + for call in members { + set_global_limit::(limits, grouping.config_target(*call), span); + } + } + } + + reads += 1; + if let Some(span) = block_number::(TxChildkeyTakeRateLimit::::get()) { + set_global_limit::( + limits, + grouping.config_target(subtensor_identifier(75)), + span, + ); + } + + reads += 1; + if let Some(span) = block_number::(NetworkRateLimit::::get()) { + if let Some(members) = grouping.members(GROUP_REGISTER_NETWORK) { + for call in members { + set_global_limit::(limits, grouping.config_target(*call), span); + } + } + } + + reads += 1; + if let Some(span) = block_number::(WeightsVersionKeyRateLimit::::get()) { + set_global_limit::( + limits, + grouping.config_target(admin_utils_identifier(6)), + span, + ); + } + + if let Some(span) = block_number::(DEFAULT_SET_SN_OWNER_HOTKEY_LIMIT) { + set_global_limit::( + limits, + grouping.config_target(admin_utils_identifier(67)), + span, + ); + } + + if let Some(span) = block_number::(::EvmKeyAssociateRateLimit::get()) { + set_global_limit::( + limits, + grouping.config_target(subtensor_identifier(93)), + span, + ); + } + + if let Some(span) = block_number::(MechanismCountSetRateLimit::::get()) { + set_global_limit::( + limits, + grouping.config_target(admin_utils_identifier(76)), + span, + ); + } + + if let Some(span) = block_number::(MechanismEmissionRateLimit::::get()) { + set_global_limit::( + limits, + grouping.config_target(admin_utils_identifier(77)), + span, + ); + } + + if let Some(span) = block_number::(MaxUidsTrimmingRateLimit::::get()) { + set_global_limit::( + limits, + grouping.config_target(admin_utils_identifier(78)), + span, + ); + } + + if let Some(span) = block_number::(SET_CHILDREN_RATE_LIMIT) { + set_global_limit::( + limits, + grouping.config_target(subtensor_identifier(67)), + span, + ); + } + + reads +} + +fn gather_owner_hparam_limits( + limits: &mut LimitEntries, + grouping: &Grouping, +) -> u64 { + let mut reads: u64 = 0; + + reads += 1; + if let Some(span) = block_number::(u64::from(OwnerHyperparamRateLimit::::get())) { + for hparam in HYPERPARAMETERS { + if let Some(identifier) = identifier_for_hyperparameter(*hparam) { + set_global_limit::(limits, grouping.config_target(identifier), span); + } + } + } + + reads +} + +fn gather_serving_limits( + limits: &mut LimitEntries, + grouping: &Grouping, +) -> u64 { + let mut reads: u64 = 0; + let netuids = Pallet::::get_all_subnet_netuids(); + + for netuid in netuids { + reads += 1; + if let Some(span) = block_number::(Pallet::::get_serving_rate_limit(netuid)) { + for call in serve_calls(grouping) { + set_scoped_limit::( + limits, + grouping.config_target(call), + RateLimitScope::Subnet(netuid), + span, + ); + } + } + } + + reads +} + +fn gather_weight_limits( + limits: &mut LimitEntries, + grouping: &Grouping, +) -> u64 { + let mut reads: u64 = 0; + let netuids = Pallet::::get_all_subnet_netuids(); + + let mut subnet_limits = BTreeMap::>::new(); + let subnet_calls = weight_calls_subnet(grouping); + let mechanism_calls = weight_calls_mechanism(grouping); + for netuid in &netuids { + reads += 1; + if let Some(span) = block_number::(Pallet::::get_weights_set_rate_limit(*netuid)) { + subnet_limits.insert(*netuid, span); + for call in &subnet_calls { + set_scoped_limit::( + limits, + grouping.config_target(*call), + RateLimitScope::Subnet(*netuid), + span, + ); + } + } + } + + for netuid in &netuids { + reads += 1; + let mech_count: u8 = MechanismCountCurrent::::get(*netuid).into(); + if mech_count <= 1 { + continue; + } + let Some(span) = subnet_limits.get(netuid).copied() else { + continue; + }; + for mecid in 1..mech_count { + let scope = RateLimitScope::SubnetMechanism { + netuid: *netuid, + mecid: MechId::from(mecid), + }; + for call in &mechanism_calls { + set_scoped_limit::(limits, grouping.config_target(*call), scope.clone(), span); + } + } + } + + reads +} + +fn build_last_seen(grouping: &Grouping) -> (LastSeenEntries, u64) { + let mut last_seen = LastSeenEntries::::new(); + let mut reads: u64 = 0; + + reads += import_last_rate_limited_blocks::(&mut last_seen, grouping); + reads += import_transaction_key_last_blocks::(&mut last_seen, grouping); + reads += import_last_update_entries::(&mut last_seen, grouping); + reads += import_serving_entries::(&mut last_seen, grouping); + reads += import_evm_entries::(&mut last_seen, grouping); + + (last_seen, reads) +} + +fn import_last_rate_limited_blocks( + entries: &mut LastSeenEntries, + grouping: &Grouping, +) -> u64 { + let mut reads: u64 = 0; + for (key, block) in LastRateLimitedBlock::::iter() { + reads += 1; + if block == 0 { + continue; + } + match key { + RateLimitKey::SetSNOwnerHotkey(netuid) => { + if let Some(identifier) = + identifier_for_transaction_type(TransactionType::SetSNOwnerHotkey) + { + record_last_seen_entry::( + entries, + grouping.usage_target(identifier), + Some(RateLimitUsageKey::Subnet(netuid)), + block, + ); + } + } + RateLimitKey::OwnerHyperparamUpdate(netuid, hyper) => { + if let Some(identifier) = identifier_for_hyperparameter(hyper) { + record_last_seen_entry::( + entries, + grouping.usage_target(identifier), + Some(RateLimitUsageKey::Subnet(netuid)), + block, + ); + } + } + RateLimitKey::LastTxBlock(account) => { + record_last_seen_entry::( + entries, + grouping.usage_target(subtensor_identifier(70)), + Some(RateLimitUsageKey::Account(account.clone())), + block, + ); + } + RateLimitKey::LastTxBlockDelegateTake(account) => { + record_last_seen_entry::( + entries, + grouping.usage_target(subtensor_identifier(66)), + Some(RateLimitUsageKey::Account(account.clone())), + block, + ); + } + RateLimitKey::NetworkLastRegistered => { + record_last_seen_entry::( + entries, + grouping.usage_target(subtensor_identifier(59)), + None, + block, + ); + } + RateLimitKey::LastTxBlockChildKeyTake(_) => { + // Deprecated storage; ignored. + } + } + } + reads +} + +fn import_transaction_key_last_blocks( + entries: &mut LastSeenEntries, + grouping: &Grouping, +) -> u64 { + let mut reads: u64 = 0; + for ((account, netuid, tx_kind), block) in TransactionKeyLastBlock::::iter() { + reads += 1; + if block == 0 { + continue; + } + let tx_type = TransactionType::from(tx_kind); + let Some(identifier) = identifier_for_transaction_type(tx_type) else { + continue; + }; + let Some(usage) = usage_key_from_transaction_type(tx_type, &account, netuid) else { + continue; + }; + record_last_seen_entry::( + entries, + grouping.usage_target(identifier), + Some(usage), + block, + ); + } + reads +} + +fn import_last_update_entries( + entries: &mut LastSeenEntries, + grouping: &Grouping, +) -> u64 { + let mut reads: u64 = 0; + let subnet_calls = weight_calls_subnet(grouping); + let mechanism_calls = weight_calls_mechanism(grouping); + for (index, blocks) in LastUpdate::::iter() { + reads += 1; + let netuid = Pallet::::get_netuid(index); + let sub_id = u16::from(index) + .checked_div(pallet_subtensor::subnets::mechanism::GLOBAL_MAX_SUBNET_COUNT) + .unwrap_or_default(); + let is_mechanism = sub_id != 0; + let Ok(sub_id) = u8::try_from(sub_id) else { + continue; + }; + let mecid = MechId::from(sub_id); + + for (uid, last_block) in blocks.into_iter().enumerate() { + if last_block == 0 { + continue; + } + let Ok(uid_u16) = u16::try_from(uid) else { + continue; + }; + let usage = if is_mechanism { + RateLimitUsageKey::SubnetMechanismNeuron { + netuid, + mecid, + uid: uid_u16, + } + } else { + RateLimitUsageKey::SubnetNeuron { + netuid, + uid: uid_u16, + } + }; + + let call_set: &[TransactionIdentifier] = if is_mechanism { + mechanism_calls.as_slice() + } else { + subnet_calls.as_slice() + }; + + for call in call_set { + record_last_seen_entry::( + entries, + grouping.usage_target(*call), + Some(usage.clone()), + last_block, + ); + } + } + } + reads +} + +fn import_serving_entries( + entries: &mut LastSeenEntries, + grouping: &Grouping, +) -> u64 { + let mut reads: u64 = 0; + for (netuid, hotkey, axon) in Axons::::iter() { + reads += 1; + if axon.block == 0 { + continue; + } + let usage = RateLimitUsageKey::AccountSubnet { + account: hotkey.clone(), + netuid, + }; + let axon_calls: Vec<_> = grouping + .members(GROUP_SERVE_AXON) + .map(|m| m.iter().copied().collect()) + .unwrap_or_else(|| vec![subtensor_identifier(4), subtensor_identifier(40)]); + for call in axon_calls { + record_last_seen_entry::( + entries, + grouping.usage_target(call), + Some(usage.clone()), + axon.block, + ); + } + } + + for (netuid, hotkey, prom) in Prometheus::::iter() { + reads += 1; + if prom.block == 0 { + continue; + } + let usage = RateLimitUsageKey::AccountSubnet { + account: hotkey, + netuid, + }; + record_last_seen_entry::( + entries, + grouping.usage_target(SERVE_PROM_IDENTIFIER), + Some(usage), + prom.block, + ); + } + + reads +} + +fn import_evm_entries( + entries: &mut LastSeenEntries, + grouping: &Grouping, +) -> u64 { + let mut reads: u64 = 0; + for (netuid, uid, (_, block)) in AssociatedEvmAddress::::iter() { + reads += 1; + if block == 0 { + continue; + } + record_last_seen_entry::( + entries, + grouping.usage_target(subtensor_identifier(93)), + Some(RateLimitUsageKey::SubnetNeuron { netuid, uid }), + block, + ); + } + reads +} + +/// TODO(rate-limiting-storage): Swap these manual writes for +/// `pallet_rate_limiting::Pallet` APIs once the runtime wires the pallet in. +fn write_limits(limits: &LimitEntries) -> u64 { + if limits.is_empty() { + return 0; + } + let limits_prefix = storage_prefix("RateLimiting", "Limits"); + let mut writes = 0; + for (identifier, limit) in limits.iter() { + let limit_key = map_storage_key(&limits_prefix, identifier); + storage::set(&limit_key, &limit.encode()); + writes += 1; + } + writes +} + +fn write_last_seen(entries: &LastSeenEntries) -> u64 { + if entries.is_empty() { + return 0; + } + let prefix = storage_prefix("RateLimiting", "LastSeen"); + let mut writes = 0; + for ((identifier, usage), block) in entries.iter() { + let key = double_map_storage_key(&prefix, identifier, usage); + storage::set(&key, &block.encode()); + writes += 1; + } + writes +} + +fn write_groups(grouping: &Grouping) -> u64 { + let mut writes = 0; + let groups_prefix = storage_prefix("RateLimiting", "Groups"); + let members_prefix = storage_prefix("RateLimiting", "GroupMembers"); + let name_index_prefix = storage_prefix("RateLimiting", "GroupNameIndex"); + let call_groups_prefix = storage_prefix("RateLimiting", "CallGroups"); + let next_group_id_prefix = storage_prefix("RateLimiting", "NextGroupId"); + + for detail in &grouping.details { + let group_key = map_storage_key(&groups_prefix, detail.id); + storage::set(&group_key, &detail.encode()); + writes += 1; + + let name_key = map_storage_key(&name_index_prefix, detail.name.clone()); + storage::set(&name_key, &detail.id.encode()); + writes += 1; + } + + for (group, members) in &grouping.members { + let members_key = map_storage_key(&members_prefix, *group); + storage::set(&members_key, &members.encode()); + writes += 1; + } + + for (identifier, info) in &grouping.assignments { + let call_key = map_storage_key(&call_groups_prefix, *identifier); + storage::set(&call_key, &info.id.encode()); + writes += 1; + } + + storage::set(&next_group_id_prefix, &grouping.next_group_id.encode()); + writes += 1; + + writes +} + +fn block_number(value: u64) -> Option> { + if value == 0 { + return None; + } + Some(value.saturated_into::>()) +} + +fn set_global_limit( + limits: &mut LimitEntries, + target: RateLimitTargetOf, + span: BlockNumberFor, +) { + if let Some((_, config)) = limits.iter_mut().find(|(id, _)| *id == target) { + *config = RateLimit::global(RateLimitKind::Exact(span)); + } else { + limits.push((target, RateLimit::global(RateLimitKind::Exact(span)))); + } +} + +fn set_scoped_limit( + limits: &mut LimitEntries, + target: RateLimitTargetOf, + scope: RateLimitScope, + span: BlockNumberFor, +) { + if let Some((_, config)) = limits.iter_mut().find(|(id, _)| *id == target) { + match config { + RateLimit::Global(_) => { + *config = RateLimit::scoped_single(scope, RateLimitKind::Exact(span)); + } + RateLimit::Scoped(map) => { + map.insert(scope, RateLimitKind::Exact(span)); + } + } + } else { + limits.push(( + target, + RateLimit::scoped_single(scope, RateLimitKind::Exact(span)), + )); + } +} + +fn record_last_seen_entry( + entries: &mut LastSeenEntries, + target: RateLimitTargetOf, + usage: Option>, + block: u64, +) { + let Some(block_number) = block_number::(block) else { + return; + }; + + let key = (target, usage); + if let Some((_, existing)) = entries.iter_mut().find(|(entry_key, _)| *entry_key == key) { + if block_number > *existing { + *existing = block_number; + } + } else { + entries.push((key, block_number)); + } +} + +fn storage_prefix(pallet: &str, storage: &str) -> Vec { + let mut out = Vec::with_capacity(32); + out.extend_from_slice(&twox_128(pallet.as_bytes())); + out.extend_from_slice(&twox_128(storage.as_bytes())); + out +} + +fn map_storage_key(prefix: &[u8], key: impl Encode) -> Vec { + let mut final_key = Vec::with_capacity(prefix.len() + 32); + final_key.extend_from_slice(prefix); + let encoded = key.encode(); + let hash = blake2_128(&encoded); + final_key.extend_from_slice(&hash); + final_key.extend_from_slice(&encoded); + final_key +} + +fn double_map_storage_key(prefix: &[u8], key1: impl Encode, key2: impl Encode) -> Vec { + let mut final_key = Vec::with_capacity(prefix.len() + 64); + final_key.extend_from_slice(prefix); + let first = map_storage_key(&[], key1); + final_key.extend_from_slice(&first); + let second = map_storage_key(&[], key2); + final_key.extend_from_slice(&second); + final_key +} + +const fn admin_utils_identifier(call_index: u8) -> TransactionIdentifier { + TransactionIdentifier::new(ADMIN_UTILS_PALLET_INDEX, call_index) +} + +const fn subtensor_identifier(call_index: u8) -> TransactionIdentifier { + TransactionIdentifier::new(SUBTENSOR_PALLET_INDEX, call_index) +} + +/// Returns the `TransactionIdentifier` for the admin-utils extrinsic that controls `hparam`. +/// +/// Only hyperparameters that are currently rate-limited (i.e. routed through +/// `ensure_sn_owner_or_root_with_limits`) are mapped; others return `None`. +pub fn identifier_for_hyperparameter(hparam: Hyperparameter) -> Option { + use Hyperparameter::*; + + let identifier = match hparam { + Unknown | MaxWeightLimit => return None, + ServingRateLimit => admin_utils_identifier(3), + MaxDifficulty => admin_utils_identifier(5), + AdjustmentAlpha => admin_utils_identifier(9), + ImmunityPeriod => admin_utils_identifier(13), + MinAllowedWeights => admin_utils_identifier(14), + MaxAllowedUids => admin_utils_identifier(15), + Kappa => admin_utils_identifier(16), + Rho => admin_utils_identifier(17), + ActivityCutoff => admin_utils_identifier(18), + PowRegistrationAllowed => admin_utils_identifier(20), + MinBurn => admin_utils_identifier(22), + MaxBurn => admin_utils_identifier(23), + BondsMovingAverage => admin_utils_identifier(26), + BondsPenalty => admin_utils_identifier(60), + CommitRevealEnabled => admin_utils_identifier(49), + LiquidAlphaEnabled => admin_utils_identifier(50), + AlphaValues => admin_utils_identifier(51), + WeightCommitInterval => admin_utils_identifier(57), + TransferEnabled => admin_utils_identifier(61), + AlphaSigmoidSteepness => admin_utils_identifier(68), + Yuma3Enabled => admin_utils_identifier(69), + BondsResetEnabled => admin_utils_identifier(70), + ImmuneNeuronLimit => admin_utils_identifier(72), + RecycleOrBurn => admin_utils_identifier(80), + _ => return None, + }; + + Some(identifier) +} + +/// Returns the `TransactionIdentifier` for the extrinsic associated with the given transaction +/// type, mirroring current rate-limit enforcement. +pub fn identifier_for_transaction_type(tx: TransactionType) -> Option { + use TransactionType::*; + + let identifier = match tx { + SetChildren => subtensor_identifier(67), + SetChildkeyTake => subtensor_identifier(75), + RegisterNetwork => subtensor_identifier(59), + SetWeightsVersionKey => admin_utils_identifier(6), + SetSNOwnerHotkey => admin_utils_identifier(67), + OwnerHyperparamUpdate(hparam) => return identifier_for_hyperparameter(hparam), + MechanismCountUpdate => admin_utils_identifier(76), + MechanismEmission => admin_utils_identifier(77), + MaxUidsTrimming => admin_utils_identifier(78), + Unknown => return None, + _ => return None, + }; + + Some(identifier) +} + +/// Maps legacy `RateLimitKey` entries to the new usage-key representation. +pub fn usage_key_from_legacy_key( + key: &RateLimitKey, +) -> Option> +where + AccountId: Parameter + Clone, +{ + match key { + RateLimitKey::SetSNOwnerHotkey(netuid) => Some(RateLimitUsageKey::Subnet(*netuid)), + RateLimitKey::OwnerHyperparamUpdate(netuid, _) => Some(RateLimitUsageKey::Subnet(*netuid)), + RateLimitKey::NetworkLastRegistered => None, + RateLimitKey::LastTxBlock(account) + | RateLimitKey::LastTxBlockChildKeyTake(account) + | RateLimitKey::LastTxBlockDelegateTake(account) => { + Some(RateLimitUsageKey::Account(account.clone())) + } + } +} + +/// Produces the usage key for a `TransactionType` that was stored in `TransactionKeyLastBlock`. +pub fn usage_key_from_transaction_type( + tx: TransactionType, + account: &AccountId, + netuid: NetUid, +) -> Option> +where + AccountId: Parameter + Clone, +{ + match tx { + TransactionType::SetChildren | TransactionType::SetChildkeyTake => { + Some(RateLimitUsageKey::AccountSubnet { + account: account.clone(), + netuid, + }) + } + TransactionType::SetWeightsVersionKey => Some(RateLimitUsageKey::Subnet(netuid)), + TransactionType::MechanismCountUpdate + | TransactionType::MechanismEmission + | TransactionType::MaxUidsTrimming => Some(RateLimitUsageKey::AccountSubnet { + account: account.clone(), + netuid, + }), + TransactionType::OwnerHyperparamUpdate(_) => Some(RateLimitUsageKey::Subnet(netuid)), + TransactionType::RegisterNetwork => None, + TransactionType::SetSNOwnerHotkey => Some(RateLimitUsageKey::Subnet(netuid)), + TransactionType::Unknown => None, + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn maps_hyperparameters() { + assert_eq!( + identifier_for_hyperparameter(Hyperparameter::ServingRateLimit), + Some(admin_utils_identifier(3)) + ); + assert!(identifier_for_hyperparameter(Hyperparameter::MaxWeightLimit).is_none()); + } + + #[test] + fn maps_transaction_types() { + assert_eq!( + identifier_for_transaction_type(TransactionType::SetChildren), + Some(subtensor_identifier(67)) + ); + assert!(identifier_for_transaction_type(TransactionType::Unknown).is_none()); + } + + #[test] + fn maps_usage_keys() { + let acct = 42u64; + assert!(matches!( + usage_key_from_legacy_key(&RateLimitKey::LastTxBlock(acct)), + Some(RateLimitUsageKey::Account(42)) + )); + } +} diff --git a/runtime/src/rate_limiting/mod.rs b/runtime/src/rate_limiting/mod.rs new file mode 100644 index 0000000000..713c8bacf6 --- /dev/null +++ b/runtime/src/rate_limiting/mod.rs @@ -0,0 +1,261 @@ +use frame_system::RawOrigin; +use pallet_admin_utils::Call as AdminUtilsCall; +use pallet_rate_limiting::{RateLimitScopeResolver, RateLimitUsageResolver}; +use pallet_subtensor::{Call as SubtensorCall, Tempo}; +use subtensor_runtime_common::{BlockNumber, NetUid, RateLimitScope, RateLimitUsageKey}; + +use crate::{AccountId, Runtime, RuntimeCall, RuntimeOrigin}; + +pub(crate) mod migration; + +fn signed_origin(origin: &RuntimeOrigin) -> Option { + match origin.clone().into() { + Ok(RawOrigin::Signed(who)) => Some(who), + _ => None, + } +} + +fn tempo_scaled(netuid: NetUid, span: BlockNumber) -> BlockNumber { + if span == 0 { + return span; + } + let tempo = BlockNumber::from(Tempo::::get(netuid) as u32); + span.saturating_mul(tempo) +} + +fn neuron_identity(origin: &RuntimeOrigin, netuid: NetUid) -> Option<(AccountId, u16)> { + let hotkey = signed_origin(origin)?; + let uid = + pallet_subtensor::Pallet::::get_uid_for_net_and_hotkey(netuid, &hotkey).ok()?; + Some((hotkey, uid)) +} + +fn owner_hparam_netuid(call: &AdminUtilsCall) -> Option { + match call { + AdminUtilsCall::sudo_set_activity_cutoff { netuid, .. } + | AdminUtilsCall::sudo_set_adjustment_alpha { netuid, .. } + | AdminUtilsCall::sudo_set_alpha_sigmoid_steepness { netuid, .. } + | AdminUtilsCall::sudo_set_alpha_values { netuid, .. } + | AdminUtilsCall::sudo_set_bonds_moving_average { netuid, .. } + | AdminUtilsCall::sudo_set_bonds_penalty { netuid, .. } + | AdminUtilsCall::sudo_set_bonds_reset_enabled { netuid, .. } + | AdminUtilsCall::sudo_set_commit_reveal_weights_enabled { netuid, .. } + | AdminUtilsCall::sudo_set_commit_reveal_weights_interval { netuid, .. } + | AdminUtilsCall::sudo_set_immunity_period { netuid, .. } + | AdminUtilsCall::sudo_set_liquid_alpha_enabled { netuid, .. } + | AdminUtilsCall::sudo_set_max_allowed_uids { netuid, .. } + | AdminUtilsCall::sudo_set_max_burn { netuid, .. } + | AdminUtilsCall::sudo_set_max_difficulty { netuid, .. } + | AdminUtilsCall::sudo_set_min_allowed_weights { netuid, .. } + | AdminUtilsCall::sudo_set_min_burn { netuid, .. } + | AdminUtilsCall::sudo_set_network_pow_registration_allowed { netuid, .. } + | AdminUtilsCall::sudo_set_owner_immune_neuron_limit { netuid, .. } + | AdminUtilsCall::sudo_set_recycle_or_burn { netuid, .. } + | AdminUtilsCall::sudo_set_rho { netuid, .. } + | AdminUtilsCall::sudo_set_serving_rate_limit { netuid, .. } + | AdminUtilsCall::sudo_set_sn_owner_hotkey { netuid, .. } + | AdminUtilsCall::sudo_set_toggle_transfer { netuid, .. } + | AdminUtilsCall::sudo_set_weights_version_key { netuid, .. } + | AdminUtilsCall::sudo_set_yuma3_enabled { netuid, .. } => Some(*netuid), + _ => None, + } +} + +fn admin_scope_netuid(call: &AdminUtilsCall) -> Option { + owner_hparam_netuid(call).or_else(|| match call { + AdminUtilsCall::sudo_set_mechanism_count { netuid, .. } + | AdminUtilsCall::sudo_set_mechanism_emission_split { netuid, .. } + | AdminUtilsCall::sudo_trim_to_max_allowed_uids { netuid, .. } => Some(*netuid), + _ => None, + }) +} + +#[derive(Default)] +pub struct UsageResolver; + +impl RateLimitUsageResolver> + for UsageResolver +{ + fn context(origin: &RuntimeOrigin, call: &RuntimeCall) -> Option> { + match call { + RuntimeCall::SubtensorModule(inner) => match inner { + SubtensorCall::swap_hotkey { .. } => { + signed_origin(origin).map(RateLimitUsageKey::::Account) + } + SubtensorCall::register_network { .. } + | SubtensorCall::register_network_with_identity { .. } => { + signed_origin(origin).map(RateLimitUsageKey::::Account) + } + SubtensorCall::increase_take { hotkey, .. } => { + Some(RateLimitUsageKey::::Account(hotkey.clone())) + } + SubtensorCall::set_childkey_take { hotkey, netuid, .. } + | SubtensorCall::set_children { hotkey, netuid, .. } => { + Some(RateLimitUsageKey::::AccountSubnet { + account: hotkey.clone(), + netuid: *netuid, + }) + } + SubtensorCall::set_weights { netuid, .. } + | SubtensorCall::commit_weights { netuid, .. } + | SubtensorCall::reveal_weights { netuid, .. } + | SubtensorCall::batch_reveal_weights { netuid, .. } + | SubtensorCall::commit_timelocked_weights { netuid, .. } => { + let (_, uid) = neuron_identity(origin, *netuid)?; + Some(RateLimitUsageKey::::SubnetNeuron { + netuid: *netuid, + uid, + }) + } + SubtensorCall::set_mechanism_weights { netuid, mecid, .. } + | SubtensorCall::commit_mechanism_weights { netuid, mecid, .. } + | SubtensorCall::reveal_mechanism_weights { netuid, mecid, .. } + | SubtensorCall::commit_crv3_mechanism_weights { netuid, mecid, .. } + | SubtensorCall::commit_timelocked_mechanism_weights { netuid, mecid, .. } => { + let (_, uid) = neuron_identity(origin, *netuid)?; + Some(RateLimitUsageKey::::SubnetMechanismNeuron { + netuid: *netuid, + mecid: *mecid, + uid, + }) + } + SubtensorCall::serve_axon { netuid, .. } + | SubtensorCall::serve_axon_tls { netuid, .. } + | SubtensorCall::serve_prometheus { netuid, .. } => { + let hotkey = signed_origin(origin)?; + Some(RateLimitUsageKey::::AccountSubnet { + account: hotkey, + netuid: *netuid, + }) + } + SubtensorCall::associate_evm_key { netuid, .. } => { + let hotkey = signed_origin(origin)?; + let uid = pallet_subtensor::Pallet::::get_uid_for_net_and_hotkey( + *netuid, &hotkey, + ) + .ok()?; + Some(RateLimitUsageKey::::SubnetNeuron { + netuid: *netuid, + uid, + }) + } + SubtensorCall::add_stake { hotkey, netuid, .. } + | SubtensorCall::add_stake_limit { hotkey, netuid, .. } + | SubtensorCall::remove_stake { hotkey, netuid, .. } + | SubtensorCall::remove_stake_limit { hotkey, netuid, .. } + | SubtensorCall::remove_stake_full_limit { hotkey, netuid, .. } + | SubtensorCall::transfer_stake { + hotkey, + origin_netuid: netuid, + .. + } + | SubtensorCall::swap_stake { + hotkey, + origin_netuid: netuid, + .. + } + | SubtensorCall::swap_stake_limit { + hotkey, + origin_netuid: netuid, + .. + } + | SubtensorCall::move_stake { + origin_hotkey: hotkey, + origin_netuid: netuid, + .. + } + | SubtensorCall::recycle_alpha { hotkey, netuid, .. } + | SubtensorCall::burn_alpha { hotkey, netuid, .. } => { + let coldkey = signed_origin(origin)?; + Some(RateLimitUsageKey::::ColdkeyHotkeySubnet { + coldkey, + hotkey: hotkey.clone(), + netuid: *netuid, + }) + } + _ => None, + }, + RuntimeCall::AdminUtils(inner) => { + if let Some(netuid) = owner_hparam_netuid(inner) { + // Hyperparameter setters share a global span but are tracked per subnet. + Some(RateLimitUsageKey::::Subnet(netuid)) + } else { + match inner { + AdminUtilsCall::sudo_set_mechanism_count { netuid, .. } + | AdminUtilsCall::sudo_set_mechanism_emission_split { netuid, .. } + | AdminUtilsCall::sudo_trim_to_max_allowed_uids { netuid, .. } => { + let who = signed_origin(origin)?; + Some(RateLimitUsageKey::::AccountSubnet { + account: who, + netuid: *netuid, + }) + } + _ => None, + } + } + } + _ => None, + } + } +} + +#[derive(Default)] +pub struct ScopeResolver; + +impl RateLimitScopeResolver + for ScopeResolver +{ + fn context(_origin: &RuntimeOrigin, call: &RuntimeCall) -> Option { + match call { + RuntimeCall::SubtensorModule(inner) => match inner { + SubtensorCall::serve_axon { netuid, .. } + | SubtensorCall::serve_axon_tls { netuid, .. } + | SubtensorCall::serve_prometheus { netuid, .. } + | SubtensorCall::set_weights { netuid, .. } + | SubtensorCall::commit_weights { netuid, .. } + | SubtensorCall::reveal_weights { netuid, .. } + | SubtensorCall::batch_reveal_weights { netuid, .. } + | SubtensorCall::commit_timelocked_weights { netuid, .. } => { + Some(RateLimitScope::Subnet(*netuid)) + } + SubtensorCall::set_mechanism_weights { netuid, mecid, .. } + | SubtensorCall::commit_mechanism_weights { netuid, mecid, .. } + | SubtensorCall::reveal_mechanism_weights { netuid, mecid, .. } + | SubtensorCall::commit_crv3_mechanism_weights { netuid, mecid, .. } + | SubtensorCall::commit_timelocked_mechanism_weights { netuid, mecid, .. } => { + Some(RateLimitScope::SubnetMechanism { + netuid: *netuid, + mecid: *mecid, + }) + } + _ => None, + }, + RuntimeCall::AdminUtils(inner) => { + if owner_hparam_netuid(inner).is_some() { + // Hyperparameter setters share a global limit span; usage is tracked per subnet. + None + } else { + admin_scope_netuid(inner).map(RateLimitScope::Subnet) + } + } + _ => None, + } + } + + fn should_bypass(origin: &RuntimeOrigin, _call: &RuntimeCall) -> bool { + matches!(origin.clone().into(), Ok(RawOrigin::Root)) + } + + fn adjust_span(_origin: &RuntimeOrigin, call: &RuntimeCall, span: BlockNumber) -> BlockNumber { + match call { + RuntimeCall::AdminUtils(inner) => { + if let Some(netuid) = owner_hparam_netuid(inner) { + tempo_scaled(netuid, span) + } else { + span + } + } + _ => span, + } + } +}