@@ -483,52 +483,77 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
483483}
484484
485485fn fuse_trivial_branches ( function : & mut Function ) {
486- let all_preds = compute_preds ( & function. blocks ) ;
486+ let mut chain_list = compute_outgoing_1to1_branches ( & function. blocks ) ;
487487 let mut rewrite_rules = FxHashMap :: default ( ) ;
488- ' outer: for ( dest_block, mut preds) in all_preds. iter ( ) . enumerate ( ) {
489- // if there's two trivial branches in a row, the middle one might get inlined before the
490- // last one, so when processing the last one, skip through to the first one.
491- let pred = loop {
492- if preds. len ( ) != 1 || preds[ 0 ] == dest_block {
493- continue ' outer;
494- }
495- let pred = preds[ 0 ] ;
496- if !function. blocks [ pred] . instructions . is_empty ( ) {
497- break pred;
498- }
499- preds = & all_preds[ pred] ;
500- } ;
501- let pred_insts = & function. blocks [ pred] . instructions ;
502- if pred_insts. last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
503- let mut dest_insts = take ( & mut function. blocks [ dest_block] . instructions ) ;
504- dest_insts. retain ( |inst| {
505- if inst. class . opcode == Op :: Phi {
506- assert_eq ! ( inst. operands. len( ) , 2 ) ;
507- rewrite_rules. insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
508- false
509- } else {
510- true
488+
489+ for block_idx in 0 ..chain_list. len ( ) {
490+ let mut next = chain_list[ block_idx] . take ( ) ;
491+ loop {
492+ match next {
493+ None => {
494+ // end of the chain list
495+ break ;
511496 }
512- } ) ;
513- let pred_insts = & mut function. blocks [ pred] . instructions ;
514- pred_insts. pop ( ) ; // pop the branch
515- pred_insts. append ( & mut dest_insts) ;
497+ Some ( x) if x == block_idx => {
498+ // loop detected
499+ break ;
500+ }
501+ Some ( next_idx) => {
502+ let mut dest_insts = take ( & mut function. blocks [ next_idx] . instructions ) ;
503+ dest_insts. retain ( |inst| {
504+ if inst. class . opcode == Op :: Phi {
505+ assert_eq ! ( inst. operands. len( ) , 2 ) ;
506+ rewrite_rules
507+ . insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
508+ false
509+ } else {
510+ true
511+ }
512+ } ) ;
513+ let self_insts = & mut function. blocks [ block_idx] . instructions ;
514+ self_insts. pop ( ) ; // pop the branch
515+ self_insts. append ( & mut dest_insts) ;
516+ next = chain_list[ next_idx] . take ( ) ;
517+ }
518+ }
516519 }
517520 }
518521 function. blocks . retain ( |b| !b. instructions . is_empty ( ) ) ;
519522 apply_rewrite_rules ( & rewrite_rules, & mut function. blocks ) ;
520523}
521524
522- fn compute_preds ( blocks : & [ Block ] ) -> Vec < Vec < usize > > {
523- let mut result = vec ! [ vec![ ] ; blocks. len( ) ] ;
525+ fn compute_outgoing_1to1_branches ( blocks : & [ Block ] ) -> Vec < Option < usize > > {
526+ let block_id_to_idx: FxHashMap < _ , _ > = blocks
527+ . iter ( )
528+ . enumerate ( )
529+ . map ( |( idx, block) | ( block. label_id ( ) . unwrap ( ) , idx) )
530+ . collect ( ) ;
531+ #[ derive( Clone ) ]
532+ enum NumIncoming {
533+ Zero ,
534+ One ( usize ) ,
535+ TooMany ,
536+ }
537+ let mut incoming = vec ! [ NumIncoming :: Zero ; blocks. len( ) ] ;
524538 for ( source_idx, source) in blocks. iter ( ) . enumerate ( ) {
525539 for dest_id in outgoing_edges ( source) {
526- let dest_idx = blocks
527- . iter ( )
528- . position ( |b| b. label_id ( ) . unwrap ( ) == dest_id)
529- . unwrap ( ) ;
530- result[ dest_idx] . push ( source_idx) ;
540+ let dest_idx = block_id_to_idx[ & dest_id] ;
541+ incoming[ dest_idx] = match incoming[ dest_idx] {
542+ NumIncoming :: Zero => NumIncoming :: One ( source_idx) ,
543+ _ => NumIncoming :: TooMany ,
544+ }
545+ }
546+ }
547+
548+ let mut result = vec ! [ None ; blocks. len( ) ] ;
549+
550+ for ( dest_idx, inc) in incoming. iter ( ) . enumerate ( ) {
551+ if let & NumIncoming :: One ( source_idx) = inc {
552+ if blocks[ source_idx] . instructions . last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
553+ result[ source_idx] = Some ( dest_idx) ;
554+ }
531555 }
532556 }
557+
533558 result
534559}
0 commit comments