@@ -82,6 +82,20 @@ void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
8282 LOG_GRAPH (" Post map conv2d -> _convolution: " << *graph);
8383}
8484
85+ void ConvTransposed2DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
86+ const std::string conv_transpose2d_node_kind = " aten::conv_transpose2d" ;
87+ const std::string convolution_pattern = R"IR(
88+ graph(%x, %w, %b, %s, %p, %o, %g, %d):
89+ %1 : bool = prim::Constant[value=1]()
90+ %2 : bool = prim::Constant[value=1]()
91+ %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
92+ return (%4))IR" ;
93+
94+ // Schema is aten::conv_transpose2d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
95+ replaceConv (graph->block (), conv_transpose2d_node_kind, convolution_pattern, 8 );
96+ LOG_GRAPH (" Post map conv_transpose2d -> _convolution: " << *graph);
97+ }
98+
8599void Conv3DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
86100 const std::string conv3d_node_kind = " aten::conv3d" ;
87101 const std::string convolution_pattern = R"IR(
@@ -96,6 +110,20 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
96110 LOG_GRAPH (" Post map conv3d -> _convolution: " << *graph);
97111}
98112
113+ void ConvTransposed3DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
114+ const std::string conv_transpose3d_node_kind = " aten::conv_transpose3d" ;
115+ const std::string convolution_pattern = R"IR(
116+ graph(%x, %w, %b, %s, %p, %o, %g, %d):
117+ %1 : bool = prim::Constant[value=1]()
118+ %2 : bool = prim::Constant[value=1]()
119+ %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
120+ return (%4))IR" ;
121+
122+ // Schema is aten::conv_transpose3d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
123+ replaceConv (graph->block (), conv_transpose3d_node_kind, convolution_pattern, 8 );
124+ LOG_GRAPH (" Post map conv_transpose3d -> _convolution: " << *graph);
125+ }
126+
99127} // namespace passes
100128} // namespace lowering
101129} // namespace core
0 commit comments