@@ -37,28 +37,17 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3737 . iter ( )
3838 . find ( |inst| inst. class . opcode == Op :: TypeVoid )
3939 . map_or ( 0 , |inst| inst. result_id . unwrap ( ) ) ;
40- let ptr_map: FxHashMap < _ , _ > = module
41- . types_global_values
42- . iter ( )
43- . filter_map ( |inst| {
44- if inst. class . opcode == Op :: TypePointer
45- && inst. operands [ 0 ] . unwrap_storage_class ( ) == StorageClass :: Function
46- {
47- Some ( ( inst. operands [ 1 ] . unwrap_id_ref ( ) , inst. result_id . unwrap ( ) ) )
48- } else {
49- None
50- }
51- } )
52- . collect ( ) ;
40+
41+ let invalid_args = module. functions . iter ( ) . flat_map ( get_invalid_args) . collect ( ) ;
42+
5343 // Drop all the functions we'll be inlining. (This also means we won't waste time processing
5444 // inlines in functions that will get inlined)
5545 let mut inliner = Inliner {
5646 header : module. header . as_mut ( ) . unwrap ( ) ,
57- types_global_values : & mut module. types_global_values ,
5847 void,
59- ptr_map,
6048 functions : & functions,
6149 needs_inline : & to_delete,
50+ invalid_args,
6251 } ;
6352 for index in postorder {
6453 inliner. inline_fn ( & mut module. functions , index) ;
@@ -268,20 +257,21 @@ fn should_inline(
268257// This should be more general, but a very common problem is passing an OpAccessChain to an
269258// OpFunctionCall (i.e. `f(&s.x)`, or more commonly, `s.x.f()` where `f` takes `&self`), so detect
270259// that case and inline the call.
271- fn args_invalid ( function : & Function , call : & Instruction ) -> bool {
272- for inst in function. all_inst_iter ( ) {
260+ fn get_invalid_args < ' a > ( function : & ' a Function ) -> impl Iterator < Item = Word > + ' a {
261+ function. all_inst_iter ( ) . filter_map ( |inst| {
273262 if inst. class . opcode == Op :: AccessChain {
274- let inst_result = inst. result_id . unwrap ( ) ;
275- if call
276- . operands
277- . iter ( )
278- . any ( |op| * op == Operand :: IdRef ( inst_result) )
279- {
280- return true ;
281- }
263+ inst. result_id
264+ } else {
265+ None
282266 }
283- }
284- false
267+ } )
268+ }
269+
270+ fn args_invalid ( invalid_args : & FxHashSet < Word > , call : & Instruction ) -> bool {
271+ call. operands . iter ( ) . skip ( 1 ) . any ( |op| {
272+ op. id_ref_any ( )
273+ . map_or ( false , |arg| invalid_args. contains ( & arg) )
274+ } )
285275}
286276
287277// Steps:
@@ -292,11 +282,10 @@ fn args_invalid(function: &Function, call: &Instruction) -> bool {
292282
293283struct Inliner < ' m , ' map > {
294284 header : & ' m mut ModuleHeader ,
295- types_global_values : & ' m mut Vec < Instruction > ,
296285 void : Word ,
297- ptr_map : FxHashMap < Word , Word > ,
298286 functions : & ' map FunctionMap ,
299287 needs_inline : & ' map [ bool ] ,
288+ invalid_args : FxHashSet < Word > ,
300289}
301290
302291impl Inliner < ' _ , ' _ > {
@@ -306,25 +295,6 @@ impl Inliner<'_, '_> {
306295 result
307296 }
308297
309- fn ptr_ty ( & mut self , pointee : Word ) -> Word {
310- let existing = self . ptr_map . get ( & pointee) ;
311- if let Some ( existing) = existing {
312- return * existing;
313- }
314- let inst_id = self . id ( ) ;
315- self . types_global_values . push ( Instruction :: new (
316- Op :: TypePointer ,
317- None ,
318- Some ( inst_id) ,
319- vec ! [
320- Operand :: StorageClass ( StorageClass :: Function ) ,
321- Operand :: IdRef ( pointee) ,
322- ] ,
323- ) ) ;
324- self . ptr_map . insert ( pointee, inst_id) ;
325- inst_id
326- }
327-
328298 fn inline_fn ( & mut self , functions : & mut [ Function ] , index : usize ) {
329299 let mut function = take ( & mut functions[ index] ) ;
330300 let mut block_idx = 0 ;
@@ -359,8 +329,8 @@ impl Inliner<'_, '_> {
359329 self . functions [ & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ] ,
360330 )
361331 } )
362- . find ( |( index , inst, func_idx) | {
363- self . needs_inline [ * func_idx] || args_invalid ( caller , inst)
332+ . find ( |( _ , inst, func_idx) | {
333+ self . needs_inline [ * func_idx] || args_invalid ( & self . invalid_args , inst)
364334 } ) ;
365335 let ( call_index, call_inst, callee_idx) = match call {
366336 None => return false ,
@@ -388,18 +358,23 @@ impl Inliner<'_, '_> {
388358 } ) ;
389359 let mut rewrite_rules = callee_parameters. zip ( call_arguments) . collect ( ) ;
390360
391- let return_variable = if call_result_type. is_some ( ) {
392- Some ( self . id ( ) )
393- } else {
394- None
395- } ;
396361 let return_jump = self . id ( ) ;
397362 // Rewrite OpReturns of the callee.
398- let mut inlined_blocks = get_inlined_blocks ( callee, return_variable , return_jump) ;
363+ let ( mut inlined_blocks, phi_pairs ) = get_inlined_blocks ( callee, return_jump) ;
399364 // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
400365 // fn is inlined multiple times.
401366 self . add_clone_id_rules ( & mut rewrite_rules, & inlined_blocks) ;
367+ // If any of the OpReturns were invalid, return will also be invalid.
368+ for ( value, _) in & phi_pairs {
369+ if self . invalid_args . contains ( value) {
370+ self . invalid_args . insert ( call_result_id) ;
371+ self . invalid_args
372+ . insert ( * rewrite_rules. get ( value) . unwrap_or ( value) ) ;
373+ }
374+ }
402375 apply_rewrite_rules ( & rewrite_rules, & mut inlined_blocks) ;
376+ // unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args);
377+ // as no values from inside the inlined function ever make it directly out.
403378
404379 // Split the block containing the OpFunctionCall into two, around the call.
405380 let mut post_call_block_insts = caller. blocks [ block_idx]
@@ -409,32 +384,27 @@ impl Inliner<'_, '_> {
409384 let call = caller. blocks [ block_idx] . instructions . pop ( ) . unwrap ( ) ;
410385 assert ! ( call. class. opcode == Op :: FunctionCall ) ;
411386
412- if let Some ( call_result_type) = call_result_type {
413- // Generate the storage space for the return value: Do this *after* the split above,
414- // because if block_idx=0, inserting a variable here shifts call_index.
415- insert_opvariable (
416- & mut caller. blocks [ 0 ] ,
417- self . ptr_ty ( call_result_type) ,
418- return_variable. unwrap ( ) ,
419- ) ;
420- }
421-
422387 // Move the variables over from the inlined function to here.
423388 let mut callee_header = take ( & mut inlined_blocks[ 0 ] ) . instructions ;
424389 // TODO: OpLine handling
425390 let num_variables = callee_header. partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
426391 // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
427392 // it, and we maintain the invariant that current block has finished processing.
428- let jump_to = self . id ( ) ;
393+ let first_block_id = self . id ( ) ;
429394 inlined_blocks[ 0 ] = Block {
430- label : Some ( Instruction :: new ( Op :: Label , None , Some ( jump_to) , vec ! [ ] ) ) ,
395+ label : Some ( Instruction :: new (
396+ Op :: Label ,
397+ None ,
398+ Some ( first_block_id) ,
399+ vec ! [ ] ,
400+ ) ) ,
431401 instructions : callee_header. split_off ( num_variables) ,
432402 } ;
433403 caller. blocks [ block_idx] . instructions . push ( Instruction :: new (
434404 Op :: Branch ,
435405 None ,
436406 None ,
437- vec ! [ Operand :: IdRef ( jump_to ) ] ,
407+ vec ! [ Operand :: IdRef ( first_block_id ) ] ,
438408 ) ) ;
439409 // Move the OpVariables of the callee to the caller.
440410 insert_opvariables ( & mut caller. blocks [ 0 ] , callee_header) ;
@@ -445,10 +415,17 @@ impl Inliner<'_, '_> {
445415 post_call_block_insts. insert (
446416 0 ,
447417 Instruction :: new (
448- Op :: Load ,
418+ Op :: Phi ,
449419 Some ( call_result_type) ,
450420 Some ( call_result_id) ,
451- vec ! [ Operand :: IdRef ( return_variable. unwrap( ) ) ] ,
421+ phi_pairs
422+ . into_iter ( )
423+ . flat_map ( |( value, parent) | {
424+ use std:: iter;
425+ iter:: once ( Operand :: IdRef ( * rewrite_rules. get ( & value) . unwrap_or ( & value) ) )
426+ . chain ( iter:: once ( Operand :: IdRef ( rewrite_rules[ & parent] ) ) )
427+ } )
428+ . collect ( ) ,
452429 ) ,
453430 ) ;
454431 }
@@ -481,51 +458,21 @@ impl Inliner<'_, '_> {
481458 }
482459}
483460
484- fn get_inlined_blocks (
485- function : & Function ,
486- return_variable : Option < Word > ,
487- return_jump : Word ,
488- ) -> Vec < Block > {
461+ fn get_inlined_blocks ( function : & Function , return_jump : Word ) -> ( Vec < Block > , Vec < ( Word , Word ) > ) {
489462 let mut blocks = function. blocks . clone ( ) ;
463+ let mut phipairs = Vec :: new ( ) ;
490464 for block in & mut blocks {
491465 let last = block. instructions . last ( ) . unwrap ( ) ;
492466 if let Op :: Return | Op :: ReturnValue = last. class . opcode {
493467 if Op :: ReturnValue == last. class . opcode {
494468 let return_value = last. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ;
495- block. instructions . insert (
496- block. instructions . len ( ) - 1 ,
497- Instruction :: new (
498- Op :: Store ,
499- None ,
500- None ,
501- vec ! [
502- Operand :: IdRef ( return_variable. unwrap( ) ) ,
503- Operand :: IdRef ( return_value) ,
504- ] ,
505- ) ,
506- ) ;
507- } else {
508- assert ! ( return_variable. is_none( ) ) ;
469+ phipairs. push ( ( return_value, block. label_id ( ) . unwrap ( ) ) )
509470 }
510471 * block. instructions . last_mut ( ) . unwrap ( ) =
511472 Instruction :: new ( Op :: Branch , None , None , vec ! [ Operand :: IdRef ( return_jump) ] ) ;
512473 }
513474 }
514- blocks
515- }
516-
517- fn insert_opvariable ( block : & mut Block , ptr_ty : Word , result_id : Word ) {
518- let index = block
519- . instructions
520- . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
521-
522- let inst = Instruction :: new (
523- Op :: Variable ,
524- Some ( ptr_ty) ,
525- Some ( result_id) ,
526- vec ! [ Operand :: StorageClass ( StorageClass :: Function ) ] ,
527- ) ;
528- block. instructions . insert ( index, inst)
475+ ( blocks, phipairs)
529476}
530477
531478fn insert_opvariables ( block : & mut Block , insts : Vec < Instruction > ) {
@@ -537,6 +484,7 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
537484
538485fn fuse_trivial_branches ( function : & mut Function ) {
539486 let all_preds = compute_preds ( & function. blocks ) ;
487+ let mut rewrite_rules = FxHashMap :: default ( ) ;
540488 ' outer: for ( dest_block, mut preds) in all_preds. iter ( ) . enumerate ( ) {
541489 // if there's two trivial branches in a row, the middle one might get inlined before the
542490 // last one, so when processing the last one, skip through to the first one.
@@ -553,12 +501,22 @@ fn fuse_trivial_branches(function: &mut Function) {
553501 let pred_insts = & function. blocks [ pred] . instructions ;
554502 if pred_insts. last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
555503 let mut dest_insts = take ( & mut function. blocks [ dest_block] . instructions ) ;
504+ dest_insts. retain ( |inst| {
505+ if inst. class . opcode == Op :: Phi {
506+ assert_eq ! ( inst. operands. len( ) , 2 ) ;
507+ rewrite_rules. insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
508+ false
509+ } else {
510+ true
511+ }
512+ } ) ;
556513 let pred_insts = & mut function. blocks [ pred] . instructions ;
557514 pred_insts. pop ( ) ; // pop the branch
558515 pred_insts. append ( & mut dest_insts) ;
559516 }
560517 }
561518 function. blocks . retain ( |b| !b. instructions . is_empty ( ) ) ;
519+ apply_rewrite_rules ( & rewrite_rules, & mut function. blocks ) ;
562520}
563521
564522fn compute_preds ( blocks : & [ Block ] ) -> Vec < Vec < usize > > {
0 commit comments