@@ -7,10 +7,12 @@ use rustc_middle::hir;
77use rustc_middle:: ich:: StableHashingContext ;
88use rustc_middle:: mir:: interpret:: Scalar ;
99use rustc_middle:: mir:: {
10- self , BasicBlock , BasicBlockData , CoverageData , Operand , Place , SourceInfo , StatementKind ,
11- Terminator , TerminatorKind , START_BLOCK ,
10+ self , traversal , BasicBlock , BasicBlockData , CoverageData , Operand , Place , SourceInfo ,
11+ StatementKind , Terminator , TerminatorKind , START_BLOCK ,
1212} ;
1313use rustc_middle:: ty;
14+ use rustc_middle:: ty:: query:: Providers ;
15+ use rustc_middle:: ty:: FnDef ;
1416use rustc_middle:: ty:: TyCtxt ;
1517use rustc_span:: def_id:: DefId ;
1618use rustc_span:: Span ;
@@ -19,6 +21,31 @@ use rustc_span::Span;
1921/// the intrinsic llvm.instrprof.increment.
2022pub struct InstrumentCoverage ;
2123
24+ /// The `query` provider for `CoverageData`, requested by `codegen_intrinsic_call()` when
25+ /// constructing the arguments for `llvm.instrprof.increment`.
26+ pub ( crate ) fn provide ( providers : & mut Providers < ' _ > ) {
27+ providers. coverage_data = |tcx, def_id| {
28+ let body = tcx. optimized_mir ( def_id) ;
29+ let count_code_region_fn =
30+ tcx. require_lang_item ( lang_items:: CountCodeRegionFnLangItem , None ) ;
31+ let mut num_counters: u32 = 0 ;
32+ for ( _, data) in traversal:: preorder ( body) {
33+ if let Some ( terminator) = & data. terminator {
34+ if let TerminatorKind :: Call { func : Operand :: Constant ( func) , .. } = & terminator. kind
35+ {
36+ if let FnDef ( called_fn_def_id, _) = func. literal . ty . kind {
37+ if called_fn_def_id == count_code_region_fn {
38+ num_counters += 1 ;
39+ }
40+ }
41+ }
42+ }
43+ }
44+ let hash = if num_counters > 0 { hash_mir_source ( tcx, def_id) } else { 0 } ;
45+ CoverageData { num_counters, hash }
46+ } ;
47+ }
48+
2249struct Instrumentor < ' tcx > {
2350 tcx : TyCtxt < ' tcx > ,
2451 num_counters : u32 ,
@@ -30,20 +57,12 @@ impl<'tcx> MirPass<'tcx> for InstrumentCoverage {
3057 // If the InstrumentCoverage pass is called on promoted MIRs, skip them.
3158 // See: https://github.com/rust-lang/rust/pull/73011#discussion_r438317601
3259 if src. promoted . is_none ( ) {
33- assert ! ( mir_body. coverage_data. is_none( ) ) ;
34-
35- let hash = hash_mir_source ( tcx, & src) ;
36-
3760 debug ! (
38- "instrumenting {:?}, hash: {}, span: {}" ,
61+ "instrumenting {:?}, span: {}" ,
3962 src. def_id( ) ,
40- hash,
4163 tcx. sess. source_map( ) . span_to_string( mir_body. span)
4264 ) ;
43-
44- let num_counters = Instrumentor :: new ( tcx) . inject_counters ( mir_body) ;
45-
46- mir_body. coverage_data = Some ( CoverageData { hash, num_counters } ) ;
65+ Instrumentor :: new ( tcx) . inject_counters ( mir_body) ;
4766 }
4867 }
4968 }
@@ -60,15 +79,13 @@ impl<'tcx> Instrumentor<'tcx> {
6079 next
6180 }
6281
63- fn inject_counters ( & mut self , mir_body : & mut mir:: Body < ' tcx > ) -> u32 {
82+ fn inject_counters ( & mut self , mir_body : & mut mir:: Body < ' tcx > ) {
6483 // FIXME(richkadel): As a first step, counters are only injected at the top of each
6584 // function. The complete solution will inject counters at each conditional code branch.
6685 let top_of_function = START_BLOCK ;
6786 let entire_function = mir_body. span ;
6887
6988 self . inject_counter ( mir_body, top_of_function, entire_function) ;
70-
71- self . num_counters
7289 }
7390
7491 fn inject_counter (
@@ -138,14 +155,9 @@ fn placeholder_block(span: Span) -> BasicBlockData<'tcx> {
138155 }
139156}
140157
141- fn hash_mir_source < ' tcx > ( tcx : TyCtxt < ' tcx > , src : & MirSource < ' tcx > ) -> u64 {
142- let fn_body_id = match tcx. hir ( ) . get_if_local ( src. def_id ( ) ) {
143- Some ( node) => match hir:: map:: associated_body ( node) {
144- Some ( body_id) => body_id,
145- _ => bug ! ( "instrumented MirSource does not include a function body: {:?}" , node) ,
146- } ,
147- None => bug ! ( "instrumented MirSource is not local: {:?}" , src) ,
148- } ;
158+ fn hash_mir_source < ' tcx > ( tcx : TyCtxt < ' tcx > , def_id : DefId ) -> u64 {
159+ let hir_node = tcx. hir ( ) . get_if_local ( def_id) . expect ( "DefId is local" ) ;
160+ let fn_body_id = hir:: map:: associated_body ( hir_node) . expect ( "HIR node is a function with body" ) ;
149161 let hir_body = tcx. hir ( ) . body ( fn_body_id) ;
150162 let mut hcx = tcx. create_no_span_stable_hashing_context ( ) ;
151163 hash ( & mut hcx, & hir_body. value ) . to_smaller_hash ( )
0 commit comments