@@ -1224,3 +1224,35 @@ TEST(Converters, WhereConvertsCorrectly) {
12241224
12251225 ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
12261226}
1227+
1228+ TEST (Converters, WhereConvertsMismatchedShapesCorrectly) {
1229+ const auto graph = R"IR(
1230+ graph(%condition : Tensor,
1231+ %x : Tensor,
1232+ %y : Tensor):
1233+ %out : Tensor = aten::where(%condition, %x, %y)
1234+ return (%out))IR" ;
1235+
1236+ auto g = std::make_shared<torch::jit::Graph>();
1237+
1238+ torch::jit::parseIR (graph, g.get ());
1239+
1240+ // As per Torch behavior, the input Tensors are expected to be broadcasted
1241+ // along their respective dimension in the largest-rank Tensor provided
1242+ auto condition = at::randint (0 , 2 , {7 , 5 }, {at::kCUDA }).to (torch::kBool );
1243+ auto x = at::randn ({2 , 7 , 5 }, {at::kCUDA });
1244+ auto y = at::randn ({5 }, {at::kCUDA });
1245+
1246+ auto jit_condition = at::clone (condition);
1247+ auto jit_x = at::clone (x);
1248+ auto jit_y = at::clone (y);
1249+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1250+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_condition, jit_x, jit_y});
1251+
1252+ auto trt_condition = at::clone (condition);
1253+ auto trt_x = at::clone (x);
1254+ auto trt_y = at::clone (y);
1255+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_condition, trt_x, trt_y});
1256+
1257+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
1258+ }
0 commit comments