1- use rustc_index:: IndexVec ;
1+ use rustc_index:: IndexSlice ;
2+ use rustc_middle:: mir:: patch:: MirPatch ;
23use rustc_middle:: mir:: * ;
34use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
45use rustc_target:: abi:: Size ;
@@ -17,9 +18,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
1718 let def_id = body. source . def_id ( ) ;
1819 let param_env = tcx. param_env_reveal_all_normalized ( def_id) ;
1920
20- let bbs = body. basic_blocks . as_mut ( ) ;
2121 let mut should_cleanup = false ;
22- for bb_idx in bbs. indices ( ) {
22+ for i in 0 ..body. basic_blocks . len ( ) {
23+ let bbs = & * body. basic_blocks ;
24+ let bb_idx = BasicBlock :: from_usize ( i) ;
2325 if !tcx. consider_optimizing ( || format ! ( "MatchBranchSimplification {def_id:?} " ) ) {
2426 continue ;
2527 }
@@ -35,12 +37,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
3537 _ => continue ,
3638 } ;
3739
38- if SimplifyToIf . simplify ( tcx, & mut body. local_decls , bbs , bb_idx, param_env) {
40+ if SimplifyToIf . simplify ( tcx, body, bb_idx, param_env) {
3941 should_cleanup = true ;
4042 continue ;
4143 }
42- if SimplifyToExp :: default ( ) . simplify ( tcx, & mut body. local_decls , bbs, bb_idx, param_env)
43- {
44+ if SimplifyToExp :: default ( ) . simplify ( tcx, body, bb_idx, param_env) {
4445 should_cleanup = true ;
4546 continue ;
4647 }
@@ -58,41 +59,39 @@ trait SimplifyMatch<'tcx> {
5859 fn simplify (
5960 & mut self ,
6061 tcx : TyCtxt < ' tcx > ,
61- local_decls : & mut IndexVec < Local , LocalDecl < ' tcx > > ,
62- bbs : & mut IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
62+ body : & mut Body < ' tcx > ,
6363 switch_bb_idx : BasicBlock ,
6464 param_env : ParamEnv < ' tcx > ,
6565 ) -> bool {
66+ let bbs = & body. basic_blocks ;
6667 let ( discr, targets) = match bbs[ switch_bb_idx] . terminator ( ) . kind {
6768 TerminatorKind :: SwitchInt { ref discr, ref targets, .. } => ( discr, targets) ,
6869 _ => unreachable ! ( ) ,
6970 } ;
7071
71- let discr_ty = discr. ty ( local_decls, tcx) ;
72+ let discr_ty = discr. ty ( body . local_decls ( ) , tcx) ;
7273 if !self . can_simplify ( tcx, targets, param_env, bbs, discr_ty) {
7374 return false ;
7475 }
7576
77+ let mut patch = MirPatch :: new ( body) ;
78+
7679 // Take ownership of items now that we know we can optimize.
7780 let discr = discr. clone ( ) ;
7881
7982 // Introduce a temporary for the discriminant value.
8083 let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
81- let discr_local = local_decls . push ( LocalDecl :: new ( discr_ty, source_info. span ) ) ;
84+ let discr_local = patch . new_temp ( discr_ty, source_info. span ) ;
8285
83- let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local, discr_ty) ;
8486 let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
85- let ( from, first) = bbs. pick2_mut ( switch_bb_idx, first) ;
86- from. statements
87- . push ( Statement { source_info, kind : StatementKind :: StorageLive ( discr_local) } ) ;
88- from. statements . push ( Statement {
89- source_info,
90- kind : StatementKind :: Assign ( Box :: new ( ( Place :: from ( discr_local) , Rvalue :: Use ( discr) ) ) ) ,
91- } ) ;
92- from. statements . extend ( new_stmts) ;
93- from. statements
94- . push ( Statement { source_info, kind : StatementKind :: StorageDead ( discr_local) } ) ;
95- from. terminator_mut ( ) . kind = first. terminator ( ) . kind . clone ( ) ;
87+ let statement_index = bbs[ switch_bb_idx] . statements . len ( ) ;
88+ let parent_end = Location { block : switch_bb_idx, statement_index } ;
89+ patch. add_statement ( parent_end, StatementKind :: StorageLive ( discr_local) ) ;
90+ patch. add_assign ( parent_end, Place :: from ( discr_local) , Rvalue :: Use ( discr) ) ;
91+ self . new_stmts ( tcx, targets, param_env, & mut patch, parent_end, bbs, discr_local, discr_ty) ;
92+ patch. add_statement ( parent_end, StatementKind :: StorageDead ( discr_local) ) ;
93+ patch. patch_terminator ( switch_bb_idx, bbs[ first] . terminator ( ) . kind . clone ( ) ) ;
94+ patch. apply ( body) ;
9695 true
9796 }
9897
@@ -104,7 +103,7 @@ trait SimplifyMatch<'tcx> {
104103 tcx : TyCtxt < ' tcx > ,
105104 targets : & SwitchTargets ,
106105 param_env : ParamEnv < ' tcx > ,
107- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
106+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
108107 discr_ty : Ty < ' tcx > ,
109108 ) -> bool ;
110109
@@ -113,10 +112,12 @@ trait SimplifyMatch<'tcx> {
113112 tcx : TyCtxt < ' tcx > ,
114113 targets : & SwitchTargets ,
115114 param_env : ParamEnv < ' tcx > ,
116- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
115+ patch : & mut MirPatch < ' tcx > ,
116+ parent_end : Location ,
117+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
117118 discr_local : Local ,
118119 discr_ty : Ty < ' tcx > ,
119- ) -> Vec < Statement < ' tcx > > ;
120+ ) ;
120121}
121122
122123struct SimplifyToIf ;
@@ -158,7 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
158159 tcx : TyCtxt < ' tcx > ,
159160 targets : & SwitchTargets ,
160161 param_env : ParamEnv < ' tcx > ,
161- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
162+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
162163 _discr_ty : Ty < ' tcx > ,
163164 ) -> bool {
164165 if targets. iter ( ) . len ( ) != 1 {
@@ -209,20 +210,23 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
209210 tcx : TyCtxt < ' tcx > ,
210211 targets : & SwitchTargets ,
211212 param_env : ParamEnv < ' tcx > ,
212- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
213+ patch : & mut MirPatch < ' tcx > ,
214+ parent_end : Location ,
215+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
213216 discr_local : Local ,
214217 discr_ty : Ty < ' tcx > ,
215- ) -> Vec < Statement < ' tcx > > {
218+ ) {
216219 let ( val, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
217220 let second = targets. otherwise ( ) ;
218221 // We already checked that first and second are different blocks,
219222 // and bb_idx has a different terminator from both of them.
220223 let first = & bbs[ first] ;
221224 let second = & bbs[ second] ;
222-
223- let new_stmts = iter:: zip ( & first. statements , & second. statements ) . map ( |( f, s) | {
225+ for ( f, s) in iter:: zip ( & first. statements , & second. statements ) {
224226 match ( & f. kind , & s. kind ) {
225- ( f_s, s_s) if f_s == s_s => ( * f) . clone ( ) ,
227+ ( f_s, s_s) if f_s == s_s => {
228+ patch. add_statement ( parent_end, f. kind . clone ( ) ) ;
229+ }
226230
227231 (
228232 StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
@@ -233,7 +237,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
233237 let s_b = s_c. const_ . try_eval_bool ( tcx, param_env) . unwrap ( ) ;
234238 if f_b == s_b {
235239 // Same value in both blocks. Use statement as is.
236- ( * f ) . clone ( )
240+ patch . add_statement ( parent_end , f . kind . clone ( ) ) ;
237241 } else {
238242 // Different value between blocks. Make value conditional on switch condition.
239243 let size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
@@ -248,17 +252,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
248252 op,
249253 Box :: new ( ( Operand :: Copy ( Place :: from ( discr_local) ) , const_cmp) ) ,
250254 ) ;
251- Statement {
252- source_info : f. source_info ,
253- kind : StatementKind :: Assign ( Box :: new ( ( * lhs, rhs) ) ) ,
254- }
255+ patch. add_assign ( parent_end, * lhs, rhs) ;
255256 }
256257 }
257258
258259 _ => unreachable ! ( ) ,
259260 }
260- } ) ;
261- new_stmts. collect ( )
261+ }
262262 }
263263}
264264
@@ -335,7 +335,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
335335 tcx : TyCtxt < ' tcx > ,
336336 targets : & SwitchTargets ,
337337 param_env : ParamEnv < ' tcx > ,
338- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
338+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
339339 discr_ty : Ty < ' tcx > ,
340340 ) -> bool {
341341 if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
@@ -372,6 +372,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
372372 == ScalarInt :: try_from_uint ( r, size) . unwrap ( ) . try_to_int ( size) . unwrap ( )
373373 }
374374
375+ // We first compare the two branches, and then the other branches need to fulfill the same conditions.
375376 let mut compare_types = Vec :: new ( ) ;
376377 for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
377378 let compare_type = match ( & f. kind , & s. kind ) {
@@ -391,6 +392,8 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
391392 s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
392393 ) {
393394 ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
395+ // Enum variants can also be simplified to an assignment statement if their values are equal.
396+ // We need to consider both unsigned and signed scenarios here.
394397 ( Some ( f) , Some ( s) )
395398 if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
396399 && int_equal ( f, first_val, discr_size)
@@ -463,16 +466,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
463466 _tcx : TyCtxt < ' tcx > ,
464467 targets : & SwitchTargets ,
465468 _param_env : ParamEnv < ' tcx > ,
466- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
469+ patch : & mut MirPatch < ' tcx > ,
470+ parent_end : Location ,
471+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
467472 discr_local : Local ,
468473 discr_ty : Ty < ' tcx > ,
469- ) -> Vec < Statement < ' tcx > > {
474+ ) {
470475 let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
471476 let first = & bbs[ first] ;
472477
473- let new_stmts =
474- iter:: zip ( & self . transfrom_types , & first. statements ) . map ( |( t, s) | match ( t, & s. kind ) {
475- ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => ( * s) . clone ( ) ,
478+ for ( t, s) in iter:: zip ( & self . transfrom_types , & first. statements ) {
479+ match ( t, & s. kind ) {
480+ ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => {
481+ patch. add_statement ( parent_end, s. kind . clone ( ) ) ;
482+ }
476483 (
477484 TransfromType :: Discr ,
478485 StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
@@ -483,13 +490,10 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
483490 } else {
484491 Rvalue :: Cast ( CastKind :: IntToInt , operand, f_c. const_ . ty ( ) )
485492 } ;
486- Statement {
487- source_info : s. source_info ,
488- kind : StatementKind :: Assign ( Box :: new ( ( * lhs, r_val) ) ) ,
489- }
493+ patch. add_assign ( parent_end, * lhs, r_val) ;
490494 }
491495 _ => unreachable ! ( ) ,
492- } ) ;
493- new_stmts . collect ( )
496+ }
497+ }
494498 }
495499}
0 commit comments