11use rustc_index:: IndexVec ;
22use rustc_middle:: mir:: * ;
33use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
4+ use rustc_target:: abi:: Size ;
45use std:: iter;
56
67use super :: simplify:: simplify_cfg;
@@ -67,13 +68,13 @@ trait SimplifyMatch<'tcx> {
6768 _ => unreachable ! ( ) ,
6869 } ;
6970
70- if !self . can_simplify ( tcx, targets, param_env, bbs) {
71+ let discr_ty = discr. ty ( local_decls, tcx) ;
72+ if !self . can_simplify ( tcx, targets, param_env, bbs, discr_ty) {
7173 return false ;
7274 }
7375
7476 // Take ownership of items now that we know we can optimize.
7577 let discr = discr. clone ( ) ;
76- let discr_ty = discr. ty ( local_decls, tcx) ;
7778
7879 // Introduce a temporary for the discriminant value.
7980 let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
@@ -104,6 +105,7 @@ trait SimplifyMatch<'tcx> {
104105 targets : & SwitchTargets ,
105106 param_env : ParamEnv < ' tcx > ,
106107 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
108+ discr_ty : Ty < ' tcx > ,
107109 ) -> bool ;
108110
109111 fn new_stmts (
@@ -157,6 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
157159 targets : & SwitchTargets ,
158160 param_env : ParamEnv < ' tcx > ,
159161 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
162+ _discr_ty : Ty < ' tcx > ,
160163 ) -> bool {
161164 if targets. iter ( ) . len ( ) != 1 {
162165 return false ;
@@ -268,7 +271,7 @@ struct SimplifyToExp {
268271enum CompareType < ' tcx , ' a > {
269272 Same ( & ' a StatementKind < ' tcx > ) ,
270273 Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
271- Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
274+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > , bool ) ,
272275}
273276
274277enum TransfromType {
@@ -282,7 +285,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
282285 match compare_type {
283286 CompareType :: Same ( _) => TransfromType :: Same ,
284287 CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
285- CompareType :: Discr ( _, _) => TransfromType :: Discr ,
288+ CompareType :: Discr ( _, _, _ ) => TransfromType :: Discr ,
286289 }
287290 }
288291}
@@ -333,6 +336,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
333336 targets : & SwitchTargets ,
334337 param_env : ParamEnv < ' tcx > ,
335338 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
339+ discr_ty : Ty < ' tcx > ,
336340 ) -> bool {
337341 if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
338342 return false ;
@@ -355,13 +359,19 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
355359 return false ;
356360 }
357361
362+ let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
358363 let first_stmts = & bbs[ first_target] . statements ;
359364 let ( second_val, second_target) = target_iter. next ( ) . unwrap ( ) ;
360365 let second_stmts = & bbs[ second_target] . statements ;
361366 if first_stmts. len ( ) != second_stmts. len ( ) {
362367 return false ;
363368 }
364369
370+ fn int_equal ( l : ScalarInt , r : impl Into < u128 > , size : Size ) -> bool {
371+ l. try_to_int ( l. size ( ) ) . unwrap ( )
372+ == ScalarInt :: try_from_uint ( r, size) . unwrap ( ) . try_to_int ( size) . unwrap ( )
373+ }
374+
365375 let mut compare_types = Vec :: new ( ) ;
366376 for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
367377 let compare_type = match ( & f. kind , & s. kind ) {
@@ -382,12 +392,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
382392 ) {
383393 ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
384394 ( Some ( f) , Some ( s) )
385- if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
386- && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
395+ if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
396+ && int_equal ( f, first_val, discr_size)
397+ && int_equal ( s, second_val, discr_size) )
398+ || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
399+ && Some ( s)
400+ == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
387401 {
388- CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
402+ CompareType :: Discr (
403+ lhs_f,
404+ f_c. const_ . ty ( ) ,
405+ f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
406+ )
407+ }
408+ _ => {
409+ return false ;
389410 }
390- _ => return false ,
391411 }
392412 }
393413
@@ -413,15 +433,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
413433 && s_c. const_ . ty ( ) == f_ty
414434 && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
415435 (
416- CompareType :: Discr ( lhs_f, f_ty) ,
436+ CompareType :: Discr ( lhs_f, f_ty, is_signed ) ,
417437 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
418438 ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
419439 let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
420440 return false ;
421441 } ;
422- if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
423- return false ;
442+ if is_signed
443+ && s_c. const_ . ty ( ) . is_signed ( )
444+ && int_equal ( f, other_val, discr_size)
445+ {
446+ continue ;
447+ }
448+ if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
449+ continue ;
424450 }
451+ return false ;
425452 }
426453 _ => return false ,
427454 }
0 commit comments