3636//! cost by `MAX_COST`.
3737
3838use rustc_arena:: DroplessArena ;
39+ use rustc_const_eval:: interpret:: { ImmTy , Immediate , InterpCx , OpTy , Projectable } ;
3940use rustc_data_structures:: fx:: FxHashSet ;
4041use rustc_index:: bit_set:: BitSet ;
4142use rustc_index:: IndexVec ;
43+ use rustc_middle:: mir:: interpret:: Scalar ;
4244use rustc_middle:: mir:: visit:: Visitor ;
4345use rustc_middle:: mir:: * ;
44- use rustc_middle:: ty:: { self , ScalarInt , Ty , TyCtxt } ;
46+ use rustc_middle:: ty:: layout:: LayoutOf ;
47+ use rustc_middle:: ty:: { self , ScalarInt , TyCtxt } ;
4548use rustc_mir_dataflow:: value_analysis:: { Map , PlaceIndex , State , TrackElem } ;
49+ use rustc_span:: DUMMY_SP ;
4650use rustc_target:: abi:: { TagEncoding , Variants } ;
4751
4852use crate :: cost_checker:: CostChecker ;
53+ use crate :: dataflow_const_prop:: DummyMachine ;
4954
5055pub struct JumpThreading ;
5156
@@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
7176 let mut finder = TOFinder {
7277 tcx,
7378 param_env,
79+ ecx : InterpCx :: new ( tcx, DUMMY_SP , param_env, DummyMachine ) ,
7480 body,
7581 arena : & arena,
7682 map : & map,
@@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
8894 debug ! ( ?discr, ?bb) ;
8995
9096 let discr_ty = discr. ty ( body, tcx) . ty ;
91- let Ok ( discr_layout) = tcx . layout_of ( param_env . and ( discr_ty) ) else { continue } ;
97+ let Ok ( discr_layout) = finder . ecx . layout_of ( discr_ty) else { continue } ;
9298
9399 let Some ( discr) = finder. map . find ( discr. as_ref ( ) ) else { continue } ;
94100 debug ! ( ?discr) ;
@@ -142,6 +148,7 @@ struct ThreadingOpportunity {
142148struct TOFinder < ' tcx , ' a > {
143149 tcx : TyCtxt < ' tcx > ,
144150 param_env : ty:: ParamEnv < ' tcx > ,
151+ ecx : InterpCx < ' tcx , ' tcx , DummyMachine > ,
145152 body : & ' a Body < ' tcx > ,
146153 map : & ' a Map ,
147154 loop_headers : & ' a BitSet < BasicBlock > ,
@@ -329,25 +336,72 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
329336 }
330337
331338 #[ instrument( level = "trace" , skip( self ) ) ]
332- fn process_operand (
339+ fn process_immediate (
333340 & mut self ,
334341 bb : BasicBlock ,
335342 lhs : PlaceIndex ,
336- rhs : & Operand < ' tcx > ,
343+ rhs : ImmTy < ' tcx > ,
337344 state : & mut State < ConditionSet < ' a > > ,
338345 ) -> Option < !> {
339346 let register_opportunity = |c : Condition | {
340347 debug ! ( ?bb, ?c. target, "register" ) ;
341348 self . opportunities . push ( ThreadingOpportunity { chain : vec ! [ bb] , target : c. target } )
342349 } ;
343350
351+ let conditions = state. try_get_idx ( lhs, self . map ) ?;
352+ if let Immediate :: Scalar ( Scalar :: Int ( int) ) = * rhs {
353+ conditions. iter_matches ( int) . for_each ( register_opportunity) ;
354+ }
355+
356+ None
357+ }
358+
359+ #[ instrument( level = "trace" , skip( self ) ) ]
360+ fn process_operand (
361+ & mut self ,
362+ bb : BasicBlock ,
363+ lhs : PlaceIndex ,
364+ rhs : & Operand < ' tcx > ,
365+ state : & mut State < ConditionSet < ' a > > ,
366+ ) -> Option < !> {
344367 match rhs {
345368 // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
346369 Operand :: Constant ( constant) => {
347- let conditions = state. try_get_idx ( lhs, self . map ) ?;
348- let constant =
349- constant. const_ . normalize ( self . tcx , self . param_env ) . try_to_scalar_int ( ) ?;
350- conditions. iter_matches ( constant) . for_each ( register_opportunity) ;
370+ let constant = self . ecx . eval_mir_constant ( & constant. const_ , None , None ) . ok ( ) ?;
371+ self . map . for_each_projection_value (
372+ lhs,
373+ constant,
374+ & mut |elem, op| match elem {
375+ TrackElem :: Field ( idx) => self . ecx . project_field ( op, idx. as_usize ( ) ) . ok ( ) ,
376+ TrackElem :: Variant ( idx) => self . ecx . project_downcast ( op, idx) . ok ( ) ,
377+ TrackElem :: Discriminant => {
378+ let variant = self . ecx . read_discriminant ( op) . ok ( ) ?;
379+ let discr_value =
380+ self . ecx . discriminant_for_variant ( op. layout . ty , variant) . ok ( ) ?;
381+ Some ( discr_value. into ( ) )
382+ }
383+ TrackElem :: DerefLen => {
384+ let op: OpTy < ' _ > = self . ecx . deref_pointer ( op) . ok ( ) ?. into ( ) ;
385+ let len_usize = op. len ( & self . ecx ) . ok ( ) ?;
386+ let layout = self . ecx . layout_of ( self . tcx . types . usize ) . unwrap ( ) ;
387+ Some ( ImmTy :: from_uint ( len_usize, layout) . into ( ) )
388+ }
389+ } ,
390+ & mut |place, op| {
391+ if let Some ( conditions) = state. try_get_idx ( place, self . map )
392+ && let Ok ( imm) = self . ecx . read_immediate_raw ( op)
393+ && let Some ( imm) = imm. right ( )
394+ && let Immediate :: Scalar ( Scalar :: Int ( int) ) = * imm
395+ {
396+ conditions. iter_matches ( int) . for_each ( |c : Condition | {
397+ self . opportunities . push ( ThreadingOpportunity {
398+ chain : vec ! [ bb] ,
399+ target : c. target ,
400+ } )
401+ } )
402+ }
403+ } ,
404+ ) ;
351405 }
352406 // Transfer the conditions on the copied rhs.
353407 Operand :: Move ( rhs) | Operand :: Copy ( rhs) => {
@@ -374,18 +428,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
374428 // Below, `lhs` is the return value of `mutated_statement`,
375429 // the place to which `conditions` apply.
376430
377- let discriminant_for_variant = |enum_ty : Ty < ' tcx > , variant_index| {
378- let discr = enum_ty. discriminant_for_variant ( self . tcx , variant_index) ?;
379- let discr_layout = self . tcx . layout_of ( self . param_env . and ( discr. ty ) ) . ok ( ) ?;
380- let scalar = ScalarInt :: try_from_uint ( discr. val , discr_layout. size ) ?;
381- Some ( Operand :: const_from_scalar (
382- self . tcx ,
383- discr. ty ,
384- scalar. into ( ) ,
385- rustc_span:: DUMMY_SP ,
386- ) )
387- } ;
388-
389431 match & stmt. kind {
390432 // If we expect `discriminant(place) ?= A`,
391433 // we have an opportunity if `variant_index ?= A`.
@@ -395,7 +437,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
395437 // `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
396438 // of a niche encoding. If we cannot ensure that we write to the discriminant, do
397439 // nothing.
398- let enum_layout = self . tcx . layout_of ( self . param_env . and ( enum_ty) ) . ok ( ) ?;
440+ let enum_layout = self . ecx . layout_of ( enum_ty) . ok ( ) ?;
399441 let writes_discriminant = match enum_layout. variants {
400442 Variants :: Single { index } => {
401443 assert_eq ! ( index, * variant_index) ;
@@ -408,8 +450,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
408450 } => * variant_index != untagged_variant,
409451 } ;
410452 if writes_discriminant {
411- let discr = discriminant_for_variant ( enum_ty, * variant_index) ?;
412- self . process_operand ( bb, discr_target, & discr, state) ?;
453+ let discr = self . ecx . discriminant_for_variant ( enum_ty, * variant_index) . ok ( ) ?;
454+ self . process_immediate ( bb, discr_target, discr, state) ?;
413455 }
414456 }
415457 // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
@@ -440,10 +482,16 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
440482 AggregateKind :: Adt ( _, variant_index, ..) if agg_ty. is_enum ( ) => {
441483 if let Some ( discr_target) =
442484 self . map . apply ( lhs, TrackElem :: Discriminant )
443- && let Some ( discr_value) =
444- discriminant_for_variant ( agg_ty, * variant_index)
485+ && let Ok ( discr_value) = self
486+ . ecx
487+ . discriminant_for_variant ( agg_ty, * variant_index)
445488 {
446- self . process_operand ( bb, discr_target, & discr_value, state) ;
489+ self . process_immediate (
490+ bb,
491+ discr_target,
492+ discr_value,
493+ state,
494+ ) ;
447495 }
448496 self . map . apply ( lhs, TrackElem :: Variant ( * variant_index) ) ?
449497 }
@@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
577625
578626 let discr = discr. place ( ) ?;
579627 let discr_ty = discr. ty ( self . body , self . tcx ) . ty ;
580- let discr_layout = self . tcx . layout_of ( self . param_env . and ( discr_ty) ) . ok ( ) ?;
628+ let discr_layout = self . ecx . layout_of ( discr_ty) . ok ( ) ?;
581629 let conditions = state. try_get ( discr. as_ref ( ) , self . map ) ?;
582630
583631 if let Some ( ( value, _) ) = targets. iter ( ) . find ( |& ( _, target) | target == target_bb) {
0 commit comments