Skip to content

Commit b155d8a

Browse files
authored
[BACKEND] Rework load-store redundant data masking (#5432)
This splits `getRedundantDataMask` into two functions, `getFreeVariableMasks` and `emitRedundantThreadPredicate`. The returned predicate doesn't include the register index, and instead you use the free variable mask to de-duplicate the registers while looping over them (i.e. we don't emit the instruction at all). This also allows us to fix predication for `AsyncCopyGlobalToLocal`, as we can explicitly zero out the block dim mask before calling `emitRedundantThreadPredicate`. I also return null values if the predicate is always true, which allows us to omit the predicate entirely if there is no redundant data.
1 parent 9829ce8 commit b155d8a

File tree

3 files changed

+255
-142
lines changed

3 files changed

+255
-142
lines changed

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 94 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s --dump-input-context 20
22

33
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
44
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>)
@@ -115,34 +115,34 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} {
115115

116116
// Load 4 elements from vector0
117117
// CHECK: mov.u32 $0, 0x0
118-
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
118+
// CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
119119
// CHECK: mov.u32 $0, 0x0
120-
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
120+
// CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
121121
// CHECK: mov.u32 $0, 0x0
122-
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
122+
// CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
123123
// CHECK: mov.u32 $0, 0x0
124-
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
124+
// CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
125125

126126
// Load 4 elements from vector1
127127
// CHECK: mov.u32 $0, 0x0
128-
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
128+
// CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
129129
// CHECK: mov.u32 $0, 0x0
130-
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
130+
// CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
131131
// CHECK: mov.u32 $0, 0x0
132-
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
132+
// CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
133133
// CHECK: mov.u32 $0, 0x0
134-
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
134+
// CHECK: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
135135
%9 = tt.load %6 : tensor<256x!tt.ptr<f32>, #blocked0>
136136
%10 = tt.load %8 : tensor<256x!tt.ptr<f32>, #blocked0>
137137
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
138138
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
139139
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
140140

141141
// Store 4 elements to global
142-
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
143-
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
144-
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
145-
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
142+
// CHECK: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
143+
// CHECK: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
144+
// CHECK: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
145+
// CHECK: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
146146
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
147147
tt.return
148148
}
@@ -166,10 +166,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} {
166166
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
167167

168168
// Load 4 elements from A with single one vectorized load instruction
169-
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
169+
// CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
170170

171171
// Load 4 elements from B with single one vectorized load instruction
172-
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
172+
// CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
173173

174174
%9 = tt.load %6 : tensor<256x!tt.ptr<f32>, #blocked0>
175175
%10 = tt.load %8 : tensor<256x!tt.ptr<f32>, #blocked0>
@@ -178,7 +178,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} {
178178
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
179179

180180
// Store 4 elements to global with single one vectorized store instruction
181-
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
181+
// CHECK: st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
182182
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
183183
tt.return
184184
}
@@ -233,16 +233,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
233233
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
234234

235235
// Load 8 elements from A with four vectorized load instruction
236-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
237-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
238-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
239-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
236+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
237+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
238+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
239+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
240240

241241
// Load 8 elements from B with four vectorized load instruction
242-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
243-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
244-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
245-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
242+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
243+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
244+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
245+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
246246

247247
%9 = tt.load %6 : tensor<256x!tt.ptr<f32>, #blocked0>
248248
%10 = tt.load %8 : tensor<256x!tt.ptr<f32>, #blocked0>
@@ -251,10 +251,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
251251
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
252252

253253
// Store 8 elements to global with four vectorized store instruction
254-
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
255-
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
256-
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
257-
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
254+
// CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
255+
// CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
256+
// CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
257+
// CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
258258
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
259259
tt.return
260260
}
@@ -278,16 +278,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
278278
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
279279

280280
// Load 8 elements from A with four vectorized load instruction
281-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
282-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
283-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
284-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
281+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
282+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
283+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
284+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
285285

286286
// Load 8 elements from B with four vectorized load instruction
287-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
288-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
289-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
290-
// CHECK: @${{.*}} ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
287+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
288+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
289+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
290+
// CHECK: ld.global.v2.b32 { ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
291291

292292
%9 = tt.load %6 : tensor<256x!tt.ptr<f32>, #blocked0>
293293
%10 = tt.load %8 : tensor<256x!tt.ptr<f32>, #blocked0>
@@ -296,10 +296,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
296296
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
297297

298298
// Store 8 elements to global with four vectorized store instruction
299-
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
300-
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
301-
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
302-
// CHECK: @${{.*}} st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
299+
// CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
300+
// CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
301+
// CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
302+
// CHECK: st.global.v2.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}} };
303303
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
304304
tt.return
305305
}
@@ -323,12 +323,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
323323
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
324324

325325
// Load 8 elements from A with two vectorized load instruction
326-
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
327-
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
326+
// CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
327+
// CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
328328

329329
// Load 8 elements from B with two vectorized load instruction
330-
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
331-
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
330+
// CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
331+
// CHECK: ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
332332

333333
%9 = tt.load %6 : tensor<256x!tt.ptr<f32>, #blocked0>
334334
%10 = tt.load %8 : tensor<256x!tt.ptr<f32>, #blocked0>
@@ -337,13 +337,56 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
337337
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
338338

339339
// Store 8 elements to global with two vectorized store instruction
340-
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
341-
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
340+
// CHECK: st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
341+
// CHECK: st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
342342
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
343343
tt.return
344344
}
345345
}
346346

347+
// -----
348+
349+
// Slice layout with 2 unique elements, but 8 total elements per thread
350+
#blocked2d = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
351+
#slice = #ttg.slice<{dim = 1, parent = #blocked2d}>
352+
353+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} {
354+
// CHECK-LABEL: global_load_store_slice
355+
tt.func @global_load_store_slice(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
356+
%c128_i32 = arith.constant 128 : i32
357+
%0 = tt.get_program_id x : i32
358+
%1 = arith.muli %0, %c128_i32 : i32
359+
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #slice>
360+
%3 = tt.splat %1 : i32 -> tensor<128xi32, #slice>
361+
%4 = arith.addi %3, %2 : tensor<128xi32, #slice>
362+
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #slice>
363+
%6 = tt.addptr %5, %4 : tensor<128x!tt.ptr<f32>, #slice>, tensor<128xi32, #slice>
364+
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #slice>
365+
%8 = tt.addptr %7, %4 : tensor<128x!tt.ptr<f32>, #slice>, tensor<128xi32, #slice>
366+
367+
// Load 2 element from vector0 without predicate
368+
// CHECK: mov.u32 $0, 0x0
369+
// CHECK-NOT: @{{.*}} ld.global
370+
// CHECK-COUNT-2: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
371+
372+
// Load 2 elements from vector1 without predicate
373+
// CHECK: mov.u32 $0, 0x0
374+
// CHECK-NOT: @{{.*}} ld.global
375+
// CHECK-COUNT-2: ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
376+
%9 = tt.load %6 : tensor<128x!tt.ptr<f32>, #slice>
377+
%10 = tt.load %8 : tensor<128x!tt.ptr<f32>, #slice>
378+
%11 = arith.addf %9, %10 : tensor<128xf32, #slice>
379+
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #slice>
380+
%13 = tt.addptr %12, %4 : tensor<128x!tt.ptr<f32>, #slice>, tensor<128xi32, #slice>
381+
382+
// Store 2 element to global without predicate
383+
// CHECK-NOT: @{{.*}} st.global
384+
// CHECK-COUNT-2: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
385+
tt.store %13, %11 : tensor<128x!tt.ptr<f32>, #slice>
386+
tt.return
387+
}
388+
}
389+
347390
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp
348391
// is from an addptr with const idx
349392

@@ -583,8 +626,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
583626
%tensor = ttg.local_alloc : () -> !ttg.memdesc<16x64xf32, #A, #smem, mutable>
584627
%index = arith.constant 1 : i32
585628

586-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
587-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
629+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
630+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
588631
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
589632
// CHECK-SAME: cp.async.commit_group
590633
%a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr<f32>, #AL> -> !ttg.memdesc<16x64xf32, #A, #smem, mutable>
@@ -674,7 +717,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
674717
// CHECK: llvm.mlir.constant(16 : i32) : i32
675718
// CHECK: llvm.mul
676719
// CHECK: llvm.add
677-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4;"
720+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4;"
678721
// CHECK: llvm.inline_asm
679722
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
680723
// CHECK: llvm.inline_asm
@@ -1289,9 +1332,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12891332
// CHECK-LABEL: store_f32
12901333
tt.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
12911334
// CHECK: llvm.inline_asm
1292-
// CHECK-SAME: @$2 st.global.b32
1335+
// CHECK-SAME: st.global.b32
12931336
// CHECK: llvm.inline_asm
1294-
// CHECK-SAME: @$2 st.global.b32
1337+
// CHECK-SAME: st.global.b32
12951338
tt.store %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
12961339
tt.return
12971340
}

third_party/nvidia/include/TritonNVIDIAGPUToLLVM/PTXAsmFormat.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,19 @@ struct PTXInstrExecution {
298298

299299
// Prefix a predicate to the instruction.
300300
PTXInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") {
301+
assert(value);
301302
pred = instr->builder->newOperand(value, constraint);
302303
return *this;
303304
}
304305

306+
// Prefix a predicate to the instruction, if non-null
307+
PTXInstrExecution &maybePredicate(mlir::Value value,
308+
StringRef constraint = "b") {
309+
if (value)
310+
predicate(value, constraint);
311+
return *this;
312+
}
313+
305314
// Prefix a !predicate to the instruction.
306315
PTXInstrExecution &predicateNot(mlir::Value value, StringRef constraint) {
307316
pred = instr->builder->newOperand(value, constraint);

0 commit comments

Comments
 (0)