Skip to content

Commit ea925e1

Browse files
[TEST] Fix triton_kernels failures after 318fa9c
Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
1 parent 64cc693 commit ea925e1

File tree

1 file changed

+2
-3
lines changed
  • python/triton_kernels/triton_kernels/matmul_ogs_details

1 file changed

+2
-3
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ def make_default_opt_flags_intel(
5555
k,
5656
routing_data,
5757
can_use_persistent_tma,
58-
can_use_fused_scatter,
58+
can_use_split_k,
5959
enforce_bitwise_invariance,
6060
epilogue_effective_itemsize,
6161
x_transpose,
6262
has_y_acc_in,
6363
constraints,
6464
):
65-
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "max_allowable_mn"]
65+
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "max_allowable_mn"]
6666
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
6767
# tokens per expert
6868
if routing_data is None:
@@ -111,7 +111,6 @@ def make_default_opt_flags_intel(
111111
block_k=block_k,
112112
num_warps=opt_flags_intel.compute_num_warps(block_m, block_n),
113113
num_stages=constraints.get("num_stages", 2),
114-
fused_scatter=constraints.get('fused_scatter', False),
115114
group_m=group_m,
116115
xcd_swizzle=xcd_swizzle,
117116
w_cache_modifier=None,

0 commit comments

Comments
 (0)