33#include " core/util/prelude.h"
44#include " torch/csrc/jit/api/function_impl.h"
55#include " torch/csrc/jit/ir/alias_analysis.h"
6+ #include " torch/csrc/jit/ir/irparser.h"
67#include " torch/csrc/jit/jit_log.h"
78#include " torch/csrc/jit/passes/constant_propagation.h"
89#include " torch/csrc/jit/passes/dead_code_elimination.h"
@@ -16,26 +17,58 @@ namespace core {
1617namespace lowering {
1718namespace passes {
1819
19- void replaceLinearWithBiasNonePattern (std::shared_ptr< torch::jit::Graph> graph ) {
20+ void replaceLinear ( torch::jit::Block* block ) {
2021 // Define the decomposition function for aten::linear for the case where bias (mat2) is None.
2122 static torch::jit::CompilationUnit decompose_funcs (R"SCRIPT(
2223 def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
2324 return torch.matmul(self, mat1.t())
2425 )SCRIPT" );
2526
26- // Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
27- auto block = graph->block ();
27+ // Define graph format for aten::linear with Tensor-type bias
28+ std::string fused_linear = R"IR(
29+ graph(%input, %weight, %bias):
30+ %1: int = prim::Constant[value=1]()
31+ %weight = aten::t(%weight)
32+ %mm: Tensor = aten::matmul(%input, %weight)
33+ %b_f: Tensor = trt::const(%bias)
34+ %out: Tensor = aten::add(%b_f, %mm, %1)
35+ return (%out))IR" ;
36+
37+ // Iterate through nodes in block, seaching for aten::linear
2838 for (auto it = block->nodes ().begin (); it != block->nodes ().end (); it++) {
2939 auto n = *it;
30- if (n->kind ().toQualString () == std::string (" aten::linear" )) {
40+
41+ // Recursively explore nested blocks, such as those arising from prim::If
42+ for (auto block : n->blocks ()) {
43+ replaceLinear (block);
44+ }
45+
46+ if ((n->kind ().toQualString () == std::string (" aten::linear" )) && (n->inputs ().size () >= 3 )) {
3147 auto input_values = n->inputs ();
32- // input_values[2] is the bias. If none, replace it with the decomposed linear graph.
48+
49+ // input_values[2] is the bias
50+ // If Tensor, replace with fused-bias decomposed graph
51+ // Otherwise, replace it with the no-bias decomposed linear graph.
3352 if (input_values[2 ]->type ()->isSubtypeOf (c10::TensorType::get ())) {
34- continue ;
53+ torch::jit::WithInsertPoint guard (*it);
54+
55+ // Initialize new fused subgraph from IR code above
56+ auto fused_g = std::make_shared<torch::jit::Graph>();
57+ torch::jit::parseIR (fused_linear, fused_g.get ());
58+
59+ // Insert subgraph in place of aten::linear, replacing inputs and outputs accordingly
60+ torch::jit::Value* new_output = insertGraph (*it->owningGraph (), *fused_g, it->inputs ()).at (0 );
61+ new_output->setType (it->output ()->type ());
62+ it->output ()->replaceAllUsesWith (new_output);
63+ it.destroyCurrent ();
3564 } else {
3665 torch::jit::WithInsertPoint guard (*it);
66+
67+ // Initialized decomposed graph without bias term
3768 std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction (decompose_funcs.get_function (" linear" )).graph ();
3869 torch::jit::Value* new_output = insertGraph (*it->owningGraph (), *d_graph, it->inputs ()).at (0 );
70+
71+ // Insert function in place of aten::linear, replacing inputs and outputs accordingly
3972 new_output->setType (it->output ()->type ());
4073 it->output ()->replaceAllUsesWith (new_output);
4174 it.destroyCurrent ();
@@ -45,27 +78,8 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
4578}
4679
4780void LinearToAddMM (std::shared_ptr<torch::jit::Graph>& graph) {
48- // TensorRT implicitly adds a flatten layer infront of FC layers if necessary
49- std::string flatten_linear_pattern = R"IR(
50- graph(%input, %weight, %bias):
51- %res = aten::linear(%input, %weight, %bias)
52- return (%res))IR" ;
53-
54- std::string fused_linear = R"IR(
55- graph(%input, %weight_t, %bias):
56- %1: int = prim::Constant[value=1]()
57- %weight = aten::t(%weight_t)
58- %mm: Tensor = aten::matmul(%input, %weight)
59- %b_f: Tensor = trt::const(%bias)
60- %out: Tensor = aten::add(%b_f, %mm, %1)
61- return (%out))IR" ;
62-
63- // First find and replace aten::linear nodes with non-tensor bias values.
64- replaceLinearWithBiasNonePattern (graph);
65-
66- torch::jit::SubgraphRewriter flatten_linear_to_linear;
67- flatten_linear_to_linear.RegisterRewritePattern (flatten_linear_pattern, fused_linear);
68- flatten_linear_to_linear.runOnGraph (graph);
81+ // Recursively find and replace all instances of aten::linear with the corresponding decomposed form
82+ replaceLinear (graph->block ());
6983}
7084
7185} // namespace passes
0 commit comments