@@ -83,6 +83,71 @@ TEST(Evaluators, FullEvaluatesCorrectly) {
8383 ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
8484}
8585
86+ TEST (Evaluators, FullLikeEvaluatesCorrectly) {
87+ const auto graph = R"IR(
88+ graph(%x.1 : Tensor):
89+ %9 : None = prim::Constant()
90+ %13 : float = prim::Constant[value=1.3]()
91+ %14 : int = prim::Constant[value=4]()
92+ %35 : Device = prim::Constant[value="cuda:0"]()
93+ %19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9)
94+ return (%19))IR" ;
95+
96+ auto in = at::randint (1 , 10 , {1 , 2 , 3 , 5 }, {at::kCUDA });
97+
98+ auto g = std::make_shared<torch::jit::Graph>();
99+ torch::jit::parseIR (graph, g.get ());
100+
101+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
102+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
103+
104+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
105+ ASSERT_TRUE (jit_results[0 ].toTensor ().dtype () == trt_results[0 ].toTensor ().dtype ());
106+ }
107+
108+ TEST (Evaluators, FullLikeNewDtypeEvaluatesCorrectly) {
109+ const auto graph = R"IR(
110+ graph(%x.1 : Tensor):
111+ %9 : None = prim::Constant()
112+ %13 : Scalar = prim::Constant[value=1]()
113+ %14 : int = prim::Constant[value=11]()
114+ %35 : Device = prim::Constant[value="cuda:0"]()
115+ %19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9)
116+ return (%19))IR" ;
117+
118+ auto in = at::randint (1 , 10 , {1 , 2 , 3 , 5 }, {at::kCUDA });
119+
120+ auto g = std::make_shared<torch::jit::Graph>();
121+ torch::jit::parseIR (graph, g.get ());
122+
123+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
124+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
125+
126+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
127+ ASSERT_TRUE (jit_results[0 ].toTensor ().dtype () == trt_results[0 ].toTensor ().dtype ());
128+ }
129+
130+ TEST (Evaluators, FullLikeOldDtypeEvaluatesCorrectly) {
131+ const auto graph = R"IR(
132+ graph(%x.1 : Tensor):
133+ %9 : None = prim::Constant()
134+ %13 : Scalar = prim::Constant[value=1.5]()
135+ %35 : Device = prim::Constant[value="cuda:0"]()
136+ %19 : Tensor = aten::full_like(%x.1, %13, %9, %9, %35, %9, %9)
137+ return (%19))IR" ;
138+
139+ auto in = at::randint (1 , 10 , {1 , 2 , 3 , 5 }, {at::kCUDA }).to (torch::kInt32 );
140+
141+ auto g = std::make_shared<torch::jit::Graph>();
142+ torch::jit::parseIR (graph, g.get ());
143+
144+ auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g, {in});
145+ auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {in});
146+
147+ ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
148+ ASSERT_TRUE (jit_results[0 ].toTensor ().dtype () == trt_results[0 ].toTensor ().dtype ());
149+ }
150+
86151TEST (Evaluators, OnesDataTypeEvaluatesCorrectly) {
87152 const auto graph = R"IR(
88153 graph(%x.1 : Tensor):
0 commit comments