• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "vm/graph_partition.h"
18 #include <string>
19 #include <functional>
20 #include <utility>
21 #include <map>
22 #include <queue>
23 #include <stack>
24 #include <set>
25 #include <algorithm>
26 #include "base/core_ops.h"
27 #include "utils/utils.h"
28 #include "utils/ms_context.h"
29 #include "ps/ps_context.h"
30 #include "ir/anf_utils.h"
31 #ifdef ENABLE_GE
32 #include "transform/graph_ir/convert.h"
33 #endif
34 namespace mindspore {
35 const char kMsConvert[] = "ms";
36 const char kMsVm[] = "vm";
37 const char kGeVm[] = "ge";
38 namespace compile {
39 namespace {
GetOtherTarget(const std::vector<AnfNodePtr> & nodes)40 std::string GetOtherTarget(const std::vector<AnfNodePtr> &nodes) {
41   auto context_ptr = MsContext::GetInstance();
42   MS_EXCEPTION_IF_NULL(context_ptr);
43   std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
44   for (auto &node : nodes) {
45     MS_EXCEPTION_IF_NULL(node);
46     if (!node->isa<CNode>()) {
47       continue;
48     }
49     std::string cur_target = GetCNodeTarget(node);
50     if (cur_target != default_target) {
51       return cur_target;
52     }
53   }
54   return "";
55 }
56 
CalcNodeRefCount(const FuncGraphPtr & graph,std::map<AnfNodePtr,size_t> * nodes_ref)57 void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) {
58   MS_EXCEPTION_IF_NULL(graph);
59   MS_EXCEPTION_IF_NULL(nodes_ref);
60   std::queue<AnfNodePtr> queue;
61   queue.push(graph->get_return());
62   std::set<AnfNodePtr> visited;
63   while (!queue.empty()) {
64     auto node = queue.front();
65     queue.pop();
66     MS_EXCEPTION_IF_NULL(node);
67     if (!node->isa<CNode>()) {
68       continue;
69     }
70     auto cnode = node->cast<CNodePtr>();
71     MS_EXCEPTION_IF_NULL(cnode);
72     for (auto &input : cnode->inputs()) {
73       auto iter = nodes_ref->find(input);
74       if (iter != nodes_ref->end()) {
75         iter->second++;
76       } else {
77         (void)nodes_ref->emplace(input, 1UL);
78       }
79       if (visited.find(input) != visited.end()) {
80         continue;
81       }
82       visited.insert(input);
83       queue.push(input);
84     }
85   }
86 }
87 
ReorderVirtualNode(const std::vector<AnfNodePtr> & nodes,const PrimitivePtr & reorder_prim)88 std::vector<AnfNodePtr> ReorderVirtualNode(const std::vector<AnfNodePtr> &nodes, const PrimitivePtr &reorder_prim) {
89   std::vector<AnfNodePtr> result;
90   std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
91   std::map<AnfNodePtr, size_t> node_positions;
92   auto add_insert_position = [&insert_positions, &node_positions](const AnfNodePtr &node, const AnfNodePtr &parent) {
93     if (parent == nullptr) {
94       return false;
95     }
96     auto iter = node_positions.find(parent);
97     if (iter != node_positions.end()) {
98       size_t position = iter->second;
99       auto iter_nodes = insert_positions.find(position);
100       if (iter_nodes != insert_positions.end()) {
101         iter_nodes->second.push_back(node);
102       } else {
103         (void)insert_positions.emplace(position, std::vector<AnfNodePtr>{node});
104       }
105       return true;
106     }
107     return false;
108   };
109   for (auto &node : nodes) {
110     MS_EXCEPTION_IF_NULL(node);
111     if (IsPrimitiveCNode(node, reorder_prim)) {
112       auto cnode = node->cast<CNodePtr>();
113       MS_EXCEPTION_IF_NULL(cnode);
114       auto &inputs = cnode->inputs();
115       AnfNodePtr parent = nullptr;
116       const size_t depend_input_size = 2;
117       if (reorder_prim == prim::kPrimDepend && inputs.size() == depend_input_size + 1 && !inputs[1]->isa<CNode>()) {
118         parent = inputs[depend_input_size];
119       } else if (reorder_prim == prim::kPrimTupleGetItem && inputs.size() > 1) {
120         parent = inputs[1];
121       }
122       if (add_insert_position(node, parent)) {
123         continue;
124       }
125     }
126     result.emplace_back(node);
127     node_positions[node] = result.size();
128   }
129 
130   size_t insert_num = 0;
131   for (auto &item : insert_positions) {
132     auto position = SizeToLong(item.first + insert_num);
133     (void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
134     insert_num += item.second.size();
135   }
136   return result;
137 }
138 
GetNextNodes(const AnfNodePtr & node,std::map<AnfNodePtr,size_t> * nodes_ref,std::vector<AnfNodePtr> * result)139 std::vector<AnfNodePtr> GetNextNodes(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *nodes_ref,
140                                      std::vector<AnfNodePtr> *result) {
141   MS_EXCEPTION_IF_NULL(node);
142   MS_EXCEPTION_IF_NULL(nodes_ref);
143   MS_EXCEPTION_IF_NULL(result);
144   auto cnode = node->cast<CNodePtr>();
145   MS_EXCEPTION_IF_NULL(cnode);
146   auto node_inputs = cnode->inputs();
147   if (!IsPrimitiveCNode(node, prim::kPrimSwitch)) {
148     std::reverse(node_inputs.begin(), node_inputs.end());
149     return node_inputs;
150   }
151   std::vector<AnfNodePtr> extend_inputs;
152   for (auto &input : node_inputs) {
153     MS_EXCEPTION_IF_NULL(input);
154     if (IsPrimitiveCNode(input, prim::kPrimPartial)) {
155       auto iter = nodes_ref->find(input);
156       if (iter != nodes_ref->end() && iter->second == 1) {
157         iter->second--;
158         result->emplace_back(input);
159         auto partial_cnode = input->cast<CNodePtr>();
160         MS_EXCEPTION_IF_NULL(partial_cnode);
161         auto partial_inputs = partial_cnode->inputs();
162         std::reverse(partial_inputs.begin(), partial_inputs.end());
163         (void)extend_inputs.insert(extend_inputs.end(), partial_inputs.begin(), partial_inputs.end());
164         continue;
165       }
166     }
167     extend_inputs.emplace_back(input);
168   }
169   return extend_inputs;
170 }
171 
SplitSort(const FuncGraphPtr & graph,const std::string & default_target)172 std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
173   MS_EXCEPTION_IF_NULL(graph);
174   std::vector<AnfNodePtr> result;
175   std::stack<AnfNodePtr> to_visit;
176   std::stack<AnfNodePtr> next_to_visit;
177   std::map<AnfNodePtr, size_t> nodes_ref;
178   CalcNodeRefCount(graph, &nodes_ref);
179   std::string handle_target = default_target;
180   std::string next_target;
181   to_visit.push(graph->get_return());
182   while (!to_visit.empty() || !next_to_visit.empty()) {
183     if (to_visit.empty()) {
184       to_visit.swap(next_to_visit);
185       handle_target = next_target;
186     }
187     auto node = to_visit.top();
188     MS_EXCEPTION_IF_NULL(node);
189     to_visit.pop();
190     result.emplace_back(node);
191     if (!node->isa<CNode>()) {
192       continue;
193     }
194     auto next_nodes = GetNextNodes(node, &nodes_ref, &result);
195     for (auto &input : next_nodes) {
196       MS_EXCEPTION_IF_NULL(input);
197       auto iter = nodes_ref.find(input);
198       if (iter != nodes_ref.end()) {
199         iter->second--;
200         if (iter->second != 0) {
201           continue;
202         }
203       }
204       if (!input->isa<CNode>()) {
205         to_visit.push(input);
206         continue;
207       }
208       std::string input_target = GetCNodeTarget(input);
209       if (input_target == handle_target) {
210         to_visit.push(input);
211       } else if (next_to_visit.empty() || input_target == next_target) {
212         next_to_visit.push(input);
213         next_target = input_target;
214       } else {
215         MS_LOG(EXCEPTION) << "Only support two different target";
216       }
217     }
218   }
219   std::reverse(result.begin(), result.end());
220   return result;
221 }
222 
223 struct GraphNodesDependencyInfo {
224   std::stack<AnfNodePtr> independent_nodes_;
225   std::map<AnfNodePtr, size_t> input_num_;
226   std::map<AnfNodePtr, std::vector<AnfNodePtr>> output_edges_;
227 };
228 
GetNodesDependencyInfo(const FuncGraphPtr & graph)229 GraphNodesDependencyInfo GetNodesDependencyInfo(const FuncGraphPtr &graph) {
230   MS_EXCEPTION_IF_NULL(graph);
231   GraphNodesDependencyInfo info;
232   std::stack<AnfNodePtr> to_visit;
233   std::map<AnfNodePtr, size_t> nodes_ref;
234   CalcNodeRefCount(graph, &nodes_ref);
235   to_visit.push(graph->get_return());
236   while (!to_visit.empty()) {
237     auto node = to_visit.top();
238     to_visit.pop();
239     if (node == nullptr || !node->isa<CNode>()) {
240       continue;
241     }
242     auto cnode = node->cast<CNodePtr>();
243     MS_EXCEPTION_IF_NULL(cnode);
244     auto node_inputs = cnode->inputs();
245     bool independent = true;
246     for (auto &input : node_inputs) {
247       MS_EXCEPTION_IF_NULL(input);
248       if (input->isa<CNode>()) {
249         independent = false;
250         auto output_edge_iter = info.output_edges_.find(input);
251         if (output_edge_iter != info.output_edges_.end()) {
252           auto &edges = output_edge_iter->second;
253           edges.emplace_back(node);
254         } else {
255           info.output_edges_[input] = {node};
256         }
257         auto input_num_iter = info.input_num_.find(node);
258         if (input_num_iter != info.input_num_.end()) {
259           input_num_iter->second++;
260         } else {
261           info.input_num_[node] = 1;
262         }
263       }
264       auto ref_iter = nodes_ref.find(input);
265       if (ref_iter != nodes_ref.end()) {
266         ref_iter->second--;
267         if (ref_iter->second != 0) {
268           continue;
269         }
270       }
271       to_visit.push(input);
272     }
273     if (independent) {
274       info.independent_nodes_.push(node);
275     }
276   }
277   return info;
278 }
279 
280 struct VisitNodesInfo {
281   std::queue<AnfNodePtr> default_target_nodes_;
282   std::queue<AnfNodePtr> other_target_nodes_;
283   std::map<AnfNodePtr, AnfNodePtr> seed_cast_next_node_;
284 };
285 
GetVisitNodesInfo(const GraphNodesDependencyInfo & dependency_info,const std::string & default_target,const std::string & other_target)286 VisitNodesInfo GetVisitNodesInfo(const GraphNodesDependencyInfo &dependency_info, const std::string &default_target,
287                                  const std::string &other_target) {
288   VisitNodesInfo result;
289   auto independent_nodes = dependency_info.independent_nodes_;
290   while (!independent_nodes.empty()) {
291     auto seed_node = independent_nodes.top();
292     independent_nodes.pop();
293     MS_EXCEPTION_IF_NULL(seed_node);
294     auto node_target = GetCNodeTarget(seed_node);
295     if (IsPrimitiveCNode(seed_node, prim::kPrimCast)) {
296       auto output_edges_iter = dependency_info.output_edges_.find(seed_node);
297       if (output_edges_iter != dependency_info.output_edges_.end() && output_edges_iter->second.size() == 1) {
298         auto &cast_next_node = output_edges_iter->second[0];
299         auto input_num_iter = dependency_info.input_num_.find(cast_next_node);
300         if (input_num_iter == dependency_info.input_num_.end()) {
301           MS_LOG(EXCEPTION) << "Node input num not found!";
302         }
303         if (input_num_iter->second > 1 && node_target == GetCNodeTarget(cast_next_node)) {
304           result.seed_cast_next_node_[cast_next_node] = seed_node;
305           continue;
306         }
307       }
308     }
309     if (node_target == default_target) {
310       result.default_target_nodes_.push(seed_node);
311     } else if (node_target == other_target) {
312       result.other_target_nodes_.push(seed_node);
313     } else {
314       MS_LOG(EXCEPTION) << "Only support two difference targets";
315     }
316   }
317   return result;
318 }
319 
ParallelSortDecideNextHandleTarget(const std::vector<AnfNodePtr> & output_edges,const std::string & node_target,std::map<AnfNodePtr,std::string> * node_input_target_map)320 std::string ParallelSortDecideNextHandleTarget(const std::vector<AnfNodePtr> &output_edges,
321                                                const std::string &node_target,
322                                                std::map<AnfNodePtr, std::string> *node_input_target_map) {
323   MS_EXCEPTION_IF_NULL(node_input_target_map);
324   std::string next_target = node_target;
325   for (auto &dst_node : output_edges) {
326     auto input_target_iter = node_input_target_map->find(dst_node);
327     if (input_target_iter != node_input_target_map->end() && input_target_iter->second != node_target) {
328       next_target = input_target_iter->second;
329       break;
330     }
331     auto dst_node_target = GetCNodeTarget(dst_node);
332     if (dst_node_target != node_target) {
333       next_target = dst_node_target;
334       break;
335     }
336   }
337   for (auto &dst_node : output_edges) {
338     (*node_input_target_map)[dst_node] = node_target;
339   }
340   return next_target;
341 }
342 
ParallelSortVisitNodeEdges(const std::vector<AnfNodePtr> & output_edges,GraphNodesDependencyInfo * dependency_info,VisitNodesInfo * visit_nodes_info,const std::string & default_target)343 void ParallelSortVisitNodeEdges(const std::vector<AnfNodePtr> &output_edges, GraphNodesDependencyInfo *dependency_info,
344                                 VisitNodesInfo *visit_nodes_info, const std::string &default_target) {
345   MS_EXCEPTION_IF_NULL(dependency_info);
346   MS_EXCEPTION_IF_NULL(visit_nodes_info);
347   for (auto &dst_node : output_edges) {
348     auto dst_node_target = GetCNodeTarget(dst_node);
349     auto input_num_iter = dependency_info->input_num_.find(dst_node);
350     if (input_num_iter == dependency_info->input_num_.end()) {
351       MS_LOG(EXCEPTION) << "Node input num not found!";
352     }
353     input_num_iter->second--;
354     if (input_num_iter->second == 1 &&
355         visit_nodes_info->seed_cast_next_node_.find(dst_node) != visit_nodes_info->seed_cast_next_node_.end()) {
356       input_num_iter->second--;
357     }
358     if (input_num_iter->second > 0) {
359       continue;
360     }
361     if (dst_node_target == default_target) {
362       visit_nodes_info->default_target_nodes_.push(dst_node);
363     } else {
364       visit_nodes_info->other_target_nodes_.push(dst_node);
365     }
366   }
367 }
368 
ParallelSort(const FuncGraphPtr & graph,const std::string & default_target,const std::string & other_target)369 std::vector<AnfNodePtr> ParallelSort(const FuncGraphPtr &graph, const std::string &default_target,
370                                      const std::string &other_target) {
371   MS_EXCEPTION_IF_NULL(graph);
372   auto dependency_info = GetNodesDependencyInfo(graph);
373   auto visit_nodes_info = GetVisitNodesInfo(dependency_info, default_target, other_target);
374   std::vector<AnfNodePtr> result;
375   std::string handle_target;
376   if (!visit_nodes_info.default_target_nodes_.empty()) {
377     handle_target = default_target;
378   } else {
379     handle_target = other_target;
380   }
381   std::map<AnfNodePtr, std::string> node_input_target_map;
382   while (!visit_nodes_info.default_target_nodes_.empty() || !visit_nodes_info.other_target_nodes_.empty()) {
383     AnfNodePtr ready_node;
384     if ((!visit_nodes_info.default_target_nodes_.empty() && handle_target == default_target) ||
385         visit_nodes_info.other_target_nodes_.empty()) {
386       ready_node = visit_nodes_info.default_target_nodes_.front();
387       visit_nodes_info.default_target_nodes_.pop();
388       handle_target = default_target;
389     } else {
390       ready_node = visit_nodes_info.other_target_nodes_.front();
391       visit_nodes_info.other_target_nodes_.pop();
392       handle_target = other_target;
393     }
394     MS_EXCEPTION_IF_NULL(ready_node);
395     auto cast_map_iter = visit_nodes_info.seed_cast_next_node_.find(ready_node);
396     if (cast_map_iter != visit_nodes_info.seed_cast_next_node_.end()) {
397       result.emplace_back(cast_map_iter->second);
398     }
399     result.emplace_back(ready_node);
400     auto output_edges_iter = dependency_info.output_edges_.find(ready_node);
401     if (output_edges_iter == dependency_info.output_edges_.end()) {
402       continue;
403     }
404     auto &output_edges = output_edges_iter->second;
405     handle_target = ParallelSortDecideNextHandleTarget(output_edges, handle_target, &node_input_target_map);
406     ParallelSortVisitNodeEdges(output_edges, &dependency_info, &visit_nodes_info, default_target);
407   }
408   return result;
409 }
410 
AddSegmentDependency(const FuncGraphPtr & graph,const std::map<AnfNodePtr,GraphSegmentPtr> & node_to_segment)411 void AddSegmentDependency(const FuncGraphPtr &graph, const std::map<AnfNodePtr, GraphSegmentPtr> &node_to_segment) {
412   MS_EXCEPTION_IF_NULL(graph);
413   std::stack<AnfNodePtr> to_visit;
414   std::map<AnfNodePtr, size_t> nodes_ref;
415   CalcNodeRefCount(graph, &nodes_ref);
416   to_visit.push(graph->get_return());
417   while (!to_visit.empty()) {
418     auto &node = to_visit.top();
419     MS_EXCEPTION_IF_NULL(node);
420     to_visit.pop();
421     if (!node->isa<CNode>()) {
422       continue;
423     }
424     auto cnode = node->cast<CNodePtr>();
425     MS_EXCEPTION_IF_NULL(cnode);
426     auto node_inputs = cnode->inputs();
427     GraphSegmentPtr node_segment{nullptr};
428     auto node_iter = node_to_segment.find(node);
429     if (node_iter != node_to_segment.end()) {
430       node_segment = node_iter->second;
431     }
432     for (auto &input : node_inputs) {
433       if (node_segment != nullptr && !node_segment->is_cut_ && input != nullptr && input->isa<CNode>()) {
434         GraphSegmentPtr input_segment{nullptr};
435         auto input_iter = node_to_segment.find(input);
436         if (input_iter != node_to_segment.end()) {
437           input_segment = input_iter->second;
438         }
439         if (input_segment != nullptr && input_segment != node_segment && !input_segment->is_cut_) {
440           node_segment->AddPreSegment(input_segment);
441         }
442       }
443       auto ref_iter = nodes_ref.find(input);
444       if (ref_iter != nodes_ref.end()) {
445         ref_iter->second--;
446         if (ref_iter->second != 0) {
447           continue;
448         }
449       }
450       to_visit.push(input);
451     }
452   }
453 }
454 
RemoveUselessDependency(const std::vector<GraphSegmentPtr> * segments)455 void RemoveUselessDependency(const std::vector<GraphSegmentPtr> *segments) {
456   MS_EXCEPTION_IF_NULL(segments);
457   for (auto &segment : *segments) {
458     MS_EXCEPTION_IF_NULL(segment);
459     if (segment->is_cut_) {
460       continue;
461     }
462     bool total_virtual_node = true;
463     for (auto &node : segment->nodes_) {
464       if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
465           IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
466           IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
467           IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
468           IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
469         continue;
470       }
471       total_virtual_node = false;
472       break;
473     }
474     if (total_virtual_node) {
475       segment->pre_segments_.clear();
476     }
477   }
478 }
479 
IsSubGraph(const AnfNodePtr & node)480 bool IsSubGraph(const AnfNodePtr &node) {
481   MS_EXCEPTION_IF_NULL(node);
482   if (node->isa<CNode>()) {
483     auto cnode = node->cast<CNodePtr>();
484     auto &inputs = cnode->inputs();
485     if (inputs.empty()) {
486       MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
487     }
488 
489     AnfNodePtr fn = inputs[0];
490     if (!IsValueNode<Primitive>(fn)) {
491       return false;
492     }
493     auto node_prim = GetValueNode<PrimitivePtr>(fn);
494     if (node_prim->name() == prim::kPrimPartial->name()) {
495       return true;
496     }
497   } else if (IsValueNode<FuncGraph>(node)) {
498     return true;
499   }
500   return false;
501 }
502 
AddSegment(const std::vector<AnfNodePtr> & nodes,std::vector<GraphSegmentPtr> * segments,std::map<AnfNodePtr,GraphSegmentPtr> * node_to_segment)503 void AddSegment(const std::vector<AnfNodePtr> &nodes, std::vector<GraphSegmentPtr> *segments,
504                 std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
505   MS_EXCEPTION_IF_NULL(segments);
506   MS_EXCEPTION_IF_NULL(node_to_segment);
507   auto segment = std::make_shared<GraphSegment>(nodes, false);
508   segments->emplace_back(segment);
509   for (auto &node : nodes) {
510     (*node_to_segment)[node] = segment;
511   }
512 }
513 
514 struct SplitDynamicNodesHelper {
AddNodemindspore::compile::__anon303885100111::SplitDynamicNodesHelper515   void AddNode(const AnfNodePtr &node, bool is_dynamic_shape) {
516     if (is_dynamic_shape) {
517       pre_dynamic_nodes.emplace_back(node);
518       pre_dynamic_nodes_set.insert(node);
519     } else {
520       pre_common_nodes.emplace_back(node);
521       pre_common_nodes_set.insert(node);
522     }
523     pre_nodes.emplace_back(node);
524   }
525 
AddSegmentsmindspore::compile::__anon303885100111::SplitDynamicNodesHelper526   void AddSegments(std::vector<GraphSegmentPtr> *segments, std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
527     MS_EXCEPTION_IF_NULL(segments);
528     MS_EXCEPTION_IF_NULL(node_to_segment);
529     if (pre_nodes.size() < merge_node_threshold) {
530       AddSegment(pre_nodes, segments, node_to_segment);
531     } else {
532       if (!pre_common_nodes.empty()) {
533         AddSegment(pre_common_nodes, segments, node_to_segment);
534       }
535       if (!pre_dynamic_nodes.empty()) {
536         AddSegment(pre_dynamic_nodes, segments, node_to_segment);
537       }
538     }
539     pre_common_nodes.clear();
540     pre_common_nodes_set.clear();
541     pre_dynamic_nodes.clear();
542     pre_dynamic_nodes_set.clear();
543     pre_nodes.clear();
544   }
545   std::vector<AnfNodePtr> pre_nodes;
546   std::vector<AnfNodePtr> pre_dynamic_nodes;
547   std::vector<AnfNodePtr> pre_common_nodes;
548   std::set<AnfNodePtr> pre_common_nodes_set;
549   std::set<AnfNodePtr> pre_dynamic_nodes_set;
550   size_t merge_node_threshold = 6;
551 };
552 
SplitDynamicNodeSegment(const std::vector<AnfNodePtr> & segment_nodes,std::vector<GraphSegmentPtr> * segments,std::map<AnfNodePtr,GraphSegmentPtr> * node_to_segment,const std::set<AnfNodePtr> & dynamic_nodes_set)553 void SplitDynamicNodeSegment(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments,
554                              std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment,
555                              const std::set<AnfNodePtr> &dynamic_nodes_set) {
556   SplitDynamicNodesHelper helper;
557   for (auto &node : segment_nodes) {
558     MS_EXCEPTION_IF_NULL(node);
559     auto cnode = node->cast<CNodePtr>();
560     MS_EXCEPTION_IF_NULL(cnode);
561     auto &inputs = cnode->inputs();
562     bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end();
563     bool depend_common_node = false;
564     bool depend_dynamic_node = false;
565     bool is_last_node_dynamic = false;
566     for (size_t i = 1; i < inputs.size(); ++i) {
567       if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) {
568         has_dynamic_shape = true;
569       }
570       if (helper.pre_common_nodes_set.find(inputs[i]) != helper.pre_common_nodes_set.end()) {
571         depend_common_node = true;
572       }
573       if (helper.pre_dynamic_nodes_set.find(inputs[i]) != helper.pre_dynamic_nodes_set.end()) {
574         depend_dynamic_node = true;
575       }
576     }
577     if (has_dynamic_shape) {
578       if (depend_common_node) {
579         helper.AddSegments(segments, node_to_segment);
580       }
581       is_last_node_dynamic = true;
582     } else {
583       if (depend_dynamic_node) {
584         helper.AddSegments(segments, node_to_segment);
585       }
586       is_last_node_dynamic = false;
587     }
588     helper.AddNode(node, is_last_node_dynamic);
589   }
590   helper.AddSegments(segments, node_to_segment);
591 }
592 
NodesToSegments(const std::vector<AnfNodePtr> & segment_nodes,std::vector<GraphSegmentPtr> * segments,std::map<AnfNodePtr,GraphSegmentPtr> * node_to_segment)593 void NodesToSegments(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments,
594                      std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
595   if (segment_nodes.empty()) {
596     return;
597   }
598   auto segment_target = GetCNodeTarget(segment_nodes[0]);
599   if (segment_target != kAscendDevice) {
600     AddSegment(segment_nodes, segments, node_to_segment);
601     return;
602   }
603   MS_EXCEPTION_IF_NULL(segments);
604   MS_EXCEPTION_IF_NULL(node_to_segment);
605   std::set<AnfNodePtr> dynamic_nodes_set;
606   for (auto &node : segment_nodes) {
607     MS_EXCEPTION_IF_NULL(node);
608     auto cnode = node->cast<CNodePtr>();
609     if (AnfUtils::IsNodeOutputDynamicShape(cnode)) {
610       (void)dynamic_nodes_set.insert(node);
611     }
612   }
613   if (dynamic_nodes_set.empty()) {
614     AddSegment(segment_nodes, segments, node_to_segment);
615     return;
616   }
617   SplitDynamicNodeSegment(segment_nodes, segments, node_to_segment, dynamic_nodes_set);
618 }
619 }  // namespace
620 
GraphPartition(const std::vector<PrimitivePtr> & cut_list,const std::string & backend_name)621 GraphPartition::GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name)
622     : cut_list_(cut_list), backend_name_(backend_name) {}
623 
IsCut(const AnfNodePtr & node)624 bool GraphPartition::IsCut(const AnfNodePtr &node) {
625   MS_EXCEPTION_IF_NULL(node);
626   if (node->isa<CNode>()) {
627     auto cnode = node->cast<CNodePtr>();
628     auto &inputs = cnode->inputs();
629     if (inputs.empty()) {
630       MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
631     }
632     AnfNodePtr fn = inputs[0];
633     if (!IsValueNode<Primitive>(fn)) {
634       return true;
635     }
636     auto node_prim = GetValueNode<PrimitivePtr>(fn);
637     for (auto &prim : cut_list_) {
638       MS_EXCEPTION_IF_NULL(prim);
639       if (prim->name() == node_prim->name()) {
640         if (prim->name() == prim::kPrimBpropCut->name()) {
641           auto ms_context = MsContext::GetInstance();
642           MS_EXCEPTION_IF_NULL(ms_context);
643           ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
644         }
645         if (backend_name_ == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
646           if (inputs.size() <= 1) {
647             return false;
648           }
649           auto ret = IsSubGraph(inputs[1]);
650           return ret;
651         }
652         return true;
653       }
654     }
655 #ifdef ENABLE_GE
656     if (backend_name_ == kGeVm) {
657       auto name = GetCNodeFuncName(cnode);
658       auto adpt = transform::DfGraphConvertor::FindAdapter(name);
659       if (adpt == nullptr) {
660         return true;
661       }
662     }
663 #endif
664   }
665   return false;
666 }
667 
Partition(const FuncGraphPtr & graph,bool * multi_target)668 std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph, bool *multi_target) {
669   MS_EXCEPTION_IF_NULL(graph);
670   auto nodes = TopoSort(graph->get_return());
671   MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
672   bool contain_multi_target = ContainMultiTarget(nodes);
673   if (multi_target != nullptr) {
674     *multi_target = contain_multi_target;
675   }
676 
677   auto context_ptr = MsContext::GetInstance();
678   MS_EXCEPTION_IF_NULL(context_ptr);
679   auto enable_loop_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
680   std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
681   if (contain_multi_target || !enable_loop_sink) {
682     if (context_ptr->get_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT)) {
683       auto other_target = GetOtherTarget(nodes);
684       nodes = ParallelSort(graph, default_target, other_target);
685     } else {
686       nodes = SplitSort(graph, default_target);
687     }
688     nodes = ReorderVirtualNode(nodes, prim::kPrimTupleGetItem);
689     nodes = ReorderVirtualNode(nodes, prim::kPrimDepend);
690   }
691   std::vector<GraphSegmentPtr> segments;
692   std::vector<AnfNodePtr> segment_nodes;
693   std::map<AnfNodePtr, GraphSegmentPtr> node_to_segment;
694   std::string last_target;
695   for (auto &node : nodes) {
696     MS_EXCEPTION_IF_NULL(node);
697     if (IsCut(node)) {
698       NodesToSegments(segment_nodes, &segments, &node_to_segment);
699       segment_nodes.clear();
700       segment_nodes.emplace_back(node);
701       auto segment = std::make_shared<GraphSegment>(segment_nodes, true);
702       segments.push_back(segment);
703       segment_nodes.clear();
704     } else if (node->isa<CNode>()) {
705       if (contain_multi_target) {
706         std::string cur_target = GetCNodeTarget(node);
707         if (cur_target != last_target && !last_target.empty()) {
708           NodesToSegments(segment_nodes, &segments, &node_to_segment);
709           segment_nodes.clear();
710         }
711         last_target = cur_target;
712       }
713       segment_nodes.emplace_back(node);
714     }
715   }
716   MS_LOG(DEBUG) << "Segment size:" << segments.size();
717   if (contain_multi_target) {
718     AddSegmentDependency(graph, node_to_segment);
719     RemoveUselessDependency(&segments);
720   }
721   return segments;
722 }
723 }  // namespace compile
724 }  // namespace mindspore
725