@@ -14,7 +14,7 @@ use itertools::Itertools as _;
1414use rustc_index:: { bit_set:: BitSet , vec:: IndexVec } ;
1515use rustc_middle:: mir:: visit:: { NonUseContext , PlaceContext , Visitor } ;
1616use rustc_middle:: mir:: * ;
17- use rustc_middle:: ty:: { List , Ty , TyCtxt } ;
17+ use rustc_middle:: ty:: { self , List , Ty , TyCtxt } ;
1818use rustc_target:: abi:: VariantIdx ;
1919use std:: iter:: { Enumerate , Peekable } ;
2020use std:: slice:: Iter ;
@@ -527,52 +527,239 @@ fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarFiel
527527pub struct SimplifyBranchSame ;
528528
529529impl < ' tcx > MirPass < ' tcx > for SimplifyBranchSame {
530- fn run_pass ( & self , _: TyCtxt < ' tcx > , _: MirSource < ' tcx > , body : & mut Body < ' tcx > ) {
531- let mut did_remove_blocks = false ;
532- let bbs = body. basic_blocks_mut ( ) ;
533- for bb_idx in bbs. indices ( ) {
534- let targets = match & bbs[ bb_idx] . terminator ( ) . kind {
535- TerminatorKind :: SwitchInt { targets, .. } => targets,
536- _ => continue ,
537- } ;
530+ fn run_pass ( & self , tcx : TyCtxt < ' tcx > , source : MirSource < ' tcx > , body : & mut Body < ' tcx > ) {
531+ trace ! ( "Running SimplifyBranchSame on {:?}" , source) ;
532+ let finder = SimplifyBranchSameOptimizationFinder { body, tcx } ;
533+ let opts = finder. find ( ) ;
534+
535+ let did_remove_blocks = opts. len ( ) > 0 ;
536+ for opt in opts. iter ( ) {
537+ trace ! ( "SUCCESS: Applying optimization {:?}" , opt) ;
538+ // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
539+ body. basic_blocks_mut ( ) [ opt. bb_to_opt_terminator ] . terminator_mut ( ) . kind =
540+ TerminatorKind :: Goto { target : opt. bb_to_goto } ;
541+ }
542+
543+ if did_remove_blocks {
544+ // We have dead blocks now, so remove those.
545+ simplify:: remove_dead_blocks ( body) ;
546+ }
547+ }
548+ }
549+
550+ #[ derive( Debug ) ]
551+ struct SimplifyBranchSameOptimization {
552+ /// All basic blocks are equal so go to this one
553+ bb_to_goto : BasicBlock ,
554+ /// Basic block where the terminator can be simplified to a goto
555+ bb_to_opt_terminator : BasicBlock ,
556+ }
557+
558+ struct SimplifyBranchSameOptimizationFinder < ' a , ' tcx > {
559+ body : & ' a Body < ' tcx > ,
560+ tcx : TyCtxt < ' tcx > ,
561+ }
538562
539- let mut iter_bbs_reachable = targets
540- . iter ( )
541- . map ( |idx| ( * idx, & bbs[ * idx] ) )
542- . filter ( |( _, bb) | {
543- // Reaching `unreachable` is UB so assume it doesn't happen.
544- bb. terminator ( ) . kind != TerminatorKind :: Unreachable
563+ impl < ' a , ' tcx > SimplifyBranchSameOptimizationFinder < ' a , ' tcx > {
564+ fn find ( & self ) -> Vec < SimplifyBranchSameOptimization > {
565+ self . body
566+ . basic_blocks ( )
567+ . iter_enumerated ( )
568+ . filter_map ( |( bb_idx, bb) | {
569+ let ( discr_switched_on, targets) = match & bb. terminator ( ) . kind {
570+ TerminatorKind :: SwitchInt { targets, discr, .. } => ( discr, targets) ,
571+ _ => return None ,
572+ } ;
573+
574+ // find the adt that has its discriminant read
575+ // assuming this must be the last statement of the block
576+ let adt_matched_on = match & bb. statements . last ( ) ?. kind {
577+ StatementKind :: Assign ( box ( place, rhs) )
578+ if Some ( * place) == discr_switched_on. place ( ) =>
579+ {
580+ match rhs {
581+ Rvalue :: Discriminant ( adt_place) if adt_place. ty ( self . body , self . tcx ) . ty . is_enum ( ) => adt_place,
582+ _ => {
583+ trace ! ( "NO: expected a discriminant read of an enum instead of: {:?}" , rhs) ;
584+ return None ;
585+ }
586+ }
587+ }
588+ other => {
589+ trace ! ( "NO: expected an assignment of a discriminant read to a place. Found: {:?}" , other) ;
590+ return None
591+ } ,
592+ } ;
593+
594+ let mut iter_bbs_reachable = targets
595+ . iter ( )
596+ . map ( |idx| ( * idx, & self . body . basic_blocks ( ) [ * idx] ) )
597+ . filter ( |( _, bb) | {
598+ // Reaching `unreachable` is UB so assume it doesn't happen.
599+ bb. terminator ( ) . kind != TerminatorKind :: Unreachable
545600 // But `asm!(...)` could abort the program,
546601 // so we cannot assume that the `unreachable` terminator itself is reachable.
547602 // FIXME(Centril): use a normalization pass instead of a check.
548603 || bb. statements . iter ( ) . any ( |stmt| match stmt. kind {
549604 StatementKind :: LlvmInlineAsm ( ..) => true ,
550605 _ => false ,
551606 } )
552- } )
553- . peekable ( ) ;
554-
555- // We want to `goto -> bb_first`.
556- let bb_first = iter_bbs_reachable. peek ( ) . map ( |( idx, _) | * idx) . unwrap_or ( targets[ 0 ] ) ;
557-
558- // All successor basic blocks should have the exact same form.
559- let all_successors_equivalent =
560- iter_bbs_reachable. map ( |( _, bb) | bb) . tuple_windows ( ) . all ( |( bb_l, bb_r) | {
561- bb_l. is_cleanup == bb_r. is_cleanup
562- && bb_l. terminator ( ) . kind == bb_r. terminator ( ) . kind
563- && bb_l. statements . iter ( ) . eq_by ( & bb_r. statements , |x, y| x. kind == y. kind )
564- } ) ;
565-
566- if all_successors_equivalent {
567- // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
568- bbs[ bb_idx] . terminator_mut ( ) . kind = TerminatorKind :: Goto { target : bb_first } ;
569- did_remove_blocks = true ;
607+ } )
608+ . peekable ( ) ;
609+
610+ let bb_first = iter_bbs_reachable. peek ( ) . map ( |( idx, _) | * idx) . unwrap_or ( targets[ 0 ] ) ;
611+ let mut all_successors_equivalent = StatementEquality :: TrivialEqual ;
612+
613+ // All successor basic blocks must be equal or contain statements that are pairwise considered equal.
614+ for ( ( bb_l_idx, bb_l) , ( bb_r_idx, bb_r) ) in iter_bbs_reachable. tuple_windows ( ) {
615+ let trivial_checks = bb_l. is_cleanup == bb_r. is_cleanup
616+ && bb_l. terminator ( ) . kind == bb_r. terminator ( ) . kind ;
617+ let statement_check = || {
618+ bb_l. statements . iter ( ) . zip ( & bb_r. statements ) . try_fold ( StatementEquality :: TrivialEqual , |acc, ( l, r) | {
619+ let stmt_equality = self . statement_equality ( * adt_matched_on, & l, bb_l_idx, & r, bb_r_idx) ;
620+ if matches ! ( stmt_equality, StatementEquality :: NotEqual ) {
621+ // short circuit
622+ None
623+ } else {
624+ Some ( acc. combine ( & stmt_equality) )
625+ }
626+ } )
627+ . unwrap_or ( StatementEquality :: NotEqual )
628+ } ;
629+ if !trivial_checks {
630+ all_successors_equivalent = StatementEquality :: NotEqual ;
631+ break ;
632+ }
633+ all_successors_equivalent = all_successors_equivalent. combine ( & statement_check ( ) ) ;
634+ } ;
635+
636+ match all_successors_equivalent{
637+ StatementEquality :: TrivialEqual => {
638+ // statements are trivially equal, so just take first
639+ trace ! ( "Statements are trivially equal" ) ;
640+ Some ( SimplifyBranchSameOptimization {
641+ bb_to_goto : bb_first,
642+ bb_to_opt_terminator : bb_idx,
643+ } )
644+ }
645+ StatementEquality :: ConsideredEqual ( bb_to_choose) => {
646+ trace ! ( "Statements are considered equal" ) ;
647+ Some ( SimplifyBranchSameOptimization {
648+ bb_to_goto : bb_to_choose,
649+ bb_to_opt_terminator : bb_idx,
650+ } )
651+ }
652+ StatementEquality :: NotEqual => {
653+ trace ! ( "NO: not all successors of basic block {:?} were equivalent" , bb_idx) ;
654+ None
655+ }
656+ }
657+ } )
658+ . collect ( )
659+ }
660+
661+ /// Tests if two statements can be considered equal
662+ ///
663+ /// Statements can be trivially equal if the kinds match.
664+ /// But they can also be considered equal in the following case A:
665+ /// ```
666+ /// discriminant(_0) = 0; // bb1
667+ /// _0 = move _1; // bb2
668+ /// ```
669+ /// In this case the two statements are equal iff
670+ /// 1: _0 is an enum where the variant index 0 is fieldless, and
671+ /// 2: bb1 was targeted by a switch where the discriminant of _1 was switched on
672+ fn statement_equality (
673+ & self ,
674+ adt_matched_on : Place < ' tcx > ,
675+ x : & Statement < ' tcx > ,
676+ x_bb_idx : BasicBlock ,
677+ y : & Statement < ' tcx > ,
678+ y_bb_idx : BasicBlock ,
679+ ) -> StatementEquality {
680+ let helper = |rhs : & Rvalue < ' tcx > ,
681+ place : & Box < Place < ' tcx > > ,
682+ variant_index : & VariantIdx ,
683+ side_to_choose| {
684+ let place_type = place. ty ( self . body , self . tcx ) . ty ;
685+ let adt = match place_type. kind {
686+ ty:: Adt ( adt, _) if adt. is_enum ( ) => adt,
687+ _ => return StatementEquality :: NotEqual ,
688+ } ;
689+ let variant_is_fieldless = adt. variants [ * variant_index] . fields . is_empty ( ) ;
690+ if !variant_is_fieldless {
691+ trace ! ( "NO: variant {:?} was not fieldless" , variant_index) ;
692+ return StatementEquality :: NotEqual ;
693+ }
694+
695+ match rhs {
696+ Rvalue :: Use ( operand) if operand. place ( ) == Some ( adt_matched_on) => {
697+ StatementEquality :: ConsideredEqual ( side_to_choose)
698+ }
699+ _ => {
700+ trace ! (
701+ "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}" ,
702+ rhs,
703+ adt_matched_on
704+ ) ;
705+ StatementEquality :: NotEqual
706+ }
707+ }
708+ } ;
709+ match ( & x. kind , & y. kind ) {
710+ // trivial case
711+ ( x, y) if x == y => StatementEquality :: TrivialEqual ,
712+
713+ // check for case A
714+ (
715+ StatementKind :: Assign ( box ( _, rhs) ) ,
716+ StatementKind :: SetDiscriminant { place, variant_index } ,
717+ ) => {
718+ // choose basic block of x, as that has the assign
719+ helper ( rhs, place, variant_index, x_bb_idx)
720+ }
721+ (
722+ StatementKind :: SetDiscriminant { place, variant_index } ,
723+ StatementKind :: Assign ( box ( _, rhs) ) ,
724+ ) => {
725+ // choose basic block of y, as that has the assign
726+ helper ( rhs, place, variant_index, y_bb_idx)
727+ }
728+ _ => {
729+ trace ! ( "NO: statements `{:?}` and `{:?}` not considered equal" , x, y) ;
730+ StatementEquality :: NotEqual
570731 }
571732 }
733+ }
734+ }
572735
573- if did_remove_blocks {
574- // We have dead blocks now, so remove those.
575- simplify:: remove_dead_blocks ( body) ;
736+ #[ derive( Copy , Clone , Eq , PartialEq ) ]
737+ enum StatementEquality {
738+ /// The two statements are trivially equal; same kind
739+ TrivialEqual ,
740+ /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization.
741+ /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index
742+ ConsideredEqual ( BasicBlock ) ,
743+ /// The two statements are not equal
744+ NotEqual ,
745+ }
746+
747+ impl StatementEquality {
748+ fn combine ( & self , other : & StatementEquality ) -> StatementEquality {
749+ use StatementEquality :: * ;
750+ match ( self , other) {
751+ ( TrivialEqual , TrivialEqual ) => TrivialEqual ,
752+ ( TrivialEqual , ConsideredEqual ( b) ) | ( ConsideredEqual ( b) , TrivialEqual ) => {
753+ ConsideredEqual ( * b)
754+ }
755+ ( ConsideredEqual ( b1) , ConsideredEqual ( b2) ) => {
756+ if b1 == b2 {
757+ ConsideredEqual ( * b1)
758+ } else {
759+ NotEqual
760+ }
761+ }
762+ ( _, NotEqual ) | ( NotEqual , _) => NotEqual ,
576763 }
577764 }
578765}
0 commit comments