Skip to content

Commit ebbd2c4

Browse files
authored
Apply simplification to range indexing in order to reuse block size symbols (#809)
1 parent 70c4120 commit ebbd2c4

File tree

10 files changed

+146
-37
lines changed

10 files changed

+146
-37
lines changed

helion/_compiler/compile_environment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,11 +522,14 @@ def mark_alternate_size(self, size: torch.SymInt | int | None) -> None:
522522
self.size = size
523523
if size is not None:
524524
env = CompileEnvironment.current()
525+
# Refresh the var_to_val hint to match the resolved block size
526+
hint = env.size_hint(size)
527+
env.shape_env.var_to_val[self.symbol()] = sympy.Integer(hint)
525528
with contextlib.suppress(KeyError):
526529
# update the size hint now that we know the size
527530
env.config_spec.block_sizes.block_id_lookup(
528531
self.block_id
529-
).update_hint(env.size_hint(size))
532+
).update_hint(hint)
530533
elif size is None or self.size is None or self.size != size:
531534
self.size = None
532535

helion/_compiler/indexing_strategy.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,10 @@ def create(
527527
output_size = SubscriptIndexing.compute_shape(fake_value, index)
528528
env = CompileEnvironment.current()
529529
dtype = env.triton_index_type()
530+
531+
def _is_size_one(size: int | torch.SymInt) -> bool:
532+
return env.known_equal(size, 1)
533+
530534
for n, k in enumerate(index):
531535
if k is None:
532536
output_idx += 1
@@ -544,7 +548,7 @@ def create(
544548
index_values.append(f"({index_var}){expand}")
545549
if (
546550
mask := state.codegen.mask_var(origin.origin.block_id)
547-
) and fake_value.size(i) != 1:
551+
) and not _is_size_one(fake_value.size(i)):
548552
mask_values.setdefault(f"({mask}){expand}")
549553
output_idx += 1
550554
else:
@@ -576,7 +580,7 @@ def create(
576580
index_values.append(f"{start}{expand}")
577581
else:
578582
# Full slice or slice without step
579-
if size != 1:
583+
if not _is_size_one(size):
580584
rdim = env.allocate_reduction_dimension(size)
581585
block_idx = rdim.block_id
582586
index_var = state.codegen.index_var(block_idx)
@@ -620,7 +624,7 @@ def create(
620624
assert len(index_values) == fake_value.ndim
621625
index_expr = []
622626
for i, idx in enumerate(index_values):
623-
if fake_value.size(i) != 1:
627+
if not _is_size_one(fake_value.size(i)):
624628
stride = state.device_function.tensor_stride(fake_value, i).name
625629
index_expr.append(f"{idx} * {stride}")
626630
if not index_expr:

helion/_compiler/type_propagation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .compile_environment import FixedBlockSizeSource
3838
from .compile_environment import LoopSpecBlockSizeSource
3939
from .compile_environment import warning
40+
from .device_function import contains_only_block_size_symbols
4041
from .host_function import HostFunction
4142
from .host_function import SymbolOrigin
4243
from .output_header import library_imports
@@ -473,6 +474,16 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
473474
if self.origin.is_device():
474475
output_sizes.append(output_size)
475476
elif output_size != 1:
477+
# If all symbols in output_size are block size symbols, we reuse them
478+
if isinstance(output_size, torch.SymInt):
479+
expr = output_size._sympy_()
480+
if (
481+
isinstance(expr, sympy.Expr)
482+
and expr.free_symbols
483+
and contains_only_block_size_symbols(expr)
484+
):
485+
output_sizes.append(output_size)
486+
continue
476487
rdim = CompileEnvironment.current().allocate_reduction_dimension(
477488
output_size
478489
)

helion/_compiler/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ def compute_slice_size(
2626
step = slice_obj.step
2727
return (stop - start + step - 1) // step
2828
# Full slice or slice without step
29-
return original_size
29+
start = slice_obj.start if slice_obj.start is not None else 0
30+
stop = slice_obj.stop if slice_obj.stop is not None else original_size
31+
return stop - start

test/test_constexpr.expected

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ def _helion_matmul_int4_block_expr(B, A, C, _NUM_SM: tl.constexpr, _BLOCK_SIZE_2
3131
offset_2 = pid_1 * _BLOCK_SIZE_2
3232
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
3333
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
34-
for offset_3 in tl.range(0, 16, loop_unroll_factor=4, num_stages=1, disallow_acc_multi_buffer=True, flatten=True):
35-
indices_3 = offset_3 + tl.arange(0, 1).to(tl.int32)
34+
for offset_0 in tl.range(0, 16, loop_unroll_factor=4, num_stages=1, disallow_acc_multi_buffer=True, flatten=True):
35+
indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32)
3636
acc_copy = acc
3737
acc_copy_0 = acc_copy
38-
packed = tl.load(B + (indices_3[:, None] * 16 + indices_2[None, :] * 1), None)
38+
packed = tl.load(B + (indices_0[:, None] * 16 + indices_2[None, :] * 1), None)
3939
v_0 = tl.full([], 4, tl.int8)
4040
v_1 = packed << v_0
4141
v_2 = tl.full([], 4, tl.int8)
@@ -54,7 +54,7 @@ def _helion_matmul_int4_block_expr(B, A, C, _NUM_SM: tl.constexpr, _BLOCK_SIZE_2
5454
mask_1 = broadcast_idx == 1
5555
stacked_result = tl.where(mask_1, expanded_1, stacked_result)
5656
unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
57-
mul_5 = 2 * offset_3
57+
mul_5 = 2 * offset_0
5858
iota = mul_5 + tl.arange(0, mul)
5959
a_tile = tl.load(A + (indices_1[:, None] * 32 + iota[None, :] * 1), None)
6060
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)

test/test_examples.expected

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,12 +1353,12 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
13531353
mask_2 = indices_2 < N
13541354
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
13551355
floordiv = triton_helpers.div_floor_integer(K, 2)
1356-
for offset_3 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1357-
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1358-
mask_0 = indices_3 < floordiv
1356+
for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
1357+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
1358+
mask_0 = indices_0 < floordiv
13591359
acc_copy = acc
13601360
acc_copy_0 = acc_copy
1361-
b_tile = tl.load(B + (indices_3[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1361+
b_tile = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
13621362
v_0 = tl.full([], 4, tl.int8)
13631363
v_1 = b_tile << v_0
13641364
v_2 = tl.full([], 4, tl.int8)
@@ -1372,12 +1372,12 @@ def _helion_matmul_bf16_int4(B, A, C, A_stride_0, A_stride_1, B_stride_0, B_stri
13721372
expanded_0 = tl.expand_dims(v_6, 1)
13731373
expanded_1 = tl.expand_dims(v_7, 1)
13741374
stacked_result = tl.zeros_like(expanded_0)
1375-
mask_4 = broadcast_idx == 0
1376-
stacked_result = tl.where(mask_4, expanded_0, stacked_result)
1377-
mask_5 = broadcast_idx == 1
1378-
stacked_result = tl.where(mask_5, expanded_1, stacked_result)
1375+
mask_3 = broadcast_idx == 0
1376+
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
1377+
mask_4 = broadcast_idx == 1
1378+
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
13791379
b_unpacked = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
1380-
mul_5 = 2 * offset_3
1380+
mul_5 = 2 * offset_0
13811381
iota = mul_5 + tl.arange(0, mul)
13821382
a_tile = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
13831383
dot = tl.dot(tl.cast(a_tile, tl.bfloat16), tl.cast(b_unpacked, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
@@ -1406,7 +1406,6 @@ def matmul_bf16_int4(A: Tensor, B: Tensor, *, _launcher=_default_launcher):
14061406
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
14071407
_BLOCK_SIZE_1 = 64
14081408
_BLOCK_SIZE_2 = 32
1409-
_RDIM_SIZE_3 = triton.next_power_of_2(K)
14101409
_BLOCK_SIZE_0 = 64
14111410
_launcher(_helion_matmul_bf16_int4, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), B, A, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
14121411
return C
@@ -3124,7 +3123,7 @@ import triton.language as tl
31243123
from helion.runtime import default_launcher as _default_launcher
31253124

31263125
@triton.jit
3127-
def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size_0, grad_out_stride_0, grad_out_stride_1, grad_weight_stride_1, grad_x_stride_0, grad_x_stride_1, rsqrt_stride_0, weight_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
3126+
def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size_0, grad_out_stride_0, grad_out_stride_1, grad_weight_stride_0, grad_weight_stride_1, grad_x_stride_0, grad_x_stride_1, rsqrt_stride_0, weight_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
31283127
pid_0 = tl.program_id(0)
31293128
offset_0 = pid_0 * _BLOCK_SIZE_0
31303129
indices_3 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
@@ -3162,7 +3161,7 @@ def _helion_rms_norm_bwd(x, grad_out, rsqrt, weight, grad_x, grad_weight, x_size
31623161
v_18 = tl.cast(v_17, tl.float16)
31633162
tl.store(grad_x + (indices_2[:, None] * grad_x_stride_0 + indices_3[None, :] * grad_x_stride_1), v_18, None)
31643163
tile_id = offset_0 // _BLOCK_SIZE_0
3165-
tl.store(grad_weight + indices_3 * grad_weight_stride_1, grad_w_m, None)
3164+
tl.store(grad_weight + (tile_id * grad_weight_stride_0 + indices_3 * grad_weight_stride_1), grad_w_m, None)
31663165

31673166
def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, rsqrt: torch.Tensor, *, _launcher=_default_launcher):
31683167
"""
@@ -3187,7 +3186,7 @@ def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor,
31873186
grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32)
31883187
_BLOCK_SIZE_0 = 32
31893188
_RDIM_SIZE_2 = 64
3190-
_launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)
3189+
_launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(0), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)
31913190
return (grad_x, grad_weight.sum(0).to(weight.dtype))
31923191

31933192
--- assertExpectedJournal(TestExamples.test_rms_norm_bwd_dw)

test/test_loops.expected

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -982,20 +982,22 @@ import triton.language as tl
982982
from helion.runtime import default_launcher as _default_launcher
983983

984984
@triton.jit
985-
def _helion_kernel_fixed_block_size(loss_sum, y_true, kl_loss, loss, loss_sum_stride_0, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
985+
def _helion_kernel_fixed_block_size(loss_sum, y_true, kl_loss, loss, loss_sum_stride_0, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _RDIM_SIZE_3: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
986986
pid_0 = tl.program_id(0)
987987
offset_1 = pid_0 * _BLOCK_SIZE_1
988988
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
989-
indices_4 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
990-
full = tl.full([64, 64], 0.0, tl.float32)
991-
tl.store(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), full, None)
992-
for offset_2 in tl.range(0, 128, _BLOCK_SIZE_3):
993-
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
994-
y_true_val = tl.load(y_true + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None)
995-
tl.store(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), y_true_val, None)
996-
load_1 = tl.load(kl_loss + (indices_1[:, None] * 128 + indices_2[None, :] * 1), None)
997-
tl.atomic_add(loss_sum + (indices_1[:, None] * loss_sum_stride_0 + indices_2[None, :] * 1), load_1, mask=None, sem='relaxed')
998-
load = tl.load(loss_sum + (indices_4[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), None)
989+
indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
990+
indices_6 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32)
991+
mask_3 = indices_6 < _BLOCK_SIZE_0
992+
full = tl.full([64, _BLOCK_SIZE_0], 0.0, tl.float32)
993+
tl.store(loss_sum + (indices_5[:, None] * loss_sum_stride_0 + indices_6[None, :] * 1), full, mask_3[None, :])
994+
for offset_4 in tl.range(0, 128, _BLOCK_SIZE_0):
995+
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
996+
y_true_val = tl.load(y_true + (indices_1[:, None] * 128 + indices_4[None, :] * 1), None)
997+
tl.store(kl_loss + (indices_1[:, None] * 128 + indices_4[None, :] * 1), y_true_val, None)
998+
load_1 = tl.load(kl_loss + (indices_1[:, None] * 128 + indices_4[None, :] * 1), None)
999+
tl.atomic_add(loss_sum + (indices_1[:, None] * loss_sum_stride_0 + indices_4[None, :] * 1), load_1, mask=None, sem='relaxed')
1000+
load = tl.load(loss_sum + (indices_5[:, None] * loss_sum_stride_0 + indices_6[None, :] * 1), mask_3[None, :], other=0)
9991001
sum_1 = tl.cast(tl.sum(load, 1), tl.float32)
10001002
tl.store(loss + indices_1 * 1, sum_1, None)
10011003

@@ -1008,8 +1010,9 @@ def kernel_fixed_block_size(y_pred: torch.Tensor, y_true: torch.Tensor, *, _laun
10081010
loss_sum = torch.zeros([BT_SIZE, block_size_n], dtype=torch.float32, device=y_pred.device)
10091011
_BLOCK_SIZE_1 = 64
10101012
_RDIM_SIZE_2 = 64
1011-
_BLOCK_SIZE_3 = 64
1012-
_launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1013+
_BLOCK_SIZE_0 = 128
1014+
_RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_0)
1015+
_launcher(_helion_kernel_fixed_block_size, (triton.cdiv(64, _BLOCK_SIZE_1),), loss_sum, y_true, kl_loss, loss, loss_sum.stride(0), _BLOCK_SIZE_1, _RDIM_SIZE_2, _RDIM_SIZE_3, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
10131016
return torch.sum(loss) / BT
10141017

10151018
--- assertExpectedJournal(TestLoops.test_reorder_with_register_block_size)

test/test_loops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ def fn(x: torch.Tensor) -> torch.Tensor:
332332
self.assertEqual(spec.min_size, 32)
333333
self.assertEqual(spec.max_size, 256)
334334

335-
@skipIfRefEager("Triton codegen is disabled in ref eager mode")
336335
def test_register_block_size_codegen_size_hint(self):
337336
@helion.kernel(static_shapes=True)
338337
def kernel_fixed_block_size(
@@ -368,7 +367,7 @@ def kernel_fixed_block_size(
368367
code, result = code_and_output(kernel_fixed_block_size, args, block_sizes=[128])
369368
self.assertExpectedJournal(code)
370369

371-
expected = y_true[:, : y_pred.size(0)].sum() / y_pred.size(0)
370+
expected = y_true[:, :].sum() / y_pred.size(0)
372371
torch.testing.assert_close(result, expected)
373372

374373
def test_reorder_with_register_block_size(self):

test/test_matmul.expected

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,58 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]]
193193
_launcher(_helion_matmul, (triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
194194
return out
195195

196+
--- assertExpectedJournal(TestMatmul.test_matmul_packed_rhs)
197+
from __future__ import annotations
198+
199+
import torch
200+
import triton
201+
import triton.language as tl
202+
from torch._inductor.runtime import triton_helpers
203+
from helion.runtime import default_launcher as _default_launcher
204+
205+
@triton.jit
206+
def _helion_matmul_with_packed_b(A, B, C, A_stride_0, A_stride_1, B_stride_0, B_stride_1, C_stride_0, C_stride_1, M, N, K, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, mul_2: tl.constexpr):
207+
num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_1)
208+
pid_0 = tl.program_id(0) % num_blocks_0
209+
pid_1 = tl.program_id(0) // num_blocks_0
210+
offset_1 = pid_0 * _BLOCK_SIZE_1
211+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
212+
mask_1 = indices_1 < M
213+
offset_2 = pid_1 * _BLOCK_SIZE_2
214+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
215+
mask_2 = indices_2 < N
216+
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
217+
floordiv = triton_helpers.div_floor_integer(K, 2)
218+
for offset_0 in tl.range(0, floordiv.to(tl.int32), _BLOCK_SIZE_0):
219+
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
220+
mask_0 = indices_0 < floordiv
221+
acc_copy = acc
222+
acc_copy_0 = acc_copy
223+
mul = 2 * offset_0
224+
iota = mul + tl.arange(0, mul_2)
225+
lhs = tl.load(A + (indices_1[:, None] * A_stride_0 + iota[None, :] * A_stride_1), mask_1[:, None], other=0)
226+
packed = tl.load(B + (indices_0[:, None] * B_stride_0 + indices_2[None, :] * B_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
227+
stack_idx = tl.arange(0, 2)
228+
broadcast_idx = stack_idx[None, :, None]
229+
expanded_0 = tl.expand_dims(packed, 1)
230+
expanded_1 = tl.expand_dims(packed, 1)
231+
stacked_result = tl.zeros_like(expanded_0)
232+
mask_3 = broadcast_idx == 0
233+
stacked_result = tl.where(mask_3, expanded_0, stacked_result)
234+
mask_4 = broadcast_idx == 1
235+
stacked_result = tl.where(mask_4, expanded_1, stacked_result)
236+
rhs = tl.reshape(stacked_result, [2 * _BLOCK_SIZE_0, _BLOCK_SIZE_2])
237+
acc = tl.dot(tl.cast(lhs, tl.float32), tl.cast(rhs, tl.float32), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
238+
tl.store(C + (indices_1[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), acc, mask_1[:, None] & mask_2[None, :])
239+
240+
def matmul_with_packed_b(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, *, _launcher=_default_launcher):
241+
M, K = A.shape
242+
_, N = B.shape
243+
_BLOCK_SIZE_1 = 16
244+
_BLOCK_SIZE_2 = 16
245+
_BLOCK_SIZE_0 = 16
246+
_launcher(_helion_matmul_with_packed_b, (triton.cdiv(M, _BLOCK_SIZE_1) * triton.cdiv(N, _BLOCK_SIZE_2),), A, B, C, A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1), M, N, K, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=4, num_stages=3)
247+
196248
--- assertExpectedJournal(TestMatmul.test_matmul_split_k)
197249
from __future__ import annotations
198250

0 commit comments

Comments
 (0)