@@ -284,7 +284,7 @@ def _check_model(
284284 opt = onnx .reference .ReferenceEvaluator (optimized_model , new_ops = [FusedMatMul ])
285285 expected = ref .run (None , feeds )
286286 got = opt .run (None , feeds )
287- self .assertEqual (len (expected ), len (got ))
287+ self .assertEqual (len (got ), len (expected ))
288288 for a , b in zip (expected , got ):
289289 np .testing .assert_allclose (a , b , atol = atol , rtol = rtol )
290290
@@ -319,7 +319,7 @@ def test_fused_matmul_div_models(self, name, script_func, input_types, output_ty
319319 rule_set = fused_matmul_rule_sets .fused_matmul_rule_sets ()
320320 rule_set .apply_to_model (ir_model )
321321 rewritten_model = ir .serde .serialize_model (ir_model )
322- self .assertEqual (["Constant" , "FusedMatMul" ], [ n .op_type for n in ir_model .graph ])
322+ self .assertEqual ([n .op_type for n in ir_model .graph ], [ "Constant" , "FusedMatMul" ])
323323 self ._check_model (model_proto , rewritten_model , atol = 1e-6 )
324324
325325 @parameterized .parameterized .expand (
@@ -354,7 +354,7 @@ def test_fused_matmul_with_transpose(self, _, script_func):
354354 ir_model = ir .serde .deserialize_model (model_proto )
355355 self ._apply_fusion_rules (ir_model )
356356 rewritten_model = ir .serde .serialize_model (ir_model )
357- self .assertEqual (["FusedMatMul" ], [ n .op_type for n in ir_model .graph ])
357+ self .assertEqual ([n .op_type for n in ir_model .graph ], [ "FusedMatMul" ])
358358 self ._check_model (model_proto , rewritten_model , atol = 1e-6 )
359359
360360 @parameterized .parameterized .expand ([("should_not_match" , _should_not_match )])
@@ -366,8 +366,8 @@ def test_should_not_match(self, _, script_func):
366366 self ._apply_fusion_rules (ir_model )
367367 rewritten_model = ir .serde .serialize_model (ir_model )
368368 self .assertEqual (
369- ["Transpose" , "MatMul" , "Transpose" ],
370369 [n .op_type for n in ir_model .graph ],
370+ ["Transpose" , "MatMul" , "Transpose" ],
371371 )
372372 self ._check_model (model_proto , rewritten_model , atol = 1e-6 )
373373
@@ -391,7 +391,7 @@ def test_fused_matmul_with_other_node_in_middle(self, _, script_func):
391391 common_passes .ShapeInferencePass ()(ir_model )
392392 self ._apply_fusion_rules (ir_model )
393393 rewritten_model = ir .serde .serialize_model (ir_model )
394- self .assertEqual (["Identity" , "FusedMatMul" ], [ n .op_type for n in ir_model .graph ])
394+ self .assertEqual ([n .op_type for n in ir_model .graph ], [ "Identity" , "FusedMatMul" ])
395395 self ._check_model (model_proto , rewritten_model , atol = 1e-6 )
396396
397397 @parameterized .parameterized .expand (
@@ -440,7 +440,7 @@ def test_transpose_fused_matmul_with_batch(self, _, script_func):
440440 ir_model = ir .serde .deserialize_model (model_proto )
441441 self ._apply_fusion_rules (ir_model )
442442 rewritten_model = ir .serde .serialize_model (ir_model )
443- self .assertEqual (["FusedMatMul" ], [ n .op_type for n in ir_model .graph ])
443+ self .assertEqual ([n .op_type for n in ir_model .graph ], [ "FusedMatMul" ])
444444 self ._check_model (model_proto , rewritten_model , atol = 1e-6 )
445445
446446
0 commit comments