11use rspirv:: binary:: Disassemble ;
22use rspirv:: dr:: { Instruction , Module , Operand } ;
3- use rspirv:: spirv:: Op ;
3+ use rspirv:: spirv:: { Op , StorageClass } ;
44use rustc_data_structures:: fx:: { FxHashMap , FxHashSet } ;
55use rustc_session:: Session ;
66
7+ // bool is if this needs stored
8+ #[ derive( Debug , Clone , PartialEq ) ]
9+ struct NormalizedInstructions {
10+ vars : Vec < Instruction > ,
11+ insts : Vec < Instruction > ,
12+ root : u32 ,
13+ }
14+
15+ impl NormalizedInstructions {
16+ fn new ( id : u32 ) -> Self {
17+ NormalizedInstructions {
18+ vars : Vec :: new ( ) ,
19+ insts : Vec :: new ( ) ,
20+ root : id,
21+ }
22+ }
23+
24+ fn extend ( & mut self , o : NormalizedInstructions ) {
25+ self . vars . extend ( o. vars ) ;
26+ self . insts . extend ( o. insts ) ;
27+ }
28+
29+ fn is_empty ( & self ) -> bool {
30+ self . insts . is_empty ( ) && self . vars . is_empty ( )
31+ }
32+
33+ fn fix_ids ( & mut self , bound : & mut u32 , new_root : u32 ) {
34+ let mut id_map: FxHashMap < u32 , u32 > = FxHashMap :: default ( ) ;
35+ id_map. insert ( self . root , new_root) ;
36+ for inst in & mut self . vars {
37+ Self :: fix_instruction ( self . root , inst, & mut id_map, bound, new_root) ;
38+ }
39+ for inst in & mut self . insts {
40+ Self :: fix_instruction ( self . root , inst, & mut id_map, bound, new_root) ;
41+ }
42+ }
43+
44+ fn fix_instruction (
45+ root : u32 ,
46+ inst : & mut Instruction ,
47+ id_map : & mut FxHashMap < u32 , u32 > ,
48+ bound : & mut u32 ,
49+ new_root : u32 ,
50+ ) {
51+ for op in & mut inst. operands {
52+ match op {
53+ Operand :: IdRef ( id) => match id_map. get ( id) {
54+ Some ( new_id) => {
55+ * id = * new_id;
56+ }
57+ _ => { }
58+ } ,
59+ _ => { }
60+ }
61+ }
62+ if let Some ( id) = & mut inst. result_id {
63+ if * id != root {
64+ id_map. insert ( * id, * bound) ;
65+ * id = * bound;
66+ * bound += 1 ;
67+ } else {
68+ * id = new_root;
69+ }
70+ }
71+ }
72+ }
73+
774#[ derive( Debug , Clone , PartialEq ) ]
875enum FunctionArg {
976 Invalid ,
10- Insts ( Vec < Instruction > ) ,
77+ Insts ( NormalizedInstructions ) ,
1178}
1279
1380pub fn inline_global_varaibles ( sess : & Session , module : & mut Module ) -> super :: Result < ( ) > {
@@ -36,14 +103,30 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
36103 }
37104 // then we keep track of which function parameter are always called with the same expression that only uses global variables
38105 let mut function_args: FxHashMap < ( u32 , u32 ) , FunctionArg > = FxHashMap :: default ( ) ;
106+ let mut bound = module. header . as_ref ( ) . unwrap ( ) . bound ;
39107 for caller in & module. functions {
40108 let mut insts: FxHashMap < u32 , Instruction > = FxHashMap :: default ( ) ;
109+ // for variables that only stored once and it's stored as a ref
110+ let mut ref_stores: FxHashMap < u32 , Option < u32 > > = FxHashMap :: default ( ) ;
41111 for block in & caller. blocks {
42112 for inst in & block. instructions {
43113 if inst. result_id . is_some ( ) {
44114 insts. insert ( inst. result_id . unwrap ( ) , inst. clone ( ) ) ;
45115 }
46- if inst. class . opcode == Op :: FunctionCall {
116+ if inst. class . opcode == Op :: Store {
117+ if let Operand :: IdRef ( to) = inst. operands [ 0 ] {
118+ if let Operand :: IdRef ( from) = inst. operands [ 1 ] {
119+ match ref_stores. get ( & to) {
120+ None => {
121+ ref_stores. insert ( to, Some ( from) ) ;
122+ }
123+ Some ( _) => {
124+ ref_stores. insert ( to, None ) ;
125+ }
126+ }
127+ }
128+ }
129+ } else if inst. class . opcode == Op :: FunctionCall {
47130 let function_id = match & inst. operands [ 0 ] {
48131 & Operand :: IdRef ( w) => w,
49132 _ => panic ! ( ) ,
@@ -52,16 +135,19 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
52135 let key = ( function_id, i as u32 - 1 ) ;
53136 match & inst. operands [ i] {
54137 & Operand :: IdRef ( w) => match & function_args. get ( & key) {
55- None => match get_const_arg_insts ( & variables, & insts, w) {
56- Some ( insts) => {
57- function_args. insert ( key, FunctionArg :: Insts ( insts) ) ;
58- }
59- None => {
60- function_args. insert ( key, FunctionArg :: Invalid ) ;
138+ None => {
139+ match get_const_arg_insts ( bound, & variables, & insts, & ref_stores, w) {
140+ Some ( insts) => {
141+ function_args. insert ( key, FunctionArg :: Insts ( insts) ) ;
142+ }
143+ None => {
144+ function_args. insert ( key, FunctionArg :: Invalid ) ;
145+ }
61146 }
62- } ,
147+ }
63148 Some ( FunctionArg :: Insts ( w2) ) => {
64- let new_insts = get_const_arg_insts ( & variables, & insts, w) ;
149+ let new_insts =
150+ get_const_arg_insts ( bound, & variables, & insts, & ref_stores, w) ;
65151 match new_insts {
66152 Some ( new_insts) => {
67153 if new_insts != * w2 {
@@ -93,11 +179,10 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
93179 if function_args. is_empty ( ) {
94180 return Ok ( false ) ;
95181 }
96- let mut bound = module. header . as_ref ( ) . unwrap ( ) . bound ;
97182 for function in & mut module. functions {
98183 let def = function. def . as_mut ( ) . unwrap ( ) ;
99184 let fid = def. result_id . unwrap ( ) ;
100- let mut insts: Vec < Instruction > = Vec :: new ( ) ;
185+ let mut insts = NormalizedInstructions :: new ( 0 ) ;
101186 let mut j: u32 = 0 ;
102187 let mut i = 0 ;
103188 let mut removed_indexes: Vec < u32 > = Vec :: new ( ) ;
@@ -108,10 +193,7 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
108193 Some ( FunctionArg :: Insts ( arg) ) => {
109194 let parameter = function. parameters . remove ( i) ;
110195 let mut arg = arg. clone ( ) ;
111- arg. reverse ( ) ;
112- insts_replacing_captured_ids ( & mut arg, & mut bound) ;
113- let index = arg. len ( ) - 1 ;
114- arg[ index] . result_id = parameter. result_id ;
196+ arg. fix_ids ( & mut bound, parameter. result_id . unwrap ( ) ) ;
115197 insts. extend ( arg) ;
116198 removed_indexes. push ( j) ;
117199 removed = true ;
@@ -132,15 +214,16 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
132214 for i in removed_indexes. iter ( ) . rev ( ) {
133215 let i = * i as usize + 1 ;
134216 function_type. operands . remove ( i) ;
135- function_type. result_id = Some ( tid) ;
136217 }
218+ function_type. result_id = Some ( tid) ;
137219 def. operands [ 1 ] = Operand :: IdRef ( tid) ;
138220 module. types_global_values . push ( function_type) ;
139221 }
140222 }
141223 // callee side. insert initialization instructions, which reuse the ids of the removed parameters
142224 if !function. blocks . is_empty ( ) {
143225 let first_block = & mut function. blocks [ 0 ] ;
226+ first_block. instructions . splice ( 0 ..0 , insts. vars ) ;
144227 // skip some instructions that must be at top of block
145228 let mut i = 0 ;
146229 loop {
@@ -154,7 +237,7 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
154237 }
155238 i += 1 ;
156239 }
157- first_block. instructions . splice ( i..i, insts) ;
240+ first_block. instructions . splice ( i..i, insts. insts ) ;
158241 }
159242 // caller side, remove parameters from function call
160243 for block in & mut function. blocks {
@@ -181,76 +264,90 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
181264 Ok ( true )
182265}
183266
184- fn insts_replacing_captured_ids ( arg : & mut Vec < Instruction > , bound : & mut u32 ) {
185- let mut id_map: FxHashMap < u32 , u32 > = FxHashMap :: default ( ) ;
186- for ins in arg {
187- if let Some ( id) = & mut ins. result_id {
188- for op in & mut ins. operands {
189- match op {
190- Operand :: IdRef ( id) => match id_map. get ( id) {
191- Some ( new_id) => {
192- * id = * new_id;
193- }
194- _ => { }
195- } ,
196- _ => { }
197- }
198- }
199- id_map. insert ( * id, * bound) ;
200- * id = * bound;
201- * bound += 1 ;
202- }
203- }
204- }
205-
206267fn get_const_arg_operands (
207268 variables : & FxHashSet < u32 > ,
208269 insts : & FxHashMap < u32 , Instruction > ,
270+ ref_stores : & FxHashMap < u32 , Option < u32 > > ,
209271 operand : & Operand ,
210- ) -> Option < Vec < Instruction > > {
272+ ) -> Option < NormalizedInstructions > {
211273 match operand {
212274 Operand :: IdRef ( id) => {
213- let insts = get_const_arg_insts ( variables, insts, * id) ?;
275+ let insts = get_const_arg_insts_rec ( variables, insts, ref_stores , * id) ?;
214276 return Some ( insts) ;
215277 }
216- Operand :: LiteralInt32 ( _) => { } ,
217- Operand :: LiteralInt64 ( _) => { } ,
218- Operand :: LiteralFloat32 ( _) => { } ,
219- Operand :: LiteralFloat64 ( _) => { } ,
220- Operand :: LiteralExtInstInteger ( _) => { } ,
221- Operand :: LiteralSpecConstantOpInteger ( _) => { } ,
222- Operand :: LiteralString ( _) => { } ,
278+ Operand :: LiteralInt32 ( _) => { }
279+ Operand :: LiteralInt64 ( _) => { }
280+ Operand :: LiteralFloat32 ( _) => { }
281+ Operand :: LiteralFloat64 ( _) => { }
282+ Operand :: LiteralExtInstInteger ( _) => { }
283+ Operand :: LiteralSpecConstantOpInteger ( _) => { }
284+ Operand :: LiteralString ( _) => { }
223285 _ => {
224286 // TOOD add more cases
225287 return None ;
226288 }
227289 }
228- return Some ( Vec :: new ( ) ) ;
290+ return Some ( NormalizedInstructions :: new ( 0 ) ) ;
229291}
230292
231293fn get_const_arg_insts (
294+ mut bound : u32 ,
295+ variables : & FxHashSet < u32 > ,
296+ insts : & FxHashMap < u32 , Instruction > ,
297+ ref_stores : & FxHashMap < u32 , Option < u32 > > ,
298+ id : u32 ,
299+ ) -> Option < NormalizedInstructions > {
300+ let mut res = get_const_arg_insts_rec ( variables, insts, ref_stores, id) ?;
301+ res. insts . reverse ( ) ;
302+ // the bound passed in is always the same
303+ // we need to normalize the ids, so they are the same when compared
304+ let fake_root = bound;
305+ bound += 1 ;
306+ res. fix_ids ( & mut bound, fake_root) ;
307+ res. root = fake_root;
308+ Some ( res)
309+ }
310+
311+ fn get_const_arg_insts_rec (
232312 variables : & FxHashSet < u32 > ,
233313 insts : & FxHashMap < u32 , Instruction > ,
314+ ref_stores : & FxHashMap < u32 , Option < u32 > > ,
234315 id : u32 ,
235- ) -> Option < Vec < Instruction > > {
236- let mut result: Vec < Instruction > = Vec :: new ( ) ;
316+ ) -> Option < NormalizedInstructions > {
317+ let mut result = NormalizedInstructions :: new ( id ) ;
237318 if variables. contains ( & id) {
238319 return Some ( result) ;
239320 }
240321 let par: & Instruction = insts. get ( & id) ?;
241322 if par. class . opcode == Op :: AccessChain {
242- result. push ( par. clone ( ) ) ;
323+ result. insts . push ( par. clone ( ) ) ;
243324 for oprand in & par. operands {
244- let insts = get_const_arg_operands ( variables, insts, oprand) ?;
325+ let insts = get_const_arg_operands ( variables, insts, ref_stores , oprand) ?;
245326 result. extend ( insts) ;
246327 }
247328 } else if par. class . opcode == Op :: FunctionCall {
248- result. push ( par. clone ( ) ) ;
329+ result. insts . push ( par. clone ( ) ) ;
249330 // skip first, first is function id
250331 for oprand in & par. operands [ 1 ..] {
251- let insts = get_const_arg_operands ( variables, insts, oprand) ?;
332+ let insts = get_const_arg_operands ( variables, insts, ref_stores , oprand) ?;
252333 result. extend ( insts) ;
253334 }
335+ } else if par. class . opcode == Op :: Variable {
336+ result. vars . push ( par. clone ( ) ) ;
337+ let stored = ref_stores. get ( & id) ?;
338+ let stored = ( * stored) ?;
339+ result. insts . push ( Instruction :: new (
340+ Op :: Store ,
341+ None ,
342+ None ,
343+ vec ! [ Operand :: IdRef ( id) , Operand :: IdRef ( stored) ] ,
344+ ) ) ;
345+ let new_insts = get_const_arg_insts_rec ( variables, insts, ref_stores, stored) ?;
346+ result. extend ( new_insts) ;
347+ } else if par. class . opcode == Op :: ArrayLength {
348+ result. insts . push ( par. clone ( ) ) ;
349+ let insts = get_const_arg_operands ( variables, insts, ref_stores, & par. operands [ 0 ] ) ?;
350+ result. extend ( insts) ;
254351 } else {
255352 // TOOD add more cases
256353 return None ;
0 commit comments