@@ -5601,6 +5601,184 @@ void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
56015601 });
56025602}
56035603
5604+ namespace {
5605+
5606+ void expand (SmallVectorImpl<int64_t > ¶ms, int numSpatialDims) {
5607+ if (params.size () == 1 ) {
5608+ for (auto _ : llvm::seq<int >(0 , numSpatialDims - 1 )) {
5609+ params.push_back (params[0 ]);
5610+ }
5611+ }
5612+ }
5613+
5614+ template <typename AtenPoolOpT>
5615+ LogicalResult expandPoolParams (AtenPoolOpT op, int numSpatialDims,
5616+ mlir::PatternRewriter &rewriter,
5617+ Value &kernelSizeList, Value &stridesList,
5618+ Value &paddingList, Value &dilationsList) {
5619+
5620+ SmallVector<int64_t , 3 > kernelSizeInts, strideInts, paddingInts, dilationInts;
5621+ if (!matchPattern (op.getKernelSize (),
5622+ m_TorchListOfConstantInts (kernelSizeInts)))
5623+ return rewriter.notifyMatchFailure (
5624+ op, " Non-const kernel_size for pooling op unsupported" );
5625+
5626+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (paddingInts)))
5627+ return rewriter.notifyMatchFailure (
5628+ op, " Non-const padding factor for pooling op unsupported" );
5629+
5630+ if (!matchPattern (op.getStride (), m_TorchListOfConstantInts (strideInts)))
5631+ return rewriter.notifyMatchFailure (
5632+ op, " Non-const stride for pooling op unsupported" );
5633+
5634+ if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() ||
5635+ std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) {
5636+ if (!matchPattern (op.getDilation (),
5637+ m_TorchListOfConstantInts (dilationInts)))
5638+ return rewriter.notifyMatchFailure (
5639+ op, " Non-const dilation for pooling op unsupported" );
5640+
5641+ if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5642+ strideInts.size () != 1 && dilationInts.size () != 1 ) {
5643+ return rewriter.notifyMatchFailure (
5644+ op,
5645+ " Expected one of kernel/stride/padding/dilation to be singleton." );
5646+ }
5647+
5648+ expand (dilationInts, numSpatialDims);
5649+
5650+ } else if (kernelSizeInts.size () != 1 && paddingInts.size () != 1 &&
5651+ strideInts.size () != 1 ) {
5652+ return rewriter.notifyMatchFailure (
5653+ op, " Expected one of kernel/stride/padding to be singleton." );
5654+ }
5655+
5656+ // expand singleton elements
5657+ expand (kernelSizeInts, numSpatialDims);
5658+ expand (paddingInts, numSpatialDims);
5659+ expand (strideInts, numSpatialDims);
5660+
5661+ Location loc = op.getLoc ();
5662+
5663+ SmallVector<Value> cstKernel, cstPadding, cstStrides, cstDilations;
5664+ for (auto dim : llvm::seq<int >(0 , kernelSizeInts.size ())) {
5665+ cstKernel.push_back (rewriter.create <Torch::ConstantIntOp>(
5666+ loc, rewriter.getI64IntegerAttr (kernelSizeInts[dim])));
5667+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
5668+ loc, rewriter.getI64IntegerAttr (paddingInts[dim])));
5669+ cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
5670+ loc, rewriter.getI64IntegerAttr (strideInts[dim])));
5671+ }
5672+
5673+ // set dilations separately as for AvgPool op it won't be set
5674+ for (auto dim : llvm::seq<int >(0 , dilationInts.size ())) {
5675+ cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
5676+ loc, rewriter.getI64IntegerAttr (dilationInts[dim])));
5677+ }
5678+
5679+ auto targetListType =
5680+ Torch::ListType::get (Torch::IntType::get (op->getContext ()));
5681+ kernelSizeList = rewriter.create <Torch::PrimListConstructOp>(
5682+ loc, targetListType, cstKernel);
5683+ paddingList = rewriter.create <Torch::PrimListConstructOp>(loc, targetListType,
5684+ cstPadding);
5685+ stridesList = rewriter.create <Torch::PrimListConstructOp>(loc, targetListType,
5686+ cstStrides);
5687+ dilationsList = rewriter.create <Torch::PrimListConstructOp>(
5688+ loc, targetListType, cstDilations);
5689+
5690+ return success ();
5691+ }
5692+
5693+ template <typename AvgPoolOpT>
5694+ struct CanonicalizeAvgPoolWithSingleIntTuple
5695+ : public mlir::OpRewritePattern<AvgPoolOpT> {
5696+ CanonicalizeAvgPoolWithSingleIntTuple (mlir::MLIRContext *context)
5697+ : OpRewritePattern<AvgPoolOpT>(context, /* benefit=*/ 1 ) {}
5698+
5699+ LogicalResult
5700+ matchAndRewrite (AvgPoolOpT op,
5701+ mlir::PatternRewriter &rewriter) const override {
5702+ Value kernel, stride, pad, dilations;
5703+
5704+ auto numSpatialDims = 2 ;
5705+ if constexpr (std::is_same<AvgPoolOpT, AtenAvgPool3dOp>())
5706+ numSpatialDims = 3 ;
5707+
5708+ // Attempt to expand params if necessary.
5709+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5710+ pad, dilations)))
5711+ return rewriter.notifyMatchFailure (op,
5712+ " Failed to expand params for pooling" );
5713+
5714+ rewriter.replaceOpWithNewOp <AvgPoolOpT>(
5715+ op, op.getResult ().getType (), op.getSelf (), kernel, stride, pad,
5716+ op.getCeilMode (), op.getCountIncludePad (), op.getDivisorOverride ());
5717+ return success ();
5718+ }
5719+ };
5720+
5721+ template <typename MaxPoolOpT>
5722+ struct CanonicalizeMaxPoolWithSingleIntTuple
5723+ : public mlir::OpRewritePattern<MaxPoolOpT> {
5724+ CanonicalizeMaxPoolWithSingleIntTuple (mlir::MLIRContext *context)
5725+ : OpRewritePattern<MaxPoolOpT>(context, /* benefit=*/ 1 ) {}
5726+
5727+ LogicalResult
5728+ matchAndRewrite (MaxPoolOpT op,
5729+ mlir::PatternRewriter &rewriter) const override {
5730+ Value kernel, stride, pad, dilations;
5731+
5732+ auto numSpatialDims = 2 ;
5733+ if constexpr (std::is_same<MaxPoolOpT, AtenMaxPool3dOp>())
5734+ numSpatialDims = 3 ;
5735+
5736+ // Attempt to expand params if necessary.
5737+ if (failed (expandPoolParams (op, numSpatialDims, rewriter, kernel, stride,
5738+ pad, dilations)))
5739+ return rewriter.notifyMatchFailure (op,
5740+ " Failed to expand params for pooling" );
5741+
5742+ rewriter.replaceOpWithNewOp <MaxPoolOpT>(op, op.getResult ().getType (),
5743+ op.getSelf (), kernel, stride, pad,
5744+ dilations, op.getCeilMode ());
5745+ return success ();
5746+ }
5747+ };
5748+ } // namespace
5749+
5750+ // ===----------------------------------------------------------------------===//
5751+ // AtenAvgPool2dOp
5752+ // ===----------------------------------------------------------------------===//
5753+ void AtenAvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5754+ MLIRContext *context) {
5755+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool2dOp>>(context);
5756+ }
5757+
5758+ // ===----------------------------------------------------------------------===//
5759+ // AtenAvgPool3dOp
5760+ // ===----------------------------------------------------------------------===//
5761+ void AtenAvgPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5762+ MLIRContext *context) {
5763+ patterns.add <CanonicalizeAvgPoolWithSingleIntTuple<AtenAvgPool3dOp>>(context);
5764+ }
5765+
5766+ // ===----------------------------------------------------------------------===//
5767+ // AtenMaxPool2dOp
5768+ // ===----------------------------------------------------------------------===//
5769+ void AtenMaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5770+ MLIRContext *context) {
5771+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool2dOp>>(context);
5772+ }
5773+
5774+ // ===----------------------------------------------------------------------===//
5775+ // AtenMaxPool3dOp
5776+ // ===----------------------------------------------------------------------===//
5777+ void AtenMaxPool3dOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
5778+ MLIRContext *context) {
5779+ patterns.add <CanonicalizeMaxPoolWithSingleIntTuple<AtenMaxPool3dOp>>(context);
5780+ }
5781+
56045782// ===----------------------------------------------------------------------===//
56055783// AtenLinalgCrossOp
56065784// ===----------------------------------------------------------------------===//
0 commit comments