@@ -273,12 +273,25 @@ struct SharedMemoryObject {
273273 ArrayRef<Value> offsets)
274274 : base(base), baseElemType(baseElemType),
275275 strides (strides.begin(), strides.end()),
276- offsets(offsets.begin(), offsets.end()) {}
276+ offsets(offsets.begin(), offsets.end()) {
277+ assert (strides.size () == offsets.size ());
278+ }
277279
278280 SharedMemoryObject (Value base, Type baseElemType, ArrayRef<int64_t > shape,
279- ArrayRef< unsigned > order , Location loc,
281+ triton::gpu::SharedEncodingAttr layout , Location loc,
280282 RewriterBase &rewriter)
281283 : base(base), baseElemType(baseElemType) {
284+ SmallVector<unsigned > order (shape.size ());
285+ // Default minor-to-major order
286+ std::iota (order.rbegin (), order.rend (), 0 );
287+ if (layout) {
288+ auto layoutOrder = convertType<int >(layout.getOrder ());
289+ int rankDiff = layoutOrder.size () - shape.size ();
290+ auto minRank = std::min (shape.size (), layoutOrder.size ());
291+ for (size_t i = 0 ; i < minRank; ++i)
292+ order[i] = layoutOrder[i] - rankDiff;
293+ }
294+ assert (isPermutationOfIota (order) && " Invalid order" );
282295 strides = getStridesFromShapeAndOrder (shape, order, loc, rewriter);
283296 offsets.append (order.size (), i32_val (0 ));
284297 }
@@ -304,14 +317,14 @@ struct SharedMemoryObject {
304317 return types;
305318 }
306319
307- Value getCSwizzleOffset (int order ) const {
308- assert (order >= 0 && order < strides.size ());
309- return offsets[order ];
320+ Value getCSwizzleOffset (int dim ) const {
321+ assert (dim >= 0 && dim < strides.size ());
322+ return offsets[dim ];
310323 }
311324
312- Value getBaseBeforeSlice (int order , Location loc,
325+ Value getBaseBeforeSlice (int dim , Location loc,
313326 RewriterBase &rewriter) const {
314- Value cSwizzleOffset = getCSwizzleOffset (order );
327+ Value cSwizzleOffset = getCSwizzleOffset (dim );
315328 Value offset = sub (i32_val (0 ), cSwizzleOffset);
316329 Type type = base.getType ();
317330 return gep (type, baseElemType, base, offset);
0 commit comments