@@ -138,7 +138,8 @@ partitioning::GraphAndMapping BuildHybridGraph(
138138 torch::jit::Block* block,
139139 CompileSpec cfg,
140140 ir::StaticParams static_params,
141- ir::CollectionTypeMap first_use_types) {
141+ ir::CollectionTypeMap first_use_types,
142+ bool expect_full_compilation = false ) {
142143 auto convert_info = cfg.convert_info ;
143144 auto partitioning_info = cfg.partitioning_info ;
144145
@@ -149,17 +150,20 @@ partitioning::GraphAndMapping BuildHybridGraph(
149150 // TODO: Combine this within partition call
150151 partitioning::populateInputIValues (&partitioning_ctx);
151152
152- partitioning::partition (&partitioning_ctx);
153+ partitioning::partition (&partitioning_ctx, expect_full_compilation );
153154
154155 for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
155156 partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
157+ int num_torch_segments = 0 ;
158+ int num_trt_segments = 0 ;
156159
157160 for (auto & seg_block : segmented_blocks) {
158161 LOG_INFO (" Block segment:" << seg_block);
159162 std::ostringstream trt_engine_id;
160163 trt_engine_id << reinterpret_cast <const int *>(&seg_block);
161164
162165 if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
166+ num_trt_segments++;
163167 auto inputs = seg_block.construct_inputs_spec ();
164168 // update the input ranges for each segments
165169 convert_info.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
@@ -180,8 +184,32 @@ partitioning::GraphAndMapping BuildHybridGraph(
180184 true );
181185
182186 seg_block.update_graph (temp_g);
187+ } else {
188+ num_torch_segments++;
189+
190+ // If full compilation is expected, ensure that all operators in Torch blocks are
191+ // for collections processing
192+ if (expect_full_compilation) {
193+ for (auto torch_node : seg_block.block ()->nodes ()) {
194+ if (partitioning::CollectionNodeKinds.find (torch_node->kind ()) == partitioning::CollectionNodeKinds.end ()) {
195+ TORCHTRT_THROW_ERROR (
196+ " Full compilation specified but node "
197+ << *torch_node
198+ << " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
199+ << " Try recompiling with require_full_compilation=False." );
200+ }
201+ }
202+ }
183203 }
184204 }
205+
206+ // If full compilation is expected, cannot have more than 2 Torch segments
207+ // (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
208+ if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1 )) {
209+ TORCHTRT_THROW_ERROR (
210+ " Full compilation was requested but unable to convert all operations to TensorRT."
211+ << " Try recompiling with require_full_compilation=False." );
212+ }
185213 }
186214
187215 return partitioning::stitch (&partitioning_ctx, block);
@@ -191,7 +219,8 @@ ir::TypeMap MapInputsAndDetermineDTypes(
191219 CompileSpec& cfg,
192220 std::shared_ptr<torch::jit::Graph>& g,
193221 ir::StaticParams& static_params,
194- ir::CollectionTypeMap& first_use_type_map) {
222+ ir::CollectionTypeMap& first_use_type_map,
223+ bool requires_collection_handling = false ) {
195224 cfg.convert_info .collection_input_spec_map =
196225 std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
197226 cfg.partitioning_info .collection_input_spec_map =
@@ -226,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
226255 " Cannot infer input type from calcuations in graph for input "
227256 << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
228257 spec[i].dtype = at::kFloat ;
229- } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info .enabled ) {
258+ } else if (spec[i].dtype_is_user_defined && ( cfg.partitioning_info .enabled || requires_collection_handling) ) {
230259 if (!est_type_opt[i]) {
231260 LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
232261 std::stringstream ss;
@@ -297,6 +326,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
297326 return engine;
298327}
299328
329+ bool userRequestedFallback (CompileSpec& cfg) {
330+ return cfg.lower_info .forced_fallback_modules .size () != 0 ||
331+ cfg.partitioning_info .forced_fallback_operators .size () != 0 ;
332+ }
333+
300334torch::jit::Module CompileGraph (const torch::jit::Module& mod, CompileSpec cfg) {
301335 torch::jit::Module new_mod (mod._ivalue ()->name () + " _trt" );
302336
@@ -315,8 +349,17 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
315349 // Infer the type of an input from the weights of the calculation
316350 auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
317351
352+ // Determine if the block is convertible/has collection output, and based on the result,
353+ // whether full compilation can be expected
354+ auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
355+ auto outputIsCollection = conversion::OutputIsCollection (g->block ());
356+ auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357+
358+ // Determine whether user specifications necessitate partitioning
359+ auto isFallbackRequested = userRequestedFallback (cfg);
360+
318361 // Extract map of IValue to DType
319- auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
362+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, requires_collection_handling );
320363
321364 // Check whether any of the input types are Long
322365 bool user_requested_long = false ;
@@ -330,20 +373,28 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
330373 user_requested_long &= (casts_inserted > 0 );
331374 }
332375
333- auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
334- auto outputIsCollection = conversion::OutputIsCollection (g->block ());
335- if (cfg.partitioning_info .enabled && !user_requested_long &&
336- (cfg.lower_info .forced_fallback_modules .size () == 0 &&
337- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
338- !outputIsCollection) {
376+ // Partitioning is required if:
377+ // 1. User requested some modules/operators fallback
378+ // 2. The block (graph) cannot be converted due to operator coverage
379+ // 3. The output of the graph is a collection
380+ // 4. The user requested a non-TRT data type input
381+ auto isPartitioningRequired =
382+ (isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long);
383+
384+ // The user did not require full compilation, but the model can be fully compiled
385+ if (cfg.partitioning_info .enabled && !isPartitioningRequired) {
339386 LOG_INFO (" Skipping partitioning since model is fully supported" );
340387 }
341388
342- if (cfg.partitioning_info .enabled &&
343- (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
344- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
345- outputIsCollection || user_requested_long)) {
346- auto graph_and_mapping = BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types);
389+ // The user did not require full compilation, and the model can be fully compiled
390+ // or, the user required full compilation but the I/O of the graph use collections
391+ if ((cfg.partitioning_info .enabled && isPartitioningRequired) || requires_collection_handling) {
392+ // If the model is fully-compilable and the user has specified full compilation, run partitioning
393+ // to generate collection-processing code in Torch
394+ auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info .enabled );
395+
396+ auto graph_and_mapping =
397+ BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types, expect_full_compilation);
347398 new_g = graph_and_mapping.first ;
348399 // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
349400 for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
0 commit comments