Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 124 additions & 1 deletion src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,126 @@ function GPUCompiler.finish_module!(job::oneAPICompilerJob, mod::LLVM.Module,
return entry
end

# finish_ir! runs later in the pipeline, after optimizations that create nested insertvalue
function GPUCompiler.finish_ir!(job::oneAPICompilerJob, mod::LLVM.Module,
entry::LLVM.Function)
entry = invoke(GPUCompiler.finish_ir!,
Tuple{CompilerJob{SPIRVCompilerTarget}, typeof(mod), typeof(entry)},
job, mod, entry)

# FIX: Flatten nested insertvalue instructions to work around SPIR-V bug
# See: https://github.com/JuliaGPU/oneAPI.jl/issues/259
# Intel's SPIR-V runtime has a bug where OpCompositeInsert with nested
# indices (e.g., "1 0") corrupts adjacent struct fields.
flatten_nested_insertvalue!(mod)

return entry
end

# Flatten nested insertvalue instructions
# This works around a bug in Intel's SPIR-V runtime where OpCompositeInsert
# with nested array indices corrupts adjacent struct fields.
function flatten_nested_insertvalue!(mod::LLVM.Module)
changed = false
count = 0

for f in functions(mod)
isempty(blocks(f)) && continue

for bb in blocks(f)
# Collect instructions to process (can't modify while iterating)
to_process = LLVM.Instruction[]

for inst in instructions(bb)
# Check if this is an insertvalue with nested indices
if LLVM.API.LLVMGetInstructionOpcode(inst) == LLVM.API.LLVMInsertValue
num_indices = LLVM.API.LLVMGetNumIndices(inst)
if num_indices > 1
push!(to_process, inst)
end
end
end

# Flatten each nested insertvalue
for inst in to_process
try
flatten_insert!(inst)
changed = true
count += 1
catch e
@warn "Failed to flatten nested insertvalue" exception=(e, catch_backtrace())
end
end
end
end

return changed
end

function flatten_insert!(inst::LLVM.Instruction)
# Transform: insertvalue %base, %val, i, j, k...
# Into: extractvalue %base, i
# insertvalue %extracted, %val, j, k...
# insertvalue %base, %modified, i

composite = LLVM.operands(inst)[1]
value = LLVM.operands(inst)[2]

num_indices = LLVM.API.LLVMGetNumIndices(inst)
idx_ptr = LLVM.API.LLVMGetIndices(inst)
indices = unsafe_wrap(Array, idx_ptr, num_indices)

builder = LLVM.IRBuilder()
LLVM.position!(builder, inst)

# Strategy: Recursively extract and insert for each nesting level
# For insertvalue %base, %val, i, j, k
# Do: %tmp1 = extractvalue %base, i
# %tmp2 = extractvalue %tmp1, j
# %tmp3 = insertvalue %tmp2, %val, k
# %tmp4 = insertvalue %tmp1, %tmp3, j
# %result = insertvalue %base, %tmp4, i

# But that's complex. Simpler approach for 2-3 levels:
# Just do one level of flattening at a time
first_idx = indices[1]
rest_indices = indices[2:end]

# Extract the first level
extracted = LLVM.extract_value!(builder, composite, first_idx)

# Now insert into the extracted value using remaining indices
# The LLVM IR builder will handle this correctly
inserted = extracted
if length(rest_indices) == 1
# Simple case: just one more level
inserted = LLVM.insert_value!(builder, extracted, value, rest_indices[1])
else
# Multiple levels: need to extract down, insert, then insert back up
# For now, recursively extract to the deepest level
temps = [extracted]
for i in 1:(length(rest_indices)-1)
temp = LLVM.extract_value!(builder, temps[end], rest_indices[i])
push!(temps, temp)
end

# Insert the value at the deepest level
inserted = LLVM.insert_value!(builder, temps[end], value, rest_indices[end])

# Insert back up the chain
for i in (length(rest_indices)-1):-1:1
inserted = LLVM.insert_value!(builder, temps[i], inserted, rest_indices[i])
end
end

# Insert the modified structure back into the original
result = LLVM.insert_value!(builder, composite, inserted, first_idx)

LLVM.replace_uses!(inst, result)
LLVM.API.LLVMInstructionEraseFromParent(inst)
LLVM.dispose(builder)
end


## compiler implementation (cache, configure, compile, and link)

Expand Down Expand Up @@ -68,7 +188,10 @@ end
supports_fp64 = oneL0.module_properties(device()).fp64flags & oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64

# TODO: emit printf format strings in constant memory
extensions = String["SPV_EXT_relaxed_printf_string_address_space"]
extensions = String[
"SPV_EXT_relaxed_printf_string_address_space",
"SPV_EXT_shader_atomic_float_add"
]

# create GPUCompiler objects
target = SPIRVCompilerTarget(; extensions, supports_fp16, supports_fp64, kwargs...)
Expand Down
45 changes: 45 additions & 0 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,48 @@ using oneAPI
mask = oneArray(Bool[true, false, true, false, false, true])
@test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])]
end

@testset "CartesianIndices with mapreduce" begin
# Test for bug fix: mapreduce with CartesianIndices and tuple reduction
# Previously failed due to SPIR-V codegen issues with nested insertvalue instructions
# when combining tuples of (bool, CartesianIndex) in reduction operations.
# The fix involved properly handling nested struct insertions in SPIR-V codegen.

# Test that we can zip CartesianIndices with array values in a mapreduce
# This tests the fix for nested tuple operations in SPIR-V codegen

# Simple test: sum of values while tracking indices
x = oneArray(ones(Int, 2, 2))
indices = CartesianIndices((2, 2))

# Map to tuple of (value, index), then reduce by summing the values
result = mapreduce(tuple, (t1, t2) -> (t1[1] + t2[1], t1[2]), x, indices;
init = (0, CartesianIndex(0, 0)))
@test result[1] == 4 # sum of four 1s

# Test with 1D array
y = oneArray(ones(Int, 4))
indices_1d = CartesianIndices((4,))
result_1d = mapreduce(tuple, (t1, t2) -> (t1[1] + t2[1], t1[2]), y, indices_1d;
init = (0, CartesianIndex(0,)))
@test result_1d[1] == 4

# Test with boolean array and index comparison (closer to original failure case)
# This pattern is similar to what findfirst would use internally
z = oneArray([false, true, false, true])
indices_z = CartesianIndices((4,))
result_z = mapreduce(tuple,
(t1, t2) -> begin
(found1, idx1), (found2, idx2) = t1, t2
# Return the first found index (smallest index if both found)
if found1
return (found1, idx1)
else
return (found2, idx2)
end
end,
z, indices_z;
init = (false, CartesianIndex(0,)))
@test result_z[1] == true # Found a true value
@test result_z[2] == CartesianIndex(2,) # First true is at index 2
end