@@ -9,7 +9,7 @@ use rustc_errors::ErrorGuaranteed;
99use rustc_hir as hir;
1010use rustc_hir:: def:: DefKind ;
1111use rustc_hir:: def_id:: { DefId , LocalDefId } ;
12- use rustc_hir:: { CoroutineKind , Node } ;
12+ use rustc_hir:: Node ;
1313use rustc_index:: bit_set:: GrowableBitSet ;
1414use rustc_index:: { Idx , IndexSlice , IndexVec } ;
1515use rustc_infer:: infer:: { InferCtxt , TyCtxtInferExt } ;
@@ -177,7 +177,7 @@ struct Builder<'a, 'tcx> {
177177 check_overflow : bool ,
178178 fn_span : Span ,
179179 arg_count : usize ,
180- coroutine_kind : Option < CoroutineKind > ,
180+ coroutine : Option < Box < CoroutineInfo < ' tcx > > > ,
181181
182182 /// The current set of scopes, updated as we traverse;
183183 /// see the `scope` module for more details.
@@ -458,7 +458,6 @@ fn construct_fn<'tcx>(
458458) -> Body < ' tcx > {
459459 let span = tcx. def_span ( fn_def) ;
460460 let fn_id = tcx. local_def_id_to_hir_id ( fn_def) ;
461- let coroutine_kind = tcx. coroutine_kind ( fn_def) ;
462461
463462 // The representation of thir for `-Zunpretty=thir-tree` relies on
464463 // the entry expression being the last element of `thir.exprs`.
@@ -488,17 +487,15 @@ fn construct_fn<'tcx>(
488487
489488 let arguments = & thir. params ;
490489
491- let ( resume_ty, yield_ty, return_ty) = if coroutine_kind. is_some ( ) {
492- let coroutine_ty = arguments[ thir:: UPVAR_ENV_PARAM ] . ty ;
493- let coroutine_sig = match coroutine_ty. kind ( ) {
494- ty:: Coroutine ( _, gen_args, ..) => gen_args. as_coroutine ( ) . sig ( ) ,
495- _ => {
496- span_bug ! ( span, "coroutine w/o coroutine type: {:?}" , coroutine_ty)
497- }
498- } ;
499- ( Some ( coroutine_sig. resume_ty ) , Some ( coroutine_sig. yield_ty ) , coroutine_sig. return_ty )
500- } else {
501- ( None , None , fn_sig. output ( ) )
490+ let return_ty = fn_sig. output ( ) ;
491+ let coroutine = match tcx. type_of ( fn_def) . instantiate_identity ( ) . kind ( ) {
492+ ty:: Coroutine ( _, args) => Some ( Box :: new ( CoroutineInfo :: initial (
493+ tcx. coroutine_kind ( fn_def) . unwrap ( ) ,
494+ args. as_coroutine ( ) . yield_ty ( ) ,
495+ args. as_coroutine ( ) . resume_ty ( ) ,
496+ ) ) ) ,
497+ ty:: Closure ( ..) | ty:: FnDef ( ..) => None ,
498+ ty => span_bug ! ( span_with_body, "unexpected type of body: {ty:?}" ) ,
502499 } ;
503500
504501 if let Some ( custom_mir_attr) =
@@ -529,7 +526,7 @@ fn construct_fn<'tcx>(
529526 safety,
530527 return_ty,
531528 return_ty_span,
532- coroutine_kind ,
529+ coroutine ,
533530 ) ;
534531
535532 let call_site_scope =
@@ -563,11 +560,6 @@ fn construct_fn<'tcx>(
563560 None
564561 } ;
565562
566- if coroutine_kind. is_some ( ) {
567- body. coroutine . as_mut ( ) . unwrap ( ) . yield_ty = yield_ty;
568- body. coroutine . as_mut ( ) . unwrap ( ) . resume_ty = resume_ty;
569- }
570-
571563 body
572564}
573565
@@ -632,47 +624,62 @@ fn construct_const<'a, 'tcx>(
632624fn construct_error ( tcx : TyCtxt < ' _ > , def_id : LocalDefId , guar : ErrorGuaranteed ) -> Body < ' _ > {
633625 let span = tcx. def_span ( def_id) ;
634626 let hir_id = tcx. local_def_id_to_hir_id ( def_id) ;
635- let coroutine_kind = tcx. coroutine_kind ( def_id) ;
636627
637- let ( inputs, output, resume_ty , yield_ty ) = match tcx. def_kind ( def_id) {
628+ let ( inputs, output, coroutine ) = match tcx. def_kind ( def_id) {
638629 DefKind :: Const
639630 | DefKind :: AssocConst
640631 | DefKind :: AnonConst
641632 | DefKind :: InlineConst
642- | DefKind :: Static ( _) => ( vec ! [ ] , tcx. type_of ( def_id) . instantiate_identity ( ) , None , None ) ,
633+ | DefKind :: Static ( _) => ( vec ! [ ] , tcx. type_of ( def_id) . instantiate_identity ( ) , None ) ,
643634 DefKind :: Ctor ( ..) | DefKind :: Fn | DefKind :: AssocFn => {
644635 let sig = tcx. liberate_late_bound_regions (
645636 def_id. to_def_id ( ) ,
646637 tcx. fn_sig ( def_id) . instantiate_identity ( ) ,
647638 ) ;
648- ( sig. inputs ( ) . to_vec ( ) , sig. output ( ) , None , None )
649- }
650- DefKind :: Closure if coroutine_kind. is_some ( ) => {
651- let coroutine_ty = tcx. type_of ( def_id) . instantiate_identity ( ) ;
652- let ty:: Coroutine ( _, args) = coroutine_ty. kind ( ) else {
653- bug ! ( "expected type of coroutine-like closure to be a coroutine" )
654- } ;
655- let args = args. as_coroutine ( ) ;
656- let resume_ty = args. resume_ty ( ) ;
657- let yield_ty = args. yield_ty ( ) ;
658- let return_ty = args. return_ty ( ) ;
659- ( vec ! [ coroutine_ty, args. resume_ty( ) ] , return_ty, Some ( resume_ty) , Some ( yield_ty) )
639+ ( sig. inputs ( ) . to_vec ( ) , sig. output ( ) , None )
660640 }
661641 DefKind :: Closure => {
662642 let closure_ty = tcx. type_of ( def_id) . instantiate_identity ( ) ;
663- let ty:: Closure ( _, args) = closure_ty. kind ( ) else {
664- bug ! ( "expected type of closure to be a closure" )
665- } ;
666- let args = args. as_closure ( ) ;
667- let sig = tcx. liberate_late_bound_regions ( def_id. to_def_id ( ) , args. sig ( ) ) ;
668- let self_ty = match args. kind ( ) {
669- ty:: ClosureKind :: Fn => Ty :: new_imm_ref ( tcx, tcx. lifetimes . re_erased , closure_ty) ,
670- ty:: ClosureKind :: FnMut => Ty :: new_mut_ref ( tcx, tcx. lifetimes . re_erased , closure_ty) ,
671- ty:: ClosureKind :: FnOnce => closure_ty,
672- } ;
673- ( [ self_ty] . into_iter ( ) . chain ( sig. inputs ( ) . to_vec ( ) ) . collect ( ) , sig. output ( ) , None , None )
643+ match closure_ty. kind ( ) {
644+ ty:: Closure ( _, args) => {
645+ let args = args. as_closure ( ) ;
646+ let sig = tcx. liberate_late_bound_regions ( def_id. to_def_id ( ) , args. sig ( ) ) ;
647+ let self_ty = match args. kind ( ) {
648+ ty:: ClosureKind :: Fn => {
649+ Ty :: new_imm_ref ( tcx, tcx. lifetimes . re_erased , closure_ty)
650+ }
651+ ty:: ClosureKind :: FnMut => {
652+ Ty :: new_mut_ref ( tcx, tcx. lifetimes . re_erased , closure_ty)
653+ }
654+ ty:: ClosureKind :: FnOnce => closure_ty,
655+ } ;
656+ (
657+ [ self_ty] . into_iter ( ) . chain ( sig. inputs ( ) . to_vec ( ) ) . collect ( ) ,
658+ sig. output ( ) ,
659+ None ,
660+ )
661+ }
662+ ty:: Coroutine ( _, args) => {
663+ let args = args. as_coroutine ( ) ;
664+ let resume_ty = args. resume_ty ( ) ;
665+ let yield_ty = args. yield_ty ( ) ;
666+ let return_ty = args. return_ty ( ) ;
667+ (
668+ vec ! [ closure_ty, args. resume_ty( ) ] ,
669+ return_ty,
670+ Some ( Box :: new ( CoroutineInfo :: initial (
671+ tcx. coroutine_kind ( def_id) . unwrap ( ) ,
672+ yield_ty,
673+ resume_ty,
674+ ) ) ) ,
675+ )
676+ }
677+ _ => {
678+ span_bug ! ( span, "expected type of closure body to be a closure or coroutine" ) ;
679+ }
680+ }
674681 }
675- dk => bug ! ( "{:?} is not a body: {:?}" , def_id, dk) ,
682+ dk => span_bug ! ( span , "{:?} is not a body: {:?}" , def_id, dk) ,
676683 } ;
677684
678685 let source_info = SourceInfo { span, scope : OUTERMOST_SOURCE_SCOPE } ;
@@ -696,7 +703,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
696703
697704 cfg. terminate ( START_BLOCK , source_info, TerminatorKind :: Unreachable ) ;
698705
699- let mut body = Body :: new (
706+ Body :: new (
700707 MirSource :: item ( def_id. to_def_id ( ) ) ,
701708 cfg. basic_blocks ,
702709 source_scopes,
@@ -705,16 +712,9 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
705712 inputs. len ( ) ,
706713 vec ! [ ] ,
707714 span,
708- coroutine_kind ,
715+ coroutine ,
709716 Some ( guar) ,
710- ) ;
711-
712- body. coroutine . as_mut ( ) . map ( |gen| {
713- gen. yield_ty = yield_ty;
714- gen. resume_ty = resume_ty;
715- } ) ;
716-
717- body
717+ )
718718}
719719
720720impl < ' a , ' tcx > Builder < ' a , ' tcx > {
@@ -728,7 +728,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
728728 safety : Safety ,
729729 return_ty : Ty < ' tcx > ,
730730 return_span : Span ,
731- coroutine_kind : Option < CoroutineKind > ,
731+ coroutine : Option < Box < CoroutineInfo < ' tcx > > > ,
732732 ) -> Builder < ' a , ' tcx > {
733733 let tcx = infcx. tcx ;
734734 let attrs = tcx. hir ( ) . attrs ( hir_id) ;
@@ -759,7 +759,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
759759 cfg : CFG { basic_blocks : IndexVec :: new ( ) } ,
760760 fn_span : span,
761761 arg_count,
762- coroutine_kind ,
762+ coroutine ,
763763 scopes : scope:: Scopes :: new ( ) ,
764764 block_context : BlockContext :: new ( ) ,
765765 source_scopes : IndexVec :: new ( ) ,
@@ -803,7 +803,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
803803 self . arg_count ,
804804 self . var_debug_info ,
805805 self . fn_span ,
806- self . coroutine_kind ,
806+ self . coroutine ,
807807 None ,
808808 )
809809 }
0 commit comments