@@ -66,9 +66,9 @@ use rustc_index::{Idx, IndexVec};
6666use rustc_middle:: mir:: dump_mir;
6767use rustc_middle:: mir:: visit:: { MutVisitor , PlaceContext , Visitor } ;
6868use rustc_middle:: mir:: * ;
69+ use rustc_middle:: ty:: CoroutineArgs ;
6970use rustc_middle:: ty:: InstanceDef ;
70- use rustc_middle:: ty:: { self , AdtDef , Ty , TyCtxt } ;
71- use rustc_middle:: ty:: { CoroutineArgs , GenericArgsRef } ;
71+ use rustc_middle:: ty:: { self , Ty , TyCtxt } ;
7272use rustc_mir_dataflow:: impls:: {
7373 MaybeBorrowedLocals , MaybeLiveLocals , MaybeRequiresStorage , MaybeStorageLive ,
7474} ;
@@ -225,8 +225,6 @@ struct SuspensionPoint<'tcx> {
225225struct TransformVisitor < ' tcx > {
226226 tcx : TyCtxt < ' tcx > ,
227227 coroutine_kind : hir:: CoroutineKind ,
228- state_adt_ref : AdtDef < ' tcx > ,
229- state_args : GenericArgsRef < ' tcx > ,
230228
231229 // The type of the discriminant in the coroutine struct
232230 discr_ty : Ty < ' tcx > ,
@@ -245,21 +243,34 @@ struct TransformVisitor<'tcx> {
245243 always_live_locals : BitSet < Local > ,
246244
247245 // The original RETURN_PLACE local
248- new_ret_local : Local ,
246+ old_ret_local : Local ,
247+
248+ old_yield_ty : Ty < ' tcx > ,
249+
250+ old_ret_ty : Ty < ' tcx > ,
249251}
250252
251253impl < ' tcx > TransformVisitor < ' tcx > {
252254 fn insert_none_ret_block ( & self , body : & mut Body < ' tcx > ) -> BasicBlock {
253- let block = BasicBlock :: new ( body . basic_blocks . len ( ) ) ;
255+ assert ! ( matches! ( self . coroutine_kind , CoroutineKind :: Gen ( _ ) ) ) ;
254256
257+ let block = BasicBlock :: new ( body. basic_blocks . len ( ) ) ;
255258 let source_info = SourceInfo :: outermost ( body. span ) ;
259+ let option_def_id = self . tcx . require_lang_item ( LangItem :: Option , None ) ;
256260
257- let ( kind, idx) = self . coroutine_state_adt_and_variant_idx ( true ) ;
258- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
259261 let statements = vec ! [ Statement {
260262 kind: StatementKind :: Assign ( Box :: new( (
261263 Place :: return_place( ) ,
262- Rvalue :: Aggregate ( Box :: new( kind) , IndexVec :: new( ) ) ,
264+ Rvalue :: Aggregate (
265+ Box :: new( AggregateKind :: Adt (
266+ option_def_id,
267+ VariantIdx :: from_usize( 0 ) ,
268+ self . tcx. mk_args( & [ self . old_yield_ty. into( ) ] ) ,
269+ None ,
270+ None ,
271+ ) ) ,
272+ IndexVec :: new( ) ,
273+ ) ,
263274 ) ) ) ,
264275 source_info,
265276 } ] ;
@@ -273,23 +284,6 @@ impl<'tcx> TransformVisitor<'tcx> {
273284 block
274285 }
275286
276- fn coroutine_state_adt_and_variant_idx (
277- & self ,
278- is_return : bool ,
279- ) -> ( AggregateKind < ' tcx > , VariantIdx ) {
280- let idx = VariantIdx :: new ( match ( is_return, self . coroutine_kind ) {
281- ( true , hir:: CoroutineKind :: Coroutine ) => 1 , // CoroutineState::Complete
282- ( false , hir:: CoroutineKind :: Coroutine ) => 0 , // CoroutineState::Yielded
283- ( true , hir:: CoroutineKind :: Async ( _) ) => 0 , // Poll::Ready
284- ( false , hir:: CoroutineKind :: Async ( _) ) => 1 , // Poll::Pending
285- ( true , hir:: CoroutineKind :: Gen ( _) ) => 0 , // Option::None
286- ( false , hir:: CoroutineKind :: Gen ( _) ) => 1 , // Option::Some
287- } ) ;
288-
289- let kind = AggregateKind :: Adt ( self . state_adt_ref . did ( ) , idx, self . state_args , None , None ) ;
290- ( kind, idx)
291- }
292-
293287 // Make a `CoroutineState` or `Poll` variant assignment.
294288 //
295289 // `core::ops::CoroutineState` only has single element tuple variants,
@@ -302,51 +296,99 @@ impl<'tcx> TransformVisitor<'tcx> {
302296 is_return : bool ,
303297 statements : & mut Vec < Statement < ' tcx > > ,
304298 ) {
305- let ( kind, idx) = self . coroutine_state_adt_and_variant_idx ( is_return) ;
306-
307- match self . coroutine_kind {
308- // `Poll::Pending`
299+ let rvalue = match self . coroutine_kind {
309300 CoroutineKind :: Async ( _) => {
310- if !is_return {
311- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
312-
313- // FIXME(swatinem): assert that `val` is indeed unit?
314- statements. push ( Statement {
315- kind : StatementKind :: Assign ( Box :: new ( (
316- Place :: return_place ( ) ,
317- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
318- ) ) ) ,
319- source_info,
320- } ) ;
321- return ;
301+ let poll_def_id = self . tcx . require_lang_item ( LangItem :: Poll , None ) ;
302+ let args = self . tcx . mk_args ( & [ self . old_ret_ty . into ( ) ] ) ;
303+ if is_return {
304+ // Poll::Ready(val)
305+ Rvalue :: Aggregate (
306+ Box :: new ( AggregateKind :: Adt (
307+ poll_def_id,
308+ VariantIdx :: from_usize ( 0 ) ,
309+ args,
310+ None ,
311+ None ,
312+ ) ) ,
313+ IndexVec :: from_raw ( vec ! [ val] ) ,
314+ )
315+ } else {
316+ // Poll::Pending
317+ Rvalue :: Aggregate (
318+ Box :: new ( AggregateKind :: Adt (
319+ poll_def_id,
320+ VariantIdx :: from_usize ( 1 ) ,
321+ args,
322+ None ,
323+ None ,
324+ ) ) ,
325+ IndexVec :: new ( ) ,
326+ )
322327 }
323328 }
324- // `Option::None`
325329 CoroutineKind :: Gen ( _) => {
330+ let option_def_id = self . tcx . require_lang_item ( LangItem :: Option , None ) ;
331+ let args = self . tcx . mk_args ( & [ self . old_yield_ty . into ( ) ] ) ;
326332 if is_return {
327- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
328-
329- statements. push ( Statement {
330- kind : StatementKind :: Assign ( Box :: new ( (
331- Place :: return_place ( ) ,
332- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
333- ) ) ) ,
334- source_info,
335- } ) ;
336- return ;
333+ // None
334+ Rvalue :: Aggregate (
335+ Box :: new ( AggregateKind :: Adt (
336+ option_def_id,
337+ VariantIdx :: from_usize ( 0 ) ,
338+ args,
339+ None ,
340+ None ,
341+ ) ) ,
342+ IndexVec :: new ( ) ,
343+ )
344+ } else {
345+ // Some(val)
346+ Rvalue :: Aggregate (
347+ Box :: new ( AggregateKind :: Adt (
348+ option_def_id,
349+ VariantIdx :: from_usize ( 1 ) ,
350+ args,
351+ None ,
352+ None ,
353+ ) ) ,
354+ IndexVec :: from_raw ( vec ! [ val] ) ,
355+ )
337356 }
338357 }
339- CoroutineKind :: Coroutine => { }
340- }
341-
342- // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
343- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 1 ) ;
358+ CoroutineKind :: Coroutine => {
359+ let coroutine_state_def_id =
360+ self . tcx . require_lang_item ( LangItem :: CoroutineState , None ) ;
361+ let args = self . tcx . mk_args ( & [ self . old_yield_ty . into ( ) , self . old_ret_ty . into ( ) ] ) ;
362+ if is_return {
363+ // CoroutineState::Complete(val)
364+ Rvalue :: Aggregate (
365+ Box :: new ( AggregateKind :: Adt (
366+ coroutine_state_def_id,
367+ VariantIdx :: from_usize ( 1 ) ,
368+ args,
369+ None ,
370+ None ,
371+ ) ) ,
372+ IndexVec :: from_raw ( vec ! [ val] ) ,
373+ )
374+ } else {
375+ // CoroutineState::Yielded(val)
376+ Rvalue :: Aggregate (
377+ Box :: new ( AggregateKind :: Adt (
378+ coroutine_state_def_id,
379+ VariantIdx :: from_usize ( 0 ) ,
380+ args,
381+ None ,
382+ None ,
383+ ) ) ,
384+ IndexVec :: from_raw ( vec ! [ val] ) ,
385+ )
386+ }
387+ }
388+ } ;
344389
345390 statements. push ( Statement {
346- kind : StatementKind :: Assign ( Box :: new ( (
347- Place :: return_place ( ) ,
348- Rvalue :: Aggregate ( Box :: new ( kind) , [ val] . into ( ) ) ,
349- ) ) ) ,
391+ kind : StatementKind :: Assign ( Box :: new ( ( Place :: return_place ( ) , rvalue) ) ) ,
350392 source_info,
351393 } ) ;
352394 }
@@ -420,7 +462,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
420462
421463 let ret_val = match data. terminator ( ) . kind {
422464 TerminatorKind :: Return => {
423- Some ( ( true , None , Operand :: Move ( Place :: from ( self . new_ret_local ) ) , None ) )
465+ Some ( ( true , None , Operand :: Move ( Place :: from ( self . old_ret_local ) ) , None ) )
424466 }
425467 TerminatorKind :: Yield { ref value, resume, resume_arg, drop } => {
426468 Some ( ( false , Some ( ( resume, resume_arg) ) , value. clone ( ) , drop) )
@@ -1493,10 +1535,11 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
14931535
14941536impl < ' tcx > MirPass < ' tcx > for StateTransform {
14951537 fn run_pass ( & self , tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
1496- let Some ( yield_ty ) = body. yield_ty ( ) else {
1538+ let Some ( old_yield_ty ) = body. yield_ty ( ) else {
14971539 // This only applies to coroutines
14981540 return ;
14991541 } ;
1542+ let old_ret_ty = body. return_ty ( ) ;
15001543
15011544 assert ! ( body. coroutine_drop( ) . is_none( ) ) ;
15021545
@@ -1520,34 +1563,33 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15201563
15211564 let is_async_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Async ( _) ) ) ;
15221565 let is_gen_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Gen ( _) ) ) ;
1523- let ( state_adt_ref , state_args ) = match body. coroutine_kind ( ) . unwrap ( ) {
1566+ let new_ret_ty = match body. coroutine_kind ( ) . unwrap ( ) {
15241567 CoroutineKind :: Async ( _) => {
15251568 // Compute Poll<return_ty>
15261569 let poll_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
15271570 let poll_adt_ref = tcx. adt_def ( poll_did) ;
1528- let poll_args = tcx. mk_args ( & [ body . return_ty ( ) . into ( ) ] ) ;
1529- ( poll_adt_ref, poll_args)
1571+ let poll_args = tcx. mk_args ( & [ old_ret_ty . into ( ) ] ) ;
1572+ Ty :: new_adt ( tcx , poll_adt_ref, poll_args)
15301573 }
15311574 CoroutineKind :: Gen ( _) => {
15321575 // Compute Option<yield_ty>
15331576 let option_did = tcx. require_lang_item ( LangItem :: Option , None ) ;
15341577 let option_adt_ref = tcx. adt_def ( option_did) ;
1535- let option_args = tcx. mk_args ( & [ body . yield_ty ( ) . unwrap ( ) . into ( ) ] ) ;
1536- ( option_adt_ref, option_args)
1578+ let option_args = tcx. mk_args ( & [ old_yield_ty . into ( ) ] ) ;
1579+ Ty :: new_adt ( tcx , option_adt_ref, option_args)
15371580 }
15381581 CoroutineKind :: Coroutine => {
15391582 // Compute CoroutineState<yield_ty, return_ty>
15401583 let state_did = tcx. require_lang_item ( LangItem :: CoroutineState , None ) ;
15411584 let state_adt_ref = tcx. adt_def ( state_did) ;
1542- let state_args = tcx. mk_args ( & [ yield_ty . into ( ) , body . return_ty ( ) . into ( ) ] ) ;
1543- ( state_adt_ref, state_args)
1585+ let state_args = tcx. mk_args ( & [ old_yield_ty . into ( ) , old_ret_ty . into ( ) ] ) ;
1586+ Ty :: new_adt ( tcx , state_adt_ref, state_args)
15441587 }
15451588 } ;
1546- let ret_ty = Ty :: new_adt ( tcx, state_adt_ref, state_args) ;
15471589
1548- // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1590+ // We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
15491591 // RETURN_PLACE then is a fresh unused local with type ret_ty.
1550- let new_ret_local = replace_local ( RETURN_PLACE , ret_ty , body, tcx) ;
1592+ let old_ret_local = replace_local ( RETURN_PLACE , new_ret_ty , body, tcx) ;
15511593
15521594 // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
15531595 if is_async_kind {
@@ -1564,17 +1606,18 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15641606 } else {
15651607 body. local_decls [ resume_local] . ty
15661608 } ;
1567- let new_resume_local = replace_local ( resume_local, resume_ty, body, tcx) ;
1609+ let old_resume_local = replace_local ( resume_local, resume_ty, body, tcx) ;
15681610
1569- // When first entering the coroutine, move the resume argument into its new local.
1611+ // When first entering the coroutine, move the resume argument into its old local
1612+ // (which is now a generator interior).
15701613 let source_info = SourceInfo :: outermost ( body. span ) ;
15711614 let stmts = & mut body. basic_blocks_mut ( ) [ START_BLOCK ] . statements ;
15721615 stmts. insert (
15731616 0 ,
15741617 Statement {
15751618 source_info,
15761619 kind : StatementKind :: Assign ( Box :: new ( (
1577- new_resume_local . into ( ) ,
1620+ old_resume_local . into ( ) ,
15781621 Rvalue :: Use ( Operand :: Move ( resume_local. into ( ) ) ) ,
15791622 ) ) ) ,
15801623 } ,
@@ -1610,14 +1653,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
16101653 let mut transform = TransformVisitor {
16111654 tcx,
16121655 coroutine_kind : body. coroutine_kind ( ) . unwrap ( ) ,
1613- state_adt_ref,
1614- state_args,
16151656 remap,
16161657 storage_liveness,
16171658 always_live_locals,
16181659 suspension_points : Vec :: new ( ) ,
1619- new_ret_local ,
1660+ old_ret_local ,
16201661 discr_ty,
1662+ old_ret_ty,
1663+ old_yield_ty,
16211664 } ;
16221665 transform. visit_body ( body) ;
16231666
0 commit comments