@@ -263,33 +263,49 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
263263 }
264264}
265265
266+ /// Check if the cast constant using `IntToInt` is equal to the target constant.
267+ fn can_cast (
268+ from_val : impl Into < u128 > ,
269+ from_size : Size ,
270+ target_scalar : ScalarInt ,
271+ from_is_signed : bool ,
272+ ) -> bool {
273+ let from_scalar = ScalarInt :: try_from_uint ( from_val. into ( ) , from_size) . unwrap ( ) ;
274+ let to_size = target_scalar. size ( ) ;
275+ let cast_scalar = if from_is_signed {
276+ ScalarInt :: try_from_int ( from_scalar. to_int ( from_size) , to_size) . unwrap ( )
277+ } else {
278+ ScalarInt :: try_from_uint ( from_scalar. to_uint ( from_size) , to_size) . unwrap ( )
279+ } ;
280+ cast_scalar == target_scalar
281+ }
282+
266283#[ derive( Default ) ]
267284struct SimplifyToExp {
268- transfrom_types : Vec < TransfromType > ,
285+ transfrom_kinds : Vec < TransfromKind > ,
269286}
270287
271288#[ derive( Clone , Copy ) ]
272- enum CompareType < ' tcx , ' a > {
289+ enum ExpectedTransformKind < ' tcx , ' a > {
273290 /// Identical statements.
274291 Same ( & ' a StatementKind < ' tcx > ) ,
275292 /// Assignment statements have the same value.
276- Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
293+ SameByEq { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , scalar : ScalarInt } ,
277294 /// Enum variant comparison type.
278- Discr { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , is_signed : bool } ,
295+ Cast { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > } ,
279296}
280297
281- enum TransfromType {
298+ enum TransfromKind {
282299 Same ,
283- Eq ,
284- Discr ,
300+ Cast ,
285301}
286302
287- impl From < CompareType < ' _ , ' _ > > for TransfromType {
288- fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
303+ impl From < ExpectedTransformKind < ' _ , ' _ > > for TransfromKind {
304+ fn from ( compare_type : ExpectedTransformKind < ' _ , ' _ > ) -> Self {
289305 match compare_type {
290- CompareType :: Same ( _) => TransfromType :: Same ,
291- CompareType :: Eq ( _ , _ , _ ) => TransfromType :: Eq ,
292- CompareType :: Discr { .. } => TransfromType :: Discr ,
306+ ExpectedTransformKind :: Same ( _) => TransfromKind :: Same ,
307+ ExpectedTransformKind :: SameByEq { .. } => TransfromKind :: Same ,
308+ ExpectedTransformKind :: Cast { .. } => TransfromKind :: Cast ,
293309 }
294310 }
295311}
@@ -353,7 +369,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
353369 return None ;
354370 }
355371 let mut target_iter = targets. iter ( ) ;
356- let ( first_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
372+ let ( first_case_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
357373 let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
358374 // Check that destinations are identical, and if not, then don't optimize this block
359375 if !targets
@@ -365,22 +381,18 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
365381
366382 let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
367383 let first_stmts = & bbs[ first_target] . statements ;
368- let ( second_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
384+ let ( second_case_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
369385 let second_stmts = & bbs[ second_target] . statements ;
370386 if first_stmts. len ( ) != second_stmts. len ( ) {
371387 return None ;
372388 }
373389
374- fn int_equal ( l : ScalarInt , r : impl Into < u128 > , size : Size ) -> bool {
375- l. to_bits_unchecked ( ) == ScalarInt :: try_from_uint ( r, size) . unwrap ( ) . to_bits_unchecked ( )
376- }
377-
378390 // We first compare the two branches, and then the other branches need to fulfill the same conditions.
379- let mut compare_types = Vec :: new ( ) ;
391+ let mut expected_transform_kinds = Vec :: new ( ) ;
380392 for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
381393 let compare_type = match ( & f. kind , & s. kind ) {
382394 // If two statements are exactly the same, we can optimize.
383- ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
395+ ( f_s, s_s) if f_s == s_s => ExpectedTransformKind :: Same ( f_s) ,
384396
385397 // If two statements are assignments with the match values to the same place, we can optimize.
386398 (
@@ -394,22 +406,23 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
394406 f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
395407 s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
396408 ) {
397- ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
398- // Enum variants can also be simplified to an assignment statement if their values are equal.
399- // We need to consider both unsigned and signed scenarios here.
409+ ( Some ( f) , Some ( s) ) if f == s => ExpectedTransformKind :: SameByEq {
410+ place : lhs_f,
411+ ty : f_c. const_ . ty ( ) ,
412+ scalar : f,
413+ } ,
414+ // Enum variants can also be simplified to an assignment statement,
415+ // if we can use `IntToInt` cast to get an equal value.
400416 ( Some ( f) , Some ( s) )
401- if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
402- && int_equal ( f, first_val, discr_size)
403- && int_equal ( s, second_val, discr_size) )
404- || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
405- && Some ( s)
406- == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
417+ if ( can_cast ( first_case_val, discr_size, f, discr_ty. is_signed ( ) )
418+ && can_cast (
419+ second_case_val,
420+ discr_size,
421+ s,
422+ discr_ty. is_signed ( ) ,
423+ ) ) =>
407424 {
408- CompareType :: Discr {
409- place : lhs_f,
410- ty : f_c. const_ . ty ( ) ,
411- is_signed : f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
412- }
425+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_c. const_ . ty ( ) }
413426 }
414427 _ => {
415428 return None ;
@@ -420,47 +433,36 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
420433 // Otherwise we cannot optimize. Try another block.
421434 _ => return None ,
422435 } ;
423- compare_types . push ( compare_type) ;
436+ expected_transform_kinds . push ( compare_type) ;
424437 }
425438
426439 // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
427440 for ( other_val, other_target) in target_iter {
428441 let other_stmts = & bbs[ other_target] . statements ;
429- if compare_types . len ( ) != other_stmts. len ( ) {
442+ if expected_transform_kinds . len ( ) != other_stmts. len ( ) {
430443 return None ;
431444 }
432- for ( f, s) in iter:: zip ( & compare_types , other_stmts) {
445+ for ( f, s) in iter:: zip ( & expected_transform_kinds , other_stmts) {
433446 match ( * f, & s. kind ) {
434- ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
447+ ( ExpectedTransformKind :: Same ( f_s) , s_s) if f_s == s_s => { }
435448 (
436- CompareType :: Eq ( lhs_f, f_ty, val ) ,
449+ ExpectedTransformKind :: SameByEq { place : lhs_f, ty : f_ty, scalar } ,
437450 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
438451 ) if lhs_f == lhs_s
439452 && s_c. const_ . ty ( ) == f_ty
440- && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val ) => { }
453+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( scalar ) => { }
441454 (
442- CompareType :: Discr { place : lhs_f, ty : f_ty, is_signed } ,
455+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_ty } ,
443456 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
444- ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
445- let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
446- return None ;
447- } ;
448- if is_signed
449- && s_c. const_ . ty ( ) . is_signed ( )
450- && int_equal ( f, other_val, discr_size)
451- {
452- continue ;
453- }
454- if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
455- continue ;
456- }
457- return None ;
458- }
457+ ) if let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env)
458+ && lhs_f == lhs_s
459+ && s_c. const_ . ty ( ) == f_ty
460+ && can_cast ( other_val, discr_size, f, discr_ty. is_signed ( ) ) => { }
459461 _ => return None ,
460462 }
461463 }
462464 }
463- self . transfrom_types = compare_types . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
465+ self . transfrom_kinds = expected_transform_kinds . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
464466 Some ( ( ) )
465467 }
466468
@@ -478,13 +480,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
478480 let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
479481 let first = & bbs[ first] ;
480482
481- for ( t, s) in iter:: zip ( & self . transfrom_types , & first. statements ) {
483+ for ( t, s) in iter:: zip ( & self . transfrom_kinds , & first. statements ) {
482484 match ( t, & s. kind ) {
483- ( TransfromType :: Same , _ ) | ( TransfromType :: Eq , _) => {
485+ ( TransfromKind :: Same , _) => {
484486 patch. add_statement ( parent_end, s. kind . clone ( ) ) ;
485487 }
486488 (
487- TransfromType :: Discr ,
489+ TransfromKind :: Cast ,
488490 StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
489491 ) => {
490492 let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
0 commit comments