diff --git a/crates/forge/src/runner.rs b/crates/forge/src/runner.rs index 0606302282cdd..9389ad31521c5 100644 --- a/crates/forge/src/runner.rs +++ b/crates/forge/src/runner.rs @@ -66,6 +66,8 @@ pub struct ContractRunner<'a> { tokio_handle: &'a tokio::runtime::Handle, /// The span of the contract. span: tracing::Span, + /// The start time of the test run. + start_time: Instant, /// The contract-level configuration. tcfg: Cow<'a, TestRunnerConfig>, /// The parent runner. @@ -98,6 +100,7 @@ impl<'a> ContractRunner<'a> { progress, tokio_handle, span, + start_time: Instant::now(), // Will be overwritten in `run_tests` tcfg: Cow::Borrowed(&mcr.tcfg), mcr, } @@ -283,58 +286,28 @@ impl<'a> ContractRunner<'a> { /// Runs all tests for a contract whose names match the provided regular expression pub fn run_tests(mut self, filter: &dyn TestFilter) -> SuiteResult { - let start = Instant::now(); + self.start_time = Instant::now(); let mut warnings = Vec::new(); - // Check if `setUp` function with valid signature declared. - let setup_fns: Vec<_> = - self.contract.abi.functions().filter(|func| func.name.is_setup()).collect(); - let call_setup = setup_fns.len() == 1 && setup_fns[0].name == "setUp"; - // There is a single miss-cased `setUp` function, so we add a warning - for &setup_fn in &setup_fns { - if setup_fn.name != "setUp" { - warnings.push(format!( - "Found invalid setup function \"{}\" did you mean \"setUp()\"?", - setup_fn.signature() - )); - } - } + let functions = self.contract.abi.functions(); - // There are multiple setUp function, so we return a single test result for `setUp` - if setup_fns.len() > 1 { - return SuiteResult::new( - start.elapsed(), - [("setUp()".to_string(), TestResult::fail("multiple setUp functions".to_string()))] - .into(), - warnings, - ); - } + let call_setup = match self.validate_special_function( + functions.clone().filter(|func| func.name.is_setup()), + "setUp", + &mut warnings, + ) { + Ok(call) => call, + Err(res) => return res, + }; - // Check if `afterInvariant` function with valid signature declared. - let after_invariant_fns: Vec<_> = - self.contract.abi.functions().filter(|func| func.name.is_after_invariant()).collect(); - if after_invariant_fns.len() > 1 { - // Return a single test result failure if multiple functions declared. - return SuiteResult::new( - start.elapsed(), - [( - "afterInvariant()".to_string(), - TestResult::fail("multiple afterInvariant functions".to_string()), - )] - .into(), - warnings, - ); - } - let call_after_invariant = after_invariant_fns.first().is_some_and(|after_invariant_fn| { - let match_sig = after_invariant_fn.name == "afterInvariant"; - if !match_sig { - warnings.push(format!( - "Found invalid afterInvariant function \"{}\" did you mean \"afterInvariant()\"?", - after_invariant_fn.signature() - )); - } - match_sig - }); + let call_after_invariant = match self.validate_special_function( + functions.clone().filter(|func| func.name.is_after_invariant()), + "afterInvariant", + &mut warnings, + ) { + Ok(call) => call, + Err(res) => return res, + }; // Invariant testing requires tracing to figure out what contracts were created. // We also want to disable `debug` for setup since we won't be using those traces. @@ -359,7 +332,7 @@ impl<'a> ContractRunner<'a> { "constructor()".to_string() }; return SuiteResult::new( - start.elapsed(), + self.start_time.elapsed(), [(fail_msg, TestResult::setup_result(setup))].into(), warnings, ); @@ -392,7 +365,7 @@ impl<'a> ContractRunner<'a> { TestResult::fail("`testFail*` has been removed. Consider changing to test_Revert[If|When]_Condition and expecting a revert".to_string()) }; let test_results = test_fail_functions.map(|func| (func.signature(), fail())).collect(); - return SuiteResult::new(start.elapsed(), test_results, warnings); + return SuiteResult::new(self.start_time.elapsed(), test_results, warnings); } let fail_fast = &self.tcfg.fail_fast; @@ -442,9 +415,47 @@ impl<'a> ContractRunner<'a> { }) .collect::>(); - let duration = start.elapsed(); + let duration = self.start_time.elapsed(); SuiteResult::new(duration, test_results, warnings) } + + /// Validates the presence and signature of a special function like `setUp` or `afterInvariant`. + fn validate_special_function<'b>( + &self, + functions: impl IntoIterator, + expected_name: &str, + warnings: &mut Vec, + ) -> Result { + let functions: Vec<_> = functions.into_iter().collect(); + + // Error if more than one function is found + if functions.len() > 1 { + return Err(SuiteResult::new( + self.start_time.elapsed(), + [( + format!("{expected_name}()"), + TestResult::fail(format!("multiple {expected_name} functions")), + )] + .into(), + warnings.clone(), + )); + } + + let Some(func) = functions.first() else { + return Ok(false); + }; + + // Add a warning if the function name is misspelled (e.g., `setup` instead of `setUp`) + if func.name != expected_name { + warnings.push(format!( + r#"Found invalid function "{}". Did you mean "{expected_name}()"?"#, + func.signature() + )); + Ok(false) + } else { + Ok(true) + } + } } /// Executes a single test function, returning a [`TestResult`].