Skip to content

Commit 69025f7

Browse files
authored
[version converter] Fix DFT opset 20 (#2659)
Fixes pytorch/pytorch#148687 Axis is actually the third input of DFT.
1 parent ad83914 commit 69025f7

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

onnxscript/version_converter/_version_converter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,13 @@ def _get_str_attribute(node: ir.Node, name: str, default: str | None = None) ->
155155
@register("DFT", node_version=19, up_conversion=True)
156156
def dft_19_20(node: ir.Node, op):
157157
input = node.inputs[0]
158+
dft_length = node.inputs[1] if len(node.inputs) > 1 else None
158159
inverse = _get_int_attribute(node, "inverse", 0)
159160
onesided = _get_int_attribute(node, "onesided", 0)
160161
axis = _get_int_attribute(node, "axis", None)
161162
if axis is not None:
162163
axis_value = op.Constant(value_int=axis)
163-
return op.DFT(input, axis_value, inverse=inverse, onesided=onesided)
164+
return op.DFT(input, dft_length, axis_value, inverse=inverse, onesided=onesided)
164165
return None
165166

166167

onnxscript/version_converter/_version_converter_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_version_convert_compatible(self):
144144
self.assertEqual(model.graph.node(3).version, 20)
145145
self.assertEqual(model.graph.node(3).op_type, "DFT")
146146
self.assertEqual(model.graph.node(3).version, 20)
147-
self.assertEqual(len(model.graph.node(3).inputs), 2)
147+
self.assertEqual(len(model.graph.node(3).inputs), 3)
148148

149149
def test_version_convert_gridsample_linear(self):
150150
model = ir.from_onnx_text(
@@ -241,7 +241,7 @@ def test_version_convert_inline(self):
241241
self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear")
242242
self.assertEqual(model.graph.node(6).op_type, "DFT")
243243
self.assertEqual(model.graph.node(6).version, 20)
244-
self.assertEqual(len(model.graph.node(6).inputs), 2)
244+
self.assertEqual(len(model.graph.node(6).inputs), 3)
245245

246246

247247
class VersionConverter20to21Test(unittest.TestCase):

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,20 @@ def forward(self, x):
238238
)
239239
_testing.assert_onnx_program(onnx_program)
240240

241+
def test_dft_axis_promoted_from_attribute_to_input(self):
242+
class Model(torch.nn.Module):
243+
def forward(self, x):
244+
return torch.ops.aten._fft_r2c(x, [0], normalization=1, onesided=True) # pylint: disable=protected-access
245+
246+
onnx_program = torch.onnx.export(
247+
Model(),
248+
(torch.randn(2, 3),),
249+
opset_version=20,
250+
dynamic_shapes=({0: "dim_x"},),
251+
dynamo=True,
252+
)
253+
_testing.assert_onnx_program(onnx_program)
254+
241255
def test_avg_pool(self):
242256
class Model(torch.nn.Module):
243257
def forward(self, x2d, x3d, x4d, x5d):

0 commit comments

Comments
 (0)