6060//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
6161//! we use this "by move" body instead.
6262
63- use itertools:: Itertools ;
64-
65- use rustc_data_structures:: unord:: UnordSet ;
63+ use rustc_data_structures:: unord:: UnordMap ;
6664use rustc_hir as hir;
65+ use rustc_middle:: hir:: place:: { Projection , ProjectionKind } ;
6766use rustc_middle:: mir:: visit:: MutVisitor ;
6867use rustc_middle:: mir:: { self , dump_mir, MirPass } ;
6968use rustc_middle:: ty:: { self , InstanceDef , Ty , TyCtxt , TypeVisitableExt } ;
70- use rustc_target:: abi:: FieldIdx ;
69+ use rustc_target:: abi:: { FieldIdx , VariantIdx } ;
7170
7271pub struct ByMoveBody ;
7372
@@ -116,32 +115,76 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
116115 . tuple_fields ( )
117116 . len ( ) ;
118117
119- let mut by_ref_fields = UnordSet :: default ( ) ;
120- for ( idx, ( coroutine_capture, parent_capture) ) in tcx
118+ let mut field_remapping = UnordMap :: default ( ) ;
119+
120+ let mut parent_captures =
121+ tcx. closure_captures ( parent_def_id) . iter ( ) . copied ( ) . enumerate ( ) . peekable ( ) ;
122+
123+ for ( child_field_idx, child_capture) in tcx
121124 . closure_captures ( coroutine_def_id)
122125 . iter ( )
126+ . copied ( )
123127 // By construction we capture all the args first.
124128 . skip ( num_args)
125- . zip_eq ( tcx. closure_captures ( parent_def_id) )
126129 . enumerate ( )
127130 {
128- // This upvar is captured by-move from the parent closure, but by-ref
129- // from the inner async block. That means that it's being borrowed from
130- // the outer closure body -- we need to change the coroutine to take the
131- // upvar by value.
132- if coroutine_capture. is_by_ref ( ) && !parent_capture. is_by_ref ( ) {
133- assert_ne ! (
134- coroutine_kind,
135- ty:: ClosureKind :: FnOnce ,
136- "`FnOnce` coroutine-closures return coroutines that capture from \
137- their body; it will always result in a borrowck error!"
131+ loop {
132+ let Some ( & ( parent_field_idx, parent_capture) ) = parent_captures. peek ( ) else {
133+ bug ! ( "we ran out of parent captures!" )
134+ } ;
135+
136+ if !std:: iter:: zip (
137+ & child_capture. place . projections ,
138+ & parent_capture. place . projections ,
139+ )
140+ . all ( |( child, parent) | child. kind == parent. kind )
141+ {
142+ // Skip this field.
143+ let _ = parent_captures. next ( ) . unwrap ( ) ;
144+ continue ;
145+ }
146+
147+ let child_precise_captures =
148+ & child_capture. place . projections [ parent_capture. place . projections . len ( ) ..] ;
149+
150+ let needs_deref = child_capture. is_by_ref ( ) && !parent_capture. is_by_ref ( ) ;
151+ if needs_deref {
152+ assert_ne ! (
153+ coroutine_kind,
154+ ty:: ClosureKind :: FnOnce ,
155+ "`FnOnce` coroutine-closures return coroutines that capture from \
156+ their body; it will always result in a borrowck error!"
157+ ) ;
158+ }
159+
160+ let mut parent_capture_ty = parent_capture. place . ty ( ) ;
161+ parent_capture_ty = match parent_capture. info . capture_kind {
162+ ty:: UpvarCapture :: ByValue => parent_capture_ty,
163+ ty:: UpvarCapture :: ByRef ( kind) => Ty :: new_ref (
164+ tcx,
165+ tcx. lifetimes . re_erased ,
166+ parent_capture_ty,
167+ kind. to_mutbl_lossy ( ) ,
168+ ) ,
169+ } ;
170+
171+ field_remapping. insert (
172+ FieldIdx :: from_usize ( child_field_idx + num_args) ,
173+ (
174+ FieldIdx :: from_usize ( parent_field_idx + num_args) ,
175+ parent_capture_ty,
176+ needs_deref,
177+ child_precise_captures,
178+ ) ,
138179 ) ;
139- by_ref_fields. insert ( FieldIdx :: from_usize ( num_args + idx) ) ;
180+
181+ break ;
140182 }
183+ }
141184
142- // Make sure we're actually talking about the same capture.
143- // FIXME(async_closures): We could look at the `hir::Upvar` instead?
144- assert_eq ! ( coroutine_capture . place . ty ( ) , parent_capture . place . ty ( ) ) ;
185+ if coroutine_kind == ty :: ClosureKind :: FnOnce {
186+ assert_eq ! ( field_remapping . len ( ) , tcx . closure_captures ( parent_def_id ) . len ( ) ) ;
187+ return ;
145188 }
146189
147190 let by_move_coroutine_ty = tcx
@@ -157,7 +200,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
157200 ) ;
158201
159202 let mut by_move_body = body. clone ( ) ;
160- MakeByMoveBody { tcx, by_ref_fields , by_move_coroutine_ty } . visit_body ( & mut by_move_body) ;
203+ MakeByMoveBody { tcx, field_remapping , by_move_coroutine_ty } . visit_body ( & mut by_move_body) ;
161204 dump_mir ( tcx, false , "coroutine_by_move" , & 0 , & by_move_body, |_, _| Ok ( ( ) ) ) ;
162205 by_move_body. source = mir:: MirSource :: from_instance ( InstanceDef :: CoroutineKindShim {
163206 coroutine_def_id : coroutine_def_id. to_def_id ( ) ,
@@ -168,7 +211,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
168211
169212struct MakeByMoveBody < ' tcx > {
170213 tcx : TyCtxt < ' tcx > ,
171- by_ref_fields : UnordSet < FieldIdx > ,
214+ field_remapping : UnordMap < FieldIdx , ( FieldIdx , Ty < ' tcx > , bool , & ' tcx [ Projection < ' tcx > ] ) > ,
172215 by_move_coroutine_ty : Ty < ' tcx > ,
173216}
174217
@@ -184,23 +227,36 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
184227 location : mir:: Location ,
185228 ) {
186229 if place. local == ty:: CAPTURE_STRUCT_LOCAL
187- && let Some ( ( & mir:: ProjectionElem :: Field ( idx, ty ) , projection) ) =
230+ && let Some ( ( & mir:: ProjectionElem :: Field ( idx, _ ) , projection) ) =
188231 place. projection . split_first ( )
189- && self . by_ref_fields . contains ( & idx)
232+ && let Some ( & ( remapped_idx, remapped_ty, needs_deref, additional_projections) ) =
233+ self . field_remapping . get ( & idx)
190234 {
191- let ( begin, end) = projection. split_first ( ) . unwrap ( ) ;
192- // FIXME(async_closures): I'm actually a bit surprised to see that we always
193- // initially deref the by-ref upvars. If this is not actually true, then we
194- // will at least get an ICE that explains why this isn't true :^)
195- assert_eq ! ( * begin, mir:: ProjectionElem :: Deref ) ;
196- // Peel one ref off of the ty.
197- let peeled_ty = ty. builtin_deref ( true ) . unwrap ( ) . ty ;
235+ let final_deref = if needs_deref {
236+ let Some ( ( mir:: ProjectionElem :: Deref , rest) ) = projection. split_first ( ) else {
237+ bug ! ( ) ;
238+ } ;
239+ rest
240+ } else {
241+ projection
242+ } ;
243+
244+ let additional_projections =
245+ additional_projections. iter ( ) . map ( |elem| match elem. kind {
246+ ProjectionKind :: Deref => mir:: ProjectionElem :: Deref ,
247+ ProjectionKind :: Field ( idx, VariantIdx :: ZERO ) => {
248+ mir:: ProjectionElem :: Field ( idx, elem. ty )
249+ }
250+ _ => unreachable ! ( "precise captures only through fields and derefs" ) ,
251+ } ) ;
252+
198253 * place = mir:: Place {
199254 local : place. local ,
200255 projection : self . tcx . mk_place_elems_from_iter (
201- [ mir:: ProjectionElem :: Field ( idx , peeled_ty ) ]
256+ [ mir:: ProjectionElem :: Field ( remapped_idx , remapped_ty ) ]
202257 . into_iter ( )
203- . chain ( end. iter ( ) . copied ( ) ) ,
258+ . chain ( additional_projections)
259+ . chain ( final_deref. iter ( ) . copied ( ) ) ,
204260 ) ,
205261 } ;
206262 }
0 commit comments