@@ -39,7 +39,7 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
3939
4040 let mut curr: usize = 0 ;
4141 for bb in mir. basic_blocks_mut ( ) {
42- let idx = match get_aggregate_statement ( curr, & bb. statements ) {
42+ let idx = match get_aggregate_statement_index ( curr, & bb. statements ) {
4343 Some ( idx) => idx,
4444 None => continue ,
4545 } ;
@@ -48,7 +48,11 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
4848 let src_info = bb. statements [ idx] . source_info ;
4949 let suffix_stmts = bb. statements . split_off ( idx+1 ) ;
5050 let orig_stmt = bb. statements . pop ( ) . unwrap ( ) ;
51- let StatementKind :: Assign ( ref lhs, ref rhs) = orig_stmt. kind ;
51+ let ( lhs, rhs) = match orig_stmt. kind {
52+ StatementKind :: Assign ( ref lhs, ref rhs) => ( lhs, rhs) ,
53+ StatementKind :: SetDiscriminant { .. } =>
54+ span_bug ! ( src_info. span, "expected aggregate, not {:?}" , orig_stmt. kind) ,
55+ } ;
5256 let ( agg_kind, operands) = match rhs {
5357 & Rvalue :: Aggregate ( ref agg_kind, ref operands) => ( agg_kind, operands) ,
5458 _ => span_bug ! ( src_info. span, "expected aggregate, not {:?}" , rhs) ,
@@ -64,10 +68,14 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
6468 let ty = variant_def. fields [ i] . ty ( tcx, substs) ;
6569 let rhs = Rvalue :: Use ( op. clone ( ) ) ;
6670
67- // since we don't handle enums, we don't need a cast
68- let lhs_cast = lhs. clone ( ) ;
69-
70- // FIXME we cannot deaggregate enums issue: #35186
71+ let lhs_cast = if adt_def. variants . len ( ) > 1 {
72+ Lvalue :: Projection ( Box :: new ( LvalueProjection {
73+ base : lhs. clone ( ) ,
74+ elem : ProjectionElem :: Downcast ( adt_def, variant) ,
75+ } ) )
76+ } else {
77+ lhs. clone ( )
78+ } ;
7179
7280 let lhs_proj = Lvalue :: Projection ( Box :: new ( LvalueProjection {
7381 base : lhs_cast,
@@ -80,18 +88,34 @@ impl<'tcx> MirPass<'tcx> for Deaggregator {
8088 debug ! ( "inserting: {:?} @ {:?}" , new_statement, idx + i) ;
8189 bb. statements . push ( new_statement) ;
8290 }
91+
92+ // if the aggregate was an enum, we need to set the discriminant
93+ if adt_def. variants . len ( ) > 1 {
94+ let set_discriminant = Statement {
95+ kind : StatementKind :: SetDiscriminant {
96+ lvalue : lhs. clone ( ) ,
97+ variant_index : variant,
98+ } ,
99+ source_info : src_info,
100+ } ;
101+ bb. statements . push ( set_discriminant) ;
102+ } ;
103+
83104 curr = bb. statements . len ( ) ;
84105 bb. statements . extend ( suffix_stmts) ;
85106 }
86107 }
87108}
88109
89- fn get_aggregate_statement < ' a , ' tcx , ' b > ( curr : usize ,
110+ fn get_aggregate_statement_index < ' a , ' tcx , ' b > ( start : usize ,
90111 statements : & Vec < Statement < ' tcx > > )
91112 -> Option < usize > {
92- for i in curr ..statements. len ( ) {
113+ for i in start ..statements. len ( ) {
93114 let ref statement = statements[ i] ;
94- let StatementKind :: Assign ( _, ref rhs) = statement. kind ;
115+ let rhs = match statement. kind {
116+ StatementKind :: Assign ( _, ref rhs) => rhs,
117+ StatementKind :: SetDiscriminant { .. } => continue ,
118+ } ;
95119 let ( kind, operands) = match rhs {
96120 & Rvalue :: Aggregate ( ref kind, ref operands) => ( kind, operands) ,
97121 _ => continue ,
@@ -100,9 +124,8 @@ fn get_aggregate_statement<'a, 'tcx, 'b>(curr: usize,
100124 & AggregateKind :: Adt ( adt_def, variant, _) => ( adt_def, variant) ,
101125 _ => continue ,
102126 } ;
103- if operands. len ( ) == 0 || adt_def . variants . len ( ) > 1 {
127+ if operands. len ( ) == 0 {
104128 // don't deaggregate ()
105- // don't deaggregate enums ... for now
106129 continue ;
107130 }
108131 debug ! ( "getting variant {:?}" , variant) ;
0 commit comments