@@ -98,9 +98,13 @@ std::vector<torch::jit::Node*> getDependencyNodes(
9898 return stk;
9999}
100100
101- void find_all_fallback_nodes (std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
101+ void find_all_fallback_nodes (
102+ std::unordered_map<torch::jit::Node*, int >& initial_fallback_nodes,
103+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
104+ // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
105+ // global_fallback_nodes are the fallback nodes that we maintain globally
102106 std::queue<torch::jit::Node*> q;
103- for (auto & node : fallback_nodes ) {
107+ for (auto & node : initial_fallback_nodes ) {
104108 q.push (node.first );
105109 }
106110
@@ -111,7 +115,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
111115 // for every node that produces this fallback node's NonTensor input, they should fallback too
112116 for (auto input : cur_node->inputs ()) {
113117 if (!isTensor (input) && input->node ()->kind () != torch::jit::prim::Constant &&
114- fallback_nodes .insert ({input->node (), 4 }).second ) {
118+ global_fallback_nodes .insert ({input->node (), 4 }).second ) {
115119 q.push (input->node ());
116120 }
117121 }
@@ -120,7 +124,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
120124 if (!isTensor (output)) {
121125 for (auto use : output->uses ()) {
122126 auto node = use.user ;
123- if (node->kind () != torch::jit::prim::Constant && fallback_nodes .insert ({node, 4 }).second ) {
127+ if (node->kind () != torch::jit::prim::Constant && global_fallback_nodes .insert ({node, 4 }).second ) {
124128 q.push (node);
125129 }
126130 }
@@ -231,6 +235,8 @@ bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::ji
231235 LOG_GRAPH (" Node explicitly set to run in torch: " << util::node_info (n));
232236 } else if (fallback_nodes.at (n) == 2 ) {
233237 LOG_GRAPH (" Node is within a module set to run in torch: " << util::node_info (n));
238+ } else if (fallback_nodes.at (n) == 3 ) {
239+ LOG_GRAPH (" Node fallback to Torch because of min_block_size" << util::node_info (n));
234240 } else {
235241 LOG_GRAPH (
236242 " Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
@@ -284,22 +290,74 @@ void get_fallback_nodes(
284290 return ;
285291}
286292
293+ std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size (
294+ torch::jit::Block* block,
295+ const std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes,
296+ size_t min_block_size) {
297+ auto nodes = block->nodes ();
298+ std::vector<torch::jit::Node*> cur_trt_nodes;
299+ std::vector<torch::jit::Node*> min_block_fallback_nodes;
300+ for (const auto n : nodes) {
301+ if (n->kind () == torch::jit::prim::Constant)
302+ continue ;
303+
304+ // check if current node fallback or not
305+ if (!global_fallback_nodes.count (n)) {
306+ // if this node is not in fallback nodes, then it's in trt segments
307+ cur_trt_nodes.push_back (n);
308+ } else {
309+ if (cur_trt_nodes.size () < min_block_size) {
310+ min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
311+ }
312+ cur_trt_nodes.clear ();
313+ }
314+ }
315+ if (cur_trt_nodes.size () < min_block_size) {
316+ min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
317+ }
318+ return min_block_fallback_nodes;
319+ }
320+
321+ void find_min_block_size_fallback_nodes (
322+ torch::jit::Block* block,
323+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes,
324+ size_t min_block_size) {
325+ // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
326+ auto min_block_fallback_nodes = traverse_nodes_for_min_block_size (block, global_fallback_nodes, min_block_size);
327+ std::unordered_map<torch::jit::Node*, int > initial_fallback_nodes;
328+
329+ // keep fallback until all segments meet the min_block_size requirement
330+ while (!min_block_fallback_nodes.empty ()) {
331+ for (const auto i : min_block_fallback_nodes) {
332+ initial_fallback_nodes.insert ({i, 3 });
333+ }
334+ global_fallback_nodes.insert (initial_fallback_nodes.begin (), initial_fallback_nodes.end ());
335+ // find the fallback nodes because of dependency with min_block_size caused fallback nodes
336+ find_all_fallback_nodes (initial_fallback_nodes, global_fallback_nodes);
337+ // keep traverse the graph until there is no node fallback because of min_block_size
338+ min_block_fallback_nodes = traverse_nodes_for_min_block_size (block, global_fallback_nodes, min_block_size);
339+ }
340+ }
341+
287342PartitionedGraph segment_graph (
288343 torch::jit::Block* block,
289344 const PartitionInfo& partition_info,
290- std::unordered_map<torch::jit::Node*, int >& fallback_nodes ) {
345+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes ) {
291346 auto min_block_size = partition_info.min_block_size ;
292347 std::unordered_set<std::string> forced_fallback_ops (
293348 partition_info.forced_fallback_operators .begin (), partition_info.forced_fallback_operators .end ());
294349
295350 // get the initial fallback nodes (nodes that are unsupported or forced fallback)
296- get_fallback_nodes (block, forced_fallback_ops, fallback_nodes );
351+ get_fallback_nodes (block, forced_fallback_ops, global_fallback_nodes );
297352
298353 // For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this
299354 // input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node
300355 // that produces this input should also fallback
301356 // TODO: don't need to fallback the TensorList related nodes once the collection feature is supported
302- find_all_fallback_nodes (fallback_nodes);
357+ find_all_fallback_nodes (global_fallback_nodes, global_fallback_nodes);
358+
359+ // find all fallback nodes because of the min_block_size requirement
360+ find_min_block_size_fallback_nodes (block, global_fallback_nodes, min_block_size);
303361
304362 auto nodes = block->nodes ();
305363
@@ -313,7 +371,7 @@ PartitionedGraph segment_graph(
313371 continue ;
314372 }
315373
316- if (check_node_fallback (n, fallback_nodes )) {
374+ if (check_node_fallback (n, global_fallback_nodes )) {
317375 in_prog_trt_blk_nodes.push_back (n);
318376
319377 // If there is an active PyTorch block and we have passed the threshold for a valid TRT
@@ -379,11 +437,11 @@ PartitionedGraph Partition(
379437 torch::jit::Block* block,
380438 std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
381439 const PartitionInfo& partition_info,
382- std::unordered_map<torch::jit::Node*, int >& fallback_nodes ) {
440+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes ) {
383441 LOG_DEBUG (partition_info);
384442 // segment lowering global graph into blocks
385443 LOG_DEBUG (" Parititioning source module into PyTorch and TensorRT sub blocks" );
386- PartitionedGraph segmented_blocks = segment_graph (block, partition_info, fallback_nodes );
444+ PartitionedGraph segmented_blocks = segment_graph (block, partition_info, global_fallback_nodes );
387445
388446 // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
389447
0 commit comments