11use rustc_index:: IndexVec ;
22use rustc_middle:: mir:: * ;
3- use rustc_middle:: ty:: { ParamEnv , Ty , TyCtxt } ;
3+ use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
44use std:: iter;
55
66use super :: simplify:: simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
3838 should_cleanup = true ;
3939 continue ;
4040 }
41+ if SimplifyToExp :: default ( ) . simplify ( tcx, & mut body. local_decls , bbs, bb_idx, param_env)
42+ {
43+ should_cleanup = true ;
44+ continue ;
45+ }
4146 }
4247
4348 if should_cleanup {
@@ -47,8 +52,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
4752}
4853
4954trait SimplifyMatch < ' tcx > {
55+ /// Simplifies a match statement, returning true if the simplification succeeds, false otherwise.
56+ /// Generic code is written here, and we generally don't need a custom implementation.
5057 fn simplify (
51- & self ,
58+ & mut self ,
5259 tcx : TyCtxt < ' tcx > ,
5360 local_decls : & mut IndexVec < Local , LocalDecl < ' tcx > > ,
5461 bbs : & mut IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
@@ -72,9 +79,9 @@ trait SimplifyMatch<'tcx> {
7279 let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
7380 let discr_local = local_decls. push ( LocalDecl :: new ( discr_ty, source_info. span ) ) ;
7481
75- // We already checked that first and second are different blocks,
82+ // We already checked that targets are different blocks,
7683 // and bb_idx has a different terminator from both of them.
77- let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local. clone ( ) , discr_ty) ;
84+ let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local, discr_ty) ;
7885 let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
7986 let ( from, first) = bbs. pick2_mut ( switch_bb_idx, first) ;
8087 from. statements
@@ -91,7 +98,7 @@ trait SimplifyMatch<'tcx> {
9198 }
9299
93100 fn can_simplify (
94- & self ,
101+ & mut self ,
95102 tcx : TyCtxt < ' tcx > ,
96103 targets : & SwitchTargets ,
97104 param_env : ParamEnv < ' tcx > ,
@@ -144,7 +151,7 @@ struct SimplifyToIf;
144151/// ```
145152impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToIf {
146153 fn can_simplify (
147- & self ,
154+ & mut self ,
148155 tcx : TyCtxt < ' tcx > ,
149156 targets : & SwitchTargets ,
150157 param_env : ParamEnv < ' tcx > ,
@@ -250,3 +257,210 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250257 new_stmts. collect ( )
251258 }
252259}
260+
261+ #[ derive( Default ) ]
262+ struct SimplifyToExp {
263+ transfrom_types : Vec < TransfromType > ,
264+ }
265+
266+ #[ derive( Clone , Copy ) ]
267+ enum CompareType < ' tcx , ' a > {
268+ Same ( & ' a StatementKind < ' tcx > ) ,
269+ Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
270+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
271+ }
272+
273+ enum TransfromType {
274+ Same ,
275+ Eq ,
276+ Discr ,
277+ }
278+
279+ impl From < CompareType < ' _ , ' _ > > for TransfromType {
280+ fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
281+ match compare_type {
282+ CompareType :: Same ( _) => TransfromType :: Same ,
283+ CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
284+ CompareType :: Discr ( _, _) => TransfromType :: Discr ,
285+ }
286+ }
287+ }
288+
289+ /// If we find that the value of match is the same as the assignment,
290+ /// merge a target block statements into the source block,
291+ /// using cast to transform different integer types.
292+ ///
293+ /// For example:
294+ ///
295+ /// ```ignore (MIR)
296+ /// bb0: {
297+ /// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
298+ /// }
299+ ///
300+ /// bb1: {
301+ /// unreachable;
302+ /// }
303+ ///
304+ /// bb2: {
305+ /// _0 = const 1_i16;
306+ /// goto -> bb5;
307+ /// }
308+ ///
309+ /// bb3: {
310+ /// _0 = const 2_i16;
311+ /// goto -> bb5;
312+ /// }
313+ ///
314+ /// bb4: {
315+ /// _0 = const 3_i16;
316+ /// goto -> bb5;
317+ /// }
318+ /// ```
319+ ///
320+ /// into:
321+ ///
322+ /// ```ignore (MIR)
323+ /// bb0: {
324+ /// _0 = _3 as i16 (IntToInt);
325+ /// goto -> bb5;
326+ /// }
327+ /// ```
328+ impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToExp {
329+ fn can_simplify (
330+ & mut self ,
331+ tcx : TyCtxt < ' tcx > ,
332+ targets : & SwitchTargets ,
333+ param_env : ParamEnv < ' tcx > ,
334+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
335+ ) -> bool {
336+ if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
337+ return false ;
338+ }
339+ // We require that the possible target blocks all be distinct.
340+ if !targets. is_distinct ( ) {
341+ return false ;
342+ }
343+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
344+ return false ;
345+ }
346+ let mut iter = targets. iter ( ) ;
347+ let ( first_val, first_target) = iter. next ( ) . unwrap ( ) ;
348+ let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
349+ // Check that destinations are identical, and if not, then don't optimize this block
350+ if !targets
351+ . iter ( )
352+ . all ( |( _, other_target) | first_terminator_kind == & bbs[ other_target] . terminator ( ) . kind )
353+ {
354+ return false ;
355+ }
356+
357+ let first_stmts = & bbs[ first_target] . statements ;
358+ let ( second_val, second_target) = iter. next ( ) . unwrap ( ) ;
359+ let second_stmts = & bbs[ second_target] . statements ;
360+ if first_stmts. len ( ) != second_stmts. len ( ) {
361+ return false ;
362+ }
363+
364+ let mut compare_types = Vec :: new ( ) ;
365+ for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
366+ let compare_type = match ( & f. kind , & s. kind ) {
367+ // If two statements are exactly the same, we can optimize.
368+ ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
369+
370+ // If two statements are assignments with the match values to the same place, we can optimize.
371+ (
372+ StatementKind :: Assign ( box ( lhs_f, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
373+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
374+ ) if lhs_f == lhs_s
375+ && f_c. const_ . ty ( ) == s_c. const_ . ty ( )
376+ && f_c. const_ . ty ( ) . is_integral ( ) =>
377+ {
378+ match (
379+ f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
380+ s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
381+ ) {
382+ ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
383+ ( Some ( f) , Some ( s) )
384+ if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
385+ && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
386+ {
387+ CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
388+ }
389+ _ => return false ,
390+ }
391+ }
392+
393+ // Otherwise we cannot optimize. Try another block.
394+ _ => return false ,
395+ } ;
396+ compare_types. push ( compare_type) ;
397+ }
398+
399+ for ( other_val, other_target) in iter {
400+ let other_stmts = & bbs[ other_target] . statements ;
401+ if compare_types. len ( ) != other_stmts. len ( ) {
402+ return false ;
403+ }
404+ for ( f, s) in iter:: zip ( & compare_types, other_stmts) {
405+ match ( * f, & s. kind ) {
406+ ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
407+ (
408+ CompareType :: Eq ( lhs_f, f_ty, val) ,
409+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
410+ ) if lhs_f == lhs_s
411+ && s_c. const_ . ty ( ) == f_ty
412+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
413+ (
414+ CompareType :: Discr ( lhs_f, f_ty) ,
415+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
416+ ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
417+ let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
418+ return false ;
419+ } ;
420+ if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
421+ return false ;
422+ }
423+ }
424+ _ => return false ,
425+ }
426+ }
427+ }
428+ self . transfrom_types = compare_types. into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
429+ true
430+ }
431+
432+ fn new_stmts (
433+ & self ,
434+ _tcx : TyCtxt < ' tcx > ,
435+ targets : & SwitchTargets ,
436+ _param_env : ParamEnv < ' tcx > ,
437+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
438+ discr_local : Local ,
439+ discr_ty : Ty < ' tcx > ,
440+ ) -> Vec < Statement < ' tcx > > {
441+ let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
442+ let first = & bbs[ first] ;
443+
444+ let new_stmts =
445+ iter:: zip ( & self . transfrom_types , & first. statements ) . map ( |( t, s) | match ( t, & s. kind ) {
446+ ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => ( * s) . clone ( ) ,
447+ (
448+ TransfromType :: Discr ,
449+ StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
450+ ) => {
451+ let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
452+ let r_val = if f_c. const_ . ty ( ) == discr_ty {
453+ Rvalue :: Use ( operand)
454+ } else {
455+ Rvalue :: Cast ( CastKind :: IntToInt , operand, f_c. const_ . ty ( ) )
456+ } ;
457+ Statement {
458+ source_info : s. source_info ,
459+ kind : StatementKind :: Assign ( Box :: new ( ( * lhs, r_val) ) ) ,
460+ }
461+ }
462+ _ => unreachable ! ( ) ,
463+ } ) ;
464+ new_stmts. collect ( )
465+ }
466+ }
0 commit comments