@@ -3,8 +3,10 @@ use std::iter;
33use rustc_index:: IndexSlice ;
44use rustc_middle:: mir:: patch:: MirPatch ;
55use rustc_middle:: mir:: * ;
6+ use rustc_middle:: ty:: layout:: { IntegerExt , TyAndLayout } ;
67use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
7- use rustc_target:: abi:: Size ;
8+ use rustc_target:: abi:: Integer ;
9+ use rustc_type_ir:: TyKind :: * ;
810
911use super :: simplify:: simplify_cfg;
1012
@@ -264,33 +266,56 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
264266 }
265267}
266268
269+ /// Check if the cast constant using `IntToInt` is equal to the target constant.
270+ fn can_cast (
271+ tcx : TyCtxt < ' _ > ,
272+ src_val : impl Into < u128 > ,
273+ src_layout : TyAndLayout < ' _ > ,
274+ cast_ty : Ty < ' _ > ,
275+ target_scalar : ScalarInt ,
276+ ) -> bool {
277+ let from_scalar = ScalarInt :: try_from_uint ( src_val. into ( ) , src_layout. size ) . unwrap ( ) ;
278+ let v = match src_layout. ty . kind ( ) {
279+ Uint ( _) => from_scalar. to_uint ( src_layout. size ) ,
280+ Int ( _) => from_scalar. to_int ( src_layout. size ) as u128 ,
281+ _ => unreachable ! ( "invalid int" ) ,
282+ } ;
283+ let size = match * cast_ty. kind ( ) {
284+ Int ( t) => Integer :: from_int_ty ( & tcx, t) . size ( ) ,
285+ Uint ( t) => Integer :: from_uint_ty ( & tcx, t) . size ( ) ,
286+ _ => unreachable ! ( "invalid int" ) ,
287+ } ;
288+ let v = size. truncate ( v) ;
289+ let cast_scalar = ScalarInt :: try_from_uint ( v, size) . unwrap ( ) ;
290+ cast_scalar == target_scalar
291+ }
292+
267293#[ derive( Default ) ]
268294struct SimplifyToExp {
269- transfrom_types : Vec < TransfromType > ,
295+ transfrom_kinds : Vec < TransfromKind > ,
270296}
271297
272298#[ derive( Clone , Copy ) ]
273- enum CompareType < ' tcx , ' a > {
299+ enum ExpectedTransformKind < ' tcx , ' a > {
274300 /// Identical statements.
275301 Same ( & ' a StatementKind < ' tcx > ) ,
276302 /// Assignment statements have the same value.
277- Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
303+ SameByEq { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , scalar : ScalarInt } ,
278304 /// Enum variant comparison type.
279- Discr { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , is_signed : bool } ,
305+ Cast { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > } ,
280306}
281307
282- enum TransfromType {
308+ enum TransfromKind {
283309 Same ,
284- Eq ,
285- Discr ,
310+ Cast ,
286311}
287312
288- impl From < CompareType < ' _ , ' _ > > for TransfromType {
289- fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
313+ impl From < ExpectedTransformKind < ' _ , ' _ > > for TransfromKind {
314+ fn from ( compare_type : ExpectedTransformKind < ' _ , ' _ > ) -> Self {
290315 match compare_type {
291- CompareType :: Same ( _) => TransfromType :: Same ,
292- CompareType :: Eq ( _ , _ , _ ) => TransfromType :: Eq ,
293- CompareType :: Discr { .. } => TransfromType :: Discr ,
316+ ExpectedTransformKind :: Same ( _) => TransfromKind :: Same ,
317+ ExpectedTransformKind :: SameByEq { .. } => TransfromKind :: Same ,
318+ ExpectedTransformKind :: Cast { .. } => TransfromKind :: Cast ,
294319 }
295320 }
296321}
@@ -354,7 +379,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
354379 return None ;
355380 }
356381 let mut target_iter = targets. iter ( ) ;
357- let ( first_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
382+ let ( first_case_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
358383 let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
359384 // Check that destinations are identical, and if not, then don't optimize this block
360385 if !targets
@@ -364,24 +389,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
364389 return None ;
365390 }
366391
367- let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
392+ let discr_layout = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) ;
368393 let first_stmts = & bbs[ first_target] . statements ;
369- let ( second_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
394+ let ( second_case_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
370395 let second_stmts = & bbs[ second_target] . statements ;
371396 if first_stmts. len ( ) != second_stmts. len ( ) {
372397 return None ;
373398 }
374399
375- fn int_equal ( l : ScalarInt , r : impl Into < u128 > , size : Size ) -> bool {
376- l. to_bits_unchecked ( ) == ScalarInt :: try_from_uint ( r, size) . unwrap ( ) . to_bits_unchecked ( )
377- }
378-
379400 // We first compare the two branches, and then the other branches need to fulfill the same conditions.
380- let mut compare_types = Vec :: new ( ) ;
401+ let mut expected_transform_kinds = Vec :: new ( ) ;
381402 for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
382403 let compare_type = match ( & f. kind , & s. kind ) {
383404 // If two statements are exactly the same, we can optimize.
384- ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
405+ ( f_s, s_s) if f_s == s_s => ExpectedTransformKind :: Same ( f_s) ,
385406
386407 // If two statements are assignments with the match values to the same place, we can optimize.
387408 (
@@ -395,22 +416,29 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
395416 f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
396417 s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
397418 ) {
398- ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
399- // Enum variants can also be simplified to an assignment statement if their values are equal.
400- // We need to consider both unsigned and signed scenarios here.
419+ ( Some ( f) , Some ( s) ) if f == s => ExpectedTransformKind :: SameByEq {
420+ place : lhs_f,
421+ ty : f_c. const_ . ty ( ) ,
422+ scalar : f,
423+ } ,
424+ // Enum variants can also be simplified to an assignment statement,
425+ // if we can use `IntToInt` cast to get an equal value.
401426 ( Some ( f) , Some ( s) )
402- if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
403- && int_equal ( f, first_val, discr_size)
404- && int_equal ( s, second_val, discr_size) )
405- || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
406- && Some ( s)
407- == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
427+ if ( can_cast (
428+ tcx,
429+ first_case_val,
430+ discr_layout,
431+ f_c. const_ . ty ( ) ,
432+ f,
433+ ) && can_cast (
434+ tcx,
435+ second_case_val,
436+ discr_layout,
437+ f_c. const_ . ty ( ) ,
438+ s,
439+ ) ) =>
408440 {
409- CompareType :: Discr {
410- place : lhs_f,
411- ty : f_c. const_ . ty ( ) ,
412- is_signed : f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
413- }
441+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_c. const_ . ty ( ) }
414442 }
415443 _ => {
416444 return None ;
@@ -421,47 +449,36 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
421449 // Otherwise we cannot optimize. Try another block.
422450 _ => return None ,
423451 } ;
424- compare_types . push ( compare_type) ;
452+ expected_transform_kinds . push ( compare_type) ;
425453 }
426454
427455 // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
428456 for ( other_val, other_target) in target_iter {
429457 let other_stmts = & bbs[ other_target] . statements ;
430- if compare_types . len ( ) != other_stmts. len ( ) {
458+ if expected_transform_kinds . len ( ) != other_stmts. len ( ) {
431459 return None ;
432460 }
433- for ( f, s) in iter:: zip ( & compare_types , other_stmts) {
461+ for ( f, s) in iter:: zip ( & expected_transform_kinds , other_stmts) {
434462 match ( * f, & s. kind ) {
435- ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
463+ ( ExpectedTransformKind :: Same ( f_s) , s_s) if f_s == s_s => { }
436464 (
437- CompareType :: Eq ( lhs_f, f_ty, val ) ,
465+ ExpectedTransformKind :: SameByEq { place : lhs_f, ty : f_ty, scalar } ,
438466 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
439467 ) if lhs_f == lhs_s
440468 && s_c. const_ . ty ( ) == f_ty
441- && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val ) => { }
469+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( scalar ) => { }
442470 (
443- CompareType :: Discr { place : lhs_f, ty : f_ty, is_signed } ,
471+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_ty } ,
444472 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
445- ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
446- let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
447- return None ;
448- } ;
449- if is_signed
450- && s_c. const_ . ty ( ) . is_signed ( )
451- && int_equal ( f, other_val, discr_size)
452- {
453- continue ;
454- }
455- if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
456- continue ;
457- }
458- return None ;
459- }
473+ ) if let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env)
474+ && lhs_f == lhs_s
475+ && s_c. const_ . ty ( ) == f_ty
476+ && can_cast ( tcx, other_val, discr_layout, f_ty, f) => { }
460477 _ => return None ,
461478 }
462479 }
463480 }
464- self . transfrom_types = compare_types . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
481+ self . transfrom_kinds = expected_transform_kinds . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
465482 Some ( ( ) )
466483 }
467484
@@ -479,13 +496,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
479496 let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
480497 let first = & bbs[ first] ;
481498
482- for ( t, s) in iter:: zip ( & self . transfrom_types , & first. statements ) {
499+ for ( t, s) in iter:: zip ( & self . transfrom_kinds , & first. statements ) {
483500 match ( t, & s. kind ) {
484- ( TransfromType :: Same , _ ) | ( TransfromType :: Eq , _) => {
501+ ( TransfromKind :: Same , _) => {
485502 patch. add_statement ( parent_end, s. kind . clone ( ) ) ;
486503 }
487504 (
488- TransfromType :: Discr ,
505+ TransfromKind :: Cast ,
489506 StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
490507 ) => {
491508 let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
0 commit comments