@@ -66,6 +66,15 @@ class HasMatchingMaskTypeConstraint<string vector, string mask> :
6666 vector, mask,
6767 "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
6868
69+ class TileSliceMaskConstraint<string tile, string mask> :
70+ TypesMatchWith<
71+ "`" # mask # "` has i1 element type and the shape is a slice of `" # tile # "`",
72+ tile, mask,
73+ "VectorType("
74+ "VectorType::Builder("
75+ "::llvm::cast<mlir::VectorType>($_self)"
76+ ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1)))">;
77+
6978//===----------------------------------------------------------------------===//
7079// ArmSME attr definitions
7180//===----------------------------------------------------------------------===//
@@ -408,15 +417,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
408417}
409418
410419def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
411- AllTypesMatch<["tile", "result"]>,
412- TypesMatchWith<
413- "mask has i1 element type and is a slice of the result",
414- "result", "mask",
415- "VectorType("
416- "VectorType::Builder("
417- "::llvm::cast<mlir::VectorType>($_self)"
418- ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
419- ")">,
420+ AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
420421]> {
421422 let summary = "Tile slice load and update operation";
422423 let description = [{
@@ -432,9 +433,8 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
432433 dimensions since the operation is scalable, and the element type must be a
433434 scalar that matches the element type of the result.
434435
435- An SSA value `mask` specifies to mask out elements read from the MemRef.
436- The `mask` type is an `i1` vector with a shape that matches how elements
437- are read from the MemRef.
436+ The provided `mask` is used to specify which elements of the tile slice
437+ will be loaded.
438438
439439 Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
440440 ```mlir
@@ -474,7 +474,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
474474 }];
475475}
476476
477- def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
477+ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
478+ TileSliceMaskConstraint<"tile", "mask">
479+ ]> {
478480 let summary = "Tile slice store operation";
479481 let description = [{
480482 Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
@@ -489,22 +491,26 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
489491 dimensions since the operation is scalable, and the element type must be a
490492 scalar that matches the element type of the input tile.
491493
494+ The provided `mask` is used to specify which elements of the tile slice
495+ will be stored.
496+
492497 Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
493498 ```mlir
494- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
499+ arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, % base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1 >, memref<?x?xi8>
495500 ```
496501
497502 Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
498503 ```mlir
499- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
504+ arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, % base[%c0] layout<vertical> : vector<[4]x[4]xf32>, vector<[4]xi1 >, memref<?x?xf32>
500505 ```
501506
502507 Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
503508 ```mlir
504- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
509+ arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, % base[%c0] layout<vertical> : vector<[1]x[1]xi128>, vector<[1]xi1 >, memref<?x?xi128>
505510 ```
506511 }];
507- let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
512+ let arguments = (ins
513+ SMETile:$tile, Index:$tile_slice_index, SVEPredicate:$mask,
508514 Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
509515 Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
510516 );
@@ -518,8 +524,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
518524 }];
519525
520526 let assemblyFormat = [{
521- $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
522- attr-dict `:` type($base) `,` type($tile)
527+ $tile `,` $tile_slice_index `,` $mask `,` $ base `[` $indices `]` (`layout` `` $layout^)?
528+ attr-dict `:` type($base) `,` type($mask) `,` type($ tile)
523529 }];
524530}
525531
0 commit comments