From 4572127e9afe78a41a47af3c24b641a2aeed7448 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Tue, 11 Nov 2025 10:06:21 -0800 Subject: [PATCH] [NFC] Switch to new pass generation tablegen definitions. This commit completes the migration from the deprecated GEN_PASS_CLASSES to the new GEN_PASS_DEF infrastructure across all torch-mlir passes. Changes include: 1. Remove PassDetail.h files (deprecated pattern) - Deleted lib/Conversion/PassDetail.h - Deleted lib/RefBackend/PassDetail.h - Deleted lib/Dialect/Torch/Transforms/PassDetail.h - Deleted lib/Dialect/TorchConversion/Transforms/PassDetail.h - Deleted lib/Dialect/TMTensor/Transforms/PassDetail.h 2. Migrate conversion passes to GEN_PASS_DEF - Updated all passes in lib/Conversion/ to use #define GEN_PASS_DEF_* - Removed GEN_PASS_DECL from .cpp files (move to headers where needed) - Fixed includes and namespace declarations 3. Migrate dialect transform passes - Updated Torch, TorchConversion, and TMTensor transform passes - Properly scoped GEN_PASS_DEF in namespace blocks 4. Handle passes with options (TorchToStablehlo, TorchToTosa) - Added GEN_PASS_DECL_* to headers - Implemented default and convenience create functions - Used generated constructors via `using BaseClass::BaseClass` 5. Handle passes without options (RefBackend) - Removed manual create function implementations - Let tablegen auto-generate create functions - Added using declarations for Base classes in impl namespace 6. Fix backend type conversion passes - Added missing create functions in BackendTypeConversionPasses.cpp - Fixed namespace scoping issues 7. Fix missing namespace closures - Added proper closing namespace comments in Verify*BackendContract.cpp The migration maintains full backward compatibility while adopting the recommended LLVM pass infrastructure patterns. All passes now use the generated base classes and follow consistent patterns based on whether they have options defined in tablegen. Signed-off-by: hanhanW --- .../Dialect/TMTensor/Transforms/PassDetail.h | 27 ---------- .../TorchToStablehlo/TorchToStablehlo.h | 8 +++ .../Conversion/TorchToTosa/TorchToTosa.h | 7 +++ include/torch-mlir/RefBackend/Passes.h | 13 ----- include/torch-mlir/RefBackend/Passes.td | 6 --- lib/Conversion/PassDetail.h | 26 ---------- .../TorchConversionToMLProgram.cpp | 14 +++-- lib/Conversion/TorchOnnxToTorch/PassDetail.h | 24 --------- .../TorchOnnxToTorch/TorchOnnxToTorch.cpp | 14 +++-- lib/Conversion/TorchToArith/TorchToArith.cpp | 15 ++++-- .../TorchToLinalg/TorchToLinalg.cpp | 15 ++++-- lib/Conversion/TorchToLinalg/Utils.cpp | 1 - lib/Conversion/TorchToSCF/TorchToSCF.cpp | 16 ++++-- lib/Conversion/TorchToStablehlo/Basic.cpp | 1 - .../TorchToStablehlo/GatherScatter.cpp | 1 - lib/Conversion/TorchToStablehlo/Linear.cpp | 1 - lib/Conversion/TorchToStablehlo/Pooling.cpp | 1 - lib/Conversion/TorchToStablehlo/Reduction.cpp | 1 - lib/Conversion/TorchToStablehlo/Rng.cpp | 1 - .../TorchToStablehlo/TorchToStablehlo.cpp | 35 ++++++++----- .../TorchToStablehlo/Uncategorized.cpp | 1 - lib/Conversion/TorchToStablehlo/ViewLike.cpp | 1 - .../TorchToTMTensor/TorchToTMTensor.cpp | 14 +++-- .../TorchToTensor/TorchToTensor.cpp | 16 ++++-- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 37 +++++++++----- lib/Dialect/TMTensor/Transforms/Bufferize.cpp | 12 +++-- .../TMTensor/Transforms/ConvertToLoops.cpp | 13 +++-- .../Transforms/AdjustCallingConventions.cpp | 15 ++++-- .../Torch/Transforms/DecomposeComplexOps.cpp | 26 ++++++---- .../DropAbstractInterpCalculations.cpp | 14 +++-- .../Transforms/EraseModuleInitializer.cpp | 14 +++-- .../Torch/Transforms/FuseQuantizedOps.cpp | 16 ++++-- .../Torch/Transforms/GlobalizeObjectGraph.cpp | 18 +++++-- .../Torch/Transforms/InlineGlobalSlots.cpp | 15 ++++-- .../Transforms/LowerToBackendContract.cpp | 45 ++++++++-------- .../Torch/Transforms/MatchQuantizedOps.cpp | 14 +++-- .../Transforms/MaximizeValueSemantics.cpp | 14 +++-- lib/Dialect/Torch/Transforms/PassDetail.h | 28 ---------- .../PrepareForGlobalizeObjectGraph.cpp | 14 +++-- .../Torch/Transforms/RecomposeComplexOps.cpp | 15 ++++-- .../Torch/Transforms/ReduceOpVariants.cpp | 24 +++++---- .../Torch/Transforms/RefinePublicReturn.cpp | 15 ++++-- .../Transforms/ReifyDtypeCalculations.cpp | 26 ++++++---- .../Transforms/ReifyShapeCalculations.cpp | 25 +++++---- .../Transforms/RestructureNonConstantAxes.cpp | 15 ++++-- .../Torch/Transforms/ScalarizeShapes.cpp | 15 ++++-- .../Transforms/SimplifyDtypeCalculations.cpp | 15 ++++-- .../Transforms/SimplifyShapeCalculations.cpp | 15 ++++-- .../BackendTypeConversionPasses.cpp | 35 ++++++++----- .../Transforms/ConvertCustomQuantOp.cpp | 15 ++++-- .../TorchConversion/Transforms/PassDetail.h | 29 ----------- .../Transforms/UnpackQuantTensor.cpp | 14 +++-- .../VerifyLinalgOnTensorsBackendContract.cpp | 13 +++-- .../VerifyStablehloBackendContract.cpp | 14 +++-- .../Transforms/VerifyTosaBackendContract.cpp | 16 ++++-- lib/RefBackend/PassDetail.h | 26 ---------- lib/RefBackend/RefBackend.cpp | 51 ++++++++----------- 57 files changed, 492 insertions(+), 440 deletions(-) delete mode 100644 include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h delete mode 100644 lib/Conversion/PassDetail.h delete mode 100644 lib/Conversion/TorchOnnxToTorch/PassDetail.h delete mode 100644 lib/Dialect/Torch/Transforms/PassDetail.h delete mode 100644 lib/Dialect/TorchConversion/Transforms/PassDetail.h delete mode 100644 lib/RefBackend/PassDetail.h diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h b/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h deleted file mode 100644 index 391280bbc1a0..000000000000 --- a/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- PassDetail.h - TMTensor Pass class details -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ -#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { -namespace TMTensor { - -#define GEN_PASS_CLASSES -#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" // IWYU pragma: keep - -} // namespace TMTensor -} // namespace torch -} // namespace mlir - -#endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h b/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h index c1926015989e..36f171a4b52b 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h @@ -16,10 +16,18 @@ namespace mlir { namespace torch { + +#define GEN_PASS_DECL_CONVERTTORCHTOSTABLEHLO +#include "torch-mlir/Conversion/Passes.h.inc" + std::unique_ptr> createConvertTorchToStablehloPass(); + +// Convenience wrapper for users who want to pass options as individual +// parameters std::unique_ptr> createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index); + } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index 8ee6fecaa015..c9d9688e04fd 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -19,6 +19,9 @@ namespace mlir { namespace torch { +#define GEN_PASS_DECL_CONVERTTORCHTOTOSA +#include "torch-mlir/Conversion/Passes.h.inc" + /// Collect a set of legal/illegal ops for converting Torch operations to Tosa /// dialect. void populateTorchToTosaConversionLegalOps(ConversionTarget &target); @@ -30,8 +33,12 @@ populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter, RewritePatternSet &patterns); std::unique_ptr> createConvertTorchToTosaPass(); + +// Convenience wrapper for users who want to pass options as individual +// parameters std::unique_ptr> createConvertTorchToTosaPass(bool requireFullTosaConversion); + } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/RefBackend/Passes.h b/include/torch-mlir/RefBackend/Passes.h index be5e43a1e63c..19879ace8194 100644 --- a/include/torch-mlir/RefBackend/Passes.h +++ b/include/torch-mlir/RefBackend/Passes.h @@ -15,25 +15,12 @@ #include "mlir/Pass/PassManager.h" namespace mlir { -class ModuleOp; - namespace torch { namespace RefBackend { /// Registers all RefBackend passes. void registerRefBackendPasses(); -std::unique_ptr> createMungeCallingConventionsPass(); - -std::unique_ptr> createExpandOpsForLLVMPass(); - -std::unique_ptr> createMLProgramBufferizePass(); - -std::unique_ptr> createMungeMemrefCopyPass(); - -std::unique_ptr> createGeneralizeTensorConcatPass(); - -std::unique_ptr> createGeneralizeTensorPadPass(); } // namespace RefBackend } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/RefBackend/Passes.td b/include/torch-mlir/RefBackend/Passes.td index 3d8b7fd41b1b..2f08518f92c2 100644 --- a/include/torch-mlir/RefBackend/Passes.td +++ b/include/torch-mlir/RefBackend/Passes.td @@ -14,35 +14,29 @@ include "mlir/Pass/PassBase.td" def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleOp"> { let summary = "Munge calling conventions for calling via ExecutionEngine"; - let constructor = "mlir::torch::RefBackend::createMungeCallingConventionsPass();"; let dependentDialects = ["memref::MemRefDialect"]; } def MLProgramBufferize: Pass<"refback-mlprogram-bufferize", "ModuleOp"> { let summary = "Bufferize the MLProgram dialect ops"; - let constructor = "mlir::torch::RefBackend::createMLProgramBufferizePass();"; let dependentDialects = ["memref::MemRefDialect"]; } def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> { let summary = "Expand ops into more primitive ops before LLVM lowering."; - let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();"; } def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> { let summary = "Munge memref.copy to linalg.copy"; - let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();"; let dependentDialects = ["memref::MemRefDialect"]; } def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> { let summary = "Convert tensor.concat to other tensor ops"; - let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()"; } def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> { let summary = "Convert tensor.pad to linalg ops"; - let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()"; } #endif // TORCHMLIR_REFBACKEND_PASSES diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h deleted file mode 100644 index aa832141f1de..000000000000 --- a/lib/Conversion/PassDetail.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_PASSDETAIL_H -#define TORCHMLIR_CONVERSION_PASSDETAIL_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { - -#define GEN_PASS_CLASSES -#include "torch-mlir/Conversion/Passes.h.inc" - -} // namespace torch -} // end namespace mlir - -#endif // TORCHMLIR_CONVERSION_PASSDETAIL_H diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index e5dee5b4c3bb..d45506a16088 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -20,6 +22,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHCONVERSIONTOMLPROGRAM +#include "torch-mlir/Conversion/Passes.h.inc" static constexpr StringRef getSeedGobalVarName() { return "global_seed"; } @@ -102,7 +108,7 @@ class ConvertGetNextSeedOp : public OpConversionPattern { namespace { class ConvertTorchConversionToMLProgram - : public ConvertTorchConversionToMLProgramBase< + : public impl::ConvertTorchConversionToMLProgramBase< ConvertTorchConversionToMLProgram> { public: void getDependentDialects(DialectRegistry ®istry) const override { @@ -138,6 +144,8 @@ class ConvertTorchConversionToMLProgram } // namespace std::unique_ptr> -mlir::torch::createConvertTorchConversionToMLProgramPass() { +createConvertTorchConversionToMLProgramPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchOnnxToTorch/PassDetail.h b/lib/Conversion/TorchOnnxToTorch/PassDetail.h deleted file mode 100644 index bbcd3413c59c..000000000000 --- a/lib/Conversion/TorchOnnxToTorch/PassDetail.h +++ /dev/null @@ -1,24 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H -#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir::torch::onnx_c { - -#define GEN_PASS_CLASSES -#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" - -} // namespace mlir::torch::onnx_c - -#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H diff --git a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp index d2f7517376d8..cc42d947175e 100644 --- a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp +++ b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp @@ -7,7 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "./PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" @@ -19,6 +20,10 @@ using llvm::dbgs; using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; +namespace mlir::torch::onnx_c { + +#define GEN_PASS_DEF_CONVERTTORCHONNXTOTORCH +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" #define DEBUG_TYPE "torch-onnx" @@ -37,7 +42,7 @@ int64_t getDefaultOpsetVersion(Operation *containerOp) { } class ConvertTorchOnnxToTorch - : public ConvertTorchOnnxToTorchBase { + : public impl::ConvertTorchOnnxToTorchBase { public: ConvertTorchOnnxToTorch() = default; void runOnOperation() override { @@ -82,7 +87,8 @@ class ConvertTorchOnnxToTorch } // namespace -std::unique_ptr> -mlir::torch::onnx_c::createTorchOnnxToTorchPass() { +std::unique_ptr> createTorchOnnxToTorchPass() { return std::make_unique(); } + +} // namespace mlir::torch::onnx_c diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 17614f95ea16..2dd15ef2c651 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -25,6 +27,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOARITH +#include "torch-mlir/Conversion/Passes.h.inc" // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) @@ -407,7 +413,7 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern { namespace { class ConvertTorchToArith - : public ConvertTorchToArithBase { + : public impl::ConvertTorchToArithBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -565,7 +571,8 @@ class ConvertTorchToArith }; } // namespace -std::unique_ptr> -mlir::torch::createConvertTorchToArithPass() { +std::unique_ptr> createConvertTorchToArithPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 01b1d4b973b6..8b0c6ab8ad19 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -24,6 +26,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOLINALG +#include "torch-mlir/Conversion/Passes.h.inc" // ----------------------------------------------------------------------------- // The pass @@ -34,7 +40,7 @@ using namespace mlir::torch::Torch; namespace { class ConvertTorchToLinalg - : public ConvertTorchToLinalgBase { + : public impl::ConvertTorchToLinalgBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -89,7 +95,8 @@ class ConvertTorchToLinalg }; } // namespace -std::unique_ptr> -mlir::torch::createConvertTorchToLinalgPass() { +std::unique_ptr> createConvertTorchToLinalgPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index c2b584efdecc..8630e6a7ac1a 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -8,7 +8,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 8978a75c01a4..57ee17700f53 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" @@ -21,6 +23,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOSCF +#include "torch-mlir/Conversion/Passes.h.inc" namespace { class ConvertTorchPrimIfYieldOp : public OpConversionPattern { @@ -312,7 +318,8 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { } // namespace namespace { -class ConvertTorchToSCF : public ConvertTorchToSCFBase { +class ConvertTorchToSCF + : public impl::ConvertTorchToSCFBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -345,7 +352,8 @@ class ConvertTorchToSCF : public ConvertTorchToSCFBase { }; } // namespace -std::unique_ptr> -mlir::torch::createConvertTorchToSCFPass() { +std::unique_ptr> createConvertTorchToSCFPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index a22e6658a2ac..2b6f4a90ef7e 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 8ebb7050b124..95435dd5805b 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index 56094c8d0f52..892a0158667b 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 5c0ecb19c5a4..45982e108d00 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index f66d9e040951..ec078080708a 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Rng.cpp b/lib/Conversion/TorchToStablehlo/Rng.cpp index b71af126c69e..c7627431cf56 100644 --- a/lib/Conversion/TorchToStablehlo/Rng.cpp +++ b/lib/Conversion/TorchToStablehlo/Rng.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "./PopulatePatterns.h" #include "stablehlo/dialect/StablehloOps.h" diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index 5a7eb398dc9b..03d36e0ec91b 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -23,17 +25,18 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOSTABLEHLO +#include "torch-mlir/Conversion/Passes.h.inc" namespace { class ConvertTorchToStablehlo - : public ConvertTorchToStablehloBase { + : public impl::ConvertTorchToStablehloBase { public: - ConvertTorchToStablehlo() = default; - ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) { - this->enableStaticShape = enableStaticShape; - this->enableI32Index = enableI32Index; - } + using impl::ConvertTorchToStablehloBase< + ConvertTorchToStablehlo>::ConvertTorchToStablehloBase; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -86,14 +89,20 @@ class ConvertTorchToStablehlo } // namespace +// Default pass creation function (required by tablegen) std::unique_ptr> -mlir::torch::createConvertTorchToStablehloPass() { - return std::make_unique(false, false); +createConvertTorchToStablehloPass() { + return std::make_unique(); } +// Convenience wrapper for users who want to pass options as individual +// parameters std::unique_ptr> -mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape, - bool enableI32Index) { - return std::make_unique(enableStaticShape, - enableI32Index); +createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index) { + ConvertTorchToStablehloOptions options; + options.enableStaticShape = enableStaticShape; + options.enableI32Index = enableI32Index; + return std::make_unique(options); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToStablehlo/Uncategorized.cpp b/lib/Conversion/TorchToStablehlo/Uncategorized.cpp index f8af1529ff15..5026e8dd09ee 100644 --- a/lib/Conversion/TorchToStablehlo/Uncategorized.cpp +++ b/lib/Conversion/TorchToStablehlo/Uncategorized.cpp @@ -10,7 +10,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index af48f84fc357..632d64a3eae1 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -9,7 +9,6 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 60a04bbd7e55..cb5baab07b67 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -8,8 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -33,6 +35,10 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TMTensor; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOTMTENSOR +#include "torch-mlir/Conversion/Passes.h.inc" // ----------------------------------------------------------------------------- // Patterns (as this grows, it should be organized into multiple files) @@ -2459,7 +2465,7 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { namespace { class ConvertTorchToTMTensor - : public ConvertTorchToTMTensorBase { + : public impl::ConvertTorchToTMTensorBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -2519,6 +2525,8 @@ class ConvertTorchToTMTensor } // namespace std::unique_ptr> -mlir::torch::createConvertTorchToTMTensorPass() { +createConvertTorchToTMTensorPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index 10fd2a160d0d..890ca4ec1860 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -8,8 +8,9 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" - -#include "../PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "torch-mlir/Conversion/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -21,6 +22,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOTENSOR +#include "torch-mlir/Conversion/Passes.h.inc" namespace { @@ -139,7 +144,7 @@ class ConvertAtenTensorOpPattern : public OpConversionPattern { }; class ConvertTorchToTensor - : public ConvertTorchToTensorBase { + : public impl::ConvertTorchToTensorBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -170,7 +175,8 @@ class ConvertTorchToTensor } // namespace -std::unique_ptr> -mlir::torch::createConvertTorchToTensorPass() { +std::unique_ptr> createConvertTorchToTensorPass() { return std::make_unique(); } + +} // namespace mlir::torch diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c959f06c6a66..850ca3f3cfb9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8,13 +8,15 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/Passes.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -34,6 +36,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch { + +#define GEN_PASS_DEF_CONVERTTORCHTOTOSA +#include "torch-mlir/Conversion/Passes.h.inc" namespace { @@ -9033,12 +9039,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // ----------------------------------------------------------------------------- namespace { -class ConvertTorchToTosa : public ConvertTorchToTosaBase { +class ConvertTorchToTosa + : public impl::ConvertTorchToTosaBase { public: - ConvertTorchToTosa() = default; - ConvertTorchToTosa(bool requireFullTosaConversion) { - this->requireFullTosaConversion = requireFullTosaConversion; - } + using impl::ConvertTorchToTosaBase< + ConvertTorchToTosa>::ConvertTorchToTosaBase; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -9081,7 +9086,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { }; } // namespace -void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { +void populateTorchToTosaConversionLegalOps(ConversionTarget &target) { // The following ops are never the primary reason why lowering fails. // The backend contract only allows functions to return tensors thus there // is always another op using them. @@ -9098,7 +9103,7 @@ void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { target.addLegalOp(); } -std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( +std::set populateTorchToTosaConversionPatternsAndIllegalOps( TypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); @@ -9411,12 +9416,18 @@ std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( return illegalOps; } -std::unique_ptr> -mlir::torch::createConvertTorchToTosaPass() { - return std::make_unique(true); +// Default pass creation function (required by tablegen) +std::unique_ptr> createConvertTorchToTosaPass() { + return std::make_unique(); } +// Convenience wrapper for users who want to pass options as individual +// parameters std::unique_ptr> -mlir::torch::createConvertTorchToTosaPass(bool requireFullTosaConversion) { - return std::make_unique(requireFullTosaConversion); +createConvertTorchToTosaPass(bool requireFullTosaConversion) { + ConvertTorchToTosaOptions options; + options.requireFullTosaConversion = requireFullTosaConversion; + return std::make_unique(options); } + +} // namespace mlir::torch diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index ca47cdd6033a..f70099e0a478 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -23,11 +23,14 @@ #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" -#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" using namespace ::mlir; using namespace ::mlir::torch::TMTensor; +namespace mlir::torch::TMTensor { + +#define GEN_PASS_DEF_TMTENSORBUFFERIZE +#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { auto memrefType = cast(memref.getType()); @@ -134,7 +137,7 @@ static Value materializeToTensor(OpBuilder &builder, TensorType type, /// Converts TMTensor operations that work on tensor-type operands or results to /// work on buffers. struct TMTensorBufferizePass - : public TMTensorBufferizeBase { + : public impl::TMTensorBufferizeBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -202,7 +205,8 @@ struct TMTensorBufferizePass }; } // namespace -std::unique_ptr> -torch::TMTensor::createTMTensorBufferizePass() { +std::unique_ptr> createTMTensorBufferizePass() { return std::make_unique(); } + +} // namespace mlir::torch::TMTensor diff --git a/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp b/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp index 74d539ab6d8a..5c7755d210fa 100644 --- a/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp +++ b/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp @@ -20,7 +20,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" -#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -28,6 +27,10 @@ using namespace mlir; using namespace mlir::torch::TMTensor; +namespace mlir::torch::TMTensor { + +#define GEN_PASS_DEF_TMTENSORTOLOOPS +#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h.inc" /// Recursive method that lowers one dimension of the `ScalarLoopOpInterface` to /// scalar loops at a time. @@ -98,7 +101,8 @@ struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern { //===----------------------------------------------------------------------===// namespace { -struct TMTensorToLoopsPass : public TMTensorToLoopsBase { +struct TMTensorToLoopsPass + : public impl::TMTensorToLoopsBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert { }; } // namespace -std::unique_ptr> -torch::TMTensor::createTMTensorToLoopsPass() { +std::unique_ptr> createTMTensorToLoopsPass() { return std::make_unique(); } + +} // namespace mlir::torch::TMTensor diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 25a38c83627c..37c54b51d874 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -17,6 +17,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_ADJUSTCALLINGCONVENTIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" // Map from func name and arg index to the type bound for that arg. // This is needed because to rewrite calls, we need the non-local information @@ -285,7 +289,7 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, namespace { class AdjustCallingConventionsPass - : public AdjustCallingConventionsBase { + : public impl::AdjustCallingConventionsBase { void runOnOperation() override { auto module = getOperation(); TypeBoundMap typeBoundMap; @@ -306,7 +310,8 @@ class AdjustCallingConventionsPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createAdjustCallingConventionsPass() { +std::unique_ptr> createAdjustCallingConventionsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c9c42b43c463..08b25c9b6f60 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7,11 +7,11 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -27,6 +27,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_DECOMPOSECOMPLEXOPS +#define GEN_PASS_DEF_DECOMPOSECOMPLEXOPS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" // Helper function to check whether the `dtype` is None or Float type. static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { @@ -13047,7 +13052,7 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { namespace { class DecomposeComplexOpsPass - : public DecomposeComplexOpsBase { + : public impl::DecomposeComplexOpsBase { private: llvm::StringSet<> legalOpsSet; @@ -13068,10 +13073,8 @@ class DecomposeComplexOpsPass } public: - DecomposeComplexOpsPass() = default; - DecomposeComplexOpsPass(ArrayRef legalOps) { - this->legalOps = legalOps; - } + using impl::DecomposeComplexOpsBase< + DecomposeComplexOpsPass>::DecomposeComplexOpsBase; void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); @@ -13392,7 +13395,10 @@ class DecomposeComplexOpsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createDecomposeComplexOpsPass( - ArrayRef legalOps) { - return std::make_unique(legalOps); +createDecomposeComplexOpsPass(ArrayRef legalOps) { + DecomposeComplexOpsOptions options; + options.legalOps.append(legalOps.begin(), legalOps.end()); + return std::make_unique(options); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp b/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp index c3236c0324d1..f7cf7e5f4384 100644 --- a/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -17,6 +17,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_DROPABSTRACTINTERPCALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { template @@ -39,7 +43,7 @@ class DropCalculateOp : public OpConversionPattern { namespace { class DropAbstractInterpCalculationsPass - : public DropAbstractInterpCalculationsBase< + : public impl::DropAbstractInterpCalculationsBase< DropAbstractInterpCalculationsPass> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -61,6 +65,8 @@ class DropAbstractInterpCalculationsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createDropAbstractInterpCalculationsPass() { +createDropAbstractInterpCalculationsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp index db80714127e1..7602169f74ab 100644 --- a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp +++ b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp @@ -7,10 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,10 +17,14 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_ERASEMODULEINITIALIZER +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { class EraseModuleInitializerPass - : public EraseModuleInitializerBase { + : public impl::EraseModuleInitializerBase { void runOnOperation() override { for (auto initializer : getOperation().getOps()) { @@ -37,7 +40,8 @@ class EraseModuleInitializerPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createEraseModuleInitializerPass() { +std::unique_ptr> createEraseModuleInitializerPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index e418a4b08ec0..4b04151f68cc 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,6 +18,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_FUSEQUANTIZEDOPS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -438,7 +442,8 @@ template class RemoveUnused : public OpRewritePattern { } }; -class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { +class FuseQuantizedOpsPass + : public impl::FuseQuantizedOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); @@ -470,7 +475,8 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { } // namespace -std::unique_ptr> -mlir::torch::Torch::createFuseQuantizedOpsPass() { +std::unique_ptr> createFuseQuantizedOpsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index d47220e348ea..a3f15c07183e 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -7,10 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" @@ -21,6 +21,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_GLOBALIZEOBJECTGRAPH +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" static FailureOr findRootNnModule(ModuleOp module) { NnModuleOp rootNnModule; @@ -299,6 +303,8 @@ struct Monomorphization { }; } // namespace +} // namespace mlir::torch::Torch + template <> struct llvm::DenseMapInfo { static Monomorphization getEmptyKey() { return Monomorphization{nullptr, {ArgInstance{-1, nullptr}}}; @@ -318,6 +324,8 @@ template <> struct llvm::DenseMapInfo { } }; +namespace mlir::torch::Torch { + // Populate `mapping` such that values of NnModuleType in the function are // mapped to appropriate global objects of NnModuleType. // @@ -696,7 +704,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { namespace { class GlobalizeObjectGraphPass - : public GlobalizeObjectGraphBase { + : public impl::GlobalizeObjectGraphBase { void runOnOperation() override { if (failed(globalizeObjectGraph(getOperation()))) return signalPassFailure(); @@ -704,7 +712,7 @@ class GlobalizeObjectGraphPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createGlobalizeObjectGraphPass() { +std::unique_ptr> createGlobalizeObjectGraphPass() { return std::make_unique(); } +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 12660cfee47c..d6308f11b8d8 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -23,12 +23,11 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "llvm/Support/Debug.h" @@ -38,6 +37,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_FLATSYMBOLREFLATTICEANCHOR +#define GEN_PASS_DEF_INLINEGLOBALSLOTS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" /// A program point representing a symbol. /// @@ -276,7 +280,7 @@ static bool isInitialValueTransitivelySafeToInline(Value initialValue, namespace { class InlineGlobalSlotsPass - : public InlineGlobalSlotsBase { + : public impl::InlineGlobalSlotsBase { void runOnOperation() override { ModuleOp module = getOperation(); @@ -417,7 +421,8 @@ class InlineGlobalSlotsPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createInlineGlobalSlotsPass() { +std::unique_ptr> createInlineGlobalSlotsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index cfc8bb96118b..b149d172496c 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -24,6 +24,12 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_LOWERTOBACKENDCONTRACT +#define GEN_PASS_DEF_LOWERTOBACKENDCONTRACT +#define GEN_PASS_DEF_VERIFYBACKENDCONTRACTNODECOMPOSITIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// // Checking the backend contract. @@ -258,19 +264,10 @@ getBackendContractTarget(MLIRContext *context, bool decompose, namespace { class LowerToBackendContractPass - : public LowerToBackendContractBase { + : public impl::LowerToBackendContractBase { public: - LowerToBackendContractPass() = default; - LowerToBackendContractPass(int maxIterations, bool decompose, - bool shapeDtypeRefine, - ArrayRef backendLegalOps, - StringRef extraLibrary) { - this->maxIterations = maxIterations; - this->decompose = decompose; - this->shapeDtypeRefine = shapeDtypeRefine; - this->backendLegalOps = backendLegalOps; - this->extraLibrary = extraLibrary.str(); - } + using impl::LowerToBackendContractBase< + LowerToBackendContractPass>::LowerToBackendContractBase; void runOnOperation() override { ModuleOp module = getOperation(); MLIRContext *context = &getContext(); @@ -317,7 +314,7 @@ class LowerToBackendContractPass }; class VerifyBackendContractNoDecompositionsPass - : public VerifyBackendContractNoDecompositionsBase< + : public impl::VerifyBackendContractNoDecompositionsBase< VerifyBackendContractNoDecompositionsPass> { public: VerifyBackendContractNoDecompositionsPass() = default; @@ -336,17 +333,21 @@ class VerifyBackendContractNoDecompositionsPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createLowerToBackendContractPass( +std::unique_ptr> createLowerToBackendContractPass( int maxIterations, bool decompose, bool shapeDtypeRefine, ArrayRef backendLegalOps, StringRef extraLibrary) { - return std::make_unique( - maxIterations, decompose, shapeDtypeRefine, backendLegalOps, - extraLibrary); + LowerToBackendContractOptions options; + options.maxIterations = maxIterations; + options.decompose = decompose; + options.shapeDtypeRefine = shapeDtypeRefine; + options.backendLegalOps.append(backendLegalOps.begin(), + backendLegalOps.end()); + options.extraLibrary = extraLibrary.str(); + return std::make_unique(options); } std::unique_ptr> -mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() { +createVerifyBackendContractNoDecompositionsPass() { return std::make_unique(); } @@ -606,3 +607,5 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, return backendLegalOpsSet.contains(opName); }); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp index e5c279415c7f..d873892ba61e 100644 --- a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -17,6 +17,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_MATCHQUANTIZEDCUSTOMOPS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -115,7 +119,7 @@ class MatchQuantizeOperator : public OpRewritePattern { }; class MatchQuantizedCustomOpsPass - : public MatchQuantizedCustomOpsBase { + : public impl::MatchQuantizedCustomOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); @@ -132,6 +136,8 @@ class MatchQuantizedCustomOpsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createMatchQuantizedCustomOpsPass() { +createMatchQuantizedCustomOpsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 10580b81876b..aec53aa6535a 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -7,10 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -19,6 +19,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_MAXIMIZEVALUESEMANTICS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" static Value assertNonValueTensor(Value tensor) { assert(isa(tensor.getType()) && @@ -364,7 +368,7 @@ class RewriteViewLikeSubgraph namespace { class MaximizeValueSemanticsPass - : public MaximizeValueSemanticsBase { + : public impl::MaximizeValueSemanticsBase { void runOnOperation() override { MLIRContext *context = &getContext(); auto func = getOperation(); @@ -379,6 +383,8 @@ class MaximizeValueSemanticsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createMaximizeValueSemanticsPass() { +createMaximizeValueSemanticsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/PassDetail.h b/lib/Dialect/Torch/Transforms/PassDetail.h deleted file mode 100644 index 85fc116fe5ae..000000000000 --- a/lib/Dialect/Torch/Transforms/PassDetail.h +++ /dev/null @@ -1,28 +0,0 @@ -//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H -#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -class ModuleOp; -namespace torch { -namespace Torch { - -#define GEN_PASS_CLASSES -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" - -} // namespace Torch -} // namespace torch -} // end namespace mlir - -#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H diff --git a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp index c7ff95270d98..a53c24954627 100644 --- a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -18,6 +18,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_PREPAREFORGLOBALIZEOBJECTGRAPH +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { class ConvertPrimCallMethodToCall : public OpRewritePattern { @@ -63,7 +67,7 @@ class EraseUnusedConstantOp : public OpRewritePattern { namespace { class PrepareForGlobalizeObjectGraphPass - : public PrepareForGlobalizeObjectGraphBase< + : public impl::PrepareForGlobalizeObjectGraphBase< PrepareForGlobalizeObjectGraphPass> { void runOnOperation() override { @@ -105,6 +109,8 @@ class PrepareForGlobalizeObjectGraphPass } // namespace std::unique_ptr> -mlir::torch::Torch::createPrepareForGlobalizeObjectGraphPass() { +createPrepareForGlobalizeObjectGraphPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index 1d7c926473c2..dc533427d773 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,6 +18,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_RECOMPOSECOMPLEXOPS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -806,7 +810,7 @@ class RecomposeMeshgridIndexingListUnpack namespace { class RecomposeComplexOpsPass - : public RecomposeComplexOpsBase { + : public impl::RecomposeComplexOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); @@ -841,7 +845,8 @@ class RecomposeComplexOpsPass }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createRecomposeComplexOpsPass() { +std::unique_ptr> createRecomposeComplexOpsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index b84b4465eab5..187d234183a3 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "ReifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,6 +18,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_REDUCEOPVARIANTS +#define GEN_PASS_DEF_REDUCEOPVARIANTS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" // Create an overwrite in a manner that preserves the // `OverwriteTensorContentsOp` invariant that both arguments @@ -403,11 +408,8 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op, namespace { struct ReduceOpVariantsPass - : public ReduceOpVariantsBase { - ReduceOpVariantsPass() = default; - ReduceOpVariantsPass(StringRef extraLibrary) { - this->extraLibrary = extraLibrary.str(); - } + : public impl::ReduceOpVariantsBase { + using impl::ReduceOpVariantsBase::ReduceOpVariantsBase; void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); @@ -481,6 +483,10 @@ struct ReduceOpVariantsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) { - return std::make_unique(extraLibrary); +createReduceOpVariantsPass(StringRef extraLibrary) { + ReduceOpVariantsOptions options; + options.extraLibrary = extraLibrary.str(); + return std::make_unique(options); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index 6f45e8876ee1..1040b8e9976d 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -7,20 +7,24 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_REFINEPUBLICRETURN +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { class RefinePublicReturnPass - : public RefinePublicReturnBase { + : public impl::RefinePublicReturnBase { void runOnOperation() override { auto module = getOperation(); module.walk([&](func::FuncOp func) { @@ -101,7 +105,8 @@ class RefinePublicReturnPass } // namespace -std::unique_ptr> -mlir::torch::Torch::createRefinePublicReturnPass() { +std::unique_ptr> createRefinePublicReturnPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index 790fd80a2f71..e1b781957f0e 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -7,10 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "ReifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -18,6 +18,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_REIFYDTYPECALCULATIONS +#define GEN_PASS_DEF_REIFYDTYPECALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" // Massage the op operands to match the dtype function signature. // The dtype function generally takes the same operands as the op, with a few @@ -62,11 +67,10 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, namespace { struct ReifyDtypeCalculationsPass - : public ReifyDtypeCalculationsBase { - ReifyDtypeCalculationsPass() = default; - ReifyDtypeCalculationsPass(StringRef extraLibrary) { - this->extraLibrary = extraLibrary.str(); - } + : public impl::ReifyDtypeCalculationsBase { + using impl::ReifyDtypeCalculationsBase< + ReifyDtypeCalculationsPass>::ReifyDtypeCalculationsBase; + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -96,6 +100,10 @@ struct ReifyDtypeCalculationsPass } // namespace std::unique_ptr> -Torch::createReifyDtypeCalculationsPass(StringRef extraLibrary) { - return std::make_unique(extraLibrary); +createReifyDtypeCalculationsPass(StringRef extraLibrary) { + ReifyDtypeCalculationsOptions options; + options.extraLibrary = extraLibrary.str(); + return std::make_unique(options); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp index 4b81970909d2..81bd5e45a30a 100644 --- a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp @@ -7,10 +7,10 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "ReifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -19,6 +19,11 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DECL_REIFYSHAPECALCULATIONS +#define GEN_PASS_DEF_REIFYSHAPECALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" static FailureOr> shapeFunctionArgsBuilder(OpBuilder &b, Location loc, @@ -57,11 +62,9 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc, namespace { struct ReifyShapeCalculationsPass - : public ReifyShapeCalculationsBase { - ReifyShapeCalculationsPass() = default; - ReifyShapeCalculationsPass(StringRef extraLibrary) { - this->extraLibrary = extraLibrary.str(); - } + : public impl::ReifyShapeCalculationsBase { + using impl::ReifyShapeCalculationsBase< + ReifyShapeCalculationsPass>::ReifyShapeCalculationsBase; void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -95,6 +98,10 @@ struct ReifyShapeCalculationsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createReifyShapeCalculationsPass(StringRef extraLibrary) { - return std::make_unique(extraLibrary); +createReifyShapeCalculationsPass(StringRef extraLibrary) { + ReifyShapeCalculationsOptions options; + options.extraLibrary = extraLibrary.str(); + return std::make_unique(options); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp index 0ea79a02a799..7a5d9270bf77 100644 --- a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp +++ b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp @@ -8,9 +8,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -26,6 +26,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_RESTRUCTURENONCONSTANTAXES +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -251,7 +255,8 @@ void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns, } class RestructureNonConstantAxesPass - : public RestructureNonConstantAxesBase { + : public impl::RestructureNonConstantAxesBase< + RestructureNonConstantAxesPass> { public: RestructureNonConstantAxesPass() = default; @@ -276,6 +281,8 @@ class RestructureNonConstantAxesPass } // namespace std::unique_ptr> -mlir::torch::Torch::createRestructureNonConstantAxesPass() { +createRestructureNonConstantAxesPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index d6db40d0c182..5c990b720d51 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -7,12 +7,13 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Iterators.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" @@ -23,6 +24,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_SCALARIZESHAPES +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { @@ -1534,7 +1539,8 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { } // namespace namespace { -class ScalarizeShapesPass : public ScalarizeShapesBase { +class ScalarizeShapesPass + : public impl::ScalarizeShapesBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -1615,7 +1621,8 @@ class ScalarizeShapesPass : public ScalarizeShapesBase { }; } // namespace -std::unique_ptr> -mlir::torch::Torch::createScalarizeShapesPass() { +std::unique_ptr> createScalarizeShapesPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 2432a3b4686d..25a3cbabb5b9 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "SimplifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" @@ -18,6 +18,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_SIMPLIFYDTYPECALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op, int resultNum, @@ -192,7 +196,8 @@ class RefineNumToTensorScalarOpType namespace { class SimplifyDtypeCalculationsPass - : public SimplifyDtypeCalculationsBase { + : public impl::SimplifyDtypeCalculationsBase< + SimplifyDtypeCalculationsPass> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -222,6 +227,8 @@ class SimplifyDtypeCalculationsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createSimplifyDtypeCalculationsPass() { +createSimplifyDtypeCalculationsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 54a9fb07d72b..0f78e668310f 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "SimplifyAbstractInterpCalculationsUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -17,6 +17,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::Torch { + +#define GEN_PASS_DEF_SIMPLIFYSHAPECALCULATIONS +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" namespace { class DecomposeAtenSizeOp : public OpRewritePattern { @@ -186,7 +190,8 @@ class RefineShapeCalculateOp : public OpRewritePattern { namespace { class SimplifyShapeCalculationsPass - : public SimplifyShapeCalculationsBase { + : public impl::SimplifyShapeCalculationsBase< + SimplifyShapeCalculationsPass> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -219,6 +224,8 @@ class SimplifyShapeCalculationsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createSimplifyShapeCalculationsPass() { +createSimplifyShapeCalculationsPass() { return std::make_unique(); } + +} // namespace mlir::torch::Torch diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index dadd865a54a7..8625a55205d3 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -7,12 +7,11 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -22,6 +21,13 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_FUNCBACKENDTYPECONVERSION +#define GEN_PASS_DEF_FUNCBACKENDTYPECONVERSIONFORSTABLEHLO +#define GEN_PASS_DEF_FINALIZINGBACKENDTYPECONVERSION +#define GEN_PASS_DEF_FINALIZINGBACKENDTYPECONVERSIONFORSTABLEHLO +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// // FuncBackendTypeConversionPass @@ -74,7 +80,8 @@ void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter, } struct FuncBackendTypeConversionPass - : public FuncBackendTypeConversionBase { + : public impl::FuncBackendTypeConversionBase< + FuncBackendTypeConversionPass> { using FuncBackendTypeConversionBase< FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase; void getDependentDialects(DialectRegistry ®istry) const override { @@ -99,7 +106,7 @@ struct FuncBackendTypeConversionPass #ifdef TORCH_MLIR_ENABLE_STABLEHLO struct FuncBackendTypeConversionForStablehloPass - : public FuncBackendTypeConversionForStablehloBase< + : public impl::FuncBackendTypeConversionForStablehloBase< FuncBackendTypeConversionForStablehloPass> { using FuncBackendTypeConversionForStablehloBase< FuncBackendTypeConversionForStablehloPass>:: @@ -127,14 +134,14 @@ struct FuncBackendTypeConversionForStablehloPass #endif // TORCH_MLIR_ENABLE_STABLEHLO } // namespace -std::unique_ptr> -mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() { +// Create functions for passes +std::unique_ptr> createFuncBackendTypeConversionPass() { return std::make_unique(); } #ifdef TORCH_MLIR_ENABLE_STABLEHLO -std::unique_ptr> mlir::torch::TorchConversion:: - createFuncBackendTypeConversionForStablehloPass() { +std::unique_ptr> +createFuncBackendTypeConversionForStablehloPass() { return std::make_unique(); } #endif // TORCH_MLIR_ENABLE_STABLEHLO @@ -195,7 +202,7 @@ static void stripTorchAttrs(FunctionOpInterface func) { namespace { struct FinalizingBackendTypeConversionPass - : public FinalizingBackendTypeConversionBase< + : public impl::FinalizingBackendTypeConversionBase< FinalizingBackendTypeConversionPass> { using FinalizingBackendTypeConversionBase< FinalizingBackendTypeConversionPass>::FinalizingBackendTypeConversionBase; @@ -242,7 +249,7 @@ struct FinalizingBackendTypeConversionPass #ifdef TORCH_MLIR_ENABLE_STABLEHLO struct FinalizingBackendTypeConversionForStablehloPass - : public FinalizingBackendTypeConversionForStablehloBase< + : public impl::FinalizingBackendTypeConversionForStablehloBase< FinalizingBackendTypeConversionForStablehloPass> { using FinalizingBackendTypeConversionForStablehloBase< FinalizingBackendTypeConversionForStablehloPass>:: @@ -287,13 +294,15 @@ struct FinalizingBackendTypeConversionForStablehloPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { +createFinalizingBackendTypeConversionPass() { return std::make_unique(); } #ifdef TORCH_MLIR_ENABLE_STABLEHLO -std::unique_ptr> mlir::torch:: - TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() { +std::unique_ptr> +createFinalizingBackendTypeConversionForStablehloPass() { return std::make_unique(); } #endif // TORCH_MLIR_ENABLE_STABLEHLO + +} // namespace mlir::torch::TorchConversion diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index a4c28c2c3160..8b55b664a578 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -7,11 +7,11 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -22,6 +22,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_CONVERTCUSTOMQUANTOP +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { @@ -191,8 +195,7 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { namespace { class ConvertCustomQuantOpPass - : public TorchConversion::ConvertCustomQuantOpBase< - ConvertCustomQuantOpPass> { + : public impl::ConvertCustomQuantOpBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -225,6 +228,8 @@ class ConvertCustomQuantOpPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createConvertCustomQuantOpPass() { +createConvertCustomQuantOpPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion diff --git a/lib/Dialect/TorchConversion/Transforms/PassDetail.h b/lib/Dialect/TorchConversion/Transforms/PassDetail.h deleted file mode 100644 index cb80ebd89a3c..000000000000 --- a/lib/Dialect/TorchConversion/Transforms/PassDetail.h +++ /dev/null @@ -1,29 +0,0 @@ -//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H -#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H - -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -class ModuleOp; - -namespace torch { -namespace TorchConversion { - -#define GEN_PASS_CLASSES -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" - -} // namespace TorchConversion -} // namespace torch -} // end namespace mlir - -#endif // TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index fcc0beb8d0c3..b621eea1fcd7 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -19,6 +19,10 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_UNPACKQUANTTENSOR +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class UnpackQuantizedMatmulWeights @@ -119,7 +123,7 @@ class UnpackQuantizedMatmulWeights namespace { class UnpackQuantTensorPass - : public TorchConversion::UnpackQuantTensorBase { + : public impl::UnpackQuantTensorBase { using UnpackQuantTensorBase::UnpackQuantTensorBase; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -138,6 +142,8 @@ class UnpackQuantTensorPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createUnpackQuantTensorPass() { +createUnpackQuantTensorPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index 5189a17fc942..f08a1f389f81 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -7,8 +7,6 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" - #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -19,6 +17,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -30,10 +29,14 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; using namespace TMTensor; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_VERIFYLINALGONTENSORSBACKENDCONTRACT +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class VerifyLinalgOnTensorsBackendContractPass - : public VerifyLinalgOnTensorsBackendContractBase< + : public impl::VerifyLinalgOnTensorsBackendContractBase< VerifyLinalgOnTensorsBackendContractPass> { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -105,6 +108,8 @@ class VerifyLinalgOnTensorsBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass() { +createVerifyLinalgOnTensorsBackendContractPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index 3ff6e4732db2..9b0b8986bf8e 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -7,13 +7,12 @@ // //===----------------------------------------------------------------------===// #ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "PassDetail.h" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -22,10 +21,14 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_VERIFYSTABLEHLOBACKENDCONTRACT +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class VerifyStablehloBackendContractPass - : public VerifyStablehloBackendContractBase< + : public impl::VerifyStablehloBackendContractBase< VerifyStablehloBackendContractPass> { void runOnOperation() override { TypeConverter converter; @@ -66,7 +69,10 @@ class VerifyStablehloBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() { +createVerifyStablehloBackendContractPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion + #endif // TORCH_MLIR_ENABLE_STABLEHLO diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp index efa40a02aeb0..3d48a4b8ef81 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp @@ -7,22 +7,26 @@ // //===----------------------------------------------------------------------===// #ifdef TORCH_MLIR_ENABLE_TOSA -#include "PassDetail.h" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TorchConversion; +namespace mlir::torch::TorchConversion { + +#define GEN_PASS_DEF_VERIFYTOSABACKENDCONTRACT +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" namespace { class VerifyTosaBackendContractPass - : public VerifyTosaBackendContractBase { + : public impl::VerifyTosaBackendContractBase< + VerifyTosaBackendContractPass> { void runOnOperation() override { MLIRContext *context = &getContext(); auto module = getOperation(); @@ -59,8 +63,10 @@ class VerifyTosaBackendContractPass }; } // namespace -std::unique_ptr> -mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() { +std::unique_ptr> createVerifyTosaBackendContractPass() { return std::make_unique(); } + +} // namespace mlir::torch::TorchConversion + #endif // TORCH_MLIR_ENABLE_TOSA diff --git a/lib/RefBackend/PassDetail.h b/lib/RefBackend/PassDetail.h deleted file mode 100644 index aad2c369168b..000000000000 --- a/lib/RefBackend/PassDetail.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- PassDetail.h - RefBackend Pass class details -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef REFBACKEND_PASSDETAIL_H -#define REFBACKEND_PASSDETAIL_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { - -#define GEN_PASS_CLASSES -#include "torch-mlir/RefBackend/Passes.h.inc" - -} // namespace torch -} // end namespace mlir - -#endif // REFBACKEND_PASSDETAIL_H diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 89c7fb5df21a..f5e005994432 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -14,7 +14,6 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -42,6 +41,26 @@ using namespace mlir::torch::RefBackend; // Pass registration //===----------------------------------------------------------------------===// +namespace mlir::torch::RefBackend { + +#define GEN_PASS_DEF_MUNGECALLINGCONVENTIONS +#define GEN_PASS_DEF_MLPROGRAMBUFFERIZE +#define GEN_PASS_DEF_EXPANDOPSFORLLVM +#define GEN_PASS_DEF_MUNGEMEMREFCOPY +#define GEN_PASS_DEF_GENERALIZETENSORCONCAT +#define GEN_PASS_DEF_GENERALIZETENSORPAD +#include "torch-mlir/RefBackend/Passes.h.inc" + +} // namespace mlir::torch::RefBackend + +// Bring Base classes into scope for anonymous namespace passes +using mlir::torch::RefBackend::impl::ExpandOpsForLLVMBase; +using mlir::torch::RefBackend::impl::GeneralizeTensorConcatBase; +using mlir::torch::RefBackend::impl::GeneralizeTensorPadBase; +using mlir::torch::RefBackend::impl::MLProgramBufferizeBase; +using mlir::torch::RefBackend::impl::MungeCallingConventionsBase; +using mlir::torch::RefBackend::impl::MungeMemrefCopyBase; + namespace { #define GEN_PASS_REGISTRATION #include "torch-mlir/RefBackend/Passes.h.inc" @@ -220,11 +239,6 @@ class MungeCallingConventions }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createMungeCallingConventionsPass() { - return std::make_unique(); -} - //===----------------------------------------------------------------------===// // MLProgramBufferize //===----------------------------------------------------------------------===// @@ -346,11 +360,6 @@ class MLProgramBufferize : public MLProgramBufferizeBase { }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createMLProgramBufferizePass() { - return std::make_unique(); -} - //===----------------------------------------------------------------------===// // ExpandOpsForLLVM //===----------------------------------------------------------------------===// @@ -376,11 +385,6 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase { }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createExpandOpsForLLVMPass() { - return std::make_unique(); -} - //===----------------------------------------------------------------------===// // MungeMemrefCopy //===----------------------------------------------------------------------===// @@ -432,11 +436,6 @@ class MungeMemrefCopy : public MungeMemrefCopyBase { }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createMungeMemrefCopyPass() { - return std::make_unique(); -} - namespace { class GeneralizeTensorConcat : public GeneralizeTensorConcatBase { @@ -454,11 +453,6 @@ class GeneralizeTensorConcat }; } // namespace -std::unique_ptr> -mlir::torch::RefBackend::createGeneralizeTensorConcatPass() { - return std::make_unique(); -} - namespace { class GeneralizeTensorPad : public GeneralizeTensorPadBase { @@ -476,8 +470,3 @@ class GeneralizeTensorPad } }; } // namespace - -std::unique_ptr> -mlir::torch::RefBackend::createGeneralizeTensorPadPass() { - return std::make_unique(); -}