@@ -135,18 +135,29 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
135135
136136 let mut patch = MirPatch :: new ( body) ;
137137
138- // create temp to store second discriminant in, `_s` in example above
139- let second_discriminant_temp =
140- patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
138+ let ( second_discriminant_temp, second_operand) = if opt_data. hoist_discriminant {
139+ // create temp to store second discriminant in, `_s` in example above
140+ let second_discriminant_temp =
141+ patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
141142
142- patch. add_statement ( parent_end, StatementKind :: StorageLive ( second_discriminant_temp) ) ;
143+ patch. add_statement (
144+ parent_end,
145+ StatementKind :: StorageLive ( second_discriminant_temp) ,
146+ ) ;
143147
144- // create assignment of discriminant
145- patch. add_assign (
146- parent_end,
147- Place :: from ( second_discriminant_temp) ,
148- Rvalue :: Discriminant ( opt_data. child_place ) ,
149- ) ;
148+ // create assignment of discriminant
149+ patch. add_assign (
150+ parent_end,
151+ Place :: from ( second_discriminant_temp) ,
152+ Rvalue :: Discriminant ( opt_data. child_place ) ,
153+ ) ;
154+ (
155+ Some ( second_discriminant_temp) ,
156+ Operand :: Move ( Place :: from ( second_discriminant_temp) ) ,
157+ )
158+ } else {
159+ ( None , Operand :: Copy ( opt_data. child_place ) )
160+ } ;
150161
151162 // create temp to store inequality comparison between the two discriminants, `_t` in
152163 // example above
@@ -156,10 +167,8 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
156167 patch. add_statement ( parent_end, StatementKind :: StorageLive ( comp_temp) ) ;
157168
158169 // create inequality comparison between the two discriminants
159- let comp_rvalue = Rvalue :: BinaryOp (
160- nequal,
161- Box :: new ( ( parent_op. clone ( ) , Operand :: Move ( Place :: from ( second_discriminant_temp) ) ) ) ,
162- ) ;
170+ let comp_rvalue =
171+ Rvalue :: BinaryOp ( nequal, Box :: new ( ( parent_op. clone ( ) , second_operand) ) ) ;
163172 patch. add_statement (
164173 parent_end,
165174 StatementKind :: Assign ( Box :: new ( ( Place :: from ( comp_temp) , comp_rvalue) ) ) ,
@@ -194,8 +203,13 @@ impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
194203 TerminatorKind :: if_ ( Operand :: Move ( Place :: from ( comp_temp) ) , true_case, false_case) ,
195204 ) ;
196205
197- // generate StorageDead for the second_discriminant_temp not in use anymore
198- patch. add_statement ( parent_end, StatementKind :: StorageDead ( second_discriminant_temp) ) ;
206+ if let Some ( second_discriminant_temp) = second_discriminant_temp {
207+ // generate StorageDead for the second_discriminant_temp not in use anymore
208+ patch. add_statement (
209+ parent_end,
210+ StatementKind :: StorageDead ( second_discriminant_temp) ,
211+ ) ;
212+ }
199213
200214 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
201215 // the switch
@@ -271,6 +285,7 @@ struct OptimizationData<'tcx> {
271285 child_place : Place < ' tcx > ,
272286 child_ty : Ty < ' tcx > ,
273287 child_source : SourceInfo ,
288+ hoist_discriminant : bool ,
274289}
275290
276291fn evaluate_candidate < ' tcx > (
@@ -284,38 +299,6 @@ fn evaluate_candidate<'tcx>(
284299 return None ;
285300 } ;
286301 let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
287- if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
288- // Someone could write code like this:
289- // ```rust
290- // let Q = val;
291- // if discriminant(P) == otherwise {
292- // let ptr = &mut Q as *mut _ as *mut u8;
293- // // Any invalid value for the type. It is possible to be opaque, such as in other functions.
294- // unsafe { *ptr = 10; }
295- // }
296- //
297- // match P {
298- // A => match Q {
299- // A => {
300- // // code
301- // }
302- // _ => {
303- // // don't use Q
304- // }
305- // }
306- // _ => {
307- // // don't use Q
308- // }
309- // };
310- // ```
311- //
312- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
313- // invalid value, which is UB.
314- // In order to fix this, we would either need to show that the discriminant computation of
315- // `place` is computed in all branches.
316- // So we need the `otherwise` branch has no statements and an unreachable terminator.
317- return None ;
318- }
319302 let ( _, child) = targets. iter ( ) . next ( ) ?;
320303 let child_terminator = & bbs[ child] . terminator ( ) ;
321304 let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
@@ -327,31 +310,89 @@ fn evaluate_candidate<'tcx>(
327310 if child_ty != parent_ty {
328311 return None ;
329312 }
330- let Some ( StatementKind :: Assign ( boxed ) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x . kind ) else {
313+ if bbs[ child] . statements . len ( ) > 1 {
331314 return None ;
315+ }
316+ let hoist_discriminant = bbs[ child] . statements . len ( ) == 1 ;
317+ let child_place = if hoist_discriminant {
318+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
319+ // Someone could write code like this:
320+ // ```rust
321+ // let Q = val;
322+ // if discriminant(P) == otherwise {
323+ // let ptr = &mut Q as *mut _ as *mut u8;
324+ // // Any invalid value for the type. It is possible to be opaque, such as in other functions.
325+ // unsafe { *ptr = 10; }
326+ // }
327+ //
328+ // match P {
329+ // A => match Q {
330+ // A => {
331+ // // code
332+ // }
333+ // _ => {
334+ // // don't use Q
335+ // }
336+ // }
337+ // _ => {
338+ // // don't use Q
339+ // }
340+ // };
341+ // ```
342+ //
343+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
344+ // invalid value, which is UB.
345+ // In order to fix this, we would either need to show that the discriminant computation of
346+ // `place` is computed in all branches.
347+ // So we need the `otherwise` branch has no statements and an unreachable terminator.
348+ return None ;
349+ }
350+ let Some ( StatementKind :: Assign ( boxed) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x. kind )
351+ else {
352+ return None ;
353+ } ;
354+ let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
355+ return None ;
356+ } ;
357+ // Verify that the optimization is legal in general
358+ // We can hoist evaluating the child discriminant out of the branch
359+ if !may_hoist ( tcx, body, * child_place) {
360+ return None ;
361+ }
362+ * child_place
363+ } else {
364+ let TerminatorKind :: SwitchInt { discr, .. } = & bbs[ child] . terminator ( ) . kind else {
365+ return None ;
366+ } ;
367+ let Operand :: Copy ( child_place) = discr else {
368+ return None ;
369+ } ;
370+ * child_place
332371 } ;
333- let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
334- return None ;
372+ let destination = if hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
373+ child_targets. otherwise ( )
374+ } else {
375+ targets. otherwise ( )
335376 } ;
336- let destination = child_targets. otherwise ( ) ;
337-
338- // Verify that the optimization is legal in general
339- // We can hoist evaluating the child discriminant out of the branch
340- if !may_hoist ( tcx, body, * child_place) {
341- return None ;
342- }
343377
344378 // Verify that the optimization is legal for each branch
345379 for ( value, child) in targets. iter ( ) {
346- if !verify_candidate_branch ( & bbs[ child] , value, * child_place, destination) {
380+ if !verify_candidate_branch (
381+ & bbs[ child] ,
382+ value,
383+ child_place,
384+ destination,
385+ hoist_discriminant,
386+ ) {
347387 return None ;
348388 }
349389 }
350390 Some ( OptimizationData {
351391 destination,
352- child_place : * child_place ,
392+ child_place,
353393 child_ty,
354394 child_source : child_terminator. source_info ,
395+ hoist_discriminant,
355396 } )
356397}
357398
@@ -360,29 +401,38 @@ fn verify_candidate_branch<'tcx>(
360401 value : u128 ,
361402 place : Place < ' tcx > ,
362403 destination : BasicBlock ,
404+ hoist_discriminant : bool ,
363405) -> bool {
364406 // In order for the optimization to be correct, the branch must...
365407 // ...have exactly one statement
366- if branch. statements . len ( ) != 1 {
367- return false ;
368- }
369- // ...assign the discriminant of `place` in that statement
370- let StatementKind :: Assign ( boxed) = & branch. statements [ 0 ] . kind else { return false } ;
371- let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed else { return false } ;
372- if * from_place != place {
373- return false ;
374- }
375- // ...make that assignment to a local
376- if discr_place. projection . len ( ) != 0 {
408+ if ( hoist_discriminant && branch. statements . len ( ) != 1 )
409+ || ( !hoist_discriminant && !branch. statements . is_empty ( ) )
410+ {
377411 return false ;
378412 }
379413 // ...terminate on a `SwitchInt` that invalidates that local
380414 let TerminatorKind :: SwitchInt { discr : switch_op, targets, .. } = & branch. terminator ( ) . kind
381415 else {
382416 return false ;
383417 } ;
384- if * switch_op != Operand :: Move ( * discr_place) {
385- return false ;
418+ if hoist_discriminant {
419+ // ...assign the discriminant of `place` in that statement
420+ let StatementKind :: Assign ( boxed) = & branch. statements [ 0 ] . kind else { return false } ;
421+ let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed else { return false } ;
422+ if * from_place != place {
423+ return false ;
424+ }
425+ // ...make that assignment to a local
426+ if discr_place. projection . len ( ) != 0 {
427+ return false ;
428+ }
429+ if * switch_op != Operand :: Move ( * discr_place) {
430+ return false ;
431+ }
432+ } else {
433+ if * switch_op != Operand :: Copy ( place) {
434+ return false ;
435+ }
386436 }
387437 // ...fall through to `destination` if the switch misses
388438 if destination != targets. otherwise ( ) {
@@ -397,7 +447,7 @@ fn verify_candidate_branch<'tcx>(
397447 return false ;
398448 }
399449 // ...and have no more branches
400- if let Some ( _ ) = iter. next ( ) {
450+ if iter. next ( ) . is_some ( ) {
401451 return false ;
402452 }
403453 true
0 commit comments