Skip to content

Commit ad70cd1

Browse files
committed
Guard HIR contracts based on compiler flag rather than lang_item
This allows the optimiser to properly eliminate contract code when runtime contract checks are disabled. It comes at the cost of having to recompile upstream crates (e.g. std) to enable contracts in them. However, this trade off is acceptable if it means disabled runtime contract checks do not affect the runtime performance of the functions they annotate. With the proper elimination of contract code, which this change introduces, the runtime performance of annotated functions should be the same as the original unannotated function.
1 parent b606b49 commit ad70cd1

File tree

2 files changed

+120
-37
lines changed

2 files changed

+120
-37
lines changed

compiler/rustc_ast_lowering/src/contract.rs

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1+
use thin_vec::thin_vec;
2+
13
use crate::LoweringContext;
24

35
impl<'a, 'hir> LoweringContext<'a, 'hir> {
6+
/// Lowered contracts are guarded with the `contract_checks` compiler flag,
7+
/// i.e. the flag turns into a boolean guard in the lowered HIR. The reason
8+
/// for not eliminating the contract code entirely when the `contract_checks`
9+
/// flag is disabled is so that contracts can be type checked, even when
10+
/// they are disabled, which avoids them becoming stale (i.e. out of sync
11+
/// with the codebase) over time.
12+
///
13+
/// The optimiser should be able to eliminate all contract code guarded
14+
/// by `if false`, leaving the original body intact when runtime contract
15+
/// checks are disabled.
416
pub(super) fn lower_contract(
517
&mut self,
618
body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
@@ -14,14 +26,20 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
1426
//
1527
// into:
1628
//
29+
// let __postcond = if contract_checks {
30+
// contract_check_requires(PRECOND);
31+
// Some(|ret_val| POSTCOND)
32+
// } else {
33+
// None
34+
// };
1735
// {
18-
// let __postcond = if contracts_checks() {
19-
// contract_check_requires(PRECOND);
20-
// Some(|ret_val| POSTCOND)
21-
// } else {
22-
// None
23-
// };
24-
// contract_check_ensures(__postcond, { body })
36+
// let ret = { body };
37+
//
38+
// if contract_checks {
39+
// contract_check_ensures(__postcond, ret)
40+
// } else {
41+
// ret
42+
// }
2543
// }
2644

2745
let precond = self.lower_precond(req);
@@ -41,13 +59,19 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
4159
//
4260
// into:
4361
//
62+
// let __postcond = if contract_checks {
63+
// Some(|ret_val| POSTCOND)
64+
// } else {
65+
// None
66+
// };
4467
// {
45-
// let __postcond = if contracts_check() {
46-
// Some(|ret_val| POSTCOND)
47-
// } else {
48-
// None
49-
// };
50-
// __postcond({ body })
68+
// let ret = { body };
69+
//
70+
// if contract_checks {
71+
// contract_check_ensures(__postcond, ret)
72+
// } else {
73+
// ret
74+
// }
5175
// }
5276

5377
let postcond_checker = self.lower_postcond_checker(ens);
@@ -66,7 +90,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
6690
// into:
6791
//
6892
// {
69-
// if contracts_check() {
93+
// if contracts_checks {
7094
// contract_requires(PRECOND);
7195
// }
7296
// body
@@ -129,11 +153,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
129153
let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
130154

131155
let precond_check = rustc_hir::ExprKind::If(
132-
self.expr_call_lang_item_fn(
133-
precond.span,
134-
rustc_hir::LangItem::ContractChecks,
135-
Default::default(),
136-
),
156+
self.arena.alloc(self.expr_bool_literal(precond.span, self.tcx.sess.contract_checks())),
137157
then_block,
138158
None,
139159
);
@@ -170,11 +190,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
170190
let else_block = self.arena.alloc(self.expr_block(else_block));
171191

172192
let contract_check = rustc_hir::ExprKind::If(
173-
self.expr_call_lang_item_fn(
174-
span,
175-
rustc_hir::LangItem::ContractChecks,
176-
Default::default(),
177-
),
193+
self.arena.alloc(self.expr_bool_literal(span, self.tcx.sess.contract_checks())),
178194
then_block,
179195
Some(else_block),
180196
);
@@ -249,12 +265,60 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
249265
cond_ident: rustc_span::Ident,
250266
cond_hir_id: rustc_hir::HirId,
251267
) -> &'hir rustc_hir::Expr<'hir> {
268+
// {
269+
// let ret = { body };
270+
//
271+
// if contract_checks {
272+
// contract_check_ensures(__postcond, ret)
273+
// } else {
274+
// ret
275+
// }
276+
// }
277+
let ret_ident: rustc_span::Ident = rustc_span::Ident::from_str_and_span("__ret", span);
278+
279+
// Set up the return `let` statement.
280+
let (ret_pat, ret_hir_id) =
281+
self.pat_ident_binding_mode_mut(span, ret_ident, rustc_hir::BindingMode::NONE);
282+
283+
let ret_stmt = self.stmt_let_pat(
284+
None,
285+
span,
286+
Some(expr),
287+
self.arena.alloc(ret_pat),
288+
rustc_hir::LocalSource::Contract,
289+
);
290+
291+
let ret = self.expr_ident(span, ret_ident, ret_hir_id);
292+
252293
let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
253-
let call_expr = self.expr_call_lang_item_fn_mut(
294+
let contract_check = self.expr_call_lang_item_fn_mut(
254295
span,
255296
rustc_hir::LangItem::ContractCheckEnsures,
256-
arena_vec![self; *cond_fn, *expr],
297+
arena_vec![self; *cond_fn, *ret],
257298
);
258-
self.arena.alloc(call_expr)
299+
let contract_check = self.arena.alloc(contract_check);
300+
let call_expr = self.block_expr_block(contract_check);
301+
302+
// same ident can't be used in 2 places, so we create a new one for the
303+
// else branch
304+
let ret = self.expr_ident(span, ret_ident, ret_hir_id);
305+
let ret_block = self.block_expr_block(ret);
306+
307+
let contracts_enabled: rustc_hir::Expr<'_> =
308+
self.expr_bool_literal(span, self.tcx.sess.contract_checks());
309+
let contract_check = self.arena.alloc(self.expr(
310+
span,
311+
rustc_hir::ExprKind::If(
312+
self.arena.alloc(contracts_enabled),
313+
call_expr,
314+
Some(ret_block),
315+
),
316+
));
317+
318+
let attrs: rustc_ast::AttrVec = thin_vec![self.unreachable_code_attr(span)];
319+
self.lower_attrs(contract_check.hir_id, &attrs, span);
320+
321+
let ret_block = self.block_all(span, arena_vec![self; ret_stmt], Some(contract_check));
322+
self.arena.alloc(self.expr_block(self.arena.alloc(ret_block)))
259323
}
260324
}

compiler/rustc_ast_lowering/src/expr.rs

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,16 +1916,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
19161916
)
19171917
};
19181918

1919-
// `#[allow(unreachable_code)]`
1920-
let attr = attr::mk_attr_nested_word(
1921-
&self.tcx.sess.psess.attr_id_generator,
1922-
AttrStyle::Outer,
1923-
Safety::Default,
1924-
sym::allow,
1925-
sym::unreachable_code,
1926-
try_span,
1927-
);
1928-
let attrs: AttrVec = thin_vec![attr];
1919+
let attrs: AttrVec = thin_vec![self.unreachable_code_attr(try_span)];
19291920

19301921
// `ControlFlow::Continue(val) => #[allow(unreachable_code)] val,`
19311922
let continue_arm = {
@@ -2262,6 +2253,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
22622253
self.expr(b.span, hir::ExprKind::Block(b, None))
22632254
}
22642255

2256+
/// Wrap an expression in a block, and wrap that block in an expression again.
2257+
/// Useful for constructing if-expressions, which require expressions of
2258+
/// kind block.
2259+
pub(super) fn block_expr_block(
2260+
&mut self,
2261+
expr: &'hir hir::Expr<'hir>,
2262+
) -> &'hir hir::Expr<'hir> {
2263+
let b = self.block_expr(expr);
2264+
self.arena.alloc(self.expr_block(b))
2265+
}
2266+
22652267
pub(super) fn expr_array_ref(
22662268
&mut self,
22672269
span: Span,
@@ -2275,6 +2277,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
22752277
self.expr(span, hir::ExprKind::AddrOf(hir::BorrowKind::Ref, hir::Mutability::Not, expr))
22762278
}
22772279

2280+
pub(super) fn expr_bool_literal(&mut self, span: Span, val: bool) -> hir::Expr<'hir> {
2281+
self.expr(span, hir::ExprKind::Lit(Spanned { node: LitKind::Bool(val), span }))
2282+
}
2283+
22782284
pub(super) fn expr(&mut self, span: Span, kind: hir::ExprKind<'hir>) -> hir::Expr<'hir> {
22792285
let hir_id = self.next_id();
22802286
hir::Expr { hir_id, kind, span: self.lower_span(span) }
@@ -2308,6 +2314,19 @@ impl<'hir> LoweringContext<'_, 'hir> {
23082314
body: expr,
23092315
}
23102316
}
2317+
2318+
/// `#[allow(unreachable_code)]`
2319+
pub(super) fn unreachable_code_attr(&mut self, span: Span) -> Attribute {
2320+
let attr = attr::mk_attr_nested_word(
2321+
&self.tcx.sess.psess.attr_id_generator,
2322+
AttrStyle::Outer,
2323+
Safety::Default,
2324+
sym::allow,
2325+
sym::unreachable_code,
2326+
span,
2327+
);
2328+
attr
2329+
}
23112330
}
23122331

23132332
/// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind

0 commit comments

Comments
 (0)