1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/common/func_graph_subgraph.h"
19 #include <set>
20 #include <string>
21 #include <vector>
22 #include <map>
23 #include <queue>
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "src/common/log_adapter.h"
26 #include "tools/common/node_util.h"
27 #include "tools/common/graph_util.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "ops/fusion/partial_fusion.h"
30 #include "nnacl/op_base.h"
31
32 namespace mindspore::lite {
Init(const std::set<CNodePtr> & head_nodes)33 int SubGraph::Init(const std::set<CNodePtr> &head_nodes) {
34 auto ret = InitSubGraphNode(head_nodes);
35 if (ret != RET_OK) {
36 MS_LOG(ERROR) << "InitSubGraphNode failed";
37 return RET_ERROR;
38 }
39 ret = InitSubGraphInNode();
40 if (ret != RET_OK) {
41 MS_LOG(ERROR) << "InitSubGraphInNode failed";
42 return RET_ERROR;
43 }
44 ret = InitSubGraphOutNode();
45 if (ret != RET_OK) {
46 MS_LOG(ERROR) << "InitSubGraphOutNode failed";
47 return RET_ERROR;
48 }
49 return RET_OK;
50 }
51
Reset(const std::set<CNodePtr> & nodes,const std::set<CNodePtr> & head_nodes)52 int SubGraph::Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes) {
53 this->nodes_ = nodes;
54 auto ret = InitSubGraphNode(head_nodes);
55 if (ret != RET_OK) {
56 MS_LOG(ERROR) << "InitSubGraphNode failed";
57 return RET_ERROR;
58 }
59 ret = InitSubGraphInNode();
60 if (ret != RET_OK) {
61 MS_LOG(ERROR) << "InitSubGraphInNode failed";
62 return RET_ERROR;
63 }
64 ret = InitSubGraphOutNode();
65 if (ret != RET_OK) {
66 MS_LOG(ERROR) << "InitSubGraphOutNode failed";
67 return RET_ERROR;
68 }
69 return RET_OK;
70 }
71
GetNodes() const72 std::set<CNodePtr> SubGraph::GetNodes() const { return this->nodes_; }
73
GetInCNodes() const74 std::set<CNodePtr> SubGraph::GetInCNodes() const { return this->in_nodes_; }
75
GetInputCNodes() const76 std::set<CNodePtr> SubGraph::GetInputCNodes() const {
77 std::set<CNodePtr> inputs;
78 for (const auto &in_node : in_nodes_) {
79 if (in_node == nullptr) {
80 continue;
81 }
82 auto input_cnodes = GetInputCNode(in_node);
83 inputs.insert(input_cnodes.begin(), input_cnodes.end());
84 }
85 return inputs;
86 }
87
GetOutCNodes() const88 std::set<CNodePtr> SubGraph::GetOutCNodes() const { return this->out_nodes_; }
89
FindCommonOutputs(const SubGraphPtr & subgraph) const90 std::set<CNodePtr> SubGraph::FindCommonOutputs(const SubGraphPtr &subgraph) const {
91 if (subgraph == nullptr) {
92 return {};
93 }
94 std::set<CNodePtr> outputs_this = this->GetOutputCNodes();
95 if (this == subgraph.get()) {
96 return outputs_this;
97 }
98 std::set<CNodePtr> outputs_other = subgraph->GetOutputCNodes();
99 std::set<CNodePtr> common_outputs;
100 for (const auto &output1 : outputs_this) {
101 if (output1 == nullptr) {
102 continue;
103 }
104 auto iter = outputs_other.find(output1);
105 if (iter == outputs_other.end()) {
106 continue;
107 }
108 if (GetInputCNode(output1).size() == 2) {
109 common_outputs.insert(output1);
110 }
111 }
112 return common_outputs;
113 }
114
IfDependOnSameNode(const SubGraphPtr & subgraph) const115 bool SubGraph::IfDependOnSameNode(const SubGraphPtr &subgraph) const {
116 if (subgraph == nullptr || this == subgraph.get()) {
117 return false;
118 }
119 std::set<CNodePtr> inputs_this = this->GetInputCNodes();
120 std::set<CNodePtr> inputs_other = subgraph->GetInputCNodes();
121 return std::any_of(inputs_this.begin(), inputs_this.end(), [&inputs_other](const CNodePtr &input_this) {
122 if (input_this == nullptr) {
123 return false;
124 }
125 return (inputs_other.count(input_this) > 0);
126 });
127 }
128
GetOutputCNodes() const129 std::set<CNodePtr> SubGraph::GetOutputCNodes() const {
130 MS_ASSERT(belong_anf_ != nullptr);
131 MS_ASSERT(belong_anf_->manager() != nullptr);
132 auto node_users = belong_anf_->manager()->node_users();
133 std::set<CNodePtr> outputs;
134 for (const auto &out_node : out_nodes_) {
135 if (out_node == nullptr) {
136 continue;
137 }
138 auto iter = node_users.find(out_node);
139 if (iter == node_users.end()) {
140 continue;
141 }
142 auto post_node_pairs = iter->second;
143 for (const auto &post_node_pair : post_node_pairs) {
144 auto post_node = post_node_pair.first;
145 if (post_node == nullptr || !utils::isa<CNodePtr>(post_node)) {
146 continue;
147 }
148 outputs.insert(utils::cast<CNodePtr>(post_node));
149 }
150 }
151 return outputs;
152 }
153
InitSubGraphNode(const std::set<CNodePtr> & head_nodes)154 int SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
155 MS_ASSERT(belong_anf_ != nullptr);
156 MS_ASSERT(belong_anf_->manager() != nullptr);
157 auto node_users = belong_anf_->manager()->node_users();
158 std::queue<CNodePtr> q{};
159 for (const auto &head_node : head_nodes) {
160 if (head_node == nullptr) {
161 continue;
162 }
163 q.push(head_node);
164 }
165 while (!q.empty()) {
166 auto cur_node = q.front();
167 MS_CHECK_TRUE_MSG(cur_node != nullptr, RET_NULL_PTR, "cur_node is nullptr");
168 q.pop();
169 this->nodes_.insert(cur_node);
170 // check output-cnode of cur-node only depend on cur-node
171 auto iter = node_users.find(cur_node);
172 if (iter == node_users.end()) {
173 continue;
174 }
175 auto post_node_pairs = iter->second;
176 for (const auto &post_node_pair : post_node_pairs) {
177 auto post_node = post_node_pair.first;
178 if (post_node == nullptr || !utils::isa<CNodePtr>(post_node)) {
179 continue;
180 }
181 auto post_cnode = utils::cast<CNodePtr>(post_node);
182 MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "cast failed");
183 // return-node should not be include into subgraph absolutely // ut
184 if (opt::CheckPrimitiveType(post_cnode, prim::kPrimReturn)) {
185 continue;
186 }
187 MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "post_cnode is nullptr");
188 bool non_depend = true;
189 // check all inputs of output-cnode
190 for (const auto &input : post_cnode->inputs()) {
191 if (input == nullptr) {
192 continue;
193 }
194 // input cnode is not contained in subgraph
195 if (utils::isa<CNodePtr>(input)) {
196 auto input_cnode = utils::cast<CNodePtr>(input);
197 MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed");
198 if (this->nodes_.count(input_cnode) == 0) {
199 non_depend = false;
200 break;
201 }
202 }
203 // input parameter is a graph input
204 if (utils::isa<ParameterPtr>(input)) {
205 auto input_parameter = utils::cast<ParameterPtr>(input);
206 MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_NULL_PTR, "cast failed");
207 if (!input_parameter->has_default()) {
208 non_depend = false;
209 break;
210 }
211 }
212 }
213 if (non_depend) {
214 q.push(post_cnode);
215 }
216 }
217 }
218 return RET_OK;
219 }
220
InitSubGraphInNode()221 int SubGraph::InitSubGraphInNode() {
222 MS_ASSERT(belong_anf_ != nullptr);
223 MS_ASSERT(belong_anf_->manager() != nullptr);
224 auto node_users = belong_anf_->manager()->node_users();
225 this->in_nodes_.clear();
226 for (const auto &node : this->nodes_) {
227 if (node == nullptr) {
228 continue;
229 }
230 if (std::any_of(node->inputs().begin(), node->inputs().end(), [this, &node_users](const auto &input) {
231 if (input == nullptr) {
232 return false;
233 }
234 if (utils::isa<CNodePtr>(input)) {
235 auto input_cnode = utils::cast<CNodePtr>(input);
236 MS_CHECK_TRUE_MSG(input_cnode != nullptr, false, "cast failed");
237 if (this->nodes_.count(input_cnode) == 0) {
238 return true;
239 }
240 }
241 // graph input or shared weight input // ut
242 if (utils::isa<ParameterPtr>(input)) {
243 auto input_parameter = utils::cast<ParameterPtr>(input);
244 MS_CHECK_TRUE_MSG(input_parameter != nullptr, false, "cast failed");
245 if (!input_parameter->has_default()) {
246 return true;
247 }
248 auto output_pair_iter = node_users.find(input);
249 if (output_pair_iter != node_users.end() && output_pair_iter->second.size() > 1) {
250 return true;
251 }
252 }
253 return false;
254 })) {
255 in_nodes_.insert(node);
256 }
257 }
258 return RET_OK;
259 }
260
InitSubGraphOutNode()261 int SubGraph::InitSubGraphOutNode() {
262 MS_ASSERT(belong_anf_ != nullptr);
263 MS_ASSERT(belong_anf_->manager() != nullptr);
264 auto node_users = belong_anf_->manager()->node_users();
265 this->out_nodes_.clear();
266 for (const auto &node : this->nodes_) {
267 if (node == nullptr) {
268 continue;
269 }
270 auto node_users_iter = node_users.find(node);
271 if (node_users_iter == node_users.end()) {
272 continue;
273 }
274 auto node_output_pairs = node_users_iter->second;
275 if (!std::any_of(node_output_pairs.begin(), node_output_pairs.end(),
276 [this](const std::pair<AnfNodePtr, int> &output_pair) {
277 auto output_node = output_pair.first;
278 if (output_node == nullptr || !utils::isa<CNodePtr>(output_node)) {
279 return false;
280 }
281 // graph output // ut
282 if (opt::CheckPrimitiveType(output_node, prim::kPrimReturn)) {
283 return true;
284 }
285 auto output_cnode = utils::cast<CNodePtr>(output_node);
286 MS_CHECK_TRUE_MSG(output_cnode != nullptr, false, "cast failed");
287 if (this->nodes_.count(output_cnode) == 0) {
288 return true;
289 }
290 return false;
291 }))
292 continue;
293 out_nodes_.insert(node);
294 }
295 return RET_OK;
296 }
297
MergeSubGraph(const SubGraphPtr & subgraph)298 bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
299 if (subgraph == nullptr || this == subgraph.get()) {
300 return false;
301 }
302 // if two subgraph has same output, and this output node has only two input cnode which exactly from two
303 // subgraph, we merge two subgraph, and find more post node
304 auto common_outputs = this->FindCommonOutputs(subgraph);
305 if (!common_outputs.empty()) {
306 auto new_nodes = this->GetNodes();
307 auto new_nodes2 = subgraph->GetNodes();
308 new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
309 new_nodes.insert(common_outputs.begin(), common_outputs.end());
310 if (this->Reset(new_nodes, common_outputs) != RET_OK) {
311 MS_LOG(ERROR) << "Reset failed";
312 return false;
313 }
314 return true;
315 }
316
317 if (this->IfDependOnSameNode(subgraph)) {
318 auto new_nodes = this->GetNodes();
319 auto new_nodes2 = subgraph->GetNodes();
320 new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
321 if (this->Reset(new_nodes) != RET_OK) {
322 MS_LOG(ERROR) << "Reset failed";
323 return false;
324 }
325 return true;
326 }
327 return false;
328 }
329
330 // iterate node from in_nodes of current subgraph up to input of belong_anf
FindBeforeSubGraphInBelongAnf() const331 SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const {
332 MS_ASSERT(belong_anf_ == nullptr);
333 // find before subgraph's nodes
334 std::queue<CNodePtr> q{};
335 std::set<CNodePtr> before_nodes{};
336 for (const auto &node : this->GetInCNodes()) {
337 for (const auto &in_cnode : lite::GetInputCNode(node)) {
338 if (in_cnode == nullptr) {
339 continue;
340 }
341 q.push(in_cnode);
342 }
343 }
344 while (!q.empty()) {
345 auto cur_cnode = q.front();
346 MS_CHECK_TRUE_MSG(cur_cnode != nullptr, nullptr, "cur_cnode is nullptr");
347 q.pop();
348 before_nodes.insert(cur_cnode);
349 for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) {
350 q.push(in_cnode);
351 }
352 }
353 // construct before subgraph
354 auto before_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/before_subgraph");
355 MS_CHECK_TRUE_MSG(before_subgraph != nullptr, nullptr, "before_subgraph is nullptr");
356 if (before_subgraph->Reset(before_nodes) != RET_OK) {
357 MS_LOG(ERROR) << "Reset failed";
358 return nullptr;
359 }
360 return before_subgraph;
361 }
362
363 // iterate node from output of belong_anf up to out_nodes of current subgraph and before subgraph
FindAfterSubGraphInBelongAnf() const364 SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const {
365 MS_ASSERT(belong_anf_ == nullptr);
366 // find before subgraph
367 auto before_subgraph = this->FindBeforeSubGraphInBelongAnf();
368 if (before_subgraph == nullptr) {
369 MS_LOG(ERROR) << "Find before subgraph failed";
370 return nullptr;
371 }
372 // find after subgraph's nodes
373 std::queue<CNodePtr> q{};
374 std::set<CNodePtr> after_nodes{};
375 auto output_node = belong_anf_->output();
376 if (!utils::isa<CNodePtr>(output_node)) {
377 MS_LOG(ERROR) << "Output node of anf should be a cnode: " << output_node->fullname_with_scope();
378 return nullptr;
379 }
380 q.push(utils::cast<CNodePtr>(output_node));
381 auto subgraph_out_nodes = this->GetOutCNodes();
382 auto before_out_nodes = before_subgraph->GetOutCNodes();
383 while (!q.empty()) {
384 auto cur_cnode = q.front();
385 MS_CHECK_TRUE_MSG(cur_cnode != nullptr, nullptr, "cur_cnode is nullptr");
386 q.pop();
387 after_nodes.insert(cur_cnode);
388 for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) {
389 if (subgraph_out_nodes.count(in_cnode) == 0 && before_out_nodes.count(in_cnode) == 0) {
390 q.push(in_cnode);
391 }
392 }
393 }
394 // construct before subgraph
395 auto after_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/after_subgraph");
396 MS_CHECK_TRUE_MSG(after_subgraph != nullptr, nullptr, "after_subgraph is nullptr");
397 if (after_subgraph->Reset(after_nodes) != RET_OK) {
398 MS_LOG(ERROR) << "Reset failed";
399 return nullptr;
400 }
401
402 return after_subgraph;
403 }
404
CreatePartialInBelongAnf()405 int SubGraph::CreatePartialInBelongAnf() {
406 MS_ASSERT(this->belong_anf_ != nullptr);
407 MS_ASSERT(this->belong_anf_->manager() != nullptr);
408 // determine func_graph name
409 std::string graph_name = this->name_;
410 if (graph_name.empty()) {
411 if (this->nodes_.empty()) {
412 graph_name = "subgraph";
413 } else {
414 graph_name = (*(this->nodes_.begin()))->fullname_with_scope() + "/subgraph";
415 }
416 }
417 // create func_graph of partial
418 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
419 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
420 auto manager = belong_anf_->manager();
421 manager->AddFuncGraph(func_graph);
422 func_graph->set_attr("graph_name", MakeValue(graph_name));
423 func_graph->set_manager(manager);
424 // create cnode and parameter for func_graph of partial
425 std::vector<AnfNodePtr> partial_inputs;
426 std::map<AnfNodePtr, AnfNodePtr> partial_inputs_and_subgraph_input_map;
427 auto ret = CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map);
428 if (ret != RET_OK) {
429 MS_LOG(DEBUG) << "CreateParameterForPartialSubGraph failed";
430 return ret;
431 }
432 ret = CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map);
433 if (ret != RET_OK) {
434 MS_LOG(DEBUG) << "CreateCNodeForPartialSubGraph failed";
435 return ret;
436 }
437 // add return for func_graph of partial
438 auto sub_graph_outputs = this->GetOutCNodes();
439 MS_ASSERT(!sub_graph_outputs.empty());
440 ret = SetFuncGraphOutput(func_graph, sub_graph_outputs);
441 if (ret != RET_OK) {
442 MS_LOG(DEBUG) << "Set subgraph output failed";
443 return ret;
444 }
445 // create partial cnode
446 auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
447 auto graph_value_node = NewValueNode(func_graph);
448 MS_CHECK_TRUE_MSG(partial_prim != nullptr, RET_NULL_PTR, "partial_prim is nullptr");
449 MS_CHECK_TRUE_MSG(graph_value_node != nullptr, RET_NULL_PTR, "graph_value_node is nullptr");
450 auto partial_prim_c = partial_prim->GetPrim();
451 MS_CHECK_TRUE_MSG(partial_prim_c != nullptr, RET_NULL_PTR, "partial_prim_c is nullptr");
452 partial_inputs.insert(partial_inputs.begin(), graph_value_node);
453 auto partial_cnode = belong_anf_->NewCNode(partial_prim_c, partial_inputs);
454 MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "partial_cnode is nullptr");
455 partial_cnode->set_fullname_with_scope(graph_name + "/partial");
456 for (size_t i = 0; i < partial_inputs.size(); ++i) {
457 const auto &input = partial_inputs.at(i);
458 manager->SetEdge(partial_cnode, static_cast<int>(i + 1), input);
459 }
460 // create call cnode
461 std::vector<AnfNodePtr> call_node_inputs{partial_cnode};
462 auto call_cnode = belong_anf_->NewCNode(call_node_inputs);
463 MS_CHECK_TRUE_MSG(call_cnode != nullptr, RET_NULL_PTR, "call_cnode is nullptr");
464 call_cnode->set_fullname_with_scope(graph_name + "/call");
465 // replace belong-graph's output
466 auto return_node = belong_anf_->get_return();
467 // return node should has 2 inputs
468 MS_ASSERT(return_node != nullptr && return_node->size() == 2);
469 auto ori_output = return_node->inputs().at(1);
470 manager->Replace(ori_output, call_cnode);
471 return RET_OK;
472 }
473
SetFuncGraphOutput(const FuncGraphPtr & graph,const std::set<CNodePtr> & outputs)474 int SubGraph::SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs) {
475 std::vector<AnfNodePtr> output_nodes;
476 output_nodes.insert(output_nodes.end(), outputs.begin(), outputs.end());
477 return lite::SetFuncGraphOutput(graph, output_nodes);
478 }
479
CreateParameterForPartialSubGraph(const FuncGraphPtr & sub_graph,std::vector<AnfNodePtr> * partial_inputs,std::map<AnfNodePtr,AnfNodePtr> * partial_inputs_and_subgraph_input_map)480 int SubGraph::CreateParameterForPartialSubGraph(
481 const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
482 std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map) {
483 MS_ASSERT(sub_graph != nullptr);
484 MS_ASSERT(partial_inputs != nullptr && partial_inputs->empty());
485 MS_ASSERT(partial_inputs_and_subgraph_input_map != nullptr && partial_inputs_and_subgraph_input_map->empty());
486 MS_CHECK_TRUE_MSG(sub_graph->get_attr("graph_name") != nullptr, RET_ERROR, "graph_name is nullptr");
487 std::string graph_name = sub_graph->get_attr("graph_name")->ToString();
488 for (const auto &in_cnode : this->GetInCNodes()) {
489 if (in_cnode == nullptr) {
490 continue;
491 }
492 for (size_t i = 1; i < in_cnode->size(); i++) {
493 auto input = in_cnode->input(i);
494 if (input == nullptr) {
495 continue;
496 }
497 auto iter = partial_inputs_and_subgraph_input_map->find(input);
498 if (iter != partial_inputs_and_subgraph_input_map->end()) {
499 continue;
500 }
501 // create subgraph input parameter from cnode and record partial inputs
502 if (utils::isa<CNodePtr>(input)) {
503 auto input_cnode = utils::cast<CNodePtr>(input);
504 MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed");
505 if (this->GetNodes().count(input_cnode) > 0) {
506 continue;
507 }
508 partial_inputs->emplace_back(input);
509 auto new_parameter = sub_graph->add_parameter();
510 new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
511 new_parameter->set_abstract(input->abstract());
512 (*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
513 }
514 // create subgraph input parameter from parameter and record partial inputs
515 // add parameter to func_graph
516 auto node_users = this->belong_anf_->manager()->node_users();
517 if (utils::isa<ParameterPtr>(input)) {
518 auto parameter = utils::cast<ParameterPtr>(input);
519 MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "cast ptr failed");
520 // graph input: create a parameter
521 if (!parameter->has_default()) {
522 auto new_parameter = sub_graph->add_parameter();
523 new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
524 new_parameter->set_abstract(input->abstract());
525 (*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
526 partial_inputs->emplace_back(new_parameter);
527 }
528 // weight parameter, it depends
529 auto output_pairs_iter = node_users.find(input);
530 if (output_pairs_iter != node_users.end() &&
531 output_pairs_iter->second.size() > 1) { // shared weight: create a parameter
532 auto new_parameter = sub_graph->add_parameter();
533 new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
534 new_parameter->set_abstract(input->abstract());
535 (*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
536 partial_inputs->emplace_back(new_parameter);
537 } else { // not shared weight: move into subgraph
538 sub_graph->AddNode(input);
539 input->set_func_graph(sub_graph);
540 this->belong_anf_->DropNode(input);
541 }
542 }
543 }
544 }
545 return RET_OK;
546 }
547
CreateCNodeForPartialSubGraph(const FuncGraphPtr & sub_graph,const std::map<AnfNodePtr,AnfNodePtr> & partial_inputs_and_subgraph_input_map)548 int SubGraph::CreateCNodeForPartialSubGraph(
549 const FuncGraphPtr &sub_graph, const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map) {
550 MS_ASSERT(sub_graph != nullptr);
551 // move cnode from belong_graph to subgraph
552 for (auto &node : this->GetNodes()) {
553 sub_graph->AddNode(node);
554 if (!utils::isa<ValueNodePtr>(node)) {
555 node->set_func_graph(sub_graph);
556 }
557 for (size_t i = 0; i < node->size(); i++) {
558 auto input = node->inputs().at(i);
559 if (input == nullptr) {
560 continue;
561 }
562 auto iter = partial_inputs_and_subgraph_input_map.find(input);
563 if (iter == partial_inputs_and_subgraph_input_map.end()) {
564 continue;
565 }
566 // use SetEdge not set_input, if not, node_user is not updated.
567 this->belong_anf_->manager()->SetEdge(node, static_cast<int>(i), iter->second);
568 }
569 this->belong_anf_->DropNode(node);
570 }
571 return RET_OK;
572 }
573
ApplySubGraph()574 int SubGraph::ApplySubGraph() {
575 // check
576 if (this->nodes_.empty()) {
577 return lite::RET_NO_CHANGE;
578 }
579 if (belong_anf_ == nullptr || belong_anf_->manager() == nullptr) {
580 MS_LOG(DEBUG) << "belong_anf_ or manager is nullptr";
581 return lite::RET_NO_CHANGE;
582 }
583 for (const auto &node : this->nodes_) {
584 if (node == nullptr) {
585 continue;
586 }
587 if (node->func_graph() != belong_anf_) {
588 MS_LOG(DEBUG) << "subgraph nodes belong to different func_graph";
589 return lite::RET_ERROR;
590 }
591 }
592
593 // create after partial // redirect input of after subgraph
594 auto after_subgraph = this->FindAfterSubGraphInBelongAnf();
595 if (after_subgraph == nullptr) {
596 MS_LOG(DEBUG) << "Create after subgraph failed";
597 return RET_ERROR;
598 }
599 auto ret = after_subgraph->CreatePartialInBelongAnf();
600 if (ret != RET_OK) {
601 MS_LOG(DEBUG) << "Create after partial failed";
602 return RET_ERROR;
603 }
604 // merge after partial into subgraph
605 auto subgraph_nodes = this->nodes_;
606 auto return_node = belong_anf_->get_return();
607 MS_ASSERT(return_node != nullptr && return_node->size() == 2);
608 auto call_node = return_node->inputs().at(1);
609 MS_ASSERT(call_node != nullptr && utils::isa<CNodePtr>(call_node));
610 auto call_cnode = utils::cast<CNodePtr>(call_node);
611 MS_ASSERT(call_cnode != nullptr && call_cnode->size() == 1);
612 auto after_partial_node = call_cnode->inputs().at(0);
613 MS_ASSERT(after_partial_node != nullptr && utils::isa<CNodePtr>(after_partial));
614 auto after_partial_cnode = utils::cast<CNodePtr>(after_partial_node);
615 MS_ASSERT(after_partial_cnode != nullptr);
616 subgraph_nodes.insert(after_partial_cnode);
617 subgraph_nodes.insert(call_cnode);
618 if (this->Reset(subgraph_nodes) != RET_OK) {
619 MS_LOG(ERROR) << "Reset failed";
620 return RET_ERROR;
621 }
622 // create subgraph partial // add partial to main subgraph
623 ret = this->CreatePartialInBelongAnf();
624 if (ret != RET_OK) {
625 MS_LOG(DEBUG) << "Create partial failed";
626 return RET_ERROR;
627 }
628 return RET_OK;
629 }
630 } // namespace mindspore::lite
631