Skip to content

Commit 5bb0bb4

Browse files
committed
fixup dpas layout per review comments
1 parent 4c92bc6 commit 5bb0bb4

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,13 @@ struct LoadOpConversion
516516
"Only row_major or column_major is supported");
517517
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");
518518

519-
auto dpasLayout = hasDpasLayout
520-
? cast<DpasEncodingAttr>(encoding)
521-
: cast<DpasEncodingAttr>(
522-
getDotEncoding(tensorType).value().getParent());
523-
auto dotOrder = dpasLayout.getThreadOrder();
519+
auto getDotOrder = [&]() {
520+
return hasDpasLayout
521+
? cast<DpasEncodingAttr>(encoding).getThreadOrder()
522+
: getDotEncoding(tensorType).value().getThreadOrder();
523+
};
524+
auto dotOrder = getDotOrder();
525+
524526
size_t rank = dotOrder.size();
525527
const bool valueRowMajor =
526528
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
@@ -537,11 +539,16 @@ struct LoadOpConversion
537539
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
538540
}
539541
};
540-
541542
auto opIdx = getOpIdx();
543+
542544
Type eltTy = tensorType.getElementType();
543545
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
544546

547+
auto dpasLayout = hasDpasLayout
548+
? cast<DpasEncodingAttr>(encoding)
549+
: cast<DpasEncodingAttr>(
550+
getDotEncoding(tensorType).value().getParent());
551+
545552
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
546553
unsigned numElems = getTotalElemsPerThread(resultType);
547554
SmallVector<int64_t> numReps =

0 commit comments

Comments
 (0)