@@ -165,6 +165,60 @@ TEST(Converters, ATenSelectEmptyTensorConvertsCorrectly) {
165165 ASSERT_TRUE (torch_tensorrt::tests::util::sameShape (jit_results[0 ], trt_results[0 ]));
166166}
167167
168+ TEST (Converters, ATenIndexSelectConvertsCorrectly) {
169+ const auto graph = R"IR(
170+ graph(%0 : Tensor, %index : Int (2)):
171+ %2 : int = prim::Constant[value=0]()
172+ %3 : Tensor = aten::index_select(%0, %2, %index)
173+ return (%3))IR" ;
174+ auto g = std::make_shared<torch::jit::Graph>();
175+ torch::jit::parseIR (graph, g.get ());
176+ auto in = at::randint (1 , 10 , {4 , 4 , 4 }, {at::kCUDA });
177+ auto index = at::randint (0 , 4 , {2 }, {at::kCUDA }).to (torch::kI32 );
178+
179+ auto jit_in = at::clone (in);
180+ auto jit_index = at::clone (index);
181+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_index});
182+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
183+
184+ auto trt_in = at::clone (in);
185+ auto trt_index = at::clone (index);
186+ auto trt_params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_index});
187+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, trt_params, {trt_in});
188+
189+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
190+
191+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
192+ }
193+
194+ TEST (Converters, ATenIndexSelectNegativeDimConvertsCorrectly) {
195+ const auto graph = R"IR(
196+ graph(%0 : Tensor, %index : Int (5)):
197+ %2 : int = prim::Constant[value=-1]()
198+ %3 : Tensor = aten::index_select(%0, %2, %index)
199+ return (%3))IR" ;
200+ auto g = std::make_shared<torch::jit::Graph>();
201+
202+ torch::jit::parseIR (graph, g.get ());
203+
204+ auto in = at::randint (1 , 10 , {5 , 3 , 9 }, {at::kCUDA });
205+ auto index = at::randint (0 , 9 , {5 }, {at::kCUDA }).to (torch::kI32 );
206+
207+ auto jit_in = at::clone (in);
208+ auto jit_index = at::clone (index);
209+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {jit_index});
210+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
211+
212+ auto trt_in = at::clone (in);
213+ auto trt_index = at::clone (index);
214+ auto trt_params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {trt_index});
215+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, trt_params, {trt_in});
216+
217+ auto trt = trt_results[0 ].reshape (jit_results[0 ].sizes ());
218+
219+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
220+ }
221+
168222TEST (Converters, ATenNarrowStartScalarConvertsCorrectly) {
169223 const auto graph = R"IR(
170224 graph(%x.1 : Tensor):
0 commit comments