@@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
187187 return partitioning::stitch (&partitioning_ctx, block);
188188}
189189
190- void MapInputsAndDetermineDTypes (
190+ ir::TypeMap MapInputsAndDetermineDTypes (
191191 CompileSpec& cfg,
192192 std::shared_ptr<torch::jit::Graph>& g,
193193 ir::StaticParams& static_params,
@@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
197197 cfg.partitioning_info .collection_input_spec_map =
198198 ir::CollectionInputSpecMap (cfg.convert_info .collection_input_spec_map );
199199
200+ ir::TypeMap inferred_dtypes;
200201 auto collection_inputs = ir::get_collection_inputs (g, static_params);
201202 LOG_DEBUG (
202203 " In MapInputsAndDetermineDTypes, the g->inputs() size is "
@@ -218,13 +219,13 @@ void MapInputsAndDetermineDTypes(
218219 LOG_INFO (
219220 " Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
220221 << in->debugName () << " has type " << est_type_opt[i].value ());
221- spec[i].dtype = util::ScalarTypeToTRTDataType ( est_type_opt[i].value () );
222+ spec[i].dtype = est_type_opt[i].value ();
222223 } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
223224 // If we cannot calculate the type and the user did not define the type, then default to FP32
224225 LOG_WARNING (
225226 " Cannot infer input type from calcuations in graph for input "
226227 << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
227- spec[i].dtype = nvinfer1::DataType:: kFLOAT ;
228+ spec[i].dtype = at:: kFloat ;
228229 } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info .enabled ) {
229230 if (!est_type_opt[i]) {
230231 LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
@@ -236,37 +237,35 @@ void MapInputsAndDetermineDTypes(
236237 auto warn_str = ss.str ();
237238 LOG_WARNING (warn_str);
238239 // Overwrite type map with user settings
239- first_use_type_map[in][i] = {
240- util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
241-
242- } else {
243- if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) !=
244- est_type_opt[i].value ()) {
245- std::stringstream ss;
246- ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
247- ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
248- ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
249- ss << est_type_opt[i].value () << std::endl;
250- ss << " The compiler is going to use the user setting "
251- << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
252- ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
253- ss << " compatibility with PyTorch's data type convention is required.\n " ;
254- ss << " If you do indeed see errors at runtime either:\n " ;
255- ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
256- ss << " - Disable partial compilation by setting require_full_compilation to True" ;
257- auto warn_str = ss.str ();
258- LOG_WARNING (warn_str);
259- // Overwrite type map with user settings
260- first_use_type_map[in][i] = {
261- util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
262- }
240+ first_use_type_map[in][i] = {cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype };
241+
242+ } else if (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype != est_type_opt[i].value ()) {
243+ std::stringstream ss;
244+ ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
245+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
246+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
247+ ss << est_type_opt[i].value () << std::endl;
248+ ss << " The compiler is going to use the user setting "
249+ << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
250+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
251+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
252+ ss << " If you do indeed see errors at runtime either:\n " ;
253+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
254+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
255+ auto warn_str = ss.str ();
256+ LOG_WARNING (warn_str);
257+ // Overwrite type map with user settings
258+ first_use_type_map[in][i] = {cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype };
263259 }
264260 } else {
265261 // The user defined the type so no changes are necessary
266262 }
263+
264+ // Insert entry for Value pointer and determined ScalarType
265+ inferred_dtypes.insert ({in, {spec[i].dtype }});
267266 }
268267 }
269- // }
268+ return inferred_dtypes;
270269}
271270
272271std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -284,6 +283,15 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
284283
285284 MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
286285
286+ // Ensure none of the specified types are of acceptable input types incompatible with TRT
287+ // Currently, only at::kLong is an acceptable, though TRT-incompatible type
288+ for (auto value_to_dtypes : first_use_types) {
289+ for (auto dtype : value_to_dtypes.second ) {
290+ TORCHTRT_CHECK (
291+ !dtype || dtype.value () != at::kLong , " Cannot specify Int64 input for a model fully compiled in TRT" );
292+ }
293+ }
294+
287295 auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
288296
289297 return engine;
@@ -307,10 +315,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
307315 // Infer the type of an input from the weights of the calculation
308316 auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
309317
310- MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
318+ // Extract map of IValue to DType
319+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
320+
321+ // Check whether any of the input types are Long
322+ bool user_requested_long = false ;
323+ for (auto dtype : type_map) {
324+ user_requested_long |= dtype.second && (dtype.second .value () == at::kLong );
325+ }
326+
327+ // Use dtype map to autocast Tensor-type inputs to Long dtype as necessary
328+ if (cfg.partitioning_info .enabled && cfg.partitioning_info .truncate_long_and_double && user_requested_long) {
329+ auto casts_inserted = lowering::AutocastLongInputs (g, type_map, cfg.lower_info .getGPUDeviceString ());
330+ user_requested_long &= (casts_inserted > 0 );
331+ }
332+
311333 auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
312334 auto outputIsCollection = conversion::OutputIsCollection (g->block ());
313- if (cfg.partitioning_info .enabled &&
335+ if (cfg.partitioning_info .enabled && !user_requested_long &&
314336 (cfg.lower_info .forced_fallback_modules .size () == 0 &&
315337 cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
316338 !outputIsCollection) {
@@ -320,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
320342 if (cfg.partitioning_info .enabled &&
321343 (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
322344 cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
323- outputIsCollection)) {
345+ outputIsCollection || user_requested_long )) {
324346 auto graph_and_mapping = BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types);
325347 new_g = graph_and_mapping.first ;
326348 // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
0 commit comments