@@ -13,20 +13,24 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet};
1313use rustc_session:: Session ;
1414use std:: mem:: take;
1515
16- type FunctionMap = FxHashMap < Word , Function > ;
16+ type FunctionMap = FxHashMap < Word , usize > ;
1717
1818pub fn inline ( sess : & Session , module : & mut Module ) -> super :: Result < ( ) > {
19+ let ( disallowed_argument_types, disallowed_return_types) =
20+ compute_disallowed_argument_and_return_types ( module) ;
21+ let mut to_delete: Vec < _ > = module
22+ . functions
23+ . iter ( )
24+ . map ( |f| should_inline ( & disallowed_argument_types, & disallowed_return_types, f) )
25+ . collect ( ) ;
1926 // This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
20- if module_has_recursion ( sess, module) {
21- return Err ( rustc_errors:: ErrorReported ) ;
22- }
27+ let postorder = compute_function_postorder ( sess, module, & mut to_delete) ?;
2328 let functions = module
2429 . functions
2530 . iter ( )
26- . map ( |f| ( f. def_id ( ) . unwrap ( ) , f. clone ( ) ) )
31+ . enumerate ( )
32+ . map ( |( idx, f) | ( f. def_id ( ) . unwrap ( ) , idx) )
2733 . collect ( ) ;
28- let ( disallowed_argument_types, disallowed_return_types) =
29- compute_disallowed_argument_and_return_types ( module) ;
3034 let void = module
3135 . types_global_values
3236 . iter ( )
@@ -35,23 +39,6 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3539 . unwrap_or ( 0 ) ;
3640 // Drop all the functions we'll be inlining. (This also means we won't waste time processing
3741 // inlines in functions that will get inlined)
38- let mut dropped_ids = FxHashSet :: default ( ) ;
39- module. functions . retain ( |f| {
40- if should_inline ( & disallowed_argument_types, & disallowed_return_types, f) {
41- // TODO: We should insert all defined IDs in this function.
42- dropped_ids. insert ( f. def_id ( ) . unwrap ( ) ) ;
43- false
44- } else {
45- true
46- }
47- } ) ;
48- // Drop OpName etc. for inlined functions
49- module. debug_names . retain ( |inst| {
50- !inst. operands . iter ( ) . any ( |op| {
51- op. id_ref_any ( )
52- . map_or ( false , |id| dropped_ids. contains ( & id) )
53- } )
54- } ) ;
5542 let mut inliner = Inliner {
5643 header : module. header . as_mut ( ) . unwrap ( ) ,
5744 types_global_values : & mut module. types_global_values ,
@@ -60,77 +47,122 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
6047 disallowed_argument_types : & disallowed_argument_types,
6148 disallowed_return_types : & disallowed_return_types,
6249 } ;
63- for function in & mut module . functions {
64- inliner. inline_fn ( function ) ;
65- fuse_trivial_branches ( function ) ;
50+ for index in postorder {
51+ inliner. inline_fn ( & mut module . functions , index ) ;
52+ fuse_trivial_branches ( & mut module . functions [ index ] ) ;
6653 }
54+ let mut dropped_ids = FxHashSet :: default ( ) ;
55+ for i in ( 0 ..module. functions . len ( ) ) . rev ( ) {
56+ if to_delete[ i] {
57+ dropped_ids. insert ( module. functions . remove ( i) . def_id ( ) . unwrap ( ) ) ;
58+ }
59+ }
60+ // Drop OpName etc. for inlined functions
61+ module. debug_names . retain ( |inst| {
62+ !inst. operands . iter ( ) . any ( |op| {
63+ op. id_ref_any ( )
64+ . map_or ( false , |id| dropped_ids. contains ( & id) )
65+ } )
66+ } ) ;
6767 Ok ( ( ) )
6868}
6969
70- // https://stackoverflow.com/a/53995651
71- fn module_has_recursion ( sess : & Session , module : & Module ) -> bool {
70+ /// Topological sorting algorithm due to T. Cormen
71+ /// Starts from module's entry points, so only reachable functions will be returned
72+ /// in post-traversal order of DFS. For all unvisited functions `module.functions[i]`,
73+ /// `to_delete[i]` is set to true.
74+ fn compute_function_postorder (
75+ sess : & Session ,
76+ module : & Module ,
77+ to_delete : & mut [ bool ] ,
78+ ) -> super :: Result < Vec < usize > > {
7279 let func_to_index: FxHashMap < Word , usize > = module
7380 . functions
7481 . iter ( )
7582 . enumerate ( )
7683 . map ( |( index, func) | ( func. def_id ( ) . unwrap ( ) , index) )
7784 . collect ( ) ;
78- let mut discovered = vec ! [ false ; module. functions. len( ) ] ;
79- let mut finished = vec ! [ false ; module. functions. len( ) ] ;
85+ /// Possible node states for cycle-discovering DFS.
86+ #[ derive( Clone , PartialEq ) ]
87+ enum NodeState {
88+ /// Normal, not visited.
89+ NotVisited ,
90+ /// Currently being visited.
91+ Discovered ,
92+ /// DFS returned.
93+ Finished ,
94+ /// Not visited, entry point.
95+ Entry ,
96+ }
97+ let mut states = vec ! [ NodeState :: NotVisited ; module. functions. len( ) ] ;
98+ for opep in module. entry_points . iter ( ) {
99+ let func_id = opep. operands [ 1 ] . unwrap_id_ref ( ) ;
100+ states[ func_to_index[ & func_id] ] = NodeState :: Entry ;
101+ }
80102 let mut has_recursion = false ;
103+ let mut postorder = vec ! [ ] ;
81104 for index in 0 ..module. functions . len ( ) {
82- if !discovered [ index ] && !finished [ index] {
105+ if NodeState :: Entry == states [ index] {
83106 visit (
84107 sess,
85108 module,
86109 index,
87- & mut discovered,
88- & mut finished,
110+ & mut states[ ..] ,
89111 & mut has_recursion,
112+ & mut postorder,
90113 & func_to_index,
91114 ) ;
92115 }
93116 }
94117
118+ for index in 0 ..module. functions . len ( ) {
119+ if NodeState :: NotVisited == states[ index] {
120+ to_delete[ index] = true ;
121+ }
122+ }
123+
95124 fn visit (
96125 sess : & Session ,
97126 module : & Module ,
98127 current : usize ,
99- discovered : & mut Vec < bool > ,
100- finished : & mut Vec < bool > ,
128+ states : & mut [ NodeState ] ,
101129 has_recursion : & mut bool ,
130+ postorder : & mut Vec < usize > ,
102131 func_to_index : & FxHashMap < Word , usize > ,
103132 ) {
104- discovered [ current] = true ;
133+ states [ current] = NodeState :: Discovered ;
105134
106135 for next in calls ( & module. functions [ current] , func_to_index) {
107- if discovered[ next] {
108- let names = get_names ( module) ;
109- let current_name = get_name ( & names, module. functions [ current] . def_id ( ) . unwrap ( ) ) ;
110- let next_name = get_name ( & names, module. functions [ next] . def_id ( ) . unwrap ( ) ) ;
111- sess. err ( & format ! (
112- "module has recursion, which is not allowed: `{}` calls `{}`" ,
113- current_name, next_name
114- ) ) ;
115- * has_recursion = true ;
116- break ;
117- }
118-
119- if !finished[ next] {
120- visit (
121- sess,
122- module,
123- next,
124- discovered,
125- finished,
126- has_recursion,
127- func_to_index,
128- ) ;
136+ match states[ next] {
137+ NodeState :: Discovered => {
138+ let names = get_names ( module) ;
139+ let current_name =
140+ get_name ( & names, module. functions [ current] . def_id ( ) . unwrap ( ) ) ;
141+ let next_name = get_name ( & names, module. functions [ next] . def_id ( ) . unwrap ( ) ) ;
142+ sess. err ( & format ! (
143+ "module has recursion, which is not allowed: `{}` calls `{}`" ,
144+ current_name, next_name
145+ ) ) ;
146+ * has_recursion = true ;
147+ break ;
148+ }
149+ NodeState :: NotVisited | NodeState :: Entry => {
150+ visit (
151+ sess,
152+ module,
153+ next,
154+ states,
155+ has_recursion,
156+ postorder,
157+ func_to_index,
158+ ) ;
159+ }
160+ NodeState :: Finished => { }
129161 }
130162 }
131163
132- discovered [ current] = false ;
133- finished [ current] = true ;
164+ states [ current] = NodeState :: Finished ;
165+ postorder . push ( current)
134166 }
135167
136168 fn calls < ' a > (
@@ -146,7 +178,11 @@ fn module_has_recursion(sess: &Session, module: &Module) -> bool {
146178 } )
147179 }
148180
149- has_recursion
181+ if has_recursion {
182+ Err ( rustc_errors:: ErrorReported )
183+ } else {
184+ Ok ( postorder)
185+ }
150186}
151187
152188fn compute_disallowed_argument_and_return_types (
@@ -283,33 +319,39 @@ impl Inliner<'_, '_> {
283319 inst_id
284320 }
285321
286- fn inline_fn ( & mut self , function : & mut Function ) {
322+ fn inline_fn ( & mut self , functions : & mut [ Function ] , index : usize ) {
323+ let mut function = take ( & mut functions[ index] ) ;
287324 let mut block_idx = 0 ;
288325 while block_idx < function. blocks . len ( ) {
289- // If we successfully inlined a block, then repeat processing on the same block, in
290- // case the newly inlined block has more inlined calls.
291- // TODO: This is quadratic
292- if !self . inline_block ( function, block_idx) {
293- block_idx += 1 ;
294- }
326+ // If we successfully inlined a block, then continue processing on the next block or its tail.
327+ // TODO: this is quadratic in cases where [`Op::AccessChain`]s cascade into inner arguments.
328+ // For the common case of "we knew which functions to inline", it is linear.
329+ self . inline_block ( & mut function, & functions, block_idx) ;
330+ block_idx += 1 ;
295331 }
332+ functions[ index] = function;
296333 }
297334
298- fn inline_block ( & mut self , caller : & mut Function , block_idx : usize ) -> bool {
335+ /// Inlines one block and returns whether inlining actually occurred.
336+ /// After calling this, blocks[block_idx] is finished processing.
337+ fn inline_block (
338+ & mut self ,
339+ caller : & mut Function ,
340+ functions : & [ Function ] ,
341+ block_idx : usize ,
342+ ) -> bool {
299343 // Find the first inlined OpFunctionCall
300344 let call = caller. blocks [ block_idx]
301345 . instructions
302346 . iter ( )
303347 . enumerate ( )
304348 . filter ( |( _, inst) | inst. class . opcode == Op :: FunctionCall )
305349 . map ( |( index, inst) | {
306- (
307- index,
308- inst,
309- self . functions
310- . get ( & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) )
311- . unwrap ( ) ,
312- )
350+ let idx = self
351+ . functions
352+ . get ( & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) )
353+ . unwrap ( ) ;
354+ ( index, inst, & functions[ * idx] )
313355 } )
314356 . find ( |( _, inst, f) | {
315357 should_inline (
@@ -374,17 +416,23 @@ impl Inliner<'_, '_> {
374416 ) ;
375417 }
376418
377- // Fuse the first block of the callee into the block of the caller. This is okay because
378- // it's illegal to branch to the first BB in a function.
379- let mut callee_header = inlined_blocks. remove ( 0 ) . instructions ;
419+ // Move the variables over from the inlined function to here.
420+ let mut callee_header = take ( & mut inlined_blocks[ 0 ] ) . instructions ;
380421 // TODO: OpLine handling
381- let num_variables = callee_header
382- . iter ( )
383- . position ( |inst| inst. class . opcode != Op :: Variable )
384- . unwrap_or ( callee_header. len ( ) ) ;
385- caller. blocks [ block_idx]
386- . instructions
387- . append ( & mut callee_header. split_off ( num_variables) ) ;
422+ let num_variables = callee_header. partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
423+ // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
424+ // it, and we maintain the invariant that current block has finished processing.
425+ let jump_to = self . id ( ) ;
426+ inlined_blocks[ 0 ] = Block {
427+ label : Some ( Instruction :: new ( Op :: Label , None , Some ( jump_to) , vec ! [ ] ) ) ,
428+ instructions : callee_header. split_off ( num_variables) ,
429+ } ;
430+ caller. blocks [ block_idx] . instructions . push ( Instruction :: new (
431+ Op :: Branch ,
432+ None ,
433+ None ,
434+ vec ! [ Operand :: IdRef ( jump_to) ] ,
435+ ) ) ;
388436 // Move the OpVariables of the callee to the caller.
389437 insert_opvariables ( & mut caller. blocks [ 0 ] , callee_header) ;
390438
@@ -466,45 +514,22 @@ fn get_inlined_blocks(
466514fn insert_opvariable ( block : & mut Block , ptr_ty : Word , result_id : Word ) {
467515 let index = block
468516 . instructions
469- . iter ( )
470- . enumerate ( )
471- . find_map ( |( index, inst) | {
472- if inst. class . opcode != Op :: Variable {
473- Some ( index)
474- } else {
475- None
476- }
477- } ) ;
517+ . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
518+
478519 let inst = Instruction :: new (
479520 Op :: Variable ,
480521 Some ( ptr_ty) ,
481522 Some ( result_id) ,
482523 vec ! [ Operand :: StorageClass ( StorageClass :: Function ) ] ,
483524 ) ;
484- match index {
485- Some ( index) => block. instructions . insert ( index, inst) ,
486- None => block. instructions . push ( inst) ,
487- }
525+ block. instructions . insert ( index, inst)
488526}
489527
490- fn insert_opvariables ( block : & mut Block , mut insts : Vec < Instruction > ) {
528+ fn insert_opvariables ( block : & mut Block , insts : Vec < Instruction > ) {
491529 let index = block
492530 . instructions
493- . iter ( )
494- . enumerate ( )
495- . find_map ( |( index, inst) | {
496- if inst. class . opcode != Op :: Variable {
497- Some ( index)
498- } else {
499- None
500- }
501- } ) ;
502- match index {
503- Some ( index) => {
504- block. instructions . splice ( index..index, insts) ;
505- }
506- None => block. instructions . append ( & mut insts) ,
507- }
531+ . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
532+ block. instructions . splice ( index..index, insts) ;
508533}
509534
510535fn fuse_trivial_branches ( function : & mut Function ) {
0 commit comments