@@ -98,9 +98,13 @@ std::vector<torch::jit::Node*> getDependencyNodes(
9898 return stk;
9999}
100100
101- void find_all_fallback_nodes (std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
101+ void find_all_fallback_nodes (
102+ std::unordered_map<torch::jit::Node*, int >& initial_fallback_nodes,
103+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
104+ // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
105+ // global_fallback_nodes are the fallback nodes that we maintain globally
102106 std::queue<torch::jit::Node*> q;
103- for (auto & node : fallback_nodes ) {
107+ for (auto & node : initial_fallback_nodes ) {
104108 q.push (node.first );
105109 }
106110
@@ -111,7 +115,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
111115 // for every node that produces this fallback node's NonTensor input, they should fallback too
112116 for (auto input : cur_node->inputs ()) {
113117 if (!isTensor (input) && input->node ()->kind () != torch::jit::prim::Constant &&
114- fallback_nodes .insert ({input->node (), 4 }).second ) {
118+ global_fallback_nodes .insert ({input->node (), FallbackNodeType:: kNON_TENSOR }).second ) {
115119 q.push (input->node ());
116120 }
117121 }
@@ -120,7 +124,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
120124 if (!isTensor (output)) {
121125 for (auto use : output->uses ()) {
122126 auto node = use.user ;
123- if (node->kind () != torch::jit::prim::Constant && fallback_nodes .insert ({node, 4 }).second ) {
127+ if (node->kind () != torch::jit::prim::Constant && global_fallback_nodes .insert ({node, FallbackNodeType:: kNON_TENSOR }).second ) {
124128 q.push (node);
125129 }
126130 }
@@ -225,12 +229,14 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
225229
226230bool check_node_fallback (torch::jit::Node* n, const std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
227231 if (fallback_nodes.count (n)) {
228- if (fallback_nodes.at (n) == 0 ) {
232+ if (fallback_nodes.at (n) == FallbackNodeType:: kUNSUPPORTED ) {
229233 LOG_GRAPH (" Node not supported by conversion: " << util::node_info (n));
230- } else if (fallback_nodes.at (n) == 1 ) {
234+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kOPERATOR_FALLBACK ) {
231235 LOG_GRAPH (" Node explicitly set to run in torch: " << util::node_info (n));
232- } else if (fallback_nodes.at (n) == 2 ) {
236+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kMODULE_FALLBACK ) {
233237 LOG_GRAPH (" Node is within a module set to run in torch: " << util::node_info (n));
238+ } else if (fallback_nodes.at (n) == FallbackNodeType::kMIN_BLOCK_FALLBACK ) {
239+ LOG_GRAPH (" Node fallback to Torch because of min_block_size" << util::node_info (n));
234240 } else {
235241 LOG_GRAPH (
236242 " Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
@@ -267,39 +273,91 @@ void get_fallback_nodes(
267273
268274 // If the op is not supported by the conversion phase it should run in PyTorch
269275 if (!conversion::OpSupported (n)) {
270- fallback_nodes.insert ({n, 0 });
276+ fallback_nodes.insert ({n, FallbackNodeType:: kUNSUPPORTED });
271277 }
272278
273279 // If the user specifies the op to run in Torch it should run in PyTorch
274280 if (forced_fallback_ops.find (n->kind ().toQualString ()) != forced_fallback_ops.end ()) {
275- fallback_nodes.insert ({n, 1 });
281+ fallback_nodes.insert ({n, FallbackNodeType:: kOPERATOR_FALLBACK });
276282 }
277283
278284 // If the user specifies the module containing this op to run in torch it should run in PyTorch
279285 const auto to_compile_sym = c10::Symbol::attr (" to_compile" );
280286 if (n->hasAttribute (to_compile_sym) && n->i (to_compile_sym) == (int64_t ) false ) {
281- fallback_nodes.insert ({n, 2 });
287+ fallback_nodes.insert ({n, FallbackNodeType:: kMODULE_FALLBACK });
282288 }
283289 }
284290 return ;
285291}
286292
293+ std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size (
294+ torch::jit::Block* block,
295+ const std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes,
296+ size_t min_block_size) {
297+ auto nodes = block->nodes ();
298+ std::vector<torch::jit::Node*> cur_trt_nodes;
299+ std::vector<torch::jit::Node*> min_block_fallback_nodes;
300+ for (const auto n : nodes) {
301+ if (n->kind () == torch::jit::prim::Constant)
302+ continue ;
303+
304+ // check if current node fallback or not
305+ if (!global_fallback_nodes.count (n)) {
306+ // if this node is not in fallback nodes, then it's in trt segments
307+ cur_trt_nodes.push_back (n);
308+ } else {
309+ if (cur_trt_nodes.size () < min_block_size) {
310+ min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
311+ }
312+ cur_trt_nodes.clear ();
313+ }
314+ }
315+ if (cur_trt_nodes.size () < min_block_size) {
316+ min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
317+ }
318+ return min_block_fallback_nodes;
319+ }
320+
321+ void find_min_block_size_fallback_nodes (
322+ torch::jit::Block* block,
323+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes,
324+ size_t min_block_size) {
325+ // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
326+ auto min_block_fallback_nodes = traverse_nodes_for_min_block_size (block, global_fallback_nodes, min_block_size);
327+ std::unordered_map<torch::jit::Node*, int > initial_fallback_nodes;
328+
329+ // keep fallback until all segments meet the min_block_size requirement
330+ while (!min_block_fallback_nodes.empty ()) {
331+ for (const auto i : min_block_fallback_nodes) {
332+ initial_fallback_nodes.insert ({i, FallbackNodeType::kMIN_BLOCK_FALLBACK });
333+ }
334+ global_fallback_nodes.insert (initial_fallback_nodes.begin (), initial_fallback_nodes.end ());
335+ // find the fallback nodes because of dependency with min_block_size caused fallback nodes
336+ find_all_fallback_nodes (initial_fallback_nodes, global_fallback_nodes);
337+ // keep traverse the graph until there is no node fallback because of min_block_size
338+ min_block_fallback_nodes = traverse_nodes_for_min_block_size (block, global_fallback_nodes, min_block_size);
339+ }
340+ }
341+
287342PartitionedGraph segment_graph (
288343 torch::jit::Block* block,
289344 const PartitionInfo& partition_info,
290- std::unordered_map<torch::jit::Node*, int >& fallback_nodes ) {
345+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes ) {
291346 auto min_block_size = partition_info.min_block_size ;
292347 std::unordered_set<std::string> forced_fallback_ops (
293348 partition_info.forced_fallback_operators .begin (), partition_info.forced_fallback_operators .end ());
294349
295350 // get the initial fallback nodes (nodes that are unsupported or forced fallback)
296- get_fallback_nodes (block, forced_fallback_ops, fallback_nodes );
351+ get_fallback_nodes (block, forced_fallback_ops, global_fallback_nodes );
297352
298353 // For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this
299354 // input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node
300355 // that produces this input should also fallback
301356 // TODO: don't need to fallback the TensorList related nodes once the collection feature is supported
302- find_all_fallback_nodes (fallback_nodes);
357+ find_all_fallback_nodes (global_fallback_nodes, global_fallback_nodes);
358+
359+ // find all fallback nodes because of the min_block_size requirement
360+ find_min_block_size_fallback_nodes (block, global_fallback_nodes, min_block_size);
303361
304362 auto nodes = block->nodes ();
305363
@@ -313,7 +371,7 @@ PartitionedGraph segment_graph(
313371 continue ;
314372 }
315373
316- if (check_node_fallback (n, fallback_nodes )) {
374+ if (check_node_fallback (n, global_fallback_nodes )) {
317375 in_prog_trt_blk_nodes.push_back (n);
318376
319377 // If there is an active PyTorch block and we have passed the threshold for a valid TRT
@@ -379,11 +437,11 @@ PartitionedGraph Partition(
379437 torch::jit::Block* block,
380438 std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
381439 const PartitionInfo& partition_info,
382- std::unordered_map<torch::jit::Node*, int >& fallback_nodes ) {
440+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes ) {
383441 LOG_DEBUG (partition_info);
384442 // segment lowering global graph into blocks
385443 LOG_DEBUG (" Parititioning source module into PyTorch and TensorRT sub blocks" );
386- PartitionedGraph segmented_blocks = segment_graph (block, partition_info, fallback_nodes );
444+ PartitionedGraph segmented_blocks = segment_graph (block, partition_info, global_fallback_nodes );
387445
388446 // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
389447
0 commit comments