Skip to content

Commit 04cbf60

Browse files
committed
feat: use new device properties [skip ci]
1 parent 827fbf4 commit 04cbf60

File tree

4 files changed

+55
-20
lines changed

4 files changed

+55
-20
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "f2072aa2031eb6a1d5d1972d3a95340fb67c9480"
7+
ENZYMEXLA_COMMIT = "8221b6147f497592205e6f558b1609e2964f3330"
88

99
ENZYMEXLA_SHA256 = ""
1010

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,16 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4747
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
4848
end
4949

50-
function normalize_grid_and_blocks(grid_fn, metadata)
51-
return normalize_grid_and_blocks(grid_fn(metadata), metadata)
50+
function normalize_grid_and_blocks(grid_fn, metadata, device_properties)
51+
return normalize_grid_and_blocks(
52+
grid_fn(metadata, device_properties), metadata, device_properties
53+
)
5254
end
5355

54-
function normalize_grid_and_blocks(grid::Integer, metadata)
55-
return normalize_grid_and_blocks((grid,), metadata)
56+
function normalize_grid_and_blocks(grid::Integer, metadata, device_properties)
57+
return normalize_grid_and_blocks((grid,), metadata, device_properties)
5658
end
57-
function normalize_grid_and_blocks(grid::Dims{N}, metadata) where {N}
59+
function normalize_grid_and_blocks(grid::Dims{N}, metadata, device_properties) where {N}
5860
@assert N <= 3
5961
@assert all(grid .> 0)
6062
return (grid..., ntuple(_ -> 1, 3 - N)...)
@@ -71,8 +73,9 @@ function overlayed_pycall_with_triton(
7173
args...;
7274
grid,
7375
blocks,
74-
num_warps::Integer=1,
76+
num_warps::Integer=4,
7577
num_stages::Integer=3,
78+
num_ctas::Integer=1,
7679
hints=nothing,
7780
)
7881
triton = tritonptr[]
@@ -105,16 +108,23 @@ function overlayed_pycall_with_triton(
105108
fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs
106109
)
107110

111+
# TODO: pass the device/client here from `compile`
112+
client = Reactant.XLA.default_backend()
113+
@assert Reactant.XLA.platform_name(client) == "cuda"
114+
device = Reactant.XLA.default_device(client)
115+
device_properties = Reactant.XLA.device_properties(device)
116+
108117
target = triton.backends.compiler.GPUTarget(
109-
"cuda",
110-
parse(Int, Reactant.Compiler.cubinChip[][4:end]),
111-
Reactant.Compiler.cuWarpSize[],
118+
Reactant.XLA.platform_name(client),
119+
parse(Int, "$(device_properties.major)$(device_properties.minor)"),
120+
device_properties.warp_size,
112121
)
113122
backend = triton.compiler.make_backend(target)
114123
options = backend.parse_options(
115124
pydict(
116125
"num_warps" => num_warps,
117126
"num_stages" => num_stages,
127+
"num_ctas" => num_ctas,
118128
"extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)),
119129
),
120130
)
@@ -123,8 +133,8 @@ function overlayed_pycall_with_triton(
123133
# we are compiling here + lowering again inside enzymejax
124134
ccinfo = triton.compile(src; target=target, options=options.__dict__)
125135

126-
grid = normalize_grid_and_blocks(grid, ccinfo.metadata)
127-
blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata)
136+
grid = normalize_grid_and_blocks(grid, ccinfo.metadata, device_properties)
137+
blocks = normalize_grid_and_blocks(blocks, ccinfo.metadata, device_properties)
128138

129139
return @opcall triton_call(
130140
pyconvert(String, ccinfo.asm["source"]),
@@ -136,5 +146,8 @@ function overlayed_pycall_with_triton(
136146
block_x=@opcall(constant(blocks[1])),
137147
block_y=@opcall(constant(blocks[2])),
138148
block_z=@opcall(constant(blocks[3])),
149+
# The following are written to module attributes and restored later on
150+
num_ctas,
151+
num_warps,
139152
)
140153
end

src/Compiler.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,7 @@ function optimization_passes(
703703
max_constant_threshold::Int=1024,
704704
backend::String="gpu",
705705
enable_triton_passes::Bool=false,
706+
device_properties::Union{Nothing,XLA.DeviceProperties}=nothing,
706707
)
707708
transform_passes_list = [
708709
"patterns=compare_op_canon<16>",
@@ -1302,14 +1303,20 @@ function optimization_passes(
13021303
end
13031304
push!(passes, func_passes)
13041305
if enable_triton_passes && backend == "cuda"
1305-
push!(passes, triton_optimization_passes())
1306+
push!(passes, triton_optimization_passes(device_properties))
13061307
end
13071308
return join(passes, ',')
13081309
end
13091310

13101311
# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc
13111312
# To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes
1312-
function triton_optimization_passes()
1313+
function triton_optimization_passes(device_properties)
1314+
@assert device_properties !== nothing "Device properties must be provided to run \
1315+
triton passes. This might happen if you are \
1316+
compiling a triton kernel for non-cuda backend."
1317+
major_version = device_properties.major
1318+
minor_version = device_properties.minor
1319+
13131320
all_passes = join(
13141321
[
13151322
"canonicalize",
@@ -1320,7 +1327,9 @@ function triton_optimization_passes()
13201327
"cse",
13211328
"symbol-dce",
13221329
"triton-loop-unroll",
1323-
"convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1330+
"preserve-triton-warps-ctas{save=true restore=false}",
1331+
"convert-triton-to-tritongpu{target=cuda:$(major_version)$(minor_version)}",
1332+
"preserve-triton-warps-ctas{save=false restore=true}",
13241333
"tritongpu-coalesce",
13251334
"tritongpu-F32DotTC",
13261335
"triton-nvidia-gpu-plan-cta",
@@ -1740,6 +1749,9 @@ function compile_mlir!(
17401749

17411750
toolkit = XLA.CUDA_DATA_DIR[]
17421751

1752+
default_device = XLA.default_device(client)
1753+
device_properties = XLA.device_properties(default_device)
1754+
17431755
if backend == "cpu" || backend == "tpu"
17441756
kern = "lower-kernel{backend=cpu},canonicalize"
17451757
if backend == "tpu"
@@ -1754,9 +1766,7 @@ function compile_mlir!(
17541766
"lower-kernel,canonicalize"
17551767
end
17561768

1757-
device_properties = XLA.device_properties(XLA.default_device(client))
17581769
cubinChip = "sm_$(device_properties.major)$(device_properties.minor)"
1759-
17601770
if DEBUG_KERNEL[]
17611771
curesulthandler = dlsym(
17621772
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
@@ -1787,6 +1797,7 @@ function compile_mlir!(
17871797
lower_comms,
17881798
backend,
17891799
enable_triton_passes=false,
1800+
device_properties,
17901801
)
17911802
opt_passes2 = optimization_passes(
17921803
compile_options;
@@ -1795,6 +1806,7 @@ function compile_mlir!(
17951806
lower_comms,
17961807
backend,
17971808
enable_triton_passes=false,
1809+
device_properties,
17981810
)
17991811
opt_passes_with_triton = optimization_passes(
18001812
compile_options;
@@ -1803,6 +1815,7 @@ function compile_mlir!(
18031815
lower_comms,
18041816
backend,
18051817
enable_triton_passes=true,
1818+
device_properties,
18061819
)
18071820

18081821
raise_passes = if raise isa String
@@ -1824,6 +1837,7 @@ function compile_mlir!(
18241837
recognize_comms,
18251838
lower_comms,
18261839
backend,
1840+
device_properties,
18271841
)
18281842
result = result * "," * opt_passes_dus_to_concat
18291843
end
@@ -2148,6 +2162,7 @@ function compile_mlir!(
21482162
recognize_comms,
21492163
lower_comms,
21502164
backend,
2165+
device_properties,
21512166
),
21522167
"post_op_transpose_reshape",
21532168
)

src/Ops.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,7 +1837,7 @@ function _extract_function(
18371837
error("hlo_call: could not find function $func_name in the provided module")
18381838
end
18391839

1840-
return fn, symref
1840+
return fn, symref, moduleop
18411841
end
18421842

18431843
function triton_call(
@@ -1850,9 +1850,16 @@ function triton_call(
18501850
block_x::TracedRNumber{<:Integer},
18511851
block_y::TracedRNumber{<:Integer},
18521852
block_z::TracedRNumber{<:Integer},
1853+
num_ctas::Integer=1,
1854+
num_warps::Integer=4,
18531855
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
18541856
)
1855-
_, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location)
1857+
_, symref, modop = _extract_function(
1858+
mlir_code; func_name, func_op_kind="tt.func", location
1859+
)
1860+
1861+
MLIR.IR.attr!(modop, "ttg.num-wraps", MLIR.IR.Attribute(Int32(num_warps)))
1862+
MLIR.IR.attr!(modop, "ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas)))
18561863

18571864
result_types = MLIR.IR.Type[]
18581865
output_operand_aliases = MLIR.IR.Attribute[]
@@ -1929,7 +1936,7 @@ julia> Reactant.@jit(
19291936
func_name="main",
19301937
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
19311938
)
1932-
fn, symref = _extract_function(code; func_name, func_op_kind="func.func", location)
1939+
fn, symref, _ = _extract_function(code; func_name, func_op_kind="func.func", location)
19331940

19341941
ftype_attr = MLIR.IR.attr(fn, "function_type")
19351942
ftype = MLIR.IR.Type(ftype_attr)

0 commit comments

Comments
 (0)