1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/optimizer/graph/control_flow_pass.h"
17 #include <vector>
18 #include <memory>
19 #include <algorithm>
20 #include "ops/switch.h"
21 #include "ops/fusion/partial_fusion.h"
22 #include "include/errorcode.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "src/common/log_adapter.h"
25 #include "tools/common/node_util.h"
26 #include "nnacl/op_base.h"
27
28 namespace mindspore::opt {
ReplaceNode(const FuncGraphPtr & fg,const std::unordered_map<AnfNodePtr,AnfNodePtr> & replace_pairs)29 void ControlFlowPass::ReplaceNode(const FuncGraphPtr &fg,
30 const std::unordered_map<AnfNodePtr, AnfNodePtr> &replace_pairs) {
31 for (auto &node : fg->nodes()) {
32 if (!utils::isa<CNodePtr>(node)) {
33 continue;
34 }
35 auto cnode = node->cast<CNodePtr>();
36 MS_ASSERT(cnode != nullptr);
37 auto new_inputs = cnode->inputs();
38 for (auto &input : new_inputs) {
39 if (replace_pairs.find(input) == replace_pairs.end()) {
40 continue;
41 }
42 input = replace_pairs.at(input);
43 }
44 cnode->set_inputs(new_inputs);
45 }
46 }
47
VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg)48 void ControlFlowPass::VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &visited_nodes,
49 const std::vector<AnfNodePtr> &remain_nodes,
50 std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg) {
51 std::deque<AnfNodePtr> nodes{};
52 std::set<AnfNodePtr> visited_nodes_used_by_after_fg_set{};
53 std::set<AnfNodePtr> remain_nodes_set{};
54 nodes.assign(remain_nodes.begin(), remain_nodes.end());
55 while (!nodes.empty()) {
56 auto node = nodes.front();
57 nodes.pop_front();
58 remain_nodes_set.insert(node);
59 if (!utils::isa<CNodePtr>(node)) {
60 continue;
61 }
62 auto cnode = node->cast<CNodePtr>();
63 MS_ASSERT(cnode != nullptr);
64 for (auto &input : cnode->inputs()) {
65 if (visited_nodes.find(input) != visited_nodes.end() &&
66 visited_nodes_used_by_after_fg_set.find(input) == visited_nodes_used_by_after_fg_set.end()) {
67 visited_nodes_used_by_after_fg->push_back(input);
68 visited_nodes_used_by_after_fg_set.insert(input);
69 }
70 }
71 }
72 }
73
GetItemVisitedNums(const std::set<AnfNodePtr> & visited_nodes,const AnfNodePtr & tuple_node)74 size_t ControlFlowPass::GetItemVisitedNums(const std::set<AnfNodePtr> &visited_nodes, const AnfNodePtr &tuple_node) {
75 size_t count = 0;
76 for (auto &node : visited_nodes) {
77 if (!utils::isa<CNodePtr>(node)) {
78 continue;
79 }
80 if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
81 continue;
82 }
83 auto get_item_cnode = node->cast<CNodePtr>();
84 MS_ASSERT(get_item_cnode != nullptr);
85 if (get_item_cnode->inputs()[kCNodeFirstInputIndex] == tuple_node) {
86 count++;
87 }
88 }
89 return count;
90 }
91
MoveGetItemToVisited(const size_t & need_size,const AnfNodePtr & tuple_node,std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)92 void ControlFlowPass::MoveGetItemToVisited(const size_t &need_size, const AnfNodePtr &tuple_node,
93 std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
94 size_t i = 0;
95 for (auto it = remain_nodes->begin(); it != remain_nodes->end();) {
96 if (!utils::isa<CNodePtr>(*it)) {
97 ++it;
98 continue;
99 }
100 if (!CheckPrimitiveType(*it, prim::kPrimTupleGetItem)) {
101 ++it;
102 continue;
103 }
104 auto get_item_cnode = (*it)->cast<CNodePtr>();
105 MS_ASSERT(get_item_cnode != nullptr);
106 if (get_item_cnode->inputs()[kCNodeFirstInputIndex] != tuple_node) {
107 ++it;
108 continue;
109 }
110 i++;
111 visited_nodes->insert(*it);
112 it = remain_nodes->erase(it);
113 if (need_size == i) {
114 return;
115 }
116 }
117 MS_LOG(INFO) << tuple_node->fullname_with_scope() << " not found enough get item, size: " << need_size - i;
118 }
119
BindGetItemNodes(std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)120 void ControlFlowPass::BindGetItemNodes(std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
121 std::deque<AnfNodePtr> multi_output_nodes{};
122 for (auto &node : *visited_nodes) {
123 if (!utils::isa<CNodePtr>(node)) {
124 continue;
125 }
126 if (utils::isa<abstract::AbstractTuple>(node->abstract())) {
127 multi_output_nodes.push_back(node);
128 }
129 }
130
131 while (!multi_output_nodes.empty()) {
132 auto cur_node = multi_output_nodes.front();
133 multi_output_nodes.pop_front();
134 size_t total_getitem_size = cur_node->abstract()->cast<abstract::AbstractTuplePtr>()->size();
135 size_t visited_getitem_size = GetItemVisitedNums(*visited_nodes, cur_node);
136 if (total_getitem_size == visited_getitem_size) {
137 continue;
138 }
139
140 size_t need_getitem_size = total_getitem_size - visited_getitem_size;
141 MoveGetItemToVisited(need_getitem_size, cur_node, visited_nodes, remain_nodes);
142 }
143 }
144
SplitGraph(const FuncGraphPtr & fg,AnfNodePtr * control_flow_node,std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)145 int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow_node,
146 std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
147 auto inputs = fg->get_inputs();
148
149 // notice: fg->nodes() is not work in this pass, cause too many useless parameter have been created.
150 auto node_list = TopoSort(fg->get_return());
151 for (auto &node : node_list) {
152 MS_ASSERT(node != nullptr);
153 if (utils::isa<CNodePtr>(node) &&
154 (CheckPrimitiveType(node, prim::kPrimWhile) || CheckPrimitiveType(node, prim::kPrimIf))) {
155 *control_flow_node = node;
156 break;
157 }
158 }
159
160 std::deque<AnfNodePtr> q;
161 visited_nodes->insert(inputs.begin(), inputs.end());
162 q.push_back(*control_flow_node);
163 while (!q.empty()) {
164 auto node = q.front();
165 q.pop_front();
166 if (!utils::isa<CNodePtr>(node)) {
167 continue;
168 }
169 visited_nodes->insert(node);
170 auto cnode = utils::cast<CNodePtr>(node);
171 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cast ptr failed");
172 for (size_t i = 0; i < cnode->inputs().size(); i++) {
173 auto input = cnode->input(i);
174 if (visited_nodes->find(input) == visited_nodes->end()) {
175 q.push_back(input);
176 }
177 }
178 }
179
180 for (auto &node : node_list) {
181 if (visited_nodes->find(node) == visited_nodes->end()) {
182 remain_nodes->push_back(node);
183 }
184 }
185 visited_nodes->erase(*control_flow_node);
186
187 BindGetItemNodes(visited_nodes, remain_nodes);
188
189 return RET_SUCCESS;
190 }
191
CreateAfterGraph(const FuncGraphPtr & main_fg,const std::vector<AnfNodePtr> & remain_nodes,const CNodePtr & aim_cnode,FuncGraphPtr * after_fg)192 int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
193 const CNodePtr &aim_cnode, FuncGraphPtr *after_fg) {
194 *after_fg = std::make_shared<FuncGraph>();
195 MS_CHECK_TRUE_MSG(*after_fg != nullptr, lite::RET_NULL_PTR, "*after_fg is nullptr");
196 auto manager = main_fg->manager();
197 MS_ASSERT(manager != nullptr);
198 manager->AddFuncGraph(*after_fg);
199 (*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
200 (*after_fg)->set_attr("graph_name", MakeValue(aim_cnode->fullname_with_scope() + "_after_fg"));
201 (*after_fg)->set_manager(main_fg->manager());
202
203 for (auto &cur_node : remain_nodes) {
204 if (cur_node->isa<ValueNode>()) {
205 continue;
206 }
207 if (cur_node == main_fg->get_return()) {
208 continue;
209 }
210 (*after_fg)->AddNode(cur_node);
211 if (!utils::isa<ValueNodePtr>(cur_node)) {
212 cur_node->set_func_graph(*after_fg);
213 }
214 if (cur_node == main_fg->output()) {
215 (*after_fg)->set_output(cur_node, false);
216 }
217 main_fg->DropNode(cur_node);
218 }
219 return RET_SUCCESS;
220 }
221
CreateWhileCondCallNode(const FuncGraphPtr & fg,const CNodePtr & while_cnode,const std::vector<AnfNodePtr> & visited_nodes_used_by_after_fg,CNodePtr * cond_call_cnode,std::vector<AnfNodePtr> * cond_nodes_used_by_after_partial,std::unordered_map<AnfNodePtr,AnfNodePtr> * visited_nodes_and_cond_fg_inputs_replace_pairs)222 int ControlFlowPass::CreateWhileCondCallNode(
223 const FuncGraphPtr &fg, const CNodePtr &while_cnode, const std::vector<AnfNodePtr> &visited_nodes_used_by_after_fg,
224 CNodePtr *cond_call_cnode, std::vector<AnfNodePtr> *cond_nodes_used_by_after_partial,
225 std::unordered_map<AnfNodePtr, AnfNodePtr> *visited_nodes_and_cond_fg_inputs_replace_pairs) {
226 auto cond_vnode = while_cnode->input(kWhileCondIndex);
227 MS_CHECK_TRUE_MSG(cond_vnode != nullptr, lite::RET_NULL_PTR, "cnode is nullptr");
228 auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
229 if (cond_fg == nullptr) {
230 MS_LOG(ERROR) << "Get value as func graph failed.";
231 return RET_FAILED;
232 }
233
234 // create after partial node
235 ValueNodePtr cond_partial_anf_primitive = lite::GetPartialFusionPrim();
236 if (cond_partial_anf_primitive == nullptr) {
237 MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
238 return RET_FAILED;
239 }
240
241 std::vector<AnfNodePtr> cond_partial_cnode_inputs{cond_partial_anf_primitive, cond_vnode};
242 cond_partial_cnode_inputs.insert(cond_partial_cnode_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
243 while_cnode->inputs().end());
244
245 auto origin_cond_fg_inputs = cond_fg->get_inputs();
246 for (auto &item : visited_nodes_used_by_after_fg) {
247 bool found = false;
248 size_t input_index = 0;
249 for (size_t i = kPartialFirstInputSize; i < cond_partial_cnode_inputs.size(); ++i) {
250 if (cond_partial_cnode_inputs[i] == item) {
251 found = true;
252 input_index = i - kPartialFirstInputSize;
253 break;
254 }
255 }
256
257 if (found) {
258 (*visited_nodes_and_cond_fg_inputs_replace_pairs)[item] = origin_cond_fg_inputs.at(input_index);
259 cond_nodes_used_by_after_partial->push_back(origin_cond_fg_inputs.at(input_index));
260 continue;
261 }
262
263 // set after fg inputs to cond_partial_cnode inputs
264 cond_partial_cnode_inputs.push_back(item);
265 auto new_parameter = cond_fg->add_parameter();
266 MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
267 new_parameter->set_name(item->fullname_with_scope() + "_cond_fg_parameter");
268 new_parameter->set_abstract(item->abstract());
269 (*visited_nodes_and_cond_fg_inputs_replace_pairs)[item] = new_parameter;
270 cond_nodes_used_by_after_partial->push_back(new_parameter);
271 }
272
273 auto cond_partial_cnode = fg->NewCNode(cond_partial_cnode_inputs);
274 MS_CHECK_TRUE_MSG(cond_partial_cnode != nullptr, lite::RET_NULL_PTR, "cond_partial_cnode is nullptr");
275 cond_partial_cnode->set_fullname_with_scope("partial_" + cond_fg->get_attr("graph_name")->ToString());
276
277 // insert call node
278 std::vector<AnfNodePtr> call_node_inputs{cond_partial_cnode};
279 *cond_call_cnode = fg->NewCNode(call_node_inputs);
280 MS_CHECK_TRUE_MSG(*cond_call_cnode != nullptr, lite::RET_NULL_PTR, "new cnode is nullptr");
281 (*cond_call_cnode)->set_fullname_with_scope("call_" + cond_partial_cnode->fullname_with_scope());
282
283 return RET_SUCCESS;
284 }
285
CreateWhileBodyPartialNode(const FuncGraphPtr & cond_fg,const CNodePtr & while_cnode,CNodePtr * body_partial_node)286 int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode,
287 CNodePtr *body_partial_node) {
288 auto body_vnode = while_cnode->input(kWhileBodyIndex);
289 MS_CHECK_TRUE_MSG(body_vnode != nullptr, RET_FAILED, "body_vnode is nullptr");
290 auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
291 if (body_fg == nullptr) {
292 MS_LOG(ERROR) << "Get value as func_graph failed.";
293 return RET_FAILED;
294 }
295
296 ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
297 if (partial_anf_primitive == nullptr) {
298 MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
299 return RET_FAILED;
300 }
301
302 std::vector<AnfNodePtr> body_partial_node_inputs{partial_anf_primitive, body_vnode};
303 // set body inputs to body partial inputs
304 auto cond_fg_inputs = cond_fg->get_inputs();
305 body_partial_node_inputs.insert(body_partial_node_inputs.end(), cond_fg_inputs.begin(), cond_fg_inputs.end());
306 *body_partial_node = cond_fg->NewCNode(body_partial_node_inputs);
307 MS_CHECK_TRUE_MSG(*body_partial_node != nullptr, RET_FAILED, "new cnode is nullptr");
308 (*body_partial_node)->set_fullname_with_scope("CNode_" + body_fg->get_attr("graph_name")->ToString());
309
310 // add after inputs for body fg to call cond fg
311 auto body_fg_inputs = body_fg->get_inputs();
312 auto origin_body_fg_inputs_size = body_fg_inputs.size();
313 for (size_t i = origin_body_fg_inputs_size; i < cond_fg_inputs.size(); ++i) {
314 if (!utils::isa<ParameterPtr>(cond_fg_inputs[i])) {
315 MS_LOG(ERROR) << "fg is not right.";
316 return RET_FAILED;
317 }
318 auto new_parameter = body_fg->add_parameter();
319 MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
320 new_parameter->set_name(cond_fg_inputs[i]->fullname_with_scope() + "_body_fg_parameter");
321 new_parameter->set_abstract(cond_fg_inputs[i]->abstract());
322 }
323
324 // call the cond fg
325 auto cond_partial_vnode = NewValueNode(cond_fg);
326 MS_CHECK_TRUE_MSG(cond_partial_vnode != nullptr, lite::RET_NULL_PTR, "cond_partial_vnode is nullptr");
327 std::vector<AnfNodePtr> cond_call_cnode_inputs{cond_partial_vnode};
328 // set body fg output
329 auto body_output = body_fg->output()->cast<CNodePtr>();
330 MS_ASSERT(body_output != nullptr);
331 if (CheckPrimitiveType(body_output, prim::kPrimMakeTuple)) {
332 for (size_t i = 1; i < body_output->inputs().size(); ++i) {
333 cond_call_cnode_inputs.push_back(body_output->input(i));
334 }
335 body_fg->DropNode(body_output);
336 } else {
337 cond_call_cnode_inputs.push_back(body_output);
338 }
339
340 body_fg_inputs = body_fg->get_inputs();
341 for (size_t i = origin_body_fg_inputs_size; i < body_fg_inputs.size(); ++i) {
342 cond_call_cnode_inputs.push_back(body_fg_inputs[i]);
343 }
344
345 auto cond_call_cnode = body_fg->NewCNode(cond_call_cnode_inputs);
346 MS_CHECK_TRUE_MSG(cond_call_cnode != nullptr, lite::RET_NULL_PTR, "cond_call_cnode != nullptr");
347 cond_call_cnode->set_fullname_with_scope(body_fg->get_attr("graph_name")->ToString() + "_call_cond_fg");
348 body_fg->set_output(cond_call_cnode);
349
350 to_process_q.push_back(body_fg);
351 return RET_SUCCESS;
352 }
353
CreateWhileAfterPartialNode(const FuncGraphPtr & main_fg,const FuncGraphPtr & cond_fg,const std::vector<AnfNodePtr> & remain_nodes,const std::vector<AnfNodePtr> & cond_nodes_used_by_after_partial,const std::unordered_map<AnfNodePtr,AnfNodePtr> & visited_nodes_and_cond_fg_inputs_replace_pairs,CNodePtr * while_cnode,CNodePtr * after_partial_cnode)354 int ControlFlowPass::CreateWhileAfterPartialNode(
355 const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
356 const std::vector<AnfNodePtr> &cond_nodes_used_by_after_partial,
357 const std::unordered_map<AnfNodePtr, AnfNodePtr> &visited_nodes_and_cond_fg_inputs_replace_pairs,
358 CNodePtr *while_cnode, CNodePtr *after_partial_cnode) {
359 // create after_fg
360 FuncGraphPtr after_fg = nullptr;
361 if (CreateAfterGraph(main_fg, remain_nodes, *while_cnode, &after_fg) != RET_SUCCESS) {
362 MS_LOG(ERROR) << "CreateAfterGraph failed.";
363 return RET_FAILED;
364 }
365
366 auto after_value_node = NewValueNode(after_fg);
367 MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "after_value_node is nullptr");
368 ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
369 if (partial_anf_primitive == nullptr) {
370 MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
371 return RET_FAILED;
372 }
373
374 std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_inputs_and_after_fg_inputs_replace_pairs{};
375 std::vector<AnfNodePtr> after_partial_cnode_inputs{partial_anf_primitive, after_value_node};
376 auto cond_fg_inputs = cond_fg->get_inputs();
377 for (const auto &node : after_fg->nodes()) {
378 if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
379 continue;
380 }
381 auto get_tuple_item_cnode = node->cast<CNodePtr>();
382 MS_ASSERT(get_tuple_item_cnode != nullptr);
383 MS_ASSERT(get_tuple_item_cnode->inputs().size() == kGetItemInputSize);
384 if (get_tuple_item_cnode->input(kCNodeFirstInputIndex) != *while_cnode) {
385 continue;
386 }
387 auto index_vnode = get_tuple_item_cnode->inputs().at(kCNodeSecondInputIndex);
388 if (!utils::isa<ValueNode>(index_vnode)) {
389 MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node";
390 return RET_FAILED;
391 }
392 auto value_node = utils::cast<ValueNodePtr>(index_vnode);
393 MS_ASSERT(value_node != nullptr);
394
395 auto input_index = value_node->value()->type()->number_type() == kNumberTypeInt64
396 ? GetValue<int64_t>(value_node->value())
397 : GetValue<int>(value_node->value());
398
399 after_partial_cnode_inputs.push_back(cond_fg_inputs.at(input_index));
400 auto new_parameter = after_fg->add_parameter();
401 MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
402 new_parameter->set_name(node->fullname_with_scope() + "_after_partial_parameter");
403 new_parameter->set_abstract(node->abstract());
404 after_partial_inputs_and_after_fg_inputs_replace_pairs[node] = new_parameter;
405 }
406
407 std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_after_fg_replace_pair{};
408 for (auto &input : cond_nodes_used_by_after_partial) {
409 after_partial_cnode_inputs.push_back(visited_nodes_and_cond_fg_inputs_replace_pairs.at(input));
410 auto new_parameter = after_fg->add_parameter();
411 MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
412 new_parameter->set_name(input->fullname_with_scope() + "_after_fg_parameter");
413 new_parameter->set_abstract(input->abstract());
414 visited_nodes_after_fg_replace_pair[visited_nodes_and_cond_fg_inputs_replace_pairs.at(input)] = new_parameter;
415 }
416
417 ReplaceNode(after_fg, visited_nodes_and_cond_fg_inputs_replace_pairs);
418 ReplaceNode(after_fg, after_partial_inputs_and_after_fg_inputs_replace_pairs);
419 ReplaceNode(after_fg, visited_nodes_after_fg_replace_pair);
420 *after_partial_cnode = cond_fg->NewCNode(after_partial_cnode_inputs);
421 MS_CHECK_TRUE_MSG(*after_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
422 (*after_partial_cnode)->set_fullname_with_scope("CNode_" + after_fg->get_attr("graph_name")->ToString());
423 return RET_SUCCESS;
424 }
425
ProcessWhileOp(const FuncGraphPtr & fg,const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,const AnfNodePtr & while_node)426 int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
427 const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &while_node) {
428 if (while_node == nullptr) {
429 MS_LOG(INFO) << "not found while, no need to process.";
430 return RET_SUCCESS;
431 }
432
433 auto while_cnode = while_node->cast<CNodePtr>();
434 MS_ASSERT(while_cnode != nullptr);
435 if (while_cnode->inputs().size() < kWhileMinInputSize) {
436 MS_LOG(ERROR) << "while input is not right.";
437 return RET_FAILED;
438 }
439
440 std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
441 VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
442
443 CNodePtr cond_call_cnode = nullptr;
444 std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_cond_fg_inputs_replace_pairs{};
445 std::vector<AnfNodePtr> cond_nodes_used_by_after_partial{};
446 int ret = CreateWhileCondCallNode(fg, while_cnode, visited_nodes_used_by_after_fg, &cond_call_cnode,
447 &cond_nodes_used_by_after_partial, &visited_nodes_and_cond_fg_inputs_replace_pairs);
448 if (ret != RET_SUCCESS) {
449 MS_LOG(ERROR) << "while create cond call cnode failed, ret: " << ret;
450 return ret;
451 }
452
453 auto cond_fg_cnode = cond_call_cnode->input(kCNodePrimIndex)->cast<CNodePtr>();
454 MS_ASSERT(cond_fg_cnode != nullptr);
455 AnfNodePtr cond_fg_vnode = cond_fg_cnode->input(kCNodeFirstInputIndex);
456 MS_ASSERT(cond_fg_vnode != nullptr);
457 auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_fg_vnode);
458 MS_CHECK_TRUE_MSG(cond_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
459
460 CNodePtr body_partial_node = nullptr;
461 ret = CreateWhileBodyPartialNode(cond_fg, while_cnode, &body_partial_node);
462 if (ret != RET_SUCCESS) {
463 MS_LOG(ERROR) << "while create body partial cnode failed, ret: " << ret;
464 return ret;
465 }
466
467 CNodePtr after_partial_cnode = nullptr;
468 ret = CreateWhileAfterPartialNode(fg, cond_fg, remain_nodes, visited_nodes_used_by_after_fg,
469 visited_nodes_and_cond_fg_inputs_replace_pairs, &while_cnode, &after_partial_cnode);
470 if (ret != RET_SUCCESS) {
471 MS_LOG(ERROR) << "while create after partial cnode failed, ret: " << ret;
472 return ret;
473 }
474
475 // create switch cnode
476 ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
477 if (switch_anf_primitive == nullptr) {
478 MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
479 return lite::RET_ERROR;
480 }
481
482 // insert switch node
483 std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, cond_fg->output(), body_partial_node,
484 after_partial_cnode};
485 auto switch_cnode = cond_fg->NewCNode(switch_node_inputs);
486 MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_ERROR, "NewCnode failed");
487 switch_cnode->set_fullname_with_scope("while-Switch-" + cond_fg->get_attr("graph_name")->ToString());
488
489 // insert call node
490 std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
491 auto call_node = cond_fg->NewCNode(call_node_inputs);
492 MS_CHECK_TRUE_MSG(call_node != nullptr, lite::RET_NULL_PTR, "call_node is nullptr");
493 call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
494 cond_fg->set_output(call_node);
495
496 fg->DropNode(while_cnode);
497 fg->set_output(cond_call_cnode);
498
499 auto after_cnode = after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>();
500 MS_ASSERT(after_cnode != nullptr);
501 auto after_fg = after_cnode->value()->cast<FuncGraphPtr>();
502 if (after_fg == nullptr) {
503 MS_LOG(ERROR) << "after_fg is nullptr.";
504 return RET_FAILED;
505 }
506 to_process_q.push_back(cond_fg);
507 to_process_q.push_back(after_fg);
508 return RET_SUCCESS;
509 }
510
CreateIfPartialNodeExternalInputs(const CNodePtr & if_cnode,const FuncGraphPtr & partial_fg,std::vector<AnfNodePtr> * then_partial_cnode_inputs)511 int ControlFlowPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg,
512 std::vector<AnfNodePtr> *then_partial_cnode_inputs) {
513 auto if_inputs = if_cnode->inputs();
514 auto fg_name_attr = partial_fg->get_attr("graph_name");
515 MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
516 auto partial_fg_name = fg_name_attr->ToString();
517 std::vector<AnfNodePtr> if_external_inputs{};
518 if_external_inputs.assign(if_inputs.begin() + kIfMinInputSize, if_inputs.end());
519 auto origin_then_fg_inputs = partial_fg->get_inputs();
520 if (if_external_inputs.size() < origin_then_fg_inputs.size()) {
521 MS_LOG(ERROR) << "graph is not right.";
522 return RET_FAILED;
523 } else if (if_external_inputs.size() == origin_then_fg_inputs.size()) {
524 then_partial_cnode_inputs->insert(then_partial_cnode_inputs->end(), if_external_inputs.begin(),
525 if_external_inputs.end());
526 return RET_SUCCESS;
527 } else {
528 for (auto &fg_input : origin_then_fg_inputs) {
529 auto fg_input_name = fg_input->fullname_with_scope();
530 auto pos = partial_fg_name.size() + sizeof("_input_");
531 auto pos2 = fg_input_name.find('_', pos);
532 auto idx_str = fg_input_name.substr(pos - 1, pos2 - pos + 1);
533 auto partial_idx = 0;
534 try {
535 partial_idx = std::stoi(idx_str);
536 } catch (const std::exception &e) {
537 MS_LOG(ERROR) << "Get index failed: " << e.what();
538 return RET_FAILED;
539 }
540 then_partial_cnode_inputs->push_back(if_external_inputs.at(partial_idx));
541 }
542 }
543 return RET_SUCCESS;
544 }
545
CreateIfPartialNode(const FuncGraphPtr & fg,const size_t & index,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,CNodePtr * if_cnode,FuncGraphPtr * after_fg,CNodePtr * then_partial_cnode)546 int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index,
547 std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg, CNodePtr *if_cnode,
548 FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode) {
549 auto then_vnode = (*if_cnode)->input(index);
550 MS_ASSERT(then_vnode != nullptr);
551 auto then_fg = GetValueNode<std::shared_ptr<FuncGraph>>(then_vnode);
552 MS_CHECK_TRUE_MSG(then_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
553
554 // create then partial node
555 ValueNodePtr then_partial_anf_primitive = lite::GetPartialFusionPrim();
556 MS_CHECK_TRUE_MSG(then_partial_anf_primitive != nullptr, RET_FAILED, "GetPartialFusionPrim failed.");
557 std::vector<AnfNodePtr> then_partial_cnode_inputs{then_partial_anf_primitive, then_vnode};
558 if (CreateIfPartialNodeExternalInputs(*if_cnode, then_fg, &then_partial_cnode_inputs) != RET_SUCCESS) {
559 MS_LOG(ERROR) << "CreateIfPartialNodeExternalInputs failed.";
560 return RET_FAILED;
561 }
562 std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_after_partial_inputs_replace_pairs{};
563 std::vector<AnfNodePtr> then_nodes_used_by_after_partial{};
564 // set fg inputs to then_partial_cnode inputs
565 auto origin_then_fg_inputs = then_fg->get_inputs();
566 for (auto &item : *visited_nodes_used_by_after_fg) {
567 bool found = false;
568 size_t input_index = 0;
569 for (size_t i = kPartialFirstInputSize; i < then_partial_cnode_inputs.size(); ++i) {
570 if (then_partial_cnode_inputs[i] == item) {
571 found = true;
572 input_index = i - kPartialFirstInputSize;
573 break;
574 }
575 }
576 if (found) {
577 visited_nodes_and_after_partial_inputs_replace_pairs[item] = origin_then_fg_inputs.at(input_index);
578 then_nodes_used_by_after_partial.push_back(origin_then_fg_inputs.at(input_index));
579 continue;
580 }
581
582 // set after fg inputs to cond_partial_cnode inputs
583 then_partial_cnode_inputs.push_back(item);
584 auto new_parameter = then_fg->add_parameter();
585 MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter is nullptr");
586 if (index == kIfThenIndex) {
587 new_parameter->set_name(item->fullname_with_scope() + "_then_fg_parameter");
588 } else {
589 new_parameter->set_name(item->fullname_with_scope() + "_else_fg_parameter");
590 }
591 new_parameter->set_abstract(item->abstract());
592 visited_nodes_and_after_partial_inputs_replace_pairs[item] = new_parameter;
593 then_nodes_used_by_after_partial.push_back(new_parameter);
594 }
595 *then_partial_cnode = fg->NewCNode(then_partial_cnode_inputs);
596 MS_CHECK_TRUE_MSG(*then_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
597 auto fg_name_attr = then_fg->get_attr("graph_name");
598 MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
599 auto then_fg_name = fg_name_attr->ToString();
600 (*then_partial_cnode)->set_fullname_with_scope("partial_" + then_fg_name);
601
602 // create after partial node
603 ValueNodePtr after_partial_anf_primitive = lite::GetPartialFusionPrim();
604 MS_CHECK_TRUE_MSG(after_partial_anf_primitive != nullptr, RET_FAILED, "GetPartialFusionPrim failed.");
605 auto after_value_node = NewValueNode(*after_fg);
606 MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "NewValueNode failed.");
607 // make the right after partial input
608 std::vector<AnfNodePtr> after_partial_cnode_inputs{after_partial_anf_primitive, after_value_node};
609 if (!CheckPrimitiveType(then_fg->output(), prim::kPrimMakeTuple)) {
610 after_partial_cnode_inputs.push_back(then_fg->output());
611 } else {
612 auto then_fg_output = then_fg->output()->cast<CNodePtr>();
613 MS_CHECK_TRUE_MSG(then_fg_output != nullptr, RET_ERROR, "cast ptr failed");
614 for (size_t i = kCNodeFirstInputIndex; i < then_fg_output->inputs().size(); ++i) {
615 after_partial_cnode_inputs.push_back(then_fg_output->input(i));
616 }
617 then_fg->DropNode(then_fg_output);
618 }
619 size_t if_output_size = after_partial_cnode_inputs.size() - kCNodeSecondInputIndex;
620
621 // add after fg inputs to partial node
622 std::copy(then_nodes_used_by_after_partial.begin(), then_nodes_used_by_after_partial.end(),
623 std::back_inserter(after_partial_cnode_inputs));
624 // insert partial node
625 auto after_partial_cnode = then_fg->NewCNode(after_partial_cnode_inputs);
626 MS_CHECK_TRUE_MSG(after_partial_cnode != nullptr, RET_FAILED, "NewCNode failed");
627 auto after_fg_name = (*after_fg)->get_attr("graph_name")->ToString();
628 after_partial_cnode->set_fullname_with_scope("partial_" + after_fg_name);
629
630 // insert call node
631 std::vector<AnfNodePtr> call_node_inputs{after_partial_cnode};
632 auto call_node = then_fg->NewCNode(call_node_inputs);
633 MS_CHECK_TRUE_MSG(call_node != nullptr, RET_FAILED, "NewCNode failed");
634 call_node->set_fullname_with_scope("call_" + after_partial_cnode->fullname_with_scope());
635 then_fg->set_output(call_node);
636 to_process_q.push_back(then_fg);
637 ReplaceNode(*after_fg, visited_nodes_and_after_partial_inputs_replace_pairs);
638
639 // check the inputs of after fg
640 auto after_fg_inputs_size = (*after_fg)->get_inputs().size();
641 if (after_fg_inputs_size == after_partial_cnode_inputs.size() - kPartialFirstInputSize) {
642 MS_LOG(INFO) << "not need add after fg input parameters.";
643 return RET_SUCCESS;
644 }
645
646 // make the inputs of the after fg
647 std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_after_fg_replace_pairs{};
648 for (size_t i = kPartialFirstInputSize; i < after_partial_cnode_inputs.size(); ++i) {
649 auto &input = after_partial_cnode_inputs[i];
650 auto new_parameter = (*after_fg)->add_parameter();
651 MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "add_parameter failed");
652 new_parameter->set_name(std::to_string(i - kPartialFirstInputSize) + "_" + input->fullname_with_scope());
653 new_parameter->set_abstract(input->abstract());
654 if (i < kPartialFirstInputSize + if_output_size) {
655 after_partial_after_fg_replace_pairs[*if_cnode] = new_parameter;
656 } else {
657 after_partial_after_fg_replace_pairs[input] = new_parameter;
658 }
659 }
660 ReplaceNode(*after_fg, after_partial_after_fg_replace_pairs);
661
662 return RET_SUCCESS;
663 }
664
CreateIfElsePartialNode(const FuncGraphPtr & main_fg,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,CNodePtr * if_cnode,FuncGraphPtr * after_fg,CNodePtr * else_partial_cnode)665 int ControlFlowPass::CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
666 std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
667 CNodePtr *if_cnode, FuncGraphPtr *after_fg, CNodePtr *else_partial_cnode) {
668 return CreateIfPartialNode(main_fg, kIfElseIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
669 else_partial_cnode);
670 }
671
CreateIfThenPartialNode(const FuncGraphPtr & main_fg,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,CNodePtr * if_cnode,FuncGraphPtr * after_fg,CNodePtr * then_partial_cnode)672 int ControlFlowPass::CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
673 std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
674 CNodePtr *if_cnode, FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode) {
675 return CreateIfPartialNode(main_fg, kIfThenIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
676 then_partial_cnode);
677 }
678
ProcessIfOp(const FuncGraphPtr & fg,const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,const AnfNodePtr & if_node)679 int ControlFlowPass::ProcessIfOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
680 const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &if_node) {
681 if (if_node == nullptr) {
682 MS_LOG(INFO) << "not found if, no need to process.";
683 return RET_SUCCESS;
684 }
685
686 auto if_cnode = if_node->cast<CNodePtr>();
687 MS_ASSERT(if_cnode != nullptr);
688 if (if_cnode->inputs().size() < kIfMinInputSize) {
689 MS_LOG(ERROR) << "if input is not right.";
690 return RET_FAILED;
691 }
692
693 // create after_fg
694 FuncGraphPtr after_fg = nullptr;
695 if (CreateAfterGraph(fg, remain_nodes, if_cnode, &after_fg) != RET_SUCCESS) {
696 MS_LOG(ERROR) << "CreateAfterGraph failed.";
697 return RET_FAILED;
698 }
699
700 // get fg input which is not used by after_parts
701 std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
702 VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
703
704 CNodePtr then_partial_cnode = nullptr;
705 int ret = CreateIfThenPartialNode(fg, &visited_nodes_used_by_after_fg, &if_cnode, &after_fg, &then_partial_cnode);
706 if (ret != RET_SUCCESS) {
707 MS_LOG(ERROR) << "if create then partial cnode failed, ret: " << ret;
708 return ret;
709 }
710
711 CNodePtr else_partial_cnode = nullptr;
712 ret = CreateIfElsePartialNode(fg, &visited_nodes_used_by_after_fg, &if_cnode, &after_fg, &else_partial_cnode);
713 if (ret != RET_SUCCESS) {
714 MS_LOG(ERROR) << "if create else partial cnode failed, ret: " << ret;
715 return ret;
716 }
717
718 // create switch cnode
719 ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
720 if (switch_anf_primitive == nullptr) {
721 MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
722 return RET_FAILED;
723 }
724
725 // insert switch node
726 std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, if_cnode->input(kIfCondIndex), then_partial_cnode,
727 else_partial_cnode};
728 auto switch_cnode = fg->NewCNode(switch_node_inputs);
729 MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_FAILED, "NewCNode failed");
730 switch_cnode->set_fullname_with_scope("if-Switch-" + fg->get_attr("graph_name")->ToString());
731
732 // insert call node
733 std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
734 auto call_node = fg->NewCNode(call_node_inputs);
735 MS_CHECK_TRUE_MSG(call_node != nullptr, RET_FAILED, "NewCNode failed");
736 call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
737 fg->DropNode(if_cnode);
738 fg->set_output(call_node, true);
739
740 to_process_q.push_back(after_fg);
741 return RET_SUCCESS;
742 }
743
ProcessControlOp(const FuncGraphPtr & fg)744 int ControlFlowPass::ProcessControlOp(const FuncGraphPtr &fg) {
745 if (fg == nullptr) {
746 MS_LOG(ERROR) << "fg is nullptr.";
747 return RET_FAILED;
748 }
749
750 AnfNodePtr control_flow_node = nullptr;
751 std::vector<AnfNodePtr> remain_nodes{};
752 std::set<AnfNodePtr> visited_nodes{};
753 int ret = SplitGraph(fg, &control_flow_node, &visited_nodes, &remain_nodes);
754 if (ret != RET_SUCCESS) {
755 MS_LOG(ERROR) << "SplitGraph failed, ret: " << ret;
756 return ret;
757 }
758
759 if (control_flow_node == nullptr) {
760 MS_LOG(INFO) << "not found control flow op, no need to process.";
761 return RET_SUCCESS;
762 }
763
764 if (CheckPrimitiveType(control_flow_node, prim::kPrimWhile)) {
765 ret = ProcessWhileOp(fg, visited_nodes, remain_nodes, control_flow_node);
766 if (ret != RET_SUCCESS) {
767 MS_LOG(ERROR) << "ProcessWhileOp failed.";
768 return ret;
769 }
770 }
771
772 if (CheckPrimitiveType(control_flow_node, prim::kPrimIf)) {
773 ret = ProcessIfOp(fg, visited_nodes, remain_nodes, control_flow_node);
774 if (ret != RET_SUCCESS) {
775 MS_LOG(ERROR) << "ProcessIfOp failed.";
776 return ret;
777 }
778 }
779 return RET_SUCCESS;
780 }
781
Run(const FuncGraphPtr & fg)782 bool ControlFlowPass::Run(const FuncGraphPtr &fg) {
783 MS_ASSERT(fg != nullptr);
784 to_process_q.push_back(fg);
785 while (!to_process_q.empty()) {
786 auto cur_fg = to_process_q.front();
787 auto cur_fg_name = cur_fg->get_attr("graph_name")->ToString();
788 int ret = ProcessControlOp(cur_fg);
789 if (ret != RET_SUCCESS) {
790 MS_LOG(ERROR) << "ProcessControlOp for graph: " << cur_fg_name << " failed.";
791 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
792 return false;
793 }
794 to_process_q.pop_front();
795 }
796 return true;
797 }
798 } // namespace mindspore::opt
799