Skip to content
33 changes: 33 additions & 0 deletions test/Conversion/intel/load_store_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,36 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.return
}
}

// -----

#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
// CHECK-LABEL: global_store_with_attributes
tt.func @global_store_with_attributes(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
tt.store %6, %cst : tensor<256x!tt.ptr<f32>, #blocked0>
tt.store %6, %cst cacheModifier = ca : tensor<256x!tt.ptr<f32>, #blocked0>
tt.store %6, %cst cacheModifier = cg : tensor<256x!tt.ptr<f32>, #blocked0>
tt.store %6, %cst cacheModifier = wb : tensor<256x!tt.ptr<f32>, #blocked0>
tt.store %6, %cst cacheModifier = cs : tensor<256x!tt.ptr<f32>, #blocked0>
tt.store %6, %cst cacheModifier = wt : tensor<256x!tt.ptr<f32>, #blocked0>
tt.store %6, %cst cacheModifier = cv : tensor<256x!tt.ptr<f32>, #blocked0>
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<1>
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<1>
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64, nontemporal} : vector<4xi32>, !llvm.ptr<1>
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<1>
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64, nontemporal} : vector<4xi32>, !llvm.ptr<1>
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<1>
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64, nontemporal} : vector<4xi32>, !llvm.ptr<1>
tt.return
}
}
84 changes: 51 additions & 33 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,11 @@ struct LoadStoreConversionBase {
// }
// All the values are decomposed by `unpackLLElements` into a vector.
// Defines the indices for the block pointer struct.
unsigned blockOffset = 0, blockShape = 1 * rank, blockStride = 2 * rank,
blockBase = 3 * rank;
const unsigned blockOffset = 0, blockShape = 1 * rank,
blockStride = 2 * rank, blockBase = 3 * rank;
const SmallVector<Value> &blockPtr =
unpackLLElements(loc, blockPointerStruct, rewriter);

unsigned numElems = getTotalElemsPerThread(tensorType);
const unsigned numElems = getTotalElemsPerThread(tensorType);

// Get the LLVM values for indices in block
auto indices = emitIndices(loc, rewriter, targetInfo,
Expand Down Expand Up @@ -293,6 +292,34 @@ struct LoadStoreConversionBase {
return std::make_tuple(ptrElems, maskElems, otherElems);
}

// Ensure the operation doesn't have attributes that the IGC predicated
// instruction cannot handle.
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, LoadOp, StoreOp>::value>>
bool canUsePredicatedInstructions(OpType op) const {
if (!usePredicatedInstructions)
return false;

if constexpr (std::is_same_v<OpType, LoadOp>)
return !op.getIsVolatile() && op.getCache() == CacheModifier::NONE;

return op.getCache() == CacheModifier::NONE;
}

template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, LoadOp, StoreOp>::value>>
bool getNonTemporalFlag(OpType op) const {
switch (op.getCache()) {
case triton::CacheModifier::CG:
case triton::CacheModifier::CS:
case triton::CacheModifier::CV:
return true;
case triton::CacheModifier::CA:
default:
return false;
}
}

protected:
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass;
const triton::intel::TargetInfo &targetInfo;
Expand Down Expand Up @@ -3035,11 +3062,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
: retTys[0];

Value other_ = b.undef(retTy);
if (otherElems.size()) {
if (otherElems.empty()) {
other_ = rewriter.create<LLVM::ConstantOp>(loc, retTy,
rewriter.getZeroAttr(retTy));
} else {
for (size_t ii = 0; ii < nWords; ++ii) {
size_t size = width / valueElemNBits;

auto vecTy = vec_ty(valueElemTy, size);
VectorType vecTy = vec_ty(valueElemTy, size);
Value v = b.undef(vecTy);
for (size_t s = 0; s < size; ++s) {
Value falseVal = otherElems[vecStart + ii * size + s];
Expand All @@ -3065,36 +3094,21 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,

v;
}
} else {
other_ = rewriter.create<LLVM::ConstantOp>(loc, retTy,
rewriter.getZeroAttr(retTy));
}
assert(other_ && "Expecting a valid value");

Value addrElem = b.bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
uint32_t alignment = nWords * width / 8;
auto createLoadWithAttrs = [&]() -> SmallVector<Value, 1> {
auto getNonTemporalFlag = [](triton::LoadOp loadOp) {
switch (loadOp.getCache()) {
case triton::CacheModifier::CG:
case triton::CacheModifier::CS:
case triton::CacheModifier::CV:
return true;
case triton::CacheModifier::CA:
default:
return false;
}
};

Value ret = b.load(retTy, addrElem, alignment, op.getIsVolatile(),
getNonTemporalFlag(op));
return {ret};
auto createLoadWithAttrs = [&]() {
return SmallVector<Value>{b.load(retTy, addrElem, alignment,
op.getIsVolatile(),
getNonTemporalFlag(op))};
};

Value ret;

if (!pred)
ret = createLoadWithAttrs()[0];
else if (usePredicatedInstructions)
else if (canUsePredicatedInstructions(op))
ret = rewriter.create<TritonGEN::PredicatedLoadOp>(
loc, retTy, addrElem, b.i64_val(alignment), pred, other_);
else {
Expand All @@ -3116,6 +3130,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
curr, LLVM::getVectorType(valueElemTy, width / valueElemNBits));
rets.push_back(curr);
}

int tmp = width / valueElemNBits;
for (size_t ii = 0; ii < vec; ++ii) {
Value loaded =
Expand Down Expand Up @@ -3528,18 +3543,21 @@ struct StoreOpConversion

Value addrElem = b.bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
uint32_t alignment = nWords * width / 8;
auto createStore = [&]() -> ArrayRef<Value> {
b.store(vecWord, addrElem, alignment);
auto createStoreWithAttrs = [&]() {
bool isVolatile = false;
b.store(vecWord, addrElem, alignment, isVolatile,
getNonTemporalFlag(op));
return ArrayRef<Value>();
};

if (!maskVal)
auto _ = createStore();
else if (usePredicatedInstructions)
auto _ = createStoreWithAttrs();
else if (canUsePredicatedInstructions(op))
rewriter.create<TritonGEN::PredicatedStoreOp>(
loc, addrElem, vecWord, b.i64_val(alignment), maskVal);
else
LLVM::intel::createPredicatedBlock(rewriter, loc, maskVal, createStore);
LLVM::intel::createPredicatedBlock(rewriter, loc, maskVal,
createStoreWithAttrs);
}

rewriter.eraseOp(op);
Expand Down
Loading