@@ -1138,3 +1138,33 @@ TEST(Converters, ScatterSrcConvertsCorrectly) {
11381138 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
11391139 }
11401140}
1141+
1142+ TEST (Converters, WhereConvertsCorrectly) {
1143+ const auto graph = R"IR(
1144+ graph(%condition : Tensor,
1145+ %x : Tensor,
1146+ %y : Tensor):
1147+ %out : Tensor = aten::where(%condition, %x, %y)
1148+ return (%out))IR" ;
1149+
1150+ auto g = std::make_shared<torch::jit::Graph>();
1151+
1152+ torch::jit::parseIR (graph, g.get ());
1153+
1154+ auto condition = at::randint (0 , 2 , {5 , 5 }, {at::kCUDA }).to (torch::kBool );
1155+ auto x = at::randn ({5 , 5 }, {at::kCUDA });
1156+ auto y = at::randn ({5 , 5 }, {at::kCUDA });
1157+
1158+ auto jit_condition = at::clone (condition);
1159+ auto jit_x = at::clone (x);
1160+ auto jit_y = at::clone (y);
1161+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1162+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_condition, jit_x, jit_y});
1163+
1164+ auto trt_condition = at::clone (condition);
1165+ auto trt_x = at::clone (x);
1166+ auto trt_y = at::clone (y);
1167+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_condition, trt_x, trt_y});
1168+
1169+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
1170+ }
0 commit comments