@@ -135,6 +135,36 @@ TEST(Converters, ATenBoolToINT32TensorConvertsCorrectly) {
135135 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
136136}
137137
138+
139+ TEST (Converters, ATenToSingleConvertsCorrectly) {
140+ const auto graph = R"IR(
141+ graph(%y.1 : Tensor):
142+ %4 : int = prim::Constant[value=6]()
143+ %5 : bool = prim::Constant[value=0]()
144+ %6 : None = prim::Constant()
145+ %y0.1 : Tensor = aten::to(%y.1, %4, %5, %5, %6)
146+ return (%y0.1))IR" ;
147+
148+ auto g = std::make_shared<torch::jit::Graph>();
149+
150+ torch::jit::parseIR (graph, &*g);
151+
152+ auto in = at::randint (1 , 10 , {3 }, {at::kCUDA });
153+
154+ auto jit_in = at::clone (in);
155+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
156+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
157+
158+ auto trt_in = at::clone (in);
159+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
160+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
161+
162+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
163+ ASSERT_TRUE (jit_results[0 ].scalar_type () == trt.scalar_type ());
164+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
165+ }
166+
167+
138168TEST (Converters, ATenTypeAsConvertsCorrectly) {
139169 const auto graph = R"IR(
140170 graph(%0 : Tensor,
0 commit comments