@@ -81,6 +81,21 @@ TEST(Converters, ATenSignConvertsZerosCorrectly) {
8181 torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
8282}
8383
84+ TEST (Converters, ATenLogicalNotBoolConvertsCorrectly) {
85+ const auto graph = gen_test_graph (" logical_not" );
86+ auto g = std::make_shared<torch::jit::Graph>();
87+ torch::jit::parseIR (graph, g.get ());
88+ auto in = at::randint (0 , 2 , {7 , 3 , 1 , 5 }, {at::kCUDA }).to (torch::kBool );
89+
90+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
91+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
92+
93+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
94+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
95+
96+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
97+ }
98+
8499#define test_unary (unary, name ) \
85100 TEST (Converters, ATen##name##ConvertsCorrectly) { \
86101 const auto graph = gen_test_graph (#unary); \
@@ -122,5 +137,6 @@ test_unary(erf, Erf);
122137test_unary (asinh, Asinh);
123138test_unary (acosh, Acosh);
124139test_unary (atanh, Atanh);
140+ test_unary (logical_not, LogicalNot);
125141
126142#undef test_unary
0 commit comments