Skip to content

Commit 0fbc96e

Browse files
authored
Flatten nested insertvalue #259 (#548)
1 parent 24110a8 commit 0fbc96e

File tree

2 files changed

+169
-1
lines changed

2 files changed

+169
-1
lines changed

src/compiler/compilation.jl

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,126 @@ function GPUCompiler.finish_module!(job::oneAPICompilerJob, mod::LLVM.Module,
3737
return entry
3838
end
3939

40+
# finish_ir! runs later in the pipeline, after optimizations that create nested insertvalue
41+
function GPUCompiler.finish_ir!(job::oneAPICompilerJob, mod::LLVM.Module,
42+
entry::LLVM.Function)
43+
entry = invoke(GPUCompiler.finish_ir!,
44+
Tuple{CompilerJob{SPIRVCompilerTarget}, typeof(mod), typeof(entry)},
45+
job, mod, entry)
46+
47+
# FIX: Flatten nested insertvalue instructions to work around SPIR-V bug
48+
# See: https://github.com/JuliaGPU/oneAPI.jl/issues/259
49+
# Intel's SPIR-V runtime has a bug where OpCompositeInsert with nested
50+
# indices (e.g., "1 0") corrupts adjacent struct fields.
51+
flatten_nested_insertvalue!(mod)
52+
53+
return entry
54+
end
55+
56+
# Flatten nested insertvalue instructions
57+
# This works around a bug in Intel's SPIR-V runtime where OpCompositeInsert
58+
# with nested array indices corrupts adjacent struct fields.
59+
function flatten_nested_insertvalue!(mod::LLVM.Module)
60+
changed = false
61+
count = 0
62+
63+
for f in functions(mod)
64+
isempty(blocks(f)) && continue
65+
66+
for bb in blocks(f)
67+
# Collect instructions to process (can't modify while iterating)
68+
to_process = LLVM.Instruction[]
69+
70+
for inst in instructions(bb)
71+
# Check if this is an insertvalue with nested indices
72+
if LLVM.API.LLVMGetInstructionOpcode(inst) == LLVM.API.LLVMInsertValue
73+
num_indices = LLVM.API.LLVMGetNumIndices(inst)
74+
if num_indices > 1
75+
push!(to_process, inst)
76+
end
77+
end
78+
end
79+
80+
# Flatten each nested insertvalue
81+
for inst in to_process
82+
try
83+
flatten_insert!(inst)
84+
changed = true
85+
count += 1
86+
catch e
87+
@warn "Failed to flatten nested insertvalue" exception=(e, catch_backtrace())
88+
end
89+
end
90+
end
91+
end
92+
93+
return changed
94+
end
95+
96+
function flatten_insert!(inst::LLVM.Instruction)
97+
# Transform: insertvalue %base, %val, i, j, k...
98+
# Into: extractvalue %base, i
99+
# insertvalue %extracted, %val, j, k...
100+
# insertvalue %base, %modified, i
101+
102+
composite = LLVM.operands(inst)[1]
103+
value = LLVM.operands(inst)[2]
104+
105+
num_indices = LLVM.API.LLVMGetNumIndices(inst)
106+
idx_ptr = LLVM.API.LLVMGetIndices(inst)
107+
indices = unsafe_wrap(Array, idx_ptr, num_indices)
108+
109+
builder = LLVM.IRBuilder()
110+
LLVM.position!(builder, inst)
111+
112+
# Strategy: Recursively extract and insert for each nesting level
113+
# For insertvalue %base, %val, i, j, k
114+
# Do: %tmp1 = extractvalue %base, i
115+
# %tmp2 = extractvalue %tmp1, j
116+
# %tmp3 = insertvalue %tmp2, %val, k
117+
# %tmp4 = insertvalue %tmp1, %tmp3, j
118+
# %result = insertvalue %base, %tmp4, i
119+
120+
# But that's complex. Simpler approach for 2-3 levels:
121+
# Just do one level of flattening at a time
122+
first_idx = indices[1]
123+
rest_indices = indices[2:end]
124+
125+
# Extract the first level
126+
extracted = LLVM.extract_value!(builder, composite, first_idx)
127+
128+
# Now insert into the extracted value using remaining indices
129+
# The LLVM IR builder will handle this correctly
130+
inserted = extracted
131+
if length(rest_indices) == 1
132+
# Simple case: just one more level
133+
inserted = LLVM.insert_value!(builder, extracted, value, rest_indices[1])
134+
else
135+
# Multiple levels: need to extract down, insert, then insert back up
136+
# For now, recursively extract to the deepest level
137+
temps = [extracted]
138+
for i in 1:(length(rest_indices)-1)
139+
temp = LLVM.extract_value!(builder, temps[end], rest_indices[i])
140+
push!(temps, temp)
141+
end
142+
143+
# Insert the value at the deepest level
144+
inserted = LLVM.insert_value!(builder, temps[end], value, rest_indices[end])
145+
146+
# Insert back up the chain
147+
for i in (length(rest_indices)-1):-1:1
148+
inserted = LLVM.insert_value!(builder, temps[i], inserted, rest_indices[i])
149+
end
150+
end
151+
152+
# Insert the modified structure back into the original
153+
result = LLVM.insert_value!(builder, composite, inserted, first_idx)
154+
155+
LLVM.replace_uses!(inst, result)
156+
LLVM.API.LLVMInstructionEraseFromParent(inst)
157+
LLVM.dispose(builder)
158+
end
159+
40160

41161
## compiler implementation (cache, configure, compile, and link)
42162

@@ -68,7 +188,10 @@ end
68188
supports_fp64 = oneL0.module_properties(device()).fp64flags & oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64
69189

70190
# TODO: emit printf format strings in constant memory
71-
extensions = String["SPV_EXT_relaxed_printf_string_address_space"]
191+
extensions = String[
192+
"SPV_EXT_relaxed_printf_string_address_space",
193+
"SPV_EXT_shader_atomic_float_add"
194+
]
72195

73196
# create GPUCompiler objects
74197
target = SPIRVCompilerTarget(; extensions, supports_fp16, supports_fp64, kwargs...)

test/indexing.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,48 @@ using oneAPI
1818
mask = oneArray(Bool[true, false, true, false, false, true])
1919
@test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])]
2020
end
21+
22+
@testset "CartesianIndices with mapreduce" begin
23+
# Test for bug fix: mapreduce with CartesianIndices and tuple reduction
24+
# Previously failed due to SPIR-V codegen issues with nested insertvalue instructions
25+
# when combining tuples of (bool, CartesianIndex) in reduction operations.
26+
# The fix involved properly handling nested struct insertions in SPIR-V codegen.
27+
28+
# Test that we can zip CartesianIndices with array values in a mapreduce
29+
# This tests the fix for nested tuple operations in SPIR-V codegen
30+
31+
# Simple test: sum of values while tracking indices
32+
x = oneArray(ones(Int, 2, 2))
33+
indices = CartesianIndices((2, 2))
34+
35+
# Map to tuple of (value, index), then reduce by summing the values
36+
result = mapreduce(tuple, (t1, t2) -> (t1[1] + t2[1], t1[2]), x, indices;
37+
init = (0, CartesianIndex(0, 0)))
38+
@test result[1] == 4 # sum of four 1s
39+
40+
# Test with 1D array
41+
y = oneArray(ones(Int, 4))
42+
indices_1d = CartesianIndices((4,))
43+
result_1d = mapreduce(tuple, (t1, t2) -> (t1[1] + t2[1], t1[2]), y, indices_1d;
44+
init = (0, CartesianIndex(0,)))
45+
@test result_1d[1] == 4
46+
47+
# Test with boolean array and index comparison (closer to original failure case)
48+
# This pattern is similar to what findfirst would use internally
49+
z = oneArray([false, true, false, true])
50+
indices_z = CartesianIndices((4,))
51+
result_z = mapreduce(tuple,
52+
(t1, t2) -> begin
53+
(found1, idx1), (found2, idx2) = t1, t2
54+
# Return the first found index (smallest index if both found)
55+
if found1
56+
return (found1, idx1)
57+
else
58+
return (found2, idx2)
59+
end
60+
end,
61+
z, indices_z;
62+
init = (false, CartesianIndex(0,)))
63+
@test result_z[1] == true # Found a true value
64+
@test result_z[2] == CartesianIndex(2,) # First true is at index 2
65+
end

0 commit comments

Comments
 (0)