Skip to content

Commit 4c52491

Browse files
authored
[NFI]: Reorganize lowering code for tt.load/tt.store (#5483)
This PR cleans up code in the lowering code for tt.load and tt.store operations. Signed-off-by: Ettore Tiotto <ettore.tiotto@intel.com>
1 parent c72740f commit 4c52491

File tree

1 file changed

+70
-71
lines changed

1 file changed

+70
-71
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,11 @@ struct LoadStoreConversionBase {
211211
[](ArrayRef<Value> A, ArrayRef<Value> B, Value init,
212212
std::function<Value(const Value &, const Value &, const Value &)>
213213
linearizeFunc) {
214-
auto rank = A.size();
214+
unsigned rank = A.size();
215215
Value accumulate = init;
216216
if (rank > 0) {
217-
for (auto [a, b] : llvm::zip(A, B)) {
217+
for (auto [a, b] : llvm::zip(A, B))
218218
accumulate = linearizeFunc(a, b, accumulate);
219-
}
220219
}
221220
return accumulate;
222221
};
@@ -226,11 +225,10 @@ struct LoadStoreConversionBase {
226225
SmallVector<Value> ptrElems(numElems);
227226
SmallVector<Value> maskElems;
228227
for (unsigned i = 0; i < numElems; ++i) {
229-
auto index = indices[i];
228+
SmallVector<Value> index = indices[i];
230229
SmallVector<Value> indicesInTensor(rank);
231-
for (unsigned j = 0; j < rank; ++j) {
230+
for (unsigned j = 0; j < rank; ++j)
232231
indicesInTensor[j] = b.add(index[j], blockPtr[blockOffset + j]);
233-
}
234232

235233
// Get the LLVM values for pointers
236234
Value offset = linearize(
@@ -272,22 +270,24 @@ struct LoadStoreConversionBase {
272270
SmallVector<Value> otherElems;
273271
if (padding) {
274272
Value other;
275-
if (*padding == PaddingOption::PAD_ZERO) {
273+
switch (*padding) {
274+
case PaddingOption::PAD_ZERO:
276275
other = rewriter.create<LLVM::ConstantOp>(
277276
loc, valueElemTy, rewriter.getZeroAttr(valueElemTy));
278-
} else if (*padding == PaddingOption::PAD_NAN) {
277+
278+
break;
279+
case PaddingOption::PAD_NAN: {
279280
assert(!valueElemTy.isIntOrIndex() &&
280281
"Expect element type to be non-integer type");
281282
auto apNaN = llvm::APFloat::getNaN(
282283
cast<FloatType>(valueElemTy).getFloatSemantics());
283284
other = rewriter.create<LLVM::ConstantOp>(
284285
loc, valueElemTy, rewriter.getFloatAttr(valueElemTy, apNaN));
285-
} else {
286-
llvm_unreachable("Unexpected padding option");
286+
} break;
287287
}
288-
for (unsigned i = 0; i < numElems; ++i) {
288+
289+
for (unsigned i = 0; i < numElems; ++i)
289290
otherElems.push_back(other);
290-
}
291291
}
292292

293293
return std::make_tuple(ptrElems, maskElems, otherElems);
@@ -296,6 +296,8 @@ struct LoadStoreConversionBase {
296296
protected:
297297
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass;
298298
const triton::intel::TargetInfo &targetInfo;
299+
const bool usePredicatedInstructions =
300+
triton::tools::getBoolEnv("TRITON_INTEL_PREDICATED");
299301
};
300302

301303
struct BlockIOConversionBase : public LoadStoreConversionBase {
@@ -2940,9 +2942,16 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
29402942
auto b = TritonLLVMOpBuilder(loc, rewriter);
29412943
auto typeConverter = getTypeConverter();
29422944
MLIRContext *ctx = rewriter.getContext();
2945+
2946+
// original values
29432947
Value ptr = op.getPtr();
29442948
Value mask = op.getMask();
2949+
Value other = op.getOther();
2950+
2951+
// adaptor values
2952+
Value llPtr = adaptor.getPtr();
29452953
Value llMask = adaptor.getMask();
2954+
Value llOther = adaptor.getOther();
29462955

29472956
// Determine the vectorization size
29482957
Type valueElemTy =
@@ -2960,13 +2969,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
29602969
// fallback to gather load.
29612970
auto tensorType = cast<RankedTensorType>(op.getType());
29622971
std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr(
2963-
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
2964-
op.getBoundaryCheck(), op.getPadding());
2972+
loc, llPtr, tensorType, valueElemTy, rewriter, op.getBoundaryCheck(),
2973+
op.getPadding());
29652974
} else {
2966-
Value other = op.getOther();
2967-
Value llPtr = adaptor.getPtr();
2968-
Value llOther = adaptor.getOther();
2969-
29702975
// Get the LLVM values for pointers
29712976
ptrElems = unpackLLElements(loc, llPtr, rewriter);
29722977
assert(ptrElems.size() == numElems);
@@ -2998,23 +3003,22 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
29983003
const int numVecs = numElems / vec;
29993004

30003005
// Load redundantly in all dims except reg
3001-
auto freeVarMasks = getFreeVariableMasks(ptr.getType());
3006+
llvm::MapVector<StringAttr, int> freeVarMasks =
3007+
getFreeVariableMasks(ptr.getType());
30023008
uint32_t regMask = freeVarMasks[str_attr("register")];
30033009

30043010
SmallVector<Value> loadedVals;
30053011
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
3006-
if (auto canonicalVecStart = getCanonicalIndex(vecStart, regMask);
3012+
if (unsigned canonicalVecStart = getCanonicalIndex(vecStart, regMask);
30073013
vecStart != canonicalVecStart) {
30083014
// For redundant registers, refer back to the canonical load
3009-
for (auto iVec = 0; iVec < vec; ++iVec) {
3015+
for (int iVec = 0; iVec < vec; ++iVec)
30103016
loadedVals.push_back(loadedVals[canonicalVecStart + iVec]);
3011-
}
3017+
30123018
continue;
30133019
}
30143020

30153021
// TODO: optimization when ptr is GEP with constant offset
3016-
size_t in_off = 0;
3017-
30183022
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
30193023
const size_t totalWidth = valueElemNBits * vec;
30203024
const size_t width = std::min(totalWidth, maxWordWidth);
@@ -3025,7 +3029,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
30253029

30263030
Value pred = maskElems.size() ? maskElems[vecStart] : Value{};
30273031

3028-
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
3032+
SmallVector<Type> retTys(nWords, IntegerType::get(ctx, width));
30293033
Type retTy = retTys.size() > 1
30303034
? vec_ty(IntegerType::get(ctx, width), nWords)
30313035
: retTys[0];
@@ -3051,13 +3055,15 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
30513055
v = b.int_val(width, splatVal);
30523056
}
30533057

3054-
Value iiVal = createIndexAttrConstant(
3055-
rewriter, loc, typeConverter->getIndexType(), ii);
3056-
if (nWords > 1) {
3057-
other_ = b.insert_element(retTy, other_, v, iiVal);
3058-
} else {
3059-
other_ = v;
3060-
}
3058+
other_ =
3059+
(nWords > 1)
3060+
? b.insert_element(
3061+
retTy, other_, v,
3062+
createIndexAttrConstant(
3063+
rewriter, loc, typeConverter->getIndexType(), ii))
3064+
:
3065+
3066+
v;
30613067
}
30623068
} else {
30633069
other_ = rewriter.create<LLVM::ConstantOp>(loc, retTy,
@@ -3085,31 +3091,27 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
30853091
};
30863092

30873093
Value ret;
3088-
// Create a predicated load operation.
3089-
if (pred) {
3090-
if (triton::tools::getBoolEnv("TRITON_INTEL_PREDICATED"))
3091-
ret = rewriter.create<TritonGEN::PredicatedLoadOp>(
3092-
loc, retTy, addrElem, b.i64_val(alignment), pred, other_);
3093-
else {
3094-
Block &endBlock = LLVM::intel::createPredicatedBlock(
3095-
rewriter, loc, pred, SmallVector<Value, 1>{other_},
3096-
createLoadWithAttrs);
3097-
ret = *endBlock.args_begin();
3098-
}
3099-
} else {
3094+
3095+
if (!pred)
31003096
ret = createLoadWithAttrs()[0];
3097+
else if (usePredicatedInstructions)
3098+
ret = rewriter.create<TritonGEN::PredicatedLoadOp>(
3099+
loc, retTy, addrElem, b.i64_val(alignment), pred, other_);
3100+
else {
3101+
Block &endBlock = LLVM::intel::createPredicatedBlock(
3102+
rewriter, loc, pred, SmallVector<Value, 1>{other_},
3103+
createLoadWithAttrs);
3104+
ret = *endBlock.args_begin();
31013105
}
3106+
assert(ret && "Expecting a valid value");
31023107

31033108
// Extract and store return values
31043109
SmallVector<Value> rets;
31053110
for (unsigned int ii = 0; ii < nWords; ++ii) {
3106-
Value curr;
3107-
if (isa<VectorType>(retTy)) {
3108-
curr = b.extract_element(IntegerType::get(ctx, width), ret,
3109-
b.i32_val(ii));
3110-
} else {
3111-
curr = ret;
3112-
}
3111+
Value curr = isa<VectorType>(retTy)
3112+
? b.extract_element(IntegerType::get(ctx, width), ret,
3113+
b.i32_val(ii))
3114+
: ret;
31133115
curr = b.bitcast(
31143116
curr, LLVM::getVectorType(valueElemTy, width / valueElemNBits));
31153117
rets.push_back(curr);
@@ -3177,8 +3179,8 @@ struct StoreOpToBlockIOConversion
31773179
// Limit vBlock to 1
31783180
vBlocks = 1;
31793181

3180-
// TODO: use the axis info to general the handling for both regular pointer
3181-
// and block pointer.
3182+
// TODO: use the axis info to general the handling for both regular
3183+
// pointer and block pointer.
31823184
const bool memoryRowMajor = isMemoryRowMajor(op);
31833185
unsigned contiguousDim = memoryRowMajor ? 1 : 0;
31843186
if (contiguousDim != colDim) {
@@ -3312,17 +3314,17 @@ struct StoreOpToBlockIOConversion
33123314
Value addrElem = ptrElems[registerIdx];
33133315
Value offsetX, offsetY;
33143316
if (isBlockPointer) {
3315-
// Need to apply the linear layout to get the offsets to the base of the
3316-
// block pointer.
3317-
// TODO: add annotation uniform to the offsets. Make sure the IGC detect
3318-
// the offsets as uniform.
3317+
// Need to apply the linear layout to get the offsets to the base of
3318+
// the block pointer.
3319+
// TODO: add annotation uniform to the offsets. Make sure the IGC
3320+
// detect the offsets as uniform.
33193321
auto offsets = applyLinearLayout(loc, rewriter, *llEncoding,
33203322
{{kRegister, b.i32_val(registerIdx)},
33213323
{kLane, b.i32_val(0)},
33223324
{kWarp, warpId},
33233325
{kBlock, b.i32_val(0)}});
3324-
// TODO: To support rank > 2 tensor, we need to add the offsets of other
3325-
// dim to the base.
3326+
// TODO: To support rank > 2 tensor, we need to add the offsets of
3327+
// other dim to the base.
33263328
assert(offsets.size() == 2 && "only support 2D tensor for now.");
33273329
offsetX = b.add(offsetBaseX, offsets[colDim].second);
33283330
offsetY = b.add(offsetBaseY, offsets[rowDim].second);
@@ -3342,7 +3344,8 @@ struct StoreOpToBlockIOConversion
33423344
assert(numPackedVals > 0 &&
33433345
"numPackedVals should be greater than zero.");
33443346
// The offsetX of linear layout is in original elements.
3345-
// The 2d block io requires the offsetX in number of packed elements.
3347+
// The 2d block io requires the offsetX in number of packed
3348+
// elements.
33463349
offsetX = b.udiv(offsetX, b.i32_val(numPackedVals));
33473350
}
33483351
if (!boundaryCheck.contains(rowDim)) {
@@ -3530,19 +3533,15 @@ struct StoreOpConversion
35303533
return ArrayRef<Value>();
35313534
};
35323535

3533-
if (maskVal) {
3534-
// Create a predicated store operation.
3535-
if (triton::tools::getBoolEnv("TRITON_INTEL_PREDICATED"))
3536-
rewriter.create<TritonGEN::PredicatedStoreOp>(
3537-
loc, addrElem, vecWord, b.i64_val(alignment), maskVal);
3538-
else
3539-
LLVM::intel::createPredicatedBlock(rewriter, loc, maskVal,
3540-
createStore);
3541-
} else {
3536+
if (!maskVal)
35423537
auto _ = createStore();
3543-
}
3538+
else if (usePredicatedInstructions)
3539+
rewriter.create<TritonGEN::PredicatedStoreOp>(
3540+
loc, addrElem, vecWord, b.i64_val(alignment), maskVal);
3541+
else
3542+
LLVM::intel::createPredicatedBlock(rewriter, loc, maskVal, createStore);
3543+
}
35443544

3545-
} // for
35463545
rewriter.eraseOp(op);
35473546
return success();
35483547
}

0 commit comments

Comments
 (0)