Skip to content

Commit 2c54c77

Browse files
committed
feat: use new device properties [skip ci]
1 parent 8dddadc commit 2c54c77

File tree

3 files changed

+54
-19
lines changed

3 files changed

+54
-19
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,16 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
4848
end
4949

50-
function normalize_grid_and_blocks(grid_fn, metadata)
51-
return normalize_grid_and_blocks(grid_fn(metadata), metadata)
50+
function normalize_grid_and_blocks(grid_fn, metadata, device_properties)
51+
return normalize_grid_and_blocks(
52+
grid_fn(metadata, device_properties), metadata, device_properties
53+
)
5254
end
5355

54-
function normalize_grid_and_blocks(grid::Integer, metadata)
55-
return normalize_grid_and_blocks((grid,), metadata)
56+
function normalize_grid_and_blocks(grid::Integer, metadata, device_properties)
57+
return normalize_grid_and_blocks((grid,), metadata, device_properties)
5658
end
57-
function normalize_grid_and_blocks(grid::Dims{N}, metadata) where {N}
59+
function normalize_grid_and_blocks(grid::Dims{N}, metadata, device_properties) where {N}
5860
@assert N <= 3
5961
@assert all(grid .> 0)
6062
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -71,8 +73,9 @@ function overlayed_pycall_with_triton(
7173
args...;
7274
grid,
7375
blocks,
74-
num_warps::Integer=1,
76+
num_warps::Integer=4,
7577
num_stages::Integer=3,
78+
num_ctas::Integer=1,
7679
hints=nothing,
7780
)
7881
triton = tritonptr[]
@@ -105,16 +108,23 @@ function overlayed_pycall_with_triton(
105108
fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs
106109
)
107110

111+
# TODO: pass the device/client here from `compile`
112+
client = Reactant.XLA.default_backend()
113+
@assert Reactant.XLA.platform_name(client) == "cuda"
114+
device = Reactant.XLA.default_device(client)
115+
device_properties = Reactant.XLA.device_properties(device)
116+
108117
target = triton.backends.compiler.GPUTarget(
109-
"cuda",
110-
parse(Int, Reactant.Compiler.cubinChip[][4:end]),
111-
Reactant.Compiler.cuWarpSize[],
118+
Reactant.XLA.platform_name(client),
119+
parse(Int, "$(device_properties.major)$(device_properties.minor)"),
120+
device_properties.warp_size,
112121
)
113122
backend = triton.compiler.make_backend(target)
114123
options = backend.parse_options(
115124
pydict(
116125
"num_warps" => num_warps,
117126
"num_stages" => num_stages,
127+
"num_ctas" => num_ctas,
118128
"extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)),
119129
),
120130
)
@@ -123,8 +133,8 @@ function overlayed_pycall_with_triton(
123133
# we are compiling here + lowering again inside enzymejax
124134
ccinfo = triton.compile(src; target=target, options=options.__dict__)
125135

126-
grid = normalize_grid_and_blocks(grid, ccinfo.metadata)
127-
blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata)
136+
grid = normalize_grid_and_blocks(grid, ccinfo.metadata, device_properties)
137+
blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata, device_properties)
128138

129139
return @opcall triton_call(
130140
pyconvert(String, ccinfo.asm["source"]),
@@ -136,5 +146,8 @@ function overlayed_pycall_with_triton(
136146
block_x=@opcall(constant(blocks[1])),
137147
block_y=@opcall(constant(blocks[2])),
138148
block_z=@opcall(constant(blocks[3])),
149+
# The following are written to module attributes and restored later on
150+
num_ctas,
151+
num_warps,
139152
)
140153
end

src/Compiler.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,7 @@ function optimization_passes(
703703
max_constant_threshold::Int=1024,
704704
backend::String="gpu",
705705
enable_triton_passes::Bool=false,
706+
device_properties::Union{Nothing,XLA.DeviceProperties}=nothing,
706707
)
707708
transform_passes_list = [
708709
"patterns=compare_op_canon<16>",
@@ -1305,14 +1306,20 @@ function optimization_passes(
13051306
end
13061307
push!(passes, func_passes)
13071308
if enable_triton_passes && backend == "cuda"
1308-
push!(passes, triton_optimization_passes())
1309+
push!(passes, triton_optimization_passes(device_properties))
13091310
end
13101311
return join(passes, ',')
13111312
end
13121313

13131314
# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
13141315
# To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes
1315-
function triton_optimization_passes()
1316+
function triton_optimization_passes(device_properties)
1317+
@assert device_properties !== nothing "Device properties must be provided to run \
1318+
triton passes. This might happen if you are \
1319+
compiling a triton kernel for non-cuda backend."
1320+
major_version = device_properties.major
1321+
minor_version = device_properties.minor
1322+
13161323
all_passes = join(
13171324
[
13181325
"canonicalize",
@@ -1323,7 +1330,9 @@ function triton_optimization_passes()
13231330
"cse",
13241331
"symbol-dce",
13251332
"triton-loop-unroll",
1326-
"convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1333+
"preserve-triton-warps-ctas{save=true restore=false}",
1334+
"convert-triton-to-tritongpu{target=cuda:$(major_version)$(minor_version)}",
1335+
"preserve-triton-warps-ctas{save=false restore=true}",
13271336
"tritongpu-coalesce",
13281337
"tritongpu-F32DotTC",
13291338
"triton-nvidia-gpu-plan-cta",
@@ -1743,6 +1752,9 @@ function compile_mlir!(
17431752

17441753
toolkit = XLA.CUDA_DATA_DIR[]
17451754

1755+
default_device = XLA.default_device(client)
1756+
device_properties = XLA.device_properties(default_device)
1757+
17461758
if backend == "cpu" || backend == "tpu"
17471759
kern = "lower-kernel{backend=cpu},canonicalize"
17481760
if backend == "tpu"
@@ -1757,9 +1769,7 @@ function compile_mlir!(
17571769
"lower-kernel,canonicalize"
17581770
end
17591771

1760-
device_properties = XLA.device_properties(XLA.default_device(client))
17611772
cubinChip = "sm_$(device_properties.major)$(device_properties.minor)"
1762-
17631773
if DEBUG_KERNEL[]
17641774
curesulthandler = dlsym(
17651775
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
@@ -1790,6 +1800,7 @@ function compile_mlir!(
17901800
lower_comms,
17911801
backend,
17921802
enable_triton_passes=false,
1803+
device_properties,
17931804
)
17941805
opt_passes2 = optimization_passes(
17951806
compile_options;
@@ -1798,6 +1809,7 @@ function compile_mlir!(
17981809
lower_comms,
17991810
backend,
18001811
enable_triton_passes=false,
1812+
device_properties,
18011813
)
18021814
opt_passes_with_triton = optimization_passes(
18031815
compile_options;
@@ -1806,6 +1818,7 @@ function compile_mlir!(
18061818
lower_comms,
18071819
backend,
18081820
enable_triton_passes=true,
1821+
device_properties,
18091822
)
18101823

18111824
raise_passes = if raise isa String
@@ -1827,6 +1840,7 @@ function compile_mlir!(
18271840
recognize_comms,
18281841
lower_comms,
18291842
backend,
1843+
device_properties,
18301844
)
18311845
result = result * "," * opt_passes_dus_to_concat
18321846
end
@@ -2151,6 +2165,7 @@ function compile_mlir!(
21512165
recognize_comms,
21522166
lower_comms,
21532167
backend,
2168+
device_properties,
21542169
),
21552170
"post_op_transpose_reshape",
21562171
)

src/Ops.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,7 +1837,7 @@ function _extract_function(
18371837
error("hlo_call: could not find function $func_name in the provided module")
18381838
end
18391839

1840-
return fn, symref
1840+
return fn, symref, moduleop
18411841
end
18421842

18431843
function triton_call(
@@ -1850,9 +1850,16 @@ function triton_call(
18501850
block_x::TracedRNumber{<:Integer},
18511851
block_y::TracedRNumber{<:Integer},
18521852
block_z::TracedRNumber{<:Integer},
1853+
num_ctas::Integer=1,
1854+
num_warps::Integer=4,
18531855
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18541856
)
1855-
_, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location)
1857+
_, symref, modop = _extract_function(
1858+
mlir_code; func_name, func_op_kind="tt.func", location
1859+
)
1860+
1861+
MLIR.IR.attr!(modop, "ttg.num-wraps", MLIR.IR.Attribute(Int32(num_warps)))
1862+
MLIR.IR.attr!(modop, "ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas)))
18561863

18571864
result_types = MLIR.IR.Type[]
18581865
output_operand_aliases = MLIR.IR.Attribute[]
@@ -1929,7 +1936,7 @@ julia> Reactant.@jit(
19291936
func_name="main",
19301937
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
19311938
)
1932-
fn, symref = _extract_function(code; func_name, func_op_kind="func.func", location)
1939+
fn, symref, _ = _extract_function(code; func_name, func_op_kind="func.func", location)
19331940

19341941
ftype_attr = MLIR.IR.attr(fn, "function_type")
19351942
ftype = MLIR.IR.Type(ftype_attr)

0 commit comments

Comments
 (0)