Skip to content

Commit 5d374ba

Browse files
authored
Decompose AtenUpsampleNearest(1/2)dVecOp to interpolate (#4330)
Addresses #4327 --------- Signed-off-by: zjgarvey <zjgarvey@gmail.com>
1 parent cedf522 commit 5d374ba

File tree

4 files changed

+129
-1
lines changed

4 files changed

+129
-1
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5185,6 +5185,30 @@ class DecomposeAtenUnflattenIntOp
51855185
};
51865186
} // namespace
51875187

5188+
namespace {
5189+
template <typename UpsampleVecOp>
5190+
class DecomposeAtenUpsampleNearestVecOp
5191+
: public OpRewritePattern<UpsampleVecOp> {
5192+
public:
5193+
using OpRewritePattern<UpsampleVecOp>::OpRewritePattern;
5194+
LogicalResult matchAndRewrite(UpsampleVecOp op,
5195+
PatternRewriter &rewriter) const override {
5196+
Value scales = op.getScaleFactors();
5197+
static_assert(std::is_same_v<UpsampleVecOp, AtenUpsampleNearest1dVecOp> ||
5198+
std::is_same_v<UpsampleVecOp, AtenUpsampleNearest2dVecOp>);
5199+
Value cstMode = rewriter.create<Torch::ConstantStrOp>(
5200+
op.getLoc(), rewriter.getStringAttr("nearest"));
5201+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
5202+
Value cstAntialias =
5203+
rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
5204+
rewriter.replaceOpWithNewOp<Aten__InterpolateSizeListScaleListOp>(
5205+
op, op.getType(), op.getInput(), op.getOutputSize(),
5206+
op.getScaleFactors(), cstMode, cstNone, cstNone, cstAntialias);
5207+
return success();
5208+
}
5209+
};
5210+
} // namespace
5211+
51885212
// Decompose aten.expand into aten.broadcast_to op.
51895213
namespace {
51905214
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
@@ -12983,6 +13007,12 @@ class DecomposeComplexOpsPass
1298313007
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
1298413008
addPatternIfTargetOpIsIllegal<DecomposeAtenFlattenUsingIntsOp>(patterns);
1298513009
addPatternIfTargetOpIsIllegal<DecomposeAtenUnflattenIntOp>(patterns);
13010+
addPatternIfTargetOpIsIllegal<
13011+
DecomposeAtenUpsampleNearestVecOp<AtenUpsampleNearest1dVecOp>>(
13012+
patterns);
13013+
addPatternIfTargetOpIsIllegal<
13014+
DecomposeAtenUpsampleNearestVecOp<AtenUpsampleNearest2dVecOp>>(
13015+
patterns);
1298613016
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
1298713017
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
1298813018
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
593593
target.addIllegalOp<AtenLogaddexp2Op>();
594594
target.addIllegalOp<AtenKlDivOp>();
595595
target.addIllegalOp<AtenAsStridedOp>();
596+
target.addIllegalOp<AtenUpsampleNearest1dVecOp>();
597+
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
596598

597599
for (auto &opName : backendLegalOpsSet) {
598600
target.addLegalOp(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,6 @@
497497
"CrossEntropyLossModule_basic",
498498
"CrossEntropyLossNoReductionModule_basic",
499499
"IsInfiniteModule_basic",
500-
"InterpolateDynamicModule_sizes_nearest",
501500
"IouOfModule_basic",
502501
"MeshgridIndexingIJ_basic",
503502
"MeshgridIndexingXY_basic",
@@ -915,8 +914,12 @@
915914
"TraceUnsignedIntModule_empty",
916915
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
917916
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
917+
"UpSampleNearest1dVecNoneScales_basic",
918+
"UpSampleNearest1dVecNoneShape_basic",
918919
"UpSampleNearest2dBackwardScalesNone_basic",
919920
"UpSampleNearest2dBackward_basic",
921+
"UpSampleNearest2dVecNoneScales_basic",
922+
"UpSampleNearest2dVecNoneShape_basic",
920923
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
921924
"ViewSizeFromOtherTensor_basic",
922925
# Error: `aten.as_strided` op is not supported
@@ -3956,8 +3959,13 @@
39563959
"TransposedConv2dNegativePadding_basic",
39573960
"TransposedConv3dNegativePadding_basic",
39583961
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
3962+
"InterpolateDynamicModule_sizes_nearest",
3963+
"UpSampleNearest1dVecNoneScales_basic",
3964+
"UpSampleNearest1dVecNoneShape_basic",
39593965
"UpSampleNearest2dBackwardScalesNone_basic",
39603966
"UpSampleNearest2dBackward_basic",
3967+
"UpSampleNearest2dVecNoneScales_basic",
3968+
"UpSampleNearest2dVecNoneShape_basic",
39613969
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
39623970
"ViewSizeFromOtherTensor_basic",
39633971
"VisionTransformerModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,94 @@ def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils):
10881088
module.forward(tu.rand(2, 3, 4, 4))
10891089

10901090

1091+
class UpSampleNearest2dVecNoneShape(torch.nn.Module):
1092+
def __init__(self):
1093+
super().__init__()
1094+
1095+
@export
1096+
@annotate_args(
1097+
[
1098+
None,
1099+
([-1, -1, -1, -1], torch.float64, True),
1100+
]
1101+
)
1102+
def forward(self, input):
1103+
return torch.ops.aten.upsample_nearest2d.vec(
1104+
input, output_size=None, scale_factors=[3.66, 4.2]
1105+
)
1106+
1107+
1108+
@register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneShape())
1109+
def UpSampleNearest2dVecNoneShape_basic(module, tu: TestUtils):
1110+
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
1111+
1112+
1113+
class UpSampleNearest2dVecNoneScales(torch.nn.Module):
1114+
def __init__(self):
1115+
super().__init__()
1116+
1117+
@export
1118+
@annotate_args(
1119+
[
1120+
None,
1121+
([-1, -1, -1, -1], torch.float64, True),
1122+
]
1123+
)
1124+
def forward(self, input):
1125+
return torch.ops.aten.upsample_nearest2d.vec(
1126+
input,
1127+
output_size=[18, 48],
1128+
scale_factors=None,
1129+
)
1130+
1131+
1132+
@register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneScales())
1133+
def UpSampleNearest2dVecNoneScales_basic(module, tu: TestUtils):
1134+
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
1135+
1136+
1137+
class UpSampleNearest1dVecNoneShape(torch.nn.Module):
1138+
def __init__(self):
1139+
super().__init__()
1140+
1141+
@export
1142+
@annotate_args(
1143+
[
1144+
None,
1145+
([-1, -1, -1], torch.float64, True),
1146+
]
1147+
)
1148+
def forward(self, input):
1149+
return torch.ops.aten.upsample_nearest1d.vec(
1150+
input, output_size=None, scale_factors=[3.0]
1151+
)
1152+
1153+
1154+
@register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneShape())
1155+
def UpSampleNearest1dVecNoneShape_basic(module, tu: TestUtils):
1156+
module.forward(tu.rand(1, 1, 6).to(torch.float64))
1157+
1158+
1159+
class UpSampleNearest1dVecNoneScales(torch.nn.Module):
1160+
def __init__(self):
1161+
super().__init__()
1162+
1163+
@export
1164+
@annotate_args(
1165+
[
1166+
None,
1167+
([-1, -1, -1], torch.float64, True),
1168+
]
1169+
)
1170+
def forward(self, input):
1171+
return torch.ops.aten.upsample_nearest1d.vec(input, [18], None)
1172+
1173+
1174+
@register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneScales())
1175+
def UpSampleNearest1dVecNoneScales_basic(module, tu: TestUtils):
1176+
module.forward(tu.rand(1, 1, 6).to(torch.float64))
1177+
1178+
10911179
class Conv1dModule(torch.nn.Module):
10921180
def __init__(self):
10931181
super().__init__()

0 commit comments

Comments
 (0)