Skip to content

Commit 0f2050b

Browse files
committed
Renumber locals after state transform.
1 parent 998af64 commit 0f2050b

File tree

4 files changed

+129
-103
lines changed

4 files changed

+129
-103
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 86 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ use rustc_hir::lang_items::LangItem;
6868
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
6969
use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
7070
use rustc_index::{Idx, IndexVec};
71-
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
71+
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
7272
use rustc_middle::mir::*;
7373
use rustc_middle::ty::util::Discr;
7474
use rustc_middle::ty::{
@@ -110,6 +110,8 @@ impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
110110
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
111111
if *local == self.from {
112112
*local = self.to;
113+
} else if *local == self.to {
114+
*local = self.from;
113115
}
114116
}
115117

@@ -157,13 +159,15 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
157159
}
158160
}
159161

162+
#[tracing::instrument(level = "trace", skip(tcx))]
160163
fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
161164
place.local = new_base.local;
162165

163166
let mut new_projection = new_base.projection.to_vec();
164167
new_projection.append(&mut place.projection.to_vec());
165168

166169
place.projection = tcx.mk_place_elems(&new_projection);
170+
tracing::trace!(?place);
167171
}
168172

169173
const SELF_ARG: Local = Local::from_u32(1);
@@ -202,8 +206,8 @@ struct TransformVisitor<'tcx> {
202206
// The set of locals that have no `StorageLive`/`StorageDead` annotations.
203207
always_live_locals: DenseBitSet<Local>,
204208

205-
// The original RETURN_PLACE local
206-
old_ret_local: Local,
209+
// New local we just create to hold the `CoroutineState` value.
210+
new_ret_local: Local,
207211

208212
old_yield_ty: Ty<'tcx>,
209213

@@ -268,6 +272,7 @@ impl<'tcx> TransformVisitor<'tcx> {
268272
// `core::ops::CoroutineState` only has single element tuple variants,
269273
// so we can just write to the downcasted first field and then set the
270274
// discriminant to the appropriate variant.
275+
#[tracing::instrument(level = "trace", skip(self, statements))]
271276
fn make_state(
272277
&self,
273278
val: Operand<'tcx>,
@@ -341,11 +346,12 @@ impl<'tcx> TransformVisitor<'tcx> {
341346

342347
statements.push(Statement::new(
343348
source_info,
344-
StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
349+
StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
345350
));
346351
}
347352

348353
// Create a Place referencing a coroutine struct field
354+
#[tracing::instrument(level = "trace", skip(self), ret)]
349355
fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
350356
let self_place = Place::from(SELF_ARG);
351357
let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
@@ -356,6 +362,7 @@ impl<'tcx> TransformVisitor<'tcx> {
356362
}
357363

358364
// Create a statement which changes the discriminant
365+
#[tracing::instrument(level = "trace", skip(self))]
359366
fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
360367
let self_place = Place::from(SELF_ARG);
361368
Statement::new(
@@ -368,6 +375,7 @@ impl<'tcx> TransformVisitor<'tcx> {
368375
}
369376

370377
// Create a statement which reads the discriminant into a temporary
378+
#[tracing::instrument(level = "trace", skip(self, body))]
371379
fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
372380
let temp_decl = LocalDecl::new(self.discr_ty, body.span);
373381
let local_decls_len = body.local_decls.push(temp_decl);
@@ -380,29 +388,41 @@ impl<'tcx> TransformVisitor<'tcx> {
380388
);
381389
(assign, temp)
382390
}
391+
392+
/// Swaps all references of `old_local` and `new_local`.
393+
#[tracing::instrument(level = "trace", skip(self, body))]
394+
fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
395+
body.local_decls.swap(old_local, new_local);
396+
397+
let mut visitor = RenameLocalVisitor { from: old_local, to: new_local, tcx: self.tcx };
398+
visitor.visit_body(body);
399+
for suspension in &mut self.suspension_points {
400+
let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
401+
let location = Location { block: START_BLOCK, statement_index: 0 };
402+
visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
403+
}
404+
}
383405
}
384406

385407
impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
386408
fn tcx(&self) -> TyCtxt<'tcx> {
387409
self.tcx
388410
}
389411

390-
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
412+
#[tracing::instrument(level = "trace", skip(self), ret)]
413+
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
391414
assert!(!self.remap.contains(*local));
392415
}
393416

394-
fn visit_place(
395-
&mut self,
396-
place: &mut Place<'tcx>,
397-
_context: PlaceContext,
398-
_location: Location,
399-
) {
417+
#[tracing::instrument(level = "trace", skip(self), ret)]
418+
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
400419
// Replace an Local in the remap with a coroutine struct access
401420
if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
402421
replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
403422
}
404423
}
405424

425+
#[tracing::instrument(level = "trace", skip(self, data), ret)]
406426
fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
407427
// Remove StorageLive and StorageDead statements for remapped locals
408428
for s in &mut data.statements {
@@ -413,29 +433,35 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
413433
}
414434
}
415435

416-
let ret_val = match data.terminator().kind {
436+
for (statement_index, statement) in data.statements.iter_mut().enumerate() {
437+
let location = Location { block, statement_index };
438+
self.visit_statement(statement, location);
439+
}
440+
441+
let location = Location { block, statement_index: data.statements.len() };
442+
let mut terminator = data.terminator.take().unwrap();
443+
let source_info = terminator.source_info;
444+
match terminator.kind {
417445
TerminatorKind::Return => {
418-
Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None))
446+
let mut v = Operand::Move(Place::return_place());
447+
self.visit_operand(&mut v, location);
448+
// We must assign the value first in case it gets declared dead below
449+
self.make_state(v, source_info, true, &mut data.statements);
450+
// State for returned
451+
let state = VariantIdx::new(CoroutineArgs::RETURNED);
452+
data.statements.push(self.set_discr(state, source_info));
453+
terminator.kind = TerminatorKind::Return;
419454
}
420-
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
421-
Some((false, Some((resume, resume_arg)), value.clone(), drop))
422-
}
423-
_ => None,
424-
};
425-
426-
if let Some((is_return, resume, v, drop)) = ret_val {
427-
let source_info = data.terminator().source_info;
428-
// We must assign the value first in case it gets declared dead below
429-
self.make_state(v, source_info, is_return, &mut data.statements);
430-
let state = if let Some((resume, mut resume_arg)) = resume {
431-
// Yield
432-
let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
433-
455+
TerminatorKind::Yield { mut value, resume, mut resume_arg, drop } => {
434456
// The resume arg target location might itself be remapped if its base local is
435457
// live across a yield.
436-
if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) {
437-
replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
438-
}
458+
self.visit_operand(&mut value, location);
459+
let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
460+
self.visit_place(&mut resume_arg, ctxt, location);
461+
// We must assign the value first in case it gets declared dead below
462+
self.make_state(value.clone(), source_info, false, &mut data.statements);
463+
// Yield
464+
let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
439465

440466
let storage_liveness: GrowableBitSet<Local> =
441467
self.storage_liveness[block].clone().unwrap().into();
@@ -450,7 +476,6 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
450476
.push(Statement::new(source_info, StatementKind::StorageDead(l)));
451477
}
452478
}
453-
454479
self.suspension_points.push(SuspensionPoint {
455480
state,
456481
resume,
@@ -459,16 +484,13 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
459484
storage_liveness,
460485
});
461486

462-
VariantIdx::new(state)
463-
} else {
464-
// Return
465-
VariantIdx::new(CoroutineArgs::RETURNED) // state for returned
466-
};
467-
data.statements.push(self.set_discr(state, source_info));
468-
data.terminator_mut().kind = TerminatorKind::Return;
469-
}
470-
471-
self.super_basic_block_data(block, data);
487+
let state = VariantIdx::new(state);
488+
data.statements.push(self.set_discr(state, source_info));
489+
terminator.kind = TerminatorKind::Return;
490+
}
491+
_ => self.visit_terminator(&mut terminator, location),
492+
};
493+
data.terminator = Some(terminator);
472494
}
473495
}
474496

@@ -481,6 +503,7 @@ fn make_aggregate_adt<'tcx>(
481503
Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
482504
}
483505

506+
#[tracing::instrument(level = "trace", skip(tcx, body))]
484507
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
485508
let coroutine_ty = body.local_decls[SELF_ARG].ty;
486509

@@ -493,6 +516,7 @@ fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Bo
493516
SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
494517
}
495518

519+
#[tracing::instrument(level = "trace", skip(tcx, body))]
496520
fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
497521
let coroutine_ty = body.local_decls[SELF_ARG].ty;
498522

@@ -533,27 +557,6 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
533557
);
534558
}
535559

536-
/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
537-
///
538-
/// `local` will be changed to a new local decl with type `ty`.
539-
///
540-
/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
541-
/// valid value to it before its first use.
542-
fn replace_local<'tcx>(
543-
local: Local,
544-
ty: Ty<'tcx>,
545-
body: &mut Body<'tcx>,
546-
tcx: TyCtxt<'tcx>,
547-
) -> Local {
548-
let new_decl = LocalDecl::new(ty, body.span);
549-
let new_local = body.local_decls.push(new_decl);
550-
body.local_decls.swap(local, new_local);
551-
552-
RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body);
553-
554-
new_local
555-
}
556-
557560
/// Transforms the `body` of the coroutine applying the following transforms:
558561
///
559562
/// - Eliminates all the `get_context` calls that async lowering created.
@@ -575,6 +578,7 @@ fn replace_local<'tcx>(
575578
/// The async lowering step and the type / lifetime inference / checking are
576579
/// still using the `ResumeTy` indirection for the time being, and that indirection
577580
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
581+
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
578582
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
579583
let context_mut_ref = Ty::new_task_context(tcx);
580584

@@ -628,6 +632,7 @@ fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local
628632
}
629633

630634
#[cfg_attr(not(debug_assertions), allow(unused))]
635+
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
631636
fn replace_resume_ty_local<'tcx>(
632637
tcx: TyCtxt<'tcx>,
633638
body: &mut Body<'tcx>,
@@ -638,7 +643,7 @@ fn replace_resume_ty_local<'tcx>(
638643
// We have to replace the `ResumeTy` that is used for type and borrow checking
639644
// with `&mut Context<'_>` in MIR.
640645
#[cfg(debug_assertions)]
641-
{
646+
if local_ty != context_mut_ref {
642647
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
643648
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
644649
assert_eq!(*resume_ty_adt, expected_adt);
@@ -692,6 +697,7 @@ struct LivenessInfo {
692697
/// case none exist, the local is considered to be always live.
693698
/// - a local has to be stored if it is either directly used after the
694699
/// the suspend point, or if it is live and has been previously borrowed.
700+
#[tracing::instrument(level = "trace", skip(tcx, body))]
695701
fn locals_live_across_suspend_points<'tcx>(
696702
tcx: TyCtxt<'tcx>,
697703
body: &Body<'tcx>,
@@ -967,6 +973,7 @@ impl StorageConflictVisitor<'_, '_> {
967973
}
968974
}
969975

976+
#[tracing::instrument(level = "trace", skip(liveness, body))]
970977
fn compute_layout<'tcx>(
971978
liveness: LivenessInfo,
972979
body: &Body<'tcx>,
@@ -1071,7 +1078,9 @@ fn compute_layout<'tcx>(
10711078
variant_source_info,
10721079
storage_conflicts,
10731080
};
1081+
debug!(?remap);
10741082
debug!(?layout);
1083+
debug!(?storage_liveness);
10751084

10761085
(remap, layout, storage_liveness)
10771086
}
@@ -1243,6 +1252,7 @@ fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
12431252
}
12441253
}
12451254

1255+
#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
12461256
fn create_coroutine_resume_function<'tcx>(
12471257
tcx: TyCtxt<'tcx>,
12481258
transform: TransformVisitor<'tcx>,
@@ -1321,7 +1331,7 @@ fn create_coroutine_resume_function<'tcx>(
13211331
}
13221332

13231333
/// An operation that can be performed on a coroutine.
1324-
#[derive(PartialEq, Copy, Clone)]
1334+
#[derive(PartialEq, Copy, Clone, Debug)]
13251335
enum Operation {
13261336
Resume,
13271337
Drop,
@@ -1336,6 +1346,7 @@ impl Operation {
13361346
}
13371347
}
13381348

1349+
#[tracing::instrument(level = "trace", skip(transform, body))]
13391350
fn create_cases<'tcx>(
13401351
body: &mut Body<'tcx>,
13411352
transform: &TransformVisitor<'tcx>,
@@ -1467,6 +1478,8 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
14671478
// This only applies to coroutines
14681479
return;
14691480
};
1481+
tracing::trace!(def_id = ?body.source.def_id());
1482+
14701483
let old_ret_ty = body.return_ty();
14711484

14721485
assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
@@ -1513,10 +1526,6 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15131526
}
15141527
};
15151528

1516-
// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
1517-
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1518-
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
1519-
15201529
// We need to insert clean drop for unresumed state and perform drop elaboration
15211530
// (finally in open_drop_for_tuple) before async drop expansion.
15221531
// Async drops, produced by this drop elaboration, will be expanded,
@@ -1541,6 +1550,11 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15411550
cleanup_async_drops(body);
15421551
}
15431552

1553+
// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1554+
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1555+
let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
1556+
tracing::trace!(?new_ret_local);
1557+
15441558
let always_live_locals = always_storage_live_locals(body);
15451559
let movable = coroutine_kind.movability() == hir::Movability::Movable;
15461560
let liveness_info =
@@ -1575,13 +1589,16 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15751589
storage_liveness,
15761590
always_live_locals,
15771591
suspension_points: Vec::new(),
1578-
old_ret_local,
15791592
discr_ty,
1593+
new_ret_local,
15801594
old_ret_ty,
15811595
old_yield_ty,
15821596
};
15831597
transform.visit_body(body);
15841598

1599+
// Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
1600+
transform.replace_local(RETURN_PLACE, new_ret_local, body);
1601+
15851602
// MIR parameters are not explicitly assigned-to when entering the MIR body.
15861603
// If we want to save their values inside the coroutine state, we need to do so explicitly.
15871604
let source_info = SourceInfo::outermost(body.span);

0 commit comments

Comments
 (0)