@@ -133,18 +133,29 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
133133
134134 let mut patch = MirPatch :: new ( body) ;
135135
136- // create temp to store second discriminant in, `_s` in example above
137- let second_discriminant_temp =
138- patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
136+ let ( second_discriminant_temp, second_operand) = if opt_data. need_hoist_discriminant {
137+ // create temp to store second discriminant in, `_s` in example above
138+ let second_discriminant_temp =
139+ patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
139140
140- patch. add_statement ( parent_end, StatementKind :: StorageLive ( second_discriminant_temp) ) ;
141+ patch. add_statement (
142+ parent_end,
143+ StatementKind :: StorageLive ( second_discriminant_temp) ,
144+ ) ;
141145
142- // create assignment of discriminant
143- patch. add_assign (
144- parent_end,
145- Place :: from ( second_discriminant_temp) ,
146- Rvalue :: Discriminant ( opt_data. child_place ) ,
147- ) ;
146+ // create assignment of discriminant
147+ patch. add_assign (
148+ parent_end,
149+ Place :: from ( second_discriminant_temp) ,
150+ Rvalue :: Discriminant ( opt_data. child_place ) ,
151+ ) ;
152+ (
153+ Some ( second_discriminant_temp) ,
154+ Operand :: Move ( Place :: from ( second_discriminant_temp) ) ,
155+ )
156+ } else {
157+ ( None , Operand :: Copy ( opt_data. child_place ) )
158+ } ;
148159
149160 // create temp to store inequality comparison between the two discriminants, `_t` in
150161 // example above
@@ -153,11 +164,9 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
153164 let comp_temp = patch. new_temp ( comp_res_type, opt_data. child_source . span ) ;
154165 patch. add_statement ( parent_end, StatementKind :: StorageLive ( comp_temp) ) ;
155166
156- // create inequality comparison between the two discriminants
157- let comp_rvalue = Rvalue :: BinaryOp (
158- nequal,
159- Box :: new ( ( parent_op. clone ( ) , Operand :: Move ( Place :: from ( second_discriminant_temp) ) ) ) ,
160- ) ;
167+ // create inequality comparison
168+ let comp_rvalue =
169+ Rvalue :: BinaryOp ( nequal, Box :: new ( ( parent_op. clone ( ) , second_operand) ) ) ;
161170 patch. add_statement (
162171 parent_end,
163172 StatementKind :: Assign ( Box :: new ( ( Place :: from ( comp_temp) , comp_rvalue) ) ) ,
@@ -193,8 +202,13 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
193202 TerminatorKind :: if_ ( Operand :: Move ( Place :: from ( comp_temp) ) , true_case, false_case) ,
194203 ) ;
195204
196- // generate StorageDead for the second_discriminant_temp not in use anymore
197- patch. add_statement ( parent_end, StatementKind :: StorageDead ( second_discriminant_temp) ) ;
205+ if let Some ( second_discriminant_temp) = second_discriminant_temp {
206+ // generate StorageDead for the second_discriminant_temp not in use anymore
207+ patch. add_statement (
208+ parent_end,
209+ StatementKind :: StorageDead ( second_discriminant_temp) ,
210+ ) ;
211+ }
198212
199213 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
200214 // the switch
@@ -222,6 +236,7 @@ struct OptimizationData<'tcx> {
222236 child_place : Place < ' tcx > ,
223237 child_ty : Ty < ' tcx > ,
224238 child_source : SourceInfo ,
239+ need_hoist_discriminant : bool ,
225240}
226241
227242fn evaluate_candidate < ' tcx > (
@@ -235,70 +250,128 @@ fn evaluate_candidate<'tcx>(
235250 return None ;
236251 } ;
237252 let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
238- if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
239- // Someone could write code like this:
240- // ```rust
241- // let Q = val;
242- // if discriminant(P) == otherwise {
243- // let ptr = &mut Q as *mut _ as *mut u8;
244- // // It may be difficult for us to effectively determine whether values are valid.
245- // // Invalid values can come from all sorts of corners.
246- // unsafe { *ptr = 10; }
247- // }
248- //
249- // match P {
250- // A => match Q {
251- // A => {
252- // // code
253- // }
254- // _ => {
255- // // don't use Q
256- // }
257- // }
258- // _ => {
259- // // don't use Q
260- // }
261- // };
262- // ```
263- //
264- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant
265- // of an invalid value, which is UB.
266- // In order to fix this, **we would either need to show that the discriminant computation of
267- // `place` is computed in all branches**.
268- // FIXME(#95162) For the moment, we adopt a conservative approach and
269- // consider only the `otherwise` branch has no statements and an unreachable terminator.
270- return None ;
271- }
272253 let ( _, child) = targets. iter ( ) . next ( ) ?;
273- let child_terminator = & bbs[ child] . terminator ( ) ;
274- let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
275- & child_terminator. kind
254+
255+ let Terminator {
256+ kind : TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } ,
257+ source_info,
258+ } = bbs[ child] . terminator ( )
276259 else {
277260 return None ;
278261 } ;
279262 let child_ty = child_discr. ty ( body. local_decls ( ) , tcx) ;
280263 if child_ty != parent_ty {
281264 return None ;
282265 }
283- let Some ( StatementKind :: Assign ( boxed) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x. kind ) else {
266+
267+ // We only handle:
268+ // ```
269+ // bb4: {
270+ // _8 = discriminant((_3.1: Enum1));
271+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
272+ // }
273+ // ```
274+ // and
275+ // ```
276+ // bb2: {
277+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
278+ // }
279+ // ```
280+ if bbs[ child] . statements . len ( ) > 1 {
284281 return None ;
282+ }
283+
284+ // When thie BB has exactly one statement, this statement should be discriminant.
285+ let need_hoist_discriminant = bbs[ child] . statements . len ( ) == 1 ;
286+ let child_place = if need_hoist_discriminant {
287+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
288+ // Someone could write code like this:
289+ // ```rust
290+ // let Q = val;
291+ // if discriminant(P) == otherwise {
292+ // let ptr = &mut Q as *mut _ as *mut u8;
293+ // // It may be difficult for us to effectively determine whether values are valid.
294+ // // Invalid values can come from all sorts of corners.
295+ // unsafe { *ptr = 10; }
296+ // }
297+ //
298+ // match P {
299+ // A => match Q {
300+ // A => {
301+ // // code
302+ // }
303+ // _ => {
304+ // // don't use Q
305+ // }
306+ // }
307+ // _ => {
308+ // // don't use Q
309+ // }
310+ // };
311+ // ```
312+ //
313+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
314+ // invalid value, which is UB.
315+ // In order to fix this, **we would either need to show that the discriminant computation of
316+ // `place` is computed in all branches**.
317+ // FIXME(#95162) For the moment, we adopt a conservative approach and
318+ // consider only the `otherwise` branch has no statements and an unreachable terminator.
319+ return None ;
320+ }
321+ // Handle:
322+ // ```
323+ // bb4: {
324+ // _8 = discriminant((_3.1: Enum1));
325+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
326+ // }
327+ // ```
328+ let [
329+ Statement {
330+ kind : StatementKind :: Assign ( box ( _, Rvalue :: Discriminant ( child_place) ) ) ,
331+ ..
332+ } ,
333+ ] = bbs[ child] . statements . as_slice ( )
334+ else {
335+ return None ;
336+ } ;
337+ * child_place
338+ } else {
339+ // Handle:
340+ // ```
341+ // bb2: {
342+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
343+ // }
344+ // ```
345+ let Operand :: Copy ( child_place) = child_discr else {
346+ return None ;
347+ } ;
348+ * child_place
285349 } ;
286- let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
287- return None ;
350+ let destination = if need_hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( )
351+ {
352+ child_targets. otherwise ( )
353+ } else {
354+ targets. otherwise ( )
288355 } ;
289- let destination = child_targets. otherwise ( ) ;
290356
291357 // Verify that the optimization is legal for each branch
292358 for ( value, child) in targets. iter ( ) {
293- if !verify_candidate_branch ( & bbs[ child] , value, * child_place, destination) {
359+ if !verify_candidate_branch (
360+ & bbs[ child] ,
361+ value,
362+ child_place,
363+ destination,
364+ need_hoist_discriminant,
365+ ) {
294366 return None ;
295367 }
296368 }
297369 Some ( OptimizationData {
298370 destination,
299- child_place : * child_place ,
371+ child_place,
300372 child_ty,
301- child_source : child_terminator. source_info ,
373+ child_source : * source_info,
374+ need_hoist_discriminant,
302375 } )
303376}
304377
@@ -307,31 +380,48 @@ fn verify_candidate_branch<'tcx>(
307380 value : u128 ,
308381 place : Place < ' tcx > ,
309382 destination : BasicBlock ,
383+ need_hoist_discriminant : bool ,
310384) -> bool {
311- // In order for the optimization to be correct, the branch must...
312- // ...have exactly one statement
313- if let [ statement] = branch. statements . as_slice ( )
314- // ...assign the discriminant of `place` in that statement
315- && let StatementKind :: Assign ( boxed) = & statement. kind
316- && let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed
317- && * from_place == place
318- // ...make that assignment to a local
319- && discr_place. projection . is_empty ( )
320- // ...terminate on a `SwitchInt` that invalidates that local
321- && let TerminatorKind :: SwitchInt { discr : switch_op, targets, .. } =
322- & branch. terminator ( ) . kind
323- && * switch_op == Operand :: Move ( * discr_place)
324- // ...fall through to `destination` if the switch misses
325- && destination == targets. otherwise ( )
326- // ...have a branch for value `value`
327- && let mut iter = targets. iter ( )
328- && let Some ( ( target_value, _) ) = iter. next ( )
329- && target_value == value
330- // ...and have no more branches
331- && iter. next ( ) . is_none ( )
332- {
333- true
385+ // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
386+ let TerminatorKind :: SwitchInt { discr : switch_op, targets } = & branch. terminator ( ) . kind else {
387+ return false ;
388+ } ;
389+ if need_hoist_discriminant {
390+ // If we need hoist discriminant, the branch must have exactly one statement.
391+ let [ statement] = branch. statements . as_slice ( ) else {
392+ return false ;
393+ } ;
394+ // The statement must assign the discriminant of `place`.
395+ let StatementKind :: Assign ( box ( discr_place, Rvalue :: Discriminant ( from_place) ) ) =
396+ statement. kind
397+ else {
398+ return false ;
399+ } ;
400+ if from_place != place {
401+ return false ;
402+ }
403+ // The assignment must invalidate a local that terminate on a `SwitchInt`.
404+ if !discr_place. projection . is_empty ( ) || * switch_op != Operand :: Move ( discr_place) {
405+ return false ;
406+ }
334407 } else {
335- false
408+ // If we don't need hoist discriminant, the branch must not have any statements.
409+ if !branch. statements . is_empty ( ) {
410+ return false ;
411+ }
412+ // The place on `SwitchInt` must be the same.
413+ if * switch_op != Operand :: Copy ( place) {
414+ return false ;
415+ }
336416 }
417+ // It must fall through to `destination` if the switch misses.
418+ if destination != targets. otherwise ( ) {
419+ return false ;
420+ }
421+ // It must have exactly one branch for value `value` and have no more branches.
422+ let mut iter = targets. iter ( ) ;
423+ let ( Some ( ( target_value, _) ) , None ) = ( iter. next ( ) , iter. next ( ) ) else {
424+ return false ;
425+ } ;
426+ target_value == value
337427}
0 commit comments