3131import numpy as np
3232import onnx
3333import onnx .parser as oprs
34+ from onnx .helper import make_opsetid
3435
3536import qonnx .core .onnx_exec as oxe
3637from qonnx .core .datatype import DataType
@@ -328,7 +329,76 @@ def create_conv_upsample():
328329 return model
329330
330331
331- @pytest .mark .parametrize ("test_model" , ["Quartz" , "VGG" , "ConvUpsample" ])
332+ def create_resize (opset ):
333+ """
334+ Creates an model for testing the 3D to 4D transform of the resize node.
335+ """
336+ resize_node1 = onnx .helper .make_node (
337+ "Resize" ,
338+ inputs = ["in_resize1" , "roi_resize1" , "scales_resize1" , "sizes_resize1" ],
339+ outputs = ["out_resize1" ],
340+ name = "Resize1" ,
341+ mode = "nearest" ,
342+ )
343+
344+ resize_node2 = onnx .helper .make_node (
345+ "Resize" ,
346+ inputs = ["out_resize1" , "roi_resize2" , "scales_resize2" ],
347+ outputs = ["out_resize2" ],
348+ name = "Resize2" ,
349+ mode = "nearest" ,
350+ )
351+
352+ in_resize1 = onnx .helper .make_tensor_value_info ("in_resize1" , onnx .TensorProto .FLOAT , [1 , 32 , 4 ])
353+ out_resize1 = onnx .helper .make_tensor_value_info ("out_resize1" , onnx .TensorProto .FLOAT , [1 , 32 , 8 ])
354+ out_resize2 = onnx .helper .make_tensor_value_info ("out_resize2" , onnx .TensorProto .FLOAT , [1 , 32 , 16 ])
355+
356+ roi_resize1 = onnx .helper .make_tensor_value_info ("roi_resize1" , onnx .TensorProto .FLOAT , [4 ])
357+ scales_resize1 = onnx .helper .make_tensor_value_info ("scales_resize1" , onnx .TensorProto .FLOAT , [])
358+ sizes_resize1 = onnx .helper .make_tensor_value_info ("sizes_resize1" , onnx .TensorProto .INT64 , [3 ])
359+
360+ roi_resize2 = onnx .helper .make_tensor_value_info ("roi_resize2" , onnx .TensorProto .FLOAT , [4 ])
361+ scales_resize2 = onnx .helper .make_tensor_value_info ("scales_resize2" , onnx .TensorProto .FLOAT , [3 ])
362+
363+ list_of_nodes = [
364+ resize_node1 ,
365+ resize_node2 ,
366+ ]
367+ list_of_value_infos = [
368+ out_resize1 ,
369+ roi_resize1 ,
370+ sizes_resize1 ,
371+ scales_resize1 ,
372+ roi_resize2 ,
373+ scales_resize2 ,
374+ ]
375+
376+ graph = onnx .helper .make_graph (
377+ nodes = list_of_nodes ,
378+ name = "4d_conversion_resize_test_graph" ,
379+ inputs = [in_resize1 ],
380+ outputs = [out_resize2 ],
381+ value_info = list_of_value_infos ,
382+ )
383+
384+ onnx_model = qonnx_make_model (
385+ graph , producer_name = "4d_conversion_resize_test-model" , opset_imports = [make_opsetid ("" , opset )]
386+ )
387+ model = ModelWrapper (onnx_model )
388+
389+ model .set_initializer ("sizes_resize1" , np .array ([1 , 32 , 8 ], dtype = np .int64 ))
390+ if opset == 11 :
391+ model .set_initializer ("scales_resize1" , np .array ([], dtype = np .float32 ))
392+ elif opset == 13 :
393+ model .graph .node [0 ].input [2 ] = ""
394+ else :
395+ assert False , f"Undefined opset { opset } for Resize testcase creator"
396+ model .set_initializer ("scales_resize2" , np .array ([1.0 , 1.0 , 2.0 ], dtype = np .float32 ))
397+ model = model .transform (InferShapes ())
398+ return model
399+
400+
401+ @pytest .mark .parametrize ("test_model" , ["Quartz" , "VGG" , "ConvUpsample" , "Resize11" , "Resize13" ])
332402def test_4d_conversion (test_model ):
333403 """
334404 Test for the 3D to 4D transformation with a valid graph.
@@ -340,6 +410,8 @@ def test_4d_conversion(test_model):
340410 model = create_arbitrary_model_vgg ()
341411 elif test_model == "ConvUpsample" :
342412 model = create_conv_upsample ()
413+ elif "Resize" in test_model :
414+ model = create_resize (opset = int (test_model .replace ("Resize" , "" )))
343415 else :
344416 raise Exception ("Unknown test_model in test_4d_conversion" )
345417
0 commit comments