@@ -13,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
1313use rustc_mir_dataflow:: { lattice:: FlatSet , Analysis , ResultsVisitor , SwitchIntEdgeEffects } ;
1414use rustc_span:: DUMMY_SP ;
1515use rustc_target:: abi:: Align ;
16+ use rustc_target:: abi:: VariantIdx ;
1617
1718use crate :: MirPass ;
1819
@@ -30,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
3031
3132 #[ instrument( skip_all level = "debug" ) ]
3233 fn run_pass ( & self , tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
34+ debug ! ( def_id = ?body. source. def_id( ) ) ;
3335 if tcx. sess . mir_opt_level ( ) < 4 && body. basic_blocks . len ( ) > BLOCK_LIMIT {
3436 debug ! ( "aborted dataflow const prop due too many basic blocks" ) ;
3537 return ;
3638 }
3739
38- // Decide which places to track during the analysis.
39- let map = Map :: from_filter ( tcx, body, Ty :: is_scalar) ;
40-
4140 // We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
4241 // Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
4342 // applications, where `h` is the height of the lattice. Because the height of our lattice
@@ -46,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
4645 // `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of
4746 // map nodes is strongly correlated to the number of tracked places, this becomes more or
4847 // less `O(n)` if we place a constant limit on the number of tracked places.
49- if tcx. sess . mir_opt_level ( ) < 4 && map . tracked_places ( ) > PLACE_LIMIT {
50- debug ! ( "aborted dataflow const prop due to too many tracked places" ) ;
51- return ;
52- }
48+ let place_limit = if tcx. sess . mir_opt_level ( ) < 4 { Some ( PLACE_LIMIT ) } else { None } ;
49+
50+ // Decide which places to track during the analysis.
51+ let map = Map :: from_filter ( tcx , body , Ty :: is_scalar , place_limit ) ;
5352
5453 // Perform the actual dataflow analysis.
5554 let analysis = ConstAnalysis :: new ( tcx, body, map) ;
@@ -63,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
6362 }
6463}
6564
66- struct ConstAnalysis < ' tcx > {
65+ struct ConstAnalysis < ' a , ' tcx > {
6766 map : Map ,
6867 tcx : TyCtxt < ' tcx > ,
68+ local_decls : & ' a LocalDecls < ' tcx > ,
6969 ecx : InterpCx < ' tcx , ' tcx , DummyMachine > ,
7070 param_env : ty:: ParamEnv < ' tcx > ,
7171}
7272
73- impl < ' tcx > ValueAnalysis < ' tcx > for ConstAnalysis < ' tcx > {
73+ impl < ' tcx > ConstAnalysis < ' _ , ' tcx > {
74+ fn eval_discriminant (
75+ & self ,
76+ enum_ty : Ty < ' tcx > ,
77+ variant_index : VariantIdx ,
78+ ) -> Option < ScalarTy < ' tcx > > {
79+ if !enum_ty. is_enum ( ) {
80+ return None ;
81+ }
82+ let discr = enum_ty. discriminant_for_variant ( self . tcx , variant_index) ?;
83+ let discr_layout = self . tcx . layout_of ( self . param_env . and ( discr. ty ) ) . ok ( ) ?;
84+ let discr_value = Scalar :: try_from_uint ( discr. val , discr_layout. size ) ?;
85+ Some ( ScalarTy ( discr_value, discr. ty ) )
86+ }
87+ }
88+
89+ impl < ' tcx > ValueAnalysis < ' tcx > for ConstAnalysis < ' _ , ' tcx > {
7490 type Value = FlatSet < ScalarTy < ' tcx > > ;
7591
7692 const NAME : & ' static str = "ConstAnalysis" ;
@@ -79,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
7995 & self . map
8096 }
8197
98+ fn handle_statement ( & self , statement : & Statement < ' tcx > , state : & mut State < Self :: Value > ) {
99+ match statement. kind {
100+ StatementKind :: SetDiscriminant { box ref place, variant_index } => {
101+ state. flood_discr ( place. as_ref ( ) , & self . map ) ;
102+ if self . map . find_discr ( place. as_ref ( ) ) . is_some ( ) {
103+ let enum_ty = place. ty ( self . local_decls , self . tcx ) . ty ;
104+ if let Some ( discr) = self . eval_discriminant ( enum_ty, variant_index) {
105+ state. assign_discr (
106+ place. as_ref ( ) ,
107+ ValueOrPlace :: Value ( FlatSet :: Elem ( discr) ) ,
108+ & self . map ,
109+ ) ;
110+ }
111+ }
112+ }
113+ _ => self . super_statement ( statement, state) ,
114+ }
115+ }
116+
82117 fn handle_assign (
83118 & self ,
84119 target : Place < ' tcx > ,
@@ -87,36 +122,47 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
87122 ) {
88123 match rvalue {
89124 Rvalue :: Aggregate ( kind, operands) => {
90- let target = self . map ( ) . find ( target. as_ref ( ) ) ;
91- if let Some ( target) = target {
92- state. flood_idx_with ( target, self . map ( ) , FlatSet :: Bottom ) ;
93- let field_based = match * * kind {
94- AggregateKind :: Tuple | AggregateKind :: Closure ( ..) => true ,
95- AggregateKind :: Adt ( def_id, ..) => {
96- matches ! ( self . tcx. def_kind( def_id) , DefKind :: Struct )
125+ state. flood_with ( target. as_ref ( ) , self . map ( ) , FlatSet :: Bottom ) ;
126+ if let Some ( target_idx) = self . map ( ) . find ( target. as_ref ( ) ) {
127+ let ( variant_target, variant_index) = match * * kind {
128+ AggregateKind :: Tuple | AggregateKind :: Closure ( ..) => {
129+ ( Some ( target_idx) , None )
97130 }
98- _ => false ,
131+ AggregateKind :: Adt ( def_id, variant_index, ..) => {
132+ match self . tcx . def_kind ( def_id) {
133+ DefKind :: Struct => ( Some ( target_idx) , None ) ,
134+ DefKind :: Enum => ( Some ( target_idx) , Some ( variant_index) ) ,
135+ _ => ( None , None ) ,
136+ }
137+ }
138+ _ => ( None , None ) ,
99139 } ;
100- if field_based {
140+ if let Some ( target ) = variant_target {
101141 for ( field_index, operand) in operands. iter ( ) . enumerate ( ) {
102142 if let Some ( field) = self
103143 . map ( )
104144 . apply ( target, TrackElem :: Field ( Field :: from_usize ( field_index) ) )
105145 {
106146 let result = self . handle_operand ( operand, state) ;
107- state. assign_idx ( field, result, self . map ( ) ) ;
147+ state. insert_idx ( field, result, self . map ( ) ) ;
108148 }
109149 }
110150 }
151+ if let Some ( variant_index) = variant_index
152+ && let Some ( discr_idx) = self . map ( ) . apply ( target_idx, TrackElem :: Discriminant )
153+ {
154+ let enum_ty = target. ty ( self . local_decls , self . tcx ) . ty ;
155+ if let Some ( discr_val) = self . eval_discriminant ( enum_ty, variant_index) {
156+ state. insert_value_idx ( discr_idx, FlatSet :: Elem ( discr_val) , & self . map ) ;
157+ }
158+ }
111159 }
112160 }
113161 Rvalue :: CheckedBinaryOp ( op, box ( left, right) ) => {
162+ // Flood everything now, so we can use `insert_value_idx` directly later.
163+ state. flood ( target. as_ref ( ) , self . map ( ) ) ;
164+
114165 let target = self . map ( ) . find ( target. as_ref ( ) ) ;
115- if let Some ( target) = target {
116- // We should not track any projections other than
117- // what is overwritten below, but just in case...
118- state. flood_idx ( target, self . map ( ) ) ;
119- }
120166
121167 let value_target = target
122168 . and_then ( |target| self . map ( ) . apply ( target, TrackElem :: Field ( 0_u32 . into ( ) ) ) ) ;
@@ -127,7 +173,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
127173 let ( val, overflow) = self . binary_op ( state, * op, left, right) ;
128174
129175 if let Some ( value_target) = value_target {
130- state. assign_idx ( value_target, ValueOrPlace :: Value ( val) , self . map ( ) ) ;
176+ // We have flooded `target` earlier.
177+ state. insert_value_idx ( value_target, val, self . map ( ) ) ;
131178 }
132179 if let Some ( overflow_target) = overflow_target {
133180 let overflow = match overflow {
@@ -142,11 +189,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
142189 }
143190 FlatSet :: Bottom => FlatSet :: Bottom ,
144191 } ;
145- state. assign_idx (
146- overflow_target,
147- ValueOrPlace :: Value ( overflow) ,
148- self . map ( ) ,
149- ) ;
192+ // We have flooded `target` earlier.
193+ state. insert_value_idx ( overflow_target, overflow, self . map ( ) ) ;
150194 }
151195 }
152196 }
@@ -195,6 +239,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
195239 FlatSet :: Bottom => ValueOrPlace :: Value ( FlatSet :: Bottom ) ,
196240 FlatSet :: Top => ValueOrPlace :: Value ( FlatSet :: Top ) ,
197241 } ,
242+ Rvalue :: Discriminant ( place) => {
243+ ValueOrPlace :: Value ( state. get_discr ( place. as_ref ( ) , self . map ( ) ) )
244+ }
198245 _ => self . super_rvalue ( rvalue, state) ,
199246 }
200247 }
@@ -268,12 +315,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
268315 }
269316}
270317
271- impl < ' tcx > ConstAnalysis < ' tcx > {
272- pub fn new ( tcx : TyCtxt < ' tcx > , body : & Body < ' tcx > , map : Map ) -> Self {
318+ impl < ' a , ' tcx > ConstAnalysis < ' a , ' tcx > {
319+ pub fn new ( tcx : TyCtxt < ' tcx > , body : & ' a Body < ' tcx > , map : Map ) -> Self {
273320 let param_env = tcx. param_env ( body. source . def_id ( ) ) ;
274321 Self {
275322 map,
276323 tcx,
324+ local_decls : & body. local_decls ,
277325 ecx : InterpCx :: new ( tcx, DUMMY_SP , param_env, DummyMachine ) ,
278326 param_env : param_env,
279327 }
@@ -466,6 +514,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
466514 _ => ( ) ,
467515 }
468516 }
517+
518+ fn visit_rvalue ( & mut self , rvalue : & Rvalue < ' tcx > , location : Location ) {
519+ match rvalue {
520+ Rvalue :: Discriminant ( place) => {
521+ match self . state . get_discr ( place. as_ref ( ) , self . visitor . map ) {
522+ FlatSet :: Top => ( ) ,
523+ FlatSet :: Elem ( value) => {
524+ self . visitor . before_effect . insert ( ( location, * place) , value) ;
525+ }
526+ FlatSet :: Bottom => ( ) ,
527+ }
528+ }
529+ _ => self . super_rvalue ( rvalue, location) ,
530+ }
531+ }
469532}
470533
471534struct DummyMachine ;
0 commit comments