Skip to content

Commit 92e2404

Browse files
authored
[MLIR][XeVM] Update XeVM prefetch ops. (#166445)
`xevm.blockprefetch2d` op has pointer operand marked as MemRead. And that causes the op got get folded away be canonicalize pass. Remove the side effect mark and update XeGPU to XeVM prefetch op conversion test cases to use canonicalize pass.
1 parent 2dd7705 commit 92e2404

File tree

3 files changed

+43
-50
lines changed

3 files changed

+43
-50
lines changed

mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,9 @@ def XeVM_PrefetchOp
463463

464464
def XeVM_BlockPrefetch2dOp
465465
: XeVM_Op<"blockprefetch2d">,
466-
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, I32:$base_width,
467-
I32:$base_height, I32:$base_pitch, I32:$x, I32:$y,
468-
I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height,
469-
I32Attr:$v_blocks,
466+
Arguments<(ins LLVM_AnyPointer:$ptr, I32:$base_width, I32:$base_height,
467+
I32:$base_pitch, I32:$x, I32:$y, I32Attr:$elem_size_in_bits,
468+
I32Attr:$tile_width, I32Attr:$tile_height, I32Attr:$v_blocks,
470469
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
471470

472471
let summary = "2D block prefetch";

mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
1-
// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s
1+
// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm -canonicalize | FileCheck %s
22

33
gpu.module @test {
44
// CHECK-LABEL: @load_gather_i64_src_value_offset
5-
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
6-
gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
5+
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16>
6+
// CHECK-SAME: %[[ARG3:.*]]: vector<1xi1>
7+
gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>, %mask: vector<1xi1>) {
8+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
9+
// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16
10+
// CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
11+
// CHECK: %[[VAR2:.*]] = vector.extract %[[ARG3]][0] : i1 from vector<1xi1>
712
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
813
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
9-
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
10-
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
11-
%1 = arith.constant dense<1>: vector<1xi1>
12-
// CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
1314
// CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
1415
// CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64
1516
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
1617
// CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) {
1718
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16
1819
// CHECK: scf.yield %[[VAR7]] : f16
1920
// CHECK: } else {
20-
// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16
2121
// CHECK: scf.yield %[[CST_0]] : f16
2222
// CHECK: }
23-
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
23+
%0 = xegpu.load %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
2424
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
25+
%c0 = arith.constant 0 : index
26+
vector.store %0, %dst[%c0] : memref<1xf16>, vector<1xf16>
2527
gpu.return
2628
}
2729
}
@@ -30,16 +32,16 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>)
3032
gpu.module @test {
3133
// CHECK-LABEL: @source_materialize_single_elem_vec
3234
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16>
33-
gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) {
34-
%1 = arith.constant dense<1>: vector<1xi1>
35-
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
35+
// CHECK-SAME: %[[ARG3:.*]]: vector<1xi1>
36+
gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>, %mask: vector<1xi1>) {
37+
%0 = xegpu.load %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
3638
: i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
39+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
3740
// CHECK: %[[VAR_IF:.*]] = scf.if
3841
// CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16>
39-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
4042
// CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16>
4143
%c0 = arith.constant 0 : index
42-
vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16>
44+
vector.store %0, %dst[%c0] : memref<1xf16>, vector<1xf16>
4345
gpu.return
4446
}
4547
}
@@ -48,24 +50,21 @@ gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>
4850

4951
gpu.module @test {
5052
// CHECK-LABEL: @store_scatter_i64_src_value_offset
51-
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
52-
gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
53+
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xi1>
54+
gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>, %mask: vector<1xi1>) {
55+
// CHECK: %[[CST_0:.*]] = arith.constant 2.900000e+00 : f32
56+
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
57+
// CHECK: %[[VAR2:.*]] = vector.extract %[[ARG2]][0] : i1 from vector<1xi1>
5358
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
5459
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
55-
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
56-
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
57-
%1 = arith.constant dense<1>: vector<1xi1>
58-
// CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
59-
// CHECK: %[[VAR3:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32>
60-
%2 = arith.constant dense<2.9>: vector<1xf32>
61-
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
60+
%0 = arith.constant dense<2.9>: vector<1xf32>
6261
// CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
6362
// CHECK: %[[VAR5:.*]] = arith.addi %[[ARG0]], %[[VAR4]] : i64
6463
// CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
6564
// CHECK: scf.if %[[VAR2]] {
66-
// CHECK: llvm.store %[[VAR3]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} : f32, !llvm.ptr<1>
65+
// CHECK: llvm.store %[[CST_0]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} : f32, !llvm.ptr<1>
6766
// CHECK: }
68-
xegpu.store %2, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
67+
xegpu.store %0, %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
6968
: vector<1xf32>, i64, vector<1xindex>, vector<1xi1>
7069
gpu.return
7170
}
@@ -76,9 +75,9 @@ gpu.module @test {
7675
// CHECK-LABEL: @prefetch_i64_src_value_offset
7776
// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
7877
gpu.func @prefetch_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
78+
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
7979
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
8080
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
81-
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
8281
// CHECK: %[[VAR2:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
8382
// CHECK: %[[VAR3:.*]] = arith.addi %[[ARG0]], %[[VAR2]] : i64
8483
// CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[VAR3]] : i64 to !llvm.ptr<1>
@@ -94,11 +93,11 @@ gpu.module @test {
9493
// CHECK-LABEL: @prefetch_memref_src_value_offset
9594
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
9695
gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
96+
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
9797
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
9898
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
9999
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
100100
// CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
101-
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
102101
// CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
103102
// CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
104103
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>

mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,29 @@
1-
// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
22

3-
gpu.module @fence_check {
4-
gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
3+
gpu.module @prefetch_nd_check {
4+
// CHECK-LABEL: gpu.func @prefetch_nd
5+
gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
6+
// CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.constant 64 : i32
7+
// CHECK: %[[LD_CREATE_DESC_I64:.*]] = arith.constant dense<0> : vector<4xi64>
8+
// CHECK: %[[PREF_BASE_H:.*]] = arith.constant 8 : i32
9+
// CHECK: %[[PREF_BASE_W:.*]] = arith.constant 16 : i32
10+
// CHECK: %[[OFFSET_ZERO:.*]] = arith.constant 0 : i32
511
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
6-
%dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
7-
812
// CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
9-
// CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
1013
// CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
1114
// CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
12-
// CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
13-
// CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
14-
// CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
15-
// CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
15+
// CHECK: %[[LD_DESC_2:.*]] = vector.insert %[[PREF_BASE_W]], %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
16+
// CHECK: %[[LD_DESC_3:.*]] = vector.insert %[[PREF_BASE_H]], %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
17+
// CHECK: %[[LD_DESC_4:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
18+
// CHECK: %[[LD_DESC:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
1619
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32,
1720
#xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
1821

1922
//CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
2023
//CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
21-
//CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
22-
//CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
23-
//CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64
24-
//CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32
25-
//CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64
26-
//CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32
2724
//CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
28-
//CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
29-
//CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32
3025
//CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]],
31-
//CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]]
26+
//CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]]
3227
//CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
3328
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}>
3429
//CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32)

0 commit comments

Comments
 (0)