• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_
19 
20 #include <vector>
21 #include <algorithm>
22 #include <utility>
23 #include <memory>
24 
25 #include "utils/hash_map.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "frontend/optimizer/irpass.h"
28 #include "frontend/optimizer/optimizer.h"
29 #include "frontend/optimizer/anf_visitor.h"
30 #include "ir/func_graph.h"
31 #include "ir/func_graph_cloner.h"
32 #include "frontend/operator/ops.h"
33 
34 namespace mindspore {
35 namespace opt {
36 namespace irpass {
37 namespace internal {
38 class CallOutputTransform {
39  public:
CallOutputTransform()40   CallOutputTransform() : cache_() {}
41   ~CallOutputTransform() = default;
42 
operator()43   FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs, bool xs_first) {
44     if (cache_.find(fg) == cache_.end()) {
45       cache_[fg] = {};
46     }
47 
48     auto &cache = cache_[fg];
49     auto key = std::make_pair(nargs, xs_first);
50     if (cache.find(key) == cache.end()) {
51       FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("call"));
52 
53       std::vector<AnfNodePtr> new_items;
54       new_items.push_back(new_fg->output());
55       if (xs_first) {
56         for (size_t i = 0; i < nargs; i++) {
57           new_items.push_back(new_fg->add_parameter());
58         }
59       } else {
60         for (size_t i = 0; i < nargs; i++) {
61           new_items.push_back(new_fg->InsertFrontParameter());
62         }
63       }
64       new_fg->set_output(new_fg->NewCNode(new_items));
65 
66       cache[key] = new_fg;
67     }
68     return cache[key];
69   }
70 
71  private:
72   mindspore::HashMap<FuncGraphPtr, mindspore::HashMap<std::pair<size_t, bool>, FuncGraphPtr, PairHasher>> cache_;
73 };
74 }  // namespace internal
75 
76 // {{G, Xs}, Ys}
77 class IncorporateCall : public AnfVisitor {
78  public:
IncorporateCall()79   IncorporateCall() : call_output_transform_() {}
80   ~IncorporateCall() override = default;
81 
operator()82   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
83     Reset();
84     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
85       return nullptr;
86     }
87 
88     auto &inputs = node->cast<CNodePtr>()->inputs();
89     if (inputs[0] == nullptr || !inputs[0]->isa<CNode>()) {
90       return nullptr;
91     }
92 
93     AnfVisitor::Visit(inputs[0]);
94     if (fg_ == nullptr) {
95       return nullptr;
96     }
97 
98     auto xs_size = Xs_.size();
99     auto ys_size = inputs.size() - 1;
100     bool xs_first = true;
101     if ((xs_size > 0) && (Xs_[xs_size - 1]->abstract() != nullptr) &&
102         (Xs_[xs_size - 1]->abstract()->isa<abstract::AbstractMonad>())) {
103       xs_first = false;
104     }
105     auto new_fg = call_output_transform_(fg_, ys_size, xs_first);
106 
107     std::vector<AnfNodePtr> args;
108     args.push_back(NewValueNode(new_fg));
109 
110     if (xs_first) {
111       if (xs_size > 0) {
112         (void)args.insert(args.cend(), Xs_.cbegin(), Xs_.cend());
113       }
114       if (ys_size > 0) {
115         (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend());
116       }
117     } else {
118       if (ys_size > 0) {
119         (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend());
120       }
121       if (xs_size > 0) {
122         (void)args.insert(args.cend(), Xs_.cbegin(), Xs_.cend());
123       }
124     }
125     return MakeNewNode(node, args);
126   }
127 
MakeNewNode(const AnfNodePtr & node,const std::vector<AnfNodePtr> & args)128   AnfNodePtr MakeNewNode(const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
129     auto new_node = node->func_graph()->NewCNode(args);
130     new_node->set_abstract(node->abstract());
131     // Check if the another only usage of {G, Xs} is UpdateState{s, {G, Xs}}, if yes, replace
132     // UpdateState{s, {G, Xs}} with UpdateState{s, new_node};
133     const auto &manager = fg_->manager();
134     MS_EXCEPTION_IF_NULL(manager);
135     auto &node_users_map = manager->node_users();
136     auto it = node_users_map.find(fg_call_cnode_);
137     if (it != node_users_map.end()) {
138       AnfNodePtr update_state_node = nullptr;
139       auto &node_users = it->second;
140       constexpr size_t users_size = 2;
141       if (node_users.size() == users_size) {
142         for (auto &node_user : node_users) {
143           if (IsPrimitiveCNode(node_user.first, prim::kPrimUpdateState)) {
144             update_state_node = node_user.first;
145           }
146         }
147       }
148       if (update_state_node != nullptr) {
149         auto update_state_cnode = update_state_node->cast<CNodePtr>();
150         // double check;
151         const size_t attach_index = 2;
152         if (update_state_cnode->input(attach_index) == fg_call_cnode_) {
153           constexpr int recursive_level = 2;
154           MS_LOG(DEBUG) << "Replace UpdateState node: " << update_state_cnode->DebugString(recursive_level)
155                         << ", input 2 with: " << new_node->DebugString();
156           manager->SetEdge(update_state_cnode, attach_index, new_node);
157         }
158       }
159     }
160     return new_node;
161   }
162 
Visit(const CNodePtr & cnode)163   void Visit(const CNodePtr &cnode) override {
164     // {G, Xs}
165     if (cnode->size() < 1 || !IsValueNode<FuncGraph>(cnode->input(0))) {
166       return;
167     }
168 
169     auto &inputs = cnode->inputs();
170     fg_ = GetValueNode<FuncGraphPtr>(inputs[0]);
171     fg_call_cnode_ = cnode;
172     (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_));
173   }
174 
Reset()175   void Reset() {
176     Xs_.clear();
177     fg_ = nullptr;
178     fg_call_cnode_ = nullptr;
179   }
180 
181  private:
182   FuncGraphPtr fg_;
183   CNodePtr fg_call_cnode_{nullptr};
184   std::vector<AnfNodePtr> Xs_{};
185   internal::CallOutputTransform call_output_transform_;
186 };
187 
188 // {{{prim::kPrimSwitch, X, G1, G2}, Xs}, Ys}
189 class IncorporateCallSwitch : public AnfVisitor {
190  public:
IncorporateCallSwitch()191   IncorporateCallSwitch() : call_output_transform_() {}
192   ~IncorporateCallSwitch() override = default;
193 
operator()194   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
195     Reset();
196     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
197       return nullptr;
198     }
199 
200     // {{...}, Ys}
201     auto &inputs = node->cast<CNodePtr>()->inputs();
202     if (inputs[0] == nullptr || !inputs[0]->isa<CNode>()) {
203       return nullptr;
204     }
205 
206     // {{...}, Xs}
207     auto &inputs_x = inputs[0]->cast<CNodePtr>()->inputs();
208     if (inputs_x[0] == nullptr || !inputs_x[0]->isa<CNode>()) {
209       return nullptr;
210     }
211 
212     // {prim::kPrimSwitch, X, G1, G2}
213     AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode<FuncGraph>, IsValueNode<FuncGraph>})(inputs_x[0]);
214     if (g2_ == nullptr) {
215       return nullptr;
216     }
217 
218     auto fg = node->func_graph();
219     auto xs_size = inputs_x.size() - 1;
220     auto ys_size = inputs.size() - 1;
221     bool xs_first = true;
222     if ((xs_size > 0) && (inputs_x[xs_size]->abstract() != nullptr) &&
223         (inputs_x[xs_size]->abstract()->isa<abstract::AbstractMonad>())) {
224       xs_first = false;
225     }
226     auto new_g1 = call_output_transform_(g1_, ys_size, xs_first);
227     auto new_g2 = call_output_transform_(g2_, ys_size, xs_first);
228     auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)});
229 
230     std::vector<AnfNodePtr> args{sw_node};
231     if (xs_first) {
232       if (xs_size > 0) {
233         (void)args.insert(args.cend(), inputs_x.cbegin() + 1, inputs_x.cend());
234       }
235       if (ys_size > 0) {
236         (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend());
237       }
238     } else {
239       if (ys_size > 0) {
240         (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend());
241       }
242       if (xs_size > 0) {
243         (void)args.insert(args.cend(), inputs_x.cbegin() + 1, inputs_x.cend());
244       }
245     }
246 
247     auto new_node = fg->NewCNode(args);
248     new_node->set_abstract(node->abstract());
249     return new_node;
250   }
251 
Visit(const AnfNodePtr & node)252   void Visit(const AnfNodePtr &node) override {
253     if (x_ == nullptr) {
254       x_ = node;
255       return;
256     }
257     AnfVisitor::Visit(node);
258   }
259 
Visit(const ValueNodePtr & vnode)260   void Visit(const ValueNodePtr &vnode) override {
261     auto g = GetValueNode<FuncGraphPtr>(vnode);
262     if (g1_ == nullptr) {
263       g1_ = g;
264     } else {
265       g2_ = g;
266     }
267   }
268 
Reset()269   void Reset() {
270     x_ = nullptr;
271     g1_ = nullptr;
272     g2_ = nullptr;
273   }
274 
275  private:
276   AnfNodePtr x_{nullptr};
277   FuncGraphPtr g1_{nullptr}, g2_{nullptr};
278   internal::CallOutputTransform call_output_transform_;
279 };
280 }  // namespace irpass
281 }  // namespace opt
282 }  // namespace mindspore
283 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_
284