@@ -115,7 +115,7 @@ void find_all_fallback_nodes(
115115 // for every node that produces this fallback node's NonTensor input, they should fallback too
116116 for (auto input : cur_node->inputs ()) {
117117 if (!isTensor (input) && input->node ()->kind () != torch::jit::prim::Constant &&
118- global_fallback_nodes.insert ({input->node (), 4 }).second ) {
118+ global_fallback_nodes.insert ({input->node (), FallbackNodeType:: kNON_TENSOR }).second ) {
119119 q.push (input->node ());
120120 }
121121 }
@@ -124,7 +124,7 @@ void find_all_fallback_nodes(
124124 if (!isTensor (output)) {
125125 for (auto use : output->uses ()) {
126126 auto node = use.user ;
127- if (node->kind () != torch::jit::prim::Constant && global_fallback_nodes.insert ({node, 4 }).second ) {
127+ if (node->kind () != torch::jit::prim::Constant && global_fallback_nodes.insert ({node, FallbackNodeType:: kNON_TENSOR }).second ) {
128128 q.push (node);
129129 }
130130 }
@@ -229,13 +229,13 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
229229
230230bool check_node_fallback (torch::jit::Node* n, const std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
231231 if (fallback_nodes.count (n)) {
232- if (fallback_nodes.at (n) == 0 ) {
232+ if (fallback_nodes.at (n) == FallbackNodeType:: kUNSUPPORTED ) {
233233 LOG_GRAPH (" Node not supported by conversion: " << util::node_info (n));
234- } else if (fallback_nodes.at (n) == 1 ) {
234+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kOPERATOR_FALLBACK ) {
235235 LOG_GRAPH (" Node explicitly set to run in torch: " << util::node_info (n));
236- } else if (fallback_nodes.at (n) == 2 ) {
236+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kMODULE_FALLBACK ) {
237237 LOG_GRAPH (" Node is within a module set to run in torch: " << util::node_info (n));
238- } else if (fallback_nodes.at (n) == 3 ) {
238+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kMIN_BLOCK_FALLBACK ) {
239239 LOG_GRAPH (" Node fallback to Torch because of min_block_size" << util::node_info (n));
240240 } else {
241241 LOG_GRAPH (
@@ -273,18 +273,18 @@ void get_fallback_nodes(
273273
274274 // If the op is not supported by the conversion phase it should run in PyTorch
275275 if (!conversion::OpSupported (n)) {
276- fallback_nodes.insert ({n, 0 });
276+ fallback_nodes.insert ({n, FallbackNodeType:: kUNSUPPORTED });
277277 }
278278
279279 // If the user specifies the op to run in Torch it should run in PyTorch
280280 if (forced_fallback_ops.find (n->kind ().toQualString ()) != forced_fallback_ops.end ()) {
281- fallback_nodes.insert ({n, 1 });
281+ fallback_nodes.insert ({n, FallbackNodeType:: kOPERATOR_FALLBACK });
282282 }
283283
284284 // If the user specifies the module containing this op to run in torch it should run in PyTorch
285285 const auto to_compile_sym = c10::Symbol::attr (" to_compile" );
286286 if (n->hasAttribute (to_compile_sym) && n->i (to_compile_sym) == (int64_t ) false ) {
287- fallback_nodes.insert ({n, 2 });
287+ fallback_nodes.insert ({n, FallbackNodeType:: kMODULE_FALLBACK });
288288 }
289289 }
290290 return ;
@@ -329,7 +329,7 @@ void find_min_block_size_fallback_nodes(
329329 // keep fallback until all segments meet the min_block_size requirement
330330 while (!min_block_fallback_nodes.empty ()) {
331331 for (const auto i : min_block_fallback_nodes) {
332- initial_fallback_nodes.insert ({i, 3 });
332+ initial_fallback_nodes.insert ({i, FallbackNodeType:: kMIN_BLOCK_FALLBACK });
333333 }
334334 global_fallback_nodes.insert (initial_fallback_nodes.begin (), initial_fallback_nodes.end ());
335335 // find the fallback nodes because of dependency with min_block_size caused fallback nodes
0 commit comments