@@ -116,6 +116,83 @@ class ConvertAtenConstantPadNdOp
116116
117117namespace {
118118
119+ class ConvertAtenReplicationPad1dOp
120+ : public OpConversionPattern<AtenReplicationPad1dOp> {
121+ public:
122+ using OpConversionPattern::OpConversionPattern;
123+
124+ LogicalResult
125+ matchAndRewrite (AtenReplicationPad1dOp op, OpAdaptor adaptor,
126+ ConversionPatternRewriter &rewriter) const override {
127+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
128+ return failure ();
129+
130+ Location loc = op.getLoc ();
131+ Value input = adaptor.getSelf ();
132+ auto inputType = llvm::cast<RankedTensorType>(input.getType ());
133+ int64_t inputRank = inputType.getRank ();
134+
135+ if (inputRank < 2 )
136+ return rewriter.notifyMatchFailure (op, " input rank must be at least 2" );
137+
138+ SmallVector<int64_t > padInts;
139+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (padInts)))
140+ return rewriter.notifyMatchFailure (
141+ op, " only support constant int pad ranges" );
142+
143+ if (padInts.size () != 2 )
144+ return rewriter.notifyMatchFailure (
145+ op, " pad range must have exactly two values" );
146+
147+ int64_t leftPad = padInts[0 ];
148+ int64_t rightPad = padInts[1 ];
149+
150+ int64_t dimToPad = inputRank - 1 ;
151+ Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
152+
153+ SmallVector<Value> inputShape = getTensorSizes (rewriter, loc, input);
154+ Value widthSize = inputShape[dimToPad];
155+ Value widthMinusOne = rewriter.create <arith::SubIOp>(loc, widthSize, one);
156+
157+ // Build offset and size arrays for slicing
158+ SmallVector<OpFoldResult> allOneStrides (inputRank,
159+ rewriter.getIndexAttr (1 ));
160+ SmallVector<OpFoldResult> leftOffsets (inputRank, rewriter.getIndexAttr (0 ));
161+ SmallVector<OpFoldResult> rightOffsets (inputRank, rewriter.getIndexAttr (0 ));
162+ SmallVector<OpFoldResult> sizes (inputRank, rewriter.getIndexAttr (0 ));
163+ for (int i = 0 ; i < inputRank; ++i)
164+ sizes[i] = (i == dimToPad) ? rewriter.getIndexAttr (1 )
165+ : getAsOpFoldResult (inputShape[i]);
166+
167+ rightOffsets[dimToPad] = getAsOpFoldResult (widthMinusOne);
168+
169+ // Extract leftmost and rightmost slices
170+ Value leftSlice = rewriter.create <tensor::ExtractSliceOp>(
171+ loc, input, leftOffsets, sizes, allOneStrides);
172+ Value rightSlice = rewriter.create <tensor::ExtractSliceOp>(
173+ loc, input, rightOffsets, sizes, allOneStrides);
174+
175+ // Aggregate slices to concat together
176+ SmallVector<Value> resultParts;
177+ resultParts.reserve (leftPad + rightPad + 1 );
178+
179+ resultParts.append (leftPad, leftSlice);
180+ resultParts.push_back (input);
181+ resultParts.append (rightPad, rightSlice);
182+
183+ Value result =
184+ rewriter.create <tensor::ConcatOp>(loc, dimToPad, resultParts);
185+ Type resultType = getTypeConverter ()->convertType (op.getType ());
186+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, result);
187+
188+ return success ();
189+ }
190+ };
191+
192+ } // namespace
193+
194+ namespace {
195+
119196// Lower aten.replication_pad2d operator into a sequence of
120197// tensor.extract_slice and tensor.concat operations.
121198
@@ -621,6 +698,8 @@ void mlir::torch::torch_to_linalg::
621698 MLIRContext *context = patterns.getContext ();
622699 target.addIllegalOp <AtenReplicationPad2dOp>();
623700 patterns.add <ConvertAtenReplicationPad2dOp>(typeConverter, context);
701+ target.addIllegalOp <AtenReplicationPad1dOp>();
702+ patterns.add <ConvertAtenReplicationPad1dOp>(typeConverter, context);
624703 target.addIllegalOp <AtenConstantPadNdOp>();
625704 patterns.add <ConvertAtenConstantPadNdOp>(typeConverter, context);
626705 target.addIllegalOp <AtenZerosOp, AtenOnesOp>();
0 commit comments