Skip to content

Commit 91ef068

Browse files
[FA] Port remaining performance features from advanced path (#3848)
`set_fast_math` didn't work on default path, as lowering pass generates LLVM operations. This PR changed setting of fastmath flag on LLVM IR. This PR guards the setting of fastmath flag under an env var to avoid accuracy failures. ![Screenshot 2025-04-06 220833](https://github.com/user-attachments/assets/6d53900a-28f2-4294-8e9a-722284b4ac4b) Observations: 1. performance of advanced path without setting fastmath flag is the same as default path 2. performance of default path with setting fastmath flag is faster than advanced path As default path (with env var) is able to achieve performance no worse than advanced path, this PR stops running FA with advanced path. Closes #3286 --------- Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
1 parent 25f5666 commit 91ef068

File tree

4 files changed

+19
-23
lines changed

4 files changed

+19
-23
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -258,18 +258,6 @@ jobs:
258258
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
259259
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
260260
261-
- name: Run Triton FA fwd kernel benchmark - advanced path
262-
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flash_attention_benchmark.py_advanced')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_benchmark.py_advanced') }}
263-
run: |
264-
cd benchmarks/triton_kernels_benchmark
265-
TRITON_INTEL_ADVANCED_PATH=1 \
266-
IGC_VISAOptions=" -enableBCR" \
267-
python flash_attention_benchmark.py --reports $REPORTS --n_runs $N_RUNS
268-
269-
TAG="${TAG}-adv"
270-
source ../../scripts/capture-hw-details.sh
271-
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-advanced-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
272-
273261
- name: Run Triton FA bwd kernel benchmark
274262
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flash_attention_bwd_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_bwd_benchmark.py') }}
275263
run: |

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5050
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
5151
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
5252
"TRITON_INTEL_ENABLE_INSTR_SCHED",
53+
"TRITON_INTEL_FAST_MATH",
5354
"TRITON_INTEL_RAISE_BLOCK_POINTER",
5455
"TRITON_INTEL_REDUCE_TRANSPOSE",
5556
// clang-format on

third_party/intel/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ def make_llir(src, metadata, options):
346346
passes.ttgpuir.add_allocate_shared_memory(pm)
347347
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.one_matrix_per_load_for_bt)
348348
intel.passes.ttgpuir.add_rewrite_stack_ptr(pm)
349-
intel.set_fast_math(mod)
350349
passes.convert.add_arith_to_llvmir(pm)
351350
passes.common.add_canonicalizer(pm)
352351
passes.common.add_cse(pm)
@@ -359,6 +358,8 @@ def make_llir(src, metadata, options):
359358
context = llvm.context()
360359
llvm_mod = llvm.to_module(mod, context)
361360
intel.set_spv_target_triple(llvm_mod)
361+
if os.getenv("TRITON_INTEL_FAST_MATH", "0") == "1":
362+
intel.set_fast_math(llvm_mod)
362363
if options.extern_libs:
363364
paths = [path for (name, path) in options.extern_libs]
364365
llvm.link_extern_libs(llvm_mod, paths)

third_party/intel/triton_xpu.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "mlir/Pass/PassManager.h"
33
#include "passes.h"
44

5+
#include "llvm/IR/InstIterator.h"
56
#include "llvm/IRReader/IRReader.h"
67
#include "llvm/Passes/PassBuilder.h"
78
#include "llvm/Passes/PassPlugin.h"
@@ -256,16 +257,21 @@ void init_triton_intel(py::module &&m) {
256257
return py::int_(ret);
257258
});
258259

259-
// May do this after llvm ir according to user fmath flag.
260-
m.def("set_fast_math", [](mlir::ModuleOp mod) {
261-
using namespace mlir;
262-
MLIRContext *ctx = mod.getContext();
263-
mod.walk([&](Operation *op) {
264-
if (auto fmIf = dyn_cast<arith::ArithFastMathInterface>(op))
265-
op->setAttr(
266-
fmIf.getFastMathAttrName(),
267-
arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast));
268-
});
260+
// FIXME: This is for internal experimentation. In the end we will need a
261+
// producer flag (e.g. PyTorch flag) to allow the Triton compiler to use the
262+
// fast math semantics on all arithmetic operations.
263+
// https://github.com/intel/intel-xpu-backend-for-triton/issues/3862
264+
m.def("set_fast_math", [](llvm::Module *mod) {
265+
using namespace llvm;
266+
for (Function &func : *mod) {
267+
for (Instruction &inst : instructions(func)) {
268+
if (auto *op = dyn_cast<FPMathOperator>(&inst)) {
269+
FastMathFlags FMF;
270+
FMF.setFast(true);
271+
inst.setFastMathFlags(FMF);
272+
}
273+
}
274+
}
269275
});
270276

271277
m.def("set_spv_target_triple", [](llvm::Module *mod) {

0 commit comments

Comments
 (0)