@@ -151,9 +151,8 @@ pub struct ObligationForest<O: ForestObligation> {
151151 /// comments in `process_obligation` for details.
152152 active_cache : FxHashMap < O :: Predicate , usize > ,
153153
154- /// A scratch vector reused in various operations, to avoid allocating new
155- /// vectors.
156- scratch : RefCell < Vec < usize > > ,
154+ /// A vector reused in compress(), to avoid allocating new vectors.
155+ node_rewrites : RefCell < Vec < usize > > ,
157156
158157 obligation_tree_id_generator : ObligationTreeIdGenerator ,
159158
@@ -235,10 +234,6 @@ enum NodeState {
235234 /// This obligation was resolved to an error. Error nodes are
236235 /// removed from the vector by the compression step.
237236 Error ,
238-
239- /// This is a temporary state used in DFS loops to detect cycles,
240- /// it should not exist outside of these DFSes.
241- OnDfsStack ,
242237}
243238
244239#[ derive( Debug ) ]
@@ -279,7 +274,7 @@ impl<O: ForestObligation> ObligationForest<O> {
279274 nodes : vec ! [ ] ,
280275 done_cache : Default :: default ( ) ,
281276 active_cache : Default :: default ( ) ,
282- scratch : RefCell :: new ( vec ! [ ] ) ,
277+ node_rewrites : RefCell :: new ( vec ! [ ] ) ,
283278 obligation_tree_id_generator : ( 0 ..) . map ( ObligationTreeId ) ,
284279 error_cache : Default :: default ( ) ,
285280 }
@@ -305,9 +300,10 @@ impl<O: ForestObligation> ObligationForest<O> {
305300
306301 match self . active_cache . entry ( obligation. as_predicate ( ) . clone ( ) ) {
307302 Entry :: Occupied ( o) => {
303+ let index = * o. get ( ) ;
308304 debug ! ( "register_obligation_at({:?}, {:?}) - duplicate of {:?}!" ,
309- obligation, parent, o . get ( ) ) ;
310- let node = & mut self . nodes [ * o . get ( ) ] ;
305+ obligation, parent, index ) ;
306+ let node = & mut self . nodes [ index ] ;
311307 if let Some ( parent_index) = parent {
312308 // If the node is already in `active_cache`, it has already
313309 // had its chance to be marked with a parent. So if it's
@@ -342,7 +338,8 @@ impl<O: ForestObligation> ObligationForest<O> {
342338 if already_failed {
343339 Err ( ( ) )
344340 } else {
345- v. insert ( self . nodes . len ( ) ) ;
341+ let new_index = self . nodes . len ( ) ;
342+ v. insert ( new_index) ;
346343 self . nodes . push ( Node :: new ( parent, obligation, obligation_tree_id) ) ;
347344 Ok ( ( ) )
348345 }
@@ -352,15 +349,16 @@ impl<O: ForestObligation> ObligationForest<O> {
352349
353350 /// Converts all remaining obligations to the given error.
354351 pub fn to_errors < E : Clone > ( & mut self , error : E ) -> Vec < Error < O , E > > {
355- let mut errors = vec ! [ ] ;
356- for ( index , node) in self . nodes . iter ( ) . enumerate ( ) {
357- if let NodeState :: Pending = node . state . get ( ) {
358- errors . push ( Error {
352+ let errors = self . nodes . iter ( ) . enumerate ( )
353+ . filter ( | ( _index , node) | node . state . get ( ) == NodeState :: Pending )
354+ . map ( | ( index , _node ) | {
355+ Error {
359356 error : error. clone ( ) ,
360357 backtrace : self . error_at ( index) ,
361- } ) ;
362- }
363- }
358+ }
359+ } )
360+ . collect ( ) ;
361+
364362 let successful_obligations = self . compress ( DoCompleted :: Yes ) ;
365363 assert ! ( successful_obligations. unwrap( ) . is_empty( ) ) ;
366364 errors
@@ -370,15 +368,14 @@ impl<O: ForestObligation> ObligationForest<O> {
370368 pub fn map_pending_obligations < P , F > ( & self , f : F ) -> Vec < P >
371369 where F : Fn ( & O ) -> P
372370 {
373- self . nodes
374- . iter ( )
375- . filter ( |n| n. state . get ( ) == NodeState :: Pending )
376- . map ( |n| f ( & n. obligation ) )
371+ self . nodes . iter ( )
372+ . filter ( |node| node. state . get ( ) == NodeState :: Pending )
373+ . map ( |node| f ( & node. obligation ) )
377374 . collect ( )
378375 }
379376
380- fn insert_into_error_cache ( & mut self , node_index : usize ) {
381- let node = & self . nodes [ node_index ] ;
377+ fn insert_into_error_cache ( & mut self , index : usize ) {
378+ let node = & self . nodes [ index ] ;
382379 self . error_cache
383380 . entry ( node. obligation_tree_id )
384381 . or_default ( )
@@ -408,10 +405,10 @@ impl<O: ForestObligation> ObligationForest<O> {
408405 // `self.active_cache`. This means that `self.active_cache` can get
409406 // out of sync with `nodes`. It's not very common, but it does
410407 // happen, and code in `compress` has to allow for it.
411- let result = match node. state . get ( ) {
412- NodeState :: Pending => processor . process_obligation ( & mut node . obligation ) ,
413- _ => continue
414- } ;
408+ if node. state . get ( ) != NodeState :: Pending {
409+ continue ;
410+ }
411+ let result = processor . process_obligation ( & mut node . obligation ) ;
415412
416413 debug ! ( "process_obligations: node {} got result {:?}" , index, result) ;
417414
@@ -476,64 +473,53 @@ impl<O: ForestObligation> ObligationForest<O> {
476473 fn process_cycles < P > ( & self , processor : & mut P )
477474 where P : ObligationProcessor < Obligation =O >
478475 {
479- let mut stack = self . scratch . replace ( vec ! [ ] ) ;
480- debug_assert ! ( stack. is_empty( ) ) ;
476+ let mut stack = vec ! [ ] ;
481477
482478 debug ! ( "process_cycles()" ) ;
483479
484480 for ( index, node) in self . nodes . iter ( ) . enumerate ( ) {
485481 // For some benchmarks this state test is extremely
486482 // hot. It's a win to handle the no-op cases immediately to avoid
487483 // the cost of the function call.
488- match node. state . get ( ) {
489- // Match arms are in order of frequency. Pending, Success and
490- // Waiting dominate; the others are rare.
491- NodeState :: Pending => { } ,
492- NodeState :: Success => self . find_cycles_from_node ( & mut stack, processor, index) ,
493- NodeState :: Waiting | NodeState :: Done | NodeState :: Error => { } ,
494- NodeState :: OnDfsStack => self . find_cycles_from_node ( & mut stack, processor, index) ,
484+ if node. state . get ( ) == NodeState :: Success {
485+ self . find_cycles_from_node ( & mut stack, processor, index) ;
495486 }
496487 }
497488
498489 debug ! ( "process_cycles: complete" ) ;
499490
500491 debug_assert ! ( stack. is_empty( ) ) ;
501- self . scratch . replace ( stack) ;
502492 }
503493
504494 fn find_cycles_from_node < P > ( & self , stack : & mut Vec < usize > , processor : & mut P , index : usize )
505495 where P : ObligationProcessor < Obligation =O >
506496 {
507497 let node = & self . nodes [ index] ;
508- match node. state . get ( ) {
509- NodeState :: OnDfsStack => {
510- let rpos = stack. iter ( ) . rposition ( |& n| n == index) . unwrap ( ) ;
511- processor. process_backedge ( stack[ rpos..] . iter ( ) . map ( GetObligation ( & self . nodes ) ) ,
512- PhantomData ) ;
513- }
514- NodeState :: Success => {
515- node. state . set ( NodeState :: OnDfsStack ) ;
516- stack. push ( index) ;
517- for & index in node. dependents . iter ( ) {
518- self . find_cycles_from_node ( stack, processor, index) ;
498+ if node. state . get ( ) == NodeState :: Success {
499+ match stack. iter ( ) . rposition ( |& n| n == index) {
500+ None => {
501+ stack. push ( index) ;
502+ for & index in node. dependents . iter ( ) {
503+ self . find_cycles_from_node ( stack, processor, index) ;
504+ }
505+ stack. pop ( ) ;
506+ node. state . set ( NodeState :: Done ) ;
507+ }
508+ Some ( rpos) => {
509+ // Cycle detected.
510+ processor. process_backedge (
511+ stack[ rpos..] . iter ( ) . map ( GetObligation ( & self . nodes ) ) ,
512+ PhantomData
513+ ) ;
519514 }
520- stack. pop ( ) ;
521- node. state . set ( NodeState :: Done ) ;
522- } ,
523- NodeState :: Waiting | NodeState :: Pending => {
524- // This node is still reachable from some pending node. We
525- // will get to it when they are all processed.
526- }
527- NodeState :: Done | NodeState :: Error => {
528- // Already processed that node.
529515 }
530- } ;
516+ }
531517 }
532518
533519 /// Returns a vector of obligations for `p` and all of its
534520 /// ancestors, putting them into the error state in the process.
535521 fn error_at ( & self , mut index : usize ) -> Vec < O > {
536- let mut error_stack = self . scratch . replace ( vec ! [ ] ) ;
522+ let mut error_stack: Vec < usize > = vec ! [ ] ;
537523 let mut trace = vec ! [ ] ;
538524
539525 loop {
@@ -554,23 +540,32 @@ impl<O: ForestObligation> ObligationForest<O> {
554540
555541 while let Some ( index) = error_stack. pop ( ) {
556542 let node = & self . nodes [ index] ;
557- match node. state . get ( ) {
558- NodeState :: Error => continue ,
559- _ => node. state . set ( NodeState :: Error ) ,
543+ if node. state . get ( ) != NodeState :: Error {
544+ node . state . set ( NodeState :: Error ) ;
545+ error_stack . extend ( node. dependents . iter ( ) ) ;
560546 }
561-
562- error_stack. extend ( node. dependents . iter ( ) ) ;
563547 }
564548
565- self . scratch . replace ( error_stack) ;
566549 trace
567550 }
568551
569552 // This always-inlined function is for the hot call site.
570553 #[ inline( always) ]
571554 fn inlined_mark_neighbors_as_waiting_from ( & self , node : & Node < O > ) {
572555 for & index in node. dependents . iter ( ) {
573- self . mark_as_waiting_from ( & self . nodes [ index] ) ;
556+ let node = & self . nodes [ index] ;
557+ match node. state . get ( ) {
558+ NodeState :: Waiting | NodeState :: Error => { }
559+ NodeState :: Success => {
560+ node. state . set ( NodeState :: Waiting ) ;
561+ // This call site is cold.
562+ self . uninlined_mark_neighbors_as_waiting_from ( node) ;
563+ }
564+ NodeState :: Pending | NodeState :: Done => {
565+ // This call site is cold.
566+ self . uninlined_mark_neighbors_as_waiting_from ( node) ;
567+ }
568+ }
574569 }
575570 }
576571
@@ -596,37 +591,28 @@ impl<O: ForestObligation> ObligationForest<O> {
596591 }
597592 }
598593
599- fn mark_as_waiting_from ( & self , node : & Node < O > ) {
600- match node. state . get ( ) {
601- NodeState :: Waiting | NodeState :: Error | NodeState :: OnDfsStack => return ,
602- NodeState :: Success => node. state . set ( NodeState :: Waiting ) ,
603- NodeState :: Pending | NodeState :: Done => { } ,
604- }
605-
606- // This call site is cold.
607- self . uninlined_mark_neighbors_as_waiting_from ( node) ;
608- }
609-
610- /// Compresses the vector, removing all popped nodes. This adjusts
611- /// the indices and hence invalidates any outstanding
612- /// indices. Cannot be used during a transaction.
594+ /// Compresses the vector, removing all popped nodes. This adjusts the
595+ /// indices and hence invalidates any outstanding indices.
613596 ///
614597 /// Beforehand, all nodes must be marked as `Done` and no cycles
615598 /// on these nodes may be present. This is done by e.g., `process_cycles`.
616599 #[ inline( never) ]
617600 fn compress ( & mut self , do_completed : DoCompleted ) -> Option < Vec < O > > {
618- let nodes_len = self . nodes . len ( ) ;
619- let mut node_rewrites: Vec < _ > = self . scratch . replace ( vec ! [ ] ) ;
620- node_rewrites. extend ( 0 ..nodes_len) ;
601+ let orig_nodes_len = self . nodes . len ( ) ;
602+ let mut node_rewrites: Vec < _ > = self . node_rewrites . replace ( vec ! [ ] ) ;
603+ debug_assert ! ( node_rewrites. is_empty( ) ) ;
604+ node_rewrites. extend ( 0 ..orig_nodes_len) ;
621605 let mut dead_nodes = 0 ;
606+ let mut removed_done_obligations: Vec < O > = vec ! [ ] ;
622607
623- // Now move all popped nodes to the end. Try to keep the order.
608+ // Now move all Done/Error nodes to the end, preserving the order of
609+ // the Pending/Waiting nodes.
624610 //
625611 // LOOP INVARIANT:
626612 // self.nodes[0..index - dead_nodes] are the first remaining nodes
627613 // self.nodes[index - dead_nodes..index] are all dead
628614 // self.nodes[index..] are unchanged
629- for index in 0 ..self . nodes . len ( ) {
615+ for index in 0 ..orig_nodes_len {
630616 let node = & self . nodes [ index] ;
631617 match node. state . get ( ) {
632618 NodeState :: Pending | NodeState :: Waiting => {
@@ -637,7 +623,7 @@ impl<O: ForestObligation> ObligationForest<O> {
637623 }
638624 NodeState :: Done => {
639625 // This lookup can fail because the contents of
640- // `self.active_cache` is not guaranteed to match those of
626+ // `self.active_cache` are not guaranteed to match those of
641627 // `self.nodes`. See the comment in `process_obligation`
642628 // for more details.
643629 if let Some ( ( predicate, _) ) =
@@ -647,61 +633,50 @@ impl<O: ForestObligation> ObligationForest<O> {
647633 } else {
648634 self . done_cache . insert ( node. obligation . as_predicate ( ) . clone ( ) ) ;
649635 }
650- node_rewrites[ index] = nodes_len;
636+ if do_completed == DoCompleted :: Yes {
637+ // Extract the success stories.
638+ removed_done_obligations. push ( node. obligation . clone ( ) ) ;
639+ }
640+ node_rewrites[ index] = orig_nodes_len;
651641 dead_nodes += 1 ;
652642 }
653643 NodeState :: Error => {
654644 // We *intentionally* remove the node from the cache at this point. Otherwise
655645 // tests must come up with a different type on every type error they
656646 // check against.
657647 self . active_cache . remove ( node. obligation . as_predicate ( ) ) ;
658- node_rewrites[ index] = nodes_len;
659- dead_nodes += 1 ;
660648 self . insert_into_error_cache ( index) ;
649+ node_rewrites[ index] = orig_nodes_len;
650+ dead_nodes += 1 ;
661651 }
662- NodeState :: OnDfsStack | NodeState :: Success => unreachable ! ( )
652+ NodeState :: Success => unreachable ! ( )
663653 }
664654 }
665655
666- // No compression needed.
667- if dead_nodes == 0 {
668- node_rewrites. truncate ( 0 ) ;
669- self . scratch . replace ( node_rewrites) ;
670- return if do_completed == DoCompleted :: Yes { Some ( vec ! [ ] ) } else { None } ;
656+ if dead_nodes > 0 {
657+ // Remove the dead nodes and rewrite indices.
658+ self . nodes . truncate ( orig_nodes_len - dead_nodes) ;
659+ self . apply_rewrites ( & node_rewrites) ;
671660 }
672661
673- // Pop off all the nodes we killed and extract the success stories.
674- let successful = if do_completed == DoCompleted :: Yes {
675- Some ( ( 0 ..dead_nodes)
676- . map ( |_| self . nodes . pop ( ) . unwrap ( ) )
677- . flat_map ( |node| {
678- match node. state . get ( ) {
679- NodeState :: Error => None ,
680- NodeState :: Done => Some ( node. obligation ) ,
681- _ => unreachable ! ( )
682- }
683- } )
684- . collect ( ) )
685- } else {
686- self . nodes . truncate ( self . nodes . len ( ) - dead_nodes) ;
687- None
688- } ;
689- self . apply_rewrites ( & node_rewrites) ;
690-
691662 node_rewrites. truncate ( 0 ) ;
692- self . scratch . replace ( node_rewrites) ;
663+ self . node_rewrites . replace ( node_rewrites) ;
693664
694- successful
665+ if do_completed == DoCompleted :: Yes {
666+ Some ( removed_done_obligations)
667+ } else {
668+ None
669+ }
695670 }
696671
697672 fn apply_rewrites ( & mut self , node_rewrites : & [ usize ] ) {
698- let nodes_len = node_rewrites. len ( ) ;
673+ let orig_nodes_len = node_rewrites. len ( ) ;
699674
700675 for node in & mut self . nodes {
701676 let mut i = 0 ;
702677 while i < node. dependents . len ( ) {
703678 let new_index = node_rewrites[ node. dependents [ i] ] ;
704- if new_index >= nodes_len {
679+ if new_index >= orig_nodes_len {
705680 node. dependents . swap_remove ( i) ;
706681 if i == 0 && node. has_parent {
707682 // We just removed the parent.
@@ -718,7 +693,7 @@ impl<O: ForestObligation> ObligationForest<O> {
718693 // removal of nodes within `compress` can fail. See above.
719694 self . active_cache . retain ( |_predicate, index| {
720695 let new_index = node_rewrites[ * index] ;
721- if new_index >= nodes_len {
696+ if new_index >= orig_nodes_len {
722697 false
723698 } else {
724699 * index = new_index;
0 commit comments