Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 58 additions & 22 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1551,7 +1551,6 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
SmallVector<Value> insertSliceOffsets{c0, c0};

SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
SmallVector<Value> sliceSizes{inputSizes[0], inputSizes[1]};

// For the case in which the padding dimension value is negative,
// we will need to shrink the dimension. Note in the PyTorch
Expand All @@ -1565,19 +1564,27 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
Value c2 = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(2));

for (size_t i = 0; i < numSpatialDims; i++) {
// Calculate inner size: (input_size - 1) * stride + 1
Value innerSize = rewriter.createOrFold<arith::SubIOp>(loc, inDims[i], c1);
innerSize = rewriter.createOrFold<arith::MulIOp>(
loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i]));
innerSize = rewriter.createOrFold<arith::AddIOp>(loc, innerSize, c1);
innerSizes.push_back(innerSize);

Value offset = rewriter.createOrFold<arith::SubIOp>(loc, weightDims[i], c1);
offset = rewriter.createOrFold<arith::MulIOp>(
loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i]));
offset = rewriter.createOrFold<arith::SubIOp>(
loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i]));

// We need to crop or pad from two sides - top&bottom or left&right.
// Therefore multiply by 2.
Value outerSize = rewriter.createOrFold<arith::MulIOp>(loc, offset, c2);

// Crop or pad based on the sign of offset
outerSize = rewriter.createOrFold<arith::AddIOp>(loc, outerSize, innerSize);

// Add optional padding values
outerSize = rewriter.createOrFold<arith::AddIOp>(
loc, outerSize,
castIntToIndex(rewriter, loc, outputPaddingIntValues[i]));
Expand All @@ -1587,45 +1594,74 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding(
// Make the negative value positive by multiplying by -1.
anyDimensionPaddingIsNegative = true;
auto offsetType = offset.getType();
auto negOneConst = rewriter.createOrFold<arith::ConstantOp>(
loc, offsetType, rewriter.getIntegerAttr(offsetType, -1));
auto negOneConst = arith::ConstantOp::create(
rewriter, loc, rewriter.getIntegerAttr(offsetType, -1));
auto posOffset =
rewriter.createOrFold<arith::MulIOp>(loc, offset, negOneConst);

// Compute the reduced dimension size due to negative padding.
auto sizeReduction =
rewriter.createOrFold<arith::MulIOp>(loc, posOffset, c2);
sliceSizes.push_back(rewriter.createOrFold<arith::SubIOp>(
loc, inputSizes[i + 2], sizeReduction));

extractSliceOffsets.push_back(posOffset);
insertSliceOffsets.push_back(c0);
} else {
sliceSizes.push_back(inputSizes[i + 2]);
extractSliceOffsets.push_back(c0);
insertSliceOffsets.push_back(offset);
}
}
Value initTensor = createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);

// Insert input into allocated tensor
SmallVector<Value> strideIndexValues{c1, c1};
for (auto stride : strideIntValues)
strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride));

auto insertSliceOpInput = input;
if (anyDimensionPaddingIsNegative) {
insertSliceOpInput = tensor::ExtractSliceOp::create(
rewriter, loc,
torch_to_linalg::removeSizeInformation(rewriter, loc, input),
extractSliceOffsets, sliceSizes, strideIndexValues);
}

auto paddedInput = tensor::InsertSliceOp::create(
rewriter, loc,
torch_to_linalg::removeSizeInformation(rewriter, loc, insertSliceOpInput),
initTensor, insertSliceOffsets, sliceSizes, strideIndexValues);
return paddedInput;
// Some dimensions may need padding and some dimensions need cropping

// 1. Allocate a maxSizes buffer (max of inner and outer for each dim)
// 2. Insert the input into maxSizes buffer at appropriate offsets (if
// insertSliceOffsets is positive, pad; 0 no padding) and stride
// 3. Extract the final outerSizes from maxSizes buffer

// Create the "max size" tensor to accommodate both padding and cropping
SmallVector<Value> maxSizes{inBatch, inChannels};
for (size_t i = 0; i < numSpatialDims; ++i) {
Value innerDim = innerSizes[i + 2];
Value outerDim = outerSizes[i + 2];
Value isPadding = rewriter.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::ugt, outerDim, innerDim);
Value maxDim = rewriter.createOrFold<arith::SelectOp>(loc, isPadding,
outerDim, innerDim);
maxSizes.push_back(maxDim);
}

Value initMaxTensor =
createInitTensor(rewriter, loc, maxSizes, inputDTy, pad);

// Insert input
auto paddedTensor = rewriter.create<tensor::InsertSliceOp>(
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
initMaxTensor, insertSliceOffsets, inputSizes, strideIndexValues);

SmallVector<Value> allOnesStrides(inputSizes.size(), c1);

// Crop. Extract the final tensor from the "max" tensor
auto finalTensor = rewriter.create<tensor::ExtractSliceOp>(
loc,
torch_to_linalg::removeSizeInformation(rewriter, loc, paddedTensor),
extractSliceOffsets, outerSizes, allOnesStrides);

return finalTensor;

} else {

Value initPaddedTensor =
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);

// Insert the original input into the outer tensor with calculated offsets
auto paddedInput = rewriter.create<tensor::InsertSliceOp>(
loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input),
initPaddedTensor, insertSliceOffsets, inputSizes, strideIndexValues);
return paddedInput;
}
}

namespace {
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3961,7 +3961,9 @@
"TraceModule_empty",
"TraceUnsignedIntModule_empty",
"TransposedConv1dNegativePadding_basic",
"TransposedConv1dNegativePaddingLarge_basic",
"TransposedConv2dNegativePadding_basic",
"TransposedConv2dPositiveAndNegativePadding_basic",
"TransposedConv3dNegativePadding_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"InterpolateDynamicModule_sizes_nearest",
Expand Down Expand Up @@ -5039,7 +5041,9 @@
"TraceUnsignedIntModule_basic",
"TraceUnsignedIntModule_empty",
"TransposedConv1dNegativePadding_basic",
"TransposedConv1dNegativePaddingLarge_basic",
"TransposedConv2dNegativePadding_basic",
"TransposedConv2dPositiveAndNegativePadding_basic",
"TransposedConv3dNegativePadding_basic",
"TupleModule_basic",
"TypeAsDifferentModule_basic",
Expand Down
68 changes: 66 additions & 2 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,7 +1988,7 @@ def forward(self, inputVec, weight, bias):
inputVec,
weight,
bias=bias,
stride=[1],
stride=[4],
padding=[3],
dilation=[1],
transposed=True,
Expand All @@ -2002,6 +2002,38 @@ def TransposedConv1dNegativePadding_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 7), tu.rand(1, 2, 3), tu.rand(2))


class TransposedConv1dNegativePaddingLarge(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([1, 17, 5], torch.float32, True),
([17, 6, 3], torch.float32, True),
([6], torch.float32, True),
]
)
def forward(self, inputVec, weight, bias):
return torch.ops.aten.convolution(
inputVec,
weight,
bias=bias,
stride=[7],
padding=[10],
dilation=[4],
transposed=True,
output_padding=[0],
groups=1,
)


@register_test_case(module_factory=lambda: TransposedConv1dNegativePaddingLarge())
def TransposedConv1dNegativePaddingLarge_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 17, 5), tu.rand(17, 6, 3), tu.rand(6))


class TransposedConv2dNegativePadding(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -2034,6 +2066,38 @@ def TransposedConv2dNegativePadding_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2))


class TransposedConv2dPositiveAndNegativePadding(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([1, 1, 4, 7], torch.float32, True),
([1, 2, 3, 3], torch.float32, True),
([2], torch.float32, True),
]
)
def forward(self, inputVec, weight, bias):
return torch.ops.aten.convolution(
inputVec,
weight,
bias=bias,
stride=[4, 4],
padding=[0, 3],
dilation=[1, 1],
transposed=True,
output_padding=[0, 0],
groups=1,
)


@register_test_case(module_factory=lambda: TransposedConv2dPositiveAndNegativePadding())
def TransposedConv2dPositiveAndNegativePadding_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 4, 7), tu.rand(1, 2, 3, 3), tu.rand(2))


class TransposedConv3dNegativePadding(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -2052,7 +2116,7 @@ def forward(self, inputVec, weight, bias):
inputVec,
weight,
bias=bias,
stride=[1, 1, 1],
stride=[4, 4, 4],
padding=[2, 1, 3],
dilation=[1, 1, 1],
transposed=True,
Expand Down
45 changes: 39 additions & 6 deletions test/Conversion/TorchToLinalg/convolution.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,17 @@ func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>)
}

// CHECK-LABEL: func.func @tranConv2dNegativePadding(
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32>
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[IN_TENSOR]][0, 0, 0, 1] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x7xf32> to tensor<1x1x4x5xf32>
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[INIT_TENSOR:.*]][0, 0, 2, 0] [1, 1, 4, 5] [1, 1, 1, 1] : tensor<1x1x4x5xf32> into tensor<1x1x8x5xf32>
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[INSERTED_SLICE]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>) -> !torch.vtensor<[1,2,6,3],f32> attributes {torch.assume_strict_symbolic_shapes} {
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x8x7xf32>
// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x8x7xf32>) -> tensor<1x1x8x7xf32>
// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 1, 1] : tensor<1x1x4x7xf32> into tensor<1x1x8x7xf32>
// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 8, 5] [1, 1, 1, 1] : tensor<1x1x8x7xf32> to tensor<1x1x8x5xf32>
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x8x5xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x6x3xf32>) -> tensor<1x2x6x3xf32>
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x6x3xf32> -> !torch.vtensor<[1,2,6,3],f32>
func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) -> !torch.vtensor<[1, 2, 6, 3],f32> attributes {torch.assume_strict_symbolic_shapes} {
%int0 = torch.constant.int 0
%true = torch.constant.bool true
Expand All @@ -174,3 +179,31 @@ func.func @tranConv2dNegativePadding(%arg0: !torch.vtensor<[1, 1, 4, 7],f32>) ->
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int1 : !torch.vtensor<[1, 1, 4, 7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1, 2, 6, 3],f32>
return %6 : !torch.vtensor<[1, 2, 6, 3],f32>
}

// CHECK-LABEL: func.func @tranConv2dNegativeAndPositivePadding(
// CHECK-SAME: %[[INPUT_VTENSOR:.*]]: !torch.vtensor<[1,1,4,7],f32>,
// CHECK-SAME: %[[WEIGHTS_VTENSOR:.*]]: !torch.vtensor<[1,2,3,3],f32>,
// CHECK-SAME: %[[BIAS_VTENSOR:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> {
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_VTENSOR]] : !torch.vtensor<[1,1,4,7],f32> -> tensor<1x1x4x7xf32>
// CHECK: %[[EMPTY_UNSTRIDED_TENSOR:.*]] = tensor.empty() : tensor<1x1x17x25xf32>
// CHECK: %[[ZEROS_UNSTRIDED_TENSOR:.*]] = linalg.fill ins(%[[C0F]] : f32) outs(%[[EMPTY_UNSTRIDED_TENSOR]] : tensor<1x1x17x25xf32>) -> tensor<1x1x17x25xf32>
// CHECK: %[[INPUT_UNSTRIDED_TENSOR:.*]] = tensor.insert_slice %[[INPUT_TENSOR]] into %[[ZEROS_UNSTRIDED_TENSOR]][0, 0, 2, 0] [1, 1, 4, 7] [1, 1, 4, 4] : tensor<1x1x4x7xf32> into tensor<1x1x17x25xf32>
// CHECK: %[[CROPPED_UNSTRIDED_TENSOR:.*]] = tensor.extract_slice %[[INPUT_UNSTRIDED_TENSOR]][0, 0, 0, 1] [1, 1, 17, 23] [1, 1, 1, 1] : tensor<1x1x17x25xf32> to tensor<1x1x17x23xf32>
// CHECK: %[[OUT_TENSOR:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[CROPPED_UNSTRIDED_TENSOR]], %[[WEIGHTS:.*]] : tensor<1x1x17x23xf32>, tensor<2x1x3x3xf32>) outs(%[[INIT_OUT_TENSOR:.*]] : tensor<1x2x15x21xf32>) -> tensor<1x2x15x21xf32>
// CHECK: %[[OUT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[OUT_TENSOR]] : tensor<1x2x15x21xf32> -> !torch.vtensor<[1,2,15,21],f32>
func.func @tranConv2dNegativeAndPositivePadding(%arg0: !torch.vtensor<[1,1,4,7],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,15,21],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%int0 = torch.constant.int 0
%int4 = torch.constant.int 4
%true = torch.constant.bool true
%0 = torch.prim.ListConstruct %int4, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int0, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%4 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %1, %2, %true, %3, %int1 : !torch.vtensor<[1,1,4,7],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.vtensor<[2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,15,21],f32>
return %4 : !torch.vtensor<[1,2,15,21],f32>
}
Loading