Skip to content

Commit f81c38a

Browse files
committed
fix: correct launch configuration
1 parent b34a9a0 commit f81c38a

File tree

6 files changed

+52
-17
lines changed

6 files changed

+52
-17
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ struct TritonMetadata{CK,MD,DP}
5959
max_num_threads::Int
6060
end
6161

62-
normalize_grid(grid_fn, metadata) = normalize_grid(grid_fn(metadata), metadata)
63-
normalize_grid(grid::Integer, metadata) = normalize_grid((grid,), metadata)
64-
function normalize_grid(grid::Dims{N}, metadata) where {N}
62+
canonicalize_grid(grid_fn, metadata) = canonicalize_grid(grid_fn(metadata), metadata)
63+
canonicalize_grid(grid::Integer, metadata) = canonicalize_grid((grid,), metadata)
64+
function canonicalize_grid(grid::Dims{N}, metadata) where {N}
6565
@assert N <= 3
6666
@assert all(grid .> 0)
6767
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -82,6 +82,7 @@ function overlayed_pycall_with_triton(
8282
num_ctas::Integer=1,
8383
hints=nothing,
8484
)
85+
@assert num_ctas == 1 "TODO: num_ctas > 1 not supported"
8586
triton = tritonptr[]
8687

8788
mapped = map(signature_string, args)
@@ -163,7 +164,7 @@ function overlayed_pycall_with_triton(
163164
Int(n_max_threads[]),
164165
)
165166

166-
grid = normalize_grid(grid, metadata)
167+
grid = canonicalize_grid(grid, metadata)
167168

168169
return @opcall triton_call(
169170
pyconvert(String, compiled_kernel.asm["source"]),
@@ -177,5 +178,7 @@ function overlayed_pycall_with_triton(
177178
block_z=@opcall(constant(1)),
178179
num_ctas,
179180
num_warps,
181+
threads_per_warp=device_properties.warp_size,
182+
enable_source_remat=false,
180183
)
181184
end

src/CompileOptions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ function CompileOptions(;
230230
:just_batch,
231231
:none,
232232
:no_triton,
233+
:before_triton_lowering,
233234
]
234235
end
235236

src/Compiler.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,9 +1330,7 @@ function triton_optimization_passes(device_properties)
13301330
"cse",
13311331
"symbol-dce",
13321332
"triton-loop-unroll",
1333-
"preserve-triton-warps-ctas{save=true restore=false}",
1334-
"convert-triton-to-tritongpu{target=cuda:$(major_version)$(minor_version)}",
1335-
"preserve-triton-warps-ctas{save=false restore=true}",
1333+
"convert-triton-to-triton-gpu-preserving-module-attributes{target=cuda:$(major_version)$(minor_version)}",
13361334
"tritongpu-coalesce",
13371335
"tritongpu-F32DotTC",
13381336
"triton-nvidia-gpu-plan-cta",
@@ -1933,6 +1931,31 @@ function compile_mlir!(
19331931
),
19341932
"no_triton",
19351933
)
1934+
elseif compile_options.optimization_passes === :before_triton_lowering
1935+
run_pass_pipeline!(
1936+
mod,
1937+
join(
1938+
if compile_options.raise_first
1939+
["mark-func-memory-effects", opt_passes]
1940+
else
1941+
[
1942+
"mark-func-memory-effects",
1943+
opt_passes,
1944+
"enzyme-batch",
1945+
opt_passes2,
1946+
enzyme_pass,
1947+
opt_passes_with_triton,
1948+
"canonicalize",
1949+
"remove-unnecessary-enzyme-ops",
1950+
"enzyme-simplify-math",
1951+
legalize_chlo_to_stablehlo...,
1952+
opt_passes2,
1953+
]
1954+
end,
1955+
',',
1956+
),
1957+
"before_triton_lowering",
1958+
)
19361959
elseif compile_options.optimization_passes === :before_kernel
19371960
run_pass_pipeline!(
19381961
mod,

src/Ops.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,14 +1852,22 @@ function triton_call(
18521852
block_z::TracedRNumber{<:Integer},
18531853
num_ctas::Integer=1,
18541854
num_warps::Integer=4,
1855+
threads_per_warp::Integer=32,
1856+
enable_source_remat::Bool=false,
18551857
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18561858
)
18571859
_, symref, modop = _extract_function(
18581860
mlir_code; func_name, func_op_kind="tt.func", location
18591861
)
18601862

1861-
MLIR.IR.attr!(modop, "ttg.num-wraps", MLIR.IR.Attribute(Int32(num_warps)))
1862-
MLIR.IR.attr!(modop, "ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas)))
1863+
MLIR.IR.attr!(modop, "enzymexla.ttg.num-warps", MLIR.IR.Attribute(Int32(num_warps)))
1864+
MLIR.IR.attr!(modop, "enzymexla.ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas)))
1865+
MLIR.IR.attr!(
1866+
modop, "enzymexla.ttg.threads-per-warp", MLIR.IR.Attribute(Int32(threads_per_warp))
1867+
)
1868+
if enable_source_remat
1869+
MLIR.IR.attr!(modop, "enzymexla.ttg.enable-source-remat", MLIR.IR.UnitAttribute())
1870+
end
18631871

18641872
result_types = MLIR.IR.Type[]
18651873
output_operand_aliases = MLIR.IR.Attribute[]

test/integration/triton/layer_norm.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,19 @@ end
5353

5454
@testset "fused_layer_norm" begin
5555
if RunningOnCUDA
56-
x_ra = Reactant.to_rarray(rand(Float32, 256, 2056))
57-
weight_ra = Reactant.to_rarray(rand(Float32, 256))
58-
bias_ra = Reactant.to_rarray(rand(Float32, 256))
56+
x_ra = Reactant.to_rarray(rand(Float32, 257, 2056))
57+
weight_ra = Reactant.to_rarray(rand(Float32, 257))
58+
bias_ra = Reactant.to_rarray(rand(Float32, 257))
5959

6060
y_ra1, mean_ra1, rstd_ra1 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, false)
6161
y_ra2, mean_ra2, rstd_ra2 = @jit layer_norm_naive(x_ra, weight_ra, bias_ra)
6262
y_ra3, mean_ra3, rstd_ra3 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, true)
6363

64-
@test_broken y_ra1 y_ra2
65-
@test_broken y_ra2 y_ra3
66-
@test_broken mean_ra1 mean_ra2
64+
@test y_ra1 y_ra2
65+
@test y_ra2 y_ra3
66+
@test mean_ra1 mean_ra2
6767
@test mean_ra2 mean_ra3
68-
@test_broken rstd_ra1 rstd_ra2
68+
@test rstd_ra1 rstd_ra2
6969
@test rstd_ra2 rstd_ra3
7070
end
7171
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
6666
@safetestset "low_memory_dropout" include(
6767
"integration/triton/low_memory_dropout.jl"
6868
)
69-
@safetestset "layer norm" include("integration/triton/layer_norm.jl") # XXX
69+
@safetestset "layer norm" include("integration/triton/layer_norm.jl")
7070
# @safetestset "attention" include("integration/triton/attention.jl")
7171
@safetestset "libdevice" include("integration/triton/libdevice.jl")
7272
# @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl")

0 commit comments

Comments
 (0)