@@ -67,3 +67,72 @@ TEST(Converters, ATenBMMConvertsCorrectly) {
6767
6868 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
6969}
70+
71+ TEST (Converters, ATenBADDBMMConvertsCorrectly) {
72+ const auto graph = R"IR(
73+ graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
74+ %a : float = prim::Constant[value=1.5]()
75+ %b : float = prim::Constant[value=.2]()
76+ %2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
77+ return (%2))IR" ;
78+
79+ auto g = std::make_shared<torch::jit::Graph>();
80+ torch::jit::parseIR (graph, g.get ());
81+
82+ auto self = at::randn ({10 , 3 , 5 }, {at::kCUDA });
83+ auto bat1 = at::randn ({10 , 3 , 4 }, {at::kCUDA });
84+ auto bat2 = at::randn ({10 , 4 , 5 }, {at::kCUDA });
85+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
86+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {self, bat1, bat2});
87+
88+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
89+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {self, bat1, bat2});
90+
91+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
92+ }
93+
94+ TEST (Converters, ATenBADDBMMAlphaBetaDisabledConvertsCorrectly) {
95+ const auto graph = R"IR(
96+ graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
97+ %a : float = prim::Constant[value=1]()
98+ %b : float = prim::Constant[value=0]()
99+ %2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
100+ return (%2))IR" ;
101+
102+ auto g = std::make_shared<torch::jit::Graph>();
103+ torch::jit::parseIR (graph, g.get ());
104+
105+ auto self = at::randn ({10 , 3 , 5 }, {at::kCUDA });
106+ auto bat1 = at::randn ({10 , 3 , 4 }, {at::kCUDA });
107+ auto bat2 = at::randn ({10 , 4 , 5 }, {at::kCUDA });
108+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
109+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {self, bat1, bat2});
110+
111+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
112+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {self, bat1, bat2});
113+
114+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
115+ }
116+
117+ TEST (Converters, ATenBADDBMMScalarDefaultsConvertsCorrectly) {
118+ const auto graph = R"IR(
119+ graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
120+ %a : float = prim::Constant[value=1]()
121+ %b : float = prim::Constant[value=1]()
122+ %2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
123+ return (%2))IR" ;
124+
125+ auto g = std::make_shared<torch::jit::Graph>();
126+ torch::jit::parseIR (graph, g.get ());
127+
128+ auto self = at::randn ({10 , 3 , 5 }, {at::kCUDA });
129+ auto bat1 = at::randn ({10 , 3 , 4 }, {at::kCUDA });
130+ auto bat2 = at::randn ({10 , 4 , 5 }, {at::kCUDA });
131+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
132+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {self, bat1, bat2});
133+
134+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
135+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {self, bat1, bat2});
136+
137+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
138+ }
0 commit comments