@@ -37,14 +37,25 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3737 . iter ( )
3838 . find ( |inst| inst. class . opcode == Op :: TypeVoid )
3939 . map_or ( 0 , |inst| inst. result_id . unwrap ( ) ) ;
40-
40+ let ptr_map: FxHashMap < _ , _ > = module
41+ . types_global_values
42+ . iter ( )
43+ . filter_map ( |inst| {
44+ if inst. class . opcode == Op :: TypePointer
45+ && inst. operands [ 0 ] . unwrap_storage_class ( ) == StorageClass :: Function
46+ {
47+ Some ( ( inst. operands [ 1 ] . unwrap_id_ref ( ) , inst. result_id . unwrap ( ) ) )
48+ } else {
49+ None
50+ }
51+ } )
52+ . collect ( ) ;
4153 let invalid_args = module. functions . iter ( ) . flat_map ( get_invalid_args) . collect ( ) ;
42-
43- // Drop all the functions we'll be inlining. (This also means we won't waste time processing
44- // inlines in functions that will get inlined)
4554 let mut inliner = Inliner {
4655 header : module. header . as_mut ( ) . unwrap ( ) ,
56+ types_global_values : & mut module. types_global_values ,
4757 void,
58+ ptr_map,
4859 functions : & functions,
4960 needs_inline : & to_delete,
5061 invalid_args,
@@ -270,7 +281,9 @@ fn args_invalid(invalid_args: &FxHashSet<Word>, call: &Instruction) -> bool {
270281
271282struct Inliner < ' m , ' map > {
272283 header : & ' m mut ModuleHeader ,
284+ types_global_values : & ' m mut Vec < Instruction > ,
273285 void : Word ,
286+ ptr_map : FxHashMap < Word , Word > ,
274287 functions : & ' map FunctionMap ,
275288 needs_inline : & ' map [ bool ] ,
276289 invalid_args : FxHashSet < Word > ,
@@ -283,6 +296,25 @@ impl Inliner<'_, '_> {
283296 result
284297 }
285298
299+ fn ptr_ty ( & mut self , pointee : Word ) -> Word {
300+ let existing = self . ptr_map . get ( & pointee) ;
301+ if let Some ( existing) = existing {
302+ return * existing;
303+ }
304+ let inst_id = self . id ( ) ;
305+ self . types_global_values . push ( Instruction :: new (
306+ Op :: TypePointer ,
307+ None ,
308+ Some ( inst_id) ,
309+ vec ! [
310+ Operand :: StorageClass ( StorageClass :: Function ) ,
311+ Operand :: IdRef ( pointee) ,
312+ ] ,
313+ ) ) ;
314+ self . ptr_map . insert ( pointee, inst_id) ;
315+ inst_id
316+ }
317+
286318 fn inline_fn ( & mut self , functions : & mut [ Function ] , index : usize ) {
287319 let mut function = take ( & mut functions[ index] ) ;
288320 let mut block_idx = 0 ;
@@ -346,23 +378,27 @@ impl Inliner<'_, '_> {
346378 } ) ;
347379 let mut rewrite_rules = callee_parameters. zip ( call_arguments) . collect ( ) ;
348380
381+ let return_variable = if call_result_type. is_some ( ) {
382+ Some ( self . id ( ) )
383+ } else {
384+ None
385+ } ;
349386 let return_jump = self . id ( ) ;
350387 // Rewrite OpReturns of the callee.
351- let ( mut inlined_blocks, phi_pairs) = get_inlined_blocks ( callee, return_jump) ;
388+ let ( mut inlined_blocks, return_values) =
389+ get_inlined_blocks ( callee, return_variable, return_jump) ;
352390 // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
353391 // fn is inlined multiple times.
354392 self . add_clone_id_rules ( & mut rewrite_rules, & inlined_blocks) ;
355393 // If any of the OpReturns were invalid, return will also be invalid.
356- for ( value, _ ) in & phi_pairs {
394+ for value in & return_values {
357395 if self . invalid_args . contains ( value) {
358396 self . invalid_args . insert ( call_result_id) ;
359397 self . invalid_args
360398 . insert ( * rewrite_rules. get ( value) . unwrap_or ( value) ) ;
361399 }
362400 }
363401 apply_rewrite_rules ( & rewrite_rules, & mut inlined_blocks) ;
364- // unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args);
365- // as no values from inside the inlined function ever make it directly out.
366402
367403 // Split the block containing the OpFunctionCall into two, around the call.
368404 let mut post_call_block_insts = caller. blocks [ block_idx]
@@ -372,27 +408,32 @@ impl Inliner<'_, '_> {
372408 let call = caller. blocks [ block_idx] . instructions . pop ( ) . unwrap ( ) ;
373409 assert ! ( call. class. opcode == Op :: FunctionCall ) ;
374410
411+ if let Some ( call_result_type) = call_result_type {
412+ // Generate the storage space for the return value: Do this *after* the split above,
413+ // because if block_idx=0, inserting a variable here shifts call_index.
414+ insert_opvariable (
415+ & mut caller. blocks [ 0 ] ,
416+ self . ptr_ty ( call_result_type) ,
417+ return_variable. unwrap ( ) ,
418+ ) ;
419+ }
420+
375421 // Move the variables over from the inlined function to here.
376422 let mut callee_header = take ( & mut inlined_blocks[ 0 ] ) . instructions ;
377423 // TODO: OpLine handling
378424 let num_variables = callee_header. partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
379425 // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
380426 // it, and we maintain the invariant that current block has finished processing.
381- let first_block_id = self . id ( ) ;
427+ let jump_to = self . id ( ) ;
382428 inlined_blocks[ 0 ] = Block {
383- label : Some ( Instruction :: new (
384- Op :: Label ,
385- None ,
386- Some ( first_block_id) ,
387- vec ! [ ] ,
388- ) ) ,
429+ label : Some ( Instruction :: new ( Op :: Label , None , Some ( jump_to) , vec ! [ ] ) ) ,
389430 instructions : callee_header. split_off ( num_variables) ,
390431 } ;
391432 caller. blocks [ block_idx] . instructions . push ( Instruction :: new (
392433 Op :: Branch ,
393434 None ,
394435 None ,
395- vec ! [ Operand :: IdRef ( first_block_id ) ] ,
436+ vec ! [ Operand :: IdRef ( jump_to ) ] ,
396437 ) ) ;
397438 // Move the OpVariables of the callee to the caller.
398439 insert_opvariables ( & mut caller. blocks [ 0 ] , callee_header) ;
@@ -403,17 +444,10 @@ impl Inliner<'_, '_> {
403444 post_call_block_insts. insert (
404445 0 ,
405446 Instruction :: new (
406- Op :: Phi ,
447+ Op :: Load ,
407448 Some ( call_result_type) ,
408449 Some ( call_result_id) ,
409- phi_pairs
410- . into_iter ( )
411- . flat_map ( |( value, parent) | {
412- use std:: iter;
413- iter:: once ( Operand :: IdRef ( * rewrite_rules. get ( & value) . unwrap_or ( & value) ) )
414- . chain ( iter:: once ( Operand :: IdRef ( rewrite_rules[ & parent] ) ) )
415- } )
416- . collect ( ) ,
450+ vec ! [ Operand :: IdRef ( return_variable. unwrap( ) ) ] ,
417451 ) ,
418452 ) ;
419453 }
@@ -446,21 +480,53 @@ impl Inliner<'_, '_> {
446480 }
447481}
448482
449- fn get_inlined_blocks ( function : & Function , return_jump : Word ) -> ( Vec < Block > , Vec < ( Word , Word ) > ) {
483+ fn get_inlined_blocks (
484+ function : & Function ,
485+ return_variable : Option < Word > ,
486+ return_jump : Word ,
487+ ) -> ( Vec < Block > , Vec < Word > ) {
450488 let mut blocks = function. blocks . clone ( ) ;
451- let mut phipairs = Vec :: new ( ) ;
489+ let mut values = Vec :: new ( ) ;
452490 for block in & mut blocks {
453491 let last = block. instructions . last ( ) . unwrap ( ) ;
454492 if let Op :: Return | Op :: ReturnValue = last. class . opcode {
455493 if Op :: ReturnValue == last. class . opcode {
456494 let return_value = last. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ;
457- phipairs. push ( ( return_value, block. label_id ( ) . unwrap ( ) ) ) ;
495+ values. push ( return_value) ;
496+ block. instructions . insert (
497+ block. instructions . len ( ) - 1 ,
498+ Instruction :: new (
499+ Op :: Store ,
500+ None ,
501+ None ,
502+ vec ! [
503+ Operand :: IdRef ( return_variable. unwrap( ) ) ,
504+ Operand :: IdRef ( return_value) ,
505+ ] ,
506+ ) ,
507+ ) ;
508+ } else {
509+ assert ! ( return_variable. is_none( ) ) ;
458510 }
459511 * block. instructions . last_mut ( ) . unwrap ( ) =
460512 Instruction :: new ( Op :: Branch , None , None , vec ! [ Operand :: IdRef ( return_jump) ] ) ;
461513 }
462514 }
463- ( blocks, phipairs)
515+ ( blocks, values)
516+ }
517+
518+ fn insert_opvariable ( block : & mut Block , ptr_ty : Word , result_id : Word ) {
519+ let index = block
520+ . instructions
521+ . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
522+
523+ let inst = Instruction :: new (
524+ Op :: Variable ,
525+ Some ( ptr_ty) ,
526+ Some ( result_id) ,
527+ vec ! [ Operand :: StorageClass ( StorageClass :: Function ) ] ,
528+ ) ;
529+ block. instructions . insert ( index, inst)
464530}
465531
466532fn insert_opvariables ( block : & mut Block , insts : Vec < Instruction > ) {
@@ -472,7 +538,6 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
472538
473539fn fuse_trivial_branches ( function : & mut Function ) {
474540 let mut chain_list = compute_outgoing_1to1_branches ( & function. blocks ) ;
475- let mut rewrite_rules = FxHashMap :: default ( ) ;
476541
477542 for block_idx in 0 ..chain_list. len ( ) {
478543 let mut next = chain_list[ block_idx] . take ( ) ;
@@ -488,16 +553,6 @@ fn fuse_trivial_branches(function: &mut Function) {
488553 }
489554 Some ( next_idx) => {
490555 let mut dest_insts = take ( & mut function. blocks [ next_idx] . instructions ) ;
491- dest_insts. retain ( |inst| {
492- if inst. class . opcode == Op :: Phi {
493- assert_eq ! ( inst. operands. len( ) , 2 ) ;
494- rewrite_rules
495- . insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
496- false
497- } else {
498- true
499- }
500- } ) ;
501556 let self_insts = & mut function. blocks [ block_idx] . instructions ;
502557 self_insts. pop ( ) ; // pop the branch
503558 self_insts. append ( & mut dest_insts) ;
@@ -507,14 +562,6 @@ fn fuse_trivial_branches(function: &mut Function) {
507562 }
508563 }
509564 function. blocks . retain ( |b| !b. instructions . is_empty ( ) ) ;
510- // Calculate a closure, as these rules can be transitive
511- let mut rewrite_rules_new = rewrite_rules. clone ( ) ;
512- for value in rewrite_rules_new. values_mut ( ) {
513- while let Some ( next) = rewrite_rules. get ( value) {
514- * value = * next;
515- }
516- }
517- apply_rewrite_rules ( & rewrite_rules_new, & mut function. blocks ) ;
518565}
519566
520567fn compute_outgoing_1to1_branches ( blocks : & [ Block ] ) -> Vec < Option < usize > > {
0 commit comments