diff --git a/Cargo.lock b/Cargo.lock index 722073ab859eb..56c2a9ae20186 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23701,4 +23701,4 @@ checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" dependencies = [ "cc", "pkg-config", -] +] \ No newline at end of file diff --git a/crates/forge/src/runner.rs b/crates/forge/src/runner.rs index 15d22fa1acc11..4c749136f3bcf 100644 --- a/crates/forge/src/runner.rs +++ b/crates/forge/src/runner.rs @@ -544,13 +544,15 @@ impl<'a> FunctionRunner<'a> { /// State modifications of before test txes and unit test function call are discarded after /// test ends, similar to `eth_call`. fn run_unit_test(mut self, func: &Function) -> TestResult { - let binding = self.executor.clone().into_owned(); - self.executor = Cow::Owned(binding); // Prepare unit test execution. + self.executor.strategy.runner.start_transaction(self.executor.strategy.context.as_ref()); if self.prepare_test(func).is_err() { + self.executor + .strategy + .runner + .rollback_transaction(self.executor.strategy.context.as_ref()); return self.result; } - self.executor.strategy.runner.start_transaction(self.executor.strategy.context.as_ref()); // Run current unit test. let (mut raw_call_result, reason) = match self.executor.call( @@ -580,11 +582,11 @@ impl<'a> FunctionRunner<'a> { return self.result; } }; - self.executor.strategy.runner.rollback_transaction(self.executor.strategy.context.as_ref()); let success = self.executor.is_raw_call_mut_success(self.address, &mut raw_call_result, false); self.result.single_result(success, reason, raw_call_result); + self.executor.strategy.runner.rollback_transaction(self.executor.strategy.context.as_ref()); self.result } @@ -598,8 +600,6 @@ impl<'a> FunctionRunner<'a> { /// - `bool[] public fixtureSwap = [true, false]` The `table_test` is then called with the pair /// of args `(2, true)` and `(5, false)`. fn run_table_test(mut self, func: &Function) -> TestResult { - let binding = self.executor.clone().into_owned(); - self.executor = Cow::Owned(binding); // Prepare unit test execution. if self.prepare_test(func).is_err() { return self.result; } @@ -729,8 +729,6 @@ impl<'a> FunctionRunner<'a> { identified_contracts: &ContractsByAddress, test_bytecode: &Bytes, ) -> TestResult { - let binding = self.executor.clone().into_owned(); - self.executor = Cow::Owned(binding); // First, run the test normally to see if it needs to be skipped. if let Err(EvmError::Skip(reason)) = self.executor.call( self.sender, diff --git a/crates/forge/tests/it/revive/cheat_snapshot.rs b/crates/forge/tests/it/revive/cheat_snapshot.rs new file mode 100644 index 0000000000000..36b04cabc91f3 --- /dev/null +++ b/crates/forge/tests/it/revive/cheat_snapshot.rs @@ -0,0 +1,16 @@ +use crate::{config::*, test_helpers::TEST_DATA_REVIVE}; +use foundry_test_utils::Filter; +use revive_strategy::ReviveRuntimeMode; +use revm::primitives::hardfork::SpecId; +use rstest::rstest; + +#[rstest] +#[case::pvm(ReviveRuntimeMode::Pvm)] +#[case::evm(ReviveRuntimeMode::Evm)] +#[tokio::test(flavor = "multi_thread")] +async fn test_snapshot_cheats(#[case] runtime_mode: ReviveRuntimeMode) { + let runner: forge::MultiContractRunner = TEST_DATA_REVIVE.runner_revive(runtime_mode); + let filter = Filter::new(".*", "StateSnapshotTest", ".*/revive/.*"); + + TestConfig::with_filter(runner, filter).spec_id(SpecId::PRAGUE).run().await; +} diff --git a/crates/forge/tests/it/revive/mod.rs b/crates/forge/tests/it/revive/mod.rs index 48ad437e3ff9e..5045bcc758951 100644 --- a/crates/forge/tests/it/revive/mod.rs +++ b/crates/forge/tests/it/revive/mod.rs @@ -6,6 +6,7 @@ pub mod cheat_mock_call; pub mod cheat_mock_calls; pub mod cheat_mock_functions; pub mod cheat_prank; +mod cheat_snapshot; pub mod cheat_store; pub mod cheats_individual; pub mod migration; diff --git a/crates/revive-strategy/src/cheatcodes/mod.rs b/crates/revive-strategy/src/cheatcodes/mod.rs index d6974da958ad2..cfe02e5868801 100644 --- a/crates/revive-strategy/src/cheatcodes/mod.rs +++ b/crates/revive-strategy/src/cheatcodes/mod.rs @@ -9,7 +9,8 @@ use foundry_cheatcodes::{ CommonCreateInput, Ecx, EvmCheatcodeInspectorStrategyRunner, Result, Vm::{ chainIdCall, coinbaseCall, dealCall, etchCall, getNonce_0Call, loadCall, pvmCall, - resetNonceCall, rollCall, setNonceCall, setNonceUnsafeCall, storeCall, warpCall, + resetNonceCall, revertToStateAndDeleteCall, revertToStateCall, rollCall, setNonceCall, + setNonceUnsafeCall, snapshotStateCall, storeCall, warpCall, }, journaled_account, precompile_error, }; @@ -319,6 +320,23 @@ impl CheatcodeInspectorStrategyRunner for PvmCheatcodeInspectorStrategyRunner { cheatcode.dyn_apply(ccx, executor) } + t if using_pvm && is::(t) => { + ctx.externalities.start_snapshotting(); + cheatcode.dyn_apply(ccx, executor) + } + t if using_pvm && is::(t) => { + let &revertToStateAndDeleteCall { snapshotId } = + cheatcode.as_any().downcast_ref().unwrap(); + + ctx.externalities.revert(snapshotId.try_into().unwrap()); + cheatcode.dyn_apply(ccx, executor) + } + t if using_pvm && is::(t) => { + let &revertToStateCall { snapshotId } = cheatcode.as_any().downcast_ref().unwrap(); + + ctx.externalities.revert(snapshotId.try_into().unwrap()); + cheatcode.dyn_apply(ccx, executor) + } t if using_pvm && is::(t) => { let &warpCall { newTimestamp } = cheatcode.as_any().downcast_ref().unwrap(); diff --git a/crates/revive-strategy/src/executor/runner.rs b/crates/revive-strategy/src/executor/runner.rs index 4d86635c3f83e..ce6c0cdee0331 100644 --- a/crates/revive-strategy/src/executor/runner.rs +++ b/crates/revive-strategy/src/executor/runner.rs @@ -157,12 +157,12 @@ impl ExecutorStrategyExt for ReviveExecutorStrategyRunner { fn start_transaction(&self, ctx: &dyn ExecutorStrategyContext) { let ctx = get_context_ref(ctx); let mut externalities = ctx.externalties.0.lock().unwrap(); - externalities.ext().storage_start_transaction(); + externalities.externalities.ext().storage_start_transaction(); } fn rollback_transaction(&self, ctx: &dyn ExecutorStrategyContext) { let ctx = get_context_ref(ctx); - let mut externalities = ctx.externalties.0.lock().unwrap(); - externalities.ext().storage_rollback_transaction().unwrap(); + let mut state = ctx.externalties.0.lock().unwrap(); + let _ = state.externalities.ext().storage_rollback_transaction(); } } diff --git a/crates/revive-strategy/src/state.rs b/crates/revive-strategy/src/state.rs index 2207459c35c11..e0e780111239e 100644 --- a/crates/revive-strategy/src/state.rs +++ b/crates/revive-strategy/src/state.rs @@ -6,6 +6,7 @@ use polkadot_sdk::{ Executable, Pallet, }, sp_core::{self, H160}, + sp_externalities::Externalities, sp_io::TestExternalities, }; use revive_env::{AccountId, BlockAuthor, ExtBuilder, Runtime, System, Timestamp}; @@ -13,15 +14,23 @@ use std::{ fmt::Debug, sync::{Arc, Mutex}, }; -pub struct TestEnv(pub Arc>); -impl Default for TestEnv { +pub(crate) struct Inner { + pub externalities: TestExternalities, + pub depth: usize, +} + +#[derive(Default)] +pub struct TestEnv(pub(crate) Arc>); + +impl Default for Inner { fn default() -> Self { - Self(Arc::new(Mutex::new( - ExtBuilder::default() + Self { + externalities: ExtBuilder::default() .balance_genesis_config(vec![(H160::from_low_u64_be(1), 1000)]) .build(), - ))) + depth: 0, + } } } @@ -33,9 +42,10 @@ impl Debug for TestEnv { impl Clone for TestEnv { fn clone(&self) -> Self { - let mut externalities = ExtBuilder::default().build(); - externalities.backend = self.0.lock().unwrap().as_backend(); - Self(Arc::new(Mutex::new(externalities))) + let mut inner: Inner = Default::default(); + inner.externalities.backend = self.0.lock().unwrap().externalities.as_backend(); + inner.depth = self.0.lock().unwrap().depth; + Self(Arc::new(Mutex::new(inner))) } } @@ -44,12 +54,28 @@ impl TestEnv { Self(self.0.clone()) } + pub fn start_snapshotting(&mut self) { + let mut state = self.0.lock().unwrap(); + state.depth += 1; + state.externalities.ext().storage_start_transaction(); + } + + pub fn revert(&mut self, depth: usize) { + let mut state = self.0.lock().unwrap(); + while state.depth > depth + 1 { + state.externalities.ext().storage_rollback_transaction().unwrap(); + state.depth -= 1; + } + state.externalities.ext().storage_rollback_transaction().unwrap(); + state.externalities.ext().storage_start_transaction(); + } + pub fn execute_with R>(&mut self, f: F) -> R { - self.0.lock().unwrap().execute_with(f) + self.0.lock().unwrap().externalities.execute_with(f) } pub fn get_nonce(&mut self, account: Address) -> u32 { - self.0.lock().unwrap().execute_with(|| { + self.0.lock().unwrap().externalities.execute_with(|| { System::account_nonce(AccountId::to_fallback_account_id(&H160::from_slice( account.as_slice(), ))) @@ -57,7 +83,7 @@ impl TestEnv { } pub fn set_nonce(&mut self, address: Address, nonce: u64) { - self.0.lock().unwrap().execute_with(|| { + self.0.lock().unwrap().externalities.execute_with(|| { let account_id = AccountId::to_fallback_account_id(&H160::from_slice(address.as_slice())); @@ -69,7 +95,7 @@ impl TestEnv { pub fn set_chain_id(&mut self, new_chain_id: u64) { // Set chain id in pallet-revive runtime. - self.0.lock().unwrap().execute_with(|| { + self.0.lock().unwrap().externalities.execute_with(|| { ::ChainId::set( &new_chain_id, ); @@ -78,14 +104,14 @@ impl TestEnv { pub fn set_block_number(&mut self, new_height: U256) { // Set block number in pallet-revive runtime. - self.0.lock().unwrap().execute_with(|| { + self.0.lock().unwrap().externalities.execute_with(|| { System::set_block_number(new_height.try_into().expect("Block number exceeds u64")); }); } pub fn set_timestamp(&mut self, new_timestamp: U256) { // Set timestamp in pallet-revive runtime (milliseconds). - self.0.lock().unwrap().execute_with(|| { + self.0.lock().unwrap().externalities.execute_with(|| { let timestamp_ms = new_timestamp.saturating_to::().saturating_mul(1000); Timestamp::set_timestamp(timestamp_ms); }); @@ -97,7 +123,7 @@ impl TestEnv { new_runtime_code: &Bytes, ecx: Ecx<'_, '_, '_>, ) -> Result { - self.0.lock().unwrap().execute_with(|| { + self.0.lock().unwrap().externalities.execute_with(|| { let origin_address = H160::from_slice(ecx.tx.caller.as_slice()); let origin_account = AccountId::to_fallback_account_id(&origin_address); @@ -153,6 +179,7 @@ impl TestEnv { self.0 .lock() .unwrap() + .externalities .execute_with(|| { pallet_revive::Pallet::::get_storage(target_address_h160, slot.into()) }) @@ -169,6 +196,7 @@ impl TestEnv { self.0 .lock() .unwrap() + .externalities .execute_with(|| { pallet_revive::Pallet::::set_storage( target_address_h160, @@ -184,7 +212,7 @@ impl TestEnv { let amount_pvm = sp_core::U256::from_little_endian(&amount.as_le_bytes()).min(u128::MAX.into()); - self.0.lock().unwrap().execute_with(|| { + self.0.lock().unwrap().externalities.execute_with(|| { let h160_addr = H160::from_slice(address.as_slice()); pallet_revive::Pallet::::set_evm_balance(&h160_addr, amount_pvm) .expect("failed to set evm balance"); @@ -195,6 +223,7 @@ impl TestEnv { self.0 .lock() .unwrap() + .externalities .execute_with(|| { let h160_addr = H160::from_slice(address.as_slice()); pallet_revive::Pallet::::evm_balance(&h160_addr) @@ -204,7 +233,7 @@ impl TestEnv { } pub fn set_block_author(&mut self, new_author: Address) { - self.0.lock().unwrap().execute_with(|| { + self.0.lock().unwrap().externalities.execute_with(|| { let account_id32 = AccountId::to_fallback_account_id(&H160::from_slice(new_author.as_slice())); BlockAuthor::set(&account_id32); diff --git a/testdata/default/revive/Snapshot.t.sol b/testdata/default/revive/Snapshot.t.sol new file mode 100644 index 0000000000000..4e3928967a9eb --- /dev/null +++ b/testdata/default/revive/Snapshot.t.sol @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +pragma solidity ^0.8.18; + +import "ds-test/test.sol"; +import "cheats/Vm.sol"; + +contract Storage { + uint256 public slot0; + uint256 public slot1; + + function setSlots(uint256 a, uint256 b) public { + slot0 = a; + slot1 = b; + } + + function blockNumber() public returns (uint256) { + return block.number; + } + + function blockTimestamp() public returns (uint256) { + return block.timestamp; + } +} + +contract StateSnapshotTest is DSTest { + Vm constant vm = Vm(HEVM_ADDRESS); + + Storage store; + + function setUp() public { + store = new Storage(); + store.setSlots(10, 20); + } + + function testStateSnapshot() public { + uint256 snapshotId = vm.snapshotState(); + store.setSlots(300, 400); + + assertEq(store.slot0(), 300); + assertEq(store.slot1(), 400); + + vm.revertToState(snapshotId); + assertEq(store.slot0(), 10, "snapshot revert for slot 0 unsuccessful"); + assertEq(store.slot1(), 20, "snapshot revert for slot 1 unsuccessful"); + } + + function testStateSnapshot2() public { + uint256 snapshotId = vm.snapshotState(); + store.setSlots(300, 400); + + assertEq(store.slot0(), 300); + assertEq(store.slot1(), 400); + + uint256 snapshotId2 = vm.snapshotState(); + store.setSlots(500, 600); + + assertEq(store.slot0(), 500); + assertEq(store.slot1(), 600); + + uint256 snapshotId3 = vm.snapshotState(); + store.setSlots(700, 800); + + assertEq(store.slot0(), 700); + assertEq(store.slot1(), 800); + + uint256 snapshotId4 = vm.snapshotState(); + store.setSlots(800, 900); + + assertEq(store.slot0(), 800); + assertEq(store.slot1(), 900); + + vm.revertToState(snapshotId4); + assertEq(store.slot0(), 700, "snapshot revert for slot 0 unsuccessful"); + assertEq(store.slot1(), 800, "snapshot revert for slot 1 unsuccessful"); + + vm.revertToState(snapshotId3); + assertEq(store.slot0(), 500, "snapshot revert for slot 0 unsuccessful"); + assertEq(store.slot1(), 600, "snapshot revert for slot 1 unsuccessful"); + + vm.revertToState(snapshotId2); + assertEq(store.slot0(), 300, "snapshot revert for slot 0 unsuccessful"); + assertEq(store.slot1(), 400, "snapshot revert for slot 1 unsuccessful"); + + vm.revertToState(snapshotId); + assertEq(store.slot0(), 10, "snapshot revert for slot 0 unsuccessful"); + assertEq(store.slot1(), 20, "snapshot revert for slot 1 unsuccessful"); + } + + function testStateSnapshotRevertDelete() public { + uint256 snapshotId = vm.snapshotState(); + store.setSlots(300, 400); + + assertEq(store.slot0(), 300); + assertEq(store.slot1(), 400); + + vm.revertToStateAndDelete(snapshotId); + assertEq(store.slot0(), 10, "snapshot revert for slot 0 unsuccessful"); + assertEq(store.slot1(), 20, "snapshot revert for slot 1 unsuccessful"); + // nothing to revert to anymore + assert(!vm.revertToState(snapshotId)); + } + + function testStateSnapshotDelete() public { + uint256 snapshotId = vm.snapshotState(); + store.setSlots(300, 400); + vm.deleteStateSnapshot(snapshotId); + // nothing to revert to anymore + assert(!vm.revertToState(snapshotId)); + } + + function testStateSnapshotDeleteAll() public { + uint256 snapshotId = vm.snapshotState(); + store.setSlots(300, 400); + vm.deleteStateSnapshots(); + // nothing to revert to anymore + assert(!vm.revertToState(snapshotId)); + } + + // + function testStateSnapshotsMany() public { + uint256 snapshotId; + for (uint256 c = 0; c < 10; c++) { + for (uint256 cc = 0; cc < 10; cc++) { + snapshotId = vm.snapshotState(); + vm.revertToStateAndDelete(snapshotId); + assert(!vm.revertToState(snapshotId)); + } + } + } + + // tests that snapshots can also revert changes to `block` + function testBlockValues() public { + uint256 num = store.blockNumber(); + uint256 time = store.blockTimestamp(); + + uint256 snapshotId = vm.snapshotState(); + Storage store2 = new Storage(); + store2.setSlots(300, 400); + + assertEq(store2.slot0(), 300); + assertEq(store2.slot1(), 400); + + vm.warp(1337); + assertEq(store.blockTimestamp(), 1337); + + vm.roll(99); + assertEq(store.blockNumber(), 99); + + assert(vm.revertToState(snapshotId)); + + assertEq(store.blockNumber(), num, "snapshot revert for block.number unsuccessful"); + assertEq(store.blockTimestamp(), time, "snapshot revert for block.timestamp unsuccessful"); + } +}