Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

26 changes: 0 additions & 26 deletions lib/Conversion/PassDetail.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand All @@ -20,6 +22,10 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHCONVERSIONTOMLPROGRAM
#include "torch-mlir/Conversion/Passes.h.inc"

static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }

Expand Down Expand Up @@ -102,7 +108,7 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {

namespace {
class ConvertTorchConversionToMLProgram
: public ConvertTorchConversionToMLProgramBase<
: public impl::ConvertTorchConversionToMLProgramBase<
ConvertTorchConversionToMLProgram> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
Expand Down Expand Up @@ -138,6 +144,8 @@ class ConvertTorchConversionToMLProgram
} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::createConvertTorchConversionToMLProgramPass() {
createConvertTorchConversionToMLProgramPass() {
return std::make_unique<ConvertTorchConversionToMLProgram>();
}

} // namespace mlir::torch
24 changes: 0 additions & 24 deletions lib/Conversion/TorchOnnxToTorch/PassDetail.h

This file was deleted.

14 changes: 10 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
//
//===----------------------------------------------------------------------===//

#include "./PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
Expand All @@ -19,6 +20,10 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;
namespace mlir::torch::onnx_c {

#define GEN_PASS_DEF_CONVERTTORCHONNXTOTORCH
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc"

#define DEBUG_TYPE "torch-onnx"

Expand All @@ -37,7 +42,7 @@ int64_t getDefaultOpsetVersion(Operation *containerOp) {
}

class ConvertTorchOnnxToTorch
: public ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
: public impl::ConvertTorchOnnxToTorchBase<ConvertTorchOnnxToTorch> {
public:
ConvertTorchOnnxToTorch() = default;
void runOnOperation() override {
Expand Down Expand Up @@ -82,7 +87,8 @@ class ConvertTorchOnnxToTorch

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::onnx_c::createTorchOnnxToTorchPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createTorchOnnxToTorchPass() {
return std::make_unique<ConvertTorchOnnxToTorch>();
}

} // namespace mlir::torch::onnx_c
15 changes: 11 additions & 4 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
Expand All @@ -25,6 +27,10 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHTOARITH
#include "torch-mlir/Conversion/Passes.h.inc"

// -----------------------------------------------------------------------------
// Patterns (as this grows, it should be organized into multiple files)
Expand Down Expand Up @@ -407,7 +413,7 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern<OpTy> {

namespace {
class ConvertTorchToArith
: public ConvertTorchToArithBase<ConvertTorchToArith> {
: public impl::ConvertTorchToArithBase<ConvertTorchToArith> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect>();
Expand Down Expand Up @@ -565,7 +571,8 @@ class ConvertTorchToArith
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToArithPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToArithPass() {
return std::make_unique<ConvertTorchToArith>();
}

} // namespace mlir::torch
15 changes: 11 additions & 4 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
Expand All @@ -24,6 +26,10 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHTOLINALG
#include "torch-mlir/Conversion/Passes.h.inc"

// -----------------------------------------------------------------------------
// The pass
Expand All @@ -34,7 +40,7 @@ using namespace mlir::torch::Torch;

namespace {
class ConvertTorchToLinalg
: public ConvertTorchToLinalgBase<ConvertTorchToLinalg> {
: public impl::ConvertTorchToLinalgBase<ConvertTorchToLinalg> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
Expand Down Expand Up @@ -89,7 +95,8 @@ class ConvertTorchToLinalg
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToLinalgPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass() {
return std::make_unique<ConvertTorchToLinalg>();
}

} // namespace mlir::torch
1 change: 0 additions & 1 deletion lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
Expand Down
16 changes: 12 additions & 4 deletions lib/Conversion/TorchToSCF/TorchToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -21,6 +23,10 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHTOSCF
#include "torch-mlir/Conversion/Passes.h.inc"

namespace {
class ConvertTorchPrimIfYieldOp : public OpConversionPattern<PrimIfYieldOp> {
Expand Down Expand Up @@ -312,7 +318,8 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern<PrimLoopOp> {
} // namespace

namespace {
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
class ConvertTorchToSCF
: public impl::ConvertTorchToSCFBase<ConvertTorchToSCF> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect, arith::ArithDialect>();
Expand Down Expand Up @@ -345,7 +352,8 @@ class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
};
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToSCFPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass() {
return std::make_unique<ConvertTorchToSCF>();
}

} // namespace mlir::torch
1 change: 0 additions & 1 deletion lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"

Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TorchToStablehlo/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TorchToStablehlo/Rng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"

#include "../PassDetail.h"
#include "./PopulatePatterns.h"

#include "stablehlo/dialect/StablehloOps.h"
Expand Down
17 changes: 12 additions & 5 deletions lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Conversion/Passes.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
Expand All @@ -23,11 +25,15 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir::torch {

#define GEN_PASS_DEF_CONVERTTORCHTOSTABLEHLO
#include "torch-mlir/Conversion/Passes.h.inc"

namespace {

class ConvertTorchToStablehlo
: public ConvertTorchToStablehloBase<ConvertTorchToStablehlo> {
: public impl::ConvertTorchToStablehloBase<ConvertTorchToStablehlo> {
public:
ConvertTorchToStablehlo() = default;
ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) {
Expand Down Expand Up @@ -87,13 +93,14 @@ class ConvertTorchToStablehlo
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToStablehloPass() {
createConvertTorchToStablehloPass() {
return std::make_unique<ConvertTorchToStablehlo>(false, false);
}

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape,
bool enableI32Index) {
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index) {
return std::make_unique<ConvertTorchToStablehlo>(enableStaticShape,
enableI32Index);
}

} // namespace mlir::torch
Loading
Loading