@@ -68,7 +68,7 @@ use rustc_hir::lang_items::LangItem;
6868use rustc_hir:: { CoroutineDesugaring , CoroutineKind } ;
6969use rustc_index:: bit_set:: { BitMatrix , DenseBitSet , GrowableBitSet } ;
7070use rustc_index:: { Idx , IndexVec } ;
71- use rustc_middle:: mir:: visit:: { MutVisitor , PlaceContext , Visitor } ;
71+ use rustc_middle:: mir:: visit:: { MutVisitor , MutatingUseContext , PlaceContext , Visitor } ;
7272use rustc_middle:: mir:: * ;
7373use rustc_middle:: ty:: util:: Discr ;
7474use rustc_middle:: ty:: {
@@ -111,6 +111,8 @@ impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
111111 fn visit_local ( & mut self , local : & mut Local , _: PlaceContext , _: Location ) {
112112 if * local == self . from {
113113 * local = self . to ;
114+ } else if * local == self . to {
115+ * local = self . from ;
114116 }
115117 }
116118
@@ -160,13 +162,15 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
160162 }
161163}
162164
165+ #[ tracing:: instrument( level = "trace" , skip( tcx) ) ]
163166fn replace_base < ' tcx > ( place : & mut Place < ' tcx > , new_base : Place < ' tcx > , tcx : TyCtxt < ' tcx > ) {
164167 place. local = new_base. local ;
165168
166169 let mut new_projection = new_base. projection . to_vec ( ) ;
167170 new_projection. append ( & mut place. projection . to_vec ( ) ) ;
168171
169172 place. projection = tcx. mk_place_elems ( & new_projection) ;
173+ tracing:: trace!( ?place) ;
170174}
171175
172176const SELF_ARG : Local = Local :: from_u32 ( 1 ) ;
@@ -205,8 +209,8 @@ struct TransformVisitor<'tcx> {
205209 // The set of locals that have no `StorageLive`/`StorageDead` annotations.
206210 always_live_locals : DenseBitSet < Local > ,
207211
208- // The original RETURN_PLACE local
209- old_ret_local : Local ,
212+ // New local we just create to hold the `CoroutineState` value.
213+ new_ret_local : Local ,
210214
211215 old_yield_ty : Ty < ' tcx > ,
212216
@@ -271,6 +275,7 @@ impl<'tcx> TransformVisitor<'tcx> {
271275 // `core::ops::CoroutineState` only has single element tuple variants,
272276 // so we can just write to the downcasted first field and then set the
273277 // discriminant to the appropriate variant.
278+ #[ tracing:: instrument( level = "trace" , skip( self , statements) ) ]
274279 fn make_state (
275280 & self ,
276281 val : Operand < ' tcx > ,
@@ -344,11 +349,12 @@ impl<'tcx> TransformVisitor<'tcx> {
344349
345350 statements. push ( Statement :: new (
346351 source_info,
347- StatementKind :: Assign ( Box :: new ( ( Place :: return_place ( ) , rvalue) ) ) ,
352+ StatementKind :: Assign ( Box :: new ( ( self . new_ret_local . into ( ) , rvalue) ) ) ,
348353 ) ) ;
349354 }
350355
351356 // Create a Place referencing a coroutine struct field
357+ #[ tracing:: instrument( level = "trace" , skip( self ) , ret) ]
352358 fn make_field ( & self , variant_index : VariantIdx , idx : FieldIdx , ty : Ty < ' tcx > ) -> Place < ' tcx > {
353359 let self_place = Place :: from ( SELF_ARG ) ;
354360 let base = self . tcx . mk_place_downcast_unnamed ( self_place, variant_index) ;
@@ -359,6 +365,7 @@ impl<'tcx> TransformVisitor<'tcx> {
359365 }
360366
361367 // Create a statement which changes the discriminant
368+ #[ tracing:: instrument( level = "trace" , skip( self ) ) ]
362369 fn set_discr ( & self , state_disc : VariantIdx , source_info : SourceInfo ) -> Statement < ' tcx > {
363370 let self_place = Place :: from ( SELF_ARG ) ;
364371 Statement :: new (
@@ -371,6 +378,7 @@ impl<'tcx> TransformVisitor<'tcx> {
371378 }
372379
373380 // Create a statement which reads the discriminant into a temporary
381+ #[ tracing:: instrument( level = "trace" , skip( self , body) ) ]
374382 fn get_discr ( & self , body : & mut Body < ' tcx > ) -> ( Statement < ' tcx > , Place < ' tcx > ) {
375383 let temp_decl = LocalDecl :: new ( self . discr_ty , body. span ) ;
376384 let local_decls_len = body. local_decls . push ( temp_decl) ;
@@ -383,29 +391,48 @@ impl<'tcx> TransformVisitor<'tcx> {
383391 ) ;
384392 ( assign, temp)
385393 }
394+
395+ /// Allocates a new local and replaces all references of `local` with it. Returns the new local.
396+ ///
397+ /// `local` will be changed to a new local decl with type `ty`.
398+ ///
399+ /// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
400+ /// valid value to it before its first use.
401+ #[ tracing:: instrument( level = "trace" , skip( self , body) ) ]
402+ fn replace_local ( & mut self , local : Local , new_local : Local , body : & mut Body < ' tcx > ) -> Local {
403+ body. local_decls . swap ( local, new_local) ;
404+
405+ let mut visitor = RenameLocalVisitor { from : local, to : new_local, tcx : self . tcx } ;
406+ visitor. visit_body ( body) ;
407+ for suspension in & mut self . suspension_points {
408+ let ctxt = PlaceContext :: MutatingUse ( MutatingUseContext :: Yield ) ;
409+ let location = Location { block : START_BLOCK , statement_index : 0 } ;
410+ visitor. visit_place ( & mut suspension. resume_arg , ctxt, location) ;
411+ }
412+
413+ new_local
414+ }
386415}
387416
388417impl < ' tcx > MutVisitor < ' tcx > for TransformVisitor < ' tcx > {
389418 fn tcx ( & self ) -> TyCtxt < ' tcx > {
390419 self . tcx
391420 }
392421
393- fn visit_local ( & mut self , local : & mut Local , _: PlaceContext , _: Location ) {
422+ #[ tracing:: instrument( level = "trace" , skip( self ) , ret) ]
423+ fn visit_local ( & mut self , local : & mut Local , _: PlaceContext , _location : Location ) {
394424 assert ! ( !self . remap. contains( * local) ) ;
395425 }
396426
397- fn visit_place (
398- & mut self ,
399- place : & mut Place < ' tcx > ,
400- _context : PlaceContext ,
401- _location : Location ,
402- ) {
427+ #[ tracing:: instrument( level = "trace" , skip( self ) , ret) ]
428+ fn visit_place ( & mut self , place : & mut Place < ' tcx > , _: PlaceContext , _location : Location ) {
403429 // Replace an Local in the remap with a coroutine struct access
404430 if let Some ( & Some ( ( ty, variant_index, idx) ) ) = self . remap . get ( place. local ) {
405431 replace_base ( place, self . make_field ( variant_index, idx, ty) , self . tcx ) ;
406432 }
407433 }
408434
435+ #[ tracing:: instrument( level = "trace" , skip( self , data) , ret) ]
409436 fn visit_basic_block_data ( & mut self , block : BasicBlock , data : & mut BasicBlockData < ' tcx > ) {
410437 // Remove StorageLive and StorageDead statements for remapped locals
411438 for s in & mut data. statements {
@@ -416,29 +443,35 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
416443 }
417444 }
418445
419- let ret_val = match data. terminator ( ) . kind {
446+ for ( statement_index, statement) in data. statements . iter_mut ( ) . enumerate ( ) {
447+ let location = Location { block, statement_index } ;
448+ self . visit_statement ( statement, location) ;
449+ }
450+
451+ let location = Location { block, statement_index : data. statements . len ( ) } ;
452+ let mut terminator = data. terminator . take ( ) . unwrap ( ) ;
453+ let source_info = terminator. source_info ;
454+ match terminator. kind {
420455 TerminatorKind :: Return => {
421- Some ( ( true , None , Operand :: Move ( Place :: from ( self . old_ret_local ) ) , None ) )
422- }
423- TerminatorKind :: Yield { ref value, resume, resume_arg, drop } => {
424- Some ( ( false , Some ( ( resume, resume_arg) ) , value. clone ( ) , drop) )
456+ let mut v = Operand :: Move ( Place :: return_place ( ) ) ;
457+ self . visit_operand ( & mut v, location) ;
458+ // We must assign the value first in case it gets declared dead below
459+ self . make_state ( v, source_info, true , & mut data. statements ) ;
460+ // State for returned
461+ let state = VariantIdx :: new ( CoroutineArgs :: RETURNED ) ;
462+ data. statements . push ( self . set_discr ( state, source_info) ) ;
463+ terminator. kind = TerminatorKind :: Return ;
425464 }
426- _ => None ,
427- } ;
428-
429- if let Some ( ( is_return, resume, v, drop) ) = ret_val {
430- let source_info = data. terminator ( ) . source_info ;
431- // We must assign the value first in case it gets declared dead below
432- self . make_state ( v, source_info, is_return, & mut data. statements ) ;
433- let state = if let Some ( ( resume, mut resume_arg) ) = resume {
434- // Yield
435- let state = CoroutineArgs :: RESERVED_VARIANTS + self . suspension_points . len ( ) ;
436-
465+ TerminatorKind :: Yield { mut value, resume, mut resume_arg, drop } => {
437466 // The resume arg target location might itself be remapped if its base local is
438467 // live across a yield.
439- if let Some ( & Some ( ( ty, variant, idx) ) ) = self . remap . get ( resume_arg. local ) {
440- replace_base ( & mut resume_arg, self . make_field ( variant, idx, ty) , self . tcx ) ;
441- }
468+ self . visit_operand ( & mut value, location) ;
469+ let ctxt = PlaceContext :: MutatingUse ( MutatingUseContext :: Yield ) ;
470+ self . visit_place ( & mut resume_arg, ctxt, location) ;
471+ // We must assign the value first in case it gets declared dead below
472+ self . make_state ( value. clone ( ) , source_info, false , & mut data. statements ) ;
473+ // Yield
474+ let state = CoroutineArgs :: RESERVED_VARIANTS + self . suspension_points . len ( ) ;
442475
443476 let storage_liveness: GrowableBitSet < Local > =
444477 self . storage_liveness [ block] . clone ( ) . unwrap ( ) . into ( ) ;
@@ -453,7 +486,6 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
453486 . push ( Statement :: new ( source_info, StatementKind :: StorageDead ( l) ) ) ;
454487 }
455488 }
456-
457489 self . suspension_points . push ( SuspensionPoint {
458490 state,
459491 resume,
@@ -462,16 +494,13 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
462494 storage_liveness,
463495 } ) ;
464496
465- VariantIdx :: new ( state)
466- } else {
467- // Return
468- VariantIdx :: new ( CoroutineArgs :: RETURNED ) // state for returned
469- } ;
470- data. statements . push ( self . set_discr ( state, source_info) ) ;
471- data. terminator_mut ( ) . kind = TerminatorKind :: Return ;
472- }
473-
474- self . super_basic_block_data ( block, data) ;
497+ let state = VariantIdx :: new ( state) ;
498+ data. statements . push ( self . set_discr ( state, source_info) ) ;
499+ terminator. kind = TerminatorKind :: Return ;
500+ }
501+ _ => self . visit_terminator ( & mut terminator, location) ,
502+ } ;
503+ data. terminator = Some ( terminator) ;
475504 }
476505}
477506
@@ -484,6 +513,7 @@ fn make_aggregate_adt<'tcx>(
484513 Rvalue :: Aggregate ( Box :: new ( AggregateKind :: Adt ( def_id, variant_idx, args, None , None ) ) , operands)
485514}
486515
516+ #[ tracing:: instrument( level = "trace" , skip( tcx, body) ) ]
487517fn make_coroutine_state_argument_indirect < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
488518 let coroutine_ty = body. local_decls . raw [ 1 ] . ty ;
489519
@@ -496,6 +526,7 @@ fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Bo
496526 SelfArgVisitor :: new ( tcx, ProjectionElem :: Deref ) . visit_body ( body) ;
497527}
498528
529+ #[ tracing:: instrument( level = "trace" , skip( tcx, body) ) ]
499530fn make_coroutine_state_argument_pinned < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
500531 let ref_coroutine_ty = body. local_decls . raw [ 1 ] . ty ;
501532
@@ -512,27 +543,6 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
512543 . visit_body ( body) ;
513544}
514545
515- /// Allocates a new local and replaces all references of `local` with it. Returns the new local.
516- ///
517- /// `local` will be changed to a new local decl with type `ty`.
518- ///
519- /// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
520- /// valid value to it before its first use.
521- fn replace_local < ' tcx > (
522- local : Local ,
523- ty : Ty < ' tcx > ,
524- body : & mut Body < ' tcx > ,
525- tcx : TyCtxt < ' tcx > ,
526- ) -> Local {
527- let new_decl = LocalDecl :: new ( ty, body. span ) ;
528- let new_local = body. local_decls . push ( new_decl) ;
529- body. local_decls . swap ( local, new_local) ;
530-
531- RenameLocalVisitor { from : local, to : new_local, tcx } . visit_body ( body) ;
532-
533- new_local
534- }
535-
536546/// Transforms the `body` of the coroutine applying the following transforms:
537547///
538548/// - Eliminates all the `get_context` calls that async lowering created.
@@ -554,6 +564,7 @@ fn replace_local<'tcx>(
554564/// The async lowering step and the type / lifetime inference / checking are
555565/// still using the `ResumeTy` indirection for the time being, and that indirection
556566/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
567+ #[ tracing:: instrument( level = "trace" , skip( tcx, body) , ret) ]
557568fn transform_async_context < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) -> Ty < ' tcx > {
558569 let context_mut_ref = Ty :: new_task_context ( tcx) ;
559570
@@ -607,6 +618,7 @@ fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local
607618}
608619
609620#[ cfg_attr( not( debug_assertions) , allow( unused) ) ]
621+ #[ tracing:: instrument( level = "trace" , skip( tcx, body) , ret) ]
610622fn replace_resume_ty_local < ' tcx > (
611623 tcx : TyCtxt < ' tcx > ,
612624 body : & mut Body < ' tcx > ,
@@ -617,7 +629,7 @@ fn replace_resume_ty_local<'tcx>(
617629 // We have to replace the `ResumeTy` that is used for type and borrow checking
618630 // with `&mut Context<'_>` in MIR.
619631 #[ cfg( debug_assertions) ]
620- {
632+ if local_ty != context_mut_ref {
621633 if let ty:: Adt ( resume_ty_adt, _) = local_ty. kind ( ) {
622634 let expected_adt = tcx. adt_def ( tcx. require_lang_item ( LangItem :: ResumeTy , body. span ) ) ;
623635 assert_eq ! ( * resume_ty_adt, expected_adt) ;
@@ -671,6 +683,7 @@ struct LivenessInfo {
671683/// case none exist, the local is considered to be always live.
672684/// - a local has to be stored if it is either directly used after the
673685/// the suspend point, or if it is live and has been previously borrowed.
686+ #[ tracing:: instrument( level = "trace" , skip( tcx, body) ) ]
674687fn locals_live_across_suspend_points < ' tcx > (
675688 tcx : TyCtxt < ' tcx > ,
676689 body : & Body < ' tcx > ,
@@ -946,6 +959,7 @@ impl StorageConflictVisitor<'_, '_> {
946959 }
947960}
948961
962+ #[ tracing:: instrument( level = "trace" , skip( liveness, body) ) ]
949963fn compute_layout < ' tcx > (
950964 liveness : LivenessInfo ,
951965 body : & Body < ' tcx > ,
@@ -1050,7 +1064,9 @@ fn compute_layout<'tcx>(
10501064 variant_source_info,
10511065 storage_conflicts,
10521066 } ;
1067+ debug ! ( ?remap) ;
10531068 debug ! ( ?layout) ;
1069+ debug ! ( ?storage_liveness) ;
10541070
10551071 ( remap, layout, storage_liveness)
10561072}
@@ -1222,6 +1238,7 @@ fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
12221238 }
12231239}
12241240
1241+ #[ tracing:: instrument( level = "trace" , skip( tcx, transform, body) ) ]
12251242fn create_coroutine_resume_function < ' tcx > (
12261243 tcx : TyCtxt < ' tcx > ,
12271244 transform : TransformVisitor < ' tcx > ,
@@ -1300,7 +1317,7 @@ fn create_coroutine_resume_function<'tcx>(
13001317}
13011318
13021319/// An operation that can be performed on a coroutine.
1303- #[ derive( PartialEq , Copy , Clone ) ]
1320+ #[ derive( PartialEq , Copy , Clone , Debug ) ]
13041321enum Operation {
13051322 Resume ,
13061323 Drop ,
@@ -1315,6 +1332,7 @@ impl Operation {
13151332 }
13161333}
13171334
1335+ #[ tracing:: instrument( level = "trace" , skip( transform, body) ) ]
13181336fn create_cases < ' tcx > (
13191337 body : & mut Body < ' tcx > ,
13201338 transform : & TransformVisitor < ' tcx > ,
@@ -1446,6 +1464,8 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
14461464 // This only applies to coroutines
14471465 return ;
14481466 } ;
1467+ tracing:: trace!( def_id = ?body. source. def_id( ) ) ;
1468+
14491469 let old_ret_ty = body. return_ty ( ) ;
14501470
14511471 assert ! ( body. coroutine_drop( ) . is_none( ) && body. coroutine_drop_async( ) . is_none( ) ) ;
@@ -1492,10 +1512,6 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
14921512 }
14931513 } ;
14941514
1495- // We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
1496- // RETURN_PLACE then is a fresh unused local with type ret_ty.
1497- let old_ret_local = replace_local ( RETURN_PLACE , new_ret_ty, body, tcx) ;
1498-
14991515 // We need to insert clean drop for unresumed state and perform drop elaboration
15001516 // (finally in open_drop_for_tuple) before async drop expansion.
15011517 // Async drops, produced by this drop elaboration, will be expanded,
@@ -1520,6 +1536,11 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15201536 cleanup_async_drops ( body) ;
15211537 }
15221538
1539+ // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1540+ // RETURN_PLACE then is a fresh unused local with type ret_ty.
1541+ let new_ret_local = body. local_decls . push ( LocalDecl :: new ( new_ret_ty, body. span ) ) ;
1542+ tracing:: trace!( ?new_ret_local) ;
1543+
15231544 let always_live_locals = always_storage_live_locals ( body) ;
15241545 let movable = coroutine_kind. movability ( ) == hir:: Movability :: Movable ;
15251546 let liveness_info =
@@ -1554,13 +1575,16 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
15541575 storage_liveness,
15551576 always_live_locals,
15561577 suspension_points : Vec :: new ( ) ,
1557- old_ret_local,
15581578 discr_ty,
1579+ new_ret_local,
15591580 old_ret_ty,
15601581 old_yield_ty,
15611582 } ;
15621583 transform. visit_body ( body) ;
15631584
1585+ // Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
1586+ transform. replace_local ( RETURN_PLACE , new_ret_local, body) ;
1587+
15641588 // MIR parameters are not explicitly assigned-to when entering the MIR body.
15651589 // If we want to save their values inside the coroutine state, we need to do so explicitly.
15661590 let source_info = SourceInfo :: outermost ( body. span ) ;
0 commit comments