@@ -37,6 +37,126 @@ function GPUCompiler.finish_module!(job::oneAPICompilerJob, mod::LLVM.Module,
3737 return entry
3838end
3939
40+ # finish_ir! runs later in the pipeline, after optimizations that create nested insertvalue
41+ function GPUCompiler. finish_ir! (job:: oneAPICompilerJob , mod:: LLVM.Module ,
42+ entry:: LLVM.Function )
43+ entry = invoke (GPUCompiler. finish_ir!,
44+ Tuple{CompilerJob{SPIRVCompilerTarget}, typeof (mod), typeof (entry)},
45+ job, mod, entry)
46+
47+ # FIX: Flatten nested insertvalue instructions to work around SPIR-V bug
48+ # See: https://github.com/JuliaGPU/oneAPI.jl/issues/259
49+ # Intel's SPIR-V runtime has a bug where OpCompositeInsert with nested
50+ # indices (e.g., "1 0") corrupts adjacent struct fields.
51+ flatten_nested_insertvalue! (mod)
52+
53+ return entry
54+ end
55+
56+ # Flatten nested insertvalue instructions
57+ # This works around a bug in Intel's SPIR-V runtime where OpCompositeInsert
58+ # with nested array indices corrupts adjacent struct fields.
59+ function flatten_nested_insertvalue! (mod:: LLVM.Module )
60+ changed = false
61+ count = 0
62+
63+ for f in functions (mod)
64+ isempty (blocks (f)) && continue
65+
66+ for bb in blocks (f)
67+ # Collect instructions to process (can't modify while iterating)
68+ to_process = LLVM. Instruction[]
69+
70+ for inst in instructions (bb)
71+ # Check if this is an insertvalue with nested indices
72+ if LLVM. API. LLVMGetInstructionOpcode (inst) == LLVM. API. LLVMInsertValue
73+ num_indices = LLVM. API. LLVMGetNumIndices (inst)
74+ if num_indices > 1
75+ push! (to_process, inst)
76+ end
77+ end
78+ end
79+
80+ # Flatten each nested insertvalue
81+ for inst in to_process
82+ try
83+ flatten_insert! (inst)
84+ changed = true
85+ count += 1
86+ catch e
87+ @warn " Failed to flatten nested insertvalue" exception= (e, catch_backtrace ())
88+ end
89+ end
90+ end
91+ end
92+
93+ return changed
94+ end
95+
96+ function flatten_insert! (inst:: LLVM.Instruction )
97+ # Transform: insertvalue %base, %val, i, j, k...
98+ # Into: extractvalue %base, i
99+ # insertvalue %extracted, %val, j, k...
100+ # insertvalue %base, %modified, i
101+
102+ composite = LLVM. operands (inst)[1 ]
103+ value = LLVM. operands (inst)[2 ]
104+
105+ num_indices = LLVM. API. LLVMGetNumIndices (inst)
106+ idx_ptr = LLVM. API. LLVMGetIndices (inst)
107+ indices = unsafe_wrap (Array, idx_ptr, num_indices)
108+
109+ builder = LLVM. IRBuilder ()
110+ LLVM. position! (builder, inst)
111+
112+ # Strategy: Recursively extract and insert for each nesting level
113+ # For insertvalue %base, %val, i, j, k
114+ # Do: %tmp1 = extractvalue %base, i
115+ # %tmp2 = extractvalue %tmp1, j
116+ # %tmp3 = insertvalue %tmp2, %val, k
117+ # %tmp4 = insertvalue %tmp1, %tmp3, j
118+ # %result = insertvalue %base, %tmp4, i
119+
120+ # But that's complex. Simpler approach for 2-3 levels:
121+ # Just do one level of flattening at a time
122+ first_idx = indices[1 ]
123+ rest_indices = indices[2 : end ]
124+
125+ # Extract the first level
126+ extracted = LLVM. extract_value! (builder, composite, first_idx)
127+
128+ # Now insert into the extracted value using remaining indices
129+ # The LLVM IR builder will handle this correctly
130+ inserted = extracted
131+ if length (rest_indices) == 1
132+ # Simple case: just one more level
133+ inserted = LLVM. insert_value! (builder, extracted, value, rest_indices[1 ])
134+ else
135+ # Multiple levels: need to extract down, insert, then insert back up
136+ # For now, recursively extract to the deepest level
137+ temps = [extracted]
138+ for i in 1 : (length (rest_indices)- 1 )
139+ temp = LLVM. extract_value! (builder, temps[end ], rest_indices[i])
140+ push! (temps, temp)
141+ end
142+
143+ # Insert the value at the deepest level
144+ inserted = LLVM. insert_value! (builder, temps[end ], value, rest_indices[end ])
145+
146+ # Insert back up the chain
147+ for i in (length (rest_indices)- 1 ): - 1 : 1
148+ inserted = LLVM. insert_value! (builder, temps[i], inserted, rest_indices[i])
149+ end
150+ end
151+
152+ # Insert the modified structure back into the original
153+ result = LLVM. insert_value! (builder, composite, inserted, first_idx)
154+
155+ LLVM. replace_uses! (inst, result)
156+ LLVM. API. LLVMInstructionEraseFromParent (inst)
157+ LLVM. dispose (builder)
158+ end
159+
40160
41161# # compiler implementation (cache, configure, compile, and link)
42162
68188 supports_fp64 = oneL0. module_properties (device ()). fp64flags & oneL0. ZE_DEVICE_MODULE_FLAG_FP64 == oneL0. ZE_DEVICE_MODULE_FLAG_FP64
69189
70190 # TODO : emit printf format strings in constant memory
71- extensions = String[" SPV_EXT_relaxed_printf_string_address_space" ]
191+ extensions = String[
192+ " SPV_EXT_relaxed_printf_string_address_space" ,
193+ " SPV_EXT_shader_atomic_float_add"
194+ ]
72195
73196 # create GPUCompiler objects
74197 target = SPIRVCompilerTarget (; extensions, supports_fp16, supports_fp64, kwargs... )
0 commit comments