@@ -40,7 +40,29 @@ def check_op_counts(
4040 self .assertTrue (op_counts_match (graph_module , expected_op_counts ))
4141
4242
43- class TestFusionPasses (TestFusionPassesBase ):
43+ class TestFuseMMWithAddPass (TestFusionPassesBase ):
44+ def test_no_fuse_for_3d_bias (self ) -> None :
45+ builder = GraphBuilder ()
46+ x = builder .placeholder ("x" , torch .randn (4 , 3 , dtype = torch .float32 ))
47+ y = builder .placeholder ("y" , torch .randn (3 , 5 , dtype = torch .float32 ))
48+ z = builder .placeholder ("z" , torch .randn (1 , 4 , 5 , dtype = torch .float32 ))
49+ mm = builder .call_operator (
50+ op = exir_ops .edge .aten .mm .default ,
51+ args = (x , y ),
52+ )
53+ output = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (mm , z ))
54+ builder .output ([output ])
55+ original_graph = builder .get_graph_module ()
56+
57+ p = FuseMMWithAdd ()
58+ converted_graph = cast (PassResult , p (original_graph )).graph_module
59+ converted_graph .graph .eliminate_dead_code ()
60+ self .assertEqual (
61+ count_node (converted_graph , exir_ops .edge .aten .addmm .default ), 0
62+ )
63+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
64+ self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 1 )
65+
4466 def test_fuse_mm_with_add (self ) -> None :
4567 builder = GraphBuilder ()
4668 x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
@@ -176,6 +198,8 @@ def test_keep_mm_add_with_multiple_users(self) -> None:
176198 self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .mm .default ), 1 )
177199 self .assertEqual (count_node (converted_graph , exir_ops .edge .aten .add .Tensor ), 3 )
178200
201+
202+ class TestFusionPasses (TestFusionPassesBase ):
179203 def test_permute_transpose_fusion (self ) -> None :
180204 builder = GraphBuilder ()
181205 x = builder .placeholder ("x" , torch .randn (3 , 1 , 3 , 1 , 4 , dtype = torch .float32 ))
0 commit comments