@@ -47,6 +47,21 @@ void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* blo
4747 }
4848}
4949
50+ // Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
51+ bool checkLoopEvaluatable (torch::jit::Node* n) {
52+ bool compile_to_trt = true ;
53+ for (auto bn : n->blocks ()[0 ]->nodes ()) {
54+ if (bn->kind () == torch::jit::prim::Loop) {
55+ compile_to_trt = compile_to_trt && checkLoopEvaluatable (bn);
56+ } else if (bn->kind () == torch::jit::prim::If) {
57+ compile_to_trt = compile_to_trt && containNonTensorOutputs (bn);
58+ } else {
59+ compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime (bn);
60+ }
61+ }
62+ return compile_to_trt;
63+ }
64+
5065// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback)
5166// we use a map to indicate the reason why it's fallback to torch
5267// For any node that's not explicitly fallback, we set it to run in TensorRT for now
@@ -59,7 +74,9 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
5974 continue ;
6075 }
6176
62- if (!conversion::OpSupported (n)) {
77+ if (n->kind () == torch::jit::prim::Loop && checkLoopEvaluatable (n)) {
78+ ctx->setNodeExecutorDecision (n, NodeExecutorDecision::kCONVERT );
79+ } else if (!conversion::OpSupported (n)) {
6380 // If the op is not supported by the conversion phase it should run in PyTorch
6481 ctx->setNodeExecutorDecision (n, NodeExecutorDecision::kUNSUPPORTED );
6582 } else if (ctx->forced_fallback_ops .find (n->kind ().toQualString ()) != ctx->forced_fallback_ops .end ()) {
@@ -336,21 +353,6 @@ void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
336353 return ;
337354}
338355
339- // Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
340- bool checkLoopEvaluatable (torch::jit::Node* n) {
341- bool compile_to_trt = true ;
342- for (auto bn : n->blocks ()[0 ]->nodes ()) {
343- if (bn->kind () == torch::jit::prim::Loop) {
344- compile_to_trt = compile_to_trt && checkLoopEvaluatable (bn);
345- } else if (bn->kind () == torch::jit::prim::If) {
346- compile_to_trt = compile_to_trt && containNonTensorOutputs (bn);
347- } else {
348- compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime (bn);
349- }
350- }
351- return compile_to_trt;
352- }
353-
354356void finalizeNewBlock (
355357 PartitionedGraph& g,
356358 SegmentedBlock::SegmentedBlockTarget kind,
@@ -499,20 +501,6 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
499501 finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , cond_node);
500502 segmented_blocks.back ().do_not_merge (true );
501503 continue ;
502- } else if (n->kind () == torch::jit::prim::Loop) {
503- if (!in_prog_pyt_blk_nodes.empty ()) {
504- finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
505- cur_pyt_nodes_uses.clear ();
506- }
507- if (checkLoopEvaluatable (n)) {
508- in_prog_trt_blk_nodes.push_back (n);
509- cur_trt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
510- } else {
511- auto loop_node = std::vector<torch::jit::Node*>{n};
512- finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , loop_node);
513- segmented_blocks.back ().do_not_merge (true );
514- }
515- continue ;
516504 }
517505 in_prog_pyt_blk_nodes.push_back (n);
518506 cur_pyt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
0 commit comments