@@ -3,8 +3,10 @@ use std::iter;
33use rustc_index:: IndexSlice ;
44use rustc_middle:: mir:: patch:: MirPatch ;
55use rustc_middle:: mir:: * ;
6+ use rustc_middle:: ty:: layout:: { IntegerExt , TyAndLayout } ;
67use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
7- use rustc_target:: abi:: Size ;
8+ use rustc_target:: abi:: Integer ;
9+ use rustc_type_ir:: TyKind :: * ;
810
911use super :: simplify:: simplify_cfg;
1012
@@ -42,10 +44,7 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
4244 should_cleanup = true ;
4345 continue ;
4446 }
45- // unsound: https://github.com/rust-lang/rust/issues/124150
46- if tcx. sess . opts . unstable_opts . unsound_mir_opts
47- && SimplifyToExp :: default ( ) . simplify ( tcx, body, bb_idx, param_env) . is_some ( )
48- {
47+ if SimplifyToExp :: default ( ) . simplify ( tcx, body, bb_idx, param_env) . is_some ( ) {
4948 should_cleanup = true ;
5049 continue ;
5150 }
@@ -264,33 +263,56 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
264263 }
265264}
266265
266+ /// Check if the cast constant using `IntToInt` is equal to the target constant.
267+ fn can_cast (
268+ tcx : TyCtxt < ' _ > ,
269+ src_val : impl Into < u128 > ,
270+ src_layout : TyAndLayout < ' _ > ,
271+ cast_ty : Ty < ' _ > ,
272+ target_scalar : ScalarInt ,
273+ ) -> bool {
274+ let from_scalar = ScalarInt :: try_from_uint ( src_val. into ( ) , src_layout. size ) . unwrap ( ) ;
275+ let v = match src_layout. ty . kind ( ) {
276+ Uint ( _) => from_scalar. to_uint ( src_layout. size ) ,
277+ Int ( _) => from_scalar. to_int ( src_layout. size ) as u128 ,
278+ _ => unreachable ! ( "invalid int" ) ,
279+ } ;
280+ let size = match * cast_ty. kind ( ) {
281+ Int ( t) => Integer :: from_int_ty ( & tcx, t) . size ( ) ,
282+ Uint ( t) => Integer :: from_uint_ty ( & tcx, t) . size ( ) ,
283+ _ => unreachable ! ( "invalid int" ) ,
284+ } ;
285+ let v = size. truncate ( v) ;
286+ let cast_scalar = ScalarInt :: try_from_uint ( v, size) . unwrap ( ) ;
287+ cast_scalar == target_scalar
288+ }
289+
267290#[ derive( Default ) ]
268291struct SimplifyToExp {
269- transfrom_types : Vec < TransfromType > ,
292+ transfrom_kinds : Vec < TransfromKind > ,
270293}
271294
272295#[ derive( Clone , Copy ) ]
273- enum CompareType < ' tcx , ' a > {
296+ enum ExpectedTransformKind < ' tcx , ' a > {
274297 /// Identical statements.
275298 Same ( & ' a StatementKind < ' tcx > ) ,
276299 /// Assignment statements have the same value.
277- Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
300+ SameByEq { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , scalar : ScalarInt } ,
278301 /// Enum variant comparison type.
279- Discr { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , is_signed : bool } ,
302+ Cast { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > } ,
280303}
281304
282- enum TransfromType {
305+ enum TransfromKind {
283306 Same ,
284- Eq ,
285- Discr ,
307+ Cast ,
286308}
287309
288- impl From < CompareType < ' _ , ' _ > > for TransfromType {
289- fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
310+ impl From < ExpectedTransformKind < ' _ , ' _ > > for TransfromKind {
311+ fn from ( compare_type : ExpectedTransformKind < ' _ , ' _ > ) -> Self {
290312 match compare_type {
291- CompareType :: Same ( _) => TransfromType :: Same ,
292- CompareType :: Eq ( _ , _ , _ ) => TransfromType :: Eq ,
293- CompareType :: Discr { .. } => TransfromType :: Discr ,
313+ ExpectedTransformKind :: Same ( _) => TransfromKind :: Same ,
314+ ExpectedTransformKind :: SameByEq { .. } => TransfromKind :: Same ,
315+ ExpectedTransformKind :: Cast { .. } => TransfromKind :: Cast ,
294316 }
295317 }
296318}
@@ -354,7 +376,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
354376 return None ;
355377 }
356378 let mut target_iter = targets. iter ( ) ;
357- let ( first_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
379+ let ( first_case_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
358380 let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
359381 // Check that destinations are identical, and if not, then don't optimize this block
360382 if !targets
@@ -364,24 +386,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
364386 return None ;
365387 }
366388
367- let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
389+ let discr_layout = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) ;
368390 let first_stmts = & bbs[ first_target] . statements ;
369- let ( second_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
391+ let ( second_case_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
370392 let second_stmts = & bbs[ second_target] . statements ;
371393 if first_stmts. len ( ) != second_stmts. len ( ) {
372394 return None ;
373395 }
374396
375- fn int_equal ( l : ScalarInt , r : impl Into < u128 > , size : Size ) -> bool {
376- l. to_bits_unchecked ( ) == ScalarInt :: try_from_uint ( r, size) . unwrap ( ) . to_bits_unchecked ( )
377- }
378-
379397 // We first compare the two branches, and then the other branches need to fulfill the same conditions.
380- let mut compare_types = Vec :: new ( ) ;
398+ let mut expected_transform_kinds = Vec :: new ( ) ;
381399 for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
382400 let compare_type = match ( & f. kind , & s. kind ) {
383401 // If two statements are exactly the same, we can optimize.
384- ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
402+ ( f_s, s_s) if f_s == s_s => ExpectedTransformKind :: Same ( f_s) ,
385403
386404 // If two statements are assignments with the match values to the same place, we can optimize.
387405 (
@@ -395,22 +413,29 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
395413 f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
396414 s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
397415 ) {
398- ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
399- // Enum variants can also be simplified to an assignment statement if their values are equal.
400- // We need to consider both unsigned and signed scenarios here.
416+ ( Some ( f) , Some ( s) ) if f == s => ExpectedTransformKind :: SameByEq {
417+ place : lhs_f,
418+ ty : f_c. const_ . ty ( ) ,
419+ scalar : f,
420+ } ,
421+ // Enum variants can also be simplified to an assignment statement,
422+ // if we can use `IntToInt` cast to get an equal value.
401423 ( Some ( f) , Some ( s) )
402- if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
403- && int_equal ( f, first_val, discr_size)
404- && int_equal ( s, second_val, discr_size) )
405- || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
406- && Some ( s)
407- == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
424+ if ( can_cast (
425+ tcx,
426+ first_case_val,
427+ discr_layout,
428+ f_c. const_ . ty ( ) ,
429+ f,
430+ ) && can_cast (
431+ tcx,
432+ second_case_val,
433+ discr_layout,
434+ f_c. const_ . ty ( ) ,
435+ s,
436+ ) ) =>
408437 {
409- CompareType :: Discr {
410- place : lhs_f,
411- ty : f_c. const_ . ty ( ) ,
412- is_signed : f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
413- }
438+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_c. const_ . ty ( ) }
414439 }
415440 _ => {
416441 return None ;
@@ -421,47 +446,36 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
421446 // Otherwise we cannot optimize. Try another block.
422447 _ => return None ,
423448 } ;
424- compare_types . push ( compare_type) ;
449+ expected_transform_kinds . push ( compare_type) ;
425450 }
426451
427452 // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
428453 for ( other_val, other_target) in target_iter {
429454 let other_stmts = & bbs[ other_target] . statements ;
430- if compare_types . len ( ) != other_stmts. len ( ) {
455+ if expected_transform_kinds . len ( ) != other_stmts. len ( ) {
431456 return None ;
432457 }
433- for ( f, s) in iter:: zip ( & compare_types , other_stmts) {
458+ for ( f, s) in iter:: zip ( & expected_transform_kinds , other_stmts) {
434459 match ( * f, & s. kind ) {
435- ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
460+ ( ExpectedTransformKind :: Same ( f_s) , s_s) if f_s == s_s => { }
436461 (
437- CompareType :: Eq ( lhs_f, f_ty, val ) ,
462+ ExpectedTransformKind :: SameByEq { place : lhs_f, ty : f_ty, scalar } ,
438463 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
439464 ) if lhs_f == lhs_s
440465 && s_c. const_ . ty ( ) == f_ty
441- && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val ) => { }
466+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( scalar ) => { }
442467 (
443- CompareType :: Discr { place : lhs_f, ty : f_ty, is_signed } ,
468+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_ty } ,
444469 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
445- ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
446- let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
447- return None ;
448- } ;
449- if is_signed
450- && s_c. const_ . ty ( ) . is_signed ( )
451- && int_equal ( f, other_val, discr_size)
452- {
453- continue ;
454- }
455- if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
456- continue ;
457- }
458- return None ;
459- }
470+ ) if let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env)
471+ && lhs_f == lhs_s
472+ && s_c. const_ . ty ( ) == f_ty
473+ && can_cast ( tcx, other_val, discr_layout, f_ty, f) => { }
460474 _ => return None ,
461475 }
462476 }
463477 }
464- self . transfrom_types = compare_types . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
478+ self . transfrom_kinds = expected_transform_kinds . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
465479 Some ( ( ) )
466480 }
467481
@@ -479,13 +493,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
479493 let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
480494 let first = & bbs[ first] ;
481495
482- for ( t, s) in iter:: zip ( & self . transfrom_types , & first. statements ) {
496+ for ( t, s) in iter:: zip ( & self . transfrom_kinds , & first. statements ) {
483497 match ( t, & s. kind ) {
484- ( TransfromType :: Same , _ ) | ( TransfromType :: Eq , _) => {
498+ ( TransfromKind :: Same , _) => {
485499 patch. add_statement ( parent_end, s. kind . clone ( ) ) ;
486500 }
487501 (
488- TransfromType :: Discr ,
502+ TransfromKind :: Cast ,
489503 StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
490504 ) => {
491505 let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
0 commit comments