@@ -137,3 +137,33 @@ TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
137137 ASSERT_TRUE (
138138 torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
139139}
140+
141+ TEST (Converters, ATenBatchNormHalfConvertsCorrectly) {
142+ const auto graph = R"IR(
143+ graph(%input : Tensor, %running_var : Half(32, strides=[1], requires_grad=0, device=cuda:0), %running_mean : Half(32, strides=[1], requires_grad=0, device=cuda:0)):
144+ %5 : bool = prim::Constant[value=0]()
145+ %4 : float = prim::Constant[value=0.01]()
146+ %3 : float = prim::Constant[value=0.001]()
147+ %2 : bool = prim::Constant[value=1]()
148+ %8 : Tensor = aten::batch_norm(%input, %running_var, %running_mean, %running_mean, %running_var, %5, %4, %3, %2)
149+ return (%8))IR" ;
150+
151+ auto g = std::make_shared<torch::jit::Graph>();
152+ torch::jit::parseIR (graph, &*g);
153+
154+ auto in = at::randn ({2 , 32 , 5 , 5 }, {at::kCUDA }).to (at::kHalf );
155+ auto mean = at::ones ({32 }, {at::kCUDA }).to (at::kHalf );
156+ auto var = at::zeros ({32 }, {at::kCUDA }).to (at::kHalf );
157+
158+ auto trt_in = at::clone (in);
159+ auto trt_mean = at::clone (mean);
160+ auto trt_var = at::clone (var);
161+
162+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {mean, var});
163+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
164+
165+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_mean, trt_var});
166+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in}, {nvinfer1::DataType::kHALF });
167+
168+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-2 ));
169+ }
0 commit comments