@@ -50,7 +50,10 @@ def matmul_kernel( #
5050 b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
5151 accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = output_ptr .dtype .element_ty )
5252 for k in tl .range (0 , tl .cdiv (K , BLOCK_K ), num_stages = NUM_STAGES ):
53- mask_a = (offs_am [:, None ] < M ) & (offs_k [None , :] + k * BLOCK_K < K )
53+ if not A_TRANS :
54+ mask_a = (offs_am [:, None ] < M ) & (offs_k [None , :] + k * BLOCK_K < K )
55+ else :
56+ mask_a = (offs_k [:, None ] + k * BLOCK_K < K ) & (offs_am [None , :] < M )
5457 a = tl .load (a_ptrs , mask = mask_a , other = 0.0 )
5558 if SCALE_A is not None :
5659 a = a * SCALE_A
@@ -551,11 +554,6 @@ def test_lhs_in_tmem(BLOCK_M, BLOCK_N, BLOCK_K, a_trans, dtype_src_str, device,
551554 N = 512
552555 K = 256
553556 _knob_promote_lhs_to_tmem (monkeypatch )
554- if is_xpu () and (M != BLOCK_M or N != BLOCK_N or K != BLOCK_K ):
555- # TODO: Make LHS TMEM promotion work for all problem sizes regardless of block dims
556- pytest .skip (
557- "LHS TMEM promotion produces incorrect results when the workload dimensions are not equal to the block dims"
558- )
559557 torch .manual_seed (42 )
560558 if dtype_src_str == "float8e5" :
561559 a = torch .randint (20 , 40 , (M , K ), dtype = torch .int8 , device = device ).view (torch .float8_e5m2 )
@@ -581,6 +579,8 @@ def test_lhs_in_tmem(BLOCK_M, BLOCK_N, BLOCK_K, a_trans, dtype_src_str, device,
581579 atol = 0.03
582580 rtol = 0.03
583581 torch .testing .assert_close (ref_out , output , atol = atol , rtol = rtol )
582+ if not is_cuda ():
583+ return
584584 pattern = r"%\w+\s*=\s*ttng\.tmem_alloc[\s\S]*?tng\.tc_gen5_mma\s+%\w+,"
585585 ttgir = k .asm ["ttgir" ]
586586 assert re .search (pattern , ttgir )
0 commit comments