@@ -485,52 +485,77 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
485485}
486486
487487fn fuse_trivial_branches ( function : & mut Function ) {
488- let all_preds = compute_preds ( & function. blocks ) ;
488+ let mut chain_list = compute_outgoing_1to1_branches ( & function. blocks ) ;
489489 let mut rewrite_rules = FxHashMap :: default ( ) ;
490- ' outer: for ( dest_block, mut preds) in all_preds. iter ( ) . enumerate ( ) {
491- // if there's two trivial branches in a row, the middle one might get inlined before the
492- // last one, so when processing the last one, skip through to the first one.
493- let pred = loop {
494- if preds. len ( ) != 1 || preds[ 0 ] == dest_block {
495- continue ' outer;
496- }
497- let pred = preds[ 0 ] ;
498- if !function. blocks [ pred] . instructions . is_empty ( ) {
499- break pred;
500- }
501- preds = & all_preds[ pred] ;
502- } ;
503- let pred_insts = & function. blocks [ pred] . instructions ;
504- if pred_insts. last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
505- let mut dest_insts = take ( & mut function. blocks [ dest_block] . instructions ) ;
506- dest_insts. retain ( |inst| {
507- if inst. class . opcode == Op :: Phi {
508- assert_eq ! ( inst. operands. len( ) , 2 ) ;
509- rewrite_rules. insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
510- false
511- } else {
512- true
490+
491+ for block_idx in 0 ..chain_list. len ( ) {
492+ let mut next = chain_list[ block_idx] . take ( ) ;
493+ loop {
494+ match next {
495+ None => {
496+ // end of the chain list
497+ break ;
513498 }
514- } ) ;
515- let pred_insts = & mut function. blocks [ pred] . instructions ;
516- pred_insts. pop ( ) ; // pop the branch
517- pred_insts. append ( & mut dest_insts) ;
499+ Some ( x) if x == block_idx => {
500+ // loop detected
501+ break ;
502+ }
503+ Some ( next_idx) => {
504+ let mut dest_insts = take ( & mut function. blocks [ next_idx] . instructions ) ;
505+ dest_insts. retain ( |inst| {
506+ if inst. class . opcode == Op :: Phi {
507+ assert_eq ! ( inst. operands. len( ) , 2 ) ;
508+ rewrite_rules
509+ . insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
510+ false
511+ } else {
512+ true
513+ }
514+ } ) ;
515+ let self_insts = & mut function. blocks [ block_idx] . instructions ;
516+ self_insts. pop ( ) ; // pop the branch
517+ self_insts. append ( & mut dest_insts) ;
518+ next = chain_list[ next_idx] . take ( ) ;
519+ }
520+ }
518521 }
519522 }
520523 function. blocks . retain ( |b| !b. instructions . is_empty ( ) ) ;
521524 apply_rewrite_rules ( & rewrite_rules, & mut function. blocks ) ;
522525}
523526
524- fn compute_preds ( blocks : & [ Block ] ) -> Vec < Vec < usize > > {
525- let mut result = vec ! [ vec![ ] ; blocks. len( ) ] ;
527+ fn compute_outgoing_1to1_branches ( blocks : & [ Block ] ) -> Vec < Option < usize > > {
528+ let block_id_to_idx: FxHashMap < _ , _ > = blocks
529+ . iter ( )
530+ . enumerate ( )
531+ . map ( |( idx, block) | ( block. label_id ( ) . unwrap ( ) , idx) )
532+ . collect ( ) ;
533+ #[ derive( Clone ) ]
534+ enum NumIncoming {
535+ Zero ,
536+ One ( usize ) ,
537+ TooMany ,
538+ }
539+ let mut incoming = vec ! [ NumIncoming :: Zero ; blocks. len( ) ] ;
526540 for ( source_idx, source) in blocks. iter ( ) . enumerate ( ) {
527541 for dest_id in outgoing_edges ( source) {
528- let dest_idx = blocks
529- . iter ( )
530- . position ( |b| b. label_id ( ) . unwrap ( ) == dest_id)
531- . unwrap ( ) ;
532- result[ dest_idx] . push ( source_idx) ;
542+ let dest_idx = block_id_to_idx[ & dest_id] ;
543+ incoming[ dest_idx] = match incoming[ dest_idx] {
544+ NumIncoming :: Zero => NumIncoming :: One ( source_idx) ,
545+ _ => NumIncoming :: TooMany ,
546+ }
547+ }
548+ }
549+
550+ let mut result = vec ! [ None ; blocks. len( ) ] ;
551+
552+ for ( dest_idx, inc) in incoming. iter ( ) . enumerate ( ) {
553+ if let & NumIncoming :: One ( source_idx) = inc {
554+ if blocks[ source_idx] . instructions . last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
555+ result[ source_idx] = Some ( dest_idx) ;
556+ }
533557 }
534558 }
559+
535560 result
536561}
0 commit comments