22//!
33//! Currently, this pass only propagates scalar values.
44
5- use rustc_const_eval:: interpret:: { ImmTy , Immediate , InterpCx , OpTy , Projectable } ;
5+ use rustc_const_eval:: interpret:: { ImmTy , Immediate , InterpCx , OpTy , PlaceTy , Projectable } ;
66use rustc_data_structures:: fx:: FxHashMap ;
77use rustc_hir:: def:: DefKind ;
88use rustc_middle:: mir:: interpret:: { AllocId , ConstAllocation , InterpResult , Scalar } ;
99use rustc_middle:: mir:: visit:: { MutVisitor , PlaceContext , Visitor } ;
1010use rustc_middle:: mir:: * ;
11- use rustc_middle:: ty:: layout:: TyAndLayout ;
11+ use rustc_middle:: ty:: layout:: { LayoutOf , TyAndLayout } ;
1212use rustc_middle:: ty:: { self , Ty , TyCtxt } ;
1313use rustc_mir_dataflow:: value_analysis:: {
1414 Map , PlaceIndex , State , TrackElem , ValueAnalysis , ValueAnalysisWrapper , ValueOrPlace ,
1515} ;
1616use rustc_mir_dataflow:: { lattice:: FlatSet , Analysis , Results , ResultsVisitor } ;
1717use rustc_span:: def_id:: DefId ;
1818use rustc_span:: DUMMY_SP ;
19- use rustc_target:: abi:: { FieldIdx , VariantIdx } ;
19+ use rustc_target:: abi:: { Abi , FieldIdx , Size , VariantIdx , FIRST_VARIANT } ;
2020
21+ use crate :: const_prop:: throw_machine_stop_str;
2122use crate :: MirPass ;
2223
2324// These constants are somewhat random guesses and have not been optimized.
@@ -553,107 +554,151 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {
553554
554555 fn try_make_constant (
555556 & self ,
557+ ecx : & mut InterpCx < ' tcx , ' tcx , DummyMachine > ,
556558 place : Place < ' tcx > ,
557559 state : & State < FlatSet < Scalar > > ,
558560 map : & Map ,
559561 ) -> Option < Const < ' tcx > > {
560562 let ty = place. ty ( self . local_decls , self . patch . tcx ) . ty ;
563+ let layout = ecx. layout_of ( ty) . ok ( ) ?;
564+
565+ if layout. is_zst ( ) {
566+ return Some ( Const :: zero_sized ( ty) ) ;
567+ }
568+
569+ if layout. is_unsized ( ) {
570+ return None ;
571+ }
572+
561573 let place = map. find ( place. as_ref ( ) ) ?;
562- if let FlatSet :: Elem ( Scalar :: Int ( value) ) = state. get_idx ( place, map) {
563- Some ( Const :: Val ( ConstValue :: Scalar ( value. into ( ) ) , ty) )
564- } else {
565- let valtree = self . try_make_valtree ( place, ty, state, map) ?;
566- let constant = ty:: Const :: new_value ( self . patch . tcx , valtree, ty) ;
567- Some ( Const :: Ty ( constant) )
574+ if layout. abi . is_scalar ( )
575+ && let Some ( value) = propagatable_scalar ( place, state, map)
576+ {
577+ return Some ( Const :: Val ( ConstValue :: Scalar ( value) , ty) ) ;
578+ }
579+
580+ if matches ! ( layout. abi, Abi :: Scalar ( ..) | Abi :: ScalarPair ( ..) ) {
581+ let alloc_id = ecx
582+ . intern_with_temp_alloc ( layout, |ecx, dest| {
583+ try_write_constant ( ecx, dest, place, ty, state, map)
584+ } )
585+ . ok ( ) ?;
586+ return Some ( Const :: Val ( ConstValue :: Indirect { alloc_id, offset : Size :: ZERO } , ty) ) ;
568587 }
588+
589+ None
569590 }
591+ }
570592
571- fn try_make_valtree (
572- & self ,
573- place : PlaceIndex ,
574- ty : Ty < ' tcx > ,
575- state : & State < FlatSet < Scalar > > ,
576- map : & Map ,
577- ) -> Option < ty:: ValTree < ' tcx > > {
578- let tcx = self . patch . tcx ;
579- match ty. kind ( ) {
580- // ZSTs.
581- ty:: FnDef ( ..) => Some ( ty:: ValTree :: zst ( ) ) ,
582-
583- // Scalars.
584- ty:: Bool | ty:: Int ( _) | ty:: Uint ( _) | ty:: Float ( _) | ty:: Char => {
585- if let FlatSet :: Elem ( Scalar :: Int ( value) ) = state. get_idx ( place, map) {
586- Some ( ty:: ValTree :: Leaf ( value) )
587- } else {
588- None
589- }
590- }
593+ fn propagatable_scalar (
594+ place : PlaceIndex ,
595+ state : & State < FlatSet < Scalar > > ,
596+ map : & Map ,
597+ ) -> Option < Scalar > {
598+ if let FlatSet :: Elem ( value) = state. get_idx ( place, map) && value. try_to_int ( ) . is_ok ( ) {
599+ // Do not attempt to propagate pointers, as we may fail to preserve their identity.
600+ Some ( value)
601+ } else {
602+ None
603+ }
604+ }
591605
592- // Unsupported for now.
593- ty:: Array ( _, _) => None ,
594-
595- ty:: Tuple ( elem_tys) => {
596- let branches = elem_tys
597- . iter ( )
598- . enumerate ( )
599- . map ( |( i, ty) | {
600- let field = map. apply ( place, TrackElem :: Field ( FieldIdx :: from_usize ( i) ) ) ?;
601- self . try_make_valtree ( field, ty, state, map)
602- } )
603- . collect :: < Option < Vec < _ > > > ( ) ?;
604- Some ( ty:: ValTree :: Branch ( tcx. arena . alloc_from_iter ( branches. into_iter ( ) ) ) )
605- }
606+ #[ instrument( level = "trace" , skip( ecx, state, map) ) ]
607+ fn try_write_constant < ' tcx > (
608+ ecx : & mut InterpCx < ' _ , ' tcx , DummyMachine > ,
609+ dest : & PlaceTy < ' tcx > ,
610+ place : PlaceIndex ,
611+ ty : Ty < ' tcx > ,
612+ state : & State < FlatSet < Scalar > > ,
613+ map : & Map ,
614+ ) -> InterpResult < ' tcx > {
615+ let layout = ecx. layout_of ( ty) ?;
616+
617+ // Fast path for ZSTs.
618+ if layout. is_zst ( ) {
619+ return Ok ( ( ) ) ;
620+ }
621+
622+ // Fast path for scalars.
623+ if layout. abi . is_scalar ( )
624+ && let Some ( value) = propagatable_scalar ( place, state, map)
625+ {
626+ return ecx. write_immediate ( Immediate :: Scalar ( value) , dest) ;
627+ }
606628
607- ty:: Adt ( def, args) => {
608- if def. is_union ( ) {
609- return None ;
610- }
629+ match ty. kind ( ) {
630+ // ZSTs. Nothing to do.
631+ ty:: FnDef ( ..) => { }
611632
612- let ( variant_idx, variant_def, variant_place) = if def. is_enum ( ) {
613- let discr = map. apply ( place, TrackElem :: Discriminant ) ?;
614- let FlatSet :: Elem ( Scalar :: Int ( discr) ) = state. get_idx ( discr, map) else {
615- return None ;
616- } ;
617- let discr_bits = discr. assert_bits ( discr. size ( ) ) ;
618- let ( variant, _) =
619- def. discriminants ( tcx) . find ( |( _, var) | discr_bits == var. val ) ?;
620- let variant_place = map. apply ( place, TrackElem :: Variant ( variant) ) ?;
621- let variant_int = ty:: ValTree :: Leaf ( variant. as_u32 ( ) . into ( ) ) ;
622- ( Some ( variant_int) , def. variant ( variant) , variant_place)
623- } else {
624- ( None , def. non_enum_variant ( ) , place)
633+ // Those are scalars, must be handled above.
634+ ty:: Bool | ty:: Int ( _) | ty:: Uint ( _) | ty:: Float ( _) | ty:: Char => throw_machine_stop_str ! ( "primitive type with provenance" ) ,
635+
636+ ty:: Tuple ( elem_tys) => {
637+ for ( i, elem) in elem_tys. iter ( ) . enumerate ( ) {
638+ let Some ( field) = map. apply ( place, TrackElem :: Field ( FieldIdx :: from_usize ( i) ) ) else {
639+ throw_machine_stop_str ! ( "missing field in tuple" )
625640 } ;
641+ let field_dest = ecx. project_field ( dest, i) ?;
642+ try_write_constant ( ecx, & field_dest, field, elem, state, map) ?;
643+ }
644+ }
626645
627- let branches = variant_def
628- . fields
629- . iter_enumerated ( )
630- . map ( |( i, field) | {
631- let ty = field. ty ( tcx, args) ;
632- let field = map. apply ( variant_place, TrackElem :: Field ( i) ) ?;
633- self . try_make_valtree ( field, ty, state, map)
634- } )
635- . collect :: < Option < Vec < _ > > > ( ) ?;
636- Some ( ty:: ValTree :: Branch (
637- tcx. arena . alloc_from_iter ( variant_idx. into_iter ( ) . chain ( branches) ) ,
638- ) )
646+ ty:: Adt ( def, args) => {
647+ if def. is_union ( ) {
648+ throw_machine_stop_str ! ( "cannot propagate unions" )
639649 }
640650
641- // Do not attempt to support indirection in constants.
642- ty:: Ref ( ..) | ty:: RawPtr ( ..) | ty:: FnPtr ( ..) | ty:: Str | ty:: Slice ( _) => None ,
651+ let ( variant_idx, variant_def, variant_place, variant_dest) = if def. is_enum ( ) {
652+ let Some ( discr) = map. apply ( place, TrackElem :: Discriminant ) else {
653+ throw_machine_stop_str ! ( "missing discriminant for enum" )
654+ } ;
655+ let FlatSet :: Elem ( Scalar :: Int ( discr) ) = state. get_idx ( discr, map) else {
656+ throw_machine_stop_str ! ( "discriminant with provenance" )
657+ } ;
658+ let discr_bits = discr. assert_bits ( discr. size ( ) ) ;
659+ let Some ( ( variant, _) ) = def. discriminants ( * ecx. tcx ) . find ( |( _, var) | discr_bits == var. val ) else {
660+ throw_machine_stop_str ! ( "illegal discriminant for enum" )
661+ } ;
662+ let Some ( variant_place) = map. apply ( place, TrackElem :: Variant ( variant) ) else {
663+ throw_machine_stop_str ! ( "missing variant for enum" )
664+ } ;
665+ let variant_dest = ecx. project_downcast ( dest, variant) ?;
666+ ( variant, def. variant ( variant) , variant_place, variant_dest)
667+ } else {
668+ ( FIRST_VARIANT , def. non_enum_variant ( ) , place, dest. clone ( ) )
669+ } ;
670+
671+ for ( i, field) in variant_def. fields . iter_enumerated ( ) {
672+ let ty = field. ty ( * ecx. tcx , args) ;
673+ let Some ( field) = map. apply ( variant_place, TrackElem :: Field ( i) ) else {
674+ throw_machine_stop_str ! ( "missing field in ADT" )
675+ } ;
676+ let field_dest = ecx. project_field ( & variant_dest, i. as_usize ( ) ) ?;
677+ try_write_constant ( ecx, & field_dest, field, ty, state, map) ?;
678+ }
679+ ecx. write_discriminant ( variant_idx, dest) ?;
680+ }
643681
644- ty:: Never
645- | ty:: Foreign ( ..)
646- | ty:: Alias ( ..)
647- | ty:: Param ( _)
648- | ty:: Bound ( ..)
649- | ty:: Placeholder ( ..)
650- | ty:: Closure ( ..)
651- | ty:: Coroutine ( ..)
652- | ty:: Dynamic ( ..) => None ,
682+ // Unsupported for now.
683+ ty:: Array ( _, _)
653684
654- ty:: Error ( _) | ty:: Infer ( ..) | ty:: CoroutineWitness ( ..) => bug ! ( ) ,
655- }
685+ // Do not attempt to support indirection in constants.
686+ | ty:: Ref ( ..) | ty:: RawPtr ( ..) | ty:: FnPtr ( ..) | ty:: Str | ty:: Slice ( _)
687+
688+ | ty:: Never
689+ | ty:: Foreign ( ..)
690+ | ty:: Alias ( ..)
691+ | ty:: Param ( _)
692+ | ty:: Bound ( ..)
693+ | ty:: Placeholder ( ..)
694+ | ty:: Closure ( ..)
695+ | ty:: Coroutine ( ..)
696+ | ty:: Dynamic ( ..) => throw_machine_stop_str ! ( "unsupported type" ) ,
697+
698+ ty:: Error ( _) | ty:: Infer ( ..) | ty:: CoroutineWitness ( ..) => bug ! ( ) ,
656699 }
700+
701+ Ok ( ( ) )
657702}
658703
659704impl < ' mir , ' tcx >
@@ -671,8 +716,13 @@ impl<'mir, 'tcx>
671716 ) {
672717 match & statement. kind {
673718 StatementKind :: Assign ( box ( _, rvalue) ) => {
674- OperandCollector { state, visitor : self , map : & results. analysis . 0 . map }
675- . visit_rvalue ( rvalue, location) ;
719+ OperandCollector {
720+ state,
721+ visitor : self ,
722+ ecx : & mut results. analysis . 0 . ecx ,
723+ map : & results. analysis . 0 . map ,
724+ }
725+ . visit_rvalue ( rvalue, location) ;
676726 }
677727 _ => ( ) ,
678728 }
@@ -690,7 +740,12 @@ impl<'mir, 'tcx>
690740 // Don't overwrite the assignment if it already uses a constant (to keep the span).
691741 }
692742 StatementKind :: Assign ( box ( place, _) ) => {
693- if let Some ( value) = self . try_make_constant ( place, state, & results. analysis . 0 . map ) {
743+ if let Some ( value) = self . try_make_constant (
744+ & mut results. analysis . 0 . ecx ,
745+ place,
746+ state,
747+ & results. analysis . 0 . map ,
748+ ) {
694749 self . patch . assignments . insert ( location, value) ;
695750 }
696751 }
@@ -705,8 +760,13 @@ impl<'mir, 'tcx>
705760 terminator : & ' mir Terminator < ' tcx > ,
706761 location : Location ,
707762 ) {
708- OperandCollector { state, visitor : self , map : & results. analysis . 0 . map }
709- . visit_terminator ( terminator, location) ;
763+ OperandCollector {
764+ state,
765+ visitor : self ,
766+ ecx : & mut results. analysis . 0 . ecx ,
767+ map : & results. analysis . 0 . map ,
768+ }
769+ . visit_terminator ( terminator, location) ;
710770 }
711771}
712772
@@ -761,6 +821,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
761821struct OperandCollector < ' tcx , ' map , ' locals , ' a > {
762822 state : & ' a State < FlatSet < Scalar > > ,
763823 visitor : & ' a mut Collector < ' tcx , ' locals > ,
824+ ecx : & ' map mut InterpCx < ' tcx , ' tcx , DummyMachine > ,
764825 map : & ' map Map ,
765826}
766827
@@ -773,15 +834,17 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
773834 location : Location ,
774835 ) {
775836 if let PlaceElem :: Index ( local) = elem
776- && let Some ( value) = self . visitor . try_make_constant ( local. into ( ) , self . state , self . map )
837+ && let Some ( value) = self . visitor . try_make_constant ( self . ecx , local. into ( ) , self . state , self . map )
777838 {
778839 self . visitor . patch . before_effect . insert ( ( location, local. into ( ) ) , value) ;
779840 }
780841 }
781842
782843 fn visit_operand ( & mut self , operand : & Operand < ' tcx > , location : Location ) {
783844 if let Some ( place) = operand. place ( ) {
784- if let Some ( value) = self . visitor . try_make_constant ( place, self . state , self . map ) {
845+ if let Some ( value) =
846+ self . visitor . try_make_constant ( self . ecx , place, self . state , self . map )
847+ {
785848 self . visitor . patch . before_effect . insert ( ( location, place) , value) ;
786849 } else if !place. projection . is_empty ( ) {
787850 // Try to propagate into `Index` projections.
@@ -804,7 +867,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
804867 }
805868
806869 fn enforce_validity ( _ecx : & InterpCx < ' mir , ' tcx , Self > , _layout : TyAndLayout < ' tcx > ) -> bool {
807- unimplemented ! ( )
870+ false
808871 }
809872
810873 fn before_access_global (
@@ -816,13 +879,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
816879 is_write : bool ,
817880 ) -> InterpResult < ' tcx > {
818881 if is_write {
819- crate :: const_prop :: throw_machine_stop_str!( "can't write to global" ) ;
882+ throw_machine_stop_str ! ( "can't write to global" ) ;
820883 }
821884
822885 // If the static allocation is mutable, then we can't const prop it as its content
823886 // might be different at runtime.
824887 if alloc. inner ( ) . mutability . is_mut ( ) {
825- crate :: const_prop :: throw_machine_stop_str!( "can't access mutable globals in ConstProp" ) ;
888+ throw_machine_stop_str ! ( "can't access mutable globals in ConstProp" ) ;
826889 }
827890
828891 Ok ( ( ) )
@@ -872,7 +935,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
872935 _left : & rustc_const_eval:: interpret:: ImmTy < ' tcx , Self :: Provenance > ,
873936 _right : & rustc_const_eval:: interpret:: ImmTy < ' tcx , Self :: Provenance > ,
874937 ) -> interpret:: InterpResult < ' tcx , ( ImmTy < ' tcx , Self :: Provenance > , bool ) > {
875- crate :: const_prop :: throw_machine_stop_str!( "can't do pointer arithmetic" ) ;
938+ throw_machine_stop_str ! ( "can't do pointer arithmetic" ) ;
876939 }
877940
878941 fn expose_ptr (
0 commit comments