@@ -67,9 +67,9 @@ use rustc_index::{Idx, IndexVec};
6767use rustc_middle:: mir:: dump_mir;
6868use rustc_middle:: mir:: visit:: { MutVisitor , PlaceContext , Visitor } ;
6969use rustc_middle:: mir:: * ;
70+ use rustc_middle:: ty:: CoroutineArgs ;
7071use rustc_middle:: ty:: InstanceDef ;
71- use rustc_middle:: ty:: { self , AdtDef , Ty , TyCtxt } ;
72- use rustc_middle:: ty:: { CoroutineArgs , GenericArgsRef } ;
72+ use rustc_middle:: ty:: { self , Ty , TyCtxt } ;
7373use rustc_mir_dataflow:: impls:: {
7474 MaybeBorrowedLocals , MaybeLiveLocals , MaybeRequiresStorage , MaybeStorageLive ,
7575} ;
@@ -226,8 +226,6 @@ struct SuspensionPoint<'tcx> {
226226struct TransformVisitor < ' tcx > {
227227 tcx : TyCtxt < ' tcx > ,
228228 coroutine_kind : hir:: CoroutineKind ,
229- state_adt_ref : AdtDef < ' tcx > ,
230- state_args : GenericArgsRef < ' tcx > ,
231229
232230 // The type of the discriminant in the coroutine struct
233231 discr_ty : Ty < ' tcx > ,
@@ -246,21 +244,34 @@ struct TransformVisitor<'tcx> {
246244 always_live_locals : BitSet < Local > ,
247245
248246 // The original RETURN_PLACE local
249- new_ret_local : Local ,
247+ old_ret_local : Local ,
248+
249+ old_yield_ty : Ty < ' tcx > ,
250+
251+ old_ret_ty : Ty < ' tcx > ,
250252}
251253
252254impl < ' tcx > TransformVisitor < ' tcx > {
253255 fn insert_none_ret_block ( & self , body : & mut Body < ' tcx > ) -> BasicBlock {
254- let block = BasicBlock :: new ( body . basic_blocks . len ( ) ) ;
256+ assert ! ( matches! ( self . coroutine_kind , CoroutineKind :: Gen ( _ ) ) ) ;
255257
258+ let block = BasicBlock :: new ( body. basic_blocks . len ( ) ) ;
256259 let source_info = SourceInfo :: outermost ( body. span ) ;
260+ let option_def_id = self . tcx . require_lang_item ( LangItem :: Option , None ) ;
257261
258- let ( kind, idx) = self . coroutine_state_adt_and_variant_idx ( true ) ;
259- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
260262 let statements = vec ! [ Statement {
261263 kind: StatementKind :: Assign ( Box :: new( (
262264 Place :: return_place( ) ,
263- Rvalue :: Aggregate ( Box :: new( kind) , IndexVec :: new( ) ) ,
265+ Rvalue :: Aggregate (
266+ Box :: new( AggregateKind :: Adt (
267+ option_def_id,
268+ VariantIdx :: from_usize( 0 ) ,
269+ self . tcx. mk_args( & [ self . old_yield_ty. into( ) ] ) ,
270+ None ,
271+ None ,
272+ ) ) ,
273+ IndexVec :: new( ) ,
274+ ) ,
264275 ) ) ) ,
265276 source_info,
266277 } ] ;
@@ -274,23 +285,6 @@ impl<'tcx> TransformVisitor<'tcx> {
274285 block
275286 }
276287
277- fn coroutine_state_adt_and_variant_idx (
278- & self ,
279- is_return : bool ,
280- ) -> ( AggregateKind < ' tcx > , VariantIdx ) {
281- let idx = VariantIdx :: new ( match ( is_return, self . coroutine_kind ) {
282- ( true , hir:: CoroutineKind :: Coroutine ) => 1 , // CoroutineState::Complete
283- ( false , hir:: CoroutineKind :: Coroutine ) => 0 , // CoroutineState::Yielded
284- ( true , hir:: CoroutineKind :: Async ( _) ) => 0 , // Poll::Ready
285- ( false , hir:: CoroutineKind :: Async ( _) ) => 1 , // Poll::Pending
286- ( true , hir:: CoroutineKind :: Gen ( _) ) => 0 , // Option::None
287- ( false , hir:: CoroutineKind :: Gen ( _) ) => 1 , // Option::Some
288- } ) ;
289-
290- let kind = AggregateKind :: Adt ( self . state_adt_ref . did ( ) , idx, self . state_args , None , None ) ;
291- ( kind, idx)
292- }
293-
294288 // Make a `CoroutineState` or `Poll` variant assignment.
295289 //
296290 // `core::ops::CoroutineState` only has single element tuple variants,
@@ -303,51 +297,99 @@ impl<'tcx> TransformVisitor<'tcx> {
303297 is_return : bool ,
304298 statements : & mut Vec < Statement < ' tcx > > ,
305299 ) {
306- let ( kind, idx) = self . coroutine_state_adt_and_variant_idx ( is_return) ;
307-
308- match self . coroutine_kind {
309- // `Poll::Pending`
300+ let rvalue = match self . coroutine_kind {
310301 CoroutineKind :: Async ( _) => {
311- if !is_return {
312- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
313-
314- // FIXME(swatinem): assert that `val` is indeed unit?
315- statements. push ( Statement {
316- kind : StatementKind :: Assign ( Box :: new ( (
317- Place :: return_place ( ) ,
318- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
319- ) ) ) ,
320- source_info,
321- } ) ;
322- return ;
302+ let poll_def_id = self . tcx . require_lang_item ( LangItem :: Poll , None ) ;
303+ let args = self . tcx . mk_args ( & [ self . old_ret_ty . into ( ) ] ) ;
304+ if is_return {
305+ // Poll::Ready(val)
306+ Rvalue :: Aggregate (
307+ Box :: new ( AggregateKind :: Adt (
308+ poll_def_id,
309+ VariantIdx :: from_usize ( 0 ) ,
310+ args,
311+ None ,
312+ None ,
313+ ) ) ,
314+ IndexVec :: from_raw ( vec ! [ val] ) ,
315+ )
316+ } else {
317+ // Poll::Pending
318+ Rvalue :: Aggregate (
319+ Box :: new ( AggregateKind :: Adt (
320+ poll_def_id,
321+ VariantIdx :: from_usize ( 1 ) ,
322+ args,
323+ None ,
324+ None ,
325+ ) ) ,
326+ IndexVec :: new ( ) ,
327+ )
323328 }
324329 }
325- // `Option::None`
326330 CoroutineKind :: Gen ( _) => {
331+ let option_def_id = self . tcx . require_lang_item ( LangItem :: Option , None ) ;
332+ let args = self . tcx . mk_args ( & [ self . old_yield_ty . into ( ) ] ) ;
327333 if is_return {
328- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
329-
330- statements. push ( Statement {
331- kind : StatementKind :: Assign ( Box :: new ( (
332- Place :: return_place ( ) ,
333- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
334- ) ) ) ,
335- source_info,
336- } ) ;
337- return ;
334+ // None
335+ Rvalue :: Aggregate (
336+ Box :: new ( AggregateKind :: Adt (
337+ option_def_id,
338+ VariantIdx :: from_usize ( 0 ) ,
339+ args,
340+ None ,
341+ None ,
342+ ) ) ,
343+ IndexVec :: new ( ) ,
344+ )
345+ } else {
346+ // Some(val)
347+ Rvalue :: Aggregate (
348+ Box :: new ( AggregateKind :: Adt (
349+ option_def_id,
350+ VariantIdx :: from_usize ( 1 ) ,
351+ args,
352+ None ,
353+ None ,
354+ ) ) ,
355+ IndexVec :: from_raw ( vec ! [ val] ) ,
356+ )
338357 }
339358 }
340- CoroutineKind :: Coroutine => { }
341- }
342-
343- // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
344- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 1 ) ;
359+ CoroutineKind :: Coroutine => {
360+ let coroutine_state_def_id =
361+ self . tcx . require_lang_item ( LangItem :: CoroutineState , None ) ;
362+ let args = self . tcx . mk_args ( & [ self . old_yield_ty . into ( ) , self . old_ret_ty . into ( ) ] ) ;
363+ if is_return {
364+ // CoroutineState::Complete(val)
365+ Rvalue :: Aggregate (
366+ Box :: new ( AggregateKind :: Adt (
367+ coroutine_state_def_id,
368+ VariantIdx :: from_usize ( 1 ) ,
369+ args,
370+ None ,
371+ None ,
372+ ) ) ,
373+ IndexVec :: from_raw ( vec ! [ val] ) ,
374+ )
375+ } else {
376+ // CoroutineState::Yielded(val)
377+ Rvalue :: Aggregate (
378+ Box :: new ( AggregateKind :: Adt (
379+ coroutine_state_def_id,
380+ VariantIdx :: from_usize ( 0 ) ,
381+ args,
382+ None ,
383+ None ,
384+ ) ) ,
385+ IndexVec :: from_raw ( vec ! [ val] ) ,
386+ )
387+ }
388+ }
389+ } ;
345390
346391 statements. push ( Statement {
347- kind : StatementKind :: Assign ( Box :: new ( (
348- Place :: return_place ( ) ,
349- Rvalue :: Aggregate ( Box :: new ( kind) , [ val] . into ( ) ) ,
350- ) ) ) ,
392+ kind : StatementKind :: Assign ( Box :: new ( ( Place :: return_place ( ) , rvalue) ) ) ,
351393 source_info,
352394 } ) ;
353395 }
@@ -421,7 +463,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
421463
422464 let ret_val = match data. terminator ( ) . kind {
423465 TerminatorKind :: Return => {
424- Some ( ( true , None , Operand :: Move ( Place :: from ( self . new_ret_local ) ) , None ) )
466+ Some ( ( true , None , Operand :: Move ( Place :: from ( self . old_ret_local ) ) , None ) )
425467 }
426468 TerminatorKind :: Yield { ref value, resume, resume_arg, drop } => {
427469 Some ( ( false , Some ( ( resume, resume_arg) ) , value. clone ( ) , drop) )
@@ -1503,10 +1545,11 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
15031545
15041546impl < ' tcx > MirPass < ' tcx > for StateTransform {
15051547 fn run_pass ( & self , tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
1506- let Some ( yield_ty ) = body. yield_ty ( ) else {
1548+ let Some ( old_yield_ty ) = body. yield_ty ( ) else {
15071549 // This only applies to coroutines
15081550 return ;
15091551 } ;
1552+ let old_ret_ty = body. return_ty ( ) ;
15101553
15111554 assert ! ( body. coroutine_drop( ) . is_none( ) ) ;
15121555
@@ -1528,34 +1571,33 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15281571
15291572 let is_async_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Async ( _) ) ) ;
15301573 let is_gen_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Gen ( _) ) ) ;
1531- let ( state_adt_ref , state_args ) = match body. coroutine_kind ( ) . unwrap ( ) {
1574+ let new_ret_ty = match body. coroutine_kind ( ) . unwrap ( ) {
15321575 CoroutineKind :: Async ( _) => {
15331576 // Compute Poll<return_ty>
15341577 let poll_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
15351578 let poll_adt_ref = tcx. adt_def ( poll_did) ;
1536- let poll_args = tcx. mk_args ( & [ body . return_ty ( ) . into ( ) ] ) ;
1537- ( poll_adt_ref, poll_args)
1579+ let poll_args = tcx. mk_args ( & [ old_ret_ty . into ( ) ] ) ;
1580+ Ty :: new_adt ( tcx , poll_adt_ref, poll_args)
15381581 }
15391582 CoroutineKind :: Gen ( _) => {
15401583 // Compute Option<yield_ty>
15411584 let option_did = tcx. require_lang_item ( LangItem :: Option , None ) ;
15421585 let option_adt_ref = tcx. adt_def ( option_did) ;
1543- let option_args = tcx. mk_args ( & [ body . yield_ty ( ) . unwrap ( ) . into ( ) ] ) ;
1544- ( option_adt_ref, option_args)
1586+ let option_args = tcx. mk_args ( & [ old_yield_ty . into ( ) ] ) ;
1587+ Ty :: new_adt ( tcx , option_adt_ref, option_args)
15451588 }
15461589 CoroutineKind :: Coroutine => {
15471590 // Compute CoroutineState<yield_ty, return_ty>
15481591 let state_did = tcx. require_lang_item ( LangItem :: CoroutineState , None ) ;
15491592 let state_adt_ref = tcx. adt_def ( state_did) ;
1550- let state_args = tcx. mk_args ( & [ yield_ty . into ( ) , body . return_ty ( ) . into ( ) ] ) ;
1551- ( state_adt_ref, state_args)
1593+ let state_args = tcx. mk_args ( & [ old_yield_ty . into ( ) , old_ret_ty . into ( ) ] ) ;
1594+ Ty :: new_adt ( tcx , state_adt_ref, state_args)
15521595 }
15531596 } ;
1554- let ret_ty = Ty :: new_adt ( tcx, state_adt_ref, state_args) ;
15551597
1556- // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1598+ // We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
15571599 // RETURN_PLACE then is a fresh unused local with type ret_ty.
1558- let new_ret_local = replace_local ( RETURN_PLACE , ret_ty , body, tcx) ;
1600+ let old_ret_local = replace_local ( RETURN_PLACE , new_ret_ty , body, tcx) ;
15591601
15601602 // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
15611603 if is_async_kind {
@@ -1572,17 +1614,18 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15721614 } else {
15731615 body. local_decls [ resume_local] . ty
15741616 } ;
1575- let new_resume_local = replace_local ( resume_local, resume_ty, body, tcx) ;
1617+ let old_resume_local = replace_local ( resume_local, resume_ty, body, tcx) ;
15761618
1577- // When first entering the coroutine, move the resume argument into its new local.
1619+ // When first entering the coroutine, move the resume argument into its old local
1620+ // (which is now a generator interior).
15781621 let source_info = SourceInfo :: outermost ( body. span ) ;
15791622 let stmts = & mut body. basic_blocks_mut ( ) [ START_BLOCK ] . statements ;
15801623 stmts. insert (
15811624 0 ,
15821625 Statement {
15831626 source_info,
15841627 kind : StatementKind :: Assign ( Box :: new ( (
1585- new_resume_local . into ( ) ,
1628+ old_resume_local . into ( ) ,
15861629 Rvalue :: Use ( Operand :: Move ( resume_local. into ( ) ) ) ,
15871630 ) ) ) ,
15881631 } ,
@@ -1618,14 +1661,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
16181661 let mut transform = TransformVisitor {
16191662 tcx,
16201663 coroutine_kind : body. coroutine_kind ( ) . unwrap ( ) ,
1621- state_adt_ref,
1622- state_args,
16231664 remap,
16241665 storage_liveness,
16251666 always_live_locals,
16261667 suspension_points : Vec :: new ( ) ,
1627- new_ret_local ,
1668+ old_ret_local ,
16281669 discr_ty,
1670+ old_ret_ty,
1671+ old_yield_ty,
16291672 } ;
16301673 transform. visit_body ( body) ;
16311674
0 commit comments