1111
1212#include " torch/csrc/jit/frontend/function_schema_parser.h"
1313#include " torch/csrc/jit/ir/ir.h"
14- #include " torch/csrc/jit/ir/ir_views.h"
1514#include " torch/csrc/jit/passes/graph_fuser.h"
1615#include " torch/csrc/jit/passes/loop_unrolling.h"
1716#include " torch/csrc/jit/passes/lower_graph.h"
@@ -128,179 +127,54 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
128127 return conversion::VerifyConverterSupportForBlock (g->block ());
129128}
130129
131- void AddSegmentedBlockToGraph (
132- std::shared_ptr<torch::jit::Graph>& g,
133- partitioning::SegmentedBlock& seg,
134- std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
135- // old_to_new_g contains: original global graph value => new global graph value,
136- // mini_to_new_g: mini graph value -> new graph value
137- std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
138- size_t input_idx = 0 ;
139- if (seg.target () == partitioning::SegmentedBlock::kTensorRT && g->inputs ().size () > 0 ) {
140- if (g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
141- auto self = g->insertInput (0 , " self_1" );
142- self->setType (seg.inputs ()[0 ]->type ());
143- }
144- mini_to_new_g[seg.inputs ()[input_idx++]] = g->inputs ()[0 ];
145- }
146-
147- for (auto & raw_input : seg.raw_inputs ()) {
148- if (old_to_new_g.count (raw_input)) {
149- mini_to_new_g[seg.inputs ()[input_idx++]] = old_to_new_g[raw_input];
150- }
151- }
152-
153- for (const auto n : seg.nodes ()) {
154- util::cloneNode (n, g, mini_to_new_g);
155- }
156-
157- // original graph value => new global graph value
158- for (size_t i = 0 ; i < seg.raw_outputs ().size (); ++i) {
159- old_to_new_g[seg.raw_outputs ()[i]] = mini_to_new_g[seg.outputs ()[i]];
160- }
161- size_t offset = seg.target () == partitioning::SegmentedBlock::kTensorRT ? 1 : 0 ;
162- for (size_t i = 0 ; i < seg.raw_inputs ().size (); ++i) {
163- if (!old_to_new_g.count (seg.raw_inputs ()[i])) {
164- old_to_new_g[seg.raw_inputs ()[i]] = mini_to_new_g[seg.inputs ()[i + offset]];
165- }
166- }
167-
168- return ;
169- }
170-
171- typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
172- GraphAndMapping;
173-
174- void AddIfBlockToGraph (
175- std::shared_ptr<torch::jit::Graph>& new_g,
176- torch::jit::Node* if_node,
177- const std::vector<GraphAndMapping>& graph_and_mappings,
178- std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
179- torch::jit::IfView if_view (if_node);
180-
181- // create a new if node in new_g and add corresponding inputs
182- auto new_if = new_g->insertNode (new_g->create (torch::jit::prim::If, {}, 0 ));
183- new_if->addInput (util::getOrAddInputForValue (if_view.cond (), new_g, old_to_new_g));
184-
185- // iterate over all blocks and add them to new created prim::If
186- for (auto graph_and_mapping : graph_and_mappings) {
187- auto new_if_block = new_if->addBlock ();
188- auto cur_block_graph = graph_and_mapping.first ;
189- auto cur_block_mapping = graph_and_mapping.second ;
190- std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
191- for (auto & i : cur_block_mapping) {
192- // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
193- // it's mini graph's input
194- if (old_to_new_g.count (i.first )) {
195- block_graph_to_new_g[i.second ] = old_to_new_g[i.first ];
196- }
197- }
198-
199- auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue (v, new_g, block_graph_to_new_g); };
200- new_if_block->cloneFrom (cur_block_graph->block (), env);
201- if (cur_block_graph->inputs ().size () &&
202- cur_block_graph->inputs ()[0 ]->type ()->str ().find (" __torch__" ) != std::string::npos) {
203- if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
204- auto self = new_g->insertInput (0 , " self_1" );
205- self->setType (cur_block_graph->inputs ()[0 ]->type ());
206- }
207- block_graph_to_new_g[cur_block_graph->inputs ()[0 ]] = new_g->inputs ()[0 ];
208- }
209- for (int i = cur_block_graph->inputs ().size () - 1 ; i >= 0 ; --i) {
210- new_if_block->inputs ()[i]->replaceAllUsesWith (block_graph_to_new_g[cur_block_graph->inputs ()[i]]);
211- new_if_block->eraseInput (i);
212- }
213- }
214- for (auto ov : if_view.outputs ()) {
215- auto no = new_if->addOutput ();
216- old_to_new_g[ov] = no;
217- no->copyMetadata (ov);
218- }
219- return ;
220- }
221-
222- GraphAndMapping ConstructFallbackGraph (
130+ partitioning::GraphAndMapping BuildHybridGraph (
223131 torch::jit::script::Module& new_mod,
224132 torch::jit::Block* block,
225- std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
226133 CompileSpec cfg,
227134 ir::StaticParams static_params,
228- std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
229- auto convert_cfg = cfg.convert_info ;
230- auto partition_info = cfg.partition_info ;
231-
232- auto new_g = std::make_shared<torch::jit::Graph>();
233-
234- auto segmented_blocks = partitioning::Partition (block, example_tensor_map, partition_info, fallback_nodes);
235-
236- // the mapping from lowering graph => fallback global graph
237- std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
238- for (auto input : block->inputs ()) {
239- util::getOrAddInputForValue (input, new_g, old_to_new_g);
240- }
241-
242- for (auto & seg_block : segmented_blocks) {
243- LOG_INFO (seg_block << " (GraphInSegmentedBlock)\n " );
244- std::ostringstream trt_engine_id;
245- trt_engine_id << reinterpret_cast <const int *>(&seg_block);
246-
247- if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
248- auto shapes = seg_block.in_shapes ();
249- auto types = seg_block.in_types ();
250- std::vector<ir::Input> inputs;
251- for (size_t i = 0 ; i < shapes.size (); i++) {
252- auto in = ir::Input (shapes[i]);
253- in.dtype = util::ScalarTypeToTRTDataType (types[i]);
254- inputs.push_back (in);
255- }
256- // update the input ranges for each segments
257- convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
258-
259- // TODO mapping Inputs Ivalue to flatten one here
260- auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params);
261- auto temp_g = std::make_shared<torch::jit::Graph>();
262- auto device_spec = convert_cfg.engine_settings .device ;
263- auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
264- AddEngineToGraph (new_mod, temp_g, engine, cuda_device, trt_engine_id.str (), true );
265-
266- seg_block.update_graph (temp_g);
267- AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
268- } else {
269- if (seg_block.raw_nodes ()[0 ]->kind () == torch::jit::prim::If) {
270- auto if_node = seg_block.raw_nodes ()[0 ];
271-
272- // convert the 2 blocks in prim::if and get the converted graph with mappings
273- std::vector<GraphAndMapping> graph_and_mappings;
274- for (auto cur_block : if_node->blocks ()) {
275- graph_and_mappings.push_back (
276- ConstructFallbackGraph (new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
135+ ir::CollectionTypeMap first_use_types) {
136+ auto convert_info = cfg.convert_info ;
137+ auto partitioning_info = cfg.partitioning_info ;
138+
139+ auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
140+ auto collection_input_ivalues_map =
141+ partitioning::generateRandomInputs (partitioning_info.collection_input_spec_map , first_use_types);
142+
143+ partitioning::partition (&partitioning_ctx, collection_input_ivalues_map);
144+
145+ for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
146+ partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
147+
148+ for (auto & seg_block : segmented_blocks) {
149+ LOG_INFO (" Block segment:" << seg_block);
150+ std::ostringstream trt_engine_id;
151+ trt_engine_id << reinterpret_cast <const int *>(&seg_block);
152+
153+ if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
154+ auto shapes = seg_block.in_shapes ();
155+ auto types = seg_block.in_types ();
156+ std::vector<ir::Input> inputs;
157+ for (size_t i = 0 ; i < shapes.size (); i++) {
158+ auto in = ir::Input (shapes[i]);
159+ in.dtype = util::ScalarTypeToTRTDataType (types[i]);
160+ inputs.push_back (in);
277161 }
278- AddIfBlockToGraph (new_g, if_node, graph_and_mappings, old_to_new_g);
162+ // update the input ranges for each segments
163+ convert_info.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
279164
280- } else {
281- AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
282- }
283- }
284- }
165+ // TODO mapping Inputs Ivalue to flatten one here
166+ auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_info, static_params);
167+ auto temp_g = std::make_shared<torch::jit::Graph>();
168+ auto device_spec = convert_info.engine_settings .device ;
169+ auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
170+ AddEngineToGraph (new_mod, temp_g, engine, cuda_device, trt_engine_id.str (), true );
285171
286- if (block->outputs ().size () > 1 ) {
287- std::vector<torch::jit::Value*> fallback_graph_vector;
288- for (auto & output : block->outputs ()) {
289- if (old_to_new_g.count (output)) {
290- fallback_graph_vector.push_back (old_to_new_g[output]);
172+ seg_block.update_graph (temp_g);
291173 }
292174 }
293- torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs (fallback_graph_vector);
294- auto return_tuple_node = new_g->createTuple (fallback_graph_outputs);
295- new_g->block ()->appendNode (return_tuple_node);
296- // Set the output as the produced tuple
297- new_g->registerOutput (return_tuple_node->outputs ()[0 ]);
298- } else {
299- if (block->outputs ().size () && old_to_new_g.count (block->outputs ()[0 ])) {
300- new_g->registerOutput (old_to_new_g[block->outputs ()[0 ]]);
301- }
302175 }
303- return {new_g, old_to_new_g};
176+
177+ return partitioning::stitch (&partitioning_ctx, block);
304178}
305179
306180void MapInputsAndDetermineDTypes (
@@ -310,6 +184,8 @@ void MapInputsAndDetermineDTypes(
310184 ir::CollectionTypeMap& first_use_type_map) {
311185 cfg.convert_info .collection_input_spec_map =
312186 std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
187+ cfg.partitioning_info .collection_input_spec_map =
188+ ir::CollectionInputSpecMap (cfg.convert_info .collection_input_spec_map );
313189
314190 auto collection_inputs = ir::get_collection_inputs (g, static_params);
315191 LOG_DEBUG (
@@ -339,7 +215,7 @@ void MapInputsAndDetermineDTypes(
339215 " Cannot infer input type from calcuations in graph for input "
340216 << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
341217 spec[i].dtype = nvinfer1::DataType::kFLOAT ;
342- } else if (spec[i].dtype_is_user_defined && cfg.partition_info .enabled ) {
218+ } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info .enabled ) {
343219 if (!est_type_opt[i]) {
344220 LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
345221 std::stringstream ss;
@@ -424,22 +300,18 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
424300 MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
425301 auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
426302 auto outputIsCollection = conversion::OutputIsCollection (g->block ());
427- if (cfg.partition_info .enabled &&
303+ if (cfg.partitioning_info .enabled &&
428304 (cfg.lower_info .forced_fallback_modules .size () == 0 &&
429- cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
305+ cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
430306 !outputIsCollection) {
431307 LOG_INFO (" Skipping partitioning since model is fully supported" );
432308 }
433309
434- if (cfg.partition_info .enabled &&
310+ if (cfg.partitioning_info .enabled &&
435311 (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
436- cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
312+ cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
437313 outputIsCollection)) {
438- std::unordered_map<torch::jit::Node*, int > fallback_nodes;
439- auto collection_input_ivalues_map =
440- partitioning::generateRandomInputs (cfg.convert_info .collection_input_spec_map , first_use_types);
441- auto graph_and_mapping = ConstructFallbackGraph (
442- new_mod, g->block (), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
314+ auto graph_and_mapping = BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types);
443315 new_g = graph_and_mapping.first ;
444316 // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
445317 for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
0 commit comments