Skip to content

Commit 3bbc37b

Browse files
committed
fix: cluster dims
1 parent 548c28d commit 3bbc37b

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ function overlayed_pycall_with_triton(
166166

167167
grid = canonicalize_grid(grid, metadata)
168168

169+
# TODO: actual cluster_x/y/z
170+
169171
return @opcall triton_call(
170172
pyconvert(String, compiled_kernel.asm["source"]),
171173
filter(x -> x isa Reactant.TracedType, args)...;
@@ -176,6 +178,9 @@ function overlayed_pycall_with_triton(
176178
block_x=@opcall(constant(num_warps * device_properties.warp_size)),
177179
block_y=@opcall(constant(1)),
178180
block_z=@opcall(constant(1)),
181+
cluster_x=@opcall(constant(1)),
182+
cluster_y=@opcall(constant(1)),
183+
cluster_z=@opcall(constant(1)),
179184
num_ctas,
180185
num_warps,
181186
threads_per_warp=device_properties.warp_size,

src/Ops.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,9 @@ function triton_call(
18501850
block_x::TracedRNumber{<:Integer},
18511851
block_y::TracedRNumber{<:Integer},
18521852
block_z::TracedRNumber{<:Integer},
1853+
cluster_x::TracedRNumber{<:Integer},
1854+
cluster_y::TracedRNumber{<:Integer},
1855+
cluster_z::TracedRNumber{<:Integer},
18531856
num_ctas::Integer=1,
18541857
num_warps::Integer=4,
18551858
threads_per_warp::Integer=32,
@@ -1894,6 +1897,9 @@ function triton_call(
18941897
block_x.mlir_data,
18951898
block_y.mlir_data,
18961899
block_z.mlir_data,
1900+
cluster_x.mlir_data,
1901+
cluster_y.mlir_data,
1902+
cluster_z.mlir_data,
18971903
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
18981904
fn=symref,
18991905
result_0=result_types,

src/mlir/Dialects/TritonExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ function call(
2020
blockx::Value,
2121
blocky::Value,
2222
blockz::Value,
23+
clusterx::Value,
24+
clustery::Value,
25+
clusterz::Value,
2326
inputs::Vector{Value};
2427
result_0::Vector{IR.Type},
2528
fn,
@@ -33,7 +36,9 @@ function call(
3336
location=Location(),
3437
)
3538
op_ty_results = IR.Type[result_0...,]
36-
operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, inputs...]
39+
operands = Value[
40+
gridx, gridy, gridz, blockx, blocky, blockz, clusterx, clustery, clusterz, inputs...
41+
]
3742
owned_regions = Region[]
3843
successors = Block[]
3944
attributes = NamedAttribute[namedattribute("fn", fn),]

0 commit comments

Comments
 (0)