• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_SPLITTER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_SPLITTER_H_
19 
20 #include <map>
21 #include <set>
22 #include <tuple>
23 #include <utility>
24 #include <string>
25 #include <memory>
26 #include <vector>
27 #include "ir/value.h"
28 #include "ir/graph_utils.h"
29 #include "ir/func_graph.h"
30 #include "base/base.h"
31 #include "include/common/utils/utils.h"
32 #include "ops/array_op_name.h"
33 #include "include/backend/distributed/constants.h"
34 #if defined(__linux__) && defined(WITH_BACKEND)
35 #include "include/backend/distributed/cluster/cluster_context.h"
36 #else
37 #include "include/backend/distributed/cluster/dummy_cluster_context.h"
38 #endif
39 #include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
40 
41 namespace mindspore {
42 namespace parallel {
43 using distributed::cluster::ClusterContext;
44 using AnfNodePtrSet = CompactSet<AnfNodePtr>;
45 
46 constexpr char kEnvNeedFusion[] = "fusion";
47 
48 // The distributed label of the operators(kernel) used to split graph with send/recv nodes.
49 struct OperatorLabel {
50   uint32_t rank_id;
51   std::string ms_role;
52 
53   bool operator<(const OperatorLabel &label) const;
54   bool operator==(const OperatorLabel &label) const;
55   bool operator!=(const OperatorLabel &label) const;
56 
57   // Judge whether the labels are equal but with looser conditions according to different modes. For example, this
58   // method returns true when comparing the workers in PS mode.
59   bool LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const;
60 
61   std::string to_string() const;
62 };
63 
64 // The label for inter-process edges. This is used for classify the edges.
65 // For example, only edges with same label should be fused.
66 struct InterProcessEdgeLabel {
67   std::string label_name;
68 };
69 
70 // The map of all nodes in the graph to their distributed split label.
71 using NodeLabels = std::map<AnfNodePtr, OperatorLabel>;
72 
73 // The list of data-sync node pairs.
74 using DataSyncNodePairList = std::vector<std::pair<CNodePtr, CNodePtr>>;
75 
76 // The pair of control edge nodes.
77 using ControlEdgeNodePair = std::pair<CNodePtr, CNodePtr>;
78 // The pair list of control edge nodes.
79 using ControlEdgeNodePairList = std::vector<std::pair<CNodePtr, CNodePtr>>;
80 
81 // The judging functions for different modes because the logic will change under different execution modes. If labels
82 // are not matched, the send and recv nodes should be inserted.
83 using LabelMatchingFunc = std::function<bool(const OperatorLabel &, const OperatorLabel &)>;
MatchLabelForPSMode(const OperatorLabel & label1,const OperatorLabel & label2)84 inline bool MatchLabelForPSMode(const OperatorLabel &label1, const OperatorLabel &label2) {
85   // In Parameter Server training mode, Workers have the same labels regardless of their rank id.
86   bool both_worker = (label1.ms_role == label2.ms_role) && (label1.ms_role == distributed::kEnvRoleOfWorker);
87   bool all_match = (label1.rank_id == label2.rank_id) && (label1.ms_role == label2.ms_role);
88   if (both_worker || all_match) {
89     return true;
90   }
91   return false;
92 }
MatchLabelForParallelMode(const OperatorLabel & label1,const OperatorLabel & label2)93 inline bool MatchLabelForParallelMode(const OperatorLabel &label1, const OperatorLabel &label2) {
94   // When parallel mode is enabled by using MindSpore cluster, processes with the same role has the same label
95   // regardless of their rank id.
96   return (label1.ms_role == label2.ms_role);
97 }
98 
99 const std::map<distributed::DistExecutionMode, LabelMatchingFunc> kLabelMatchingFuncMap = {
100   {distributed::DistExecutionMode::kPSMode, MatchLabelForPSMode},
101   {distributed::DistExecutionMode::kEmbeddingCacheMode, MatchLabelForPSMode},
102   {distributed::DistExecutionMode::kParallelMode, MatchLabelForParallelMode}};
103 
104 // Split graph segment which is generated according to the topo sort of the graph.
105 struct SplitGraphSegment {
106   std::vector<AnfNodePtr> nodes;
107   OperatorLabel label;
108 };
109 
110 // The inter-process edge with nodes. This represents the edge between two nodes on two processes.
111 struct InterProcessOpEdge {
112   // The peers of this edge with nodes and their labels.
113   AnfNodePtr src_node;
114   OperatorLabel src_label;
115   AnfNodePtr dst_node;
116   OperatorLabel dst_label;
117 
118   // The label of this inter-process edge.
119   InterProcessEdgeLabel edge_label;
120 
121   bool operator==(const InterProcessOpEdge &e) const { return to_string() == e.to_string(); }
122 
123   bool operator<(const InterProcessOpEdge &e) const { return to_string() < e.to_string(); }
124 
to_stringInterProcessOpEdge125   std::string to_string() const {
126     return src_node->fullname_with_scope() + "_" + src_label.to_string() + "->" + dst_node->fullname_with_scope() +
127            "_" + dst_label.to_string();
128   }
129 };
130 
131 // The inter-process edge without nodes. This just represents communication edge between two processes.
132 struct InterProcessEdgeWithIndex {
133   OperatorLabel src_label;
134   OperatorLabel dst_label;
135 
136   // If there are multiple independent edges between two processes, after rpc node fusion with segments, multiple
137   // InterProcessEdgeWithIndex will be generated. Index represents the segment index in this case.
138   size_t index;
139 
140   bool operator==(const InterProcessEdgeWithIndex &e) const { return to_string() == e.to_string(); }
141 
142   bool operator<(const InterProcessEdgeWithIndex &e) const { return to_string() < e.to_string(); }
143 
to_stringInterProcessEdgeWithIndex144   std::string to_string() const {
145     return src_label.to_string() + "->" + dst_label.to_string() + "_" + std::to_string(index);
146   }
147 };
148 
149 // The connection relationship for Send and Recv nodes.
150 // First element represents the Send node.
151 // Second element represents the Recv node.
152 // Third element represents a node which uses the Recv node as a input.
153 // Fourth element represents the input index of the user node.
154 using InterProcessOpPair = std::tuple<CNodePtr, CNodePtr, CNodePtr, int>;
155 using InterProcessOpEdgesInfo = std::map<InterProcessOpEdge, InterProcessOpPair>;
156 
157 // The connection relationship for fused Send and Recv nodes.
158 // First element represents the fused Send node.
159 // Second element represents the fused Recv node.
160 // Third element represents the output index of the fused Recv node.
161 // Third element represents the user node which uses the fused Recv node output as an input.
162 // Fourth element represents the input index of the user node.
163 using FusedInterProcessOpPair = std::tuple<CNodePtr, CNodePtr, int, CNodePtr, int>;
164 using InterProcessOpPairMap = std::map<InterProcessEdgeWithIndex, std::vector<InterProcessOpPair>>;
165 using FusedInterProcessOpPairMap = std::map<InterProcessEdgeWithIndex, std::vector<FusedInterProcessOpPair>>;
166 
167 // The list of in and out degrees of one segment.
168 using InOutDegreeList = std::vector<std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>>>;
169 
170 constexpr char kPSOptimizerEdgeLabel[] = "ps_optimizer_edge_label";
171 
172 constexpr char kAttrUpdateParameter[] = "update_parameter";
173 constexpr char kAttrParameterInputIndex[] = "parameter_input_index";
174 constexpr char kAttrGradientInputIndex[] = "gradient_input_index";
175 constexpr char kAttrIndicesInputIndex[] = "indices_input_index";
176 
177 constexpr char kAttrGradientType[] = "gradient_type";
178 constexpr char kDenseGradient[] = "dense_gradient";
179 constexpr char kSparseGradient[] = "sparse_gradient";
180 // The accumulator operator names for different gradient types.
181 const std::map<std::string, std::string> kGradTypeToAccumOpName = {
182   {kDenseGradient, kAddNOpName},
183   {kSparseGradient, kConcatOpName},
184 };
185 
186 // Node which is not physically on this process should be created for splitting graph implementation. This could be
187 // considered as a virtual node which will be elimimated after splitting graph. For example, for server in PS mode, some
188 // virtual nodes which are launched on the workers should be created as gradient accumulation nodes' inputs:
189 // VirtualNode  VirtualNode  RealBackwardNode
190 //      |            |               |
191 //      |            |               |
192 //      |            |               |
193 //       \           |               /
194 //        \          |              /
195 //         \         |             /
196 //          \        |            /
197 //         GradientAccumulationNode
198 constexpr char kVirtualNode[] = "VirtualNode";
199 
200 // This method creates a fake tensor. Its type is the same as the origin_node's output if use_origin_node is set
201 // true.
202 // Normally it is used to connect the edges for send/recv nodes.
203 ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node = nullptr,
204                                  bool use_fake_shape = true);
205 
206 // Create a TupleGetItem node from a node with tuple output.
207 CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
208                                 size_t item_index);
209 
210 // Create a MakeTuple node from multiple inputs.
211 CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &tuple_inputs);
212 
213 // For some processes, the original output should be replaced with a node with the same abstract so error won't be
214 // raised in Python layer.
215 AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output);
216 
217 // Set attributes for send and recv node. These attributes is used in other stages like graph compiling, rpc route,
218 // etc.
219 void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge);
220 void SetRecvNodeAttr(const AnfNodePtr &recv_node, const InterProcessOpEdge &inter_process_edge);
221 
222 // The inter-process edge between two nodes should be like this:
223 // input-->Send-->Recv-->peer.
224 // Send node takes 'input' node as one input, its output's abstract is the same as a scalar value tensor's to save
225 // memory. Recv node takes a scalar value tensor as one input, its output's abstract is the same as the 'input'
226 // node's.
227 CNodePtr CreateSendNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge &inter_process_edge);
228 CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge &inter_process_edge);
229 
230 // Calculate the index to segment number map.
231 std::map<size_t, size_t> GetRealIndexToSeg(const std::vector<size_t> &split_segment, size_t real_size);
232 
233 bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &input);
234 
235 /**
236  * @description: Generate the distributed strategy according to user configuration.
237  * @return {distributed::DistExecutionMode}: The distributed strategy enum.
238  */
239 distributed::DistExecutionMode GenerateStrategy();
240 
241 /**
242  * @description: Transform primal attributes of cnode to normal attributes.
243  * @param {CNodePtr} &cnode: The cnode which has the primal attributes.
244  * @return {void}
245  */
246 void TransformPrimAttrToAttr(const CNodePtr &cnode);
247 
248 /**
249  * @description: Judge whether this node has label.
250  * @param {AnfNodePtr} &node: AnfNode in a func_graph.
251  * @return {bool}: Whether this node has label.
252  */
253 bool NodeHasLabel(const AnfNodePtr &node);
254 
255 /**
256  * @description: Judge whether this graph has any label.
257  * @param {FuncGraphPtr} &func_graph: The func_graph.
258  * @return {bool}: Whether this graph has label.
259  */
260 bool GraphHasLabel(const FuncGraphPtr &func_graph);
261 
262 /**
263  * @description: Get node list of side effect nodes in the func_graph.
264  * @param {AnfNodePtrList} &nodes: All nodes of the func_graph.
265  * @return {CNodePtrList}: Side effect node list.
266  */
267 CNodePtrList GetSideEffectNodeList(const AnfNodePtrList &nodes);
268 
269 /**
270  * @description: Get reference inputs of the cnode.
271  * @param {CNodePtr} &cnode: Node with side effect.
272  * @return {AnfNodePtrList}: The reference inputs node list.
273  */
274 AnfNodePtrList GetRefInputs(const CNodePtr &cnode);
275 
276 /**
277  * @description: Find the UpdateState node which is the user of the input cnode.
278  * @param {FuncGraphPtr} &func_graph: The graph.
279  * @param {CNodePtr} &cnode: The node with side effect.
280  * @return {CNodePtr}: UpdateState node.
281  */
282 CNodePtr FindNextUpdateStateNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
283 
284 /**
285  * @description: Create 'U' node which is the input of user created side effect nodes like RpcRecv, UpdateState, etc.
286  * @return {CNodePtr}: UMonad node.
287  */
288 ValueNodePtr CreateUMonadNode();
289 
290 /**
291  * @description: Create UpdateState node manually.
292  * @param {FuncGraphPtr} &func_graph: The func_graph.
293  * @param {AnfNodePtrList} &update_state_inputs: Inputs of UpdateState node. Normally first is UMonadNode, which is
294  * created inside this method. the others are other side effect nodes passed by caller.
295  * @return {CNodePtr}: UpdateState node.
296  */
297 CNodePtr CreateUpdateStateNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &update_state_inputs);
298 
299 // Filter out 'func_graph' nodes' dependency matrix to the specified target nodes set.
300 std::map<AnfNodePtr, AnfNodePtrSet> FilterDependencyToTargetNode(const FuncGraphPtr &func_graph,
301                                                                  const AnfNodePtrSet &target_nodes);
302 
303 // After a new node is added, the depended set should be updated to keep the minimal dependencies.
304 AnfNodePtrSet UpdateDependedSet(const AnfNodePtr &new_node, const AnfNodePtrSet &old_depended_set,
305                                 const std::map<AnfNodePtr, AnfNodePtrSet> &node_dependency);
306 
307 // Connect hung nodes to output in case they are optimized out.
308 void HandleHungNodes(const FuncGraphPtr &func_graph, const NodeLabels &node_labels, OperatorLabel process_label,
309                      const AnfNodePtrList &hung_nodes_list);
310 
311 // Base class for different execution modes. It builds distributed graphs, optimize execution performance, etc.
312 class DistributedExecutionMode {
313  public:
314   // Pass the dyed graph, node labels, process's role and rank id to construct execution mode.
DistributedExecutionMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)315   explicit DistributedExecutionMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
316                                     const std::string &role)
317       : func_graph_(func_graph), node_labels_(node_labels), rank_id_(rank_id), role_(role) {}
318   virtual ~DistributedExecutionMode() = default;
319 
320   // Prebuild the distributed graph to prepare for splitting graph. For example,adding extra accumulation nodes, replace
321   // gradient input of optimizer nodes, dying new created nodes so that common split implementation could applied.
322   // Input 'node_labels' represents node labels of the origin graph. This method could modify this map.
PreBuildDistributedGraph()323   virtual void PreBuildDistributedGraph() {}
324 
325   // Do rpc node fusion to decrease the overhead of network communication.
DoRpcNodeFusion(InterProcessOpEdgesInfo * comm_edges_ptr)326   virtual FusedInterProcessOpPairMap DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) { return {}; }
327 
328   // Postbuild the distributed graph after splitting graph. For example, adding extra edges to the split graph.
329   // Input 'node_labels' represents node labels of the split graph.
330   // Input 'comm_edges' represents the inter-process edges generated after splitting the graph.
PostBuildDistributedGraph(const InterProcessOpEdgesInfo & comm_edges)331   virtual void PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {}
PostBuildDistributedGraph(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)332   virtual void PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {}
333 
334  protected:
335   FuncGraphPtr func_graph_;
336 
337   // The node label set by graph splitter. It could be modified by DistributedExecutionMode.
338   NodeLabels *node_labels_;
339 
340   // Rank id and node role of this process. They are used to dye graph with different labels, help build split graph,
341   // etc.
342   uint32_t rank_id_;
343   std::string role_;
344 };
345 
346 // Gradient accumulation node is needed when the worker number is equal to or greater than 2.
347 constexpr uint32_t kMinGradAccumWorkerNum = 2;
348 
349 // The execution of Parameter Server mode.
350 class ParameterServerMode : public DistributedExecutionMode {
351  public:
ParameterServerMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)352   explicit ParameterServerMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
353                                const std::string &role)
354       : DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
355   ~ParameterServerMode() = default;
356 
357   void PreBuildDistributedGraph() override;
358   FusedInterProcessOpPairMap DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) override;
359   void PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) override;
360   void PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) override;
361 
362  private:
363   // Process optimizers split to the parameter server.
364   void ProcessForSplitOptimizer();
365 
366   // Filter out all optimizer nodes which are set on parameter server from the graph.
367   std::vector<CNodePtr> FilterServerAwareOptimizerList();
368 
369   // Create gradients accumulator with mean operator for the given optimizer. It could be sparse or dense gradients.
370   // 'total_gradient_number' represents how many workers' gradients will be accumulated for this optimizer.
371   // The return value is a pair of accumulation node to RealDiv node.
372   std::pair<CNodePtr, CNodePtr> CreateNodesForGradAccumulation(const AnfNodePtr &gradient_input,
373                                                                size_t gradient_input_index,
374                                                                const std::string &gradient_type,
375                                                                size_t total_gradient_number);
376 
377   // Normally after gradients accumulation, the mean value should be calculated.
378   CNodePtr CreateGradMeanNode(const AnfNodePtr &gradient, size_t divisor);
379 
380   // Create MakeTupe and TupleGetItem nodes for the multiple inputs.
381   std::pair<CNodePtr, CNodePtr> CreateNodesForMakeTuple(const AnfNodePtr &input, size_t total_inputs_number);
382 
383   // Create node with multiple inputs. Some of the inputs could be fake nodes.
384   // 'many_to_one_node_name' represents the name of the node to be created.
385   // 'real_input' represents the input which is already in the func_graph_. Other inputs will be created as this input.
386   // 'index_of_real_input': the input index of 'real_input' of this new created node: 'many_to_one_node_name'.
387   // 'total_inputs_number': the total inputs number of the created node.
388   CNodePtr CreateNodeWithInterProcessEdgeOnPServer(const std::string &many_to_one_node_name,
389                                                    const AnfNodePtr &real_input, size_t index_of_real_input,
390                                                    uint32_t total_inputs_number);
391 
392   // Fuse RpcSend and RpcRecv nodes for Parameter Server optimizers. Only one fused send node should be corresponding to
393   // one fused recv node, vice versa.
394   FusedInterProcessOpPairMap FuseRpcNodesForSplitOptimizer(
395     const InterProcessOpEdgesInfo &comm_edges_of_server_optimizer);
396 
397   // Filter out all communication edges related to optimizers on Parameter Server.
398   InterProcessOpEdgesInfo FilterCommEdgesOfServerOptimizer(const InterProcessOpEdgesInfo &comm_edges) const;
399 
400   // Filter out all communication edges which are not related to any Parameter Server optimizers and convert them to
401   // FusedInterProcessOpPairMap.
402   FusedInterProcessOpPairMap FilterNotServerOptimizerEdges(const InterProcessOpEdgesInfo &comm_edges) const;
403 
404   // Fuse the given rpc send nodes list. Only nodes which send data to the same peer can be fused.
405   CNodePtr FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes);
406 
407   // Fuse the given rpc recv nodes list. Only nodes which recv data from the same peer can be fused.
408   CNodePtr FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes);
409 
410   // Fuse communication edges with same peers.
411   std::vector<FusedInterProcessOpPair> FuseCommEdges(const std::vector<InterProcessOpPair> &inter_process_pairs);
412 
413   // The fusion config for rpc nodes connected with optimizers on Parameter Server. This is similar to
414   // 'all_reduce_fusion_split_indices'.
415   std::vector<size_t> ps_optimizer_fusion_segments_;
416 };
417 
418 class EmbeddingCacheMode : public DistributedExecutionMode {
419  public:
EmbeddingCacheMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)420   explicit EmbeddingCacheMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
421                               const std::string &role)
422       : DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
423   ~EmbeddingCacheMode() = default;
424 
425   void PreBuildDistributedGraph() override;
426 
427  private:
428   void AddEmbeddingCacheOps() const;
429 
430   OperatorLabel GetNodeLabel(const AnfNodePtr &node) const;
431 };
432 
433 // Users may want to simply split a training graph into multiple devices without other extra features. GeneralMode is
434 // for this scenario.
435 class GeneralMode : public DistributedExecutionMode {
436  public:
GeneralMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)437   explicit GeneralMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
438                        const std::string &role)
439       : DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
440   ~GeneralMode() = default;
441 };
442 
443 // The mode applied when AutoParallel/SemiAutoParallel feature is enabled.
444 class ParallelMode : public DistributedExecutionMode {
445  public:
ParallelMode(const FuncGraphPtr & func_graph,NodeLabels * node_labels,uint32_t rank_id,const std::string & role)446   explicit ParallelMode(const FuncGraphPtr &func_graph, NodeLabels *node_labels, uint32_t rank_id,
447                         const std::string &role)
448       : DistributedExecutionMode(func_graph, node_labels, rank_id, role) {}
449   ~ParallelMode() = default;
450 };
451 
452 // The class is used as an action in pipeline. It will process the graph and split the nodes to each process in the
453 // cluster.
454 class GraphSplitter {
455  public:
456   GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, const std::string &role);
457   ~GraphSplitter();
458 
459   // Launch the action.
460   void Run();
461 
462  private:
463   // Dyeing the func_graph according to the split label passed by frontend. Only nodes with the same label will be dyed
464   // with the same 'color'.
465   void DyeGraph();
466 
467   // Create the execution mode.
468   void CreateExecutionMode();
469 
470   // Traverse all nodes and split these nodes to multiple segments according to the split label.
471   std::vector<SplitGraphSegment> GenerateSplitSegments();
472 
473   /**
474    * @description: Reassign the operator label for 'TupleGetItem' nodes. This is an optimization for nodes with multiple
475    * outputs.
476    * @return {void}
477    */
478   void ReassignTupleGetItemNodeLabel();
479 
480   /**
481    * @description: Recursively visit TupeGetItem nodes and set their labels.
482    * @param {CNodePtr} &tuple_get_item_node: The TupeGetItem node.
483    * @return {OperatorLabel}: The OperatorLabel of this TupeGetItem node.
484    */
485   OperatorLabel RecursiveSetTupeGetItemLabel(const CNodePtr &tuple_get_item_node);
486 
487   /**
488    * @description: Add data-sync node pairs for reference nodes like trainable parameters. These nodes are used to
489    * synchronize updates of parameters between nodes.
490    * @return {void}
491    */
492   void ProcessRefNodes();
493 
494   /**
495    * @description: Add some extra control edges between nodes with different labels to keep the consistency of
496    * topo-sort.
497    * @return {void}
498    */
499   void AddExtraControlEdgeAcrossProcess();
500 
501   // Generate Send-Recv pairs for the nodes which has different split.
502   // Because nodes with different split label from this proccess's with be on another machine, we use Send-Recv pairs to
503   // do network communication.
504   InterProcessOpEdgesInfo GenerateInterProcessOperators();
505 
506   // Eliminate nodes which are on other machine's graphs and add control edges for nodes of this process's graph.
507   void SplitGraph(const std::vector<SplitGraphSegment> &segments, const InterProcessOpEdgesInfo &comm_edges);
508   void SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
509 
510   /**
511    * @description: Add data-sync nodes for reference nodes. To ensure the control edges, the data-sync nodes should be
512    * add to the UpdateState node's input:
513    * SideEffectNode(Ref1, Ref2, U)
514    *              |             |
515    *              |             |
516    *              |        DataSyncSrcNode(Ref1, Ref2)
517    *              |             |
518    *              |        DataSyncDstNode(Ref1, Ref2)
519    * UpdateState(U, SideEffectNode, DataSyncDstNode)
520    *
521    * The topology relationship is shown above: After SideEffectNode is launched and could have updated Ref1 and Ref2,
522    * data-sync nodes are inserted and connected to UpdateState node so that nodes after UpdateState will not be launched
523    * until the data is synchronized.
524    * @param {CNodePtr} &update_state_node: The update state node which is the reference node's user.
525    * @param {AnfNodePtrList} &ref_nodes: Reference nodes need to be synchronized.
526    * @return {void}
527    */
528   void AddDataSyncNode(const CNodePtr &side_effect_node, const CNodePtr &update_state_node,
529                        const AnfNodePtrList &ref_nodes);
530 
531   /**
532    * @description: Create data-sync node pairs for the reference node. It may need to be synchronized to multiple
533    * processes.
534    * @param {CNodePtr} &side_effect_node: The node with side effect using reference node as input.
535    * @param {AnfNodePtr} &ref: The reference node.
536    * @param {vector<OperatorLabel>} &diff_labels: The operator label list of each process to which the reference node
537    * data will be synchronized.
538    * @return {DataSyncNodePairList}: The list of data-sync nodes.
539    */
540   DataSyncNodePairList CreateDataSyncNodes(const CNodePtr &side_effect_node, const AnfNodePtr &ref,
541                                            const std::set<OperatorLabel> &diff_labels);
542 
543   /**
544    * @description: For processes without any indegree, control edge should be connected from process with default label.
545    * This is to avoid these processes
546    * @return {void}
547    */
548   void AddControlEdgeForProcessWithoutIndegree();
549 
550   /**
551    * @description: Create src and dst node of a control edge with the specified src and dst operator labels.
552    *              ControlSrc(1.0)
553    *                   |
554    *                   |
555    *              ControlDst()
556    * @param {OperatorLabel} &src_label: Control edge src label.
557    * @param {OperatorLabel} &dst_label: Control edge dst label.
558    * @return {ControlEdgeNodePair}: The nodes pair.
559    */
560   ControlEdgeNodePair CreateControlEdgeNode(const OperatorLabel &src_label, const OperatorLabel &dst_label);
561 
562   /**
563    * @description: The data-sync nodes and control-edge nodes should be eliminated at the end of the splitting process.
564    * These nodes are just for graph splitting and have no corresponding backend kernels.
565    * @return {void}
566    */
567   void EliminateDataSyncNode();
568   void EliminateControlEdgeNode();
569 
570   // Split the graph but don't eliminate the nodes so that a global graph ir could be exported.
571   void DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges);
572 
573   // Return the split label of this node. Only CNode is supported for now.
574   // If the node has no split label, return the label of this process, which means this node should be in this process's
575   // graph.
576   OperatorLabel GetSplitLabel(const AnfNodePtr &node);
577 
578   // Consider Node-X is the split node. Node-In is Node-X's one input, Node-Out takes Node-X as one input.
579   // So the graph should be like this:
580   // Node-In-->Node-X-->Node-Out.
581   // After send and recv op is inserted, the graph should be:
582   // Node-In-->Send-->Recv-->Node-X-->Send-->Recv-->Node-Out.
583   // So method GenerateInterProcessOpsForNodeInputs is for generating Send-Recv pair between Node-In and Node-X.
584   InterProcessOpEdgesInfo GenerateInterProcessOpsForNodeInputs(const AnfNodePtr &node);
585 
586   InterProcessEdgeLabel GenerateEdgeLabel(const AnfNodePtr &src_node, const AnfNodePtr &dst_node) const;
587 
588   // Segments will be independent with each other after the graph is cut, so in-degrees and out-degrees of each segment
589   // should be connected with control edges in case that the nodes are optimized out.
590   std::vector<AnfNodePtr> FindInterProcessInDegree(const std::vector<AnfNodePtr> &nodes,
591                                                    const InterProcessOpEdgesInfo &comm_edges);
592   std::vector<AnfNodePtr> FindInterProcessOutDegree(const std::vector<AnfNodePtr> &nodes,
593                                                     const InterProcessOpEdgesInfo &comm_edges);
594 
595   // Generate in and out degrees list of the segments to add dependency between segments.
596   InOutDegreeList GenerateInOutDegreeList(const std::vector<SplitGraphSegment> &segments,
597                                           const InterProcessOpEdgesInfo &comm_edges);
598 
599   //  Must add extra dependency edge for RpcSend and RpcRecv nodes in case they are optimized out or lose explicit
600   //  dependencies.
601   void AddDependencyBetweenEdges(const InterProcessOpEdgesInfo &comm_edges);
602   // For the segments on this process, dependency edges should be created so that they won't be optimized out.
603   void AddDependencyBetweenSegments(const InOutDegreeList &in_out_degree_list);
604 
605   // Replace nodes inputs with Recv nodes to eliminate extra nodes not on this process.
606   void EliminateExtraNodes(const InterProcessOpEdgesInfo &comm_edges);
607 
608   // Replace nodes inputs with Recv nodes.
609   void ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
610 
611   void AddSendRecvDependency(const InterProcessOpEdgesInfo &in_degree_comm_edges, const AnfNodePtrSet &send_src_nodes,
612                              const std::map<AnfNodePtr, AnfNodePtrSet> &src_nodes_to_send_nodes,
613                              const std::map<AnfNodePtr, AnfNodePtrSet> &node_dependency,
614                              std::map<AnfNodePtr, bool> *is_send_node_hung);
615 
616   // Add outputs edges for send nodes so that they won't be optimized out.
617   void AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs);
618 
619   // Judge whether two nodes have the same distributed label.
620   bool IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2);
621 
622   // Check whether need split distributed graph.
623   bool NeedSplitGraph() const;
624 
625   // Return whether this node has corresponding label stored in node_labels_.
626   bool NodeHasLabel(const AnfNodePtr &node);
627 
628   FuncGraphPtr func_graph_;
629 
630   // Rank id and node role of this process. They are used to dye graph with different labels, help build split graph,
631   // etc.
632   uint32_t rank_id_;
633   std::string role_;
634 
635   // Created according to the execution mode. Used to build the distributed graph.
636   distributed::DistExecutionMode mode_;
637   std::unique_ptr<DistributedExecutionMode> exec_mode_;
638 
639   // The label of this process which consists of its rank and role.
640   OperatorLabel this_process_label_;
641 
642   // For each mode, there is a default label. Every node in the graph should be launched on the process with this label
643   // defaultly unless it has a different split label.
644   OperatorLabel default_label_;
645 
646   // The map of all nodes in the graph to their distributed split label.
647   NodeLabels node_labels_;
648 
649   // All labels in the graph.
650   std::set<OperatorLabel> all_labels_;
651 
652   // Whether need to fuse rpc nodes.
653   bool need_fuse_rpc_nodes_;
654 
655   // Visited TupleGetItem nodes when recursively setting their labels.
656   std::map<AnfNodePtr, bool> visited_tuple_get_item_nodes_;
657 };
658 using GraphSplitterPtr = std::shared_ptr<GraphSplitter>;
659 }  // namespace parallel
660 }  // namespace mindspore
661 
662 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GRAPH_SPLITTER_H_
663