@@ -29,6 +29,29 @@ TEST(Converters, ATenMaxDimConvertsCorrectly) {
2929 torch_tensorrt::tests::util::almostEqual (jit_results[1 ], trt_results[1 ].reshape_as (jit_results[1 ]), 2e-6 ));
3030}
3131
32+ TEST (Converters, ATenMaxDimIntInputConvertsCorrectly) {
33+ const auto graph = R"IR(
34+ graph(%x.1 : Tensor):
35+ %2 : int = prim::Constant[value=0]()
36+ %3 : bool = prim::Constant[value=0]()
37+ %4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
38+ return (%4, %5))IR" ;
39+
40+ auto g = std::make_shared<torch::jit::Graph>();
41+ torch::jit::parseIR (graph, g.get ());
42+
43+ auto in = at::randint (-5 , 5 , {5 , 5 }, {at::kCUDA });
44+
45+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
46+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
47+
48+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
49+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
50+
51+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
52+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[1 ], trt_results[1 ], 2e-6 ));
53+ }
54+
3255TEST (Converters, ATenMinDimConvertsCorrectly) {
3356 const auto graph = R"IR(
3457 graph(%x.1 : Tensor):
@@ -77,6 +100,28 @@ TEST(Converters, ATenArgMaxConvertsCorrectly) {
77100 torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
78101}
79102
103+ TEST (Converters, ATenArgMaxIntInputConvertsCorrectly) {
104+ const auto graph = R"IR(
105+ graph(%x.1 : Tensor):
106+ %2 : int = prim::Constant[value=0]()
107+ %3 : bool = prim::Constant[value=0]()
108+ %4 : Tensor = aten::argmax(%x.1, %2, %3)
109+ return (%4))IR" ;
110+
111+ auto g = std::make_shared<torch::jit::Graph>();
112+ torch::jit::parseIR (graph, g.get ());
113+
114+ auto in = at::randint (-5 , 5 , {5 , 5 }, {at::kCUDA });
115+
116+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
117+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
118+
119+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
120+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
121+
122+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
123+ }
124+
80125TEST (Converters, ATenArgMaxKeepdimConvertsCorrectly) {
81126 const auto graph = R"IR(
82127 graph(%x.1 : Tensor):
0 commit comments