@@ -5617,6 +5617,184 @@ void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
56175617 });
56185618}
56195619
5620+ namespace {
5621+
5622+ void expand (SmallVectorImpl<int64_t > ¶ms, int numSpatialDims) {
5623+ if (params.size () == 1 ) {
5624+ for (auto _ : llvm::seq<int >(0 , numSpatialDims - 1 )) {
5625+ params.push_back (params[0 ]);
5626+ }
5627+ }
5628+ }
5629+
5630+ template <typename AtenPoolOpT>
5631+ LogicalResult expandPoolParams (AtenPoolOpT op, int numSpatialDims,
5632+ mlir::PatternRewriter &rewriter,
5633+ Value &kernelSizeList, Value &stridesList,
5634+ Value &paddingList, Value &dilationsList) {
5635+
5636+ SmallVector<int64_t , 3 > kernelSizeInts, strideInts, paddingInts, dilationInts;
5637+ if (!matchPattern (op.getKernelSize (),
5638+ m_TorchListOfConstantInts (kernelSizeInts)))
5639+ return rewriter.notifyMatchFailure (
5640+ op, " Non-const kernel_size for pooling op unsupported" );
5641+
5642+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (paddingInts)))
5643+ return rewriter.notifyMatchFailure (
5644+ op, " Non-const padding factor for pooling op unsupported" );
5645+
5646+ if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts)))
5647+ return rewriter.notifyMatchFailure (
5648+ op, " Non-const stride for pooling op unsupported" );
5649+
5650+ if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() ||
5651+ std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) {
5652+ if (!matchPattern (op.getDilation (),
5653+ m_TorchListOfConstantInts (dilationInts)))
5654+ return rewriter.notifyMatchFailure (
5655+ op, " Non-const dilation for pooling op unsupported" );
5656+
5657+ if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5658+ strideInts.size () != 1 && dilationInts.size () != 1 ) {
5659+ return rewriter.notifyMatchFailure (
5660+ op,
5661+ " Expected one of kernel/stride/padding/dilation to be singleton." );
5662+ }
5663+
5664+ expand (dilationInts, numSpatialDims);
5665+
5666+ } else if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5667+ strideInts.size () != 1 ) {
5668+ return rewriter.notifyMatchFailure (
5669+ op, " Expected one of kernel/stride/padding to be singleton." );
5670+ }
5671+
5672+ // expand singleton elements
5673+ expand (kernelSizeInts, numSpatialDims);
5674+ expand (paddingInts, numSpatialDims);
5675+ expand (strideInts, numSpatialDims);
5676+
5677+ Location loc = op.getLoc ();
5678+
5679+ SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
5680+ for (auto dim : llvm::seq<int >(0 , kernelSizeInts.size ())) {
5681+ cstKernel.push_back (rewriter.create <Torch::ConstantIntOp>(
5682+ loc, rewriter.getI64IntegerAttr (kernelSizeInts[dim])));
5683+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
5684+ loc, rewriter.getI64IntegerAttr (paddingInts[dim])));
5685+ cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
5686+ loc, rewriter.getI64IntegerAttr (strideInts[dim])));
5687+ }
5688+
5689+ // set dilations separately as for AvgPool op it won't be set
5690+ for (auto dim : llvm::seq<int >(0 , dilationInts.size ())) {
5691+ cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
5692+ loc, rewriter.getI64IntegerAttr (dilationInts[dim])));
5693+ }
5694+
5695+ auto targetListType =
5696+ Torch::ListType::get (Torch::IntType::get (op->getContext ()));
5697+ kernelSizeList = rewriter.create <Torch::PrimListConstructOp>(
5698+ loc, targetListType, cstKernel);
5699+ paddingList = rewriter.create <Torch::PrimListConstructOp>(loc, targetListType,
5700+ cstPadding);
5701+ stridesList = rewriter.create <Torch::PrimListConstructOp>(loc, targetListType,
5702+ cstStrides);
5703+ dilationsList = rewriter.create <Torch::PrimListConstructOp>(
5704+ loc, targetListType, cstDilations);
5705+
5706+ return success ();
5707+ }
5708+
5709+ template <typename AvgPoolOpT>
5710+ struct CanonicalizeAvgPoolWithSingleIntTuple
5711+ : public mlir::OpRewritePattern<AvgPoolOpT> {
5712+ CanonicalizeAvgPoolWithSingleIntTuple (mlir::MLIRContext *context)
5713+ : OpRewritePattern<AvgPoolOpT>(context, /* benefit=*/ 1 ) {}
5714+
5715+ LogicalResult
5716+ matchAndRewrite (AvgPoolOpT op,
5717+ mlir::PatternRewriter &rewriter) const override {
5718+ Value kernel, stride, pad, dilations;
5719+
5720+ auto numSpatialDims = 2 ;
5721+ if constexpr (std::is_same<AvgPoolOpT, AtenAvgPool3dOp>())
5722+ numSpatialDims = 3 ;
5723+
5724+ // Attempt to expand params if necessary.
5725+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5726+ pad, dilations)))
5727+ return rewriter.notifyMatchFailure (op,
5728+ " Failed to expand params for pooling" );
5729+
5730+ rewriter.replaceOpWithNewOp <AvgPoolOpT>(
5731+ op, op.getResult ().getType (), op.getSelf (), kernel, stride, pad,
5732+ op.getCeilMode (), op.getCountIncludePad (), op.getDivisorOverride ());
5733+ return success ();
5734+ }
5735+ };
5736+
5737+ template <typename MaxPoolOpT>
5738+ struct CanonicalizeMaxPoolWithSingleIntTuple
5739+ : public mlir::OpRewritePattern<MaxPoolOpT> {
5740+ CanonicalizeMaxPoolWithSingleIntTuple (mlir::MLIRContext *context)
5741+ : OpRewritePattern<MaxPoolOpT>(context, /* benefit=*/ 1 ) {}
5742+
5743+ LogicalResult
5744+ matchAndRewrite (MaxPoolOpT op,
5745+ mlir::PatternRewriter &rewriter) const override {
5746+ Value kernel, stride, pad, dilations;
5747+
5748+ auto numSpatialDims = 2 ;
5749+ if constexpr (std::is_same<MaxPoolOpT, AtenMaxPool3dOp>())
5750+ numSpatialDims = 3 ;
5751+
5752+ // Attempt to expand params if necessary.
5753+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5754+ pad, dilations)))
5755+ return rewriter.notifyMatchFailure (op,
5756+ " Failed to expand params for pooling" );
5757+
5758+ rewriter.replaceOpWithNewOp <MaxPoolOpT>(op, op.getResult ().getType (),
5759+ op.getSelf (), kernel, stride, pad,
5760+ dilations, op.getCeilMode ());
5761+ return success ();
5762+ }
5763+ };
5764+ } // namespace
5765+
5766+ // ===----------------------------------------------------------------------===//
5767+ // AtenAvgPool2dOp
5768+ // ===----------------------------------------------------------------------===//
5769+ void AtenAvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5770+ MLIRContext *context) {
5771+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool2dOp>>(context);
5772+ }
5773+
5774+ // ===----------------------------------------------------------------------===//
5775+ // AtenAvgPool3dOp
5776+ // ===----------------------------------------------------------------------===//
5777+ void AtenAvgPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5778+ MLIRContext *context) {
5779+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool3dOp>>(context);
5780+ }
5781+
5782+ // ===----------------------------------------------------------------------===//
5783+ // AtenMaxPool2dOp
5784+ // ===----------------------------------------------------------------------===//
5785+ void AtenMaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5786+ MLIRContext *context) {
5787+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool2dOp>>(context);
5788+ }
5789+
5790+ // ===----------------------------------------------------------------------===//
5791+ // AtenMaxPool3dOp
5792+ // ===----------------------------------------------------------------------===//
5793+ void AtenMaxPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5794+ MLIRContext *context) {
5795+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
5796+ }
5797+
56205798// ===----------------------------------------------------------------------===//
56215799// AtenLinalgCrossOp
56225800// ===----------------------------------------------------------------------===//
0 commit comments