• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_
19 
20 #include <vector>
21 #include <algorithm>
22 #include <memory>
23 
24 #include "frontend/optimizer/irpass.h"
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "frontend/optimizer/optimizer.h"
29 #include "frontend/optimizer/anf_visitor.h"
30 #include "frontend/operator/ops.h"
31 #include "utils/anf_utils.h"
32 
33 namespace mindspore {
34 namespace opt {
35 namespace irpass {
36 // {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} ->
37 // {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}}
38 // {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} ->
39 // {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}}
40 class MergeAddN : public AnfVisitor {
41  public:
operator()42   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
43     Reset();
44     mng_ = optimizer->manager();
45     is_outer_ = true;
46     AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
47     // do not hold this manager
48     mng_ = nullptr;
49     if (!is_match_ || node->func_graph() == nullptr) {
50       return nullptr;
51     }
52     addn_nodes_.push_back(node);
53 
54     auto cnode = node->cast<CNodePtr>();
55     auto addn = NewValueNode(GetValueNode(cnode->input(0)));
56 
57     // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs}
58     (void)args_.insert(args_.cbegin(), NewValueNode(prim::kPrimMakeTuple));
59     auto fg = node->func_graph();
60     auto make_node = fg->NewCNode(args_);
61 
62     auto new_node = fg->NewCNode({addn, make_node});
63     UpdateDumpFlag(new_node);
64     new_node->AddFusedDebugInfoList(addn_nodes_);
65     return new_node;
66   }
67 
Visit(const CNodePtr & cnode)68   void Visit(const CNodePtr &cnode) override {
69     if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
70       return;
71     }
72 
73     auto &inputs = cnode->inputs();
74 
75     if (is_outer_) {
76       (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_));
77 
78       is_outer_ = false;
79       is_inner_ = true;
80 
81       // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}
82       const auto &first_input = inputs.at(1);
83       AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(first_input);
84       if (is_match_) {
85         if (!is_unique(first_input)) {
86           is_match_ = false;
87           return;
88         }
89 
90         if (!IsStateEquivalent(cnode, first_input)) {
91           is_match_ = false;
92           return;
93         }
94 
95         addn_nodes_.push_back(first_input);
96         (void)Ys_.erase(Ys_.cbegin());
97         (void)std::copy(Xs_.cbegin(), Xs_.cend(), std::back_inserter(args_));
98         (void)std::copy(Ys_.cbegin(), Ys_.cend(), std::back_inserter(args_));
99         return;
100       }
101 
102       // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}
103       const auto &last_input = inputs.back();
104       AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(last_input);
105       if (is_match_) {
106         if (!is_unique(last_input)) {
107           is_match_ = false;
108           return;
109         }
110 
111         if (!IsStateEquivalent(cnode, last_input)) {
112           is_match_ = false;
113           return;
114         }
115 
116         addn_nodes_.push_back(last_input);
117         Ys_.pop_back();
118         (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_));
119         (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_));
120         return;
121       }
122 
123       return;
124     }
125 
126     if (is_inner_) {
127       is_match_ = true;
128       (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
129     }
130   }
131 
is_unique(const AnfNodePtr & node)132   bool is_unique(const AnfNodePtr &node) {
133     auto &node_users = mng_->node_users();
134     if (node_users.find(node) == node_users.end()) {
135       return false;
136     }
137 
138     size_t n_use = node_users[node].size();
139     return n_use == 1;
140   }
141 
Reset()142   void Reset() {
143     Xs_.clear();
144     Ys_.clear();
145     args_.clear();
146     addn_nodes_.clear();
147     is_inner_ = false;
148     is_outer_ = false;
149     is_match_ = false;
150   }
151 
UpdateDumpFlag(const AnfNodePtr & node)152   void UpdateDumpFlag(const AnfNodePtr &node) {
153     if (node == nullptr) {
154       return;
155     }
156     for (const auto &addn : addn_nodes_) {
157       if (AnfUtils::GetDumpFlag(addn)) {
158         AnfUtils::SetDumpFlag(node);
159         return;
160       }
161     }
162   }
163 
164  private:
165   FuncGraphManagerPtr mng_{nullptr};
166   std::vector<AnfNodePtr> Xs_{}, Ys_{}, args_{}, addn_nodes_{};
167   bool is_inner_{false}, is_outer_{false}, is_match_{false};
168 };
169 
170 // {PrimAddN, {kPrimMakeTuple, Xs}}
171 class AddNZeroFilter : public AnfVisitor {
172  public:
operator()173   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
174     Reset();
175     AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
176 
177     if (filtered_Xs_.empty() || node->func_graph() == nullptr) {
178       return nullptr;
179     }
180 
181     // if only two node in filtered_nodes, {make_tuple, x}. return x.
182     constexpr auto input_size = 2;
183     if (filtered_Xs_.size() == input_size) {
184       return filtered_Xs_[1];
185     }
186 
187     // if only one node in filtered_nodes, all node is zerolike, return one of the input.
188     if (filtered_Xs_.size() == 1 && Xs_.size() > 0) {
189       return Xs_[0];
190     }
191 
192     if (!has_zero_like_) {
193       return nullptr;
194     }
195 
196     auto cnode = node->cast<CNodePtr>();
197     auto addn = NewValueNode(GetValueNode(cnode->input(0)));
198     auto fg = node->func_graph();
199     auto make_tuple = fg->NewCNode(filtered_Xs_);
200     return fg->NewCNode({addn, make_tuple});
201   }
202 
Visit(const CNodePtr & cnode)203   void Visit(const CNodePtr &cnode) override {
204     if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
205       return;
206     }
207 
208     auto &inputs = cnode->inputs();
209     (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
210 
211     // {kPrimMakeTuple, X1, X2, ...}
212     filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple));
213     for (auto &x : Xs_) {
214       if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) {
215         filtered_Xs_.push_back(x);
216       } else {
217         has_zero_like_ = true;
218       }
219     }
220   }
221 
Reset()222   void Reset() {
223     Xs_.clear();
224     filtered_Xs_.clear();
225     has_zero_like_ = false;
226   }
227 
228  private:
229   std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
230   bool has_zero_like_{false};
231 };
232 
233 // {PrimAddN, {kPrimMakeTuple, Xs}}
234 class AddNCheckDump : public AnfVisitor {
235  public:
operator()236   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
237     Reset();
238     AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node);
239 
240     // Only handle gradient addn.
241     if (node->scope()->name().find("Gradients/") != 0) {
242       return nullptr;
243     }
244 
245     if (set_dump_) {
246       AnfUtils::SetDumpFlag(node);
247     }
248 
249     return nullptr;
250   }
251 
Visit(const CNodePtr & cnode)252   void Visit(const CNodePtr &cnode) override {
253     if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
254       return;
255     }
256     if (cnode->size() < kSizeThree) {
257       return;
258     }
259 
260     // When all of inputs has dump flag, we need set dump flag for AddN.
261     set_dump_ = true;
262     for (size_t i = 1; i < cnode->size(); ++i) {
263       auto input = cnode->input(i);
264       MS_EXCEPTION_IF_NULL(input);
265       if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem) || IsPrimitiveCNode(input, prim::kPrimDepend)) {
266         input = input->cast<CNodePtr>()->input(kIndexOne);
267       }
268       if (!input->isa<CNode>() || !AnfUtils::GetDumpFlag(input)) {
269         set_dump_ = false;
270         break;
271       }
272     }
273   }
274 
Reset()275   void Reset() { set_dump_ = false; }
276 
277  private:
278   bool set_dump_{false};
279 };
280 }  // namespace irpass
281 }  // namespace opt
282 }  // namespace mindspore
283 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_MERGE_ADDN_H_
284