@@ -211,12 +211,11 @@ struct LoadStoreConversionBase {
211211 [](ArrayRef<Value> A, ArrayRef<Value> B, Value init,
212212 std::function<Value (const Value &, const Value &, const Value &)>
213213 linearizeFunc) {
214- auto rank = A.size ();
214+ unsigned rank = A.size ();
215215 Value accumulate = init;
216216 if (rank > 0 ) {
217- for (auto [a, b] : llvm::zip (A, B)) {
217+ for (auto [a, b] : llvm::zip (A, B))
218218 accumulate = linearizeFunc (a, b, accumulate);
219- }
220219 }
221220 return accumulate;
222221 };
@@ -226,11 +225,10 @@ struct LoadStoreConversionBase {
226225 SmallVector<Value> ptrElems (numElems);
227226 SmallVector<Value> maskElems;
228227 for (unsigned i = 0 ; i < numElems; ++i) {
229- auto index = indices[i];
228+ SmallVector<Value> index = indices[i];
230229 SmallVector<Value> indicesInTensor (rank);
231- for (unsigned j = 0 ; j < rank; ++j) {
230+ for (unsigned j = 0 ; j < rank; ++j)
232231 indicesInTensor[j] = b.add (index[j], blockPtr[blockOffset + j]);
233- }
234232
235233 // Get the LLVM values for pointers
236234 Value offset = linearize (
@@ -272,22 +270,24 @@ struct LoadStoreConversionBase {
272270 SmallVector<Value> otherElems;
273271 if (padding) {
274272 Value other;
275- if (*padding == PaddingOption::PAD_ZERO) {
273+ switch (*padding) {
274+ case PaddingOption::PAD_ZERO:
276275 other = rewriter.create <LLVM::ConstantOp>(
277276 loc, valueElemTy, rewriter.getZeroAttr (valueElemTy));
278- } else if (*padding == PaddingOption::PAD_NAN) {
277+
278+ break ;
279+ case PaddingOption::PAD_NAN: {
279280 assert (!valueElemTy.isIntOrIndex () &&
280281 " Expect element type to be non-integer type" );
281282 auto apNaN = llvm::APFloat::getNaN (
282283 cast<FloatType>(valueElemTy).getFloatSemantics ());
283284 other = rewriter.create <LLVM::ConstantOp>(
284285 loc, valueElemTy, rewriter.getFloatAttr (valueElemTy, apNaN));
285- } else {
286- llvm_unreachable (" Unexpected padding option" );
286+ } break ;
287287 }
288- for (unsigned i = 0 ; i < numElems; ++i) {
288+
289+ for (unsigned i = 0 ; i < numElems; ++i)
289290 otherElems.push_back (other);
290- }
291291 }
292292
293293 return std::make_tuple (ptrElems, maskElems, otherElems);
@@ -296,6 +296,8 @@ struct LoadStoreConversionBase {
296296protected:
297297 const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass;
298298 const triton::intel::TargetInfo &targetInfo;
299+ const bool usePredicatedInstructions =
300+ triton::tools::getBoolEnv (" TRITON_INTEL_PREDICATED" );
299301};
300302
301303struct BlockIOConversionBase : public LoadStoreConversionBase {
@@ -2940,9 +2942,16 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
29402942 auto b = TritonLLVMOpBuilder (loc, rewriter);
29412943 auto typeConverter = getTypeConverter ();
29422944 MLIRContext *ctx = rewriter.getContext ();
2945+
2946+ // original values
29432947 Value ptr = op.getPtr ();
29442948 Value mask = op.getMask ();
2949+ Value other = op.getOther ();
2950+
2951+ // adaptor values
2952+ Value llPtr = adaptor.getPtr ();
29452953 Value llMask = adaptor.getMask ();
2954+ Value llOther = adaptor.getOther ();
29462955
29472956 // Determine the vectorization size
29482957 Type valueElemTy =
@@ -2960,13 +2969,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
29602969 // fallback to gather load.
29612970 auto tensorType = cast<RankedTensorType>(op.getType ());
29622971 std::tie (ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr (
2963- loc, adaptor. getPtr () , tensorType, valueElemTy, rewriter,
2964- op.getBoundaryCheck (), op. getPadding ());
2972+ loc, llPtr , tensorType, valueElemTy, rewriter, op. getBoundaryCheck () ,
2973+ op.getPadding ());
29652974 } else {
2966- Value other = op.getOther ();
2967- Value llPtr = adaptor.getPtr ();
2968- Value llOther = adaptor.getOther ();
2969-
29702975 // Get the LLVM values for pointers
29712976 ptrElems = unpackLLElements (loc, llPtr, rewriter);
29722977 assert (ptrElems.size () == numElems);
@@ -2998,23 +3003,22 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
29983003 const int numVecs = numElems / vec;
29993004
30003005 // Load redundantly in all dims except reg
3001- auto freeVarMasks = getFreeVariableMasks (ptr.getType ());
3006+ llvm::MapVector<StringAttr, int > freeVarMasks =
3007+ getFreeVariableMasks (ptr.getType ());
30023008 uint32_t regMask = freeVarMasks[str_attr (" register" )];
30033009
30043010 SmallVector<Value> loadedVals;
30053011 for (size_t vecStart = 0 ; vecStart < numElems; vecStart += vec) {
3006- if (auto canonicalVecStart = getCanonicalIndex (vecStart, regMask);
3012+ if (unsigned canonicalVecStart = getCanonicalIndex (vecStart, regMask);
30073013 vecStart != canonicalVecStart) {
30083014 // For redundant registers, refer back to the canonical load
3009- for (auto iVec = 0 ; iVec < vec; ++iVec) {
3015+ for (int iVec = 0 ; iVec < vec; ++iVec)
30103016 loadedVals.push_back (loadedVals[canonicalVecStart + iVec]);
3011- }
3017+
30123018 continue ;
30133019 }
30143020
30153021 // TODO: optimization when ptr is GEP with constant offset
3016- size_t in_off = 0 ;
3017-
30183022 const size_t maxWordWidth = std::max<size_t >(32 , valueElemNBits);
30193023 const size_t totalWidth = valueElemNBits * vec;
30203024 const size_t width = std::min (totalWidth, maxWordWidth);
@@ -3025,7 +3029,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
30253029
30263030 Value pred = maskElems.size () ? maskElems[vecStart] : Value{};
30273031
3028- SmallVector<Type> retTys (nWords, IntegerType::get (getContext () , width));
3032+ SmallVector<Type> retTys (nWords, IntegerType::get (ctx , width));
30293033 Type retTy = retTys.size () > 1
30303034 ? vec_ty (IntegerType::get (ctx, width), nWords)
30313035 : retTys[0 ];
@@ -3051,13 +3055,15 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
30513055 v = b.int_val (width, splatVal);
30523056 }
30533057
3054- Value iiVal = createIndexAttrConstant (
3055- rewriter, loc, typeConverter->getIndexType (), ii);
3056- if (nWords > 1 ) {
3057- other_ = b.insert_element (retTy, other_, v, iiVal);
3058- } else {
3059- other_ = v;
3060- }
3058+ other_ =
3059+ (nWords > 1 )
3060+ ? b.insert_element (
3061+ retTy, other_, v,
3062+ createIndexAttrConstant (
3063+ rewriter, loc, typeConverter->getIndexType (), ii))
3064+ :
3065+
3066+ v;
30613067 }
30623068 } else {
30633069 other_ = rewriter.create <LLVM::ConstantOp>(loc, retTy,
@@ -3085,31 +3091,27 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
30853091 };
30863092
30873093 Value ret;
3088- // Create a predicated load operation.
3089- if (pred) {
3090- if (triton::tools::getBoolEnv (" TRITON_INTEL_PREDICATED" ))
3091- ret = rewriter.create <TritonGEN::PredicatedLoadOp>(
3092- loc, retTy, addrElem, b.i64_val (alignment), pred, other_);
3093- else {
3094- Block &endBlock = LLVM::intel::createPredicatedBlock (
3095- rewriter, loc, pred, SmallVector<Value, 1 >{other_},
3096- createLoadWithAttrs);
3097- ret = *endBlock.args_begin ();
3098- }
3099- } else {
3094+
3095+ if (!pred)
31003096 ret = createLoadWithAttrs ()[0 ];
3097+ else if (usePredicatedInstructions)
3098+ ret = rewriter.create <TritonGEN::PredicatedLoadOp>(
3099+ loc, retTy, addrElem, b.i64_val (alignment), pred, other_);
3100+ else {
3101+ Block &endBlock = LLVM::intel::createPredicatedBlock (
3102+ rewriter, loc, pred, SmallVector<Value, 1 >{other_},
3103+ createLoadWithAttrs);
3104+ ret = *endBlock.args_begin ();
31013105 }
3106+ assert (ret && " Expecting a valid value" );
31023107
31033108 // Extract and store return values
31043109 SmallVector<Value> rets;
31053110 for (unsigned int ii = 0 ; ii < nWords; ++ii) {
3106- Value curr;
3107- if (isa<VectorType>(retTy)) {
3108- curr = b.extract_element (IntegerType::get (ctx, width), ret,
3109- b.i32_val (ii));
3110- } else {
3111- curr = ret;
3112- }
3111+ Value curr = isa<VectorType>(retTy)
3112+ ? b.extract_element (IntegerType::get (ctx, width), ret,
3113+ b.i32_val (ii))
3114+ : ret;
31133115 curr = b.bitcast (
31143116 curr, LLVM::getVectorType (valueElemTy, width / valueElemNBits));
31153117 rets.push_back (curr);
@@ -3177,8 +3179,8 @@ struct StoreOpToBlockIOConversion
31773179 // Limit vBlock to 1
31783180 vBlocks = 1 ;
31793181
3180- // TODO: use the axis info to general the handling for both regular pointer
3181- // and block pointer.
3182+ // TODO: use the axis info to general the handling for both regular
3183+ // pointer and block pointer.
31823184 const bool memoryRowMajor = isMemoryRowMajor (op);
31833185 unsigned contiguousDim = memoryRowMajor ? 1 : 0 ;
31843186 if (contiguousDim != colDim) {
@@ -3312,17 +3314,17 @@ struct StoreOpToBlockIOConversion
33123314 Value addrElem = ptrElems[registerIdx];
33133315 Value offsetX, offsetY;
33143316 if (isBlockPointer) {
3315- // Need to apply the linear layout to get the offsets to the base of the
3316- // block pointer.
3317- // TODO: add annotation uniform to the offsets. Make sure the IGC detect
3318- // the offsets as uniform.
3317+ // Need to apply the linear layout to get the offsets to the base of
3318+ // the block pointer.
3319+ // TODO: add annotation uniform to the offsets. Make sure the IGC
3320+ // detect the offsets as uniform.
33193321 auto offsets = applyLinearLayout (loc, rewriter, *llEncoding,
33203322 {{kRegister , b.i32_val (registerIdx)},
33213323 {kLane , b.i32_val (0 )},
33223324 {kWarp , warpId},
33233325 {kBlock , b.i32_val (0 )}});
3324- // TODO: To support rank > 2 tensor, we need to add the offsets of other
3325- // dim to the base.
3326+ // TODO: To support rank > 2 tensor, we need to add the offsets of
3327+ // other dim to the base.
33263328 assert (offsets.size () == 2 && " only support 2D tensor for now." );
33273329 offsetX = b.add (offsetBaseX, offsets[colDim].second );
33283330 offsetY = b.add (offsetBaseY, offsets[rowDim].second );
@@ -3342,7 +3344,8 @@ struct StoreOpToBlockIOConversion
33423344 assert (numPackedVals > 0 &&
33433345 " numPackedVals should be greater than zero." );
33443346 // The offsetX of linear layout is in original elements.
3345- // The 2d block io requires the offsetX in number of packed elements.
3347+ // The 2d block io requires the offsetX in number of packed
3348+ // elements.
33463349 offsetX = b.udiv (offsetX, b.i32_val (numPackedVals));
33473350 }
33483351 if (!boundaryCheck.contains (rowDim)) {
@@ -3530,19 +3533,15 @@ struct StoreOpConversion
35303533 return ArrayRef<Value>();
35313534 };
35323535
3533- if (maskVal) {
3534- // Create a predicated store operation.
3535- if (triton::tools::getBoolEnv (" TRITON_INTEL_PREDICATED" ))
3536- rewriter.create <TritonGEN::PredicatedStoreOp>(
3537- loc, addrElem, vecWord, b.i64_val (alignment), maskVal);
3538- else
3539- LLVM::intel::createPredicatedBlock (rewriter, loc, maskVal,
3540- createStore);
3541- } else {
3536+ if (!maskVal)
35423537 auto _ = createStore ();
3543- }
3538+ else if (usePredicatedInstructions)
3539+ rewriter.create <TritonGEN::PredicatedStoreOp>(
3540+ loc, addrElem, vecWord, b.i64_val (alignment), maskVal);
3541+ else
3542+ LLVM::intel::createPredicatedBlock (rewriter, loc, maskVal, createStore);
3543+ }
35443544
3545- } // for
35463545 rewriter.eraseOp (op);
35473546 return success ();
35483547 }
0 commit comments