1- #include < torch/csrc/jit/passes/subgraph_rewrite.h >
1+ #include " torch/csrc/jit/passes/dead_code_elimination.h "
22
33#include " core/util/prelude.h"
44
@@ -7,86 +7,52 @@ namespace core {
77namespace lowering {
88namespace passes {
99
10- void RemoveDropout (std::shared_ptr<torch::jit::Graph>& graph) {
11- std::string dropout_pattern = R"IR(
12- graph(%input, %4, %5):
13- %6 = aten::dropout(%input, %4, %5)
14- return (%6))IR" ;
15- std::string no_dropout_pattern = R"IR(
16- graph(%input, %4, %5):
17- return (%input))IR" ;
18-
19- torch::jit::SubgraphRewriter remove_dropout;
20- remove_dropout.RegisterRewritePattern (dropout_pattern, no_dropout_pattern);
21- remove_dropout.runOnGraph (graph);
22-
23- std::string dropout_inplace_pattern = R"IR(
24- graph(%input, %4, %5):
25- %6 = aten::dropout_(%input, %4, %5)
26- return (%6))IR" ;
27- std::string no_dropout_inplace_pattern = R"IR(
28- graph(%input, %4, %5):
29- return (%input))IR" ;
30-
31- torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
32- remove_dropout_inplace_pattern.RegisterRewritePattern (dropout_inplace_pattern, no_dropout_inplace_pattern);
33- remove_dropout_inplace_pattern.runOnGraph (graph);
34-
35- // remove feature_dropout
36- std::string feature_dropout_pattern = R"IR(
37- graph(%input, %4, %5):
38- %6 = aten::feature_dropout(%input, %4, %5)
39- return (%6))IR" ;
40- std::string no_feature_dropout_pattern = R"IR(
41- graph(%input, %4, %5):
42- return (%input))IR" ;
43-
44- torch::jit::SubgraphRewriter remove_feature_dropout_pattern;
45- remove_feature_dropout_pattern.RegisterRewritePattern (feature_dropout_pattern, no_feature_dropout_pattern);
46- remove_feature_dropout_pattern.runOnGraph (graph);
47-
48- // remove feature_dropout inplace
49- std::string feature_dropout_inplace_pattern = R"IR(
50- graph(%input, %4, %5):
51- %6 = aten::feature_dropout_(%input, %4, %5)
52- return (%6))IR" ;
53- std::string no_feature_dropout_inplace_pattern = R"IR(
54- graph(%input, %4, %5):
55- return (%input))IR" ;
56-
57- torch::jit::SubgraphRewriter remove_feature_dropout_inplace_pattern;
58- remove_feature_dropout_inplace_pattern.RegisterRewritePattern (
59- feature_dropout_inplace_pattern, no_feature_dropout_inplace_pattern);
60- remove_feature_dropout_inplace_pattern.runOnGraph (graph);
61-
62- // remove feature_alpha_dropout
63- std::string feature_alpha_dropout_pattern = R"IR(
64- graph(%input, %4, %5):
65- %6 = aten::feature_alpha_dropout(%input, %4, %5)
66- return (%6))IR" ;
67- std::string no_feature_alpha_dropout_pattern = R"IR(
68- graph(%input, %4, %5):
69- return (%input))IR" ;
70-
71- torch::jit::SubgraphRewriter remove_feature_alpha_dropout_pattern;
72- remove_feature_alpha_dropout_pattern.RegisterRewritePattern (
73- feature_alpha_dropout_pattern, no_feature_alpha_dropout_pattern);
74- remove_feature_alpha_dropout_pattern.runOnGraph (graph);
75-
76- // remove feature_alpha_dropout inplace
77- std::string feature_alpha_dropout_inplace_pattern = R"IR(
78- graph(%input, %4, %5):
79- %6 = aten::feature_alpha_dropout_(%input, %4, %5)
80- return (%6))IR" ;
81- std::string no_feature_alpha_dropout_inplace_pattern = R"IR(
82- graph(%input, %4, %5):
83- return (%input))IR" ;
84-
85- torch::jit::SubgraphRewriter remove_feature_alpha_dropout_inplace_pattern;
86- remove_feature_alpha_dropout_inplace_pattern.RegisterRewritePattern (
87- feature_alpha_dropout_inplace_pattern, no_feature_alpha_dropout_inplace_pattern);
88- remove_feature_alpha_dropout_inplace_pattern.runOnGraph (graph);
10+ // Schemas for dropout variants
11+ const std::unordered_set<c10::Symbol> DropoutNodeKinds = {
12+ c10::Symbol::fromQualString (" aten::dropout" ),
13+ c10::Symbol::fromQualString (" aten::dropout_" ),
14+ c10::Symbol::fromQualString (" aten::feature_dropout" ),
15+ c10::Symbol::fromQualString (" aten::feature_dropout_" ),
16+ c10::Symbol::fromQualString (" aten::feature_alpha_dropout" ),
17+ c10::Symbol::fromQualString (" aten::feature_alpha_dropout_" ),
18+ };
19+
20+ void removeDropoutInBlock (torch::jit::Block* block) {
21+ /*
22+ Function adapted from:
23+ torch/csrc/jit/passes/remove_dropout.cpp
24+
25+ Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added
26+ */
27+ std::vector<torch::jit::Node*> dropout_nodes_to_remove;
28+
29+ for (auto node : block->nodes ()) {
30+ // Remove dropout for each member block within a node
31+ for (auto block : node->blocks ()) {
32+ removeDropoutInBlock (block);
33+ }
34+
35+ // For each node having a dropout-variant Schema, remove the node
36+ if (DropoutNodeKinds.find (node->kind ()) != DropoutNodeKinds.end ()) {
37+ // Extract input and output tensors of dropout operator
38+ auto input_value = node->inputs ()[0 ];
39+ auto output_value = node->outputs ()[0 ];
40+
41+ output_value->replaceAllUsesWith (input_value);
42+ dropout_nodes_to_remove.push_back (node);
43+ }
44+ }
45+
46+ // Delete dropout nodes
47+ for (auto del_node : dropout_nodes_to_remove) {
48+ del_node->destroy ();
49+ }
50+ }
8951
52+ void RemoveDropout (std::shared_ptr<torch::jit::Graph>& graph) {
53+ // Remove all instances of dropout variants from graph
54+ removeDropoutInBlock (graph->block ());
55+ torch::jit::EliminateDeadCode (graph);
9056 LOG_GRAPH (" Post remove dropout: " << *graph);
9157}
9258
0 commit comments