@@ -65,13 +65,13 @@ trait SimplifyMatch<'tcx> {
6565 _ => unreachable ! ( ) ,
6666 } ;
6767
68- if !self . can_simplify ( tcx, targets, param_env, bbs) {
68+ let discr_ty = discr. ty ( local_decls, tcx) ;
69+ if !self . can_simplify ( tcx, targets, param_env, bbs, discr_ty) {
6970 return false ;
7071 }
7172
7273 // Take ownership of items now that we know we can optimize.
7374 let discr = discr. clone ( ) ;
74- let discr_ty = discr. ty ( local_decls, tcx) ;
7575
7676 // Introduce a temporary for the discriminant value.
7777 let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
@@ -101,6 +101,7 @@ trait SimplifyMatch<'tcx> {
101101 targets : & SwitchTargets ,
102102 param_env : ParamEnv < ' tcx > ,
103103 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
104+ discr_ty : Ty < ' tcx > ,
104105 ) -> bool ;
105106
106107 fn new_stmts (
@@ -154,6 +155,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
154155 targets : & SwitchTargets ,
155156 param_env : ParamEnv < ' tcx > ,
156157 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
158+ _discr_ty : Ty < ' tcx > ,
157159 ) -> bool {
158160 if targets. iter ( ) . len ( ) != 1 {
159161 return false ;
@@ -265,7 +267,7 @@ struct SimplifyToExp {
265267enum CompareType < ' tcx , ' a > {
266268 Same ( & ' a StatementKind < ' tcx > ) ,
267269 Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
268- Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
270+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > , bool ) ,
269271}
270272
271273enum TransfromType {
@@ -279,7 +281,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
279281 match compare_type {
280282 CompareType :: Same ( _) => TransfromType :: Same ,
281283 CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
282- CompareType :: Discr ( _, _) => TransfromType :: Discr ,
284+ CompareType :: Discr ( _, _, _ ) => TransfromType :: Discr ,
283285 }
284286 }
285287}
@@ -330,6 +332,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
330332 targets : & SwitchTargets ,
331333 param_env : ParamEnv < ' tcx > ,
332334 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
335+ discr_ty : Ty < ' tcx > ,
333336 ) -> bool {
334337 if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
335338 return false ;
@@ -352,6 +355,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
352355 return false ;
353356 }
354357
358+ let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
355359 let first_stmts = & bbs[ first_target] . statements ;
356360 let ( second_val, second_target) = iter. next ( ) . unwrap ( ) ;
357361 let second_stmts = & bbs[ second_target] . statements ;
@@ -379,12 +383,30 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
379383 ) {
380384 ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
381385 ( Some ( f) , Some ( s) )
382- if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
383- && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
386+ if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
387+ && f. try_to_int ( f. size ( ) ) . unwrap ( )
388+ == ScalarInt :: try_from_uint ( first_val, discr_size)
389+ . unwrap ( )
390+ . try_to_int ( discr_size)
391+ . unwrap ( )
392+ && s. try_to_int ( s. size ( ) ) . unwrap ( )
393+ == ScalarInt :: try_from_uint ( second_val, discr_size)
394+ . unwrap ( )
395+ . try_to_int ( discr_size)
396+ . unwrap ( ) )
397+ || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
398+ && Some ( s)
399+ == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
384400 {
385- CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
401+ CompareType :: Discr (
402+ lhs_f,
403+ f_c. const_ . ty ( ) ,
404+ f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
405+ )
406+ }
407+ _ => {
408+ return false ;
386409 }
387- _ => return false ,
388410 }
389411 }
390412
@@ -409,15 +431,26 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
409431 && s_c. const_ . ty ( ) == f_ty
410432 && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
411433 (
412- CompareType :: Discr ( lhs_f, f_ty) ,
434+ CompareType :: Discr ( lhs_f, f_ty, is_signed ) ,
413435 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
414436 ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
415437 let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
416438 return false ;
417439 } ;
418- if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
419- return false ;
440+ if is_signed
441+ && s_c. const_ . ty ( ) . is_signed ( )
442+ && f. try_to_int ( f. size ( ) ) . unwrap ( )
443+ == ScalarInt :: try_from_uint ( other_val, discr_size)
444+ . unwrap ( )
445+ . try_to_int ( discr_size)
446+ . unwrap ( )
447+ {
448+ continue ;
449+ }
450+ if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
451+ continue ;
420452 }
453+ return false ;
421454 }
422455 _ => return false ,
423456 }
0 commit comments