Skip to content

Commit 6f743d1

Browse files
Fix test_matmul.py::test_lhs_in_tmem (#3866)
Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com> Co-authored-by: Whitney Tsang <whitney.tsang@intel.com>
1 parent 91ef068 commit 6f743d1

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def matmul_kernel( #
5050
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
5151
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty)
5252
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
53-
mask_a = (offs_am[:, None] < M) & (offs_k[None, :] + k * BLOCK_K < K)
53+
if not A_TRANS:
54+
mask_a = (offs_am[:, None] < M) & (offs_k[None, :] + k * BLOCK_K < K)
55+
else:
56+
mask_a = (offs_k[:, None] + k * BLOCK_K < K) & (offs_am[None, :] < M)
5457
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
5558
if SCALE_A is not None:
5659
a = a * SCALE_A
@@ -551,11 +554,6 @@ def test_lhs_in_tmem(BLOCK_M, BLOCK_N, BLOCK_K, a_trans, dtype_src_str, device,
551554
N = 512
552555
K = 256
553556
_knob_promote_lhs_to_tmem(monkeypatch)
554-
if is_xpu() and (M != BLOCK_M or N != BLOCK_N or K != BLOCK_K):
555-
# TODO: Make LHS TMEM promotion work for all problem sizes regardless of block dims
556-
pytest.skip(
557-
"LHS TMEM promotion produces incorrect results when the workload dimensions are not equal to the block dims"
558-
)
559557
torch.manual_seed(42)
560558
if dtype_src_str == "float8e5":
561559
a = torch.randint(20, 40, (M, K), dtype=torch.int8, device=device).view(torch.float8_e5m2)
@@ -581,6 +579,8 @@ def test_lhs_in_tmem(BLOCK_M, BLOCK_N, BLOCK_K, a_trans, dtype_src_str, device,
581579
atol = 0.03
582580
rtol = 0.03
583581
torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol)
582+
if not is_cuda():
583+
return
584584
pattern = r"%\w+\s*=\s*ttng\.tmem_alloc[\s\S]*?tng\.tc_gen5_mma\s+%\w+,"
585585
ttgir = k.asm["ttgir"]
586586
assert re.search(pattern, ttgir)

scripts/skiplist/lts/language.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,4 @@ test/unit/language/test_pipeliner.py::test_indirect_matmul[5-128-128-64]
316316
test/unit/language/test_pipeliner.py::test_indirect_matmul[5-128-64-128]
317317
test/unit/language/test_core.py::test_convert_mma2mma[mma_pair0-float16-256-256]
318318
test/unit/language/test_matmul.py::test_mxfp
319+
test/unit/language/test_matmul.py::test_lhs_in_tmem

0 commit comments

Comments
 (0)