• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_
19 
20 #include <vector>
21 #include <algorithm>
22 #include <memory>
23 
24 #include "frontend/optimizer/irpass.h"
25 #include "frontend/optimizer/optimizer.h"
26 #include "frontend/optimizer/anf_visitor.h"
27 #include "frontend/operator/ops.h"
28 
29 namespace mindspore {
30 namespace opt {
31 namespace irpass {
32 // {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} ->
33 // {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}}
34 // {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} ->
35 // {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}}
36 class MergeAddN : public AnfVisitor {
37  public:
operator()38   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
39     Reset();
40     mng_ = optimizer->manager();
41     is_outer_ = true;
42     AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
43     // do not hold this manager
44     mng_ = nullptr;
45     if (!is_match_ || node->func_graph() == nullptr) {
46       return nullptr;
47     }
48 
49     auto cnode = node->cast<CNodePtr>();
50     auto addn = NewValueNode(GetValueNode(cnode->input(0)));
51 
52     // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs}
53     (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple));
54     auto fg = node->func_graph();
55     auto make_node = fg->NewCNode(args_);
56 
57     return fg->NewCNode({addn, make_node});
58   }
59 
Visit(const CNodePtr & cnode)60   void Visit(const CNodePtr &cnode) override {
61     if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
62       return;
63     }
64 
65     auto &inputs = cnode->inputs();
66 
67     if (is_outer_) {
68       (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_));
69 
70       is_outer_ = false;
71       is_inner_ = true;
72 
73       // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}
74       const auto &first_input = inputs.at(1);
75       AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(first_input);
76       if (is_match_) {
77         if (!is_unique(first_input)) {
78           is_match_ = false;
79           return;
80         }
81 
82         if (!IsStateEquivalent(cnode, first_input)) {
83           is_match_ = false;
84           return;
85         }
86 
87         (void)Ys_.erase(Ys_.begin());
88         (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_));
89         (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_));
90         return;
91       }
92 
93       // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}
94       const auto &last_input = inputs.back();
95       AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(last_input);
96       if (is_match_) {
97         if (!is_unique(last_input)) {
98           is_match_ = false;
99           return;
100         }
101 
102         if (!IsStateEquivalent(cnode, last_input)) {
103           is_match_ = false;
104           return;
105         }
106 
107         Ys_.pop_back();
108         (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_));
109         (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_));
110         return;
111       }
112 
113       return;
114     }
115 
116     if (is_inner_) {
117       is_match_ = true;
118       (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
119     }
120   }
121 
is_unique(const AnfNodePtr & node)122   bool is_unique(const AnfNodePtr &node) {
123     auto &node_users = mng_->node_users();
124     if (node_users.find(node) == node_users.end()) {
125       return false;
126     }
127 
128     size_t n_use = node_users[node].size();
129     return n_use == 1;
130   }
131 
Reset()132   void Reset() {
133     Xs_.clear();
134     Ys_.clear();
135     args_.clear();
136     is_inner_ = false;
137     is_outer_ = false;
138     is_match_ = false;
139   }
140 
141  private:
142   FuncGraphManagerPtr mng_{nullptr};
143   std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{};
144   bool is_inner_{false}, is_outer_{false}, is_match_{false};
145 };
146 
147 // {PrimAddN, {kPrimMakeTuple, Xs}}
148 class AddNZeroFilter : public AnfVisitor {
149  public:
operator()150   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
151     Reset();
152     AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
153 
154     if (filtered_Xs_.empty() || node->func_graph() == nullptr) {
155       return nullptr;
156     }
157 
158     // if only two node in filtered_nodes, {make_tuple, x}. return x.
159     if (filtered_Xs_.size() == 2) {
160       return filtered_Xs_[1];
161     }
162 
163     // if only one node in filtered_nodes, all node is zerolike, return one of the input.
164     if (filtered_Xs_.size() == 1 && Xs_.size() > 0) {
165       return Xs_[0];
166     }
167 
168     if (!has_zero_like_) {
169       return nullptr;
170     }
171 
172     auto cnode = node->cast<CNodePtr>();
173     auto addn = NewValueNode(GetValueNode(cnode->input(0)));
174     auto fg = node->func_graph();
175     auto make_tuple = fg->NewCNode(filtered_Xs_);
176     return fg->NewCNode({addn, make_tuple});
177   }
178 
Visit(const CNodePtr & cnode)179   void Visit(const CNodePtr &cnode) override {
180     if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
181       return;
182     }
183 
184     auto &inputs = cnode->inputs();
185     (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
186 
187     // {kPrimMakeTuple, X1, X2, ...}
188     filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple));
189     for (auto &x : Xs_) {
190       if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) {
191         filtered_Xs_.push_back(x);
192       } else {
193         has_zero_like_ = true;
194       }
195     }
196   }
197 
Reset()198   void Reset() {
199     Xs_.clear();
200     filtered_Xs_.clear();
201     has_zero_like_ = false;
202   }
203 
204  private:
205   std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
206   bool has_zero_like_{false};
207 };
208 }  // namespace irpass
209 }  // namespace opt
210 }  // namespace mindspore
211 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_
212