diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 8f14515e425c..523e9a4d006c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -2102,26 +2102,27 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), c, cstBlockSizeSquare); cDivBlockSizeSquare = rewriter.create( binder.getLoc(), cDivBlockSizeSquare); - Value reshapeSizesList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(input.getContext())), - llvm::SmallVector{b, cstBlockSize, cstBlockSize, - cDivBlockSizeSquare, h, w}); int64_t cDivBlockSizeSquareInt = inputSizes[1] == Torch::kUnknownSize ? Torch::kUnknownSize : inputSizes[1] / (blockSize * blockSize); - SmallVector reshapeSizesInt{ - inputSizes[0], blockSize, blockSize, - cDivBlockSizeSquareInt, inputSizes[2], inputSizes[3]}; - Value reshapedInput = rewriter.create( - binder.getLoc(), - inputTy.getWithSizesAndDtype(reshapeSizesInt, - inputTy.getOptionalDtype()), - input, reshapeSizesList); Value transposedInput; + Value reshapeSizesList; if (mode == "DCR") { + reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, cstBlockSize, cstBlockSize, + cDivBlockSizeSquare, h, w}); + SmallVector reshapeSizesInt{ + inputSizes[0], blockSize, blockSize, + cDivBlockSizeSquareInt, inputSizes[2], inputSizes[3]}; + Value reshapedInput = rewriter.create( + binder.getLoc(), + inputTy.getWithSizesAndDtype(reshapeSizesInt, + inputTy.getOptionalDtype()), + input, reshapeSizesList); if (failed(createTorchTransposeOp( rewriter, binder.getLoc(), reshapedInput, /*dimA=*/1, /*dimB=*/3, transposedInput))) @@ -2134,6 +2135,19 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "Failed to create TorchTranspose op"); } else { // mode == "CRD" + reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, cDivBlockSizeSquare, cstBlockSize, + cstBlockSize, h, w}); + SmallVector reshapeSizesInt{ + inputSizes[0], cDivBlockSizeSquareInt, blockSize, + blockSize, inputSizes[2], inputSizes[3]}; + Value reshapedInput = rewriter.create( + binder.getLoc(), + inputTy.getWithSizesAndDtype(reshapeSizesInt, + inputTy.getOptionalDtype()), + input, reshapeSizesList); if (failed(createTorchTransposeOp( rewriter, binder.getLoc(), reshapedInput, /*dimA=*/2, /*dimB=*/4, transposedInput))) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 8d83dc181987..5c326c15e1ef 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1836,7 +1836,7 @@ func.func @test_depthtospace_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !t // CHECK: %[[TRANSPOSE_2:.*]] = torch.aten.transpose.int %[[TRANSPOSE_1]], %[[C4_1]], %[[C5]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,2],f32> // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[SIZE_1]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[SIZE_2]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %5, %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[INT]], %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[TRANSPOSE_2]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,2,3,2],f32>, !torch.list -> !torch.vtensor<[1,2,4,6],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,4,6],f32 %0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "DCR"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> @@ -1859,7 +1859,7 @@ func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f3 // CHECK: %[[C4:.*]] = torch.constant.int 4 // CHECK: %[[DIV:.*]] = torch.aten.div.int %[[SIZE_0]], %[[C4]] : !torch.int, !torch.int -> !torch.float // CHECK: %[[INT:.*]] = torch.aten.Int.float %[[DIV]] : !torch.float -> !torch.int - // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[C2_0]], %[[C2_0]], %[[INT]], %[[SIZE_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[INT]], %[[C2_0]], %[[C2_0]], %[[SIZE_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,8,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,2,2,2,3],f32> // CHECK: %[[C2_1:.*]] = torch.constant.int 2 // CHECK: %[[C4_0:.*]] = torch.constant.int 4 @@ -1872,7 +1872,7 @@ func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f3 // CHECK: %[[TRANSPOSE_2:.*]] = torch.aten.transpose.int %[[TRANSPOSE_1]], %[[C4_1]], %[[C5]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,2],f32> // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[SIZE_1]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[SIZE_2]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %5, %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[INT]], %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[TRANSPOSE_2]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,2,3,2],f32>, !torch.list -> !torch.vtensor<[1,2,4,6],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,4,6],f32 %0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "CRD"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32>