3131import numpy as np
3232import onnx
3333import onnx .parser as oprs
34+ from onnx .helper import make_opsetid
3435
3536import qonnx .core .onnx_exec as oxe
3637from qonnx .core .datatype import DataType
@@ -328,7 +329,7 @@ def create_conv_upsample():
328329 return model
329330
330331
331- def create_resize ():
332+ def create_resize (opset ):
332333 """
333334 Creates an model for testing the 3D to 4D transform of the resize node.
334335 """
@@ -380,16 +381,24 @@ def create_resize():
380381 value_info = list_of_value_infos ,
381382 )
382383
383- onnx_model = qonnx_make_model (graph , producer_name = "4d_conversion_resize_test-model" )
384+ onnx_model = qonnx_make_model (
385+ graph , producer_name = "4d_conversion_resize_test-model" , opset_imports = [make_opsetid ("" , opset )]
386+ )
384387 model = ModelWrapper (onnx_model )
385- model = model . transform ( InferShapes ())
388+
386389 model .set_initializer ("sizes_resize1" , np .array ([1 , 32 , 8 ], dtype = np .int64 ))
387- model .set_initializer ("scales_resize1" , np .array ([], dtype = np .float32 ))
390+ if opset == 11 :
391+ model .set_initializer ("scales_resize1" , np .array ([], dtype = np .float32 ))
392+ elif opset == 13 :
393+ model .graph .node [0 ].input [2 ] = ""
394+ else :
395+ assert False , f"Undefined opset { opset } for Resize testcase creator"
388396 model .set_initializer ("scales_resize2" , np .array ([1.0 , 1.0 , 2.0 ], dtype = np .float32 ))
397+ model = model .transform (InferShapes ())
389398 return model
390399
391400
392- @pytest .mark .parametrize ("test_model" , ["Quartz" , "VGG" , "ConvUpsample" , "Resize " ])
401+ @pytest .mark .parametrize ("test_model" , ["Quartz" , "VGG" , "ConvUpsample" , "Resize11" , "Resize13 " ])
393402def test_4d_conversion (test_model ):
394403 """
395404 Test for the 3D to 4D transformation with a valid graph.
@@ -401,8 +410,8 @@ def test_4d_conversion(test_model):
401410 model = create_arbitrary_model_vgg ()
402411 elif test_model == "ConvUpsample" :
403412 model = create_conv_upsample ()
404- elif test_model == "Resize" :
405- model = create_resize ()
413+ elif "Resize" in test_model :
414+ model = create_resize (opset = int ( test_model . replace ( "Resize" , "" )) )
406415 else :
407416 raise Exception ("Unknown test_model in test_4d_conversion" )
408417
0 commit comments