|
| 1 | +use rustc_index::IndexVec; |
1 | 2 | use rustc_middle::mir::*; |
2 | | -use rustc_middle::ty::TyCtxt; |
| 3 | +use rustc_middle::ty::{ParamEnv, Ty, TyCtxt}; |
3 | 4 | use std::iter; |
4 | 5 |
|
5 | 6 | use super::simplify::simplify_cfg; |
6 | 7 |
|
7 | 8 | pub struct MatchBranchSimplification; |
8 | 9 |
|
| 10 | +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { |
| 11 | + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { |
| 12 | + sess.mir_opt_level() >= 1 |
| 13 | + } |
| 14 | + |
| 15 | + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
| 16 | + let def_id = body.source.def_id(); |
| 17 | + let param_env = tcx.param_env_reveal_all_normalized(def_id); |
| 18 | + |
| 19 | + let bbs = body.basic_blocks.as_mut(); |
| 20 | + let mut should_cleanup = false; |
| 21 | + for bb_idx in bbs.indices() { |
| 22 | + if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { |
| 23 | + continue; |
| 24 | + } |
| 25 | + |
| 26 | + match bbs[bb_idx].terminator().kind { |
| 27 | + TerminatorKind::SwitchInt { |
| 28 | + discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)), |
| 29 | + ref targets, |
| 30 | + .. |
| 31 | + // We require that the possible target blocks don't contain this block. |
| 32 | + } if !targets.all_targets().contains(&bb_idx) => {} |
| 33 | + // Only optimize switch int statements |
| 34 | + _ => continue, |
| 35 | + }; |
| 36 | + |
| 37 | + if SimplifyToIf.simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) { |
| 38 | + should_cleanup = true; |
| 39 | + continue; |
| 40 | + } |
| 41 | + } |
| 42 | + |
| 43 | + if should_cleanup { |
| 44 | + simplify_cfg(body); |
| 45 | + } |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +trait SimplifyMatch<'tcx> { |
| 50 | + fn simplify( |
| 51 | + &self, |
| 52 | + tcx: TyCtxt<'tcx>, |
| 53 | + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, |
| 54 | + bbs: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 55 | + switch_bb_idx: BasicBlock, |
| 56 | + param_env: ParamEnv<'tcx>, |
| 57 | + ) -> bool { |
| 58 | + let (discr, targets) = match bbs[switch_bb_idx].terminator().kind { |
| 59 | + TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets), |
| 60 | + _ => unreachable!(), |
| 61 | + }; |
| 62 | + |
| 63 | + if !self.can_simplify(tcx, targets, param_env, bbs) { |
| 64 | + return false; |
| 65 | + } |
| 66 | + |
| 67 | + // Take ownership of items now that we know we can optimize. |
| 68 | + let discr = discr.clone(); |
| 69 | + let discr_ty = discr.ty(local_decls, tcx); |
| 70 | + |
| 71 | + // Introduce a temporary for the discriminant value. |
| 72 | + let source_info = bbs[switch_bb_idx].terminator().source_info; |
| 73 | + let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span)); |
| 74 | + |
| 75 | + // We already checked that first and second are different blocks, |
| 76 | + // and bb_idx has a different terminator from both of them. |
| 77 | + let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty); |
| 78 | + let (_, first) = targets.iter().next().unwrap(); |
| 79 | + let (from, first) = bbs.pick2_mut(switch_bb_idx, first); |
| 80 | + from.statements |
| 81 | + .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); |
| 82 | + from.statements.push(Statement { |
| 83 | + source_info, |
| 84 | + kind: StatementKind::Assign(Box::new((Place::from(discr_local), Rvalue::Use(discr)))), |
| 85 | + }); |
| 86 | + from.statements.extend(new_stmts); |
| 87 | + from.statements |
| 88 | + .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); |
| 89 | + from.terminator_mut().kind = first.terminator().kind.clone(); |
| 90 | + true |
| 91 | + } |
| 92 | + |
| 93 | + fn can_simplify( |
| 94 | + &self, |
| 95 | + tcx: TyCtxt<'tcx>, |
| 96 | + targets: &SwitchTargets, |
| 97 | + param_env: ParamEnv<'tcx>, |
| 98 | + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 99 | + ) -> bool; |
| 100 | + |
| 101 | + fn new_stmts( |
| 102 | + &self, |
| 103 | + tcx: TyCtxt<'tcx>, |
| 104 | + targets: &SwitchTargets, |
| 105 | + param_env: ParamEnv<'tcx>, |
| 106 | + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 107 | + discr_local: Local, |
| 108 | + discr_ty: Ty<'tcx>, |
| 109 | + ) -> Vec<Statement<'tcx>>; |
| 110 | +} |
| 111 | + |
| 112 | +struct SimplifyToIf; |
| 113 | + |
9 | 114 | /// If a source block is found that switches between two blocks that are exactly |
10 | 115 | /// the same modulo const bool assignments (e.g., one assigns true another false |
11 | 116 | /// to the same place), merge a target block statements into the source block, |
@@ -37,144 +142,111 @@ pub struct MatchBranchSimplification; |
37 | 142 | /// goto -> bb3; |
38 | 143 | /// } |
39 | 144 | /// ``` |
| 145 | +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { |
| 146 | + fn can_simplify( |
| 147 | + &self, |
| 148 | + tcx: TyCtxt<'tcx>, |
| 149 | + targets: &SwitchTargets, |
| 150 | + param_env: ParamEnv<'tcx>, |
| 151 | + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 152 | + ) -> bool { |
| 153 | + if targets.iter().len() != 1 { |
| 154 | + return false; |
| 155 | + } |
| 156 | + // We require that the possible target blocks all be distinct. |
| 157 | + let (_, first) = targets.iter().next().unwrap(); |
| 158 | + let second = targets.otherwise(); |
| 159 | + if first == second { |
| 160 | + return false; |
| 161 | + } |
| 162 | + // Check that destinations are identical, and if not, then don't optimize this block |
| 163 | + if bbs[first].terminator().kind != bbs[second].terminator().kind { |
| 164 | + return false; |
| 165 | + } |
40 | 166 |
|
41 | | -impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { |
42 | | - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { |
43 | | - sess.mir_opt_level() >= 1 |
44 | | - } |
45 | | - |
46 | | - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
47 | | - let def_id = body.source.def_id(); |
48 | | - let param_env = tcx.param_env_reveal_all_normalized(def_id); |
49 | | - |
50 | | - let bbs = body.basic_blocks.as_mut(); |
51 | | - let mut should_cleanup = false; |
52 | | - 'outer: for bb_idx in bbs.indices() { |
53 | | - if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { |
54 | | - continue; |
55 | | - } |
56 | | - |
57 | | - let (discr, val, first, second) = match bbs[bb_idx].terminator().kind { |
58 | | - TerminatorKind::SwitchInt { |
59 | | - discr: ref discr @ (Operand::Copy(_) | Operand::Move(_)), |
60 | | - ref targets, |
61 | | - .. |
62 | | - } if targets.iter().len() == 1 => { |
63 | | - let (value, target) = targets.iter().next().unwrap(); |
64 | | - // We require that this block and the two possible target blocks all be |
65 | | - // distinct. |
66 | | - if target == targets.otherwise() |
67 | | - || bb_idx == target |
68 | | - || bb_idx == targets.otherwise() |
69 | | - { |
70 | | - continue; |
71 | | - } |
72 | | - (discr, value, target, targets.otherwise()) |
73 | | - } |
74 | | - // Only optimize switch int statements |
75 | | - _ => continue, |
76 | | - }; |
77 | | - |
78 | | - // Check that destinations are identical, and if not, then don't optimize this block |
79 | | - if bbs[first].terminator().kind != bbs[second].terminator().kind { |
80 | | - continue; |
| 167 | + // Check that blocks are assignments of consts to the same place or same statement, |
| 168 | + // and match up 1-1, if not don't optimize this block. |
| 169 | + let first_stmts = &bbs[first].statements; |
| 170 | + let second_stmts = &bbs[second].statements; |
| 171 | + if first_stmts.len() != second_stmts.len() { |
| 172 | + return false; |
| 173 | + } |
| 174 | + for (f, s) in iter::zip(first_stmts, second_stmts) { |
| 175 | + match (&f.kind, &s.kind) { |
| 176 | + // If two statements are exactly the same, we can optimize. |
| 177 | + (f_s, s_s) if f_s == s_s => {} |
| 178 | + |
| 179 | + // If two statements are const bool assignments to the same place, we can optimize. |
| 180 | + ( |
| 181 | + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), |
| 182 | + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), |
| 183 | + ) if lhs_f == lhs_s |
| 184 | + && f_c.const_.ty().is_bool() |
| 185 | + && s_c.const_.ty().is_bool() |
| 186 | + && f_c.const_.try_eval_bool(tcx, param_env).is_some() |
| 187 | + && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} |
| 188 | + |
| 189 | + // Otherwise we cannot optimize. Try another block. |
| 190 | + _ => return false, |
81 | 191 | } |
| 192 | + } |
| 193 | + true |
| 194 | + } |
82 | 195 |
|
83 | | - // Check that blocks are assignments of consts to the same place or same statement, |
84 | | - // and match up 1-1, if not don't optimize this block. |
85 | | - let first_stmts = &bbs[first].statements; |
86 | | - let scnd_stmts = &bbs[second].statements; |
87 | | - if first_stmts.len() != scnd_stmts.len() { |
88 | | - continue; |
89 | | - } |
90 | | - for (f, s) in iter::zip(first_stmts, scnd_stmts) { |
91 | | - match (&f.kind, &s.kind) { |
92 | | - // If two statements are exactly the same, we can optimize. |
93 | | - (f_s, s_s) if f_s == s_s => {} |
94 | | - |
95 | | - // If two statements are const bool assignments to the same place, we can optimize. |
96 | | - ( |
97 | | - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), |
98 | | - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), |
99 | | - ) if lhs_f == lhs_s |
100 | | - && f_c.const_.ty().is_bool() |
101 | | - && s_c.const_.ty().is_bool() |
102 | | - && f_c.const_.try_eval_bool(tcx, param_env).is_some() |
103 | | - && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} |
104 | | - |
105 | | - // Otherwise we cannot optimize. Try another block. |
106 | | - _ => continue 'outer, |
107 | | - } |
108 | | - } |
109 | | - // Take ownership of items now that we know we can optimize. |
110 | | - let discr = discr.clone(); |
111 | | - let discr_ty = discr.ty(&body.local_decls, tcx); |
112 | | - |
113 | | - // Introduce a temporary for the discriminant value. |
114 | | - let source_info = bbs[bb_idx].terminator().source_info; |
115 | | - let discr_local = body.local_decls.push(LocalDecl::new(discr_ty, source_info.span)); |
116 | | - |
117 | | - // We already checked that first and second are different blocks, |
118 | | - // and bb_idx has a different terminator from both of them. |
119 | | - let (from, first, second) = bbs.pick3_mut(bb_idx, first, second); |
120 | | - |
121 | | - let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { |
122 | | - match (&f.kind, &s.kind) { |
123 | | - (f_s, s_s) if f_s == s_s => (*f).clone(), |
124 | | - |
125 | | - ( |
126 | | - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), |
127 | | - StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), |
128 | | - ) => { |
129 | | - // From earlier loop we know that we are dealing with bool constants only: |
130 | | - let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); |
131 | | - let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); |
132 | | - if f_b == s_b { |
133 | | - // Same value in both blocks. Use statement as is. |
134 | | - (*f).clone() |
135 | | - } else { |
136 | | - // Different value between blocks. Make value conditional on switch condition. |
137 | | - let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; |
138 | | - let const_cmp = Operand::const_from_scalar( |
139 | | - tcx, |
140 | | - discr_ty, |
141 | | - rustc_const_eval::interpret::Scalar::from_uint(val, size), |
142 | | - rustc_span::DUMMY_SP, |
143 | | - ); |
144 | | - let op = if f_b { BinOp::Eq } else { BinOp::Ne }; |
145 | | - let rhs = Rvalue::BinaryOp( |
146 | | - op, |
147 | | - Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), |
148 | | - ); |
149 | | - Statement { |
150 | | - source_info: f.source_info, |
151 | | - kind: StatementKind::Assign(Box::new((*lhs, rhs))), |
152 | | - } |
| 196 | + fn new_stmts( |
| 197 | + &self, |
| 198 | + tcx: TyCtxt<'tcx>, |
| 199 | + targets: &SwitchTargets, |
| 200 | + param_env: ParamEnv<'tcx>, |
| 201 | + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, |
| 202 | + discr_local: Local, |
| 203 | + discr_ty: Ty<'tcx>, |
| 204 | + ) -> Vec<Statement<'tcx>> { |
| 205 | + let (val, first) = targets.iter().next().unwrap(); |
| 206 | + let second = targets.otherwise(); |
| 207 | + // We already checked that first and second are different blocks, |
| 208 | + // and bb_idx has a different terminator from both of them. |
| 209 | + let first = &bbs[first]; |
| 210 | + let second = &bbs[second]; |
| 211 | + |
| 212 | + let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { |
| 213 | + match (&f.kind, &s.kind) { |
| 214 | + (f_s, s_s) if f_s == s_s => (*f).clone(), |
| 215 | + |
| 216 | + ( |
| 217 | + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), |
| 218 | + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), |
| 219 | + ) => { |
| 220 | + // From earlier loop we know that we are dealing with bool constants only: |
| 221 | + let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); |
| 222 | + let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); |
| 223 | + if f_b == s_b { |
| 224 | + // Same value in both blocks. Use statement as is. |
| 225 | + (*f).clone() |
| 226 | + } else { |
| 227 | + // Different value between blocks. Make value conditional on switch condition. |
| 228 | + let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; |
| 229 | + let const_cmp = Operand::const_from_scalar( |
| 230 | + tcx, |
| 231 | + discr_ty, |
| 232 | + rustc_const_eval::interpret::Scalar::from_uint(val, size), |
| 233 | + rustc_span::DUMMY_SP, |
| 234 | + ); |
| 235 | + let op = if f_b { BinOp::Eq } else { BinOp::Ne }; |
| 236 | + let rhs = Rvalue::BinaryOp( |
| 237 | + op, |
| 238 | + Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), |
| 239 | + ); |
| 240 | + Statement { |
| 241 | + source_info: f.source_info, |
| 242 | + kind: StatementKind::Assign(Box::new((*lhs, rhs))), |
153 | 243 | } |
154 | 244 | } |
155 | | - |
156 | | - _ => unreachable!(), |
157 | 245 | } |
158 | | - }); |
159 | | - |
160 | | - from.statements |
161 | | - .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); |
162 | | - from.statements.push(Statement { |
163 | | - source_info, |
164 | | - kind: StatementKind::Assign(Box::new(( |
165 | | - Place::from(discr_local), |
166 | | - Rvalue::Use(discr), |
167 | | - ))), |
168 | | - }); |
169 | | - from.statements.extend(new_stmts); |
170 | | - from.statements |
171 | | - .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); |
172 | | - from.terminator_mut().kind = first.terminator().kind.clone(); |
173 | | - should_cleanup = true; |
174 | | - } |
175 | 246 |
|
176 | | - if should_cleanup { |
177 | | - simplify_cfg(body); |
178 | | - } |
| 247 | + _ => unreachable!(), |
| 248 | + } |
| 249 | + }); |
| 250 | + new_stmts.collect() |
179 | 251 | } |
180 | 252 | } |
0 commit comments