@@ -48,18 +48,26 @@ int AutocastLongInputs(
4848 auto dtype = dtype_input->second .value ();
4949 // Currently, we do not autocast inputs for which the determined type is not long
5050 if (dtype != at::kLong ) {
51+ LOG_DEBUG (
52+ " Skipping autocast for tensor " << input->debugName () << " , since its dtype is " << dtype
53+ << " and not at::kLong" );
5154 continue ;
5255 }
5356
5457 LOG_DEBUG (" Inserting aten::to casting " << input->debugName () << " to dtype " << dtype);
5558
5659 // Generate cast node sending input tensors to the inferred or specified datatype (long)
60+ torch::jit::Value *const_false, *cuda, *none_val;
61+ if (num_autocasts == 0 ) {
62+ // Only generate constants once and reuse for all autocasts
63+ const_false = g->insertConstant (0 );
64+ const_false->setType (torch::jit::BoolType::get ());
65+ cuda = g->insertConstant (target_device_name);
66+ cuda->setType (torch::jit::DeviceObjType::get ());
67+ none_val = g->insertNode (g->createNone ())->output ();
68+ }
69+
5770 auto const_type = g->insertConstant (dtype);
58- auto const_false = g->insertConstant (0 );
59- const_false->setType (torch::jit::BoolType::get ());
60- auto cuda = g->insertConstant (target_device_name);
61- cuda->setType (torch::jit::DeviceObjType::get ());
62- auto none_val = g->insertNode (g->createNone ())->output ();
6371 auto cast_node = g->create (torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});
6472
6573 // Replace all uses of the original tensor with that of the casted tensor
@@ -73,12 +81,16 @@ int AutocastLongInputs(
7381 }
7482 }
7583
76- LOG_WARNING (
77- " Input tensors to this Torch-TRT engine may have their data types in-place modified "
78- << " if the type does not match the determined required type for TRT. To disable this "
79- << " automatic casting, specify an Input dtype other than Long" );
84+ LOG_GRAPH (" Inserted " << num_autocasts << " autocasts" );
8085
81- LOG_GRAPH (" Graph after Autocast: " << *g);
86+ if (num_autocasts > 0 ) {
87+ LOG_WARNING (
88+ " Data types for input tensors have been modified by inserting "
89+ << " aten::to operations which cast INT64 inputs to INT32. "
90+ << " To disable this, please recompile using INT32 inputs" );
91+
92+ LOG_GRAPH (" Graph after Autocast: " << *g);
93+ }
8294
8395 return num_autocasts;
8496}
0 commit comments