1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/control_flow_pass.h"
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ops/switch.h"
26 #include "ops/fusion/partial_fusion.h"
27 #include "include/errorcode.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "src/common/log_adapter.h"
30 #include "tools/common/node_util.h"
31 #include "nnacl/op_base.h"
32 #include "include/registry/converter_context.h"
33
34 namespace mindspore::opt {
ReplaceNode(const FuncGraphPtr & fg,const std::unordered_map<AnfNodePtr,AnfNodePtr> & replace_pairs)35 void ControlFlowPass::ReplaceNode(const FuncGraphPtr &fg,
36 const std::unordered_map<AnfNodePtr, AnfNodePtr> &replace_pairs) {
37 for (auto &node : fg->nodes()) {
38 if (!utils::isa<CNodePtr>(node)) {
39 continue;
40 }
41 auto cnode = node->cast<CNodePtr>();
42 MS_ASSERT(cnode != nullptr);
43 auto new_inputs = cnode->inputs();
44 for (auto &input : new_inputs) {
45 if (replace_pairs.find(input) == replace_pairs.end()) {
46 continue;
47 }
48 input = replace_pairs.at(input);
49 }
50 cnode->set_inputs(new_inputs);
51 }
52 }
53
VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg)54 void ControlFlowPass::VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &visited_nodes,
55 const std::vector<AnfNodePtr> &remain_nodes,
56 std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg) {
57 std::deque<AnfNodePtr> nodes{};
58 std::set<AnfNodePtr> visited_nodes_used_by_after_fg_set{};
59 std::set<AnfNodePtr> remain_nodes_set{};
60 nodes.assign(remain_nodes.begin(), remain_nodes.end());
61 while (!nodes.empty()) {
62 auto node = nodes.front();
63 nodes.pop_front();
64 remain_nodes_set.insert(node);
65 if (!utils::isa<CNodePtr>(node)) {
66 continue;
67 }
68 auto cnode = node->cast<CNodePtr>();
69 MS_ASSERT(cnode != nullptr);
70 for (auto &input : cnode->inputs()) {
71 if (visited_nodes.find(input) != visited_nodes.end() &&
72 visited_nodes_used_by_after_fg_set.find(input) == visited_nodes_used_by_after_fg_set.end()) {
73 visited_nodes_used_by_after_fg->push_back(input);
74 visited_nodes_used_by_after_fg_set.insert(input);
75 }
76 }
77 }
78 }
79
GetItemVisitedNums(const std::set<AnfNodePtr> & visited_nodes,const AnfNodePtr & tuple_node)80 size_t ControlFlowPass::GetItemVisitedNums(const std::set<AnfNodePtr> &visited_nodes, const AnfNodePtr &tuple_node) {
81 size_t count = 0;
82 for (auto &node : visited_nodes) {
83 if (!utils::isa<CNodePtr>(node)) {
84 continue;
85 }
86 if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
87 continue;
88 }
89 auto get_item_cnode = node->cast<CNodePtr>();
90 MS_ASSERT(get_item_cnode != nullptr);
91 if (get_item_cnode->inputs()[kCNodeFirstInputIndex] == tuple_node) {
92 count++;
93 }
94 }
95 return count;
96 }
97
MoveGetItemToVisited(const size_t & need_size,const AnfNodePtr & tuple_node,std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)98 void ControlFlowPass::MoveGetItemToVisited(const size_t &need_size, const AnfNodePtr &tuple_node,
99 std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
100 size_t i = 0;
101 for (auto it = remain_nodes->begin(); it != remain_nodes->end();) {
102 if (!utils::isa<CNodePtr>(*it)) {
103 ++it;
104 continue;
105 }
106 if (!CheckPrimitiveType(*it, prim::kPrimTupleGetItem)) {
107 ++it;
108 continue;
109 }
110 auto get_item_cnode = (*it)->cast<CNodePtr>();
111 MS_ASSERT(get_item_cnode != nullptr);
112 if (get_item_cnode->inputs()[kCNodeFirstInputIndex] != tuple_node) {
113 ++it;
114 continue;
115 }
116 i++;
117 visited_nodes->insert(*it);
118 it = remain_nodes->erase(it);
119 if (need_size == i) {
120 return;
121 }
122 }
123 MS_LOG(INFO) << tuple_node->fullname_with_scope() << " not found enough get item, size: " << need_size - i;
124 }
125
BindGetItemNodes(std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)126 void ControlFlowPass::BindGetItemNodes(std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
127 std::deque<AnfNodePtr> multi_output_nodes{};
128 for (auto &node : *visited_nodes) {
129 if (!utils::isa<CNodePtr>(node)) {
130 continue;
131 }
132 if (utils::isa<abstract::AbstractTuple>(node->abstract())) {
133 multi_output_nodes.push_back(node);
134 }
135 }
136
137 while (!multi_output_nodes.empty()) {
138 auto cur_node = multi_output_nodes.front();
139 multi_output_nodes.pop_front();
140 size_t total_getitem_size = cur_node->abstract()->cast<abstract::AbstractTuplePtr>()->size();
141 size_t visited_getitem_size = GetItemVisitedNums(*visited_nodes, cur_node);
142 if (total_getitem_size == visited_getitem_size) {
143 continue;
144 }
145
146 size_t need_getitem_size = total_getitem_size - visited_getitem_size;
147 MoveGetItemToVisited(need_getitem_size, cur_node, visited_nodes, remain_nodes);
148 }
149 }
150
SplitGraph(const FuncGraphPtr & fg,AnfNodePtr * control_flow_node,std::set<AnfNodePtr> * visited_nodes,std::vector<AnfNodePtr> * remain_nodes)151 int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow_node,
152 std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
153 auto inputs = fg->get_inputs();
154
155 // notice: fg->nodes() is not work in this pass, cause too many useless parameter have been created.
156 auto node_list = TopoSort(fg->get_return());
157 for (auto &node : node_list) {
158 MS_ASSERT(node != nullptr);
159 if (utils::isa<CNodePtr>(node) &&
160 (CheckPrimitiveType(node, prim::kPrimWhile) || CheckPrimitiveType(node, prim::kPrimIf))) {
161 *control_flow_node = node;
162 break;
163 }
164 }
165
166 std::deque<AnfNodePtr> q;
167 visited_nodes->insert(inputs.begin(), inputs.end());
168 q.push_back(*control_flow_node);
169 while (!q.empty()) {
170 auto node = q.front();
171 q.pop_front();
172 if (!utils::isa<CNodePtr>(node)) {
173 continue;
174 }
175 visited_nodes->insert(node);
176 auto cnode = utils::cast<CNodePtr>(node);
177 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cast ptr failed");
178 for (size_t i = 0; i < cnode->size(); i++) {
179 auto input = cnode->input(i);
180 if (visited_nodes->find(input) == visited_nodes->end()) {
181 q.push_back(input);
182 }
183 }
184 }
185
186 for (auto &node : node_list) {
187 if (visited_nodes->find(node) == visited_nodes->end()) {
188 remain_nodes->push_back(node);
189 }
190 }
191 visited_nodes->erase(*control_flow_node);
192
193 BindGetItemNodes(visited_nodes, remain_nodes);
194
195 return RET_SUCCESS;
196 }
197
CreateAfterGraph(const FuncGraphPtr & main_fg,const std::vector<AnfNodePtr> & remain_nodes,const CNodePtr & aim_cnode,FuncGraphPtr * after_fg)198 int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
199 const CNodePtr &aim_cnode, FuncGraphPtr *after_fg) {
200 *after_fg = std::make_shared<FuncGraph>();
201 MS_CHECK_TRUE_MSG(*after_fg != nullptr, lite::RET_NULL_PTR, "*after_fg is nullptr");
202 auto manager = main_fg->manager();
203 MS_ASSERT(manager != nullptr);
204 manager->AddFuncGraph(*after_fg);
205 (*after_fg)->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
206 (*after_fg)->set_attr("graph_name", MakeValue(aim_cnode->fullname_with_scope() + "_after_fg"));
207 (*after_fg)->set_manager(main_fg->manager());
208
209 for (auto &cur_node : remain_nodes) {
210 if (cur_node->isa<ValueNode>()) {
211 continue;
212 }
213 if (cur_node == main_fg->get_return()) {
214 continue;
215 }
216 (*after_fg)->AddNode(cur_node);
217 if (!utils::isa<ValueNodePtr>(cur_node)) {
218 cur_node->set_func_graph(*after_fg);
219 }
220 if (cur_node == main_fg->output()) {
221 (*after_fg)->set_output(cur_node, false);
222 }
223 main_fg->DropNode(cur_node);
224 }
225 return RET_SUCCESS;
226 }
227
CreateWhileCondCallNode(const FuncGraphPtr & fg,const CNodePtr & while_cnode,const std::vector<AnfNodePtr> & visited_nodes_used_by_after_fg,CNodePtr * cond_call_cnode,std::vector<AnfNodePtr> * cond_nodes_used_by_after_partial,std::unordered_map<AnfNodePtr,AnfNodePtr> * visited_nodes_and_cond_fg_inputs_replace_pairs)228 int ControlFlowPass::CreateWhileCondCallNode(
229 const FuncGraphPtr &fg, const CNodePtr &while_cnode, const std::vector<AnfNodePtr> &visited_nodes_used_by_after_fg,
230 CNodePtr *cond_call_cnode, std::vector<AnfNodePtr> *cond_nodes_used_by_after_partial,
231 std::unordered_map<AnfNodePtr, AnfNodePtr> *visited_nodes_and_cond_fg_inputs_replace_pairs) {
232 auto cond_vnode = while_cnode->input(kWhileCondIndex);
233 MS_CHECK_TRUE_MSG(cond_vnode != nullptr, lite::RET_NULL_PTR, "cnode is nullptr");
234 auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
235 if (cond_fg == nullptr) {
236 MS_LOG(ERROR) << "Get value as func graph failed.";
237 return RET_FAILED;
238 }
239
240 // create after partial node
241 ValueNodePtr cond_partial_anf_primitive = lite::GetPartialFusionPrim();
242 if (cond_partial_anf_primitive == nullptr) {
243 MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
244 return RET_FAILED;
245 }
246
247 std::vector<AnfNodePtr> cond_partial_cnode_inputs{cond_partial_anf_primitive, cond_vnode};
248 cond_partial_cnode_inputs.insert(cond_partial_cnode_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
249 while_cnode->inputs().end());
250
251 auto origin_cond_fg_inputs = cond_fg->get_inputs();
252 for (auto &item : visited_nodes_used_by_after_fg) {
253 bool found = false;
254 size_t input_index = 0;
255 for (size_t i = kPartialFirstInputSize; i < cond_partial_cnode_inputs.size(); ++i) {
256 if (cond_partial_cnode_inputs[i] == item) {
257 found = true;
258 input_index = i - kPartialFirstInputSize;
259 break;
260 }
261 }
262
263 if (found) {
264 (*visited_nodes_and_cond_fg_inputs_replace_pairs)[item] = origin_cond_fg_inputs.at(input_index);
265 cond_nodes_used_by_after_partial->push_back(origin_cond_fg_inputs.at(input_index));
266 continue;
267 }
268
269 // set after fg inputs to cond_partial_cnode inputs
270 cond_partial_cnode_inputs.push_back(item);
271 auto new_parameter = cond_fg->add_parameter();
272 MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
273 new_parameter->set_name(item->fullname_with_scope() + "_cond_fg_parameter");
274 new_parameter->set_abstract(item->abstract());
275 (*visited_nodes_and_cond_fg_inputs_replace_pairs)[item] = new_parameter;
276 cond_nodes_used_by_after_partial->push_back(new_parameter);
277 }
278
279 auto cond_partial_cnode = fg->NewCNode(cond_partial_cnode_inputs);
280 MS_CHECK_TRUE_MSG(cond_partial_cnode != nullptr, lite::RET_NULL_PTR, "cond_partial_cnode is nullptr");
281 cond_partial_cnode->set_fullname_with_scope("partial_" + cond_fg->get_attr("graph_name")->ToString());
282
283 // insert call node
284 std::vector<AnfNodePtr> call_node_inputs{cond_partial_cnode};
285 *cond_call_cnode = fg->NewCNode(call_node_inputs);
286 MS_CHECK_TRUE_MSG(*cond_call_cnode != nullptr, lite::RET_NULL_PTR, "new cnode is nullptr");
287 (*cond_call_cnode)->set_fullname_with_scope("call_" + cond_partial_cnode->fullname_with_scope());
288
289 return RET_SUCCESS;
290 }
291
CreateWhileBodyPartialNode(const FuncGraphPtr & cond_fg,const CNodePtr & while_cnode,CNodePtr * body_partial_node)292 int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode,
293 CNodePtr *body_partial_node) {
294 auto body_vnode = while_cnode->input(kWhileBodyIndex);
295 MS_CHECK_TRUE_MSG(body_vnode != nullptr, RET_FAILED, "body_vnode is nullptr");
296 auto body_fg = GetValueNode<std::shared_ptr<FuncGraph>>(body_vnode);
297 if (body_fg == nullptr) {
298 MS_LOG(ERROR) << "Get value as func_graph failed.";
299 return RET_FAILED;
300 }
301
302 ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
303 if (partial_anf_primitive == nullptr) {
304 MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
305 return RET_FAILED;
306 }
307
308 std::vector<AnfNodePtr> body_partial_node_inputs{partial_anf_primitive, body_vnode};
309 // set body inputs to body partial inputs
310 auto cond_fg_inputs = cond_fg->get_inputs();
311 body_partial_node_inputs.insert(body_partial_node_inputs.end(), cond_fg_inputs.begin(), cond_fg_inputs.end());
312 *body_partial_node = cond_fg->NewCNode(body_partial_node_inputs);
313 MS_CHECK_TRUE_MSG(*body_partial_node != nullptr, RET_FAILED, "new cnode is nullptr");
314 (*body_partial_node)->set_fullname_with_scope("CNode_" + body_fg->get_attr("graph_name")->ToString());
315
316 // add after inputs for body fg to call cond fg
317 auto body_fg_inputs = body_fg->get_inputs();
318 auto origin_body_fg_inputs_size = body_fg_inputs.size();
319 for (size_t i = origin_body_fg_inputs_size; i < cond_fg_inputs.size(); ++i) {
320 if (!utils::isa<ParameterPtr>(cond_fg_inputs[i])) {
321 MS_LOG(ERROR) << "fg is not right.";
322 return RET_FAILED;
323 }
324 auto new_parameter = body_fg->add_parameter();
325 MS_CHECK_TRUE_MSG(new_parameter != nullptr, lite::RET_NULL_PTR, "new_parameter is nullptr");
326 new_parameter->set_name(cond_fg_inputs[i]->fullname_with_scope() + "_body_fg_parameter");
327 new_parameter->set_abstract(cond_fg_inputs[i]->abstract());
328 }
329
330 // call the cond fg
331 ValueNodePtr cond_partial_anf_primitive = lite::GetPartialFusionPrim();
332 if (cond_partial_anf_primitive == nullptr) {
333 MS_LOG(ERROR) << "`new cond_partial_anf_primitive failed.";
334 return RET_FAILED;
335 }
336 auto cond_partial_vnode = NewValueNode(cond_fg);
337 MS_CHECK_TRUE_MSG(cond_partial_vnode != nullptr, lite::RET_NULL_PTR, "cond_partial_vnode is nullptr");
338 std::vector<AnfNodePtr> cond_partial_inputs{cond_partial_anf_primitive, cond_partial_vnode};
339 // set body fg output
340 auto body_output = body_fg->output()->cast<CNodePtr>();
341 MS_ASSERT(body_output != nullptr);
342 if (CheckPrimitiveType(body_output, prim::kPrimMakeTuple)) {
343 for (size_t i = 1; i < body_output->size(); ++i) {
344 cond_partial_inputs.push_back(body_output->input(i));
345 }
346 body_fg->DropNode(body_output);
347 } else {
348 cond_partial_inputs.push_back(body_output);
349 }
350
351 body_fg_inputs = body_fg->get_inputs();
352 for (size_t i = origin_body_fg_inputs_size; i < body_fg_inputs.size(); ++i) {
353 cond_partial_inputs.push_back(body_fg_inputs[i]);
354 }
355
356 auto cond_partial_cnode = body_fg->NewCNode(cond_partial_inputs);
357 MS_CHECK_TRUE_MSG(cond_partial_cnode != nullptr, lite::RET_NULL_PTR, "cond_partial_cnode != nullptr");
358 cond_partial_cnode->set_fullname_with_scope(body_fg->get_attr("graph_name")->ToString() + "_call_cond_fg");
359
360 // insert call node
361 std::vector<AnfNodePtr> call_node_inputs{cond_partial_cnode};
362 auto cond_call_cnode = body_fg->NewCNode(call_node_inputs);
363 MS_CHECK_TRUE_MSG(cond_call_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
364 cond_call_cnode->set_fullname_with_scope("call_" + cond_partial_cnode->fullname_with_scope());
365 body_fg->set_output(cond_call_cnode);
366
367 to_process_q.push_back(body_fg);
368 return RET_SUCCESS;
369 }
370
CreateWhileAfterPartialNode(const FuncGraphPtr & main_fg,const FuncGraphPtr & cond_fg,const std::vector<AnfNodePtr> & remain_nodes,const std::vector<AnfNodePtr> & cond_nodes_used_by_after_partial,const std::unordered_map<AnfNodePtr,AnfNodePtr> & visited_nodes_and_cond_fg_inputs_replace_pairs,const CNodePtr * while_cnode,CNodePtr * after_partial_cnode)371 int ControlFlowPass::CreateWhileAfterPartialNode(
372 const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
373 const std::vector<AnfNodePtr> &cond_nodes_used_by_after_partial,
374 const std::unordered_map<AnfNodePtr, AnfNodePtr> &visited_nodes_and_cond_fg_inputs_replace_pairs,
375 const CNodePtr *while_cnode, CNodePtr *after_partial_cnode) {
376 // create after_fg
377 FuncGraphPtr after_fg = nullptr;
378 if (CreateAfterGraph(main_fg, remain_nodes, *while_cnode, &after_fg) != RET_SUCCESS) {
379 MS_LOG(ERROR) << "CreateAfterGraph failed.";
380 return RET_FAILED;
381 }
382
383 auto after_value_node = NewValueNode(after_fg);
384 MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "after_value_node is nullptr");
385 ValueNodePtr partial_anf_primitive = lite::GetPartialFusionPrim();
386 if (partial_anf_primitive == nullptr) {
387 MS_LOG(ERROR) << "GetPartialFusionPrim failed.";
388 return RET_FAILED;
389 }
390
391 std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_inputs_and_after_fg_inputs_replace_pairs{};
392 std::vector<AnfNodePtr> after_partial_cnode_inputs{partial_anf_primitive, after_value_node};
393 auto cond_fg_inputs = cond_fg->get_inputs();
394 for (const auto &node : after_fg->nodes()) {
395 if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
396 continue;
397 }
398 auto get_tuple_item_cnode = node->cast<CNodePtr>();
399 MS_ASSERT(get_tuple_item_cnode != nullptr);
400 MS_ASSERT(get_tuple_item_cnode->size() == kGetItemInputSize);
401 if (get_tuple_item_cnode->input(kCNodeFirstInputIndex) != *while_cnode) {
402 continue;
403 }
404 auto index_vnode = get_tuple_item_cnode->inputs().at(kCNodeSecondInputIndex);
405 if (!utils::isa<ValueNode>(index_vnode)) {
406 MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node";
407 return RET_FAILED;
408 }
409 auto value_node = utils::cast<ValueNodePtr>(index_vnode);
410 MS_ASSERT(value_node != nullptr);
411
412 auto input_index = value_node->value()->type()->number_type() == kNumberTypeInt64
413 ? GetValue<int64_t>(value_node->value())
414 : GetValue<int>(value_node->value());
415
416 after_partial_cnode_inputs.push_back(cond_fg_inputs.at(input_index));
417 auto new_parameter = after_fg->add_parameter();
418 MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
419 new_parameter->set_name(node->fullname_with_scope() + "_after_partial_parameter");
420 new_parameter->set_abstract(node->abstract());
421 after_partial_inputs_and_after_fg_inputs_replace_pairs[node] = new_parameter;
422 }
423
424 std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_after_fg_replace_pair{};
425 for (auto &input : cond_nodes_used_by_after_partial) {
426 after_partial_cnode_inputs.push_back(visited_nodes_and_cond_fg_inputs_replace_pairs.at(input));
427 auto new_parameter = after_fg->add_parameter();
428 MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter != nullptr");
429 new_parameter->set_name(input->fullname_with_scope() + "_after_fg_parameter");
430 new_parameter->set_abstract(input->abstract());
431 visited_nodes_after_fg_replace_pair[visited_nodes_and_cond_fg_inputs_replace_pairs.at(input)] = new_parameter;
432 }
433
434 ReplaceNode(after_fg, visited_nodes_and_cond_fg_inputs_replace_pairs);
435 ReplaceNode(after_fg, after_partial_inputs_and_after_fg_inputs_replace_pairs);
436 ReplaceNode(after_fg, visited_nodes_after_fg_replace_pair);
437 *after_partial_cnode = cond_fg->NewCNode(after_partial_cnode_inputs);
438 MS_CHECK_TRUE_MSG(*after_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
439 (*after_partial_cnode)->set_fullname_with_scope("CNode_" + after_fg->get_attr("graph_name")->ToString());
440 return RET_SUCCESS;
441 }
442
ProcessWhileOp(const FuncGraphPtr & fg,const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,const AnfNodePtr & while_node)443 int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
444 const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &while_node) {
445 if (while_node == nullptr) {
446 MS_LOG(INFO) << "not found while, no need to process.";
447 return RET_SUCCESS;
448 }
449
450 auto while_cnode = while_node->cast<CNodePtr>();
451 MS_ASSERT(while_cnode != nullptr);
452 if (while_cnode->size() < kWhileMinInputSize) {
453 MS_LOG(ERROR) << "while input is not right.";
454 return RET_FAILED;
455 }
456
457 std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
458 VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
459
460 CNodePtr cond_call_cnode = nullptr;
461 std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_cond_fg_inputs_replace_pairs{};
462 std::vector<AnfNodePtr> cond_nodes_used_by_after_partial{};
463 int ret = CreateWhileCondCallNode(fg, while_cnode, visited_nodes_used_by_after_fg, &cond_call_cnode,
464 &cond_nodes_used_by_after_partial, &visited_nodes_and_cond_fg_inputs_replace_pairs);
465 if (ret != RET_SUCCESS) {
466 MS_LOG(ERROR) << "while create cond call cnode failed, ret: " << ret;
467 return ret;
468 }
469
470 auto cond_fg_cnode = cond_call_cnode->input(kCNodePrimIndex)->cast<CNodePtr>();
471 MS_ASSERT(cond_fg_cnode != nullptr);
472 AnfNodePtr cond_fg_vnode = cond_fg_cnode->input(kCNodeFirstInputIndex);
473 MS_ASSERT(cond_fg_vnode != nullptr);
474 auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_fg_vnode);
475 MS_CHECK_TRUE_MSG(cond_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
476
477 CNodePtr body_partial_node = nullptr;
478 ret = CreateWhileBodyPartialNode(cond_fg, while_cnode, &body_partial_node);
479 if (ret != RET_SUCCESS) {
480 MS_LOG(ERROR) << "while create body partial cnode failed, ret: " << ret;
481 return ret;
482 }
483
484 CNodePtr after_partial_cnode = nullptr;
485 ret = CreateWhileAfterPartialNode(fg, cond_fg, remain_nodes, visited_nodes_used_by_after_fg,
486 visited_nodes_and_cond_fg_inputs_replace_pairs, &while_cnode, &after_partial_cnode);
487 if (ret != RET_SUCCESS) {
488 MS_LOG(ERROR) << "while create after partial cnode failed, ret: " << ret;
489 return ret;
490 }
491
492 // create switch cnode
493 ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
494 if (switch_anf_primitive == nullptr) {
495 MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
496 return lite::RET_ERROR;
497 }
498
499 // insert switch node
500 std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, cond_fg->output(), body_partial_node,
501 after_partial_cnode};
502 auto switch_cnode = cond_fg->NewCNode(switch_node_inputs);
503 MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_ERROR, "NewCnode failed");
504 switch_cnode->set_fullname_with_scope("while-Switch-" + cond_fg->get_attr("graph_name")->ToString());
505
506 // insert call node
507 std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
508 auto call_node = cond_fg->NewCNode(call_node_inputs);
509 MS_CHECK_TRUE_MSG(call_node != nullptr, lite::RET_NULL_PTR, "call_node is nullptr");
510 call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
511 cond_fg->set_output(call_node);
512
513 fg->DropNode(while_cnode);
514 fg->set_output(cond_call_cnode);
515
516 auto after_cnode = after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>();
517 MS_ASSERT(after_cnode != nullptr);
518 auto after_fg = after_cnode->value()->cast<FuncGraphPtr>();
519 if (after_fg == nullptr) {
520 MS_LOG(ERROR) << "after_fg is nullptr.";
521 return RET_FAILED;
522 }
523 to_process_q.push_back(cond_fg);
524 to_process_q.push_back(after_fg);
525 return RET_SUCCESS;
526 }
527
CreateIfPartialNodeExternalInputs(const CNodePtr & if_cnode,const FuncGraphPtr & partial_fg,std::vector<AnfNodePtr> * then_partial_cnode_inputs)528 int ControlFlowPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg,
529 std::vector<AnfNodePtr> *then_partial_cnode_inputs) {
530 auto if_inputs = if_cnode->inputs();
531 auto fg_name_attr = partial_fg->get_attr("graph_name");
532 MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
533 auto partial_fg_name = fg_name_attr->ToString();
534 std::vector<AnfNodePtr> if_external_inputs{};
535 if_external_inputs.assign(if_inputs.begin() + kIfMinInputSize, if_inputs.end());
536 auto origin_then_fg_inputs = partial_fg->get_inputs();
537 if (if_external_inputs.size() < origin_then_fg_inputs.size()) {
538 MS_LOG(ERROR) << "graph is not right.";
539 return RET_FAILED;
540 } else if (if_external_inputs.size() == origin_then_fg_inputs.size()) {
541 then_partial_cnode_inputs->insert(then_partial_cnode_inputs->end(), if_external_inputs.begin(),
542 if_external_inputs.end());
543 return RET_SUCCESS;
544 } else {
545 for (auto &fg_input : origin_then_fg_inputs) {
546 auto fg_input_name = fg_input->fullname_with_scope();
547 auto pos = partial_fg_name.size() + sizeof("_input_");
548 auto pos2 = fg_input_name.find('_', pos);
549 auto idx_str = fg_input_name.substr(pos - 1, pos2 - pos + 1);
550 auto partial_idx = 0;
551 try {
552 partial_idx = std::stoi(idx_str);
553 } catch (const std::exception &e) {
554 MS_LOG(ERROR) << "Get index failed: " << e.what();
555 return RET_FAILED;
556 }
557 then_partial_cnode_inputs->push_back(if_external_inputs.at(partial_idx));
558 }
559 }
560 return RET_SUCCESS;
561 }
562
CreateIfPartialNode(const FuncGraphPtr & fg,const size_t & index,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,const CNodePtr & if_cnode,const FuncGraphPtr & after_fg,CNodePtr * then_partial_cnode)563 int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index,
564 std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
565 const CNodePtr &if_cnode, const FuncGraphPtr &after_fg,
566 CNodePtr *then_partial_cnode) {
567 auto then_vnode = if_cnode->input(index);
568 MS_ASSERT(then_vnode != nullptr);
569 auto then_fg = GetValueNode<std::shared_ptr<FuncGraph>>(then_vnode);
570 MS_CHECK_TRUE_MSG(then_fg != nullptr, RET_FAILED, "Get value as func_graph failed.");
571
572 // create then partial node
573 ValueNodePtr then_partial_anf_primitive = lite::GetPartialFusionPrim();
574 MS_CHECK_TRUE_MSG(then_partial_anf_primitive != nullptr, RET_FAILED, "GetPartialFusionPrim failed.");
575 std::vector<AnfNodePtr> then_partial_cnode_inputs{then_partial_anf_primitive, then_vnode};
576 if (CreateIfPartialNodeExternalInputs(if_cnode, then_fg, &then_partial_cnode_inputs) != RET_SUCCESS) {
577 MS_LOG(ERROR) << "CreateIfPartialNodeExternalInputs failed.";
578 return RET_FAILED;
579 }
580 std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_after_partial_inputs_replace_pairs{};
581 std::vector<AnfNodePtr> then_nodes_used_by_after_partial{};
582 // set fg inputs to then_partial_cnode inputs
583 auto origin_then_fg_inputs = then_fg->get_inputs();
584 for (auto &item : *visited_nodes_used_by_after_fg) {
585 bool found = false;
586 size_t input_index = 0;
587 for (size_t i = kPartialFirstInputSize; i < then_partial_cnode_inputs.size(); ++i) {
588 if (then_partial_cnode_inputs[i] == item) {
589 found = true;
590 input_index = i - kPartialFirstInputSize;
591 break;
592 }
593 }
594 if (found) {
595 visited_nodes_and_after_partial_inputs_replace_pairs[item] = origin_then_fg_inputs.at(input_index);
596 then_nodes_used_by_after_partial.push_back(origin_then_fg_inputs.at(input_index));
597 continue;
598 }
599
600 // set after fg inputs to cond_partial_cnode inputs
601 then_partial_cnode_inputs.push_back(item);
602 auto new_parameter = then_fg->add_parameter();
603 MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "new_parameter is nullptr");
604 if (index == kIfThenIndex) {
605 new_parameter->set_name(item->fullname_with_scope() + "_then_fg_parameter");
606 } else {
607 new_parameter->set_name(item->fullname_with_scope() + "_else_fg_parameter");
608 }
609 new_parameter->set_abstract(item->abstract());
610 visited_nodes_and_after_partial_inputs_replace_pairs[item] = new_parameter;
611 then_nodes_used_by_after_partial.push_back(new_parameter);
612 }
613 *then_partial_cnode = fg->NewCNode(then_partial_cnode_inputs);
614 MS_CHECK_TRUE_MSG(*then_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr");
615 auto fg_name_attr = then_fg->get_attr("graph_name");
616 MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED);
617 auto then_fg_name = fg_name_attr->ToString();
618 (*then_partial_cnode)->set_fullname_with_scope("partial_" + then_fg_name);
619
620 // create after partial node
621 ValueNodePtr after_partial_anf_primitive = lite::GetPartialFusionPrim();
622 MS_CHECK_TRUE_MSG(after_partial_anf_primitive != nullptr, RET_FAILED, "GetPartialFusionPrim failed.");
623 auto after_value_node = NewValueNode(after_fg);
624 MS_CHECK_TRUE_MSG(after_value_node != nullptr, RET_FAILED, "NewValueNode failed.");
625 // make the right after partial input
626 std::vector<AnfNodePtr> after_partial_cnode_inputs{after_partial_anf_primitive, after_value_node};
627 if (!CheckPrimitiveType(then_fg->output(), prim::kPrimMakeTuple)) {
628 after_partial_cnode_inputs.push_back(then_fg->output());
629 } else {
630 auto then_fg_output = then_fg->output()->cast<CNodePtr>();
631 MS_CHECK_TRUE_MSG(then_fg_output != nullptr, RET_ERROR, "cast ptr failed");
632 for (size_t i = kCNodeFirstInputIndex; i < then_fg_output->size(); ++i) {
633 after_partial_cnode_inputs.push_back(then_fg_output->input(i));
634 }
635 then_fg->DropNode(then_fg_output);
636 }
637 size_t if_output_size = after_partial_cnode_inputs.size() - kCNodeSecondInputIndex;
638
639 // add after fg inputs to partial node
640 std::copy(then_nodes_used_by_after_partial.begin(), then_nodes_used_by_after_partial.end(),
641 std::back_inserter(after_partial_cnode_inputs));
642 // insert partial node
643 auto after_partial_cnode = then_fg->NewCNode(after_partial_cnode_inputs);
644 MS_CHECK_TRUE_MSG(after_partial_cnode != nullptr, RET_FAILED, "NewCNode failed");
645 auto after_fg_name = after_fg->get_attr("graph_name")->ToString();
646 after_partial_cnode->set_fullname_with_scope("partial_" + after_fg_name);
647
648 // insert call node
649 std::vector<AnfNodePtr> call_node_inputs{after_partial_cnode};
650 auto call_node = then_fg->NewCNode(call_node_inputs);
651 MS_CHECK_TRUE_MSG(call_node != nullptr, RET_FAILED, "NewCNode failed");
652 call_node->set_fullname_with_scope("call_" + after_partial_cnode->fullname_with_scope());
653 then_fg->set_output(call_node);
654 to_process_q.push_back(then_fg);
655 ReplaceNode(after_fg, visited_nodes_and_after_partial_inputs_replace_pairs);
656
657 // check the inputs of after fg
658 auto after_fg_inputs_size = after_fg->get_inputs().size();
659 if (after_fg_inputs_size == after_partial_cnode_inputs.size() - kPartialFirstInputSize) {
660 return RET_SUCCESS;
661 }
662
663 // make the inputs of the after fg
664 std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_after_fg_replace_pairs{};
665 for (size_t i = kPartialFirstInputSize; i < after_partial_cnode_inputs.size(); ++i) {
666 auto &input = after_partial_cnode_inputs[i];
667 auto new_parameter = after_fg->add_parameter();
668 MS_CHECK_TRUE_MSG(new_parameter != nullptr, RET_FAILED, "add_parameter failed");
669 new_parameter->set_name(std::to_string(i - kPartialFirstInputSize) + "_" + input->fullname_with_scope());
670 new_parameter->set_abstract(input->abstract());
671 if (i < kPartialFirstInputSize + if_output_size) {
672 after_partial_after_fg_replace_pairs[if_cnode] = new_parameter;
673 } else {
674 after_partial_after_fg_replace_pairs[input] = new_parameter;
675 }
676 }
677 ReplaceNode(after_fg, after_partial_after_fg_replace_pairs);
678
679 return RET_SUCCESS;
680 }
681
CreateIfElsePartialNode(const FuncGraphPtr & main_fg,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,const CNodePtr & if_cnode,const FuncGraphPtr & after_fg,CNodePtr * else_partial_cnode)682 int ControlFlowPass::CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
683 std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
684 const CNodePtr &if_cnode, const FuncGraphPtr &after_fg,
685 CNodePtr *else_partial_cnode) {
686 return CreateIfPartialNode(main_fg, kIfElseIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
687 else_partial_cnode);
688 }
689
CreateIfThenPartialNode(const FuncGraphPtr & main_fg,std::vector<AnfNodePtr> * visited_nodes_used_by_after_fg,const CNodePtr & if_cnode,const FuncGraphPtr & after_fg,CNodePtr * then_partial_cnode)690 int ControlFlowPass::CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
691 std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
692 const CNodePtr &if_cnode, const FuncGraphPtr &after_fg,
693 CNodePtr *then_partial_cnode) {
694 return CreateIfPartialNode(main_fg, kIfThenIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
695 then_partial_cnode);
696 }
697
ProcessIfOp(const FuncGraphPtr & fg,const std::set<AnfNodePtr> & visited_nodes,const std::vector<AnfNodePtr> & remain_nodes,const AnfNodePtr & if_node)698 int ControlFlowPass::ProcessIfOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
699 const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &if_node) {
700 if (if_node == nullptr) {
701 MS_LOG(INFO) << "not found if, no need to process.";
702 return RET_SUCCESS;
703 }
704
705 auto if_cnode = if_node->cast<CNodePtr>();
706 MS_ASSERT(if_cnode != nullptr);
707 if (if_cnode->size() < kIfMinInputSize) {
708 MS_LOG(ERROR) << "if input is not right.";
709 return RET_FAILED;
710 }
711
712 // create after_fg
713 FuncGraphPtr after_fg = nullptr;
714 if (CreateAfterGraph(fg, remain_nodes, if_cnode, &after_fg) != RET_SUCCESS) {
715 MS_LOG(ERROR) << "CreateAfterGraph failed.";
716 return RET_FAILED;
717 }
718
719 // get fg input which is not used by after_parts
720 std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
721 VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
722
723 CNodePtr then_partial_cnode = nullptr;
724 int ret = CreateIfThenPartialNode(fg, &visited_nodes_used_by_after_fg, if_cnode, after_fg, &then_partial_cnode);
725 if (ret != RET_SUCCESS) {
726 MS_LOG(ERROR) << "if create then partial cnode failed, ret: " << ret;
727 return ret;
728 }
729
730 CNodePtr else_partial_cnode = nullptr;
731 ret = CreateIfElsePartialNode(fg, &visited_nodes_used_by_after_fg, if_cnode, after_fg, &else_partial_cnode);
732 if (ret != RET_SUCCESS) {
733 MS_LOG(ERROR) << "if create else partial cnode failed, ret: " << ret;
734 return ret;
735 }
736
737 // create switch cnode
738 ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim();
739 if (switch_anf_primitive == nullptr) {
740 MS_LOG(ERROR) << "GetSwitchAnfPrim failed.";
741 return RET_FAILED;
742 }
743
744 // insert switch node
745 std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, if_cnode->input(kIfCondIndex), then_partial_cnode,
746 else_partial_cnode};
747 auto switch_cnode = fg->NewCNode(switch_node_inputs);
748 MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_FAILED, "NewCNode failed");
749 switch_cnode->set_fullname_with_scope("if-Switch-" + fg->get_attr("graph_name")->ToString());
750
751 // insert call node
752 std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
753 auto call_node = fg->NewCNode(call_node_inputs);
754 MS_CHECK_TRUE_MSG(call_node != nullptr, RET_FAILED, "NewCNode failed");
755 call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
756 fg->DropNode(if_cnode);
757 fg->set_output(call_node, true);
758
759 to_process_q.push_back(after_fg);
760 return RET_SUCCESS;
761 }
762
ProcessControlOp(const FuncGraphPtr & fg)763 int ControlFlowPass::ProcessControlOp(const FuncGraphPtr &fg) {
764 if (fg == nullptr) {
765 MS_LOG(ERROR) << "fg is nullptr.";
766 return RET_FAILED;
767 }
768
769 AnfNodePtr control_flow_node = nullptr;
770 std::vector<AnfNodePtr> remain_nodes{};
771 std::set<AnfNodePtr> visited_nodes{};
772 int ret = SplitGraph(fg, &control_flow_node, &visited_nodes, &remain_nodes);
773 if (ret != RET_SUCCESS) {
774 MS_LOG(ERROR) << "SplitGraph failed, ret: " << ret;
775 return ret;
776 }
777
778 if (control_flow_node == nullptr) {
779 MS_LOG(INFO) << "not found control flow op, no need to process.";
780 return RET_SUCCESS;
781 }
782
783 if (CheckPrimitiveType(control_flow_node, prim::kPrimWhile)) {
784 ret = ProcessWhileOp(fg, visited_nodes, remain_nodes, control_flow_node);
785 if (ret != RET_SUCCESS) {
786 MS_LOG(ERROR) << "ProcessWhileOp failed.";
787 return ret;
788 }
789 }
790
791 if (CheckPrimitiveType(control_flow_node, prim::kPrimIf)) {
792 ret = ProcessIfOp(fg, visited_nodes, remain_nodes, control_flow_node);
793 if (ret != RET_SUCCESS) {
794 MS_LOG(ERROR) << "ProcessIfOp failed.";
795 return ret;
796 }
797 }
798 return RET_SUCCESS;
799 }
800
Run(const FuncGraphPtr & fg)801 bool ControlFlowPass::Run(const FuncGraphPtr &fg) {
802 MS_ASSERT(fg != nullptr);
803 to_process_q.push_back(fg);
804 while (!to_process_q.empty()) {
805 auto cur_fg = to_process_q.front();
806 auto cur_fg_name = cur_fg->get_attr("graph_name")->ToString();
807 int ret = ProcessControlOp(cur_fg);
808 if (ret != RET_SUCCESS) {
809 MS_LOG(ERROR) << "ProcessControlOp for graph: " << cur_fg_name << " failed.";
810 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
811 return false;
812 }
813 to_process_q.pop_front();
814 }
815 return true;
816 }
817 } // namespace mindspore::opt
818