• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/common/func_graph_subgraph.h"
19 #include <set>
20 #include <string>
21 #include <vector>
22 #include <map>
23 #include <queue>
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "src/common/log_adapter.h"
26 #include "tools/common/node_util.h"
27 #include "tools/common/graph_util.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "ops/fusion/partial_fusion.h"
30 #include "nnacl/op_base.h"
31 
32 namespace mindspore::lite {
Init(const std::set<CNodePtr> & head_nodes)33 int SubGraph::Init(const std::set<CNodePtr> &head_nodes) {
34   auto ret = InitSubGraphNode(head_nodes);
35   if (ret != RET_OK) {
36     MS_LOG(ERROR) << "InitSubGraphNode failed";
37     return RET_ERROR;
38   }
39   ret = InitSubGraphInNode();
40   if (ret != RET_OK) {
41     MS_LOG(ERROR) << "InitSubGraphInNode failed";
42     return RET_ERROR;
43   }
44   ret = InitSubGraphOutNode();
45   if (ret != RET_OK) {
46     MS_LOG(ERROR) << "InitSubGraphOutNode failed";
47     return RET_ERROR;
48   }
49   return RET_OK;
50 }
51 
Reset(const std::set<CNodePtr> & nodes,const std::set<CNodePtr> & head_nodes)52 int SubGraph::Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes) {
53   this->nodes_ = nodes;
54   auto ret = InitSubGraphNode(head_nodes);
55   if (ret != RET_OK) {
56     MS_LOG(ERROR) << "InitSubGraphNode failed";
57     return RET_ERROR;
58   }
59   ret = InitSubGraphInNode();
60   if (ret != RET_OK) {
61     MS_LOG(ERROR) << "InitSubGraphInNode failed";
62     return RET_ERROR;
63   }
64   ret = InitSubGraphOutNode();
65   if (ret != RET_OK) {
66     MS_LOG(ERROR) << "InitSubGraphOutNode failed";
67     return RET_ERROR;
68   }
69   return RET_OK;
70 }
71 
GetNodes() const72 std::set<CNodePtr> SubGraph::GetNodes() const { return this->nodes_; }
73 
GetInCNodes() const74 std::set<CNodePtr> SubGraph::GetInCNodes() const { return this->in_nodes_; }
75 
GetInputCNodes() const76 std::set<CNodePtr> SubGraph::GetInputCNodes() const {
77   std::set<CNodePtr> inputs;
78   for (const auto &in_node : in_nodes_) {
79     if (in_node == nullptr) {
80       continue;
81     }
82     auto input_cnodes = GetInputCNode(in_node);
83     inputs.insert(input_cnodes.begin(), input_cnodes.end());
84   }
85   return inputs;
86 }
87 
GetOutCNodes() const88 std::set<CNodePtr> SubGraph::GetOutCNodes() const { return this->out_nodes_; }
89 
FindCommonOutputs(const SubGraphPtr & subgraph) const90 std::set<CNodePtr> SubGraph::FindCommonOutputs(const SubGraphPtr &subgraph) const {
91   if (subgraph == nullptr) {
92     return {};
93   }
94   std::set<CNodePtr> outputs_this = this->GetOutputCNodes();
95   if (this == subgraph.get()) {
96     return outputs_this;
97   }
98   std::set<CNodePtr> outputs_other = subgraph->GetOutputCNodes();
99   std::set<CNodePtr> common_outputs;
100   for (const auto &output1 : outputs_this) {
101     if (output1 == nullptr) {
102       continue;
103     }
104     auto iter = outputs_other.find(output1);
105     if (iter == outputs_other.end()) {
106       continue;
107     }
108     if (GetInputCNode(output1).size() == 2) {
109       common_outputs.insert(output1);
110     }
111   }
112   return common_outputs;
113 }
114 
IfDependOnSameNode(const SubGraphPtr & subgraph) const115 bool SubGraph::IfDependOnSameNode(const SubGraphPtr &subgraph) const {
116   if (subgraph == nullptr || this == subgraph.get()) {
117     return false;
118   }
119   std::set<CNodePtr> inputs_this = this->GetInputCNodes();
120   std::set<CNodePtr> inputs_other = subgraph->GetInputCNodes();
121   return std::any_of(inputs_this.begin(), inputs_this.end(), [&inputs_other](const CNodePtr &input_this) {
122     if (input_this == nullptr) {
123       return false;
124     }
125     return (inputs_other.count(input_this) > 0);
126   });
127 }
128 
GetOutputCNodes() const129 std::set<CNodePtr> SubGraph::GetOutputCNodes() const {
130   MS_ASSERT(belong_anf_ != nullptr);
131   MS_ASSERT(belong_anf_->manager() != nullptr);
132   auto node_users = belong_anf_->manager()->node_users();
133   std::set<CNodePtr> outputs;
134   for (const auto &out_node : out_nodes_) {
135     if (out_node == nullptr) {
136       continue;
137     }
138     auto iter = node_users.find(out_node);
139     if (iter == node_users.end()) {
140       continue;
141     }
142     auto post_node_pairs = iter->second;
143     for (const auto &post_node_pair : post_node_pairs) {
144       auto post_node = post_node_pair.first;
145       if (post_node == nullptr || !utils::isa<CNodePtr>(post_node)) {
146         continue;
147       }
148       outputs.insert(utils::cast<CNodePtr>(post_node));
149     }
150   }
151   return outputs;
152 }
153 
InitSubGraphNode(const std::set<CNodePtr> & head_nodes)154 int SubGraph::InitSubGraphNode(const std::set<CNodePtr> &head_nodes) {
155   MS_ASSERT(belong_anf_ != nullptr);
156   MS_ASSERT(belong_anf_->manager() != nullptr);
157   auto node_users = belong_anf_->manager()->node_users();
158   std::queue<CNodePtr> q{};
159   for (const auto &head_node : head_nodes) {
160     if (head_node == nullptr) {
161       continue;
162     }
163     q.push(head_node);
164   }
165   while (!q.empty()) {
166     auto cur_node = q.front();
167     MS_CHECK_TRUE_MSG(cur_node != nullptr, RET_NULL_PTR, "cur_node is nullptr");
168     q.pop();
169     this->nodes_.insert(cur_node);
170     // check output-cnode of cur-node only depend on cur-node
171     auto iter = node_users.find(cur_node);
172     if (iter == node_users.end()) {
173       continue;
174     }
175     auto post_node_pairs = iter->second;
176     for (const auto &post_node_pair : post_node_pairs) {
177       auto post_node = post_node_pair.first;
178       if (post_node == nullptr || !utils::isa<CNodePtr>(post_node)) {
179         continue;
180       }
181       auto post_cnode = utils::cast<CNodePtr>(post_node);
182       MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "cast failed");
183       // return-node should not be include into subgraph absolutely // ut
184       if (opt::CheckPrimitiveType(post_cnode, prim::kPrimReturn)) {
185         continue;
186       }
187       MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "post_cnode is nullptr");
188       bool non_depend = true;
189       // check all inputs of output-cnode
190       for (const auto &input : post_cnode->inputs()) {
191         if (input == nullptr) {
192           continue;
193         }
194         // input cnode is not contained in subgraph
195         if (utils::isa<CNodePtr>(input)) {
196           auto input_cnode = utils::cast<CNodePtr>(input);
197           MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed");
198           if (this->nodes_.count(input_cnode) == 0) {
199             non_depend = false;
200             break;
201           }
202         }
203         // input parameter is a graph input
204         if (utils::isa<ParameterPtr>(input)) {
205           auto input_parameter = utils::cast<ParameterPtr>(input);
206           MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_NULL_PTR, "cast failed");
207           if (!input_parameter->has_default()) {
208             non_depend = false;
209             break;
210           }
211         }
212       }
213       if (non_depend) {
214         q.push(post_cnode);
215       }
216     }
217   }
218   return RET_OK;
219 }
220 
InitSubGraphInNode()221 int SubGraph::InitSubGraphInNode() {
222   MS_ASSERT(belong_anf_ != nullptr);
223   MS_ASSERT(belong_anf_->manager() != nullptr);
224   auto node_users = belong_anf_->manager()->node_users();
225   this->in_nodes_.clear();
226   for (const auto &node : this->nodes_) {
227     if (node == nullptr) {
228       continue;
229     }
230     if (std::any_of(node->inputs().begin(), node->inputs().end(), [this, &node_users](const auto &input) {
231           if (input == nullptr) {
232             return false;
233           }
234           if (utils::isa<CNodePtr>(input)) {
235             auto input_cnode = utils::cast<CNodePtr>(input);
236             MS_CHECK_TRUE_MSG(input_cnode != nullptr, false, "cast failed");
237             if (this->nodes_.count(input_cnode) == 0) {
238               return true;
239             }
240           }
241           // graph input or shared weight input // ut
242           if (utils::isa<ParameterPtr>(input)) {
243             auto input_parameter = utils::cast<ParameterPtr>(input);
244             MS_CHECK_TRUE_MSG(input_parameter != nullptr, false, "cast failed");
245             if (!input_parameter->has_default()) {
246               return true;
247             }
248             auto output_pair_iter = node_users.find(input);
249             if (output_pair_iter != node_users.end() && output_pair_iter->second.size() > 1) {
250               return true;
251             }
252           }
253           return false;
254         })) {
255       in_nodes_.insert(node);
256     }
257   }
258   return RET_OK;
259 }
260 
InitSubGraphOutNode()261 int SubGraph::InitSubGraphOutNode() {
262   MS_ASSERT(belong_anf_ != nullptr);
263   MS_ASSERT(belong_anf_->manager() != nullptr);
264   auto node_users = belong_anf_->manager()->node_users();
265   this->out_nodes_.clear();
266   for (const auto &node : this->nodes_) {
267     if (node == nullptr) {
268       continue;
269     }
270     auto node_users_iter = node_users.find(node);
271     if (node_users_iter == node_users.end()) {
272       continue;
273     }
274     auto node_output_pairs = node_users_iter->second;
275     if (!std::any_of(node_output_pairs.begin(), node_output_pairs.end(),
276                      [this](const std::pair<AnfNodePtr, int> &output_pair) {
277                        auto output_node = output_pair.first;
278                        if (output_node == nullptr || !utils::isa<CNodePtr>(output_node)) {
279                          return false;
280                        }
281                        // graph output // ut
282                        if (opt::CheckPrimitiveType(output_node, prim::kPrimReturn)) {
283                          return true;
284                        }
285                        auto output_cnode = utils::cast<CNodePtr>(output_node);
286                        MS_CHECK_TRUE_MSG(output_cnode != nullptr, false, "cast failed");
287                        if (this->nodes_.count(output_cnode) == 0) {
288                          return true;
289                        }
290                        return false;
291                      }))
292       continue;
293     out_nodes_.insert(node);
294   }
295   return RET_OK;
296 }
297 
MergeSubGraph(const SubGraphPtr & subgraph)298 bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) {
299   if (subgraph == nullptr || this == subgraph.get()) {
300     return false;
301   }
302   // if two subgraph has same output, and this output node has only two input cnode which exactly from two
303   // subgraph, we merge two subgraph, and find more post node
304   auto common_outputs = this->FindCommonOutputs(subgraph);
305   if (!common_outputs.empty()) {
306     auto new_nodes = this->GetNodes();
307     auto new_nodes2 = subgraph->GetNodes();
308     new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
309     new_nodes.insert(common_outputs.begin(), common_outputs.end());
310     if (this->Reset(new_nodes, common_outputs) != RET_OK) {
311       MS_LOG(ERROR) << "Reset failed";
312       return false;
313     }
314     return true;
315   }
316 
317   if (this->IfDependOnSameNode(subgraph)) {
318     auto new_nodes = this->GetNodes();
319     auto new_nodes2 = subgraph->GetNodes();
320     new_nodes.insert(new_nodes2.begin(), new_nodes2.end());
321     if (this->Reset(new_nodes) != RET_OK) {
322       MS_LOG(ERROR) << "Reset failed";
323       return false;
324     }
325     return true;
326   }
327   return false;
328 }
329 
330 // iterate node from in_nodes of current subgraph up to input of belong_anf
FindBeforeSubGraphInBelongAnf() const331 SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const {
332   MS_ASSERT(belong_anf_ == nullptr);
333   // find before subgraph's nodes
334   std::queue<CNodePtr> q{};
335   std::set<CNodePtr> before_nodes{};
336   for (const auto &node : this->GetInCNodes()) {
337     for (const auto &in_cnode : lite::GetInputCNode(node)) {
338       if (in_cnode == nullptr) {
339         continue;
340       }
341       q.push(in_cnode);
342     }
343   }
344   while (!q.empty()) {
345     auto cur_cnode = q.front();
346     MS_CHECK_TRUE_MSG(cur_cnode != nullptr, nullptr, "cur_cnode is nullptr");
347     q.pop();
348     before_nodes.insert(cur_cnode);
349     for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) {
350       q.push(in_cnode);
351     }
352   }
353   // construct before subgraph
354   auto before_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/before_subgraph");
355   MS_CHECK_TRUE_MSG(before_subgraph != nullptr, nullptr, "before_subgraph is nullptr");
356   if (before_subgraph->Reset(before_nodes) != RET_OK) {
357     MS_LOG(ERROR) << "Reset failed";
358     return nullptr;
359   }
360   return before_subgraph;
361 }
362 
363 // iterate node from output of belong_anf up to out_nodes of current subgraph and before subgraph
FindAfterSubGraphInBelongAnf() const364 SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const {
365   MS_ASSERT(belong_anf_ == nullptr);
366   // find before subgraph
367   auto before_subgraph = this->FindBeforeSubGraphInBelongAnf();
368   if (before_subgraph == nullptr) {
369     MS_LOG(ERROR) << "Find before subgraph failed";
370     return nullptr;
371   }
372   // find after subgraph's nodes
373   std::queue<CNodePtr> q{};
374   std::set<CNodePtr> after_nodes{};
375   auto output_node = belong_anf_->output();
376   if (!utils::isa<CNodePtr>(output_node)) {
377     MS_LOG(ERROR) << "Output node of anf should be a cnode: " << output_node->fullname_with_scope();
378     return nullptr;
379   }
380   q.push(utils::cast<CNodePtr>(output_node));
381   auto subgraph_out_nodes = this->GetOutCNodes();
382   auto before_out_nodes = before_subgraph->GetOutCNodes();
383   while (!q.empty()) {
384     auto cur_cnode = q.front();
385     MS_CHECK_TRUE_MSG(cur_cnode != nullptr, nullptr, "cur_cnode is nullptr");
386     q.pop();
387     after_nodes.insert(cur_cnode);
388     for (const auto &in_cnode : lite::GetInputCNode(cur_cnode)) {
389       if (subgraph_out_nodes.count(in_cnode) == 0 && before_out_nodes.count(in_cnode) == 0) {
390         q.push(in_cnode);
391       }
392     }
393   }
394   // construct before subgraph
395   auto after_subgraph = std::make_shared<SubGraph>(belong_anf_, this->name_ + "/after_subgraph");
396   MS_CHECK_TRUE_MSG(after_subgraph != nullptr, nullptr, "after_subgraph is nullptr");
397   if (after_subgraph->Reset(after_nodes) != RET_OK) {
398     MS_LOG(ERROR) << "Reset failed";
399     return nullptr;
400   }
401 
402   return after_subgraph;
403 }
404 
CreatePartialInBelongAnf()405 int SubGraph::CreatePartialInBelongAnf() {
406   MS_ASSERT(this->belong_anf_ != nullptr);
407   MS_ASSERT(this->belong_anf_->manager() != nullptr);
408   // determine func_graph name
409   std::string graph_name = this->name_;
410   if (graph_name.empty()) {
411     if (this->nodes_.empty()) {
412       graph_name = "subgraph";
413     } else {
414       graph_name = (*(this->nodes_.begin()))->fullname_with_scope() + "/subgraph";
415     }
416   }
417   // create func_graph of partial
418   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
419   MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
420   auto manager = belong_anf_->manager();
421   manager->AddFuncGraph(func_graph);
422   func_graph->set_attr("graph_name", MakeValue(graph_name));
423   func_graph->set_manager(manager);
424   // create cnode and parameter for func_graph of partial
425   std::vector<AnfNodePtr> partial_inputs;
426   std::map<AnfNodePtr, AnfNodePtr> partial_inputs_and_subgraph_input_map;
427   auto ret = CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map);
428   if (ret != RET_OK) {
429     MS_LOG(DEBUG) << "CreateParameterForPartialSubGraph  failed";
430     return ret;
431   }
432   ret = CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map);
433   if (ret != RET_OK) {
434     MS_LOG(DEBUG) << "CreateCNodeForPartialSubGraph failed";
435     return ret;
436   }
437   // add return for func_graph of partial
438   auto sub_graph_outputs = this->GetOutCNodes();
439   MS_ASSERT(!sub_graph_outputs.empty());
440   ret = SetFuncGraphOutput(func_graph, sub_graph_outputs);
441   if (ret != RET_OK) {
442     MS_LOG(DEBUG) << "Set subgraph output failed";
443     return ret;
444   }
445   // create partial cnode
446   auto partial_prim = std::make_shared<mindspore::ops::PartialFusion>();
447   auto graph_value_node = NewValueNode(func_graph);
448   MS_CHECK_TRUE_MSG(partial_prim != nullptr, RET_NULL_PTR, "partial_prim is nullptr");
449   MS_CHECK_TRUE_MSG(graph_value_node != nullptr, RET_NULL_PTR, "graph_value_node is nullptr");
450   auto partial_prim_c = partial_prim->GetPrim();
451   MS_CHECK_TRUE_MSG(partial_prim_c != nullptr, RET_NULL_PTR, "partial_prim_c is nullptr");
452   partial_inputs.insert(partial_inputs.begin(), graph_value_node);
453   auto partial_cnode = belong_anf_->NewCNode(partial_prim_c, partial_inputs);
454   MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "partial_cnode is nullptr");
455   partial_cnode->set_fullname_with_scope(graph_name + "/partial");
456   for (size_t i = 0; i < partial_inputs.size(); ++i) {
457     const auto &input = partial_inputs.at(i);
458     manager->SetEdge(partial_cnode, static_cast<int>(i + 1), input);
459   }
460   // create call cnode
461   std::vector<AnfNodePtr> call_node_inputs{partial_cnode};
462   auto call_cnode = belong_anf_->NewCNode(call_node_inputs);
463   MS_CHECK_TRUE_MSG(call_cnode != nullptr, RET_NULL_PTR, "call_cnode is nullptr");
464   call_cnode->set_fullname_with_scope(graph_name + "/call");
465   // replace belong-graph's output
466   auto return_node = belong_anf_->get_return();
467   // return node should has 2 inputs
468   MS_ASSERT(return_node != nullptr && return_node->size() == 2);
469   auto ori_output = return_node->inputs().at(1);
470   manager->Replace(ori_output, call_cnode);
471   return RET_OK;
472 }
473 
SetFuncGraphOutput(const FuncGraphPtr & graph,const std::set<CNodePtr> & outputs)474 int SubGraph::SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs) {
475   std::vector<AnfNodePtr> output_nodes;
476   output_nodes.insert(output_nodes.end(), outputs.begin(), outputs.end());
477   return lite::SetFuncGraphOutput(graph, output_nodes);
478 }
479 
CreateParameterForPartialSubGraph(const FuncGraphPtr & sub_graph,std::vector<AnfNodePtr> * partial_inputs,std::map<AnfNodePtr,AnfNodePtr> * partial_inputs_and_subgraph_input_map)480 int SubGraph::CreateParameterForPartialSubGraph(
481   const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
482   std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map) {
483   MS_ASSERT(sub_graph != nullptr);
484   MS_ASSERT(partial_inputs != nullptr && partial_inputs->empty());
485   MS_ASSERT(partial_inputs_and_subgraph_input_map != nullptr && partial_inputs_and_subgraph_input_map->empty());
486   MS_CHECK_TRUE_MSG(sub_graph->get_attr("graph_name") != nullptr, RET_ERROR, "graph_name is nullptr");
487   std::string graph_name = sub_graph->get_attr("graph_name")->ToString();
488   for (const auto &in_cnode : this->GetInCNodes()) {
489     if (in_cnode == nullptr) {
490       continue;
491     }
492     for (size_t i = 1; i < in_cnode->size(); i++) {
493       auto input = in_cnode->input(i);
494       if (input == nullptr) {
495         continue;
496       }
497       auto iter = partial_inputs_and_subgraph_input_map->find(input);
498       if (iter != partial_inputs_and_subgraph_input_map->end()) {
499         continue;
500       }
501       // create subgraph input parameter from cnode and record partial inputs
502       if (utils::isa<CNodePtr>(input)) {
503         auto input_cnode = utils::cast<CNodePtr>(input);
504         MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed");
505         if (this->GetNodes().count(input_cnode) > 0) {
506           continue;
507         }
508         partial_inputs->emplace_back(input);
509         auto new_parameter = sub_graph->add_parameter();
510         new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
511         new_parameter->set_abstract(input->abstract());
512         (*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
513       }
514       // create subgraph input parameter from parameter and record partial inputs
515       // add parameter to func_graph
516       auto node_users = this->belong_anf_->manager()->node_users();
517       if (utils::isa<ParameterPtr>(input)) {
518         auto parameter = utils::cast<ParameterPtr>(input);
519         MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "cast ptr failed");
520         // graph input: create a parameter
521         if (!parameter->has_default()) {
522           auto new_parameter = sub_graph->add_parameter();
523           new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
524           new_parameter->set_abstract(input->abstract());
525           (*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
526           partial_inputs->emplace_back(new_parameter);
527         }
528         // weight parameter, it depends
529         auto output_pairs_iter = node_users.find(input);
530         if (output_pairs_iter != node_users.end() &&
531             output_pairs_iter->second.size() > 1) {  // shared weight: create a parameter
532           auto new_parameter = sub_graph->add_parameter();
533           new_parameter->set_name(graph_name + "_input_" + input->fullname_with_scope());
534           new_parameter->set_abstract(input->abstract());
535           (*partial_inputs_and_subgraph_input_map)[input] = new_parameter;
536           partial_inputs->emplace_back(new_parameter);
537         } else {  // not shared weight: move into subgraph
538           sub_graph->AddNode(input);
539           input->set_func_graph(sub_graph);
540           this->belong_anf_->DropNode(input);
541         }
542       }
543     }
544   }
545   return RET_OK;
546 }
547 
CreateCNodeForPartialSubGraph(const FuncGraphPtr & sub_graph,const std::map<AnfNodePtr,AnfNodePtr> & partial_inputs_and_subgraph_input_map)548 int SubGraph::CreateCNodeForPartialSubGraph(
549   const FuncGraphPtr &sub_graph, const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map) {
550   MS_ASSERT(sub_graph != nullptr);
551   // move cnode from belong_graph to subgraph
552   for (auto &node : this->GetNodes()) {
553     sub_graph->AddNode(node);
554     if (!utils::isa<ValueNodePtr>(node)) {
555       node->set_func_graph(sub_graph);
556     }
557     for (size_t i = 0; i < node->size(); i++) {
558       auto input = node->inputs().at(i);
559       if (input == nullptr) {
560         continue;
561       }
562       auto iter = partial_inputs_and_subgraph_input_map.find(input);
563       if (iter == partial_inputs_and_subgraph_input_map.end()) {
564         continue;
565       }
566       // use SetEdge not set_input, if not, node_user is not updated.
567       this->belong_anf_->manager()->SetEdge(node, static_cast<int>(i), iter->second);
568     }
569     this->belong_anf_->DropNode(node);
570   }
571   return RET_OK;
572 }
573 
ApplySubGraph()574 int SubGraph::ApplySubGraph() {
575   // check
576   if (this->nodes_.empty()) {
577     return lite::RET_NO_CHANGE;
578   }
579   if (belong_anf_ == nullptr || belong_anf_->manager() == nullptr) {
580     MS_LOG(DEBUG) << "belong_anf_ or manager is nullptr";
581     return lite::RET_NO_CHANGE;
582   }
583   for (const auto &node : this->nodes_) {
584     if (node == nullptr) {
585       continue;
586     }
587     if (node->func_graph() != belong_anf_) {
588       MS_LOG(DEBUG) << "subgraph nodes belong to different func_graph";
589       return lite::RET_ERROR;
590     }
591   }
592 
593   // create after partial // redirect input of after subgraph
594   auto after_subgraph = this->FindAfterSubGraphInBelongAnf();
595   if (after_subgraph == nullptr) {
596     MS_LOG(DEBUG) << "Create after subgraph failed";
597     return RET_ERROR;
598   }
599   auto ret = after_subgraph->CreatePartialInBelongAnf();
600   if (ret != RET_OK) {
601     MS_LOG(DEBUG) << "Create after partial failed";
602     return RET_ERROR;
603   }
604   // merge after partial into subgraph
605   auto subgraph_nodes = this->nodes_;
606   auto return_node = belong_anf_->get_return();
607   MS_ASSERT(return_node != nullptr && return_node->size() == 2);
608   auto call_node = return_node->inputs().at(1);
609   MS_ASSERT(call_node != nullptr && utils::isa<CNodePtr>(call_node));
610   auto call_cnode = utils::cast<CNodePtr>(call_node);
611   MS_ASSERT(call_cnode != nullptr && call_cnode->size() == 1);
612   auto after_partial_node = call_cnode->inputs().at(0);
613   MS_ASSERT(after_partial_node != nullptr && utils::isa<CNodePtr>(after_partial));
614   auto after_partial_cnode = utils::cast<CNodePtr>(after_partial_node);
615   MS_ASSERT(after_partial_cnode != nullptr);
616   subgraph_nodes.insert(after_partial_cnode);
617   subgraph_nodes.insert(call_cnode);
618   if (this->Reset(subgraph_nodes) != RET_OK) {
619     MS_LOG(ERROR) << "Reset failed";
620     return RET_ERROR;
621   }
622   // create subgraph partial // add partial to main subgraph
623   ret = this->CreatePartialInBelongAnf();
624   if (ret != RET_OK) {
625     MS_LOG(DEBUG) << "Create partial failed";
626     return RET_ERROR;
627   }
628   return RET_OK;
629 }
630 }  // namespace mindspore::lite
631