Skip to content

Commit f085357

Browse files
authored
fix(forge): dynamic test linking for try catch with custom return (#12050)
1 parent 8bab682 commit f085357

File tree

2 files changed

+131
-9
lines changed

2 files changed

+131
-9
lines changed

crates/common/src/preprocessor/deps.rs

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ enum BytecodeDependencyKind {
100100
value: Option<String>,
101101
/// `salt` (if any) used when creating contract.
102102
salt: Option<String>,
103-
/// Whether it's a try contract creation statement.
104-
try_stmt: bool,
103+
/// Whether it's a try contract creation statement, with custom return.
104+
try_stmt: Option<bool>,
105105
},
106106
}
107107

@@ -173,7 +173,6 @@ impl<'gcx> Visit<'gcx> for BytecodeDependencyCollector<'gcx, '_> {
173173
call_expr,
174174
call_args,
175175
named_args,
176-
false,
177176
) {
178177
self.collect_dependency(dependency);
179178
}
@@ -199,17 +198,30 @@ impl<'gcx> Visit<'gcx> for BytecodeDependencyCollector<'gcx, '_> {
199198
fn visit_stmt(&mut self, stmt: &'gcx Stmt<'gcx>) -> ControlFlow<Self::BreakValue> {
200199
if let StmtKind::Try(stmt_try) = stmt.kind
201200
&& let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
202-
&& let Some(dependency) = handle_call_expr(
201+
&& let Some(mut dependency) = handle_call_expr(
203202
self.src,
204203
self.gcx.sess.source_map(),
205204
&stmt_try.expr,
206205
call_expr,
207206
call_args,
208207
named_args,
209-
true,
210208
)
211209
{
210+
let has_custom_return = if let Some(clause) = stmt_try.clauses.first()
211+
&& clause.args.len() == 1
212+
&& let Some(ret_var) = clause.args.first()
213+
&& let TypeKind::Custom(_) = self.hir().variable(*ret_var).ty.kind
214+
{
215+
true
216+
} else {
217+
false
218+
};
219+
220+
if let BytecodeDependencyKind::New { try_stmt, .. } = &mut dependency.kind {
221+
*try_stmt = Some(has_custom_return);
222+
}
212223
self.collect_dependency(dependency);
224+
213225
for clause in stmt_try.clauses {
214226
for &var in clause.args {
215227
self.visit_nested_var(var)?;
@@ -232,7 +244,6 @@ fn handle_call_expr(
232244
call_expr: &Expr<'_>,
233245
call_args: &CallArgs<'_>,
234246
named_args: &Option<&[NamedArg<'_>]>,
235-
try_stmt: bool,
236247
) -> Option<BytecodeDependency> {
237248
if let ExprKind::New(ty_new) = &call_expr.kind
238249
&& let TypeKind::Custom(item_id) = ty_new.kind
@@ -258,7 +269,7 @@ fn handle_call_expr(
258269
call_args_offset,
259270
value: named_arg(src, named_args, "value", source_map),
260271
salt: named_arg(src, named_args, "salt", source_map),
261-
try_stmt,
272+
try_stmt: None,
262273
},
263274
loc: span_to_range(source_map, call_expr.span),
264275
referenced_contract: contract_id,
@@ -284,6 +295,17 @@ fn named_arg(
284295

285296
/// Goes over all test/script files and replaces bytecode dependencies with cheatcode
286297
/// invocations.
298+
///
299+
/// Special handling of try/catch statements with custom returns, where the try statement becomes
300+
/// ```solidity
301+
/// try this.addressToCounter() returns (Counter c)
302+
/// ```
303+
/// and helper to cast address is appended
304+
/// ```solidity
305+
/// function addressToCounter(address addr) returns (Counter) {
306+
/// return Counter(addr);
307+
/// }
308+
/// ```
287309
pub(crate) fn remove_bytecode_dependencies(
288310
gcx: Gcx<'_>,
289311
deps: &PreprocessorDependencies,
@@ -303,6 +325,7 @@ pub(crate) fn remove_bytecode_dependencies(
303325
let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
304326
// `address(uint160(uint256(keccak256("hevm cheat code"))))`
305327
let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
328+
let mut try_catch_helpers: HashSet<&str> = HashSet::default();
306329

307330
for dep in deps {
308331
let Some(ContractData { artifact, constructor_data, .. }) =
@@ -328,8 +351,14 @@ pub(crate) fn remove_bytecode_dependencies(
328351
salt,
329352
try_stmt,
330353
} => {
331-
let (mut update, closing_seq) = if *try_stmt {
332-
(String::new(), "})")
354+
let (mut update, closing_seq) = if let Some(has_ret) = try_stmt {
355+
if *has_ret {
356+
// try this.addressToCounter1() returns (Counter c)
357+
try_catch_helpers.insert(name);
358+
(format!("this.addressTo{name}{id}(", id = contract_id.get()), "}))")
359+
} else {
360+
(String::new(), "})")
361+
}
333362
} else {
334363
(format!("{name}(payable("), "})))")
335364
};
@@ -369,6 +398,30 @@ pub(crate) fn remove_bytecode_dependencies(
369398
}
370399
};
371400
}
401+
402+
// Add try catch statements after last function of the test contract.
403+
if !try_catch_helpers.is_empty()
404+
&& let Some(last_fn_id) = contract.functions().last()
405+
{
406+
let last_fn_range =
407+
span_to_range(gcx.sess.source_map(), gcx.hir.function(last_fn_id).span);
408+
let to_address_fns = try_catch_helpers
409+
.iter()
410+
.map(|ty| {
411+
format!(
412+
r#"
413+
function addressTo{ty}{id}(address addr) public pure returns ({ty}) {{
414+
return {ty}(addr);
415+
}}
416+
"#,
417+
id = contract_id.get()
418+
)
419+
})
420+
.collect::<String>();
421+
422+
updates.insert((last_fn_range.end, last_fn_range.end, to_address_fns));
423+
}
424+
372425
let helper_imports = used_helpers.into_iter().map(|id| {
373426
let id = id.get();
374427
format!(

crates/forge/tests/cli/test_optimizer.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,3 +1588,72 @@ Ran 1 test suite [ELAPSED]: 1 tests passed, 0 failed, 0 skipped (1 total tests)
15881588
15891589
"#]]);
15901590
});
1591+
1592+
// Preprocess test contracts with try constructor statements that bind return type.
1593+
forgetest_init!(preprocess_contract_with_try_ctor_stmt_and_returns, |prj, cmd| {
1594+
prj.wipe_contracts();
1595+
prj.update_config(|config| {
1596+
config.dynamic_test_linking = true;
1597+
});
1598+
1599+
prj.add_source(
1600+
"Counter.sol",
1601+
r#"
1602+
contract Counter {
1603+
uint256 number;
1604+
constructor(uint256 a) payable {
1605+
require(a > 0, "ctor failure");
1606+
number = a;
1607+
}
1608+
}
1609+
"#,
1610+
);
1611+
prj.add_test(
1612+
"CounterReturns.t.sol",
1613+
r#"
1614+
import {Test} from "forge-std/Test.sol";
1615+
import {Counter} from "../src/Counter.sol";
1616+
1617+
contract CounterReturnsTest is Test {
1618+
function test_try_counter_creation_returns_custom_type() public {
1619+
try new Counter(1) returns (Counter c) {
1620+
c;
1621+
} catch {
1622+
revert();
1623+
}
1624+
}
1625+
}
1626+
"#,
1627+
);
1628+
1629+
cmd.args(["test"]).with_no_redact().assert_success().stdout_eq(str![[r#"
1630+
...
1631+
Compiling 21 files with [..]
1632+
...
1633+
[PASS] test_try_counter_creation_returns_custom_type() (gas: [..])
1634+
...
1635+
1636+
"#]]);
1637+
1638+
// Change Counter to fail test in try statement, only Counter contract should be compiled.
1639+
prj.add_source(
1640+
"Counter.sol",
1641+
r#"
1642+
contract Counter {
1643+
uint256 number;
1644+
constructor(uint256 a) payable {
1645+
require(a == 0, "ctor failure");
1646+
number = a;
1647+
}
1648+
}
1649+
"#,
1650+
);
1651+
cmd.assert_failure().stdout_eq(str![[r#"
1652+
...
1653+
Compiling 1 files with [..]
1654+
...
1655+
[FAIL: ctor failure] test_try_counter_creation_returns_custom_type() (gas: [..])
1656+
...
1657+
1658+
"#]]);
1659+
});

0 commit comments

Comments
 (0)