@@ -1122,6 +1122,34 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
11221122 }
11231123}
11241124
1125+ TEST (Converters, ATenUnbindEvaluatedTensor) {
1126+ const auto graph = R"IR(
1127+ graph(%x.1 : Tensor):
1128+ %2 : None = prim::Constant()
1129+ %3 : int[] = aten::size(%x.1)
1130+ %z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2)
1131+ %5 : int = prim::Constant[value=-1]()
1132+ %6 : Tensor[] = aten::unbind(%z.1, %5)
1133+ %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%6)
1134+ return (%o1.1, %o2.1))IR" ;
1135+
1136+ auto in = at::randint (1 , 10 , {2 }, {at::kCUDA });
1137+
1138+ auto g = std::make_shared<torch::jit::Graph>();
1139+
1140+ torch::jit::parseIR (graph, g.get ());
1141+
1142+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1143+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
1144+
1145+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
1146+
1147+ for (size_t i = 0 ; i < jit_results.size (); i++) {
1148+ auto trt = trt_results[i];
1149+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i].cuda (), trt, 2e-6 ));
1150+ }
1151+ }
1152+
11251153TEST (Converters, ScatterValueConvertsCorrectly) {
11261154 const auto graph = R"IR(
11271155 graph(%data : Tensor,
0 commit comments