@@ -1759,6 +1759,65 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
17591759 return success ();
17601760}
17611761
1762+ template <>
1763+ LogicalResult ConvertAtenOp<AtenReflectionPad1dOp>::matchAndRewrite(
1764+ AtenReflectionPad1dOp op, OpAdaptor adaptor,
1765+ ConversionPatternRewriter &rewriter) const {
1766+ Location loc = op.getLoc ();
1767+ Value self = adaptor.getSelf ();
1768+ auto selfTy = cast<RankedTensorType>(self.getType ());
1769+ if (!selfTy.hasStaticShape ()) {
1770+ return rewriter.notifyMatchFailure (op, " only support static shape" );
1771+ }
1772+ int64_t rank = selfTy.getRank ();
1773+ int64_t dim = rank - 1 ;
1774+
1775+ SmallVector<int64_t > padInts;
1776+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (padInts))) {
1777+ return rewriter.notifyMatchFailure (op,
1778+ " only support constant int pad ranges" );
1779+ }
1780+ if (padInts.size () != 2 ) {
1781+ return rewriter.notifyMatchFailure (op, " pad size must be 2" );
1782+ }
1783+ if (padInts[0 ] >= selfTy.getDimSize (dim) ||
1784+ padInts[1 ] >= selfTy.getDimSize (dim)) {
1785+ return rewriter.notifyMatchFailure (op,
1786+ " pad size must be less than dim size" );
1787+ }
1788+
1789+ Value left;
1790+ {
1791+ SmallVector<int64_t > startIndices (rank, 0 );
1792+ SmallVector<int64_t > limitIndices (selfTy.getShape ().begin (),
1793+ selfTy.getShape ().end ());
1794+ SmallVector<int64_t > strides (rank, 1 );
1795+ startIndices[dim] = 1 ;
1796+ limitIndices[dim] = padInts[0 ] + 1 ;
1797+ left = rewriter.create <stablehlo::SliceOp>(loc, self, startIndices,
1798+ limitIndices, strides);
1799+ left = rewriter.create <stablehlo::ReverseOp>(loc, left,
1800+ ArrayRef<int64_t >({dim}));
1801+ }
1802+ Value right;
1803+ {
1804+ SmallVector<int64_t > startIndices (rank, 0 );
1805+ SmallVector<int64_t > limitIndices (selfTy.getShape ().begin (),
1806+ selfTy.getShape ().end ());
1807+ SmallVector<int64_t > strides (rank, 1 );
1808+ startIndices[dim] = selfTy.getDimSize (dim) - 1 - padInts[1 ];
1809+ limitIndices[dim] = selfTy.getDimSize (dim) - 1 ;
1810+ right = rewriter.create <stablehlo::SliceOp>(loc, self, startIndices,
1811+ limitIndices, strides);
1812+ right = rewriter.create <stablehlo::ReverseOp>(loc, right,
1813+ ArrayRef<int64_t >({dim}));
1814+ }
1815+ Value result = rewriter.create <stablehlo::ConcatenateOp>(
1816+ loc, ValueRange{left, self, right}, dim);
1817+ rewriter.replaceOp (op, result);
1818+ return success ();
1819+ }
1820+
17621821template <>
17631822LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
17641823 AtenGeluBackwardOp op, OpAdaptor adaptor,
@@ -2269,6 +2328,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
22692328 INSERT_ATENOP_PATTERN (AtenScalarImplicitOp);
22702329 INSERT_ATENOP_PATTERN (AtenContiguousOp);
22712330 INSERT_ATENOP_PATTERN (AtenConstantPadNdOp);
2331+ INSERT_ATENOP_PATTERN (AtenReflectionPad1dOp);
22722332
22732333 INSERT_ATENOP_PATTERN (AtenReluOp);
22742334 INSERT_ATENOP_PATTERN (AtenGeluOp);
0 commit comments