Skip to content

Commit dd71541

Browse files
committed
fix: correct launch configuration
1 parent 1381077 commit dd71541

File tree

8 files changed

+54
-20
lines changed

8 files changed

+54
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ PythonCall = "0.9.25"
105105
Random = "1.10"
106106
Random123 = "1.7"
107107
ReactantCore = "0.1.16"
108-
Reactant_jll = "0.0.252"
108+
Reactant_jll = "0.0.251"
109109
ScopedValues = "1.3.0"
110110
Scratch = "1.2"
111111
Sockets = "1.10"

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
33
NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44
NSYNC_SHA256 = ""
55

6-
ENZYMEXLA_COMMIT = "8221b6147f497592205e6f558b1609e2964f3330"
7-
6+
ENZYMEXLA_COMMIT = "4d71da26119a84662cd6f5252a68a35ca1673eae"
87
ENZYMEXLA_SHA256 = ""
98

109
http_archive(

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ struct TritonMetadata{CK,MD,DP}
5959
max_num_threads::Int
6060
end
6161

62-
normalize_grid(grid_fn, metadata) = normalize_grid(grid_fn(metadata), metadata)
63-
normalize_grid(grid::Integer, metadata) = normalize_grid((grid,), metadata)
64-
function normalize_grid(grid::Dims{N}, metadata) where {N}
62+
canonicalize_grid(grid_fn, metadata) = canonicalize_grid(grid_fn(metadata), metadata)
63+
canonicalize_grid(grid::Integer, metadata) = canonicalize_grid((grid,), metadata)
64+
function canonicalize_grid(grid::Dims{N}, metadata) where {N}
6565
@assert N <= 3
6666
@assert all(grid .> 0)
6767
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -82,6 +82,7 @@ function overlayed_pycall_with_triton(
8282
num_ctas::Integer=1,
8383
hints=nothing,
8484
)
85+
@assert num_ctas == 1 "TODO: num_ctas > 1 not supported"
8586
triton = tritonptr[]
8687

8788
mapped = map(signature_string, args)
@@ -163,7 +164,7 @@ function overlayed_pycall_with_triton(
163164
Int(n_max_threads[]),
164165
)
165166

166-
grid = normalize_grid(grid, metadata)
167+
grid = canonicalize_grid(grid, metadata)
167168

168169
return @opcall triton_call(
169170
pyconvert(String, compiled_kernel.asm["source"]),
@@ -177,5 +178,7 @@ function overlayed_pycall_with_triton(
177178
block_z=@opcall(constant(1)),
178179
num_ctas,
179180
num_warps,
181+
threads_per_warp=device_properties.warp_size,
182+
enable_source_remat=false,
180183
)
181184
end

src/CompileOptions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ function CompileOptions(;
230230
:just_batch,
231231
:none,
232232
:no_triton,
233+
:before_triton_lowering,
233234
]
234235
end
235236

src/Compiler.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,9 +1327,7 @@ function triton_optimization_passes(device_properties)
13271327
"cse",
13281328
"symbol-dce",
13291329
"triton-loop-unroll",
1330-
"preserve-triton-warps-ctas{save=true restore=false}",
1331-
"convert-triton-to-tritongpu{target=cuda:$(major_version)$(minor_version)}",
1332-
"preserve-triton-warps-ctas{save=false restore=true}",
1330+
"convert-triton-to-triton-gpu-preserving-module-attributes{target=cuda:$(major_version)$(minor_version)}",
13331331
"tritongpu-coalesce",
13341332
"tritongpu-F32DotTC",
13351333
"triton-nvidia-gpu-plan-cta",
@@ -1930,6 +1928,31 @@ function compile_mlir!(
19301928
),
19311929
"no_triton",
19321930
)
1931+
elseif compile_options.optimization_passes === :before_triton_lowering
1932+
run_pass_pipeline!(
1933+
mod,
1934+
join(
1935+
if compile_options.raise_first
1936+
["mark-func-memory-effects", opt_passes]
1937+
else
1938+
[
1939+
"mark-func-memory-effects",
1940+
opt_passes,
1941+
"enzyme-batch",
1942+
opt_passes2,
1943+
enzyme_pass,
1944+
opt_passes_with_triton,
1945+
"canonicalize",
1946+
"remove-unnecessary-enzyme-ops",
1947+
"enzyme-simplify-math",
1948+
legalize_chlo_to_stablehlo...,
1949+
opt_passes2,
1950+
]
1951+
end,
1952+
',',
1953+
),
1954+
"before_triton_lowering",
1955+
)
19331956
elseif compile_options.optimization_passes === :before_kernel
19341957
run_pass_pipeline!(
19351958
mod,

src/Ops.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,14 +1852,22 @@ function triton_call(
18521852
block_z::TracedRNumber{<:Integer},
18531853
num_ctas::Integer=1,
18541854
num_warps::Integer=4,
1855+
threads_per_warp::Integer=32,
1856+
enable_source_remat::Bool=false,
18551857
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18561858
)
18571859
_, symref, modop = _extract_function(
18581860
mlir_code; func_name, func_op_kind="tt.func", location
18591861
)
18601862

1861-
MLIR.IR.attr!(modop, "ttg.num-wraps", MLIR.IR.Attribute(Int32(num_warps)))
1862-
MLIR.IR.attr!(modop, "ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas)))
1863+
MLIR.IR.attr!(modop, "enzymexla.ttg.num-warps", MLIR.IR.Attribute(Int32(num_warps)))
1864+
MLIR.IR.attr!(modop, "enzymexla.ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas)))
1865+
MLIR.IR.attr!(
1866+
modop, "enzymexla.ttg.threads-per-warp", MLIR.IR.Attribute(Int32(threads_per_warp))
1867+
)
1868+
if enable_source_remat
1869+
MLIR.IR.attr!(modop, "enzymexla.ttg.enable-source-remat", MLIR.IR.UnitAttribute())
1870+
end
18631871

18641872
result_types = MLIR.IR.Type[]
18651873
output_operand_aliases = MLIR.IR.Attribute[]

test/integration/triton/layer_norm.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,19 @@ end
5353

5454
@testset "fused_layer_norm" begin
5555
if RunningOnCUDA
56-
x_ra = Reactant.to_rarray(rand(Float32, 256, 2056))
57-
weight_ra = Reactant.to_rarray(rand(Float32, 256))
58-
bias_ra = Reactant.to_rarray(rand(Float32, 256))
56+
x_ra = Reactant.to_rarray(rand(Float32, 257, 2056))
57+
weight_ra = Reactant.to_rarray(rand(Float32, 257))
58+
bias_ra = Reactant.to_rarray(rand(Float32, 257))
5959

6060
y_ra1, mean_ra1, rstd_ra1 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, false)
6161
y_ra2, mean_ra2, rstd_ra2 = @jit layer_norm_naive(x_ra, weight_ra, bias_ra)
6262
y_ra3, mean_ra3, rstd_ra3 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, true)
6363

64-
@test_broken y_ra1 y_ra2
65-
@test_broken y_ra2 y_ra3
66-
@test_broken mean_ra1 mean_ra2
64+
@test y_ra1 y_ra2
65+
@test y_ra2 y_ra3
66+
@test mean_ra1 mean_ra2
6767
@test mean_ra2 mean_ra3
68-
@test_broken rstd_ra1 rstd_ra2
68+
@test rstd_ra1 rstd_ra2
6969
@test rstd_ra2 rstd_ra3
7070
end
7171
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
6666
@safetestset "low_memory_dropout" include(
6767
"integration/triton/low_memory_dropout.jl"
6868
)
69-
@safetestset "layer norm" include("integration/triton/layer_norm.jl") # XXX
69+
@safetestset "layer norm" include("integration/triton/layer_norm.jl")
7070
# @safetestset "attention" include("integration/triton/attention.jl")
7171
@safetestset "libdevice" include("integration/triton/libdevice.jl")
7272
# @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl")

0 commit comments

Comments
 (0)