Skip to content

Commit 75b0321

Browse files
authored
Support performance warning (intel#3922)
This commit adds a performance warning for not selecting MMA v3 for tl.dot on Hopper. For the added test case, we will get: ``` test-warning.py:24:18: remark: Warning: can't use MMA V3 for the dot op c = tl.dot(a, b) ^ test-warning.py:24:18: note: see current operation: %39 = tt.dot %37, %38, %cst, inputPrecision = tf32 : ```
1 parent 6f6d032 commit 75b0321

File tree

6 files changed

+70
-1
lines changed

6 files changed

+70
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
200200
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
201201
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
202202
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
203+
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.
203204

204205
# Changelog
205206

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
3333
for (int baseVersion : versionsSupported) {
3434
if (supportMMA(op, baseVersion))
3535
return baseVersion;
36+
if (baseVersion == 3)
37+
op.emitRemark() << "Warning: can't use MMA V3 for the dot op";
3638
}
3739
return 0;
3840
}

python/src/ir.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1111
#include "mlir/IR/Builders.h"
1212
#include "mlir/IR/BuiltinOps.h"
13+
#include "mlir/IR/Diagnostics.h"
1314
#include "mlir/IR/MLIRContext.h"
1415
#include "mlir/IR/Verifier.h"
1516
#include "mlir/Parser/Parser.h"
@@ -27,6 +28,7 @@
2728
#include "triton/Dialect/Triton/IR/Types.h"
2829
#include "triton/Dialect/Triton/IR/Utility.h"
2930
#include "triton/Tools/Sys/GetEnv.hpp"
31+
#include "llvm/Support/SourceMgr.h"
3032

3133
namespace {
3234

@@ -201,7 +203,16 @@ void init_triton_ir(py::module &&m) {
201203
.value("IEEE", InputPrecision::IEEE)
202204
.export_values();
203205

204-
py::class_<MLIRContext>(m, "context", py::module_local()).def(py::init<>());
206+
py::class_<MLIRContext>(m, "context", py::module_local())
207+
.def(py::init<>())
208+
.def("printOpOnDiagnostic",
209+
[](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); })
210+
.def("printStackTraceOnDiagnostic", [](MLIRContext &self, bool v) {
211+
self.printStackTraceOnDiagnostic(v);
212+
});
213+
py::class_<SourceMgrDiagnosticHandler>(m, "source_mgr_diag",
214+
py::module_local())
215+
.def(py::init<llvm::SourceMgr &, MLIRContext *>());
205216

206217
m.def("load_dialects", [](MLIRContext &context) {
207218
DialectRegistry registry;

python/src/llvm.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/Passes/StandardInstrumentations.h"
1919
#include "llvm/Support/CodeGen.h"
2020
#include "llvm/Support/Signals.h"
21+
#include "llvm/Support/SourceMgr.h"
2122
#include "llvm/Support/TargetSelect.h"
2223
#include "llvm/Target/TargetMachine.h"
2324
#include "llvm/Transforms/IPO/AlwaysInliner.h"
@@ -150,6 +151,8 @@ void init_triton_llvm(py::module &&m) {
150151

151152
py::class_<llvm::LLVMContext>(m, "context", py::module_local())
152153
.def(py::init<>());
154+
py::class_<llvm::SourceMgr>(m, "source_mgr", py::module_local())
155+
.def(py::init<>());
153156

154157
py::class_<llvm::Module::FunctionListType>(m, "function_list")
155158
.def(
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import triton
2+
import triton.language as tl
3+
import os
4+
import pytest
5+
import torch
6+
7+
8+
def is_perf_warning_enabled():
9+
return os.environ.get('MLIR_ENABLE_REMARK', '0') == '1'
10+
11+
12+
def is_cuda():
13+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
14+
15+
16+
def test_mma_remark(capfd):
17+
if is_cuda():
18+
capability = torch.cuda.get_device_capability()
19+
if capability[0] < 9:
20+
pytest.skip("Requires sm >= 90 to run")
21+
22+
os.environ['MLIR_ENABLE_REMARK'] = '1'
23+
24+
@triton.jit
25+
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn):
26+
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),
27+
block_shape=(32, 128), order=(1, 0))
28+
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
29+
block_shape=(128, 32), order=(0, 1))
30+
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
31+
block_shape=(32, 32), order=(1, 0))
32+
a = tl.load(a_block_ptr)
33+
b = tl.load(b_block_ptr)
34+
c = tl.dot(a, b)
35+
tl.store(c_block_ptr, c)
36+
37+
triton.compile(
38+
triton.compiler.ASTSource(
39+
fn=matmul_kernel, signature={
40+
0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32', 5: 'i32', 6: 'i32', 7: 'i32', 8: 'i32', 9:
41+
'i32', 10: 'i32', 11: 'i32'
42+
}, constants={}))
43+
captured = capfd.readouterr()
44+
45+
assert "remark: Warning: can't use MMA V3 for the dot op" in captured.err, "expect MMA V3 remark"
46+
assert "note: see current operation:" in captured.err
47+
os.environ['MLIR_ENABLE_REMARK'] = '0'

third_party/nvidia/backend/compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ def make_ttgir(mod, metadata, opt, capability):
158158
cluster_info.clusterDimX = opt.cluster_dims[0]
159159
cluster_info.clusterDimY = opt.cluster_dims[1]
160160
cluster_info.clusterDimZ = opt.cluster_dims[2]
161+
# Set up Diagnostic
162+
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
163+
srcMgr = llvm.source_mgr()
164+
diag = ir.source_mgr_diag(srcMgr, mod.context)
165+
mod.context.printOpOnDiagnostic(True)
161166
# TTIR -> TTGIR
162167
pm = ir.pass_manager(mod.context)
163168
pm.enable_debug()

0 commit comments

Comments
 (0)