@@ -102,20 +102,18 @@ def _helion_mul_relu_block_backward_kernel(x, y, dz, dx, dy, _BLOCK_SIZE_0: tl.c
102102 # src[test_control_flow.py:N]: relu_grad = torch.where(relu_mask, 1, 0)
103103 v_3 = tl.full([], 0, tl.int64)
104104 v_4 = tl.full([], 1, tl.int64)
105- v_5 = v_4[None, None]
106- v_6 = v_3[None, None]
107- v_7 = tl.where(v_2, v_5, v_6)
105+ v_5 = tl.where(v_2, v_4, v_3)
108106 # src[test_control_flow.py:N]: dx[tile_i, tile_j] = dz_tile * relu_grad * y_tile[:, None]
109- v_8 = tl.cast(v_7 , tl.float32)
110- v_9 = dz_tile * v_8
107+ v_6 = tl.cast(v_5 , tl.float32)
108+ v_7 = dz_tile * v_6
111109 subscript_1 = y_tile[:, None]
112- v_10 = v_9 * subscript_1
113- tl.store(dx + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_10 , None)
110+ v_8 = v_7 * subscript_1
111+ tl.store(dx + (indices_0[:, None] * 1024 + indices_1[None, :] * 1), v_8 , None)
114112 # src[test_control_flow.py:N]: local_dy_grad = torch.sum(dz_tile * relu_grad * x_tile, dim=1)
115- v_11 = tl.cast(v_7 , tl.float32)
116- v_12 = dz_tile * v_11
117- v_13 = v_12 * x_tile
118- local_dy_grad = tl.cast(tl.sum(v_13 , 1), tl.float32)
113+ v_9 = tl.cast(v_5 , tl.float32)
114+ v_10 = dz_tile * v_9
115+ v_11 = v_10 * x_tile
116+ local_dy_grad = tl.cast(tl.sum(v_11 , 1), tl.float32)
119117 # src[test_control_flow.py:N]: hl.atomic_add(dy, [tile_i], local_dy_grad)
120118 tl.atomic_add(dy + indices_0 * 1, local_dy_grad, mask=None, sem='relaxed')
121119
0 commit comments