Skip to content

Commit 24b8d43

Browse files
authored
[RELAND][Layouts] Reuse existing materializations in backwards pass (#5430)
This PR enables reusing existing convert_layout ops in the backwards pass if they didn't get removed through some other means. This enables the compiler to remove some tricky layout conversions by recognizing that the same computations can be reconstructed using other layout conversions. This relands the original change with the proper dominance fix included. The pass has to be careful when re-using existing rematerializations that they actually dominate the relevant part of the backwards slice. To facilitate this, I changed some functions to pass `OpOperand &` around, which gives more information about the `use->def` edges being analyzed.
1 parent d7ebf79 commit 24b8d43

File tree

5 files changed

+170
-52
lines changed

5 files changed

+170
-52
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,11 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
161161
// Get backward slice of tensor values starting from the root node along with
162162
// encoding propagation.
163163
LogicalResult getConvertBackwardSlice(
164-
Value root, SetVector<Value> &slice, Attribute rootEncoding,
164+
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
165165
DenseMap<Value, Attribute> &layout,
166-
std::function<bool(Operation *)> stopPropagation = nullptr);
166+
std::function<bool(Operation *)> stopPropagation = nullptr,
167+
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
168+
nullptr);
167169

168170
// Populate pattern to remove dead cycles in ForOp.
169171
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/Analysis/SliceAnalysis.h"
44
#include "mlir/Dialect/SCF/IR/SCF.h"
55
#include "mlir/IR/BuiltinAttributes.h"
6+
#include "mlir/IR/Dominance.h"
67
#include "mlir/IR/IRMapping.h"
78
#include "mlir/IR/Matchers.h"
89
#include "mlir/IR/PatternMatch.h"
@@ -116,17 +117,15 @@ class LayoutPropagation {
116117
class LayoutRematerialization {
117118
public:
118119
LayoutRematerialization(FuncOp F) : funcOp(F) {}
120+
119121
// Map the original value to the remat'ed one.
120122
void addRematValue(Value old, Attribute encoding, Value newV);
121-
bool hasRematValue(Value value, Attribute encoding) {
122-
return rematMapping.contains({value, encoding});
123-
}
124-
// Return the remat'ed value in the given encoding.
125-
Value getRematValue(Value value, Attribute encoding) {
126-
auto it = rematMapping.find({value, encoding});
127-
assert(it != rematMapping.end());
128-
return it->second;
123+
// Get the remat'ed value in the given encoding, if one already exists and
124+
// is different then the layout conversion root.
125+
Value getRematValue(Value value, Attribute encoding) const {
126+
return rematMapping.lookup({value, encoding});
129127
}
128+
130129
void cleanup();
131130
void backwardRematerialization();
132131
void backwardRematerialization(ConvertLayoutOp convertOp);
@@ -137,6 +136,11 @@ class LayoutRematerialization {
137136
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
138137
ConvertLayoutOp convertOp);
139138

139+
LogicalResult getRematerializableSlice(
140+
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
141+
DenseMap<Value, Attribute> &layout,
142+
std::function<bool(Operation *)> stopPropagation = nullptr);
143+
140144
private:
141145
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
142146
// Existing tuples of (value, layout) that needs to be updated when recreating
@@ -148,6 +152,7 @@ class LayoutRematerialization {
148152
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
149153
SetVector<Operation *> opToDelete;
150154
FuncOp funcOp;
155+
DominanceInfo domInfo;
151156
};
152157

153158
void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
@@ -778,8 +783,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
778783
auto layoutIt = layout.find(v);
779784
assert(layoutIt != layout.end());
780785
// If we already have a remat value for this value, use it.
781-
if (hasRematValue(v, layoutIt->second)) {
782-
mapping.map(v, getRematValue(v, layoutIt->second));
786+
if (Value remat = getRematValue(v, layoutIt->second)) {
787+
mapping.map(v, remat);
783788
valuesWithExistingRemat.insert(v);
784789
continue;
785790
}
@@ -940,12 +945,36 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
940945
rewriteSlice(slice, layout, convertOp, mapping);
941946
}
942947

943-
LogicalResult getRematerializableSlice(
944-
Value root, Attribute rootEncoding, SetVector<Value> &slice,
948+
LogicalResult LayoutRematerialization::getRematerializableSlice(
949+
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
945950
DenseMap<Value, Attribute> &layout,
946-
std::function<bool(Operation *)> stopPropagation = nullptr) {
947-
LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding,
948-
layout, stopPropagation);
951+
std::function<bool(Operation *)> stopPropagation) {
952+
// Allow re-using existing conversions for a value. Check dominance of any
953+
// re-usable materializations against the root value. This is sufficient
954+
// because the conversions are processed in post-order.
955+
auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
956+
Value remat = getRematValue(value.get(), encoding);
957+
if (!remat)
958+
return Value();
959+
// `value` can be replaced with an existing rematerialization if it
960+
// dominates the current use of value.
961+
Operation *user = value.getOwner();
962+
if (domInfo.properlyDominates(remat, user)) {
963+
return remat;
964+
}
965+
// Alternatively, if the current use can be sunk below the existing
966+
// rematerialization, then it is okay to use as well. E.g. the current use
967+
// is a conversion that will be folded away when its result is
968+
// rematerialized.
969+
if (isa<ConvertLayoutOp>(user) && remat.getDefiningOp() &&
970+
domInfo.properlyDominates(user, remat.getDefiningOp())) {
971+
return remat;
972+
}
973+
return Value();
974+
};
975+
LogicalResult result =
976+
getConvertBackwardSlice(root, slice, rootEncoding, layout,
977+
stopPropagation, getExistingConversion);
949978
if (result.failed() || slice.empty())
950979
return failure();
951980

@@ -966,6 +995,12 @@ void LayoutRematerialization::backwardRematerialization() {
966995
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
967996
for (ConvertLayoutOp convertOp : convertOps) {
968997
backwardRematerialization(convertOp);
998+
if (!opToDelete.contains(convertOp)) {
999+
// If the conversion didn't get removed, consider it for re-use in future
1000+
// backward slices.
1001+
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
1002+
convertOp.getResult());
1003+
}
9691004
}
9701005
}
9711006

@@ -976,6 +1011,12 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
9761011
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
9771012
for (ConvertLayoutOp convertOp : convertOps) {
9781013
hoistConvertOnTopOfExtOrBroadcast(convertOp);
1014+
if (!opToDelete.contains(convertOp)) {
1015+
// If the conversion didn't get removed, consider it for re-use in future
1016+
// backward slices.
1017+
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
1018+
convertOp.getResult());
1019+
}
9791020
}
9801021
}
9811022

@@ -988,14 +1029,14 @@ void LayoutRematerialization::backwardRematerialization(
9881029
// careful with the heuristics for both correctness and perf
9891030
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
9901031
return;
991-
Value oldV = convertOp->getOperand(0);
1032+
Value oldV = convertOp.getSrc();
9921033
LDBG("check backward remat with source " << oldV << " encoding "
9931034
<< targetType.getEncoding());
9941035
// Check to see if there are existing remat'ed values for the pair of oldValue
995-
// and encoding.
996-
if (hasRematValue(oldV, targetType.getEncoding())) {
1036+
// and encoding. Make sure it dominates the current conversion.
1037+
Value newV = getRematValue(oldV, targetType.getEncoding());
1038+
if (newV && domInfo.properlyDominates(newV, convertOp)) {
9971039
// Replace it with the remat'ed value.
998-
Value newV = getRematValue(oldV, targetType.getEncoding());
9991040
convertOp.replaceAllUsesWith(newV);
10001041
opToDelete.insert(convertOp);
10011042
LDBG("found remat'ed value" << newV);
@@ -1007,7 +1048,7 @@ void LayoutRematerialization::backwardRematerialization(
10071048
SetVector<Value> slice;
10081049
DenseMap<Value, Attribute> layout;
10091050
LogicalResult result = getRematerializableSlice(
1010-
convertOp.getSrc(), targetType.getEncoding(), slice, layout);
1051+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
10111052
if (result.failed()) {
10121053
LDBG(" getRematerializableSlice failed");
10131054
return;
@@ -1050,9 +1091,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10501091
// 1. Take a backward slice of all the tensor dependencies.
10511092
SetVector<Value> slice;
10521093
DenseMap<Value, Attribute> layout;
1053-
LogicalResult result =
1054-
getRematerializableSlice(convertOp.getSrc(), targetType.getEncoding(),
1055-
slice, layout, isExtOrBroadcastOp);
1094+
LogicalResult result = getRematerializableSlice(
1095+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
1096+
isExtOrBroadcastOp);
10561097
if (result.failed())
10571098
return;
10581099

@@ -1070,7 +1111,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10701111
if (!srcEncoding)
10711112
return;
10721113
LogicalResult result = getRematerializableSlice(
1073-
op->getOperand(0), srcEncoding, tempSlice, tempLayout);
1114+
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
10741115
// If we can rematerialize the rest of the ext slice we can ignore this
10751116
// ext as it won't need a convert.
10761117
if (result.succeeded()) {

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -770,44 +770,60 @@ static bool isFreeConvert(Operation *op) {
770770
convertOp.getType());
771771
}
772772

773-
LogicalResult
774-
getConvertBackwardSlice(Value root, SetVector<Value> &slice,
775-
Attribute rootEncoding,
776-
DenseMap<Value, Attribute> &layout,
777-
std::function<bool(Operation *)> stopPropagation) {
778-
DenseSet<std::pair<Value, Attribute>> seen;
779-
SmallVector<std::pair<Value, Attribute>> queue;
780-
781-
auto enqueue = [&](Value operand, Attribute encoding) {
782-
auto x = std::make_pair(operand, encoding);
773+
LogicalResult getConvertBackwardSlice(
774+
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
775+
DenseMap<Value, Attribute> &layout,
776+
std::function<bool(Operation *)> stopPropagation,
777+
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
778+
DenseSet<std::pair<OpOperand *, Attribute>> seen;
779+
SmallVector<std::pair<OpOperand *, Attribute>> queue;
780+
781+
auto enqueue = [&](OpOperand &operand, Attribute encoding) {
782+
auto x = std::make_pair(&operand, encoding);
783783
if (!seen.insert(x).second) {
784784
return; // Already enqueued, skip
785785
}
786786
queue.push_back(x);
787787
};
788788
enqueue(root, rootEncoding);
789789

790+
auto updateLayout = [&](Value value, Attribute encoding) {
791+
assert((isa<RankedTensorType>(value.getType())));
792+
slice.insert(value);
793+
if (layout.find(value) != layout.end()) {
794+
if (layout[value] != encoding)
795+
return failure();
796+
}
797+
layout[value] = encoding;
798+
return success();
799+
};
800+
790801
while (!queue.empty()) {
791-
auto [currentValue, encoding] = queue.back();
802+
auto [currentValueUse, encoding] = queue.back();
803+
Value currentValue = currentValueUse->get();
792804
queue.pop_back();
793805
if (!isa<RankedTensorType>(currentValue.getType()))
794806
continue;
795807
// Skip propagating through for op results for now.
796808
// TODO: enable this based on needs.
797809
if (currentValue.getDefiningOp<scf::ForOp>())
798810
return failure();
799-
slice.insert(currentValue);
800-
if (layout.find(currentValue) != layout.end()) {
801-
if (layout[currentValue] != encoding)
811+
if (failed(updateLayout(currentValue, encoding)))
812+
return failure();
813+
814+
Value existing;
815+
if (getExistingConversion &&
816+
(existing = getExistingConversion(*currentValueUse, encoding))) {
817+
if (failed(updateLayout(existing, encoding)))
802818
return failure();
819+
currentValue = existing;
803820
}
804-
layout[currentValue] = encoding;
805821

806822
if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
807823
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
808824

809-
auto thenValue = ifOp.thenYield().getOperand(argIdx);
810-
auto elseValue = ifOp.elseYield().getOperand(argIdx);
825+
OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx);
826+
OpOperand &elseValue = ifOp.elseYield()->getOpOperand(argIdx);
811827

812828
enqueue(thenValue, encoding);
813829
enqueue(elseValue, encoding);
@@ -819,10 +835,11 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
819835
for (Value result : definingOp->getResults()) {
820836
if (result == currentValue || !isa<RankedTensorType>(result.getType()))
821837
continue;
822-
enqueue(result, encoding);
838+
if (failed(updateLayout(result, encoding)))
839+
return failure();
823840
}
824841
if (isFreeConvert(definingOp)) {
825-
enqueue(definingOp->getOperand(0), encoding);
842+
enqueue(definingOp->getOpOperand(0), encoding);
826843
continue;
827844
}
828845
if (canFoldIntoConversion(definingOp, encoding))
@@ -837,10 +854,10 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
837854
auto srcEncoding = inferSrcEncoding(gather, encoding);
838855
if (!srcEncoding)
839856
return failure();
840-
enqueue(gather.getIndices(), srcEncoding);
857+
enqueue(gather.getIndicesMutable(), srcEncoding);
841858
continue;
842859
}
843-
for (auto [i, operand] : llvm::enumerate(definingOp->getOperands())) {
860+
for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) {
844861
auto srcEncoding = inferSrcEncoding(definingOp, encoding);
845862
if (!srcEncoding)
846863
return failure();
@@ -853,9 +870,9 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
853870
Operation *parentOp = block->getParentOp();
854871
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
855872
OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
856-
Value yieldOperand = forOp.getBody()->getTerminator()->getOperand(
873+
OpOperand &yieldOperand = forOp.getBody()->getTerminator()->getOpOperand(
857874
blockArg.getArgNumber() - forOp.getNumInductionVars());
858-
enqueue(initOperand->get(), encoding);
875+
enqueue(*initOperand, encoding);
859876
enqueue(yieldOperand, encoding);
860877
continue;
861878
}

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ configure_lit_site_cfg(
1414
set(TRITON_TEST_DEPENDS
1515
triton-opt
1616
triton-tensor-layout
17+
triton-llvm-opt
1718
)
1819

1920
set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck")

test/TritonGPU/combine.mlir

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s
22

33
#layout0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
44
#layout1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@@ -2427,8 +2427,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
24272427
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #blocked2>
24282428
%3 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
24292429
%4 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<128x64x!tt.ptr<i8>, #blocked>
2430-
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>)
2431-
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
2430+
// CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>)
2431+
// CHECK-COUNT-4: convert_layout
2432+
// CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
24322433
// CHECK: }
24332434
// CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
24342435
%5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 {
@@ -2772,3 +2773,59 @@ tt.func @do_not_remat(%arg0: tensor<64x64xf32, #blocked1>) -> tensor<1x64xf32, #
27722773
}
27732774

27742775
}
2776+
2777+
// -----
2778+
2779+
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
2780+
#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
2781+
2782+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
2783+
2784+
// CHECK-LABEL: reuse_layout_conversion
2785+
tt.func @reuse_layout_conversion(%arg0: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>) {
2786+
// CHECK-NEXT: %cst = arith.constant {{.*}} tensor<64x64xf32, #blocked>
2787+
%cst = arith.constant dense<2.000000e+00> : tensor<64x64xf32, #blocked1>
2788+
// CHECK-NEXT: [[TRANS:%.*]] = tt.trans %arg0 {{.*}} tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1>
2789+
%0 = tt.trans %arg0 {order = array<i32: 1, 0>} : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1>
2790+
// CHECK-NEXT: [[CVT:%.*]] = ttg.convert_layout [[TRANS]] : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2791+
%1 = ttg.convert_layout %0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2792+
// CHECK-NEXT: [[RESULT:%.*]] = arith.mulf [[CVT]], %cst : tensor<64x64xf32, #blocked>
2793+
%2 = arith.mulf %0, %cst : tensor<64x64xf32, #blocked1>
2794+
%3 = ttg.convert_layout %2 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2795+
// CHECK-NEXT: return [[CVT]], [[RESULT]]
2796+
tt.return %1, %3 : tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>
2797+
}
2798+
2799+
// CHECK-LABEL: respect_dominance
2800+
tt.func @respect_dominance(%arg0: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>) {
2801+
%cst = arith.constant dense<2.000000e+00> : tensor<64x64xf32, #blocked1>
2802+
2803+
// CHECK-COUNT-2: convert_layout
2804+
%0 = tt.trans %arg0 {order = array<i32: 1, 0>} : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1>
2805+
2806+
%2 = arith.mulf %0, %cst : tensor<64x64xf32, #blocked1>
2807+
%1 = ttg.convert_layout %0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2808+
%3 = ttg.convert_layout %2 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked>
2809+
tt.return %1, %3 : tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>
2810+
}
2811+
2812+
// CHECK-LABEL: remat_across_regions
2813+
tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) {
2814+
// CHECK-NEXT: scf.if
2815+
scf.if %arg0 {
2816+
// CHECK-NEXT: convert_layout
2817+
%0 = ttg.convert_layout %arg1 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1>
2818+
"test.keep"(%0) : (tensor<8x8xf32, #blocked1>) -> ()
2819+
// CHECK: else
2820+
} else {
2821+
%0 = "test.dummy"() : () -> i32
2822+
// CHECK: convert_layout
2823+
%1 = ttg.convert_layout %arg1 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1>
2824+
"test.keep"(%1) : (tensor<8x8xf32, #blocked1>) -> ()
2825+
// CHECK: }
2826+
}
2827+
// CHECK-NEXT: return
2828+
tt.return
2829+
}
2830+
2831+
}

0 commit comments

Comments
 (0)