11use rustc_index:: IndexVec ;
22use rustc_middle:: mir:: * ;
3- use rustc_middle:: ty:: { ParamEnv , Ty , TyCtxt } ;
3+ use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
44use std:: iter;
55
66use super :: simplify:: simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
3838 should_cleanup = true ;
3939 continue ;
4040 }
41+ if SimplifyToExp :: default ( ) . simplify ( tcx, & mut body. local_decls , bbs, bb_idx, param_env)
42+ {
43+ should_cleanup = true ;
44+ continue ;
45+ }
4146 }
4247
4348 if should_cleanup {
@@ -47,8 +52,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
4752}
4853
4954trait SimplifyMatch < ' tcx > {
55+ /// Simplifies a match statement, returning true if the simplification succeeds, false otherwise.
56+ /// Generic code is written here, and we generally don't need a custom implementation.
5057 fn simplify (
51- & self ,
58+ & mut self ,
5259 tcx : TyCtxt < ' tcx > ,
5360 local_decls : & mut IndexVec < Local , LocalDecl < ' tcx > > ,
5461 bbs : & mut IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
@@ -72,9 +79,7 @@ trait SimplifyMatch<'tcx> {
7279 let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
7380 let discr_local = local_decls. push ( LocalDecl :: new ( discr_ty, source_info. span ) ) ;
7481
75- // We already checked that first and second are different blocks,
76- // and bb_idx has a different terminator from both of them.
77- let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local. clone ( ) , discr_ty) ;
82+ let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local, discr_ty) ;
7883 let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
7984 let ( from, first) = bbs. pick2_mut ( switch_bb_idx, first) ;
8085 from. statements
@@ -90,8 +95,11 @@ trait SimplifyMatch<'tcx> {
9095 true
9196 }
9297
98+ /// Check that the BBs to be simplified satisfies all distinct and
99+ /// that the terminator are the same.
100+ /// There are also conditions for different ways of simplification.
93101 fn can_simplify (
94- & self ,
102+ & mut self ,
95103 tcx : TyCtxt < ' tcx > ,
96104 targets : & SwitchTargets ,
97105 param_env : ParamEnv < ' tcx > ,
@@ -144,7 +152,7 @@ struct SimplifyToIf;
144152/// ```
145153impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToIf {
146154 fn can_simplify (
147- & self ,
155+ & mut self ,
148156 tcx : TyCtxt < ' tcx > ,
149157 targets : & SwitchTargets ,
150158 param_env : ParamEnv < ' tcx > ,
@@ -250,3 +258,211 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250258 new_stmts. collect ( )
251259 }
252260}
261+
262+ #[ derive( Default ) ]
263+ struct SimplifyToExp {
264+ transfrom_types : Vec < TransfromType > ,
265+ }
266+
267+ #[ derive( Clone , Copy ) ]
268+ enum CompareType < ' tcx , ' a > {
269+ Same ( & ' a StatementKind < ' tcx > ) ,
270+ Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
271+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
272+ }
273+
274+ enum TransfromType {
275+ Same ,
276+ Eq ,
277+ Discr ,
278+ }
279+
280+ impl From < CompareType < ' _ , ' _ > > for TransfromType {
281+ fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
282+ match compare_type {
283+ CompareType :: Same ( _) => TransfromType :: Same ,
284+ CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
285+ CompareType :: Discr ( _, _) => TransfromType :: Discr ,
286+ }
287+ }
288+ }
289+
290+ /// If we find that the value of match is the same as the assignment,
291+ /// merge a target block statements into the source block,
292+ /// using cast to transform different integer types.
293+ ///
294+ /// For example:
295+ ///
296+ /// ```ignore (MIR)
297+ /// bb0: {
298+ /// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
299+ /// }
300+ ///
301+ /// bb1: {
302+ /// unreachable;
303+ /// }
304+ ///
305+ /// bb2: {
306+ /// _0 = const 1_i16;
307+ /// goto -> bb5;
308+ /// }
309+ ///
310+ /// bb3: {
311+ /// _0 = const 2_i16;
312+ /// goto -> bb5;
313+ /// }
314+ ///
315+ /// bb4: {
316+ /// _0 = const 3_i16;
317+ /// goto -> bb5;
318+ /// }
319+ /// ```
320+ ///
321+ /// into:
322+ ///
323+ /// ```ignore (MIR)
324+ /// bb0: {
325+ /// _0 = _3 as i16 (IntToInt);
326+ /// goto -> bb5;
327+ /// }
328+ /// ```
329+ impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToExp {
330+ fn can_simplify (
331+ & mut self ,
332+ tcx : TyCtxt < ' tcx > ,
333+ targets : & SwitchTargets ,
334+ param_env : ParamEnv < ' tcx > ,
335+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
336+ ) -> bool {
337+ if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
338+ return false ;
339+ }
340+ // We require that the possible target blocks all be distinct.
341+ if !targets. is_distinct ( ) {
342+ return false ;
343+ }
344+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
345+ return false ;
346+ }
347+ let mut target_iter = targets. iter ( ) ;
348+ let ( first_val, first_target) = target_iter. next ( ) . unwrap ( ) ;
349+ let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
350+ // Check that destinations are identical, and if not, then don't optimize this block
351+ if !targets
352+ . iter ( )
353+ . all ( |( _, other_target) | first_terminator_kind == & bbs[ other_target] . terminator ( ) . kind )
354+ {
355+ return false ;
356+ }
357+
358+ let first_stmts = & bbs[ first_target] . statements ;
359+ let ( second_val, second_target) = target_iter. next ( ) . unwrap ( ) ;
360+ let second_stmts = & bbs[ second_target] . statements ;
361+ if first_stmts. len ( ) != second_stmts. len ( ) {
362+ return false ;
363+ }
364+
365+ let mut compare_types = Vec :: new ( ) ;
366+ for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
367+ let compare_type = match ( & f. kind , & s. kind ) {
368+ // If two statements are exactly the same, we can optimize.
369+ ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
370+
371+ // If two statements are assignments with the match values to the same place, we can optimize.
372+ (
373+ StatementKind :: Assign ( box ( lhs_f, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
374+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
375+ ) if lhs_f == lhs_s
376+ && f_c. const_ . ty ( ) == s_c. const_ . ty ( )
377+ && f_c. const_ . ty ( ) . is_integral ( ) =>
378+ {
379+ match (
380+ f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
381+ s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
382+ ) {
383+ ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
384+ ( Some ( f) , Some ( s) )
385+ if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
386+ && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
387+ {
388+ CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
389+ }
390+ _ => return false ,
391+ }
392+ }
393+
394+ // Otherwise we cannot optimize. Try another block.
395+ _ => return false ,
396+ } ;
397+ compare_types. push ( compare_type) ;
398+ }
399+
400+ // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
401+ for ( other_val, other_target) in target_iter {
402+ let other_stmts = & bbs[ other_target] . statements ;
403+ if compare_types. len ( ) != other_stmts. len ( ) {
404+ return false ;
405+ }
406+ for ( f, s) in iter:: zip ( & compare_types, other_stmts) {
407+ match ( * f, & s. kind ) {
408+ ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
409+ (
410+ CompareType :: Eq ( lhs_f, f_ty, val) ,
411+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
412+ ) if lhs_f == lhs_s
413+ && s_c. const_ . ty ( ) == f_ty
414+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
415+ (
416+ CompareType :: Discr ( lhs_f, f_ty) ,
417+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
418+ ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
419+ let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
420+ return false ;
421+ } ;
422+ if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
423+ return false ;
424+ }
425+ }
426+ _ => return false ,
427+ }
428+ }
429+ }
430+ self . transfrom_types = compare_types. into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
431+ true
432+ }
433+
434+ fn new_stmts (
435+ & self ,
436+ _tcx : TyCtxt < ' tcx > ,
437+ targets : & SwitchTargets ,
438+ _param_env : ParamEnv < ' tcx > ,
439+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
440+ discr_local : Local ,
441+ discr_ty : Ty < ' tcx > ,
442+ ) -> Vec < Statement < ' tcx > > {
443+ let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
444+ let first = & bbs[ first] ;
445+
446+ let new_stmts =
447+ iter:: zip ( & self . transfrom_types , & first. statements ) . map ( |( t, s) | match ( t, & s. kind ) {
448+ ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => ( * s) . clone ( ) ,
449+ (
450+ TransfromType :: Discr ,
451+ StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
452+ ) => {
453+ let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
454+ let r_val = if f_c. const_ . ty ( ) == discr_ty {
455+ Rvalue :: Use ( operand)
456+ } else {
457+ Rvalue :: Cast ( CastKind :: IntToInt , operand, f_c. const_ . ty ( ) )
458+ } ;
459+ Statement {
460+ source_info : s. source_info ,
461+ kind : StatementKind :: Assign ( Box :: new ( ( * lhs, r_val) ) ) ,
462+ }
463+ }
464+ _ => unreachable ! ( ) ,
465+ } ) ;
466+ new_stmts. collect ( )
467+ }
468+ }
0 commit comments