1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_SPLITTER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_SPLITTER_H_
19
20 #include <map>
21 #include <set>
22 #include <tuple>
23 #include <utility>
24 #include <string>
25 #include <memory>
26 #include <vector>
27 #include "ir/value.h"
28 #include "ir/graph_utils.h"
29 #include "ir/func_graph.h"
30 #include "base/base.h"
31 #include "include/common/utils/utils.h"
32 #include "ops/array_op_name.h"
33 #include "include/backend/distributed/constants.h"
34 #if defined(__linux__) && defined(WITH_BACKEND)
35 #include "include/backend/distributed/cluster/cluster_context.h"
36 #else
37 #include "include/backend/distributed/cluster/dummy_cluster_context.h"
38 #endif
39 #include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
40
41 namespace mindspore {
42 namespace parallel {
43 using distributed::cluster::ClusterContext;
44 using AnfNodePtrSet = CompactSet<AnfNodePtr>;
45
46 constexpr char kEnvNeedFusion[] = "fusion";
47
48 // The distributed label of the operators(kernel) used to split graph with send/recv nodes.
49 struct OperatorLabel {
50 uint32_t rank_id;
51 std::string ms_role;
52
53 bool operator<(const OperatorLabel &label) const;
54 bool operator==(const OperatorLabel &label) const;
55 bool operator!=(const OperatorLabel &label) const;
56
57 // Judge whether the labels are equal but with looser conditions according to different modes. For example, this
58 // method returns true when comparing the workers in PS mode.
59 bool LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const;
60
61 std::string to_string() const;
62 };
63
64 // The label for inter-process edges. This is used for classify the edges.
65 // For example, only edges with same label should be fused.
66 struct InterProcessEdgeLabel {
67 std::string label_name;
68 };
69
70 // The map of all nodes in the graph to their distributed split label.
71 using NodeLabels = std::map<AnfNodePtr, OperatorLabel>;
72
73 // The list of data-sync node pairs.
74 using DataSyncNodePairList = std::vector<std::pair<CNodePtr, CNodePtr>>;
75
76 // The pair of control edge nodes.
77 using ControlEdgeNodePair = std::pair<CNodePtr, CNodePtr>;
78 // The pair list of control edge nodes.
79 using ControlEdgeNodePairList = std::vector<std::pair<CNodePtr, CNodePtr>>;
80
81 // The judging functions for different modes because the logic will change under different execution modes. If labels
82 // are not matched, the send and recv nodes should be inserted.
83 using LabelMatchingFunc = std::function<bool(const OperatorLabel &, const OperatorLabel &)>;
MatchLabelForPSMode(const OperatorLabel & label1,const OperatorLabel & label2)84 inline bool MatchLabelForPSMode(const OperatorLabel &label1, const OperatorLabel &label2) {
85 // In Parameter Server training mode, Workers have the same labels regardless of their rank id.
86 bool both_worker = (label1.ms_role == label2.ms_role) && (label1.ms_role == distributed::kEnvRoleOfWorker);
87 bool all_match = (label1.rank_id == label2.rank_id) && (label1.ms_role == label2.ms_role);
88 if (both_worker || all_match) {
89 return true;
90 }
91 return false;
92 }
MatchLabelForParallelMode(const OperatorLabel & label1,const OperatorLabel & label2)93 inline bool MatchLabelForParallelMode(const OperatorLabel &label1, const OperatorLabel &label2) {
94 // When parallel mode is enabled by using MindSpore cluster, processes with the same role has the same label
95 // regardless of their rank id.
96 return (label1.ms_role == label2.ms_role);
97 }
98
99 const std::map<distributed::DistExecutionMode, LabelMatchingFunc> kLabelMatchingFuncMap = {
100 {distributed::DistExecutionMode::kPSMode, MatchLabelForPSMode},
101 {distributed::DistExecutionMode::kEmbeddingCacheMode, MatchLabelForPSMode},
102 {distributed::DistExecutionMode::kParallelMode, MatchLabelForParallelMode}};
103
104 // Split graph segment which is generated according to the topo sort of the graph.
105 struct SplitGraphSegment {
106 std::vector<AnfNodePtr> nodes;
107 OperatorLabel label;
108 };
109
110 // The inter-process edge with nodes. This represents the edge between two nodes on two processes.
111 struct InterProcessOpEdge {
112 // The peers of this edge with nodes and their labels.
113 AnfNodePtr src_node;
114 OperatorLabel src_label;
115 AnfNodePtr dst_node;
116 OperatorLabel dst_label;
117
118 // The label of this inter-process edge.
119 InterProcessEdgeLabel edge_label;
120
121 bool operator==(const InterProcessOpEdge &e) const { return to_string() == e.to_string(); }
122
123 bool operator<(const InterProcessOpEdge &e) const { return to_string() < e.to_string(); }
124
to_stringInterProcessOpEdge125 std::string to_string() const {
126 return src_node->fullname_with_scope() + "_" + src_label.to_string() + "->" + dst_node->fullname_with_scope() +
127 "_" + dst_label.to_string();
128 }
129 };
130
131 // The inter-process edge without nodes. This just represents communication edge between two processes.
132 struct InterProcessEdgeWithIndex {
133 OperatorLabel src_label;
134 OperatorLabel dst_label;
135
136 // If there are multiple independent edges between two processes, after rpc node fusion with segments, multiple
137 // InterProcessEdgeWithIndex will be generated. Index represents the segment index in this case.
138 size_t index;
139
140 bool operator==(const InterProcessEdgeWithIndex &e) const { return to_string() == e.to_string(); }
141
142 bool operator<(const InterProcessEdgeWithIndex &e) const { return to_string() < e.to_string(); }
143
to_stringInterProcessEdgeWithIndex144 std::string to_string() const {
145 return src_label.to_string() + "->" + dst_label.to_string() + "_" + std::to_string(index);
146 }
147 };
148
149 // The connection relationship for Send and Recv nodes.
150 // First element represents the Send node.
151 // Second element represents the Recv node.
152 // Third element represents a node which uses the Recv node as a input.
153 // Fourth element represents the input index of the user node.
154 using InterProcessOpPair = std::tuple<CNodePtr, CNodePtr, CNodePtr, int>;
155 using InterProcessOpEdgesInfo = std::map<InterProcessOpEdge, InterProcessOpPair>;
156
157 // The connection relationship for fused Send and Recv nodes.
158 // First element represents the fused Send node.
159 // Second element represents the fused Recv node.
160 // Third element represents the output index of the fused Recv node.
161 // Third element represents the user node which uses the fused Recv node output as an input.
162 // Fourth element represents the input index of the user node.
163 using FusedInterProcessOpPair = std::tuple<CNodePtr, CNodePtr, int, CNodePtr, int>;
164 using InterProcessOpPairMap = std::map<InterProcessEdgeWithIndex, std::vector<InterProcessOpPair>>;
165 using FusedInterProcessOpPairMap = std::map<InterProcessEdgeWithIndex, std::vector<FusedInterProcessOpPair>>;
166
167 // The list of in and out degrees of one segment.
168 using InOutDegreeList = std::vector<std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>>>;
169
170 constexpr char kPSOptimizerEdgeLabel[] = "ps_optimizer_edge_label";
171
172 constexpr char kAttrUpdateParameter[] = "update_parameter";
173 constexpr char kAttrParameterInputIndex[] = "parameter_input_index";
174 constexpr char kAttrGradientInputIndex[] = "gradient_input_index";
175 constexpr char kAttrIndicesInputIndex[] = "indices_input_index";
176
177 constexpr char kAttrGradientType[] = "gradient_type";
178 constexpr char kDenseGradient[] = "dense_gradient";
179 constexpr char kSparseGradient[] = "sparse_gradient";
180 // The accumulator operator names for different gradient types.
181 const std::map<std::string, std::string> kGradTypeToAccumOpName = {
182 {kDenseGradient, kAddNOpName},
183 {kSparseGradient, kConcatOpName},
184 };
185
186 // Node which is not physically on this process should be created for splitting graph implementation. This could be
187 // considered as a virtual node which will be elimimated after splitting graph. For example, for server in PS mode, some
188 // virtual nodes which are launched on the workers should be created as gradient accumulation nodes' inputs:
189 // VirtualNode VirtualNode RealBackwardNode
190 // | | |
191 // | | |
192 // | | |
193 // \ | /
194 // \ | /
195 // \ | /
196 // \ | /
197 // GradientAccumulationNode
198 constexpr char kVirtualNode[] = "VirtualNode";
199
200 // This method creates a fake tensor. Its type is the same as the origin_node's output if use_origin_node is set
201 // true.
202 // Normally it is used to connect the edges for send/recv nodes.
203 ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr,
204 bool use_fake_shape = true);
205
206 // Create a TupleGetItem node from a node with tuple output.
207 CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
208 size_t item_index);
209
210 // Create a MakeTuple node from multiple inputs.
211 CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &tuple_inputs);
212
213 // For some processes, the original output should be replaced with a node with the same abstract so error won't be
214 // raised in Python layer.
215 AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output);
216
217 // Set attributes for send and recv node. These attributes is used in other stages like graph compiling, rpc route,
218 // etc.
219 void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge);
220 void SetRecvNodeAttr(const AnfNodePtr &recv_node, const InterProcessOpEdge &inter_process_edge);
221
222 // The inter-process edge between two nodes should be like this:
223 // input-->Send-->Recv-->peer.
224 // Send node takes 'input' node as one input, its output's abstract is the same as a scalar value tensor's to save
225 // memory. Recv node takes a scalar value tensor as one input, its output's abstract is the same as the 'input'
226 // node's.
227 CNodePtr CreateSendNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge &inter_process_edge);
228 CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge &inter_process_edge);
229
230 // Calculate the index to segment number map.
231 std::map<size_t, size_t> GetRealIndexToSeg(const std::vector<size_t> &split_segment, size_t real_size);
232
233 bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &input);
234
235 /**
236 * @description: Generate the distributed strategy according to user configuration.
237 * @return {distributed::DistExecutionMode}: The distributed strategy enum.
238 */
239 distributed::DistExecutionMode GenerateStrategy();
240
241 /**
242 * @description: Transform primal attributes of cnode to normal attributes.
243 * @param {CNodePtr} &cnode: The cnode which has the primal attributes.
244 * @return {void}
245 */
246 void TransformPrimAttrToAttr(const CNodePtr &cnode);
247
248 /**
249 * @description: Judge whether this node has label.
250 * @param {AnfNodePtr} &node: AnfNode in a func_graph.
251 * @return {bool}: Whether this node has label.
252 */
253 bool NodeHasLabel(const AnfNodePtr &node);
254
255 /**
256 * @description: Judge whether this graph has any label.
257 * @param {FuncGraphPtr} &func_graph: The func_graph.
258 * @return {bool}: Whether this graph has label.
259 */
260 bool GraphHasLabel(const FuncGraphPtr &func_graph);
261
262 /**
263 * @description: Get node list of side effect nodes in the func_graph.
264 * @param {AnfNodePtrList} &nodes: All nodes of the func_graph.
265 * @return {CNodePtrList}: Side effect node list.
266 */
267 CNodePtrList GetSideEffectNodeList(const AnfNodePtrList &nodes);
268
269 /**
270 * @description: Get reference inputs of the cnode.
271 * @param {CNodePtr} &cnode: Node with side effect.
272 * @return {AnfNodePtrList}: The reference inputs node list.
273 */
274 AnfNodePtrList GetRefInputs(const CNodePtr &cnode);
275
276 /**
277 * @description: Find the UpdateState node which is the user of the input cnode.
278 * @param {FuncGraphPtr} &func_graph: The graph.
279 * @param {CNodePtr} &cnode: The node with side effect.
280 * @return {CNodePtr}: UpdateState node.
281 */
282 CNodePtr FindNextUpdateStateNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
283
284 /**
285 * @description: Create 'U' node which is the input of user created side effect nodes like RpcRecv, UpdateState, etc.
286 * @return {CNodePtr}: UMonad node.
287 */
288 ValueNodePtr CreateUMonadNode();
289
290 /**
291 * @description: Create UpdateState node manually.
292 * @param {FuncGraphPtr} &func_graph: The func_graph.
293 * @param {AnfNodePtrList} &update_state_inputs: Inputs of UpdateState node. Normally first is UMonadNode, which is
294 * created inside this method. the others are other side effect nodes passed by caller.
295 * @return {CNodePtr}: UpdateState node.
296 */
297 CNodePtr CreateUpdateStateNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &update_state_inputs);
298
299 // Filter out 'func_graph' nodes' dependency matrix to the specified target nodes set.
300 std::map<AnfNodePtr, AnfNodePtrSet> FilterDependencyToTargetNode(const FuncGraphPtr &func_graph,
301 const AnfNodePtrSet &target_nodes);
302
303 // After a new node is added, the depended set should be updated to keep the minimal dependencies.
304 AnfNodePtrSet UpdateDependedSet(const AnfNodePtr &new_node, const AnfNodePtrSet &old_depended_set,
305 const std::map<AnfNodePtr, AnfNodePtrSet> &node_dependency);
306
307 // Connect hung nodes to output in case they are optimized out.
308 void HandleHungNodes(const FuncGraphPtr &func_graph, const NodeLabels &node_labels, OperatorLabel process_label,
309 const AnfNodePtrList &hung_nodes_list);
310
311 // Base class for different execution modes. It builds distributed graphs, optimize execution performance, etc.
312 class DistributedExecutionMode {
313 public:
314 // Pass the dyed graph, node labels, process's role and rank id to construct execution mode.
DistributedExecutionMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)315 explicit DistributedExecutionMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
316 const std::string &role)
317 : func_graph_(func_graph), node_labels_(node_labels), rank_id_(rank_id), role_(role) {}
318 virtual ~DistributedExecutionMode() = default;
319
320 // Prebuild the distributed graph to prepare for splitting graph. For example,adding extra accumulation nodes, replace
321 // gradient input of optimizer nodes, dying new created nodes so that common split implementation could applied.
322 // Input 'node_labels' represents node labels of the origin graph. This method could modify this map.
PreBuildDistributedGraph()323 virtual void PreBuildDistributedGraph() {}
324
325 // Do rpc node fusion to decrease the overhead of network communication.
DoRpcNodeFusion(InterProcessOpEdgesInfo * comm_edges_ptr)326 virtual FusedInterProcessOpPairMap DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) { return {}; }
327
328 // Postbuild the distributed graph after splitting graph. For example, adding extra edges to the split graph.
329 // Input 'node_labels' represents node labels of the split graph.
330 // Input 'comm_edges' represents the inter-process edges generated after splitting the graph.
PostBuildDistributedGraph(const InterProcessOpEdgesInfo & comm_edges)331 virtual void PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {}
PostBuildDistributedGraph(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)332 virtual void PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {}
333
334 protected:
335 FuncGraphPtr func_graph_;
336
337 // The node label set by graph splitter. It could be modified by DistributedExecutionMode.
338 NodeLabels *node_labels_;
339
340 // Rank id and node role of this process. They are used to dye graph with different labels, help build split graph,
341 // etc.
342 uint32_t rank_id_;
343 std::string role_;
344 };
345
346 // Gradient accumulation node is needed when the worker number is equal to or greater than 2.
347 constexpr uint32_t kMinGradAccumWorkerNum = 2;
348
349 // The execution of Parameter Server mode.
350 class ParameterServerMode : public DistributedExecutionMode {
351 public:
ParameterServerMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)352 explicit ParameterServerMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
353 const std::string &role)
354 : DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
355 ~ParameterServerMode() = default;
356
357 void PreBuildDistributedGraph() override;
358 FusedInterProcessOpPairMap DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) override;
359 void PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) override;
360 void PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) override;
361
362 private:
363 // Process optimizers split to the parameter server.
364 void ProcessForSplitOptimizer();
365
366 // Filter out all optimizer nodes which are set on parameter server from the graph.
367 std::vector<CNodePtr> FilterServerAwareOptimizerList();
368
369 // Create gradients accumulator with mean operator for the given optimizer. It could be sparse or dense gradients.
370 // 'total_gradient_number' represents how many workers' gradients will be accumulated for this optimizer.
371 // The return value is a pair of accumulation node to RealDiv node.
372 std::pair<CNodePtr, CNodePtr> CreateNodesForGradAccumulation(const AnfNodePtr &gradient_input,
373 size_t gradient_input_index,
374 const std::string &gradient_type,
375 size_t total_gradient_number);
376
377 // Normally after gradients accumulation, the mean value should be calculated.
378 CNodePtr CreateGradMeanNode(const AnfNodePtr &gradient, size_t divisor);
379
380 // Create MakeTupe and TupleGetItem nodes for the multiple inputs.
381 std::pair<CNodePtr, CNodePtr> CreateNodesForMakeTuple(const AnfNodePtr &input, size_t total_inputs_number);
382
383 // Create node with multiple inputs. Some of the inputs could be fake nodes.
384 // 'many_to_one_node_name' represents the name of the node to be created.
385 // 'real_input' represents the input which is already in the func_graph_. Other inputs will be created as this input.
386 // 'index_of_real_input': the input index of 'real_input' of this new created node: 'many_to_one_node_name'.
387 // 'total_inputs_number': the total inputs number of the created node.
388 CNodePtr CreateNodeWithInterProcessEdgeOnPServer(const std::string &many_to_one_node_name,
389 const AnfNodePtr &real_input, size_t index_of_real_input,
390 uint32_t total_inputs_number);
391
392 // Fuse RpcSend and RpcRecv nodes for Parameter Server optimizers. Only one fused send node should be corresponding to
393 // one fused recv node, vice versa.
394 FusedInterProcessOpPairMap FuseRpcNodesForSplitOptimizer(
395 const InterProcessOpEdgesInfo &comm_edges_of_server_optimizer);
396
397 // Filter out all communication edges related to optimizers on Parameter Server.
398 InterProcessOpEdgesInfo FilterCommEdgesOfServerOptimizer(const InterProcessOpEdgesInfo &comm_edges) const;
399
400 // Filter out all communication edges which are not related to any Parameter Server optimizers and convert them to
401 // FusedInterProcessOpPairMap.
402 FusedInterProcessOpPairMap FilterNotServerOptimizerEdges(const InterProcessOpEdgesInfo &comm_edges) const;
403
404 // Fuse the given rpc send nodes list. Only nodes which send data to the same peer can be fused.
405 CNodePtr FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes);
406
407 // Fuse the given rpc recv nodes list. Only nodes which recv data from the same peer can be fused.
408 CNodePtr FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes);
409
410 // Fuse communication edges with same peers.
411 std::vector<FusedInterProcessOpPair> FuseCommEdges(const std::vector<InterProcessOpPair> &inter_process_pairs);
412
413 // The fusion config for rpc nodes connected with optimizers on Parameter Server. This is similar to
414 // 'all_reduce_fusion_split_indices'.
415 std::vector<size_t> ps_optimizer_fusion_segments_;
416 };
417
418 class EmbeddingCacheMode : public DistributedExecutionMode {
419 public:
EmbeddingCacheMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)420 explicit EmbeddingCacheMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
421 const std::string &role)
422 : DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
423 ~EmbeddingCacheMode() = default;
424
425 void PreBuildDistributedGraph() override;
426
427 private:
428 void AddEmbeddingCacheOps() const;
429
430 OperatorLabel GetNodeLabel(const AnfNodePtr &node) const;
431 };
432
433 // Users may want to simply split a training graph into multiple devices without other extra features. GeneralMode is
434 // for this scenario.
435 class GeneralMode : public DistributedExecutionMode {
436 public:
GeneralMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)437 explicit GeneralMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
438 const std::string &role)
439 : DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
440 ~GeneralMode() = default;
441 };
442
443 // The mode applied when AutoParallel/SemiAutoParallel feature is enabled.
444 class ParallelMode : public DistributedExecutionMode {
445 public:
ParallelMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)446 explicit ParallelMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
447 const std::string &role)
448 : DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
449 ~ParallelMode() = default;
450 };
451
452 // The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the
453 // cluster.
454 class GraphSplitter {
455 public:
456 GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, const std::string &role);
457 ~GraphSplitter();
458
459 // Launch the action.
460 void Run();
461
462 private:
463 // Dyeing the func_graph according to the split label passed by frontend. Only nodes with the same label will be dyed
464 // with the same 'color'.
465 void DyeGraph();
466
467 // Create the execution mode.
468 void CreateExecutionMode();
469
470 // Traverse all nodes and split these nodes to multiple segments according to the split label.
471 std::vector<SplitGraphSegment> GenerateSplitSegments();
472
473 /**
474 * @description: Reassign the operator label for 'TupleGetItem' nodes. This is an optimization for nodes with multiple
475 * outputs.
476 * @return {void}
477 */
478 void ReassignTupleGetItemNodeLabel();
479
480 /**
481 * @description: Recursively visit TupeGetItem nodes and set their labels.
482 * @param {CNodePtr} &tuple_get_item_node: The TupeGetItem node.
483 * @return {OperatorLabel}: The OperatorLabel of this TupeGetItem node.
484 */
485 OperatorLabel RecursiveSetTupeGetItemLabel(const CNodePtr &tuple_get_item_node);
486
487 /**
488 * @description: Add data-sync node pairs for reference nodes like trainable parameters. These nodes are used to
489 * synchronize updates of parameters between nodes.
490 * @return {void}
491 */
492 void ProcessRefNodes();
493
494 /**
495 * @description: Add some extra control edges between nodes with different labels to keep the consistency of
496 * topo-sort.
497 * @return {void}
498 */
499 void AddExtraControlEdgeAcrossProcess();
500
501 // Generate Send-Recv pairs for the nodes which has different split.
502 // Because nodes with different split label from this proccess's with be on another machine, we use Send-Recv pairs to
503 // do network communication.
504 InterProcessOpEdgesInfo GenerateInterProcessOperators();
505
506 // Eliminate nodes which are on other machine's graphs and add control edges for nodes of this process's graph.
507 void SplitGraph(const std::vector<SplitGraphSegment> &segments, const InterProcessOpEdgesInfo &comm_edges);
508 void SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
509
510 /**
511 * @description: Add data-sync nodes for reference nodes. To ensure the control edges, the data-sync nodes should be
512 * add to the UpdateState node's input:
513 * SideEffectNode(Ref1, Ref2, U)
514 * | |
515 * | |
516 * | DataSyncSrcNode(Ref1, Ref2)
517 * | |
518 * | DataSyncDstNode(Ref1, Ref2)
519 * UpdateState(U, SideEffectNode, DataSyncDstNode)
520 *
521 * The topology relationship is shown above: After SideEffectNode is launched and could have updated Ref1 and Ref2,
522 * data-sync nodes are inserted and connected to UpdateState node so that nodes after UpdateState will not be launched
523 * until the data is synchronized.
524 * @param {CNodePtr} &update_state_node: The update state node which is the reference node's user.
525 * @param {AnfNodePtrList} &ref_nodes: Reference nodes need to be synchronized.
526 * @return {void}
527 */
528 void AddDataSyncNode(const CNodePtr &side_effect_node, const CNodePtr &update_state_node,
529 const AnfNodePtrList &ref_nodes);
530
531 /**
532 * @description: Create data-sync node pairs for the reference node. It may need to be synchronized to multiple
533 * processes.
534 * @param {CNodePtr} &side_effect_node: The node with side effect using reference node as input.
535 * @param {AnfNodePtr} &ref: The reference node.
536 * @param {vector<OperatorLabel>} &diff_labels: The operator label list of each process to which the reference node
537 * data will be synchronized.
538 * @return {DataSyncNodePairList}: The list of data-sync nodes.
539 */
540 DataSyncNodePairList CreateDataSyncNodes(const CNodePtr &side_effect_node, const AnfNodePtr &ref,
541 const std::set<OperatorLabel> &diff_labels);
542
543 /**
544 * @description: For processes without any indegree, control edge should be connected from process with default label.
545 * This is to avoid these processes
546 * @return {void}
547 */
548 void AddControlEdgeForProcessWithoutIndegree();
549
550 /**
551 * @description: Create src and dst node of a control edge with the specified src and dst operator labels.
552 * ControlSrc(1.0)
553 * |
554 * |
555 * ControlDst()
556 * @param {OperatorLabel} &src_label: Control edge src label.
557 * @param {OperatorLabel} &dst_label: Control edge dst label.
558 * @return {ControlEdgeNodePair}: The nodes pair.
559 */
560 ControlEdgeNodePair CreateControlEdgeNode(const OperatorLabel &src_label, const OperatorLabel &dst_label);
561
562 /**
563 * @description: The data-sync nodes and control-edge nodes should be eliminated at the end of the splitting process.
564 * These nodes are just for graph splitting and have no corresponding backend kernels.
565 * @return {void}
566 */
567 void EliminateDataSyncNode();
568 void EliminateControlEdgeNode();
569
570 // Split the graph but don't eliminate the nodes so that a global graph ir could be exported.
571 void DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges);
572
573 // Return the split label of this node. Only CNode is supported for now.
574 // If the node has no split label, return the label of this process, which means this node should be in this process's
575 // graph.
576 OperatorLabel GetSplitLabel(const AnfNodePtr &node);
577
578 // Consider Node-X is the split node. Node-In is Node-X's one input, Node-Out takes Node-X as one input.
579 // So the graph should be like this:
580 // Node-In-->Node-X-->Node-Out.
581 // After send and recv op is inserted, the graph should be:
582 // Node-In-->Send-->Recv-->Node-X-->Send-->Recv-->Node-Out.
583 // So method GenerateInterProcessOpsForNodeInputs is for generating Send-Recv pair between Node-In and Node-X.
584 InterProcessOpEdgesInfo GenerateInterProcessOpsForNodeInputs(const AnfNodePtr &node);
585
586 InterProcessEdgeLabel GenerateEdgeLabel(const AnfNodePtr &src_node, const AnfNodePtr &dst_node) const;
587
588 // Segments will be independent with each other after the graph is cut, so in-degrees and out-degrees of each segment
589 // should be connected with control edges in case that the nodes are optimized out.
590 std::vector<AnfNodePtr> FindInterProcessInDegree(const std::vector<AnfNodePtr> &nodes,
591 const InterProcessOpEdgesInfo &comm_edges);
592 std::vector<AnfNodePtr> FindInterProcessOutDegree(const std::vector<AnfNodePtr> &nodes,
593 const InterProcessOpEdgesInfo &comm_edges);
594
595 // Generate in and out degrees list of the segments to add dependency between segments.
596 InOutDegreeList GenerateInOutDegreeList(const std::vector<SplitGraphSegment> &segments,
597 const InterProcessOpEdgesInfo &comm_edges);
598
599 // Must add extra dependency edge for RpcSend and RpcRecv nodes in case they are optimized out or lose explicit
600 // dependencies.
601 void AddDependencyBetweenEdges(const InterProcessOpEdgesInfo &comm_edges);
602 // For the segments on this process, dependency edges should be created so that they won't be optimized out.
603 void AddDependencyBetweenSegments(const InOutDegreeList &in_out_degree_list);
604
605 // Replace nodes inputs with Recv nodes to eliminate extra nodes not on this process.
606 void EliminateExtraNodes(const InterProcessOpEdgesInfo &comm_edges);
607
608 // Replace nodes inputs with Recv nodes.
609 void ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
610
611 void AddSendRecvDependency(const InterProcessOpEdgesInfo &in_degree_comm_edges, const AnfNodePtrSet &send_src_nodes,
612 const std::map<AnfNodePtr, AnfNodePtrSet> &src_nodes_to_send_nodes,
613 const std::map<AnfNodePtr, AnfNodePtrSet> &node_dependency,
614 std::map<AnfNodePtr, bool> *is_send_node_hung);
615
616 // Add outputs edges for send nodes so that they won't be optimized out.
617 void AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
618
619 // Judge whether two nodes have the same distributed label.
620 bool IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2);
621
622 // Check whether need split distributed graph.
623 bool NeedSplitGraph() const;
624
625 // Return whether this node has corresponding label stored in node_labels_.
626 bool NodeHasLabel(const AnfNodePtr &node);
627
628 FuncGraphPtr func_graph_;
629
630 // Rank id and node role of this process. They are used to dye graph with different labels, help build split graph,
631 // etc.
632 uint32_t rank_id_;
633 std::string role_;
634
635 // Created according to the execution mode. Used to build the distributed graph.
636 distributed::DistExecutionMode mode_;
637 std::unique_ptr<DistributedExecutionMode> exec_mode_;
638
639 // The label of this process which consists of its rank and role.
640 OperatorLabel this_process_label_;
641
642 // For each mode, there is a default label. Every node in the graph should be launched on the process with this label
643 // defaultly unless it has a different split label.
644 OperatorLabel default_label_;
645
646 // The map of all nodes in the graph to their distributed split label.
647 NodeLabels node_labels_;
648
649 // All labels in the graph.
650 std::set<OperatorLabel> all_labels_;
651
652 // Whether need to fuse rpc nodes.
653 bool need_fuse_rpc_nodes_;
654
655 // Visited TupleGetItem nodes when recursively setting their labels.
656 std::map<AnfNodePtr, bool> visited_tuple_get_item_nodes_;
657 };
658 using GraphSplitterPtr = std::shared_ptr<GraphSplitter>;
659 } // namespace parallel
660 } // namespace mindspore
661
662 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_SPLITTER_H_
663