Skip to content

Commit 1f8966b

Browse files
authored
[Triton Tool] Enable the triton-tensor-layout tool for 3rd party Triton GPU Dialect (#5444)
Remove the protection of only the `ttg` dialect being allowed in triton-tensor-layout tool safely. Change the helper function to check whether the layout attribute implements the `toLinearLayout` interface. This PR does not include new test because the in-tree LIT test covers the changes for in-tree layout attribute of Triton GPU dialect.
1 parent 24b8d43 commit 1f8966b

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

bin/triton-tensor-layout.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ static cl::opt<std::string> TensorStr(
8080
//===--------------------------------------------------------------------===//
8181

8282
LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
83-
StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace();
84-
85-
// Dispatch to the corresponding dialect helper function to print the layout.
86-
if (dialectName == "ttg") {
83+
// DistributedEncodingTrait and SharedEncodingAttr implements the
84+
// toLinearLayout interface.
85+
mlir::Attribute layout = tensorType.getEncoding();
86+
if (isa<mlir::triton::gpu::DistributedEncodingTrait,
87+
mlir::triton::gpu::SharedEncodingAttr>(layout)) {
8788
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
8889
return success();
8990
}

0 commit comments

Comments
 (0)