1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "vm/graph_partition.h"
18 #include <string>
19 #include <functional>
20 #include <utility>
21 #include <map>
22 #include <queue>
23 #include <stack>
24 #include <set>
25 #include <algorithm>
26 #include "base/core_ops.h"
27 #include "utils/utils.h"
28 #include "utils/ms_context.h"
29 #include "ps/ps_context.h"
30 #include "ir/anf_utils.h"
31 #ifdef ENABLE_GE
32 #include "transform/graph_ir/convert.h"
33 #endif
34 namespace mindspore {
35 const char kMsConvert[] = "ms";
36 const char kMsVm[] = "vm";
37 const char kGeVm[] = "ge";
38 namespace compile {
39 namespace {
GetOtherTarget(const std::vector<AnfNodePtr> & nodes)40 std::string GetOtherTarget(const std::vector<AnfNodePtr> &nodes) {
41 auto context_ptr = MsContext::GetInstance();
42 MS_EXCEPTION_IF_NULL(context_ptr);
43 std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
44 for (auto &node : nodes) {
45 MS_EXCEPTION_IF_NULL(node);
46 if (!node->isa<CNode>()) {
47 continue;
48 }
49 std::string cur_target = GetCNodeTarget(node);
50 if (cur_target != default_target) {
51 return cur_target;
52 }
53 }
54 return "";
55 }
56
CalcNodeRefCount(const FuncGraphPtr & graph,std::map<AnfNodePtr,size_t> * nodes_ref)57 void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref) {
58 MS_EXCEPTION_IF_NULL(graph);
59 MS_EXCEPTION_IF_NULL(nodes_ref);
60 std::queue<AnfNodePtr> queue;
61 queue.push(graph->get_return());
62 std::set<AnfNodePtr> visited;
63 while (!queue.empty()) {
64 auto node = queue.front();
65 queue.pop();
66 MS_EXCEPTION_IF_NULL(node);
67 if (!node->isa<CNode>()) {
68 continue;
69 }
70 auto cnode = node->cast<CNodePtr>();
71 MS_EXCEPTION_IF_NULL(cnode);
72 for (auto &input : cnode->inputs()) {
73 auto iter = nodes_ref->find(input);
74 if (iter != nodes_ref->end()) {
75 iter->second++;
76 } else {
77 (void)nodes_ref->emplace(input, 1UL);
78 }
79 if (visited.find(input) != visited.end()) {
80 continue;
81 }
82 visited.insert(input);
83 queue.push(input);
84 }
85 }
86 }
87
ReorderVirtualNode(const std::vector<AnfNodePtr> & nodes,const PrimitivePtr & reorder_prim)88 std::vector<AnfNodePtr> ReorderVirtualNode(const std::vector<AnfNodePtr> &nodes, const PrimitivePtr &reorder_prim) {
89 std::vector<AnfNodePtr> result;
90 std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
91 std::map<AnfNodePtr, size_t> node_positions;
92 auto add_insert_position = [&insert_positions, &node_positions](const AnfNodePtr &node, const AnfNodePtr &parent) {
93 if (parent == nullptr) {
94 return false;
95 }
96 auto iter = node_positions.find(parent);
97 if (iter != node_positions.end()) {
98 size_t position = iter->second;
99 auto iter_nodes = insert_positions.find(position);
100 if (iter_nodes != insert_positions.end()) {
101 iter_nodes->second.push_back(node);
102 } else {
103 (void)insert_positions.emplace(position, std::vector<AnfNodePtr>{node});
104 }
105 return true;
106 }
107 return false;
108 };
109 for (auto &node : nodes) {
110 MS_EXCEPTION_IF_NULL(node);
111 if (IsPrimitiveCNode(node, reorder_prim)) {
112 auto cnode = node->cast<CNodePtr>();
113 MS_EXCEPTION_IF_NULL(cnode);
114 auto &inputs = cnode->inputs();
115 AnfNodePtr parent = nullptr;
116 const size_t depend_input_size = 2;
117 if (reorder_prim == prim::kPrimDepend && inputs.size() == depend_input_size + 1 && !inputs[1]->isa<CNode>()) {
118 parent = inputs[depend_input_size];
119 } else if (reorder_prim == prim::kPrimTupleGetItem && inputs.size() > 1) {
120 parent = inputs[1];
121 }
122 if (add_insert_position(node, parent)) {
123 continue;
124 }
125 }
126 result.emplace_back(node);
127 node_positions[node] = result.size();
128 }
129
130 size_t insert_num = 0;
131 for (auto &item : insert_positions) {
132 auto position = SizeToLong(item.first + insert_num);
133 (void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
134 insert_num += item.second.size();
135 }
136 return result;
137 }
138
GetNextNodes(const AnfNodePtr & node,std::map<AnfNodePtr,size_t> * nodes_ref,std::vector<AnfNodePtr> * result)139 std::vector<AnfNodePtr> GetNextNodes(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *nodes_ref,
140 std::vector<AnfNodePtr> *result) {
141 MS_EXCEPTION_IF_NULL(node);
142 MS_EXCEPTION_IF_NULL(nodes_ref);
143 MS_EXCEPTION_IF_NULL(result);
144 auto cnode = node->cast<CNodePtr>();
145 MS_EXCEPTION_IF_NULL(cnode);
146 auto node_inputs = cnode->inputs();
147 if (!IsPrimitiveCNode(node, prim::kPrimSwitch)) {
148 std::reverse(node_inputs.begin(), node_inputs.end());
149 return node_inputs;
150 }
151 std::vector<AnfNodePtr> extend_inputs;
152 for (auto &input : node_inputs) {
153 MS_EXCEPTION_IF_NULL(input);
154 if (IsPrimitiveCNode(input, prim::kPrimPartial)) {
155 auto iter = nodes_ref->find(input);
156 if (iter != nodes_ref->end() && iter->second == 1) {
157 iter->second--;
158 result->emplace_back(input);
159 auto partial_cnode = input->cast<CNodePtr>();
160 MS_EXCEPTION_IF_NULL(partial_cnode);
161 auto partial_inputs = partial_cnode->inputs();
162 std::reverse(partial_inputs.begin(), partial_inputs.end());
163 (void)extend_inputs.insert(extend_inputs.end(), partial_inputs.begin(), partial_inputs.end());
164 continue;
165 }
166 }
167 extend_inputs.emplace_back(input);
168 }
169 return extend_inputs;
170 }
171
SplitSort(const FuncGraphPtr & graph,const std::string & default_target)172 std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
173 MS_EXCEPTION_IF_NULL(graph);
174 std::vector<AnfNodePtr> result;
175 std::stack<AnfNodePtr> to_visit;
176 std::stack<AnfNodePtr> next_to_visit;
177 std::map<AnfNodePtr, size_t> nodes_ref;
178 CalcNodeRefCount(graph, &nodes_ref);
179 std::string handle_target = default_target;
180 std::string next_target;
181 to_visit.push(graph->get_return());
182 while (!to_visit.empty() || !next_to_visit.empty()) {
183 if (to_visit.empty()) {
184 to_visit.swap(next_to_visit);
185 handle_target = next_target;
186 }
187 auto node = to_visit.top();
188 MS_EXCEPTION_IF_NULL(node);
189 to_visit.pop();
190 result.emplace_back(node);
191 if (!node->isa<CNode>()) {
192 continue;
193 }
194 auto next_nodes = GetNextNodes(node, &nodes_ref, &result);
195 for (auto &input : next_nodes) {
196 MS_EXCEPTION_IF_NULL(input);
197 auto iter = nodes_ref.find(input);
198 if (iter != nodes_ref.end()) {
199 iter->second--;
200 if (iter->second != 0) {
201 continue;
202 }
203 }
204 if (!input->isa<CNode>()) {
205 to_visit.push(input);
206 continue;
207 }
208 std::string input_target = GetCNodeTarget(input);
209 if (input_target == handle_target) {
210 to_visit.push(input);
211 } else if (next_to_visit.empty() || input_target == next_target) {
212 next_to_visit.push(input);
213 next_target = input_target;
214 } else {
215 MS_LOG(EXCEPTION) << "Only support two different target";
216 }
217 }
218 }
219 std::reverse(result.begin(), result.end());
220 return result;
221 }
222
223 struct GraphNodesDependencyInfo {
224 std::stack<AnfNodePtr> independent_nodes_;
225 std::map<AnfNodePtr, size_t> input_num_;
226 std::map<AnfNodePtr, std::vector<AnfNodePtr>> output_edges_;
227 };
228
GetNodesDependencyInfo(const FuncGraphPtr & graph)229 GraphNodesDependencyInfo GetNodesDependencyInfo(const FuncGraphPtr &graph) {
230 MS_EXCEPTION_IF_NULL(graph);
231 GraphNodesDependencyInfo info;
232 std::stack<AnfNodePtr> to_visit;
233 std::map<AnfNodePtr, size_t> nodes_ref;
234 CalcNodeRefCount(graph, &nodes_ref);
235 to_visit.push(graph->get_return());
236 while (!to_visit.empty()) {
237 auto node = to_visit.top();
238 to_visit.pop();
239 if (node == nullptr || !node->isa<CNode>()) {
240 continue;
241 }
242 auto cnode = node->cast<CNodePtr>();
243 MS_EXCEPTION_IF_NULL(cnode);
244 auto node_inputs = cnode->inputs();
245 bool independent = true;
246 for (auto &input : node_inputs) {
247 MS_EXCEPTION_IF_NULL(input);
248 if (input->isa<CNode>()) {
249 independent = false;
250 auto output_edge_iter = info.output_edges_.find(input);
251 if (output_edge_iter != info.output_edges_.end()) {
252 auto &edges = output_edge_iter->second;
253 edges.emplace_back(node);
254 } else {
255 info.output_edges_[input] = {node};
256 }
257 auto input_num_iter = info.input_num_.find(node);
258 if (input_num_iter != info.input_num_.end()) {
259 input_num_iter->second++;
260 } else {
261 info.input_num_[node] = 1;
262 }
263 }
264 auto ref_iter = nodes_ref.find(input);
265 if (ref_iter != nodes_ref.end()) {
266 ref_iter->second--;
267 if (ref_iter->second != 0) {
268 continue;
269 }
270 }
271 to_visit.push(input);
272 }
273 if (independent) {
274 info.independent_nodes_.push(node);
275 }
276 }
277 return info;
278 }
279
280 struct VisitNodesInfo {
281 std::queue<AnfNodePtr> default_target_nodes_;
282 std::queue<AnfNodePtr> other_target_nodes_;
283 std::map<AnfNodePtr, AnfNodePtr> seed_cast_next_node_;
284 };
285
GetVisitNodesInfo(const GraphNodesDependencyInfo & dependency_info,const std::string & default_target,const std::string & other_target)286 VisitNodesInfo GetVisitNodesInfo(const GraphNodesDependencyInfo &dependency_info, const std::string &default_target,
287 const std::string &other_target) {
288 VisitNodesInfo result;
289 auto independent_nodes = dependency_info.independent_nodes_;
290 while (!independent_nodes.empty()) {
291 auto seed_node = independent_nodes.top();
292 independent_nodes.pop();
293 MS_EXCEPTION_IF_NULL(seed_node);
294 auto node_target = GetCNodeTarget(seed_node);
295 if (IsPrimitiveCNode(seed_node, prim::kPrimCast)) {
296 auto output_edges_iter = dependency_info.output_edges_.find(seed_node);
297 if (output_edges_iter != dependency_info.output_edges_.end() && output_edges_iter->second.size() == 1) {
298 auto &cast_next_node = output_edges_iter->second[0];
299 auto input_num_iter = dependency_info.input_num_.find(cast_next_node);
300 if (input_num_iter == dependency_info.input_num_.end()) {
301 MS_LOG(EXCEPTION) << "Node input num not found!";
302 }
303 if (input_num_iter->second > 1 && node_target == GetCNodeTarget(cast_next_node)) {
304 result.seed_cast_next_node_[cast_next_node] = seed_node;
305 continue;
306 }
307 }
308 }
309 if (node_target == default_target) {
310 result.default_target_nodes_.push(seed_node);
311 } else if (node_target == other_target) {
312 result.other_target_nodes_.push(seed_node);
313 } else {
314 MS_LOG(EXCEPTION) << "Only support two difference targets";
315 }
316 }
317 return result;
318 }
319
ParallelSortDecideNextHandleTarget(const std::vector<AnfNodePtr> & output_edges,const std::string & node_target,std::map<AnfNodePtr,std::string> * node_input_target_map)320 std::string ParallelSortDecideNextHandleTarget(const std::vector<AnfNodePtr> &output_edges,
321 const std::string &node_target,
322 std::map<AnfNodePtr, std::string> *node_input_target_map) {
323 MS_EXCEPTION_IF_NULL(node_input_target_map);
324 std::string next_target = node_target;
325 for (auto &dst_node : output_edges) {
326 auto input_target_iter = node_input_target_map->find(dst_node);
327 if (input_target_iter != node_input_target_map->end() && input_target_iter->second != node_target) {
328 next_target = input_target_iter->second;
329 break;
330 }
331 auto dst_node_target = GetCNodeTarget(dst_node);
332 if (dst_node_target != node_target) {
333 next_target = dst_node_target;
334 break;
335 }
336 }
337 for (auto &dst_node : output_edges) {
338 (*node_input_target_map)[dst_node] = node_target;
339 }
340 return next_target;
341 }
342
ParallelSortVisitNodeEdges(const std::vector<AnfNodePtr> & output_edges,GraphNodesDependencyInfo * dependency_info,VisitNodesInfo * visit_nodes_info,const std::string & default_target)343 void ParallelSortVisitNodeEdges(const std::vector<AnfNodePtr> &output_edges, GraphNodesDependencyInfo *dependency_info,
344 VisitNodesInfo *visit_nodes_info, const std::string &default_target) {
345 MS_EXCEPTION_IF_NULL(dependency_info);
346 MS_EXCEPTION_IF_NULL(visit_nodes_info);
347 for (auto &dst_node : output_edges) {
348 auto dst_node_target = GetCNodeTarget(dst_node);
349 auto input_num_iter = dependency_info->input_num_.find(dst_node);
350 if (input_num_iter == dependency_info->input_num_.end()) {
351 MS_LOG(EXCEPTION) << "Node input num not found!";
352 }
353 input_num_iter->second--;
354 if (input_num_iter->second == 1 &&
355 visit_nodes_info->seed_cast_next_node_.find(dst_node) != visit_nodes_info->seed_cast_next_node_.end()) {
356 input_num_iter->second--;
357 }
358 if (input_num_iter->second > 0) {
359 continue;
360 }
361 if (dst_node_target == default_target) {
362 visit_nodes_info->default_target_nodes_.push(dst_node);
363 } else {
364 visit_nodes_info->other_target_nodes_.push(dst_node);
365 }
366 }
367 }
368
ParallelSort(const FuncGraphPtr & graph,const std::string & default_target,const std::string & other_target)369 std::vector<AnfNodePtr> ParallelSort(const FuncGraphPtr &graph, const std::string &default_target,
370 const std::string &other_target) {
371 MS_EXCEPTION_IF_NULL(graph);
372 auto dependency_info = GetNodesDependencyInfo(graph);
373 auto visit_nodes_info = GetVisitNodesInfo(dependency_info, default_target, other_target);
374 std::vector<AnfNodePtr> result;
375 std::string handle_target;
376 if (!visit_nodes_info.default_target_nodes_.empty()) {
377 handle_target = default_target;
378 } else {
379 handle_target = other_target;
380 }
381 std::map<AnfNodePtr, std::string> node_input_target_map;
382 while (!visit_nodes_info.default_target_nodes_.empty() || !visit_nodes_info.other_target_nodes_.empty()) {
383 AnfNodePtr ready_node;
384 if ((!visit_nodes_info.default_target_nodes_.empty() && handle_target == default_target) ||
385 visit_nodes_info.other_target_nodes_.empty()) {
386 ready_node = visit_nodes_info.default_target_nodes_.front();
387 visit_nodes_info.default_target_nodes_.pop();
388 handle_target = default_target;
389 } else {
390 ready_node = visit_nodes_info.other_target_nodes_.front();
391 visit_nodes_info.other_target_nodes_.pop();
392 handle_target = other_target;
393 }
394 MS_EXCEPTION_IF_NULL(ready_node);
395 auto cast_map_iter = visit_nodes_info.seed_cast_next_node_.find(ready_node);
396 if (cast_map_iter != visit_nodes_info.seed_cast_next_node_.end()) {
397 result.emplace_back(cast_map_iter->second);
398 }
399 result.emplace_back(ready_node);
400 auto output_edges_iter = dependency_info.output_edges_.find(ready_node);
401 if (output_edges_iter == dependency_info.output_edges_.end()) {
402 continue;
403 }
404 auto &output_edges = output_edges_iter->second;
405 handle_target = ParallelSortDecideNextHandleTarget(output_edges, handle_target, &node_input_target_map);
406 ParallelSortVisitNodeEdges(output_edges, &dependency_info, &visit_nodes_info, default_target);
407 }
408 return result;
409 }
410
AddSegmentDependency(const FuncGraphPtr & graph,const std::map<AnfNodePtr,GraphSegmentPtr> & node_to_segment)411 void AddSegmentDependency(const FuncGraphPtr &graph, const std::map<AnfNodePtr, GraphSegmentPtr> &node_to_segment) {
412 MS_EXCEPTION_IF_NULL(graph);
413 std::stack<AnfNodePtr> to_visit;
414 std::map<AnfNodePtr, size_t> nodes_ref;
415 CalcNodeRefCount(graph, &nodes_ref);
416 to_visit.push(graph->get_return());
417 while (!to_visit.empty()) {
418 auto &node = to_visit.top();
419 MS_EXCEPTION_IF_NULL(node);
420 to_visit.pop();
421 if (!node->isa<CNode>()) {
422 continue;
423 }
424 auto cnode = node->cast<CNodePtr>();
425 MS_EXCEPTION_IF_NULL(cnode);
426 auto node_inputs = cnode->inputs();
427 GraphSegmentPtr node_segment{nullptr};
428 auto node_iter = node_to_segment.find(node);
429 if (node_iter != node_to_segment.end()) {
430 node_segment = node_iter->second;
431 }
432 for (auto &input : node_inputs) {
433 if (node_segment != nullptr && !node_segment->is_cut_ && input != nullptr && input->isa<CNode>()) {
434 GraphSegmentPtr input_segment{nullptr};
435 auto input_iter = node_to_segment.find(input);
436 if (input_iter != node_to_segment.end()) {
437 input_segment = input_iter->second;
438 }
439 if (input_segment != nullptr && input_segment != node_segment && !input_segment->is_cut_) {
440 node_segment->AddPreSegment(input_segment);
441 }
442 }
443 auto ref_iter = nodes_ref.find(input);
444 if (ref_iter != nodes_ref.end()) {
445 ref_iter->second--;
446 if (ref_iter->second != 0) {
447 continue;
448 }
449 }
450 to_visit.push(input);
451 }
452 }
453 }
454
RemoveUselessDependency(const std::vector<GraphSegmentPtr> * segments)455 void RemoveUselessDependency(const std::vector<GraphSegmentPtr> *segments) {
456 MS_EXCEPTION_IF_NULL(segments);
457 for (auto &segment : *segments) {
458 MS_EXCEPTION_IF_NULL(segment);
459 if (segment->is_cut_) {
460 continue;
461 }
462 bool total_virtual_node = true;
463 for (auto &node : segment->nodes_) {
464 if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
465 IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
466 IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
467 IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
468 IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
469 continue;
470 }
471 total_virtual_node = false;
472 break;
473 }
474 if (total_virtual_node) {
475 segment->pre_segments_.clear();
476 }
477 }
478 }
479
IsSubGraph(const AnfNodePtr & node)480 bool IsSubGraph(const AnfNodePtr &node) {
481 MS_EXCEPTION_IF_NULL(node);
482 if (node->isa<CNode>()) {
483 auto cnode = node->cast<CNodePtr>();
484 auto &inputs = cnode->inputs();
485 if (inputs.empty()) {
486 MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
487 }
488
489 AnfNodePtr fn = inputs[0];
490 if (!IsValueNode<Primitive>(fn)) {
491 return false;
492 }
493 auto node_prim = GetValueNode<PrimitivePtr>(fn);
494 if (node_prim->name() == prim::kPrimPartial->name()) {
495 return true;
496 }
497 } else if (IsValueNode<FuncGraph>(node)) {
498 return true;
499 }
500 return false;
501 }
502
AddSegment(const std::vector<AnfNodePtr> & nodes,std::vector<GraphSegmentPtr> * segments,std::map<AnfNodePtr,GraphSegmentPtr> * node_to_segment)503 void AddSegment(const std::vector<AnfNodePtr> &nodes, std::vector<GraphSegmentPtr> *segments,
504 std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
505 MS_EXCEPTION_IF_NULL(segments);
506 MS_EXCEPTION_IF_NULL(node_to_segment);
507 auto segment = std::make_shared<GraphSegment>(nodes, false);
508 segments->emplace_back(segment);
509 for (auto &node : nodes) {
510 (*node_to_segment)[node] = segment;
511 }
512 }
513
514 struct SplitDynamicNodesHelper {
AddNodemindspore::compile::__anon303885100111::SplitDynamicNodesHelper515 void AddNode(const AnfNodePtr &node, bool is_dynamic_shape) {
516 if (is_dynamic_shape) {
517 pre_dynamic_nodes.emplace_back(node);
518 pre_dynamic_nodes_set.insert(node);
519 } else {
520 pre_common_nodes.emplace_back(node);
521 pre_common_nodes_set.insert(node);
522 }
523 pre_nodes.emplace_back(node);
524 }
525
AddSegmentsmindspore::compile::__anon303885100111::SplitDynamicNodesHelper526 void AddSegments(std::vector<GraphSegmentPtr> *segments, std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
527 MS_EXCEPTION_IF_NULL(segments);
528 MS_EXCEPTION_IF_NULL(node_to_segment);
529 if (pre_nodes.size() < merge_node_threshold) {
530 AddSegment(pre_nodes, segments, node_to_segment);
531 } else {
532 if (!pre_common_nodes.empty()) {
533 AddSegment(pre_common_nodes, segments, node_to_segment);
534 }
535 if (!pre_dynamic_nodes.empty()) {
536 AddSegment(pre_dynamic_nodes, segments, node_to_segment);
537 }
538 }
539 pre_common_nodes.clear();
540 pre_common_nodes_set.clear();
541 pre_dynamic_nodes.clear();
542 pre_dynamic_nodes_set.clear();
543 pre_nodes.clear();
544 }
545 std::vector<AnfNodePtr> pre_nodes;
546 std::vector<AnfNodePtr> pre_dynamic_nodes;
547 std::vector<AnfNodePtr> pre_common_nodes;
548 std::set<AnfNodePtr> pre_common_nodes_set;
549 std::set<AnfNodePtr> pre_dynamic_nodes_set;
550 size_t merge_node_threshold = 6;
551 };
552
SplitDynamicNodeSegment(const std::vector<AnfNodePtr> & segment_nodes,std::vector<GraphSegmentPtr> * segments,std::map<AnfNodePtr,GraphSegmentPtr> * node_to_segment,const std::set<AnfNodePtr> & dynamic_nodes_set)553 void SplitDynamicNodeSegment(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments,
554 std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment,
555 const std::set<AnfNodePtr> &dynamic_nodes_set) {
556 SplitDynamicNodesHelper helper;
557 for (auto &node : segment_nodes) {
558 MS_EXCEPTION_IF_NULL(node);
559 auto cnode = node->cast<CNodePtr>();
560 MS_EXCEPTION_IF_NULL(cnode);
561 auto &inputs = cnode->inputs();
562 bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end();
563 bool depend_common_node = false;
564 bool depend_dynamic_node = false;
565 bool is_last_node_dynamic = false;
566 for (size_t i = 1; i < inputs.size(); ++i) {
567 if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) {
568 has_dynamic_shape = true;
569 }
570 if (helper.pre_common_nodes_set.find(inputs[i]) != helper.pre_common_nodes_set.end()) {
571 depend_common_node = true;
572 }
573 if (helper.pre_dynamic_nodes_set.find(inputs[i]) != helper.pre_dynamic_nodes_set.end()) {
574 depend_dynamic_node = true;
575 }
576 }
577 if (has_dynamic_shape) {
578 if (depend_common_node) {
579 helper.AddSegments(segments, node_to_segment);
580 }
581 is_last_node_dynamic = true;
582 } else {
583 if (depend_dynamic_node) {
584 helper.AddSegments(segments, node_to_segment);
585 }
586 is_last_node_dynamic = false;
587 }
588 helper.AddNode(node, is_last_node_dynamic);
589 }
590 helper.AddSegments(segments, node_to_segment);
591 }
592
NodesToSegments(const std::vector<AnfNodePtr> & segment_nodes,std::vector<GraphSegmentPtr> * segments,std::map<AnfNodePtr,GraphSegmentPtr> * node_to_segment)593 void NodesToSegments(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments,
594 std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
595 if (segment_nodes.empty()) {
596 return;
597 }
598 auto segment_target = GetCNodeTarget(segment_nodes[0]);
599 if (segment_target != kAscendDevice) {
600 AddSegment(segment_nodes, segments, node_to_segment);
601 return;
602 }
603 MS_EXCEPTION_IF_NULL(segments);
604 MS_EXCEPTION_IF_NULL(node_to_segment);
605 std::set<AnfNodePtr> dynamic_nodes_set;
606 for (auto &node : segment_nodes) {
607 MS_EXCEPTION_IF_NULL(node);
608 auto cnode = node->cast<CNodePtr>();
609 if (AnfUtils::IsNodeOutputDynamicShape(cnode)) {
610 (void)dynamic_nodes_set.insert(node);
611 }
612 }
613 if (dynamic_nodes_set.empty()) {
614 AddSegment(segment_nodes, segments, node_to_segment);
615 return;
616 }
617 SplitDynamicNodeSegment(segment_nodes, segments, node_to_segment, dynamic_nodes_set);
618 }
619 } // namespace
620
GraphPartition(const std::vector<PrimitivePtr> & cut_list,const std::string & backend_name)621 GraphPartition::GraphPartition(const std::vector<PrimitivePtr> &cut_list, const std::string &backend_name)
622 : cut_list_(cut_list), backend_name_(backend_name) {}
623
IsCut(const AnfNodePtr & node)624 bool GraphPartition::IsCut(const AnfNodePtr &node) {
625 MS_EXCEPTION_IF_NULL(node);
626 if (node->isa<CNode>()) {
627 auto cnode = node->cast<CNodePtr>();
628 auto &inputs = cnode->inputs();
629 if (inputs.empty()) {
630 MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
631 }
632 AnfNodePtr fn = inputs[0];
633 if (!IsValueNode<Primitive>(fn)) {
634 return true;
635 }
636 auto node_prim = GetValueNode<PrimitivePtr>(fn);
637 for (auto &prim : cut_list_) {
638 MS_EXCEPTION_IF_NULL(prim);
639 if (prim->name() == node_prim->name()) {
640 if (prim->name() == prim::kPrimBpropCut->name()) {
641 auto ms_context = MsContext::GetInstance();
642 MS_EXCEPTION_IF_NULL(ms_context);
643 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, true);
644 }
645 if (backend_name_ == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
646 if (inputs.size() <= 1) {
647 return false;
648 }
649 auto ret = IsSubGraph(inputs[1]);
650 return ret;
651 }
652 return true;
653 }
654 }
655 #ifdef ENABLE_GE
656 if (backend_name_ == kGeVm) {
657 auto name = GetCNodeFuncName(cnode);
658 auto adpt = transform::DfGraphConvertor::FindAdapter(name);
659 if (adpt == nullptr) {
660 return true;
661 }
662 }
663 #endif
664 }
665 return false;
666 }
667
Partition(const FuncGraphPtr & graph,bool * multi_target)668 std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph, bool *multi_target) {
669 MS_EXCEPTION_IF_NULL(graph);
670 auto nodes = TopoSort(graph->get_return());
671 MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size();
672 bool contain_multi_target = ContainMultiTarget(nodes);
673 if (multi_target != nullptr) {
674 *multi_target = contain_multi_target;
675 }
676
677 auto context_ptr = MsContext::GetInstance();
678 MS_EXCEPTION_IF_NULL(context_ptr);
679 auto enable_loop_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
680 std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
681 if (contain_multi_target || !enable_loop_sink) {
682 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT)) {
683 auto other_target = GetOtherTarget(nodes);
684 nodes = ParallelSort(graph, default_target, other_target);
685 } else {
686 nodes = SplitSort(graph, default_target);
687 }
688 nodes = ReorderVirtualNode(nodes, prim::kPrimTupleGetItem);
689 nodes = ReorderVirtualNode(nodes, prim::kPrimDepend);
690 }
691 std::vector<GraphSegmentPtr> segments;
692 std::vector<AnfNodePtr> segment_nodes;
693 std::map<AnfNodePtr, GraphSegmentPtr> node_to_segment;
694 std::string last_target;
695 for (auto &node : nodes) {
696 MS_EXCEPTION_IF_NULL(node);
697 if (IsCut(node)) {
698 NodesToSegments(segment_nodes, &segments, &node_to_segment);
699 segment_nodes.clear();
700 segment_nodes.emplace_back(node);
701 auto segment = std::make_shared<GraphSegment>(segment_nodes, true);
702 segments.push_back(segment);
703 segment_nodes.clear();
704 } else if (node->isa<CNode>()) {
705 if (contain_multi_target) {
706 std::string cur_target = GetCNodeTarget(node);
707 if (cur_target != last_target && !last_target.empty()) {
708 NodesToSegments(segment_nodes, &segments, &node_to_segment);
709 segment_nodes.clear();
710 }
711 last_target = cur_target;
712 }
713 segment_nodes.emplace_back(node);
714 }
715 }
716 MS_LOG(DEBUG) << "Segment size:" << segments.size();
717 if (contain_multi_target) {
718 AddSegmentDependency(graph, node_to_segment);
719 RemoveUselessDependency(&segments);
720 }
721 return segments;
722 }
723 } // namespace compile
724 } // namespace mindspore
725