11#include < torch/csrc/jit/passes/subgraph_rewrite.h>
2+ #include " torch/csrc/jit/ir/irparser.h"
23
34#include " core/util/prelude.h"
45
@@ -7,78 +8,91 @@ namespace core {
78namespace lowering {
89namespace passes {
910
10- void Conv1DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
11- std::string conv1d_pattern = R"IR(
12- graph(%x, %w, %b, %s, %p, %d, %g):
13- %4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
14- return (%4))IR" ;
11+ void replaceConv (
12+ torch::jit::Block* block,
13+ const std::string& node_kind,
14+ const std::string& unwrapped_conv,
15+ const size_t num_input_args) {
16+ // Iterate through nodes in block, seaching for aten::conv*
17+ for (auto it = block->nodes ().begin (); it != block->nodes ().end (); it++) {
18+ auto n = *it;
19+
20+ // Recursively explore nested blocks, such as those arising from prim::If
21+ for (auto nested_block : n->blocks ()) {
22+ replaceConv (nested_block, node_kind, unwrapped_conv, num_input_args);
23+ }
24+
25+ // If node matches desired kind and number of input arguments, replace it
26+ if ((n->kind ().toQualString () == node_kind) && (n->inputs ().size () == num_input_args)) {
27+ // Establish insert point within block
28+ torch::jit::WithInsertPoint guard (*it);
29+
30+ // Initialize new fused subgraph from IR code provided
31+ auto fused_g = std::make_shared<torch::jit::Graph>();
32+ torch::jit::parseIR (unwrapped_conv, fused_g.get ());
33+
34+ // Insert subgraph in place of aten::conv*, replacing inputs and outputs accordingly
35+ torch::jit::Value* new_output = insertGraph (*it->owningGraph (), *fused_g, it->inputs ()).at (0 );
36+ new_output->setType (it->output ()->type ());
37+ it->output ()->replaceAllUsesWith (new_output);
38+ it.destroyCurrent ();
39+ }
40+ }
41+ }
1542
16- std::string convolution_pattern = R"IR(
43+ void Conv1DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
44+ const std::string conv1d_node_kind = " aten::conv1d" ;
45+ const std::string convolution_pattern = R"IR(
1746 graph(%x, %w, %b, %s, %p, %d, %g):
1847 %1 : bool = prim::Constant[value=0]()
1948 %2 : int[] = prim::Constant[value=[0]]()
2049 %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
2150 return (%4))IR" ;
2251
23- torch::jit::SubgraphRewriter map_conv1d_to_convolution;
24- map_conv1d_to_convolution.RegisterRewritePattern (conv1d_pattern, convolution_pattern);
25- map_conv1d_to_convolution.runOnGraph (graph);
52+ // Schema is aten::conv1d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
53+ replaceConv (graph->block (), conv1d_node_kind, convolution_pattern, 7 );
2654 LOG_GRAPH (" Post map conv1d -> _convolution: " << *graph);
2755}
2856
2957void ConvTransposed1DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
30- std::string conv_transpose1d_pattern = R"IR(
31- graph(%x, %w, %b, %s, %p, %o, %g, %d):
32- %4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
33- return (%4))IR" ;
34- std::string convolution_pattern = R"IR(
58+ const std::string conv_transpose1d_node_kind = " aten::conv_transpose1d" ;
59+ const std::string convolution_pattern = R"IR(
3560 graph(%x, %w, %b, %s, %p, %o, %g, %d):
3661 %1 : bool = prim::Constant[value=1]()
3762 %2 : bool = prim::Constant[value=1]()
3863 %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
3964 return (%4))IR" ;
4065
41- torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
42- map_conv_transpose1d_to_convolution.RegisterRewritePattern (conv_transpose1d_pattern, convolution_pattern);
43- map_conv_transpose1d_to_convolution.runOnGraph (graph);
66+ // Schema is aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
67+ replaceConv (graph->block (), conv_transpose1d_node_kind, convolution_pattern, 8 );
4468 LOG_GRAPH (" Post map conv_transpose1d -> _convolution: " << *graph);
4569}
4670
4771void Conv2DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
48- std::string conv2d_pattern = R"IR(
49- graph(%x, %w, %b, %s, %p, %d, %g):
50- %4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
51- return (%4))IR" ;
52- std::string convolution_pattern = R"IR(
72+ const std::string conv2d_node_kind = " aten::conv2d" ;
73+ const std::string convolution_pattern = R"IR(
5374 graph(%x, %w, %b, %s, %p, %d, %g):
5475 %1 : bool = prim::Constant[value=0]()
5576 %2 : int[] = prim::Constant[value=[0, 0]]()
5677 %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
5778 return (%4))IR" ;
5879
59- // replace matmul + add pattern to linear
60- torch::jit::SubgraphRewriter map_conv2d_to_convolution;
61- map_conv2d_to_convolution.RegisterRewritePattern (conv2d_pattern, convolution_pattern);
62- map_conv2d_to_convolution.runOnGraph (graph);
80+ // Schema is aten::conv2d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
81+ replaceConv (graph->block (), conv2d_node_kind, convolution_pattern, 7 );
6382 LOG_GRAPH (" Post map conv2d -> _convolution: " << *graph);
6483}
6584
6685void Conv3DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
67- std::string conv3d_pattern = R"IR(
68- graph(%x, %w, %b, %s, %p, %d, %g):
69- %4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
70- return (%4))IR" ;
71- std::string convolution_pattern = R"IR(
86+ const std::string conv3d_node_kind = " aten::conv3d" ;
87+ const std::string convolution_pattern = R"IR(
7288 graph(%x, %w, %b, %s, %p, %d, %g):
7389 %1 : bool = prim::Constant[value=0]()
7490 %2 : int[] = prim::Constant[value=[0, 0, 0]]()
7591 %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
7692 return (%4))IR" ;
7793
78- // replace matmul + add pattern to linear
79- torch::jit::SubgraphRewriter map_conv3d_to_convolution;
80- map_conv3d_to_convolution.RegisterRewritePattern (conv3d_pattern, convolution_pattern);
81- map_conv3d_to_convolution.runOnGraph (graph);
94+ // Schema is aten::conv3d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
95+ replaceConv (graph->block (), conv3d_node_kind, convolution_pattern, 7 );
8296 LOG_GRAPH (" Post map conv3d -> _convolution: " << *graph);
8397}
8498
0 commit comments