1111
1212use crate :: transform:: { simplify, MirPass , MirSource } ;
1313use itertools:: Itertools as _;
14- use rustc_index:: vec:: IndexVec ;
14+ use rustc_index:: { bit_set:: BitSet , vec:: IndexVec } ;
15+ use rustc_middle:: mir:: visit:: { NonUseContext , PlaceContext , Visitor } ;
1516use rustc_middle:: mir:: * ;
16- use rustc_middle:: ty:: { Ty , TyCtxt } ;
17+ use rustc_middle:: ty:: { List , Ty , TyCtxt } ;
1718use rustc_target:: abi:: VariantIdx ;
1819use std:: iter:: { Enumerate , Peekable } ;
1920use std:: slice:: Iter ;
@@ -73,9 +74,20 @@ struct ArmIdentityInfo<'tcx> {
7374
7475 /// The statements that should be removed (turned into nops)
7576 stmts_to_remove : Vec < usize > ,
77+
78+ /// Indices of debug variables that need to be adjusted to point to
79+ // `{local_0}.{dbg_projection}`.
80+ dbg_info_to_adjust : Vec < usize > ,
81+
82+ /// The projection used to rewrite debug info.
83+ dbg_projection : & ' tcx List < PlaceElem < ' tcx > > ,
7684}
7785
78- fn get_arm_identity_info < ' a , ' tcx > ( stmts : & ' a [ Statement < ' tcx > ] ) -> Option < ArmIdentityInfo < ' tcx > > {
86+ fn get_arm_identity_info < ' a , ' tcx > (
87+ stmts : & ' a [ Statement < ' tcx > ] ,
88+ locals_count : usize ,
89+ debug_info : & ' a [ VarDebugInfo < ' tcx > ] ,
90+ ) -> Option < ArmIdentityInfo < ' tcx > > {
7991 // This can't possibly match unless there are at least 3 statements in the block
8092 // so fail fast on tiny blocks.
8193 if stmts. len ( ) < 3 {
@@ -187,7 +199,7 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
187199 try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
188200
189201 let ( get_variant_field_stmt, stmt) = stmt_iter. next ( ) ?;
190- let ( local_tmp_s0, local_1, vf_s0) = match_get_variant_field ( stmt) ?;
202+ let ( local_tmp_s0, local_1, vf_s0, dbg_projection ) = match_get_variant_field ( stmt) ?;
191203
192204 try_eat_storage_stmts ( & mut stmt_iter, & mut storage_live_stmts, & mut storage_dead_stmts) ;
193205
@@ -228,6 +240,19 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
228240 let stmt_to_overwrite =
229241 nop_stmts. iter ( ) . find ( |stmt_idx| live_idx < * * stmt_idx && * * stmt_idx < dead_idx) ;
230242
243+ let mut tmp_assigned_vars = BitSet :: new_empty ( locals_count) ;
244+ for ( l, r) in & tmp_assigns {
245+ tmp_assigned_vars. insert ( * l) ;
246+ tmp_assigned_vars. insert ( * r) ;
247+ }
248+
249+ let mut dbg_info_to_adjust = Vec :: new ( ) ;
250+ for ( i, var_info) in debug_info. iter ( ) . enumerate ( ) {
251+ if tmp_assigned_vars. contains ( var_info. place . local ) {
252+ dbg_info_to_adjust. push ( i) ;
253+ }
254+ }
255+
231256 Some ( ArmIdentityInfo {
232257 local_temp_0 : local_tmp_s0,
233258 local_1,
@@ -243,12 +268,16 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
243268 source_info : discr_stmt_source_info,
244269 storage_stmts,
245270 stmts_to_remove : nop_stmts,
271+ dbg_info_to_adjust,
272+ dbg_projection,
246273 } )
247274}
248275
249276fn optimization_applies < ' tcx > (
250277 opt_info : & ArmIdentityInfo < ' tcx > ,
251278 local_decls : & IndexVec < Local , LocalDecl < ' tcx > > ,
279+ local_uses : & IndexVec < Local , usize > ,
280+ var_debug_info : & [ VarDebugInfo < ' tcx > ] ,
252281) -> bool {
253282 trace ! ( "testing if optimization applies..." ) ;
254283
@@ -273,6 +302,7 @@ fn optimization_applies<'tcx>(
273302 // Verify the assigment chain consists of the form b = a; c = b; d = c; etc...
274303 if opt_info. field_tmp_assignments . is_empty ( ) {
275304 trace ! ( "NO: no assignments found" ) ;
305+ return false ;
276306 }
277307 let mut last_assigned_to = opt_info. field_tmp_assignments [ 0 ] . 1 ;
278308 let source_local = last_assigned_to;
@@ -285,6 +315,35 @@ fn optimization_applies<'tcx>(
285315 last_assigned_to = * l;
286316 }
287317
318+ // Check that the first and last used locals are only used twice
319+ // since they are of the form:
320+ //
321+ // ```
322+ // _first = ((_x as Variant).n: ty);
323+ // _n = _first;
324+ // ...
325+ // ((_y as Variant).n: ty) = _n;
326+ // discriminant(_y) = z;
327+ // ```
328+ for ( l, r) in & opt_info. field_tmp_assignments {
329+ if local_uses[ * l] != 2 {
330+ warn ! ( "NO: FAILED assignment chain local {:?} was used more than twice" , l) ;
331+ return false ;
332+ } else if local_uses[ * r] != 2 {
333+ warn ! ( "NO: FAILED assignment chain local {:?} was used more than twice" , r) ;
334+ return false ;
335+ }
336+ }
337+
338+ // Check that debug info only points to full Locals and not projections.
339+ for dbg_idx in & opt_info. dbg_info_to_adjust {
340+ let dbg_info = & var_debug_info[ * dbg_idx] ;
341+ if !dbg_info. place . projection . is_empty ( ) {
342+ trace ! ( "NO: debug info for {:?} had a projection {:?}" , dbg_info. name, dbg_info. place) ;
343+ return false ;
344+ }
345+ }
346+
288347 if source_local != opt_info. local_temp_0 {
289348 trace ! (
290349 "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}" ,
@@ -312,11 +371,15 @@ impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
312371 }
313372
314373 trace ! ( "running SimplifyArmIdentity on {:?}" , source) ;
315- let ( basic_blocks, local_decls) = body. basic_blocks_and_local_decls_mut ( ) ;
374+ let local_uses = LocalUseCounter :: get_local_uses ( body) ;
375+ let ( basic_blocks, local_decls, debug_info) =
376+ body. basic_blocks_local_decls_mut_and_var_debug_info ( ) ;
316377 for bb in basic_blocks {
317- if let Some ( opt_info) = get_arm_identity_info ( & bb. statements ) {
378+ if let Some ( opt_info) =
379+ get_arm_identity_info ( & bb. statements , local_decls. len ( ) , debug_info)
380+ {
318381 trace ! ( "got opt_info = {:#?}" , opt_info) ;
319- if !optimization_applies ( & opt_info, local_decls) {
382+ if !optimization_applies ( & opt_info, local_decls, & local_uses , & debug_info ) {
320383 debug ! ( "optimization skipped for {:?}" , source) ;
321384 continue ;
322385 }
@@ -352,23 +415,57 @@ impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
352415
353416 bb. statements . retain ( |stmt| stmt. kind != StatementKind :: Nop ) ;
354417
418+ // Fix the debug info to point to the right local
419+ for dbg_index in opt_info. dbg_info_to_adjust {
420+ let dbg_info = & mut debug_info[ dbg_index] ;
421+ assert ! ( dbg_info. place. projection. is_empty( ) ) ;
422+ dbg_info. place . local = opt_info. local_0 ;
423+ dbg_info. place . projection = opt_info. dbg_projection ;
424+ }
425+
355426 trace ! ( "block is now {:?}" , bb. statements) ;
356427 }
357428 }
358429 }
359430}
360431
432+ struct LocalUseCounter {
433+ local_uses : IndexVec < Local , usize > ,
434+ }
435+
436+ impl LocalUseCounter {
437+ fn get_local_uses < ' tcx > ( body : & Body < ' tcx > ) -> IndexVec < Local , usize > {
438+ let mut counter = LocalUseCounter { local_uses : IndexVec :: from_elem ( 0 , & body. local_decls ) } ;
439+ counter. visit_body ( body) ;
440+ counter. local_uses
441+ }
442+ }
443+
444+ impl < ' tcx > Visitor < ' tcx > for LocalUseCounter {
445+ fn visit_local ( & mut self , local : & Local , context : PlaceContext , _location : Location ) {
446+ if context. is_storage_marker ( )
447+ || context == PlaceContext :: NonUse ( NonUseContext :: VarDebugInfo )
448+ {
449+ return ;
450+ }
451+
452+ self . local_uses [ * local] += 1 ;
453+ }
454+ }
455+
361456/// Match on:
362457/// ```rust
363458/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
364459/// ```
365- fn match_get_variant_field < ' tcx > ( stmt : & Statement < ' tcx > ) -> Option < ( Local , Local , VarField < ' tcx > ) > {
460+ fn match_get_variant_field < ' tcx > (
461+ stmt : & Statement < ' tcx > ,
462+ ) -> Option < ( Local , Local , VarField < ' tcx > , & ' tcx List < PlaceElem < ' tcx > > ) > {
366463 match & stmt. kind {
367464 StatementKind :: Assign ( box ( place_into, rvalue_from) ) => match rvalue_from {
368465 Rvalue :: Use ( Operand :: Copy ( pf) | Operand :: Move ( pf) ) => {
369466 let local_into = place_into. as_local ( ) ?;
370467 let ( local_from, vf) = match_variant_field_place ( * pf) ?;
371- Some ( ( local_into, local_from, vf) )
468+ Some ( ( local_into, local_from, vf, pf . projection ) )
372469 }
373470 _ => None ,
374471 } ,
0 commit comments