@@ -44,6 +44,19 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
4444 return true ;
4545}
4646
47+ static bool isLessThanOrEqualTargetBitWidth (Type t, unsigned targetBitWidth) {
48+ VectorType vecType = dyn_cast<VectorType>(t);
49+ // Reject index since getElementTypeBitWidth will abort for Index types.
50+ if (!vecType || vecType.getElementType ().isIndex ())
51+ return false ;
52+ // There are no dimension to fold if it is a 0-D vector.
53+ if (vecType.getRank () == 0 )
54+ return false ;
55+ unsigned trailingVecDimBitWidth =
56+ vecType.getShape ().back () * vecType.getElementTypeBitWidth ();
57+ return trailingVecDimBitWidth <= targetBitWidth;
58+ }
59+
4760namespace {
4861struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
4962 using OpConversionPattern::OpConversionPattern;
@@ -355,6 +368,88 @@ struct LinearizeVectorExtract final
355368 return success ();
356369 }
357370
371+ private:
372+ unsigned targetVectorBitWidth;
373+ };
374+
375+ // / This pattern converts the InsertOp to a ShuffleOp that works on a
376+ // / linearized vector.
377+ // / Following,
378+ // / vector.insert %source %destination [ position ]
379+ // / is converted to :
380+ // / %source_1d = vector.shape_cast %source
381+ // / %destination_1d = vector.shape_cast %destination
382+ // / %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
383+ // / ] %out_nd = vector.shape_cast %out_1d
384+ // / `shuffle_indices_1d` is computed using the position of the original insert.
385+ struct LinearizeVectorInsert final
386+ : public OpConversionPattern<vector::InsertOp> {
387+ using OpConversionPattern::OpConversionPattern;
388+ LinearizeVectorInsert (
389+ const TypeConverter &typeConverter, MLIRContext *context,
390+ unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
391+ PatternBenefit benefit = 1 )
392+ : OpConversionPattern(typeConverter, context, benefit),
393+ targetVectorBitWidth (targetVectBitWidth) {}
394+ LogicalResult
395+ matchAndRewrite (vector::InsertOp insertOp, OpAdaptor adaptor,
396+ ConversionPatternRewriter &rewriter) const override {
397+ Type dstTy = getTypeConverter ()->convertType (insertOp.getDestVectorType ());
398+ assert (!(insertOp.getDestVectorType ().isScalable () ||
399+ cast<VectorType>(dstTy).isScalable ()) &&
400+ " scalable vectors are not supported." );
401+
402+ if (!isLessThanOrEqualTargetBitWidth (insertOp.getSourceType (),
403+ targetVectorBitWidth))
404+ return rewriter.notifyMatchFailure (
405+ insertOp, " Can't flatten since targetBitWidth < OpSize" );
406+
407+ // dynamic position is not supported
408+ if (insertOp.hasDynamicPosition ())
409+ return rewriter.notifyMatchFailure (insertOp,
410+ " dynamic position is not supported." );
411+ auto srcTy = insertOp.getSourceType ();
412+ auto srcAsVec = dyn_cast<VectorType>(srcTy);
413+ uint64_t srcSize = 0 ;
414+ if (srcAsVec) {
415+ srcSize = srcAsVec.getNumElements ();
416+ } else {
417+ return rewriter.notifyMatchFailure (insertOp,
418+ " scalars are not supported." );
419+ }
420+
421+ auto dstShape = insertOp.getDestVectorType ().getShape ();
422+ const auto dstSize = insertOp.getDestVectorType ().getNumElements ();
423+ auto dstSizeForOffsets = dstSize;
424+
425+ // compute linearized offset
426+ int64_t linearizedOffset = 0 ;
427+ auto offsetsNd = insertOp.getStaticPosition ();
428+ for (auto [dim, offset] : llvm::enumerate (offsetsNd)) {
429+ dstSizeForOffsets /= dstShape[dim];
430+ linearizedOffset += offset * dstSizeForOffsets;
431+ }
432+
433+ llvm::SmallVector<int64_t , 2 > indices (dstSize);
434+ auto origValsUntil = indices.begin ();
435+ std::advance (origValsUntil, linearizedOffset);
436+ std::iota (indices.begin (), origValsUntil,
437+ 0 ); // original values that remain [0, offset)
438+ auto newValsUntil = origValsUntil;
439+ std::advance (newValsUntil, srcSize);
440+ std::iota (origValsUntil, newValsUntil,
441+ dstSize); // new values [offset, offset+srcNumElements)
442+ std::iota (newValsUntil, indices.end (),
443+ linearizedOffset + srcSize); // the rest of original values
444+ // [offset+srcNumElements, end)
445+
446+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
447+ insertOp, dstTy, adaptor.getDest (), adaptor.getSource (),
448+ rewriter.getI64ArrayAttr (indices));
449+
450+ return success ();
451+ }
452+
358453private:
359454 unsigned targetVectorBitWidth;
360455};
@@ -410,6 +505,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
410505 : true ;
411506 });
412507 patterns.add <LinearizeVectorShuffle, LinearizeVectorExtract,
413- LinearizeVectorExtractStridedSlice>(
508+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
414509 typeConverter, patterns.getContext (), targetBitWidth);
415510}
0 commit comments