@@ -1353,12 +1353,12 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
13531353 mask_2 = indices_2 < N
13541354 acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
13551355 floordiv = triton_helpers.div_floor_integer(K, 2)
1356- for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1357- indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1358- mask_0 = indices_3 < floordiv
1356+ for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1357+ indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1358+ mask_0 = indices_0 < floordiv
13591359 acc_copy = acc
13601360 acc_copy_0 = acc_copy
1361- b_tile = tl.load(B + (indices_3 [:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1361+ b_tile = tl.load(B + (indices_0 [:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
13621362 v_0 = tl.full([], 4, tl.int8)
13631363 v_1 = b_tile << v_0
13641364 v_2 = tl.full([], 4, tl.int8)
@@ -1372,12 +1372,12 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
13721372 expanded_0 = tl.expand_dims(v_6, 1)
13731373 expanded_1 = tl.expand_dims(v_7, 1)
13741374 stacked_result = tl.zeros_like(expanded_0)
1375- mask_4 = broadcast_idx == 0
1376- stacked_result = tl.where(mask_4 , expanded_0, stacked_result)
1377- mask_5 = broadcast_idx == 1
1378- stacked_result = tl.where(mask_5 , expanded_1, stacked_result)
1375+ mask_3 = broadcast_idx == 0
1376+ stacked_result = tl.where(mask_3 , expanded_0, stacked_result)
1377+ mask_4 = broadcast_idx == 1
1378+ stacked_result = tl.where(mask_4 , expanded_1, stacked_result)
13791379 b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
1380- mul_5 = 2 * offset_3
1380+ mul_5 = 2 * offset_0
13811381 iota = mul_5 + tl.arange(0, mul)
13821382 a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
13831383 dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
@@ -1406,7 +1406,6 @@ def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
14061406 C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
14071407 _BLOCK_SIZE_1 = 64
14081408 _BLOCK_SIZE_2 = 32
1409- _RDIM_SIZE_3 = triton.next_power_of_2(K)
14101409 _BLOCK_SIZE_0 = 64
14111410 _launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
14121411 return C
@@ -3124,7 +3123,7 @@ import triton.language as tl
31243123from helion.runtime import default_launcher as _default_launcher
31253124
31263125@triton.jit
3127- def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size_0, grad_out_stride_0, grad_out_stride_1, grad_weight_stride_1, grad_x_stride_0, grad_x_stride_1, rsqrt_stride_0, weight_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
3126+ def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size_0, grad_out_stride_0, grad_out_stride_1, grad_weight_stride_0, grad_weight_stride_1, grad_x_stride_0, grad_x_stride_1, rsqrt_stride_0, weight_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
31283127 pid_0 = tl.program_id(0)
31293128 offset_0 = pid_0 * _BLOCK_SIZE_0
31303129 indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
@@ -3162,7 +3161,7 @@ def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size
31623161 v_18 = tl.cast(v_17, tl.float16)
31633162 tl.store(grad_x + (indices_2[:, None] * grad_x_stride_0 + indices_3[None, :] * grad_x_stride_1), v_18, None)
31643163 tile_id = offset_0 // _BLOCK_SIZE_0
3165- tl.store(grad_weight + indices_3 * grad_weight_stride_1, grad_w_m, None)
3164+ tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1) , grad_w_m, None)
31663165
31673166def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, rsqrt: torch.Tensor, *, _launcher=_default_launcher):
31683167 """
@@ -3187,7 +3186,7 @@ def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor,
31873186 grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32)
31883187 _BLOCK_SIZE_0 = 32
31893188 _RDIM_SIZE_2 = 64
3190- _launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)
3189+ _launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(0), grad_weight.stride( 1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)
31913190 return (grad_x, grad_weight.sum(0).to(weight.dtype))
31923191
31933192--- assertExpectedJournal(TestExamples.test_rms_norm_bwd_dw)
0 commit comments