11use rustc_index:: IndexVec ;
22use rustc_middle:: mir:: * ;
3- use rustc_middle:: ty:: { ParamEnv , Ty , TyCtxt } ;
3+ use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
44use std:: iter;
55
66use super :: simplify:: simplify_cfg;
@@ -38,6 +38,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
3838 should_cleanup = true ;
3939 continue ;
4040 }
41+ if SimplifyToExp :: default ( ) . simplify ( tcx, & mut body. local_decls , bbs, bb_idx, param_env)
42+ {
43+ should_cleanup = true ;
44+ continue ;
45+ }
4146 }
4247
4348 if should_cleanup {
@@ -48,7 +53,7 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
4853
4954trait SimplifyMatch < ' tcx > {
5055 fn simplify (
51- & self ,
56+ & mut self ,
5257 tcx : TyCtxt < ' tcx > ,
5358 local_decls : & mut IndexVec < Local , LocalDecl < ' tcx > > ,
5459 bbs : & mut IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
@@ -72,7 +77,7 @@ trait SimplifyMatch<'tcx> {
7277 let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
7378 let discr_local = local_decls. push ( LocalDecl :: new ( discr_ty, source_info. span ) ) ;
7479
75- // We already checked that first and second are different blocks,
80+ // We already checked that targets are different blocks,
7681 // and bb_idx has a different terminator from both of them.
7782 let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local. clone ( ) , discr_ty) ;
7883 let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
@@ -91,7 +96,7 @@ trait SimplifyMatch<'tcx> {
9196 }
9297
9398 fn can_simplify (
94- & self ,
99+ & mut self ,
95100 tcx : TyCtxt < ' tcx > ,
96101 targets : & SwitchTargets ,
97102 param_env : ParamEnv < ' tcx > ,
@@ -144,7 +149,7 @@ struct SimplifyToIf;
144149/// ```
145150impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToIf {
146151 fn can_simplify (
147- & self ,
152+ & mut self ,
148153 tcx : TyCtxt < ' tcx > ,
149154 targets : & SwitchTargets ,
150155 param_env : ParamEnv < ' tcx > ,
@@ -250,3 +255,210 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
250255 new_stmts. collect ( )
251256 }
252257}
258+
259+ #[ derive( Default ) ]
260+ struct SimplifyToExp {
261+ transfrom_types : Vec < TransfromType > ,
262+ }
263+
264+ #[ derive( Clone , Copy ) ]
265+ enum CompareType < ' tcx , ' a > {
266+ Same ( & ' a StatementKind < ' tcx > ) ,
267+ Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
268+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
269+ }
270+
271+ enum TransfromType {
272+ Same ,
273+ Eq ,
274+ Discr ,
275+ }
276+
277+ impl From < CompareType < ' _ , ' _ > > for TransfromType {
278+ fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
279+ match compare_type {
280+ CompareType :: Same ( _) => TransfromType :: Same ,
281+ CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
282+ CompareType :: Discr ( _, _) => TransfromType :: Discr ,
283+ }
284+ }
285+ }
286+
287+ /// If we find that the value of match is the same as the assignment,
288+ /// merge a target block statements into the source block,
289+ /// using cast to transform different integer types.
290+ ///
291+ /// For example:
292+ ///
293+ /// ```ignore (MIR)
294+ /// bb0: {
295+ /// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
296+ /// }
297+ ///
298+ /// bb1: {
299+ /// unreachable;
300+ /// }
301+ ///
302+ /// bb2: {
303+ /// _0 = const 1_i16;
304+ /// goto -> bb5;
305+ /// }
306+ ///
307+ /// bb3: {
308+ /// _0 = const 2_i16;
309+ /// goto -> bb5;
310+ /// }
311+ ///
312+ /// bb4: {
313+ /// _0 = const 3_i16;
314+ /// goto -> bb5;
315+ /// }
316+ /// ```
317+ ///
318+ /// into:
319+ ///
320+ /// ```ignore (MIR)
321+ /// bb0: {
322+ /// _0 = _3 as i16 (IntToInt);
323+ /// goto -> bb5;
324+ /// }
325+ /// ```
326+ impl < ' tcx > SimplifyMatch < ' tcx > for SimplifyToExp {
327+ fn can_simplify (
328+ & mut self ,
329+ tcx : TyCtxt < ' tcx > ,
330+ targets : & SwitchTargets ,
331+ param_env : ParamEnv < ' tcx > ,
332+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
333+ ) -> bool {
334+ if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
335+ return false ;
336+ }
337+ // We require that the possible target blocks all be distinct.
338+ if !targets. is_distinct ( ) {
339+ return false ;
340+ }
341+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
342+ return false ;
343+ }
344+ let mut iter = targets. iter ( ) ;
345+ let ( first_val, first_target) = iter. next ( ) . unwrap ( ) ;
346+ let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
347+ // Check that destinations are identical, and if not, then don't optimize this block
348+ if !targets
349+ . iter ( )
350+ . all ( |( _, other_target) | first_terminator_kind == & bbs[ other_target] . terminator ( ) . kind )
351+ {
352+ return false ;
353+ }
354+
355+ let first_stmts = & bbs[ first_target] . statements ;
356+ let ( second_val, second_target) = iter. next ( ) . unwrap ( ) ;
357+ let second_stmts = & bbs[ second_target] . statements ;
358+ if first_stmts. len ( ) != second_stmts. len ( ) {
359+ return false ;
360+ }
361+
362+ let mut compare_types = Vec :: new ( ) ;
363+ for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
364+ let compare_type = match ( & f. kind , & s. kind ) {
365+ // If two statements are exactly the same, we can optimize.
366+ ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
367+
368+ // If two statements are assignments with the match values to the same place, we can optimize.
369+ (
370+ StatementKind :: Assign ( box ( lhs_f, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
371+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
372+ ) if lhs_f == lhs_s
373+ && f_c. const_ . ty ( ) == s_c. const_ . ty ( )
374+ && f_c. const_ . ty ( ) . is_integral ( ) =>
375+ {
376+ match (
377+ f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
378+ s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
379+ ) {
380+ ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
381+ ( Some ( f) , Some ( s) )
382+ if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
383+ && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
384+ {
385+ CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
386+ }
387+ _ => return false ,
388+ }
389+ }
390+
391+ // Otherwise we cannot optimize. Try another block.
392+ _ => return false ,
393+ } ;
394+ compare_types. push ( compare_type) ;
395+ }
396+
397+ for ( other_val, other_target) in iter {
398+ let other_stmts = & bbs[ other_target] . statements ;
399+ if compare_types. len ( ) != other_stmts. len ( ) {
400+ return false ;
401+ }
402+ for ( f, s) in iter:: zip ( & compare_types, other_stmts) {
403+ match ( * f, & s. kind ) {
404+ ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
405+ (
406+ CompareType :: Eq ( lhs_f, f_ty, val) ,
407+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
408+ ) if lhs_f == lhs_s
409+ && s_c. const_ . ty ( ) == f_ty
410+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
411+ (
412+ CompareType :: Discr ( lhs_f, f_ty) ,
413+ StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
414+ ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
415+ let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
416+ return false ;
417+ } ;
418+ if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
419+ return false ;
420+ }
421+ }
422+ _ => return false ,
423+ }
424+ }
425+ }
426+ self . transfrom_types = compare_types. into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
427+ true
428+ }
429+
430+ fn new_stmts (
431+ & self ,
432+ _tcx : TyCtxt < ' tcx > ,
433+ targets : & SwitchTargets ,
434+ _param_env : ParamEnv < ' tcx > ,
435+ bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
436+ discr_local : Local ,
437+ discr_ty : Ty < ' tcx > ,
438+ ) -> Vec < Statement < ' tcx > > {
439+ let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
440+ let first = & bbs[ first] ;
441+
442+ let new_stmts =
443+ iter:: zip ( & self . transfrom_types , & first. statements ) . map ( |( t, s) | match ( t, & s. kind ) {
444+ ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => ( * s) . clone ( ) ,
445+ (
446+ TransfromType :: Discr ,
447+ StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
448+ ) => {
449+ let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
450+ let r_val = if f_c. const_ . ty ( ) == discr_ty {
451+ Rvalue :: Use ( operand)
452+ } else {
453+ Rvalue :: Cast ( CastKind :: IntToInt , operand, f_c. const_ . ty ( ) )
454+ } ;
455+ Statement {
456+ source_info : s. source_info ,
457+ kind : StatementKind :: Assign ( Box :: new ( ( * lhs, r_val) ) ) ,
458+ }
459+ }
460+ _ => unreachable ! ( ) ,
461+ } ) ;
462+ new_stmts. collect ( )
463+ }
464+ }
0 commit comments