1+ #include " core/lowering/passes/passes.h"
2+ #include " gtest/gtest.h"
3+ #include " torch/csrc/jit/ir/irparser.h"
4+
5+ TEST (LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
6+ // parseIR does not support " = prim::If(%51)" with no return value
7+ /* std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
8+ %3 : NoneType = prim::Constant()
9+ %4 : int = prim::Constant[value=0]()
10+ %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
11+ %47 : Tensor = aten::sum(%x.1, %3)
12+ %49 : Tensor = aten::sum(%y.1, %3)
13+ %50 : Tensor = aten::gt(%47, %49)
14+ %51 : bool = aten::Bool(%50)
15+ = prim::If(%51)
16+ block0():
17+ = prim::RaiseException(%45)
18+ -> ()
19+ block1():
20+ -> ()
21+ %z.1 : Tensor = aten::cat(%mod_list.1, %4)
22+ return (%z.1))IR";*/
23+
24+ auto g = std::make_shared<torch::jit::Graph>();
25+ auto x = g->insertInput (0 , " x" );
26+ auto y = g->insertInput (1 , " y" );
27+ torch::jit::IValue zero (0 );
28+ auto zero_const_val = g->insertConstant (zero);
29+ auto none_const_val = g->insertConstant (torch::jit::IValue ());
30+ torch::jit::IValue except (" EXCEPTION" );
31+ auto except_val = g->insertConstant (except);
32+ auto list_node = g->createList (x->type (), torch::jit::ArrayRef<torch::jit::Value*>(x));
33+ g->insertNode (list_node);
34+ auto sum_x_node = g->create (torch::jit::aten::sum, {x, none_const_val});
35+ g->insertNode (sum_x_node);
36+ auto sum_y_node = g->create (torch::jit::aten::sum, {y, none_const_val});
37+ g->insertNode (sum_y_node);
38+ auto gt_node = g->create (torch::jit::aten::gt, {sum_x_node->output (), sum_y_node->output ()});
39+ g->insertNode (gt_node);
40+ auto bool_node = g->create (torch::jit::aten::Bool, {gt_node->output ()});
41+ bool_node->output ()->setType (torch::jit::BoolType::get ());
42+ g->insertNode (bool_node);
43+ auto if_node = g->create (torch::jit::prim::If, {bool_node->output ()}, 0 );
44+ auto if_block0 = if_node->addBlock ();
45+ auto exception_node = g->create (torch::jit::prim::RaiseException, {except_val}, 0 );
46+ if_block0->appendNode (exception_node);
47+ auto if_block1 = if_node->addBlock ();
48+ g->insertNode (if_node);
49+ auto cat_node = g->create (torch::jit::aten::cat, {list_node->output (), zero_const_val});
50+ g->insertNode (cat_node);
51+ g->registerOutput (cat_node->output ());
52+
53+ torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern (g);
54+ for (auto node : g->nodes ()) {
55+ EXPECT_NE (node, if_node);
56+ }
57+ }
58+
59+ TEST (LoweringPasses, EliminateExceptionOrPassPattern_Block1) {
60+ // parseIR does not support " = prim::If(%51)" with no return value
61+ /* std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
62+ %3 : NoneType = prim::Constant()
63+ %4 : int = prim::Constant[value=0]()
64+ %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
65+ %47 : Tensor = aten::sum(%x.1, %3)
66+ %49 : Tensor = aten::sum(%y.1, %3)
67+ %50 : Tensor = aten::gt(%47, %49)
68+ %51 : bool = aten::Bool(%50)
69+ = prim::If(%51)
70+ block0():
71+ -> ()
72+ block1():
73+ = prim::RaiseException(%45)
74+ -> ()
75+ %z.1 : Tensor = aten::cat(%mod_list.1, %4)
76+ return (%z.1))IR";*/
77+
78+ auto g = std::make_shared<torch::jit::Graph>();
79+ auto x = g->insertInput (0 , " x" );
80+ auto y = g->insertInput (1 , " y" );
81+ torch::jit::IValue zero (0 );
82+ auto zero_const_val = g->insertConstant (zero);
83+ auto none_const_val = g->insertConstant (torch::jit::IValue ());
84+ torch::jit::IValue except (" EXCEPTION" );
85+ auto except_val = g->insertConstant (except);
86+ auto list_node = g->createList (x->type (), torch::jit::ArrayRef<torch::jit::Value*>(x));
87+ g->insertNode (list_node);
88+ auto sum_x_node = g->create (torch::jit::aten::sum, {x, none_const_val});
89+ g->insertNode (sum_x_node);
90+ auto sum_y_node = g->create (torch::jit::aten::sum, {y, none_const_val});
91+ g->insertNode (sum_y_node);
92+ auto gt_node = g->create (torch::jit::aten::gt, {sum_x_node->output (), sum_y_node->output ()});
93+ g->insertNode (gt_node);
94+ auto bool_node = g->create (torch::jit::aten::Bool, {gt_node->output ()});
95+ bool_node->output ()->setType (torch::jit::BoolType::get ());
96+ g->insertNode (bool_node);
97+ auto if_node = g->create (torch::jit::prim::If, {bool_node->output ()}, 0 );
98+ auto if_block0 = if_node->addBlock ();
99+ auto if_block1 = if_node->addBlock ();
100+ auto exception_node = g->create (torch::jit::prim::RaiseException, {except_val}, 0 );
101+ if_block1->appendNode (exception_node);
102+ g->insertNode (if_node);
103+ auto cat_node = g->create (torch::jit::aten::cat, {list_node->output (), zero_const_val});
104+ g->insertNode (cat_node);
105+ g->registerOutput (cat_node->output ());
106+
107+ torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern (g);
108+ for (auto node : g->nodes ()) {
109+ EXPECT_NE (node, if_node);
110+ }
111+ }
112+
113+ TEST (LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
114+ // parseIR does not support " = prim::If(%51)" with no return value
115+ /* std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
116+ %3 : NoneType = prim::Constant()
117+ %4 : int = prim::Constant[value=0]()
118+ %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
119+ %47 : Tensor = aten::sum(%x.1, %3)
120+ %49 : Tensor = aten::sum(%y.1, %3)
121+ %50 : Tensor = aten::gt(%47, %49)
122+ %51 : bool = aten::Bool(%50)
123+ = prim::If(%51)
124+ block0():
125+ %10 : Tensor[] = aten::append(%mod_list.1, %y.1)
126+ -> ()
127+ block1():
128+ -> ()
129+ %z.1 : Tensor = aten::cat(%mod_list.1, %4)
130+ return (%z.1))IR";*/
131+
132+ auto g = std::make_shared<torch::jit::Graph>();
133+ auto x = g->insertInput (0 , " x" );
134+ auto y = g->insertInput (1 , " y" );
135+ torch::jit::IValue zero (0 );
136+ auto zero_const_val = g->insertConstant (zero);
137+ auto none_const_val = g->insertConstant (torch::jit::IValue ());
138+ auto list_node = g->createList (x->type (), torch::jit::ArrayRef<torch::jit::Value*>(x));
139+ g->insertNode (list_node);
140+ auto sum_x_node = g->create (torch::jit::aten::sum, {x, none_const_val});
141+ g->insertNode (sum_x_node);
142+ auto sum_y_node = g->create (torch::jit::aten::sum, {y, none_const_val});
143+ g->insertNode (sum_y_node);
144+ auto gt_node = g->create (torch::jit::aten::gt, {sum_x_node->output (), sum_y_node->output ()});
145+ g->insertNode (gt_node);
146+ auto bool_node = g->create (torch::jit::aten::Bool, {gt_node->output ()});
147+ bool_node->output ()->setType (torch::jit::BoolType::get ());
148+ g->insertNode (bool_node);
149+ auto if_node = g->create (torch::jit::prim::If, {bool_node->output ()}, 0 );
150+ auto if_block0 = if_node->addBlock ();
151+ auto append_node = g->create (torch::jit::aten::append, {list_node->output (), y});
152+ if_block0->appendNode (append_node);
153+ auto if_block1 = if_node->addBlock ();
154+ g->insertNode (if_node);
155+ auto cat_node = g->create (torch::jit::aten::cat, {list_node->output (), zero_const_val});
156+ g->insertNode (cat_node);
157+ g->registerOutput (cat_node->output ());
158+
159+ torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern (g);
160+ int if_count = 0 ;
161+ for (auto node : g->nodes ()) {
162+ if (node == if_node) {
163+ if_count++;
164+ }
165+ }
166+ EXPECT_EQ (1 , if_count);
167+ }
0 commit comments