Skip to content

Commit 6e40359

Browse files
committed
fix irrelevant ci errors on hopper
1 parent 6785b6e commit 6e40359

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

flashinfer/triton/kernels/cascade.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,9 @@ def variable_length_merge_states_kernel(
148148
for head_idx in tl.range(bdy):
149149
o, m, d = 0.0, -5e4, 1.0
150150
for iter in tl.range(tl.load(indptr + pos), tl.load(indptr + pos + 1)):
151-
s = tl.load(s_ptr + iter * num_heads + head_idx)
152-
v = tl.load(v_ptr + (iter * num_heads + head_idx) * head_dim + tx)
151+
iter_i64 = iter.to(tl.int64)
152+
s = tl.load(s_ptr + iter_i64 * num_heads + head_idx)
153+
v = tl.load(v_ptr + (iter_i64 * num_heads + head_idx) * head_dim + tx)
153154
o, m, d = state_merge(o, m, d, v, s, 1)
154155
o, m, d = state_normalize(o, m, d)
155156
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o)

0 commit comments

Comments
 (0)