@@ -67,13 +67,13 @@ trait SimplifyMatch<'tcx> {
6767 _ => unreachable ! ( ) ,
6868 } ;
6969
70- if !self . can_simplify ( tcx, targets, param_env, bbs) {
70+ let discr_ty = discr. ty ( local_decls, tcx) ;
71+ if !self . can_simplify ( tcx, targets, param_env, bbs, discr_ty) {
7172 return false ;
7273 }
7374
7475 // Take ownership of items now that we know we can optimize.
7576 let discr = discr. clone ( ) ;
76- let discr_ty = discr. ty ( local_decls, tcx) ;
7777
7878 // Introduce a temporary for the discriminant value.
7979 let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
@@ -103,6 +103,7 @@ trait SimplifyMatch<'tcx> {
103103 targets : & SwitchTargets ,
104104 param_env : ParamEnv < ' tcx > ,
105105 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
106+ discr_ty : Ty < ' tcx > ,
106107 ) -> bool ;
107108
108109 fn new_stmts (
@@ -156,6 +157,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
156157 targets : & SwitchTargets ,
157158 param_env : ParamEnv < ' tcx > ,
158159 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
160+ _discr_ty : Ty < ' tcx > ,
159161 ) -> bool {
160162 if targets. iter ( ) . len ( ) != 1 {
161163 return false ;
@@ -267,7 +269,7 @@ struct SimplifyToExp {
267269enum CompareType < ' tcx , ' a > {
268270 Same ( & ' a StatementKind < ' tcx > ) ,
269271 Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
270- Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
272+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > , bool ) ,
271273}
272274
273275enum TransfromType {
@@ -281,7 +283,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
281283 match compare_type {
282284 CompareType :: Same ( _) => TransfromType :: Same ,
283285 CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
284- CompareType :: Discr ( _, _) => TransfromType :: Discr ,
286+ CompareType :: Discr ( _, _, _ ) => TransfromType :: Discr ,
285287 }
286288 }
287289}
@@ -332,6 +334,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
332334 targets : & SwitchTargets ,
333335 param_env : ParamEnv < ' tcx > ,
334336 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
337+ discr_ty : Ty < ' tcx > ,
335338 ) -> bool {
336339 if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
337340 return false ;
@@ -354,6 +357,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
354357 return false ;
355358 }
356359
360+ let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
357361 let first_stmts = & bbs[ first_target] . statements ;
358362 let ( second_val, second_target) = iter. next ( ) . unwrap ( ) ;
359363 let second_stmts = & bbs[ second_target] . statements ;
@@ -381,12 +385,30 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
381385 ) {
382386 ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
383387 ( Some ( f) , Some ( s) )
384- if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
385- && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
388+ if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
389+ && f. try_to_int ( f. size ( ) ) . unwrap ( )
390+ == ScalarInt :: try_from_uint ( first_val, discr_size)
391+ . unwrap ( )
392+ . try_to_int ( discr_size)
393+ . unwrap ( )
394+ && s. try_to_int ( s. size ( ) ) . unwrap ( )
395+ == ScalarInt :: try_from_uint ( second_val, discr_size)
396+ . unwrap ( )
397+ . try_to_int ( discr_size)
398+ . unwrap ( ) )
399+ || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
400+ && Some ( s)
401+ == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
386402 {
387- CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
403+ CompareType :: Discr (
404+ lhs_f,
405+ f_c. const_ . ty ( ) ,
406+ f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
407+ )
408+ }
409+ _ => {
410+ return false ;
388411 }
389- _ => return false ,
390412 }
391413 }
392414
@@ -411,15 +433,26 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
411433 && s_c. const_ . ty ( ) == f_ty
412434 && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
413435 (
414- CompareType :: Discr ( lhs_f, f_ty) ,
436+ CompareType :: Discr ( lhs_f, f_ty, is_signed ) ,
415437 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
416438 ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
417439 let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
418440 return false ;
419441 } ;
420- if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
421- return false ;
442+ if is_signed
443+ && s_c. const_ . ty ( ) . is_signed ( )
444+ && f. try_to_int ( f. size ( ) ) . unwrap ( )
445+ == ScalarInt :: try_from_uint ( other_val, discr_size)
446+ . unwrap ( )
447+ . try_to_int ( discr_size)
448+ . unwrap ( )
449+ {
450+ continue ;
451+ }
452+ if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
453+ continue ;
422454 }
455+ return false ;
423456 }
424457 _ => return false ,
425458 }
0 commit comments