diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 68947a953b7a..3f4c63422e14 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1551,7 +1551,6 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding( SmallVector insertSliceOffsets{c0, c0}; SmallVector inputSizes = getTensorSizes(rewriter, loc, input); - SmallVector sliceSizes{inputSizes[0], inputSizes[1]}; // For the case in which the padding dimension value is negative, // we will need to shrink the dimension. Note in the PyTorch @@ -1565,10 +1564,12 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding( Value c2 = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(2)); for (size_t i = 0; i < numSpatialDims; i++) { + // Calculate inner size: (input_size - 1) * stride + 1 Value innerSize = rewriter.createOrFold(loc, inDims[i], c1); innerSize = rewriter.createOrFold( loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i])); innerSize = rewriter.createOrFold(loc, innerSize, c1); + innerSizes.push_back(innerSize); Value offset = rewriter.createOrFold(loc, weightDims[i], c1); offset = rewriter.createOrFold( @@ -1576,8 +1577,14 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding( offset = rewriter.createOrFold( loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i])); + // We need to crop or pad from two sides - top&bottom or left&right. + // Therefore multiply by 2. Value outerSize = rewriter.createOrFold(loc, offset, c2); + + // Crop or pad based on the sign of offset outerSize = rewriter.createOrFold(loc, outerSize, innerSize); + + // Add optional padding values outerSize = rewriter.createOrFold( loc, outerSize, castIntToIndex(rewriter, loc, outputPaddingIntValues[i])); @@ -1587,45 +1594,74 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding( // Make the negative value positive by multiplying by -1. anyDimensionPaddingIsNegative = true; auto offsetType = offset.getType(); - auto negOneConst = rewriter.createOrFold( - loc, offsetType, rewriter.getIntegerAttr(offsetType, -1)); + auto negOneConst = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(offsetType, -1)); auto posOffset = rewriter.createOrFold(loc, offset, negOneConst); - // Compute the reduced dimension size due to negative padding. - auto sizeReduction = - rewriter.createOrFold(loc, posOffset, c2); - sliceSizes.push_back(rewriter.createOrFold( - loc, inputSizes[i + 2], sizeReduction)); - extractSliceOffsets.push_back(posOffset); insertSliceOffsets.push_back(c0); } else { - sliceSizes.push_back(inputSizes[i + 2]); extractSliceOffsets.push_back(c0); insertSliceOffsets.push_back(offset); } } - Value initTensor = createInitTensor(rewriter, loc, outerSizes, inputDTy, pad); // Insert input into allocated tensor SmallVector strideIndexValues{c1, c1}; for (auto stride : strideIntValues) strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride)); - auto insertSliceOpInput = input; if (anyDimensionPaddingIsNegative) { - insertSliceOpInput = tensor::ExtractSliceOp::create( - rewriter, loc, - torch_to_linalg::removeSizeInformation(rewriter, loc, input), - extractSliceOffsets, sliceSizes, strideIndexValues); - } - auto paddedInput = tensor::InsertSliceOp::create( - rewriter, loc, - torch_to_linalg::removeSizeInformation(rewriter, loc, insertSliceOpInput), - initTensor, insertSliceOffsets, sliceSizes, strideIndexValues); - return paddedInput; + // Some dimensions may need padding and some dimensions need cropping + + // 1. Allocate a maxSizes buffer (max of inner and outer for each dim) + // 2. Insert the input into maxSizes buffer at appropriate offsets (if + // insertSliceOffsets is positive, pad; 0 no padding) and stride + // 3. Extract the final outerSizes from maxSizes buffer + + // Create the "max size" tensor to accommodate both padding and cropping + SmallVector maxSizes{inBatch, inChannels}; + for (size_t i = 0; i < numSpatialDims; ++i) { + Value innerDim = innerSizes[i + 2]; + Value outerDim = outerSizes[i + 2]; + Value isPadding = rewriter.createOrFold( + loc, arith::CmpIPredicate::ugt, outerDim, innerDim); + Value maxDim = rewriter.createOrFold(loc, isPadding, + outerDim, innerDim); + maxSizes.push_back(maxDim); + } + + Value initMaxTensor = + createInitTensor(rewriter, loc, maxSizes, inputDTy, pad); + + // Insert input + auto paddedTensor = rewriter.create( + loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input), + initMaxTensor, insertSliceOffsets, inputSizes, strideIndexValues); + + SmallVector allOnesStrides(inputSizes.size(), c1); + + // Crop. Extract the final tensor from the "max" tensor + auto finalTensor = rewriter.create( + loc, + torch_to_linalg::removeSizeInformation(rewriter, loc, paddedTensor), + extractSliceOffsets, outerSizes, allOnesStrides); + + return finalTensor; + + } else { + + Value initPaddedTensor = + createInitTensor(rewriter, loc, outerSizes, inputDTy, pad); + + // Insert the original input into the outer tensor with calculated offsets + auto paddedInput = rewriter.create( + loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input), + initPaddedTensor, insertSliceOffsets, inputSizes, strideIndexValues); + return paddedInput; + } } namespace { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0e1b3b67d102..87c455011120 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3961,7 +3961,9 @@ "TraceModule_empty", "TraceUnsignedIntModule_empty", "TransposedConv1dNegativePadding_basic", + "TransposedConv1dNegativePaddingLarge_basic", "TransposedConv2dNegativePadding_basic", + "TransposedConv2dPositiveAndNegativePadding_basic", "TransposedConv3dNegativePadding_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "InterpolateDynamicModule_sizes_nearest", @@ -5039,7 +5041,9 @@ "TraceUnsignedIntModule_basic", "TraceUnsignedIntModule_empty", "TransposedConv1dNegativePadding_basic", + "TransposedConv1dNegativePaddingLarge_basic", "TransposedConv2dNegativePadding_basic", + "TransposedConv2dPositiveAndNegativePadding_basic", "TransposedConv3dNegativePadding_basic", "TupleModule_basic", "TypeAsDifferentModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 2a1c627f6ee5..f3fd695bbeed 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1988,7 +1988,7 @@ def forward(self, inputVec, weight, bias): inputVec, weight, bias=bias, - stride=[1], + stride=[4], padding=[3], dilation=[1], transposed=True, @@ -2002,6 +2002,38 @@ def TransposedConv1dNegativePadding_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 7), tu.rand(1, 2, 3), tu.rand(2)) +class TransposedConv1dNegativePaddingLarge(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 17, 5], torch.float32, True), + ([17, 6, 3], torch.float32, True), + ([6], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.convolution( + inputVec, + weight, + bias=bias, + stride=[7], + padding=[10], + dilation=[4], + transposed=True, + output_padding=[0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: TransposedConv1dNegativePaddingLarge()) +def TransposedConv1dNegativePaddingLarge_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 17, 5), tu.rand(17, 6, 3), tu.rand(6)) + + class TransposedConv2dNegativePadding(torch.nn.Module): def __init__(self): super().__init__() @@ -2034,6 +2066,38 @@ def TransposedConv2dNegativePadding_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2)) +class TransposedConv2dPositiveAndNegativePadding(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 7], torch.float32, True), + ([1, 2, 3, 3], torch.float32, True), + ([2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.convolution( + inputVec, + weight, + bias=bias, + stride=[4, 4], + padding=[0, 3], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: TransposedConv2dPositiveAndNegativePadding()) +def TransposedConv2dPositiveAndNegativePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2)) + + class TransposedConv3dNegativePadding(torch.nn.Module): def __init__(self): super().__init__() @@ -2052,7 +2116,7 @@ def forward(self, inputVec, weight, bias): inputVec, weight, bias=bias, - stride=[1, 1, 1], + stride=[4, 4, 4], padding=[2, 1, 3], dilation=[1, 1, 1], transposed=True, diff --git a/test/Conversion/TorchToLinalg/convolution.mlir b/test/Conversion/TorchToLinalg/convolution.mlir index 88627c166877..523e93effb35 100644 --- a/test/Conversion/TorchToLinalg/convolution.mlir +++ b/test/Conversion/TorchToLinalg/convolution.mlir @@ -152,12 +152,17 @@ func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>) } // CHECK-LABEL: func.func @tranConv2dNegativePadding( -// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32> -// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32> -// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[IN_TENSOR]][0, 0, 0, 1] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x7xf32> to tensor<1x1x4x5xf32> -// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[INIT_TENSOR:.*]][0, 0, 2, 0] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x5xf32> into tensor<1x1x8x5xf32> -// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[INSERTED_SLICE]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32> -// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32> +// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32> attributes {torch.assume_strict_symbolic_shapes} { +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32> +// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x8x7xf32> +// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x8x7xf32>) -> tensor<1x1x8x7xf32> +// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 1, 1] : tensor<1x1x4x7xf32> into tensor<1x1x8x7xf32> +// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 8, 5] [1, 1, 1, 1] : tensor<1x1x8x7xf32> to tensor<1x1x8x5xf32> +// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32> +// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32> func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) -> !torch.vtensor<[1, 2, 6, 3],f32> attributes {torch.assume_strict_symbolic_shapes} { %int0 = torch.constant.int 0 %true = torch.constant.bool true @@ -174,3 +179,31 @@ func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) -> %6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int1 : !torch.vtensor<[1, 1, 4, 7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1, 2, 6, 3],f32> return %6 : !torch.vtensor<[1, 2, 6, 3],f32> } + +// CHECK-LABEL: func.func @tranConv2dNegativeAndPositivePadding( +// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>, +// CHECK-SAME: %[[WEIGHTS_VTENSOR:.*]]: !torch.vtensor<[1,2,3,3],f32>, +// CHECK-SAME: %[[BIAS_VTENSOR:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> { +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32> +// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x17x25xf32> +// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x17x25xf32>) -> tensor<1x1x17x25xf32> +// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 4, 4] : tensor<1x1x4x7xf32> into tensor<1x1x17x25xf32> +// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 17, 23] [1, 1, 1, 1] : tensor<1x1x17x25xf32> to tensor<1x1x17x23xf32> +// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x17x23xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x15x21xf32>) -> tensor<1x2x15x21xf32> +// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x15x21xf32> -> !torch.vtensor<[1,2,15,21],f32> +func.func @tranConv2dNegativeAndPositivePadding(%arg0: !torch.vtensor<[1,1,4,7],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %int0 = torch.constant.int 0 + %int4 = torch.constant.int 4 + %true = torch.constant.bool true + %0 = torch.prim.ListConstruct %int4, %int4 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int0, %int3 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %1, %2, %true, %3, %int1 : !torch.vtensor<[1,1,4,7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,15,21],f32> + return %4 : !torch.vtensor<[1,2,15,21],f32> +}