@@ -37,28 +37,17 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3737 . find ( |inst| inst. class . opcode == Op :: TypeVoid )
3838 . map ( |inst| inst. result_id . unwrap ( ) )
3939 . unwrap_or ( 0 ) ;
40- let ptr_map: FxHashMap < _ , _ > = module
41- . types_global_values
42- . iter ( )
43- . filter_map ( |inst| {
44- if inst. class . opcode == Op :: TypePointer
45- && inst. operands [ 0 ] . unwrap_storage_class ( ) == StorageClass :: Function
46- {
47- Some ( ( inst. operands [ 1 ] . unwrap_id_ref ( ) , inst. result_id . unwrap ( ) ) )
48- } else {
49- None
50- }
51- } )
52- . collect ( ) ;
40+
41+ let invalid_args = module. functions . iter ( ) . flat_map ( get_invalid_args) . collect ( ) ;
42+
5343 // Drop all the functions we'll be inlining. (This also means we won't waste time processing
5444 // inlines in functions that will get inlined)
5545 let mut inliner = Inliner {
5646 header : module. header . as_mut ( ) . unwrap ( ) ,
57- types_global_values : & mut module. types_global_values ,
5847 void,
59- ptr_map,
6048 functions : & functions,
6149 needs_inline : & to_delete,
50+ invalid_args,
6251 } ;
6352 for index in postorder {
6453 inliner. inline_fn ( & mut module. functions , index) ;
@@ -270,20 +259,21 @@ fn should_inline(
270259// This should be more general, but a very common problem is passing an OpAccessChain to an
271260// OpFunctionCall (i.e. `f(&s.x)`, or more commonly, `s.x.f()` where `f` takes `&self`), so detect
272261// that case and inline the call.
273- fn args_invalid ( function : & Function , call : & Instruction ) -> bool {
274- for inst in function. all_inst_iter ( ) {
262+ fn get_invalid_args < ' a > ( function : & ' a Function ) -> impl Iterator < Item = Word > + ' a {
263+ function. all_inst_iter ( ) . filter_map ( |inst| {
275264 if inst. class . opcode == Op :: AccessChain {
276- let inst_result = inst. result_id . unwrap ( ) ;
277- if call
278- . operands
279- . iter ( )
280- . any ( |op| * op == Operand :: IdRef ( inst_result) )
281- {
282- return true ;
283- }
265+ inst. result_id
266+ } else {
267+ None
284268 }
285- }
286- false
269+ } )
270+ }
271+
272+ fn args_invalid ( invalid_args : & FxHashSet < Word > , call : & Instruction ) -> bool {
273+ call. operands . iter ( ) . skip ( 1 ) . any ( |op| {
274+ op. id_ref_any ( )
275+ . map_or ( false , |arg| invalid_args. contains ( & arg) )
276+ } )
287277}
288278
289279// Steps:
@@ -294,11 +284,10 @@ fn args_invalid(function: &Function, call: &Instruction) -> bool {
294284
295285struct Inliner < ' m , ' map > {
296286 header : & ' m mut ModuleHeader ,
297- types_global_values : & ' m mut Vec < Instruction > ,
298287 void : Word ,
299- ptr_map : FxHashMap < Word , Word > ,
300288 functions : & ' map FunctionMap ,
301289 needs_inline : & ' map [ bool ] ,
290+ invalid_args : FxHashSet < Word > ,
302291}
303292
304293impl Inliner < ' _ , ' _ > {
@@ -308,25 +297,6 @@ impl Inliner<'_, '_> {
308297 result
309298 }
310299
311- fn ptr_ty ( & mut self , pointee : Word ) -> Word {
312- let existing = self . ptr_map . get ( & pointee) ;
313- if let Some ( existing) = existing {
314- return * existing;
315- }
316- let inst_id = self . id ( ) ;
317- self . types_global_values . push ( Instruction :: new (
318- Op :: TypePointer ,
319- None ,
320- Some ( inst_id) ,
321- vec ! [
322- Operand :: StorageClass ( StorageClass :: Function ) ,
323- Operand :: IdRef ( pointee) ,
324- ] ,
325- ) ) ;
326- self . ptr_map . insert ( pointee, inst_id) ;
327- inst_id
328- }
329-
330300 fn inline_fn ( & mut self , functions : & mut [ Function ] , index : usize ) {
331301 let mut function = take ( & mut functions[ index] ) ;
332302 let mut block_idx = 0 ;
@@ -361,8 +331,8 @@ impl Inliner<'_, '_> {
361331 self . functions [ & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ] ,
362332 )
363333 } )
364- . find ( |( index , inst, func_idx) | {
365- self . needs_inline [ * func_idx] || args_invalid ( caller , inst)
334+ . find ( |( _ , inst, func_idx) | {
335+ self . needs_inline [ * func_idx] || args_invalid ( & self . invalid_args , inst)
366336 } ) ;
367337 let ( call_index, call_inst, callee_idx) = match call {
368338 None => return false ,
@@ -390,18 +360,23 @@ impl Inliner<'_, '_> {
390360 } ) ;
391361 let mut rewrite_rules = callee_parameters. zip ( call_arguments) . collect ( ) ;
392362
393- let return_variable = if call_result_type. is_some ( ) {
394- Some ( self . id ( ) )
395- } else {
396- None
397- } ;
398363 let return_jump = self . id ( ) ;
399364 // Rewrite OpReturns of the callee.
400- let mut inlined_blocks = get_inlined_blocks ( callee, return_variable , return_jump) ;
365+ let ( mut inlined_blocks, phi_pairs ) = get_inlined_blocks ( callee, return_jump) ;
401366 // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
402367 // fn is inlined multiple times.
403368 self . add_clone_id_rules ( & mut rewrite_rules, & inlined_blocks) ;
369+ // If any of the OpReturns were invalid, return will also be invalid.
370+ for ( value, _) in & phi_pairs {
371+ if self . invalid_args . contains ( value) {
372+ self . invalid_args . insert ( call_result_id) ;
373+ self . invalid_args
374+ . insert ( * rewrite_rules. get ( value) . unwrap_or ( value) ) ;
375+ }
376+ }
404377 apply_rewrite_rules ( & rewrite_rules, & mut inlined_blocks) ;
378+ // unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args);
379+ // as no values from inside the inlined function ever make it directly out.
405380
406381 // Split the block containing the OpFunctionCall into two, around the call.
407382 let mut post_call_block_insts = caller. blocks [ block_idx]
@@ -411,32 +386,27 @@ impl Inliner<'_, '_> {
411386 let call = caller. blocks [ block_idx] . instructions . pop ( ) . unwrap ( ) ;
412387 assert ! ( call. class. opcode == Op :: FunctionCall ) ;
413388
414- if let Some ( call_result_type) = call_result_type {
415- // Generate the storage space for the return value: Do this *after* the split above,
416- // because if block_idx=0, inserting a variable here shifts call_index.
417- insert_opvariable (
418- & mut caller. blocks [ 0 ] ,
419- self . ptr_ty ( call_result_type) ,
420- return_variable. unwrap ( ) ,
421- ) ;
422- }
423-
424389 // Move the variables over from the inlined function to here.
425390 let mut callee_header = take ( & mut inlined_blocks[ 0 ] ) . instructions ;
426391 // TODO: OpLine handling
427392 let num_variables = callee_header. partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
428393 // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
429394 // it, and we maintain the invariant that current block has finished processing.
430- let jump_to = self . id ( ) ;
395+ let first_block_id = self . id ( ) ;
431396 inlined_blocks[ 0 ] = Block {
432- label : Some ( Instruction :: new ( Op :: Label , None , Some ( jump_to) , vec ! [ ] ) ) ,
397+ label : Some ( Instruction :: new (
398+ Op :: Label ,
399+ None ,
400+ Some ( first_block_id) ,
401+ vec ! [ ] ,
402+ ) ) ,
433403 instructions : callee_header. split_off ( num_variables) ,
434404 } ;
435405 caller. blocks [ block_idx] . instructions . push ( Instruction :: new (
436406 Op :: Branch ,
437407 None ,
438408 None ,
439- vec ! [ Operand :: IdRef ( jump_to ) ] ,
409+ vec ! [ Operand :: IdRef ( first_block_id ) ] ,
440410 ) ) ;
441411 // Move the OpVariables of the callee to the caller.
442412 insert_opvariables ( & mut caller. blocks [ 0 ] , callee_header) ;
@@ -447,10 +417,17 @@ impl Inliner<'_, '_> {
447417 post_call_block_insts. insert (
448418 0 ,
449419 Instruction :: new (
450- Op :: Load ,
420+ Op :: Phi ,
451421 Some ( call_result_type) ,
452422 Some ( call_result_id) ,
453- vec ! [ Operand :: IdRef ( return_variable. unwrap( ) ) ] ,
423+ phi_pairs
424+ . into_iter ( )
425+ . flat_map ( |( value, parent) | {
426+ use std:: iter;
427+ iter:: once ( Operand :: IdRef ( * rewrite_rules. get ( & value) . unwrap_or ( & value) ) )
428+ . chain ( iter:: once ( Operand :: IdRef ( rewrite_rules[ & parent] ) ) )
429+ } )
430+ . collect ( ) ,
454431 ) ,
455432 ) ;
456433 }
@@ -483,51 +460,21 @@ impl Inliner<'_, '_> {
483460 }
484461}
485462
486- fn get_inlined_blocks (
487- function : & Function ,
488- return_variable : Option < Word > ,
489- return_jump : Word ,
490- ) -> Vec < Block > {
463+ fn get_inlined_blocks ( function : & Function , return_jump : Word ) -> ( Vec < Block > , Vec < ( Word , Word ) > ) {
491464 let mut blocks = function. blocks . clone ( ) ;
465+ let mut phipairs = Vec :: new ( ) ;
492466 for block in & mut blocks {
493467 let last = block. instructions . last ( ) . unwrap ( ) ;
494468 if let Op :: Return | Op :: ReturnValue = last. class . opcode {
495469 if Op :: ReturnValue == last. class . opcode {
496470 let return_value = last. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ;
497- block. instructions . insert (
498- block. instructions . len ( ) - 1 ,
499- Instruction :: new (
500- Op :: Store ,
501- None ,
502- None ,
503- vec ! [
504- Operand :: IdRef ( return_variable. unwrap( ) ) ,
505- Operand :: IdRef ( return_value) ,
506- ] ,
507- ) ,
508- ) ;
509- } else {
510- assert ! ( return_variable. is_none( ) ) ;
471+ phipairs. push ( ( return_value, block. label_id ( ) . unwrap ( ) ) )
511472 }
512473 * block. instructions . last_mut ( ) . unwrap ( ) =
513474 Instruction :: new ( Op :: Branch , None , None , vec ! [ Operand :: IdRef ( return_jump) ] ) ;
514475 }
515476 }
516- blocks
517- }
518-
519- fn insert_opvariable ( block : & mut Block , ptr_ty : Word , result_id : Word ) {
520- let index = block
521- . instructions
522- . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
523-
524- let inst = Instruction :: new (
525- Op :: Variable ,
526- Some ( ptr_ty) ,
527- Some ( result_id) ,
528- vec ! [ Operand :: StorageClass ( StorageClass :: Function ) ] ,
529- ) ;
530- block. instructions . insert ( index, inst)
477+ ( blocks, phipairs)
531478}
532479
533480fn insert_opvariables ( block : & mut Block , insts : Vec < Instruction > ) {
@@ -539,6 +486,7 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
539486
540487fn fuse_trivial_branches ( function : & mut Function ) {
541488 let all_preds = compute_preds ( & function. blocks ) ;
489+ let mut rewrite_rules = FxHashMap :: default ( ) ;
542490 ' outer: for ( dest_block, mut preds) in all_preds. iter ( ) . enumerate ( ) {
543491 // if there's two trivial branches in a row, the middle one might get inlined before the
544492 // last one, so when processing the last one, skip through to the first one.
@@ -555,12 +503,22 @@ fn fuse_trivial_branches(function: &mut Function) {
555503 let pred_insts = & function. blocks [ pred] . instructions ;
556504 if pred_insts. last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
557505 let mut dest_insts = take ( & mut function. blocks [ dest_block] . instructions ) ;
506+ dest_insts. retain ( |inst| {
507+ if inst. class . opcode == Op :: Phi {
508+ assert_eq ! ( inst. operands. len( ) , 2 ) ;
509+ rewrite_rules. insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
510+ false
511+ } else {
512+ true
513+ }
514+ } ) ;
558515 let pred_insts = & mut function. blocks [ pred] . instructions ;
559516 pred_insts. pop ( ) ; // pop the branch
560517 pred_insts. append ( & mut dest_insts) ;
561518 }
562519 }
563520 function. blocks . retain ( |b| !b. instructions . is_empty ( ) ) ;
521+ apply_rewrite_rules ( & rewrite_rules, & mut function. blocks ) ;
564522}
565523
566524fn compute_preds ( blocks : & [ Block ] ) -> Vec < Vec < usize > > {
0 commit comments