• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/graph_util/graph_splitter.h"
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include "include/common/debug/draw.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/common/utils/parallel_context.h"
28 #include "include/common/utils/utils.h"
29 #include "mindspore/core/utils/ms_context.h"
30 #include "ops/array_op_name.h"
31 #include "ops/framework_ops.h"
32 #include "ops/math_op_name.h"
33 #include "ops/sequence_ops.h"
34 #if defined(__linux__) && defined(WITH_BACKEND)
35 #include "include/backend/distributed/ps/ps_context.h"
36 #endif
37 
38 namespace mindspore {
39 namespace parallel {
operator <(const OperatorLabel & label) const40 bool OperatorLabel::operator<(const OperatorLabel &label) const { return to_string() < label.to_string(); }
41 
operator ==(const OperatorLabel & label) const42 bool OperatorLabel::operator==(const OperatorLabel &label) const { return to_string() == label.to_string(); }
43 
operator !=(const OperatorLabel & label) const44 bool OperatorLabel::operator!=(const OperatorLabel &label) const { return !(*this == label); }
45 
LooseEqual(const OperatorLabel & label,distributed::DistExecutionMode mode) const46 bool OperatorLabel::LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const {
47   if (kLabelMatchingFuncMap.count(mode) == 0) {
48     MS_LOG(DEBUG) << "The mode " << mode << " does not need LooseEqual.";
49     return to_string() == label.to_string();
50   }
51   return kLabelMatchingFuncMap.at(mode)(label, *this);
52 }
53 
to_string() const54 std::string OperatorLabel::to_string() const { return std::to_string(rank_id) + "_" + ms_role; }
55 
CreateFakeValueNode(bool use_origin_node,const AnfNodePtr & origin_node,bool use_fake_shape)56 ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node, bool use_fake_shape) {
57   tensor::TensorPtr fake_tensor = nullptr;
58   if (use_origin_node) {
59     MS_EXCEPTION_IF_NULL(origin_node);
60     abstract::AbstractTensorPtr origin_abstract;
61     if (origin_node->abstract()->isa<abstract::AbstractTuple>()) {
62       // Defaultly, if the origin node's output is a tuple, get the abstract of the first element.
63       auto get_one_tuple_element = origin_node->abstract()->cast<abstract::AbstractTuplePtr>()->elements()[0];
64       origin_abstract = get_one_tuple_element->cast<abstract::AbstractTensorPtr>();
65     } else {
66       origin_abstract = origin_node->abstract()->cast<abstract::AbstractTensorPtr>();
67     }
68     MS_EXCEPTION_IF_NULL(origin_abstract);
69     auto element = origin_abstract->element();
70     MS_EXCEPTION_IF_NULL(element);
71     auto build_type = element->BuildType();
72     MS_EXCEPTION_IF_NULL(build_type);
73     auto type_id = build_type->type_id();
74     if (use_fake_shape) {
75       // Assign send's output shape as {1};
76       ShapeVector fake_shape = {kSizeOne};
77       fake_tensor = std::make_shared<tensor::Tensor>(type_id, fake_shape);
78     } else {
79       auto shape = origin_abstract->shape();
80       MS_EXCEPTION_IF_NULL(shape);
81       fake_tensor = std::make_shared<tensor::Tensor>(type_id, shape->shape());
82       fake_tensor->set_base_shape(shape->Clone());
83     }
84   } else {
85     fake_tensor = std::make_shared<tensor::Tensor>(1.0);
86   }
87 
88   MS_EXCEPTION_IF_NULL(fake_tensor);
89   auto fake_value = NewValueNode(fake_tensor);
90   MS_EXCEPTION_IF_NULL(fake_value);
91   fake_value->set_abstract(fake_tensor->ToAbstract());
92   return fake_value;
93 }
94 
CreateTupleGetItemNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node_with_tuple_output,size_t item_index)95 CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
96                                 size_t item_index) {
97   MS_EXCEPTION_IF_NULL(func_graph);
98   MS_EXCEPTION_IF_NULL(node_with_tuple_output);
99   const auto &tuple_abstract = node_with_tuple_output->abstract();
100   MS_EXCEPTION_IF_NULL(tuple_abstract);
101   if (!tuple_abstract->isa<abstract::AbstractTuple>()) {
102     MS_LOG(EXCEPTION) << "Only create TupleGetItem for tuple output.";
103   }
104 
105   auto item_index_value_node = NewValueNode(MakeValue(UlongToLong(item_index)));
106   MS_EXCEPTION_IF_NULL(item_index_value_node);
107 
108   std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(prim::kPrimTupleGetItem), node_with_tuple_output,
109                                                    item_index_value_node};
110   CNodePtr tuple_get_item_node = func_graph->NewCNode(tuple_get_item_inputs);
111   MS_EXCEPTION_IF_NULL(tuple_get_item_node);
112   tuple_get_item_node->set_abstract(tuple_abstract->cast<abstract::AbstractTuplePtr>()->elements()[item_index]);
113   return tuple_get_item_node;
114 }
115 
CreateMakeTupleNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & tuple_inputs)116 CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &tuple_inputs) {
117   MS_EXCEPTION_IF_NULL(func_graph);
118   AnfNodePtrList new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
119   (void)new_make_tuple_inputs.insert(new_make_tuple_inputs.cend(), tuple_inputs.cbegin(), tuple_inputs.cend());
120   auto make_tuple_node = func_graph->NewCNode(new_make_tuple_inputs);
121   MS_EXCEPTION_IF_NULL(make_tuple_node);
122 
123   // MakeTuple's abstract must consist of all inputs' abstract in case unexpected graph compiling error.
124   AbstractBasePtrList abstract_list;
125   (void)std::for_each(tuple_inputs.cbegin(), tuple_inputs.cend(),
126                       [&](const auto &input) { (void)abstract_list.emplace_back(input->abstract()); });
127   make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
128   return make_tuple_node;
129 }
130 
CreateReplacedOutputNode(const FuncGraphPtr & func_graph,const AnfNodePtr & origin_output)131 AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output) {
132   MS_EXCEPTION_IF_NULL(func_graph);
133   MS_EXCEPTION_IF_NULL(origin_output);
134   MS_EXCEPTION_IF_NULL(origin_output->abstract());
135   if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
136     auto kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(origin_output, kIndex0);
137     auto real_output = kernel_with_index.first;
138     if (!IsPrimitiveCNode(real_output, prim::kPrimMakeTuple)) {
139       MS_LOG(EXCEPTION) << "Tuple output is not a MakeTuple node: " << real_output->DebugString();
140     }
141     AnfNodePtrList tuple_inputs;
142     auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
143     for (size_t i = kIndex0; i < tuple_elements.size(); i++) {
144       // If tuple input is a ValueNode, use it as new tuple's input.
145       const auto tuple_input = real_output->cast<CNodePtr>()->input(i + kSizeOne);
146       if (tuple_input->isa<Parameter>() || tuple_input->isa<ValueNode>()) {
147         MS_LOG(INFO) << "Use " << tuple_input->DebugString() << " as replaced output.";
148         (void)tuple_inputs.emplace_back(tuple_input);
149         continue;
150       }
151 
152       const auto &element = tuple_elements[i];
153       MS_EXCEPTION_IF_NULL(element);
154       auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
155       if (!tensor_abstract) {
156         MS_LOG(EXCEPTION) << "Only support to replace tuple with all tensor elements.";
157       }
158       auto fake_tensor = std::make_shared<tensor::Tensor>(tensor_abstract->element()->BuildType()->type_id(),
159                                                           tensor_abstract->shape()->shape());
160       MS_EXCEPTION_IF_NULL(fake_tensor);
161       auto fake_value_node = NewValueNode(fake_tensor);
162       MS_EXCEPTION_IF_NULL(fake_value_node);
163       fake_value_node->set_abstract(fake_tensor->ToAbstract());
164       (void)tuple_inputs.emplace_back(fake_value_node);
165     }
166     return CreateMakeTupleNode(func_graph, tuple_inputs);
167   } else {
168     return CreateFakeValueNode(true, origin_output);
169   }
170 }
171 
SetSendNodeAttr(const AnfNodePtr & send_node,const InterProcessOpEdge & inter_process_edge)172 void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge) {
173   const auto &send_src_node = inter_process_edge.src_node;
174   const auto &send_dst_node = inter_process_edge.dst_node;
175   MS_EXCEPTION_IF_NULL(send_src_node);
176   MS_EXCEPTION_IF_NULL(send_dst_node);
177   MS_EXCEPTION_IF_NULL(send_node);
178 
179   std::string src_node_name = send_src_node->fullname_with_scope();
180   std::string dst_node_name = send_dst_node->fullname_with_scope();
181 
182   // These attributes are the inter-process edge information.
183   std::vector<uint32_t> dst_ranks = {inter_process_edge.dst_label.rank_id};
184   common::AnfAlgo::SetNodeAttr(kAttrSendDstRanks, MakeValue(dst_ranks), send_node);
185   std::vector<std::string> dst_roles = {inter_process_edge.dst_label.ms_role};
186   common::AnfAlgo::SetNodeAttr(kAttrSendDstRoles, MakeValue(dst_roles), send_node);
187 
188   common::AnfAlgo::SetNodeAttr(kAttrSendSrcNodeName, MakeValue(src_node_name), send_node);
189   common::AnfAlgo::SetNodeAttr(kAttrSendDstNodeName, MakeValue(dst_node_name), send_node);
190   std::vector<std::string> inter_process_edges = {inter_process_edge.to_string()};
191   common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), send_node);
192 
193   // Set send node to CPU for now.
194   common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), send_node);
195 }
196 
SetRecvNodeAttr(const AnfNodePtr & recv_node,const InterProcessOpEdge & inter_process_edge)197 void SetRecvNodeAttr(const AnfNodePtr &recv_node, const InterProcessOpEdge &inter_process_edge) {
198   const auto &recv_src_node = inter_process_edge.src_node;
199   const auto &recv_dst_node = inter_process_edge.dst_node;
200   MS_EXCEPTION_IF_NULL(recv_src_node);
201   MS_EXCEPTION_IF_NULL(recv_dst_node);
202   MS_EXCEPTION_IF_NULL(recv_node);
203 
204   std::string src_node_name = recv_src_node->fullname_with_scope();
205   std::string dst_node_name = recv_dst_node->fullname_with_scope();
206 
207   // These attributes are the inter-process edge information.
208   std::vector<uint32_t> src_ranks = {inter_process_edge.src_label.rank_id};
209   common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRanks, MakeValue(src_ranks), recv_node);
210   std::vector<std::string> src_roles = {inter_process_edge.src_label.ms_role};
211   common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRoles, MakeValue(src_roles), recv_node);
212 
213   common::AnfAlgo::SetNodeAttr(kAttrRecvSrcNodeName, MakeValue(src_node_name), recv_node);
214   common::AnfAlgo::SetNodeAttr(kAttrRecvDstNodeName, MakeValue(dst_node_name), recv_node);
215   std::vector<std::string> inter_process_edges = {inter_process_edge.to_string()};
216   common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), recv_node);
217 
218   // Set recv node to CPU for now.
219   common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), recv_node);
220 }
221 
CreateSendNode(const FuncGraphPtr & func_graph,const InterProcessOpEdge & inter_process_edge)222 CNodePtr CreateSendNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge &inter_process_edge) {
223   const auto &src_node = inter_process_edge.src_node;
224   const auto &dst_node = inter_process_edge.dst_node;
225   MS_EXCEPTION_IF_NULL(src_node);
226   MS_EXCEPTION_IF_NULL(dst_node);
227 
228   std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
229   ValueNodePtr mock_value = nullptr;
230   if (IsPrimitiveCNode(src_node, prim::kPrimUpdateState)) {
231     mock_value = CreateFakeValueNode(false);
232     send_inputs.push_back(mock_value);
233     send_inputs.push_back(src_node);
234   } else {
235     if (src_node->abstract()->isa<abstract::AbstractTuple>()) {
236       // If src_node's output is a tuple, get the first element of the tuple as Send's input.
237       auto tuple_get_item_node = CreateTupleGetItemNode(func_graph, src_node, kIndex0);
238       send_inputs.push_back(tuple_get_item_node);
239       mock_value = CreateFakeValueNode(true, tuple_get_item_node);
240     } else {
241       send_inputs.push_back(src_node);
242       mock_value = CreateFakeValueNode(true, src_node);
243     }
244   }
245   CNodePtr send_node = func_graph->NewCNode(send_inputs);
246   MS_EXCEPTION_IF_NULL(send_node);
247   send_node->set_abstract(mock_value->abstract());
248 
249   SetSendNodeAttr(send_node, inter_process_edge);
250   return send_node;
251 }
252 
CreateRecvNode(const FuncGraphPtr & func_graph,const InterProcessOpEdge & inter_process_edge)253 CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge &inter_process_edge) {
254   const auto &src_node = inter_process_edge.src_node;
255   const auto &dst_node = inter_process_edge.dst_node;
256   MS_EXCEPTION_IF_NULL(src_node);
257   MS_EXCEPTION_IF_NULL(dst_node);
258 
259   std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
260   CNodePtr recv_node = nullptr;
261   AbstractBasePtr recv_node_abs = nullptr;
262   if (IsPrimitiveCNode(src_node, prim::kPrimUpdateState)) {
263     ValuePtr monad_value = nullptr;
264     if (HasAbstractUMonad(src_node)) {
265       monad_value = kUMonad;
266     } else if (HasAbstractIOMonad(src_node)) {
267       monad_value = kIOMonad;
268     } else {
269       MS_LOG(EXCEPTION) << "The src_node is PrimUpdateState must have monad abstract.";
270     }
271     auto monad_input = NewValueNode(monad_value);
272     MS_EXCEPTION_IF_NULL(monad_input);
273     monad_input->set_abstract(monad_value->ToAbstract());
274     recv_inputs.push_back(monad_input);
275     recv_node_abs = src_node->abstract();
276   } else {
277     if (src_node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrUpdateParameter, src_node->cast<CNodePtr>()) &&
278         common::AnfAlgo::HasNodeAttr(kAttrParameterInputIndex, src_node->cast<CNodePtr>())) {
279       int64_t parameter_index = common::AnfAlgo::GetNodeAttr<int64_t>(src_node, kAttrParameterInputIndex);
280       auto kernel_with_index = common::AnfAlgo::VisitKernel(
281         common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), parameter_index), kIndex0);
282       auto param_node = kernel_with_index.first;
283       recv_inputs.push_back(param_node);
284 
285       // To update the parameter on the device side in heterogeneous case, side-effect node should be added to recv's
286       // input.
287       ValuePtr monad_value = kUMonad;
288       auto monad_input = NewValueNode(monad_value);
289       MS_EXCEPTION_IF_NULL(monad_input);
290       monad_input->set_abstract(monad_value->ToAbstract());
291       recv_inputs.push_back(monad_input);
292 
293       recv_node_abs = param_node->abstract();
294     } else if (src_node->isa<CNode>() && common::AnfAlgo::GetCNodeName(src_node) == distributed::kDataSyncSrcOpName) {
295       auto kernel_with_index =
296         common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), kIndex0), kIndex0);
297       auto param_node = kernel_with_index.first;
298       recv_inputs.push_back(param_node);
299 
300       ValuePtr monad_value = kUMonad;
301       auto monad_input = NewValueNode(monad_value);
302       MS_EXCEPTION_IF_NULL(monad_input);
303       monad_input->set_abstract(monad_value->ToAbstract());
304       recv_inputs.push_back(monad_input);
305 
306       recv_node_abs = param_node->abstract();
307     } else {
308       // Use the same shape as origin node's.
309       auto mock_value = CreateFakeValueNode(true, src_node, false);
310       MS_EXCEPTION_IF_NULL(mock_value);
311       recv_inputs.push_back(mock_value);
312       recv_node_abs = src_node->abstract();
313     }
314   }
315   recv_node = func_graph->NewCNode(recv_inputs);
316   MS_EXCEPTION_IF_NULL(recv_node);
317   recv_node->set_abstract(recv_node_abs);
318 
319   SetRecvNodeAttr(recv_node, inter_process_edge);
320   return recv_node;
321 }
322 
GetRealIndexToSeg(const std::vector<size_t> & split_segment,size_t real_size)323 std::map<size_t, size_t> GetRealIndexToSeg(const std::vector<size_t> &split_segment, size_t real_size) {
324   std::map<size_t, size_t> result;
325   // If split_segment is empty, return an empty map.
326   if (split_segment.empty()) {
327     return result;
328   }
329 
330   // Check whether the vector of indices is valid.
331   if (!std::is_sorted(split_segment.begin(), split_segment.end())) {
332     MS_LOG(EXCEPTION) << "Indices of segments is not in a ascending order: " << split_segment;
333   }
334 
335   size_t real_index = 0;
336   for (size_t seg_index = 0; seg_index < split_segment.size(); seg_index++) {
337     size_t upper_bound = split_segment[seg_index];
338     for (; real_index < real_size; real_index++) {
339       if (real_index <= upper_bound) {
340         result[real_index] = seg_index;
341       } else {
342         break;
343       }
344     }
345   }
346 
347   // Map the rest of real index to a segment.
348   if (real_size > (*split_segment.rbegin()) + 1) {
349     for (; real_index < real_size; real_index++) {
350       result[real_index] = split_segment.size();
351     }
352   }
353   return result;
354 }
355 
IsOneOfRealGraphInput(const FuncGraphPtr & func_graph,const AnfNodePtr & input)356 bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &input) {
357   MS_EXCEPTION_IF_NULL(func_graph);
358   MS_EXCEPTION_IF_NULL(input);
359   auto all_inputs = func_graph->get_inputs();
360   return std::count(all_inputs.begin(), all_inputs.end(), input) != 0;
361 }
362 
GenerateStrategy()363 distributed::DistExecutionMode GenerateStrategy() {
364   distributed::DistExecutionMode strategy;
365   bool enable_ps = false;
366   bool enable_embedding_cache = false;
367 #if defined(__linux__) && defined(WITH_BACKEND)
368   enable_ps = ps::PSContext::instance()->is_ps_mode();
369   enable_embedding_cache = ps::PSContext::instance()->cache_enable();
370 #endif
371   std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
372   MS_LOG(INFO) << "Current parallel mode is " << parallel_mode;
373   bool using_parallel = (parallel_mode != parallel::kStandalone) ? true : false;
374   // The conditions' priority is: EmbeddingCache > Parameter Server > General.
375   if (enable_embedding_cache) {
376     strategy = distributed::DistExecutionMode::kEmbeddingCacheMode;
377   } else if (enable_ps) {
378     strategy = distributed::DistExecutionMode::kPSMode;
379   } else if (using_parallel) {
380     strategy = distributed::DistExecutionMode::kParallelMode;
381   } else {
382     strategy = distributed::DistExecutionMode::kGeneralMode;
383   }
384   MS_LOG(INFO) << "Generated distributed strategy is " << strategy;
385   return strategy;
386 }
387 
TransformPrimAttrToAttr(const CNodePtr & cnode)388 void TransformPrimAttrToAttr(const CNodePtr &cnode) {
389   MS_EXCEPTION_IF_NULL(cnode);
390   auto prim = GetValueNode<PrimitivePtr>(cnode->input(kIndex0));
391   MS_EXCEPTION_IF_NULL(prim);
392   if (cnode->HasPrimalAttr(distributed::kOpLabelRankId)) {
393     MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'rank_id'.";
394     prim->set_attr(distributed::kOpLabelRankId, cnode->GetPrimalAttr(distributed::kOpLabelRankId));
395   }
396   if (cnode->HasPrimalAttr(distributed::kOpLabelRole)) {
397     MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'ms_role'.";
398     prim->set_attr(distributed::kOpLabelRole, cnode->GetPrimalAttr(distributed::kOpLabelRole));
399   }
400 }
401 
NodeHasLabel(const AnfNodePtr & node)402 bool NodeHasLabel(const AnfNodePtr &node) {
403   MS_EXCEPTION_IF_NULL(node);
404   if (!node->isa<CNode>()) {
405     return false;
406   }
407 
408   bool has_label = false;
409   CNodePtr cnode = node->cast<CNodePtr>();
410   MS_EXCEPTION_IF_NULL(cnode);
411   auto prim_node = cnode->input(0);
412   MS_EXCEPTION_IF_NULL(prim_node);
413 
414   // As long as the node has 'ms_role' and 'rank_id' attributes, we consider this node has label regardless the value of
415   // these two attributes.
416   if (IsValueNode<Primitive>(prim_node)) {
417     auto prim = GetValueNode<PrimitivePtr>(prim_node);
418     MS_EXCEPTION_IF_NULL(prim);
419     if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
420       has_label = true;
421     }
422   } else {
423     // Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
424     if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
425       has_label = true;
426     }
427   }
428   return has_label;
429 }
430 
GraphHasLabel(const FuncGraphPtr & func_graph)431 bool GraphHasLabel(const FuncGraphPtr &func_graph) {
432   MS_EXCEPTION_IF_NULL(func_graph);
433 
434   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph->get_return());
435   // If one node has label, this graph has label. Thus it needs to be split.
436   for (const auto &node : all_nodes) {
437     MS_EXCEPTION_IF_NULL(node);
438     if (NodeHasLabel(node)) {
439       return true;
440     }
441   }
442   return false;
443 }
444 
GetSideEffectNodeList(const AnfNodePtrList & nodes)445 CNodePtrList GetSideEffectNodeList(const AnfNodePtrList &nodes) {
446   CNodePtrList side_effect_nodes;
447   for (const auto &node : nodes) {
448     MS_EXCEPTION_IF_NULL(node);
449     if (!node->isa<CNode>()) {
450       continue;
451     }
452     auto cnode = node->cast<CNodePtr>();
453     MS_EXCEPTION_IF_NULL(cnode);
454     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
455     MS_EXCEPTION_IF_NULL(prim);
456     if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_MEM)) {
457       (void)side_effect_nodes.emplace_back(cnode);
458       MS_LOG(DEBUG) << "CNode with side effect mem: " << cnode->fullname_with_scope();
459     }
460   }
461   return side_effect_nodes;
462 }
463 
GetRefInputs(const CNodePtr & cnode)464 AnfNodePtrList GetRefInputs(const CNodePtr &cnode) {
465   MS_EXCEPTION_IF_NULL(cnode);
466   AnfNodePtrList ref_inputs;
467   for (size_t i = kIndex1; i < cnode->size(); ++i) {
468     auto &input = cnode->inputs().at(i);
469     if (common::AnfAlgo::HasAbstractRef(input)) {
470       ref_inputs.push_back(input);
471       MS_LOG(DEBUG) << "The ref input " << input->fullname_with_scope() << " of node " << cnode->fullname_with_scope();
472     }
473   }
474   return ref_inputs;
475 }
476 
FindNextUpdateStateNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)477 CNodePtr FindNextUpdateStateNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
478   MS_EXCEPTION_IF_NULL(func_graph);
479   MS_EXCEPTION_IF_NULL(cnode);
480   auto cnode_users = func_graph->manager()->node_users()[cnode];
481   for (const auto &user : cnode_users) {
482     auto user_node = user.first;
483     MS_EXCEPTION_IF_NULL(user_node);
484     if (common::AnfAlgo::GetCNodeName(user_node) == kUpdateStateOpName) {
485       return user_node->cast<CNodePtr>();
486     }
487   }
488   return nullptr;
489 }
490 
CreateUMonadNode()491 ValueNodePtr CreateUMonadNode() {
492   ValuePtr monad_value = kUMonad;
493   auto monad_input = NewValueNode(monad_value);
494   MS_EXCEPTION_IF_NULL(monad_input);
495   monad_input->set_abstract(monad_value->ToAbstract());
496   return monad_input;
497 }
498 
CreateUpdateStateNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & update_state_inputs)499 CNodePtr CreateUpdateStateNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &update_state_inputs) {
500   if (update_state_inputs.empty()) {
501     MS_LOG(EXCEPTION) << "The inputs of UpdateState should not be empty.";
502   }
503   // The first input of UpdateState is an 'U'.
504   ValueNodePtr umonad_input = CreateUMonadNode();
505   MS_EXCEPTION_IF_NULL(umonad_input);
506   AnfNodePtrList inputs = {NewValueNode(prim::kPrimUpdateState), umonad_input};
507   (void)inputs.insert(inputs.end(), update_state_inputs.begin(), update_state_inputs.end());
508 
509   auto update_state_node = func_graph->NewCNode(inputs);
510   MS_EXCEPTION_IF_NULL(update_state_node);
511   update_state_node->set_abstract(umonad_input->abstract());
512   return update_state_node;
513 }
514 
FilterDependencyToTargetNode(const FuncGraphPtr & func_graph,const AnfNodePtrSet & target_nodes)515 std::map<AnfNodePtr, AnfNodePtrSet> FilterDependencyToTargetNode(const FuncGraphPtr &func_graph,
516                                                                  const AnfNodePtrSet &target_nodes) {
517   std::map<AnfNodePtr, AnfNodePtrSet> depend_matrix;
518   MS_EXCEPTION_IF_NULL(func_graph);
519   auto return_node = func_graph->get_return();
520   MS_EXCEPTION_IF_NULL(return_node);
521   AnfNodePtrList nodes = FuncGraph::TopoSort(return_node);
522   // Trasverse all nodes in topo-sort so that time complexity is O(n).
523   for (const auto &node : nodes) {
524     MS_EXCEPTION_IF_NULL(node);
525     if (!node->isa<CNode>()) {
526       continue;
527     }
528     CNodePtr cnode = node->cast<CNodePtr>();
529     MS_EXCEPTION_IF_NULL(cnode);
530     const auto &inputs = cnode->inputs();
531     // Traverse all inputs and only filter out inputs which is in target nodes set.
532     for (const auto &input : inputs) {
533       MS_EXCEPTION_IF_NULL(input);
534       // If the input is stored already, this means it depends on some of target nodes, so we expand its inputs and
535       // insert them.
536       if (depend_matrix.count(input) != 0) {
537         (void)depend_matrix[node].insert(depend_matrix[input].begin(), depend_matrix[input].end());
538       }
539       // If input itself is in target nodes set, insert it as well.
540       if (target_nodes.contains(input)) {
541         (void)depend_matrix[node].insert(input);
542       }
543     }
544   }
545   return depend_matrix;
546 }
547 
UpdateDependedSet(const AnfNodePtr & new_node,const AnfNodePtrSet & old_depended_set,const std::map<AnfNodePtr,AnfNodePtrSet> & node_dependency)548 AnfNodePtrSet UpdateDependedSet(const AnfNodePtr &new_node, const AnfNodePtrSet &old_depended_set,
549                                 const std::map<AnfNodePtr, AnfNodePtrSet> &node_dependency) {
550   AnfNodePtrSet updated = old_depended_set;
551   bool is_independent = true;
552   for (const auto &stored_node : old_depended_set) {
553     // If 'new_node' is already depended on by 'stored_node', no need to add 'new_node'.
554     if (node_dependency.count(stored_node) != 0 && node_dependency.at(stored_node).contains(new_node)) {
555       MS_LOG(DEBUG) << "Old node " << stored_node->fullname_with_scope() << " depends on "
556                     << new_node->fullname_with_scope() << ". Do not update.";
557       is_independent = false;
558       break;
559     }
560     // If 'new_node' depends on 'stored_node', replace 'stored_node' with 'new_node' to keep minimal dependency.
561     if (node_dependency.count(new_node) != 0 && node_dependency.at(new_node).contains(stored_node)) {
562       MS_LOG(DEBUG) << "Replace old node " << stored_node->fullname_with_scope() << " with new node "
563                     << new_node->fullname_with_scope();
564       (void)updated.erase(stored_node);
565       (void)updated.insert(new_node);
566     }
567   }
568   if (is_independent) {
569     MS_LOG(DEBUG) << "Add new node to depended set " << new_node->fullname_with_scope();
570     (void)updated.insert(new_node);
571   }
572   return updated;
573 }
574 
HandleHungNodes(const FuncGraphPtr & func_graph,const NodeLabels & node_labels,OperatorLabel process_label,const AnfNodePtrList & hung_nodes_list)575 void HandleHungNodes(const FuncGraphPtr &func_graph, const NodeLabels &node_labels, OperatorLabel process_label,
576                      const AnfNodePtrList &hung_nodes_list) {
577   MS_EXCEPTION_IF_NULL(func_graph);
578   auto make_tuple_node = CreateMakeTupleNode(func_graph, hung_nodes_list);
579   MS_EXCEPTION_IF_NULL(make_tuple_node);
580 
581   const auto &origin_output = func_graph->output();
582   MS_EXCEPTION_IF_NULL(origin_output);
583   if (node_labels.count(origin_output) == 0) {
584     MS_LOG(EXCEPTION) << "The origin output node " << origin_output->fullname_with_scope()
585                       << " should have corresponding operator label.";
586   }
587   AnfNodePtr replaced_output = nullptr;
588   if (node_labels.at(origin_output) != process_label) {
589     replaced_output = CreateReplacedOutputNode(func_graph, origin_output);
590   } else {
591     replaced_output = origin_output;
592   }
593   MS_EXCEPTION_IF_NULL(replaced_output);
594 
595   // Add dependency and replace.
596   std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output, make_tuple_node};
597   auto final_output_node = func_graph->NewCNode(depend_inputs);
598   MS_EXCEPTION_IF_NULL(final_output_node);
599   final_output_node->set_abstract(replaced_output->abstract());
600   (void)func_graph->manager()->SetEdge(func_graph->get_return(), 1, final_output_node);
601 }
602 
PreBuildDistributedGraph()603 void ParameterServerMode::PreBuildDistributedGraph() {
604   MS_LOG(INFO) << "Start pre-building distribtued graph in Parameter Server mode.";
605   MS_EXCEPTION_IF_NULL(node_labels_);
606   ProcessForSplitOptimizer();
607   MS_LOG(INFO) << "End pre-building distribtued graph in Parameter Server mode.";
608 }
609 
DoRpcNodeFusion(InterProcessOpEdgesInfo * comm_edges_ptr)610 FusedInterProcessOpPairMap ParameterServerMode::DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) {
611   MS_EXCEPTION_IF_NULL(comm_edges_ptr);
612 
613   // The edges of server optimizers should be fused with same peers. For example, edges from Worker_0 to Server_0 will
614   // be fused by segments.
615   InterProcessOpEdgesInfo comm_edges_of_server_optimizer = FilterCommEdgesOfServerOptimizer(*comm_edges_ptr);
616   FusedInterProcessOpPairMap optimizer_fused_edges = FuseRpcNodesForSplitOptimizer(comm_edges_of_server_optimizer);
617 
618   // The rest of the edges are not fused like edges for EmbeddingLookup, but the FusedInterProcessOpPairMap object
619   // should be created.
620   FusedInterProcessOpPairMap rest_edges = FilterNotServerOptimizerEdges(*comm_edges_ptr);
621   (void)optimizer_fused_edges.insert(rest_edges.cbegin(), rest_edges.cend());
622   return optimizer_fused_edges;
623 }
624 
PostBuildDistributedGraph(const InterProcessOpEdgesInfo & comm_edges)625 void ParameterServerMode::PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
626   MS_LOG(INFO) << "Start post-building distribtued graph in Parameter Server mode.";
627   MS_EXCEPTION_IF_NULL(node_labels_);
628   // Judge the node role number validation.
629   uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
630   if (worker_num == 0) {
631     MS_LOG(EXCEPTION) << "In PS mode, worker number should be greater than 0.";
632   }
633   uint32_t server_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfServer);
634   if (server_num == 0) {
635     MS_LOG(EXCEPTION) << "In PS mode, server number should be greater than 0.";
636   }
637   // Only multiple worker scenario needs this optimizer.
638   if (worker_num < kMinGradAccumWorkerNum) {
639     return;
640   }
641 
642   MS_EXCEPTION_IF_NULL(func_graph_);
643   std::vector<CNodePtr> ps_optimizer_node_list = FilterServerAwareOptimizerList();
644 
645   // Duplicate out degrees for ps optimizers because defaultly there's only one edge to the rank 0 worker.
646   for (const auto &ps_optimizer : ps_optimizer_node_list) {
647     for (const auto &edge_info : comm_edges) {
648       if (edge_info.first.src_node == ps_optimizer) {
649         // The optimizer's output should always connect to Send node which is the input of a MakeTuple node.
650         // We need to replace the MakeTuple node with a new one.
651         const auto &origin_send_node = std::get<0>(edge_info.second);
652         std::vector<AnfNodePtr> new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), origin_send_node};
653         AnfNodePtr dst_node = edge_info.first.dst_node;
654         for (uint32_t i = 1; i < worker_num; i++) {
655           OperatorLabel worker_label = {i, distributed::kEnvRoleOfWorker};
656           InterProcessOpEdge edge = {ps_optimizer, node_labels_->at(ps_optimizer), dst_node, worker_label};
657           auto duplicated_send_node = CreateSendNode(func_graph_, edge);
658           (void)node_labels_->insert(std::make_pair(duplicated_send_node, edge.src_label));
659           (void)new_make_tuple_inputs.emplace_back(duplicated_send_node);
660         }
661         auto new_make_tuple_node = func_graph_->NewCNode(new_make_tuple_inputs);
662         new_make_tuple_node->set_abstract(new_make_tuple_inputs.back()->abstract());
663         (void)func_graph_->manager()->Replace(origin_send_node, new_make_tuple_node);
664       }
665     }
666   }
667   MS_LOG(INFO) << "End post-building distribtued graph in Parameter Server mode.";
668 }
669 
PostBuildDistributedGraph(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)670 void ParameterServerMode::PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
671   MS_LOG(INFO) << "Start post-building distribtued graph in Parameter Server mode.";
672   MS_EXCEPTION_IF_NULL(node_labels_);
673   // Judge the node role number validation.
674   uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
675   if (worker_num == 0) {
676     MS_LOG(EXCEPTION) << "In PS mode, worker number should be greater than 0.";
677   }
678   uint32_t server_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfServer);
679   if (server_num == 0) {
680     MS_LOG(EXCEPTION) << "In PS mode, server number should be greater than 0.";
681   }
682   // Only multiple worker scenario needs this optimizer.
683   if (worker_num < kMinGradAccumWorkerNum) {
684     return;
685   }
686 
687   MS_EXCEPTION_IF_NULL(func_graph_);
688   std::vector<CNodePtr> ps_optimizer_node_list = FilterServerAwareOptimizerList();
689   if (ps_optimizer_node_list.empty()) {
690     MS_LOG(INFO) << "This process has no ps optimizer on it. No need to do post building.";
691     return;
692   }
693 
694   // Duplicate out degrees for ps optimizers because defaultly there's only one edge to the rank 0 worker.
695   for (const auto &op_pair_info : fused_inter_process_op_pairs) {
696     const auto &op_pairs = op_pair_info.second;
697     CNodePtr fused_send_node = std::get<0>(op_pairs[0]);
698     // Node's inputs except primtive value node.
699     std::vector<AnfNodePtr> fused_send_node_inputs = fused_send_node->inputs();
700     (void)fused_send_node_inputs.erase(fused_send_node_inputs.cbegin());
701 
702     // Only handle the edge whose src_node is optimizer.
703     if (std::find_if(ps_optimizer_node_list.cbegin(), ps_optimizer_node_list.cend(), [&](const auto &ps_optimizer) {
704           return ps_optimizer.get() == fused_send_node_inputs[0].get();
705         }) == ps_optimizer_node_list.cend()) {
706       continue;
707     }
708 
709     std::vector<AnfNodePtr> new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), fused_send_node};
710     for (uint32_t i = 1; i < worker_num; i++) {
711       std::vector<CNodePtr> new_send_nodes;
712       OperatorLabel worker_label = {i, distributed::kEnvRoleOfWorker};
713       for (size_t j = 0; j < op_pairs.size(); j++) {
714         const auto &src_node = fused_send_node_inputs[j];
715         const auto &dst_node = std::get<3>(op_pairs[j]);
716         InterProcessOpEdge edge = {src_node, node_labels_->at(src_node), dst_node, worker_label};
717         auto duplicated_send_node = CreateSendNode(func_graph_, edge);
718         MS_EXCEPTION_IF_NULL(duplicated_send_node);
719         (void)node_labels_->insert(std::make_pair(duplicated_send_node, edge.src_label));
720         (void)new_send_nodes.emplace_back(duplicated_send_node);
721       }
722       CNodePtr new_fused_send_node = FuseRpcSendNodes(new_send_nodes);
723       MS_EXCEPTION_IF_NULL(new_fused_send_node);
724       (void)new_make_tuple_inputs.emplace_back(new_fused_send_node);
725     }
726     auto new_make_tuple_node = func_graph_->NewCNode(new_make_tuple_inputs);
727     new_make_tuple_node->set_abstract(fused_send_node->abstract());
728     (void)func_graph_->manager()->Replace(fused_send_node, new_make_tuple_node);
729   }
730   MS_LOG(INFO) << "End post-building distribtued graph in Parameter Server mode.";
731 }
732 
ProcessForSplitOptimizer()733 void ParameterServerMode::ProcessForSplitOptimizer() {
734   MS_EXCEPTION_IF_NULL(func_graph_);
735   std::vector<CNodePtr> ps_optimizer_node_list = FilterServerAwareOptimizerList();
736 
737   // Judge the node role number validation.
738   uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
739   if (worker_num == 0) {
740     MS_LOG(EXCEPTION) << "In PS mode, worker number should be greater than 0.";
741   }
742   uint32_t server_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfServer);
743   if (server_num == 0) {
744     MS_LOG(EXCEPTION) << "In PS mode, server number should be greater than 0.";
745   }
746   // Only multiple worker scenario needs this optimizer.
747   if (worker_num < kMinGradAccumWorkerNum) {
748     return;
749   }
750 
751   for (const auto &ps_optimizer : ps_optimizer_node_list) {
752     MS_EXCEPTION_IF_NULL(ps_optimizer);
753     // Load attributes for this optimizer.
754     auto gradient_index = common::AnfAlgo::HasNodeAttr(kAttrGradientInputIndex, ps_optimizer)
755                             ? LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(ps_optimizer, kAttrGradientInputIndex))
756                             : UINT64_MAX;
757     size_t indices_index = common::AnfAlgo::HasNodeAttr(kAttrIndicesInputIndex, ps_optimizer)
758                              ? LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(ps_optimizer, kAttrIndicesInputIndex))
759                              : UINT64_MAX;
760     std::string gradient_type = (common::AnfAlgo::HasNodeAttr(kAttrGradientType, ps_optimizer))
761                                   ? common::AnfAlgo::GetNodeAttr<std::string>(ps_optimizer, kAttrGradientType)
762                                   : kDenseGradient;
763     if (kGradTypeToAccumOpName.count(gradient_type) == 0) {
764       MS_LOG(EXCEPTION) << "The gradient type " << gradient_type << " is invalid.";
765     }
766 
767     const std::string &opt_device_target = GetCNodeTarget(ps_optimizer);
768     for (size_t i = 0; i < common::AnfAlgo::GetInputNum(ps_optimizer); i++) {
769       auto input = common::AnfAlgo::GetInputNode(ps_optimizer, i);
770       // If the input is not a cnode, no inter-process edge is added so no node with multiple inputs should be created.
771       // Unless it's a real input.
772       if (!input->isa<CNode>()) {
773         if (IsOneOfRealGraphInput(func_graph_, input)) {
774           MS_LOG(INFO) << "The input " << i << " of optimizer " << ps_optimizer->fullname_with_scope() << ": "
775                        << input->fullname_with_scope() << " is a real input from data.";
776         } else {
777           continue;
778         }
779       }
780 
781       if (i == gradient_index) {
782         // Create the node to replace origin gradient which could be a RealDiv node.
783         std::pair<CNodePtr, CNodePtr> grad_accum_nodes = CreateNodesForGradAccumulation(
784           input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, gradient_type, worker_num);
785 
786         const auto &accum_node = grad_accum_nodes.first;
787         const auto &real_div_node = grad_accum_nodes.second;
788         func_graph_->manager()->SetEdge(ps_optimizer, i + 1, real_div_node);
789         (void)node_labels_->insert(std::make_pair(accum_node, node_labels_->at(ps_optimizer)));
790         (void)node_labels_->insert(std::make_pair(real_div_node, node_labels_->at(ps_optimizer)));
791         common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), accum_node);
792         common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), real_div_node);
793         common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), accum_node);
794         common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), real_div_node);
795       } else if (i == indices_index) {
796         // Create the node to replace origin indices.
797         AnfNodePtr new_indices_input = CreateNodeWithInterProcessEdgeOnPServer(
798           kConcatOpName, input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, worker_num);
799 
800         func_graph_->manager()->SetEdge(ps_optimizer, i + 1, new_indices_input);
801         (void)node_labels_->insert(std::make_pair(new_indices_input, node_labels_->at(ps_optimizer)));
802         common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), new_indices_input);
803         common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), new_indices_input);
804       } else {
805         std::pair<CNodePtr, CNodePtr> make_tuple_get_item_nodes = CreateNodesForMakeTuple(input, worker_num);
806 
807         auto &make_tuple_node = make_tuple_get_item_nodes.first;
808         auto &tuple_get_item_node = make_tuple_get_item_nodes.second;
809         func_graph_->manager()->SetEdge(ps_optimizer, i + 1, tuple_get_item_node);
810         (void)node_labels_->insert(std::make_pair(make_tuple_node, node_labels_->at(ps_optimizer)));
811         (void)node_labels_->insert(std::make_pair(tuple_get_item_node, node_labels_->at(ps_optimizer)));
812         common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), make_tuple_node);
813         common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), tuple_get_item_node);
814         common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), make_tuple_node);
815         common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), tuple_get_item_node);
816       }
817     }
818   }
819 }
820 
FilterServerAwareOptimizerList()821 std::vector<CNodePtr> ParameterServerMode::FilterServerAwareOptimizerList() {
822   MS_EXCEPTION_IF_NULL(func_graph_);
823   auto return_node = func_graph_->get_return();
824   MS_EXCEPTION_IF_NULL(return_node);
825 
826   std::vector<CNodePtr> ps_optim_list;
827   std::vector<AnfNodePtr> nodes = FuncGraph::TopoSort(return_node);
828   for (const auto &node : nodes) {
829     if (!node->isa<CNode>()) {
830       continue;
831     }
832     const auto &cnode = node->cast<CNodePtr>();
833     if (common::AnfAlgo::HasNodeAttr(kAttrUpdateParameter, cnode)) {
834       (void)ps_optim_list.emplace_back(cnode);
835       common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), cnode);
836     }
837   }
838   return ps_optim_list;
839 }
840 
CreateNodesForGradAccumulation(const AnfNodePtr & gradient_input,size_t gradient_input_index,const std::string & gradient_type,size_t total_gradient_number)841 std::pair<CNodePtr, CNodePtr> ParameterServerMode::CreateNodesForGradAccumulation(const AnfNodePtr &gradient_input,
842                                                                                   size_t gradient_input_index,
843                                                                                   const std::string &gradient_type,
844                                                                                   size_t total_gradient_number) {
845   MS_EXCEPTION_IF_NULL(gradient_input);
846 
847   if (kGradTypeToAccumOpName.count(gradient_type) == 0) {
848     MS_LOG(EXCEPTION) << "The gradient type " << gradient_type << " is invalid.";
849   }
850   const std::string &accum_node_name = kGradTypeToAccumOpName.at(gradient_type);
851   CNodePtr grad_accum_node = CreateNodeWithInterProcessEdgeOnPServer(accum_node_name, gradient_input,
852                                                                      gradient_input_index, total_gradient_number);
853   MS_EXCEPTION_IF_NULL(grad_accum_node);
854 
855   CNodePtr grad_mean_node = CreateGradMeanNode(grad_accum_node, total_gradient_number);
856   MS_EXCEPTION_IF_NULL(grad_mean_node);
857   return std::make_pair(grad_accum_node, grad_mean_node);
858 }
859 
CreateGradMeanNode(const AnfNodePtr & gradient,size_t divisor)860 CNodePtr ParameterServerMode::CreateGradMeanNode(const AnfNodePtr &gradient, size_t divisor) {
861   MS_EXCEPTION_IF_NULL(gradient);
862 
863   // Step 1: Create the value node of divisor. The divisor's value is worker number.
864   auto addn_abstract = gradient->abstract()->cast<abstract::AbstractTensorPtr>();
865   MS_EXCEPTION_IF_NULL(addn_abstract);
866   // Use reciprocal of the divisor so Mul node should be created.
867   auto divisor_tensor =
868     std::make_shared<tensor::Tensor>(1 / static_cast<double>(divisor), addn_abstract->element()->BuildType());
869   MS_EXCEPTION_IF_NULL(divisor_tensor);
870   auto divisor_value_node = NewValueNode(divisor_tensor);
871   MS_EXCEPTION_IF_NULL(divisor_value_node);
872   divisor_value_node->set_abstract(divisor_tensor->ToAbstract());
873 
874   // Step 2: Create Mul node.
875   std::vector<AnfNodePtr> real_div_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), gradient,
876                                              divisor_value_node};
877   CNodePtr grad_mean_node = func_graph_->NewCNode(real_div_inputs);
878   MS_EXCEPTION_IF_NULL(grad_mean_node);
879   grad_mean_node->set_abstract(gradient->abstract());
880   return grad_mean_node;
881 }
882 
CreateNodesForMakeTuple(const AnfNodePtr & input,size_t total_inputs_number)883 std::pair<CNodePtr, CNodePtr> ParameterServerMode::CreateNodesForMakeTuple(const AnfNodePtr &input,
884                                                                            size_t total_inputs_number) {
885   MS_EXCEPTION_IF_NULL(input);
886   CNodePtr make_tuple_node = CreateNodeWithInterProcessEdgeOnPServer(
887     kMakeTupleOpName, input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, total_inputs_number);
888   MS_EXCEPTION_IF_NULL(make_tuple_node);
889   // For MakeTuple node on Parameter Server, we get the first input as its abstract because the other inputs are
890   // supposed to be the same as the first one.
891   CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, make_tuple_node, kIndex0);
892   return std::make_pair(make_tuple_node, tuple_get_item_node);
893 }
894 
CreateNodeWithInterProcessEdgeOnPServer(const std::string & many_to_one_node_name,const AnfNodePtr & real_input,size_t index_of_real_input,uint32_t total_inputs_number)895 CNodePtr ParameterServerMode::CreateNodeWithInterProcessEdgeOnPServer(const std::string &many_to_one_node_name,
896                                                                       const AnfNodePtr &real_input,
897                                                                       size_t index_of_real_input,
898                                                                       uint32_t total_inputs_number) {
899   if (index_of_real_input >= total_inputs_number) {
900     MS_LOG(EXCEPTION) << "The index of real input for " << many_to_one_node_name << " " << index_of_real_input
901                       << " is greater or equal to worker number " << total_inputs_number;
902   }
903 
904   // Step 1: Create multiple inputs of new node including extra nodes.
905   std::vector<AnfNodePtr> new_node_inputs;
906   new_node_inputs.resize(total_inputs_number);
907   std::vector<AnfNodePtr> mock_node_inputs = {NewValueNode(std::make_shared<Primitive>(
908     IsPrimitiveCNode(real_input, prim::kPrimUpdateState) ? kUpdateStateOpName : kVirtualNode))};
909   for (size_t i = 0; i < new_node_inputs.size(); i++) {
910     new_node_inputs[i] = func_graph_->NewCNode(mock_node_inputs);
911     MS_EXCEPTION_IF_NULL(new_node_inputs[i]);
912     new_node_inputs[i]->set_abstract(real_input->abstract());
913     new_node_inputs[i]->cast<CNodePtr>()->set_fullname_with_scope(real_input->fullname_with_scope());
914 
915     // Set operator label for new node's inputs.
916     OperatorLabel input_label = {SizeToUint(i), distributed::kEnvRoleOfWorker};
917     (void)node_labels_->insert(std::make_pair(new_node_inputs[i], input_label));
918   }
919   new_node_inputs[index_of_real_input] = real_input;
920 
921   // Step 2: Create the new node.
922   auto new_node_prim = NewValueNode(std::make_shared<Primitive>(many_to_one_node_name));
923   (void)new_node_inputs.insert(new_node_inputs.cbegin(), new_node_prim);
924   if (many_to_one_node_name == kConcatOpName) {
925     // Create axis input for concat.
926     auto axis_value = MakeValue(0L);
927     MS_EXCEPTION_IF_NULL(axis_value);
928     auto axis_value_node = NewValueNode(axis_value);
929     MS_EXCEPTION_IF_NULL(axis_value_node);
930     axis_value_node->set_abstract(axis_value->ToAbstract());
931     (void)new_node_inputs.insert(new_node_inputs.cend(), axis_value_node);
932   }
933 
934   auto new_node = func_graph_->NewCNode(new_node_inputs);
935   MS_EXCEPTION_IF_NULL(new_node);
936 
937   // Step 3: Set the new node's abstract and attrs.
938   if (many_to_one_node_name == kAddNOpName) {
939     common::AnfAlgo::SetNodeAttr("N", MakeValue(static_cast<int64_t>(total_inputs_number)), new_node);
940     common::AnfAlgo::SetNodeAttr("n", MakeValue(static_cast<int64_t>(total_inputs_number)), new_node);
941     new_node->set_abstract(real_input->abstract());
942   } else if (many_to_one_node_name == kConcatOpName) {
943     auto origin_abs = real_input->abstract()->cast<abstract::AbstractTensorPtr>();
944     MS_EXCEPTION_IF_NULL(origin_abs);
945 
946     auto new_abs = origin_abs->Clone()->cast<abstract::AbstractTensorPtr>();
947     ShapeVector new_shape = new_abs->shape()->shape();
948     new_shape[0] = new_shape[0] * static_cast<int64_t>(total_inputs_number);
949     new_abs->shape()->set_shape(new_shape);
950     new_node->set_abstract(new_abs);
951   } else if (many_to_one_node_name == kMakeTupleOpName) {
952     AbstractBasePtrList abstract_list;
953     auto first_input = new_node_inputs.begin();
954     std::advance(first_input, 1);
955     (void)std::for_each(first_input, new_node_inputs.end(),
956                         [&](const auto &input) { (void)abstract_list.emplace_back(input->abstract()); });
957     new_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
958   } else {
959     new_node->set_abstract(real_input->abstract());
960   }
961   return new_node;
962 }
963 
FuseRpcNodesForSplitOptimizer(const InterProcessOpEdgesInfo & comm_edges_of_server_optimizer)964 FusedInterProcessOpPairMap ParameterServerMode::FuseRpcNodesForSplitOptimizer(
965   const InterProcessOpEdgesInfo &comm_edges_of_server_optimizer) {
966   InterProcessOpPairMap comm_edges_with_same_peer;
967   for (const auto &comm_edge_info : comm_edges_of_server_optimizer) {
968     const InterProcessOpEdge &edge = comm_edge_info.first;
969     const InterProcessOpPair &node_pair = comm_edge_info.second;
970     (void)comm_edges_with_same_peer[{edge.src_label, edge.dst_label, 0}].emplace_back(node_pair);
971   }
972 
973   InterProcessOpPairMap comm_edges_segments;
974   for (auto comm_edge_info = comm_edges_with_same_peer.cbegin(); comm_edge_info != comm_edges_with_same_peer.cend();
975        ++comm_edge_info) {
976     InterProcessEdgeWithIndex edge_with_index = comm_edge_info->first;
977     const std::vector<InterProcessOpPair> &op_pair_list = comm_edge_info->second;
978     std::map<size_t, size_t> real_index_to_segment =
979       GetRealIndexToSeg(ps_optimizer_fusion_segments_, op_pair_list.size());
980     if (real_index_to_segment.empty()) {
981       comm_edges_segments[edge_with_index] = op_pair_list;
982       continue;
983     } else {
984       if (real_index_to_segment.size() != op_pair_list.size()) {
985         MS_LOG(EXCEPTION) << "Real index to segment index map is invalid: size not matched.";
986       }
987       for (size_t i = 0; i < op_pair_list.size(); i++) {
988         edge_with_index.index = real_index_to_segment[i];
989         (void)comm_edges_segments[edge_with_index].emplace_back(op_pair_list[i]);
990       }
991     }
992   }
993 
994   FusedInterProcessOpPairMap results;
995   for (auto rpc_nodes_fuse_info = comm_edges_segments.begin(); rpc_nodes_fuse_info != comm_edges_segments.end();
996        ++rpc_nodes_fuse_info) {
997     // Reorder the rpc node pairs list. Place monad inputs to the end of the list so that rpc send/recv nodes can be
998     // built.
999     std::vector<InterProcessOpPair> &inter_process_pairs = (*rpc_nodes_fuse_info).second;
1000     std::vector<InterProcessOpPair> monad_pairs;
1001     std::vector<InterProcessOpPair> no_monad_pairs;
1002     (void)std::for_each(inter_process_pairs.begin(), inter_process_pairs.end(), [&](const auto &op_pair) {
1003       if (HasAbstractMonad(std::get<1>(op_pair))) {
1004         (void)monad_pairs.emplace_back(op_pair);
1005       } else {
1006         (void)no_monad_pairs.emplace_back(op_pair);
1007       }
1008     });
1009     (void)no_monad_pairs.insert(no_monad_pairs.cend(), monad_pairs.cbegin(), monad_pairs.cend());
1010     inter_process_pairs = no_monad_pairs;
1011 
1012     std::vector<FusedInterProcessOpPair> fused_pairs;
1013     if (!common::GetEnv("fusion2").empty()) {
1014       fused_pairs = FuseCommEdges(inter_process_pairs);
1015     } else {
1016       std::vector<CNodePtr> rpc_send_nodes, rpc_recv_nodes;
1017       (void)std::for_each(inter_process_pairs.begin(), inter_process_pairs.end(),
1018                           [&rpc_send_nodes, &rpc_recv_nodes](const auto &node_pair) {
1019                             (void)rpc_send_nodes.emplace_back(std::get<0>(node_pair));
1020                             (void)rpc_recv_nodes.emplace_back(std::get<1>(node_pair));
1021                           });
1022       CNodePtr fused_send_node = FuseRpcSendNodes(rpc_send_nodes);
1023       CNodePtr fused_recv_node = FuseRpcRecvNodes(rpc_recv_nodes);
1024 
1025       for (size_t i = 0; i < inter_process_pairs.size(); i++) {
1026         FusedInterProcessOpPair fused_inter_process_pair =
1027           std::make_tuple(fused_send_node, fused_recv_node, i, std::get<2>(inter_process_pairs[i]),
1028                           std::get<3>(inter_process_pairs[i]));
1029         (void)fused_pairs.emplace_back(fused_inter_process_pair);
1030       }
1031     }
1032     results[rpc_nodes_fuse_info->first] = fused_pairs;
1033   }
1034   return results;
1035 }
1036 
FilterCommEdgesOfServerOptimizer(const InterProcessOpEdgesInfo & comm_edges) const1037 InterProcessOpEdgesInfo ParameterServerMode::FilterCommEdgesOfServerOptimizer(
1038   const InterProcessOpEdgesInfo &comm_edges) const {
1039   InterProcessOpEdgesInfo comm_edges_of_server_optimizer;
1040   for (const auto &edge_info : comm_edges) {
1041     if (edge_info.first.edge_label.label_name == kPSOptimizerEdgeLabel) {
1042       (void)comm_edges_of_server_optimizer.insert(edge_info);
1043     }
1044   }
1045   return comm_edges_of_server_optimizer;
1046 }
1047 
FilterNotServerOptimizerEdges(const InterProcessOpEdgesInfo & comm_edges) const1048 FusedInterProcessOpPairMap ParameterServerMode::FilterNotServerOptimizerEdges(
1049   const InterProcessOpEdgesInfo &comm_edges) const {
1050   FusedInterProcessOpPairMap results;
1051   for (const auto &edge_info : comm_edges) {
1052     if (edge_info.first.edge_label.label_name != kPSOptimizerEdgeLabel) {
1053       const InterProcessOpEdge &edge = edge_info.first;
1054       const InterProcessOpPair &node_pair = edge_info.second;
1055 
1056       // We use the hash value to make these edges with index unique. So this index has no actual meaning.
1057       size_t edge_index = std::hash<std::string>{}(edge.to_string());
1058       InterProcessEdgeWithIndex edge_with_index = {edge.src_label, edge.dst_label, edge_index};
1059       FusedInterProcessOpPair fused_op_pair = std::make_tuple(std::get<0>(node_pair), std::get<1>(node_pair), 0,
1060                                                               std::get<2>(node_pair), std::get<3>(node_pair));
1061       std::vector<FusedInterProcessOpPair> pair_list = {fused_op_pair};
1062       (void)results.insert(std::make_pair(edge_with_index, pair_list));
1063     }
1064   }
1065   return results;
1066 }
1067 
FuseRpcSendNodes(const std::vector<CNodePtr> & rpc_send_nodes)1068 CNodePtr ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes) {
1069   if (rpc_send_nodes.empty()) {
1070     MS_LOG(EXCEPTION) << "Rpc send node list is empty.";
1071   }
1072   std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
1073   AbstractBasePtrList abstract_list;
1074   std::string fused_inter_process_edge_name = "";
1075   for (const auto &send_node : rpc_send_nodes) {
1076     MS_EXCEPTION_IF_NULL(send_node);
1077     for (size_t i = 1; i < send_node->size(); i++) {
1078       auto input_i = send_node->inputs()[i];
1079       MS_EXCEPTION_IF_NULL(input_i);
1080       // If the input of send is monad, do not pass it to fused send node.
1081       if (HasAbstractMonad(input_i)) {
1082         continue;
1083       }
1084       (void)send_inputs.emplace_back(input_i);
1085     }
1086     (void)abstract_list.emplace_back(send_node->abstract());
1087     (void)fused_inter_process_edge_name.append(
1088       common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(send_node, kAttrInterProcessEdgeNames).front());
1089   }
1090 
1091   CNodePtr fused_send_node = func_graph_->NewCNode(send_inputs);
1092   MS_EXCEPTION_IF_NULL(fused_send_node);
1093   fused_send_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
1094   std::vector<std::string> fused_inter_process_edge_names = {fused_inter_process_edge_name};
1095   common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(fused_inter_process_edge_names), fused_send_node);
1096   common::AnfAlgo::CopyNodeAttr(kAttrPrimitiveTarget, rpc_send_nodes[0], fused_send_node);
1097   common::AnfAlgo::CopyNodeAttr(kAttrSendDstRanks, rpc_send_nodes[0], fused_send_node);
1098   common::AnfAlgo::CopyNodeAttr(kAttrSendDstRoles, rpc_send_nodes[0], fused_send_node);
1099   common::AnfAlgo::CopyNodeAttr(kAttrSendSrcNodeName, rpc_send_nodes[0], fused_send_node);
1100   common::AnfAlgo::CopyNodeAttr(kAttrSendDstNodeName, rpc_send_nodes[0], fused_send_node);
1101   return fused_send_node;
1102 }
1103 
FuseRpcRecvNodes(const std::vector<CNodePtr> & rpc_recv_nodes)1104 CNodePtr ParameterServerMode::FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes) {
1105   std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
1106   AbstractBasePtrList abstract_list;
1107   std::string fused_inter_process_edge_name = "";
1108   for (const auto &recv_node : rpc_recv_nodes) {
1109     MS_EXCEPTION_IF_NULL(recv_node);
1110     for (size_t i = 1; i < recv_node->size(); i++) {
1111       auto input_i = recv_node->inputs()[i];
1112       MS_EXCEPTION_IF_NULL(input_i);
1113       // If the input of recv is monad, do not pass it to fused recv node.
1114       if (HasAbstractMonad(input_i)) {
1115         continue;
1116       }
1117       (void)recv_inputs.emplace_back(input_i);
1118     }
1119     (void)abstract_list.emplace_back(recv_node->abstract());
1120     (void)fused_inter_process_edge_name.append(
1121       common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(recv_node, kAttrInterProcessEdgeNames).front());
1122   }
1123   // Add umonad for recv node to update reference.
1124   ValuePtr monad_value = kUMonad;
1125   auto monad_input = NewValueNode(monad_value);
1126   MS_EXCEPTION_IF_NULL(monad_input);
1127   monad_input->set_abstract(monad_value->ToAbstract());
1128   recv_inputs.push_back(monad_input);
1129 
1130   CNodePtr fused_recv_node = func_graph_->NewCNode(recv_inputs);
1131   MS_EXCEPTION_IF_NULL(fused_recv_node);
1132   fused_recv_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
1133   std::vector<std::string> fused_inter_process_edge_names = {fused_inter_process_edge_name};
1134   common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(fused_inter_process_edge_names), fused_recv_node);
1135   common::AnfAlgo::CopyNodeAttr(kAttrPrimitiveTarget, rpc_recv_nodes[0], fused_recv_node);
1136   common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcRanks, rpc_recv_nodes[0], fused_recv_node);
1137   common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcRoles, rpc_recv_nodes[0], fused_recv_node);
1138   common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcNodeName, rpc_recv_nodes[0], fused_recv_node);
1139   common::AnfAlgo::CopyNodeAttr(kAttrRecvDstNodeName, rpc_recv_nodes[0], fused_recv_node);
1140   return fused_recv_node;
1141 }
1142 
FuseCommEdges(const std::vector<InterProcessOpPair> & inter_process_pairs)1143 std::vector<FusedInterProcessOpPair> ParameterServerMode::FuseCommEdges(
1144   const std::vector<InterProcessOpPair> &inter_process_pairs) {
1145   std::vector<FusedInterProcessOpPair> fused_op_pairs;
1146   std::vector<CNodePtr> rpc_send_nodes, rpc_recv_nodes;
1147   std::map<size_t, size_t> indices_map;
1148   for (size_t i = 0; i < inter_process_pairs.size(); i++) {
1149     auto &op_pair = inter_process_pairs[i];
1150     auto reused_send_node =
1151       std::find_if(rpc_send_nodes.begin(), rpc_send_nodes.end(), [&op_pair](const auto &send_node_need_fuse) {
1152         CNodePtr send_node = std::get<0>(op_pair);
1153         auto node_name1 = common::AnfAlgo::GetInputNode(send_node, kIndex0)->fullname_with_scope();
1154         auto node_name2 = common::AnfAlgo::GetInputNode(send_node_need_fuse, kIndex0)->fullname_with_scope();
1155         return node_name1 == node_name2;
1156       });
1157     if (reused_send_node != rpc_send_nodes.end()) {
1158       size_t index = static_cast<size_t>(std::distance(rpc_send_nodes.begin(), reused_send_node));
1159       indices_map[i] = index;
1160     } else {
1161       (void)rpc_send_nodes.emplace_back(std::get<0>(op_pair));
1162       (void)rpc_recv_nodes.emplace_back(std::get<1>(op_pair));
1163       indices_map[i] = rpc_send_nodes.size() - 1;
1164     }
1165   }
1166 
1167   CNodePtr fused_send_node = FuseRpcSendNodes(rpc_send_nodes);
1168   CNodePtr fused_recv_node = FuseRpcRecvNodes(rpc_recv_nodes);
1169   for (size_t i = 0; i < inter_process_pairs.size(); i++) {
1170     FusedInterProcessOpPair fused_inter_process_pair =
1171       std::make_tuple(fused_send_node, fused_recv_node, indices_map[i], std::get<2>(inter_process_pairs[i]),
1172                       std::get<3>(inter_process_pairs[i]));
1173     (void)fused_op_pairs.emplace_back(fused_inter_process_pair);
1174   }
1175   return fused_op_pairs;
1176 }
1177 
GraphSplitter(const FuncGraphPtr & func_graph,uint32_t rank_id,const std::string & role)1178 GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, const std::string &role)
1179     : func_graph_(func_graph),
1180       rank_id_(rank_id),
1181       role_(role),
1182       exec_mode_(nullptr),
1183       this_process_label_({rank_id, role}),
1184       node_labels_{},
1185       need_fuse_rpc_nodes_(true) {
1186   // The distributed strategy is not explicitly defined by user. Distributed module generates the distributed strategy
1187   // and default label according to some flags set by other modules.
1188   mode_ = GenerateStrategy();
1189   default_label_ = {0, distributed::kEnvRoleOfWorker};
1190 }
1191 
PreBuildDistributedGraph()1192 void EmbeddingCacheMode::PreBuildDistributedGraph() {
1193   // Only need add embedding cache ops of remote cache.
1194   if (role_ != distributed::kEnvRoleOfPServer) {
1195     return;
1196   }
1197 
1198   // 1. Add embedding cache ops of remote cache, and build service-side graph.
1199   AddEmbeddingCacheOps();
1200 
1201   // 2. Get node labels.
1202   MS_EXCEPTION_IF_NULL(node_labels_);
1203   node_labels_->clear();
1204 
1205   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1206   (void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
1207     MS_EXCEPTION_IF_NULL(node);
1208     if (node->isa<CNode>()) {
1209       CNodePtr cnode = node->cast<CNodePtr>();
1210       MS_EXCEPTION_IF_NULL(cnode);
1211       OperatorLabel label = GetNodeLabel(cnode);
1212       (void)node_labels_->emplace(node, label);
1213     }
1214   });
1215 }
1216 
AddEmbeddingCacheOps() const1217 void EmbeddingCacheMode::AddEmbeddingCacheOps() const {
1218   uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
1219   if (worker_num == 0) {
1220     MS_LOG(EXCEPTION) << "In embedding cache mode, worker number should be greater than 0.";
1221   }
1222 
1223   // Build service-side graph.
1224   std::shared_ptr<parallel::PsEmbeddingCacheInserter> embedding_cache_inserter =
1225     std::make_shared<parallel::PsEmbeddingCacheInserter>(func_graph_, static_cast<int64_t>(rank_id_), role_,
1226                                                          worker_num);
1227   if (!embedding_cache_inserter->Run()) {
1228     MS_LOG(EXCEPTION) << "Insert ps embedding cache failed.";
1229   }
1230 }
1231 
GetNodeLabel(const AnfNodePtr & node) const1232 OperatorLabel EmbeddingCacheMode::GetNodeLabel(const AnfNodePtr &node) const {
1233   MS_EXCEPTION_IF_NULL(node);
1234   if (!node->isa<CNode>()) {
1235     MS_LOG(EXCEPTION) << "Only CNode has distributed split label.";
1236   }
1237 
1238   CNodePtr cnode = node->cast<CNodePtr>();
1239   auto prim_node = cnode->input(0);
1240   if (IsValueNode<Primitive>(prim_node)) {
1241     auto prim = GetValueNode<PrimitivePtr>(prim_node);
1242     MS_EXCEPTION_IF_NULL(prim);
1243     if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
1244       MS_LOG(INFO) << "CNode which has distributed split label: " << cnode->fullname_with_scope();
1245       uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(prim->GetAttr(distributed::kOpLabelRankId)));
1246       std::string ms_role = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
1247       return {rank_id, ms_role};
1248     }
1249   } else {
1250     // Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
1251     if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
1252       uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(cnode->GetAttr(distributed::kOpLabelRankId)));
1253       std::string ms_role = GetValue<std::string>(cnode->GetAttr(distributed::kOpLabelRole));
1254       return {rank_id, ms_role};
1255     }
1256   }
1257   return {rank_id_, role_};
1258 }
1259 
~GraphSplitter()1260 GraphSplitter::~GraphSplitter() { node_labels_.clear(); }
1261 
Run()1262 void GraphSplitter::Run() {
1263   MS_EXCEPTION_IF_NULL(func_graph_);
1264   MS_EXCEPTION_IF_NULL(func_graph_->manager());
1265 
1266   // Step 1: Dye all the nodes of the whole func_graph_.
1267   DyeGraph();
1268   // If all nodes are all on this process, no need to split the graph. So return.
1269   if (!NeedSplitGraph()) {
1270     MS_LOG(INFO) << "All nodes are on this process so there's no need to build and split distributed graph.";
1271     return;
1272   }
1273 
1274   // Step 2: Create exec_mode_ according to the current execution mode.
1275   CreateExecutionMode();
1276 
1277   // If this is general mode but no label is set, do not split graph to avoid unexpected optimizing out.
1278   if (mode_ == distributed::DistExecutionMode::kGeneralMode && !GraphHasLabel(func_graph_)) {
1279     MS_LOG(INFO) << "This graph has no label on it in general mode. So no need to split.";
1280     return;
1281   }
1282 
1283   // Step 3: Prebuild the distributed graph before it gets split.
1284   exec_mode_->PreBuildDistributedGraph();
1285 
1286   if (!NeedSplitGraph()) {
1287     MS_LOG(INFO) << "All nodes are on this precoess so there's no need to build and split distributed graph.";
1288     return;
1289   }
1290 
1291   // For TupleGetItem nodes, their label should be reset for good splitting performance.
1292   ReassignTupleGetItemNodeLabel();
1293 
1294   if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
1295     // Only use ref sync mechanism when in general mode.
1296     ProcessRefNodes();
1297     // Add some control edges between different labels.
1298     AddExtraControlEdgeAcrossProcess();
1299   }
1300 
1301   // Step 4: Create inter-process operators for segments with different labels.
1302   InterProcessOpEdgesInfo comm_edges = GenerateInterProcessOperators();
1303 
1304   need_fuse_rpc_nodes_ = common::GetEnv(kEnvNeedFusion).empty() ? false : true;
1305   if (need_fuse_rpc_nodes_) {
1306     // Step 5: Fuse the rpc nodes to improve performance.
1307     const FusedInterProcessOpPairMap &fused_inter_process_op_pairs = exec_mode_->DoRpcNodeFusion(&comm_edges);
1308 
1309     // Step 6: Add dependency and eliminate extra nodes for fused rpc nodes.
1310     SplitGraph(fused_inter_process_op_pairs);
1311 
1312     // Step 7: Postbuild the graph after splitting with fused edges.
1313     exec_mode_->PostBuildDistributedGraph(fused_inter_process_op_pairs);
1314   } else {
1315     // Step 5: Generate the node segments with different labels.
1316     std::vector<SplitGraphSegment> segments = GenerateSplitSegments();
1317     // If the segment number is 0, there will be no distributed execution.
1318     if (segments.empty()) {
1319       return;
1320     }
1321 
1322     // Step 6: Split the graph and eliminate extra nodes.
1323     SplitGraph(segments, comm_edges);
1324 
1325     // Step 7: Postbuild the graph after splitting.
1326     exec_mode_->PostBuildDistributedGraph(comm_edges);
1327   }
1328   // Only eliminate the data-sync node pairs in general mode.
1329   if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
1330     EliminateDataSyncNode();
1331     EliminateControlEdgeNode();
1332   }
1333 }
1334 
DyeGraph()1335 void GraphSplitter::DyeGraph() {
1336   MS_EXCEPTION_IF_NULL(func_graph_);
1337   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1338   (void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
1339     MS_EXCEPTION_IF_NULL(node);
1340     // Mark all nodes with original label at the beginning. This means the node is supposed to be on the process with
1341     // default_label_.
1342     node_labels_[node] = default_label_;
1343     if (node->isa<CNode>()) {
1344       // For CNodes, mark them with the label passed by frontend if has one.
1345       CNodePtr cnode = node->cast<CNodePtr>();
1346       MS_EXCEPTION_IF_NULL(cnode);
1347       OperatorLabel label = GetSplitLabel(cnode);
1348       node_labels_[node] = label;
1349     }
1350 
1351     // If the node's label is the same as this process's, set its label to this_process_label_.
1352     if (this_process_label_.LooseEqual(node_labels_[node], mode_)) {
1353       node_labels_[node] = this_process_label_;
1354     }
1355   });
1356 }
1357 
CreateExecutionMode()1358 void GraphSplitter::CreateExecutionMode() {
1359   MS_EXCEPTION_IF_NULL(func_graph_);
1360   if (node_labels_.empty()) {
1361     MS_LOG(EXCEPTION) << "Must dye the original graph before creating execution mode.";
1362   }
1363   if (mode_ == distributed::DistExecutionMode::kPSMode) {
1364     exec_mode_ = std::make_unique<ParameterServerMode>(func_graph_, &node_labels_, rank_id_, role_);
1365   } else if (mode_ == distributed::DistExecutionMode::kEmbeddingCacheMode) {
1366     exec_mode_ = std::make_unique<EmbeddingCacheMode>(func_graph_, &node_labels_, rank_id_, role_);
1367   } else if (mode_ == distributed::DistExecutionMode::kParallelMode) {
1368     exec_mode_ = std::make_unique<ParallelMode>(func_graph_, &node_labels_, rank_id_, role_);
1369   } else if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
1370     exec_mode_ = std::make_unique<GeneralMode>(func_graph_, &node_labels_, rank_id_, role_);
1371   }
1372   MS_EXCEPTION_IF_NULL(exec_mode_);
1373 }
1374 
GenerateSplitSegments()1375 std::vector<SplitGraphSegment> GraphSplitter::GenerateSplitSegments() {
1376   MS_EXCEPTION_IF_NULL(func_graph_);
1377   auto return_node = func_graph_->get_return();
1378   MS_EXCEPTION_IF_NULL(return_node);
1379   std::vector<AnfNodePtr> nodes = FuncGraph::TopoSort(return_node);
1380 
1381   std::vector<SplitGraphSegment> results;
1382   SplitGraphSegment segment;
1383   OperatorLabel last_label = this_process_label_;
1384   segment.label = last_label;
1385   for (auto &n : nodes) {
1386     if (!n->isa<CNode>()) {
1387       continue;
1388     }
1389     auto cnode_split_label = node_labels_[n];
1390     // If this node's label is not the same as last node's, create a segment from 'segment_nodes'.
1391     if (cnode_split_label != last_label && !segment.nodes.empty()) {
1392       (void)results.emplace_back(segment);
1393       segment.nodes.clear();
1394     }
1395     // Mark the last label.
1396     last_label = cnode_split_label;
1397     segment.label = cnode_split_label;
1398     (void)segment.nodes.emplace_back(n);
1399   }
1400 
1401   // Add the last segment.
1402   (void)results.emplace_back(segment);
1403   MS_LOG(INFO) << "Segments number with different distributed split labels is " << results.size();
1404   return results;
1405 }
1406 
ReassignTupleGetItemNodeLabel()1407 void GraphSplitter::ReassignTupleGetItemNodeLabel() {
1408   MS_EXCEPTION_IF_NULL(func_graph_);
1409   AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1410   for (const auto &node : all_nodes) {
1411     if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1412       node_labels_[node] = RecursiveSetTupeGetItemLabel(node->cast<CNodePtr>());
1413     }
1414   }
1415 }
1416 
RecursiveSetTupeGetItemLabel(const CNodePtr & tuple_get_item_node)1417 OperatorLabel GraphSplitter::RecursiveSetTupeGetItemLabel(const CNodePtr &tuple_get_item_node) {
1418   // Return if this node has already been visited.
1419   if (visited_tuple_get_item_nodes_.count(tuple_get_item_node) != 0) {
1420     if (NodeHasLabel(tuple_get_item_node)) {
1421       return node_labels_[tuple_get_item_node];
1422     } else {
1423       MS_LOG(EXCEPTION) << "TupeGetItem node " << tuple_get_item_node->fullname_with_scope() << " has no lebel.";
1424     }
1425   }
1426 
1427   visited_tuple_get_item_nodes_[tuple_get_item_node] = true;
1428   auto tuple_input = common::AnfAlgo::GetInputNode(tuple_get_item_node, kIndex0);
1429   OperatorLabel tuple_get_item_label;
1430   if (IsPrimitiveCNode(tuple_input, prim::kPrimTupleGetItem)) {
1431     // If TupleGetItem's input is a TupleGetItem node, recursively trace up and get a proper input's label.
1432     tuple_get_item_label = RecursiveSetTupeGetItemLabel(tuple_input->cast<CNodePtr>());
1433     node_labels_[tuple_input] = tuple_get_item_label;
1434   } else {
1435     // Set TupleGetItem's label the same as its input so it's easier to split.
1436     tuple_get_item_label = node_labels_[tuple_input];
1437   }
1438   return tuple_get_item_label;
1439 }
1440 
ProcessRefNodes()1441 void GraphSplitter::ProcessRefNodes() {
1442   MS_EXCEPTION_IF_NULL(func_graph_);
1443   AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1444   // Traverse all nodes and find each nodes with side effect.
1445   CNodePtrList cnodes_with_side_effect = GetSideEffectNodeList(all_nodes);
1446   for (const auto &cnode : cnodes_with_side_effect) {
1447     // Filter out all ref inputs which need to be synchronized between different processes.
1448     AnfNodePtrList ref_inputs = GetRefInputs(cnode);
1449     // Get the user node(UpdateState) of side effect node.
1450     CNodePtr update_state_node = FindNextUpdateStateNode(func_graph_, cnode);
1451     MS_EXCEPTION_IF_NULL(update_state_node);
1452 
1453     // The key method to keep the correctness of reference nodes across computing graph nodes.
1454     AddDataSyncNode(cnode, update_state_node, ref_inputs);
1455   }
1456 }
1457 
AddExtraControlEdgeAcrossProcess()1458 void GraphSplitter::AddExtraControlEdgeAcrossProcess() { AddControlEdgeForProcessWithoutIndegree(); }
1459 
GenerateInterProcessOperators()1460 InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOperators() {
1461   InterProcessOpEdgesInfo comm_edges;
1462   MS_EXCEPTION_IF_NULL(func_graph_);
1463   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1464   for (auto &node : all_nodes) {
1465     MS_EXCEPTION_IF_NULL(node);
1466     // Only support to split CNode to other process.
1467     if (!node->isa<CNode>()) {
1468       continue;
1469     }
1470 
1471     // Generating send/recv nodes for each nodes' inputs will be enough.
1472     auto node_inputs_comm_edges = GenerateInterProcessOpsForNodeInputs(node);
1473     (void)comm_edges.insert(node_inputs_comm_edges.cbegin(), node_inputs_comm_edges.cend());
1474   }
1475   MS_LOG(INFO) << "The communication edge number is " << comm_edges.size();
1476   return comm_edges;
1477 }
1478 
SplitGraph(const std::vector<SplitGraphSegment> & segments,const InterProcessOpEdgesInfo & comm_edges)1479 void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
1480                                const InterProcessOpEdgesInfo &comm_edges) {
1481   // Step 1: Traverse all the segments to add Depend for this process's graph.
1482   // The list of corresponding in and out degrees. In another word, the map between one segments' input send nodes. and
1483   // output recv nodes.
1484   InOutDegreeList in_out_degree_list = GenerateInOutDegreeList(segments, comm_edges);
1485   if (in_out_degree_list.empty()) {
1486     MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
1487     auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
1488     (void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
1489     return;
1490   }
1491 
1492   // Step 2: Add dependency between communication edges on this process.
1493   AddDependencyBetweenEdges(comm_edges);
1494 
1495   // Step 3: Eliminate nodes not on this process.
1496   EliminateExtraNodes(comm_edges);
1497 }
1498 
SplitGraph(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)1499 void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
1500   if (fused_inter_process_op_pairs.empty()) {
1501     MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
1502     auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
1503     (void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
1504     return;
1505   }
1506 
1507   // Step 1: Replace origin nodes with recv nodes.
1508   ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs);
1509 
1510   // Step 2: Connect output for send nodes.
1511   AddDependencyForSend(fused_inter_process_op_pairs);
1512 }
1513 
AddDataSyncNode(const CNodePtr & side_effect_node,const CNodePtr & update_state_node,const AnfNodePtrList & ref_nodes)1514 void GraphSplitter::AddDataSyncNode(const CNodePtr &side_effect_node, const CNodePtr &update_state_node,
1515                                     const AnfNodePtrList &ref_nodes) {
1516   MS_EXCEPTION_IF_NULL(func_graph_);
1517   MS_EXCEPTION_IF_NULL(side_effect_node);
1518   MS_EXCEPTION_IF_NULL(update_state_node);
1519 
1520   MS_EXCEPTION_IF_CHECK_FAIL(
1521     (node_labels_.count(side_effect_node) != 0),
1522     "The node label for side effect node " + side_effect_node->fullname_with_scope() + " is not set.");
1523   auto side_effect_node_label = node_labels_[side_effect_node];
1524 
1525   for (const auto &ref : ref_nodes) {
1526     std::set<OperatorLabel> diff_labels;
1527     for (const auto &user : func_graph_->manager()->node_users()[ref]) {
1528       const auto &user_node = user.first;
1529       MS_LOG(DEBUG) << "The user of ref " << ref->fullname_with_scope() << " is " << user_node->fullname_with_scope()
1530                     << ", side-effect node label: " << side_effect_node_label.to_string()
1531                     << ", user node label: " << node_labels_[user_node].to_string();
1532       if (node_labels_[user_node] != side_effect_node_label) {
1533         (void)diff_labels.insert(node_labels_[user_node]);
1534       } else {
1535         // If the user node is Load, we need to find one next user of it so the node could be correctly split.
1536         if (IsPrimitiveCNode(user_node, prim::kPrimLoad)) {
1537           for (const auto &load_user : func_graph_->manager()->node_users()[user_node]) {
1538             const auto &load_user_node = load_user.first;
1539             MS_LOG(DEBUG) << "Load user is " << load_user_node
1540                           << ", label: " << node_labels_[load_user_node].to_string();
1541             if (node_labels_[load_user_node] != side_effect_node_label) {
1542               (void)diff_labels.insert(node_labels_[load_user_node]);
1543             }
1544           }
1545         }
1546       }
1547     }
1548     // If the ref is used in multiple compute graph nodes, it needs to be synchronized.
1549     if (diff_labels.empty()) {
1550       MS_LOG(INFO) << "No need to synchronize ref node " << ref->fullname_with_scope()
1551                    << " because the user nodes are on the same process.";
1552       continue;
1553     }
1554 
1555     //  Create data-sync nodes and connect them to UpdateState node.
1556     auto data_sync_node_list = CreateDataSyncNodes(side_effect_node, ref, diff_labels);
1557     for (const auto &node_pair : data_sync_node_list) {
1558       CNodePtr src_node = node_pair.first;
1559       CNodePtr dst_node = node_pair.second;
1560       func_graph_->manager()->AddEdge(update_state_node, dst_node);
1561     }
1562   }
1563 }
1564 
CreateDataSyncNodes(const CNodePtr & side_effect_node,const AnfNodePtr & ref,const std::set<OperatorLabel> & diff_labels)1565 DataSyncNodePairList GraphSplitter::CreateDataSyncNodes(const CNodePtr &side_effect_node, const AnfNodePtr &ref,
1566                                                         const std::set<OperatorLabel> &diff_labels) {
1567   MS_EXCEPTION_IF_NULL(side_effect_node);
1568   MS_EXCEPTION_IF_NULL(ref);
1569 
1570   DataSyncNodePairList result;
1571   for (const auto &label : diff_labels) {
1572     // Data sync src node.
1573     std::vector<AnfNodePtr> sync_src_node_inputs = {
1574       NewValueNode(std::make_shared<Primitive>(distributed::kDataSyncSrcOpName))};
1575     (void)sync_src_node_inputs.emplace_back(ref);
1576     (void)sync_src_node_inputs.emplace_back(side_effect_node);
1577     CNodePtr sync_src_node = func_graph_->NewCNode(sync_src_node_inputs);
1578     MS_EXCEPTION_IF_NULL(sync_src_node);
1579     sync_src_node->set_abstract(ref->abstract());
1580     node_labels_[sync_src_node] = node_labels_[side_effect_node];
1581 
1582     // Data sync dst node.
1583     std::vector<AnfNodePtr> sync_dst_node_inputs = {
1584       NewValueNode(std::make_shared<Primitive>(distributed::kDataSyncDstOpName))};
1585     (void)sync_dst_node_inputs.emplace_back(sync_src_node);
1586     CNodePtr sync_dst_node = func_graph_->NewCNode(sync_dst_node_inputs);
1587     MS_EXCEPTION_IF_NULL(sync_dst_node);
1588     auto fake_value = CreateFakeValueNode(false);
1589     MS_EXCEPTION_IF_NULL(fake_value);
1590     sync_dst_node->set_abstract(fake_value->abstract());
1591     node_labels_[sync_dst_node] = label;
1592 
1593     MS_LOG(INFO) << "Data sync pair: " << sync_src_node->fullname_with_scope() << "_"
1594                  << node_labels_[sync_src_node].to_string() << "->" << sync_dst_node->fullname_with_scope() << "_"
1595                  << label.to_string();
1596     result.push_back(std::make_pair(sync_src_node, sync_dst_node));
1597   }
1598   return result;
1599 }
1600 
AddControlEdgeForProcessWithoutIndegree()1601 void GraphSplitter::AddControlEdgeForProcessWithoutIndegree() {
1602   (void)std::for_each(node_labels_.begin(), node_labels_.end(),
1603                       [this](const auto &node_label_pair) { (void)all_labels_.insert(node_label_pair.second); });
1604 
1605   std::set<OperatorLabel> labels_has_indegree;
1606   AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1607   for (const auto &node : all_nodes) {
1608     if (!node->isa<CNode>()) {
1609       continue;
1610     }
1611     auto cnode = node->cast<CNodePtr>();
1612     for (size_t i = kIndex1; i < cnode->size(); ++i) {
1613       const auto &input = cnode->inputs().at(i);
1614       if (NodeHasLabel(input) && NodeHasLabel(cnode) && node_labels_[input] != node_labels_[cnode] &&
1615           input->isa<CNode>()) {
1616         MS_LOG(DEBUG) << "Label " << node_labels_[cnode].to_string() << " has indegree from label "
1617                       << node_labels_[input].to_string() << ", edge: " << input->fullname_with_scope() << " to "
1618                       << cnode->fullname_with_scope();
1619         (void)labels_has_indegree.insert(node_labels_[cnode]);
1620       }
1621     }
1622   }
1623 
1624   ControlEdgeNodePairList control_edge_node_pair_list;
1625   for (const OperatorLabel &label : all_labels_) {
1626     // If this label has no indegree, add extra control edge nodes.
1627     if (labels_has_indegree.count(label) == 0) {
1628       ControlEdgeNodePair control_edge_nodes = CreateControlEdgeNode(default_label_, label);
1629       (void)control_edge_node_pair_list.emplace_back(control_edge_nodes);
1630     }
1631   }
1632 
1633   if (!control_edge_node_pair_list.empty()) {
1634     // Connect the dangling control dst nodes to the output.
1635     AnfNodePtrList make_tuple_inputs;
1636     (void)std::for_each(control_edge_node_pair_list.begin(), control_edge_node_pair_list.end(),
1637                         [&make_tuple_inputs](const auto &node_pair) {
1638                           CNodePtr control_dst_node = node_pair.second;
1639                           (void)make_tuple_inputs.emplace_back(control_dst_node);
1640                         });
1641 
1642     // Make tuple for all control-edge dst nodes.
1643     MS_EXCEPTION_IF_NULL(func_graph_);
1644     auto tuple_of_control_dst_nodes = CreateMakeTupleNode(func_graph_, make_tuple_inputs);
1645     MS_EXCEPTION_IF_NULL(tuple_of_control_dst_nodes);
1646     node_labels_[tuple_of_control_dst_nodes] = default_label_;
1647 
1648     // Add dependency to the Return node so control-edge nodes won't be optimized out.
1649     AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), func_graph_->output(), tuple_of_control_dst_nodes};
1650     auto final_output_node = func_graph_->NewCNode(depend_inputs);
1651     MS_EXCEPTION_IF_NULL(final_output_node);
1652     node_labels_[final_output_node] = default_label_;
1653 
1654     final_output_node->set_abstract(func_graph_->output()->abstract());
1655     (void)func_graph_->manager()->SetEdge(func_graph_->get_return(), kIndex1, final_output_node);
1656   }
1657   return;
1658 }
1659 
CreateControlEdgeNode(const OperatorLabel & src_label,const OperatorLabel & dst_label)1660 ControlEdgeNodePair GraphSplitter::CreateControlEdgeNode(const OperatorLabel &src_label,
1661                                                          const OperatorLabel &dst_label) {
1662   // Control src node's input is a value node. It has not practical meaning.
1663   auto fake_tensor = std::make_shared<tensor::Tensor>(1.0);
1664   MS_EXCEPTION_IF_NULL(fake_tensor);
1665   auto fake_value = NewValueNode(fake_tensor);
1666   MS_EXCEPTION_IF_NULL(fake_value);
1667   fake_value->set_abstract(fake_tensor->ToAbstract());
1668 
1669   AnfNodePtrList control_src_inputs = {NewValueNode(std::make_shared<Primitive>(distributed::kControlSrcOpName)),
1670                                        fake_value};
1671   CNodePtr control_src_node = func_graph_->NewCNode(control_src_inputs);
1672   MS_EXCEPTION_IF_NULL(control_src_node);
1673   control_src_node->set_abstract(fake_value->abstract());
1674   node_labels_[control_src_node] = src_label;
1675 
1676   // Control dst node's input is control src node.
1677   AnfNodePtrList control_dst_inputs = {NewValueNode(std::make_shared<Primitive>(distributed::kControlDstOpName)),
1678                                        control_src_node};
1679   CNodePtr control_dst_node = func_graph_->NewCNode(control_dst_inputs);
1680   MS_EXCEPTION_IF_NULL(control_dst_node);
1681   control_dst_node->set_abstract(control_src_node->abstract());
1682   node_labels_[control_dst_node] = dst_label;
1683 
1684   // At this phase, the control_dst_node is still a dangling node. We need to connect it to the output to avoid
1685   // optimizing out.
1686   return std::make_pair(control_src_node, control_dst_node);
1687 }
1688 
EliminateDataSyncNode()1689 void GraphSplitter::EliminateDataSyncNode() {
1690   MS_EXCEPTION_IF_NULL(func_graph_);
1691   AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1692   for (const auto &node : all_nodes) {
1693     MS_EXCEPTION_IF_NULL(node);
1694     if (!node->isa<CNode>()) {
1695       continue;
1696     }
1697 
1698     auto cnode = node->cast<CNodePtr>();
1699     MS_EXCEPTION_IF_NULL(cnode);
1700     if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kDataSyncSrcOpName) {
1701       if (cnode->size() != kSizeThree) {
1702         MS_LOG(EXCEPTION) << "Node DataSyncSrc's input number should be 3, but got " << cnode->size();
1703       }
1704       // The first input is parameter and the second input is side effect node.
1705       auto param_node = cnode->inputs()[kIndex1];
1706       MS_EXCEPTION_IF_NULL(param_node);
1707       auto side_effect_node = cnode->inputs()[kIndex2];
1708       MS_EXCEPTION_IF_NULL(side_effect_node);
1709       MS_LOG(DEBUG) << "Parameter node is " << param_node->fullname_with_scope() << ", side effect node is "
1710                     << side_effect_node->fullname_with_scope();
1711 
1712       AnfNodePtrList update_state_inputs = {side_effect_node};
1713       CNodePtr update_state_node = CreateUpdateStateNode(func_graph_, update_state_inputs);
1714       MS_EXCEPTION_IF_NULL(update_state_node);
1715 
1716       // For parameter, connect it to a 'Load' node so that the control arrow could be correctly linked.
1717       AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), param_node, update_state_node};
1718 
1719       auto load_node_replace_data_sync_src = func_graph_->NewCNode(load_inputs);
1720       MS_EXCEPTION_IF_NULL(load_node_replace_data_sync_src);
1721       load_node_replace_data_sync_src->set_abstract(cnode->abstract());
1722       (void)func_graph_->manager()->Replace(cnode, load_node_replace_data_sync_src);
1723     } else if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kDataSyncDstOpName) {
1724       if (cnode->size() != kSizeTwo) {
1725         MS_LOG(EXCEPTION) << "Node DataSyncDst's input number should be 2, but got " << cnode->size();
1726       }
1727       auto input_node = cnode->inputs()[kIndex1];
1728       MS_EXCEPTION_IF_NULL(input_node);
1729 
1730       auto users = func_graph_->manager()->node_users()[cnode];
1731       for (const auto &user_pair : users) {
1732         auto user_node = user_pair.first;
1733         int input_index = user_pair.second;
1734         func_graph_->manager()->SetEdge(user_node, input_index, input_node);
1735       }
1736     }
1737   }
1738 }
1739 
EliminateControlEdgeNode()1740 void GraphSplitter::EliminateControlEdgeNode() {
1741   MS_EXCEPTION_IF_NULL(func_graph_);
1742   AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1743   for (const auto &node : all_nodes) {
1744     MS_EXCEPTION_IF_NULL(node);
1745     if (!node->isa<CNode>()) {
1746       continue;
1747     }
1748 
1749     auto cnode = node->cast<CNodePtr>();
1750     MS_EXCEPTION_IF_NULL(cnode);
1751     if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kControlSrcOpName) {
1752       // ControlSrc->RpcSend is converted to FakeValue->RpcSend.
1753       auto fake_value_node = CreateFakeValueNode(false);
1754       MS_EXCEPTION_IF_NULL(fake_value_node);
1755       (void)func_graph_->manager()->Replace(cnode, fake_value_node);
1756     } else if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kControlDstOpName) {
1757       if (cnode->size() != kSizeTwo) {
1758         MS_LOG(EXCEPTION) << "Node DataSyncDst's input number should be 2, but got " << cnode->size();
1759       }
1760       auto input_node = cnode->inputs()[kIndex1];
1761       MS_EXCEPTION_IF_NULL(input_node);
1762       (void)func_graph_->manager()->Replace(cnode, input_node);
1763     }
1764   }
1765 }
1766 
DumpDistributedGraph(const InterProcessOpEdgesInfo & comm_edges)1767 void GraphSplitter::DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
1768   // Traverse all the segments to add Depend for this process's graph.
1769   for (const auto &edge : comm_edges) {
1770     auto send_recv_pair = edge.second;
1771     auto send_node = std::get<0>(send_recv_pair);
1772     auto recv_node = std::get<1>(send_recv_pair);
1773     auto user_node = std::get<2>(send_recv_pair);
1774     auto user_node_index = std::get<3>(send_recv_pair);
1775     func_graph_->manager()->SetEdge(recv_node, 1, send_node);
1776     func_graph_->manager()->SetEdge(user_node, user_node_index, recv_node);
1777   }
1778   MS_LOG(INFO) << "Cut graph without eliminating nodes.";
1779   draw::Draw("single_node_graph.dot", func_graph_);
1780 }
1781 
GetSplitLabel(const AnfNodePtr & node)1782 OperatorLabel GraphSplitter::GetSplitLabel(const AnfNodePtr &node) {
1783   MS_EXCEPTION_IF_NULL(node);
1784   if (!node->isa<CNode>()) {
1785     MS_LOG(EXCEPTION) << "Only CNode has distributed split label.";
1786   }
1787   CNodePtr cnode = node->cast<CNodePtr>();
1788   MS_EXCEPTION_IF_NULL(cnode);
1789   auto prim_node = cnode->input(0);
1790   if (IsValueNode<Primitive>(prim_node)) {
1791     TransformPrimAttrToAttr(cnode);
1792     auto prim = GetValueNode<PrimitivePtr>(prim_node);
1793     MS_EXCEPTION_IF_NULL(prim);
1794     if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
1795       MS_LOG(INFO) << "CNode which has distributed split label: " << cnode->fullname_with_scope();
1796       uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(prim->GetAttr(distributed::kOpLabelRankId)));
1797       std::string ms_role = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
1798       return {rank_id, ms_role};
1799     }
1800   } else {
1801     // Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
1802     if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
1803       uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(cnode->GetAttr(distributed::kOpLabelRankId)));
1804       std::string ms_role = GetValue<std::string>(cnode->GetAttr(distributed::kOpLabelRole));
1805       return {rank_id, ms_role};
1806     }
1807   }
1808   return default_label_;
1809 }
1810 
GenerateInterProcessOpsForNodeInputs(const AnfNodePtr & node)1811 InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeInputs(const AnfNodePtr &node) {
1812   MS_EXCEPTION_IF_NULL(func_graph_);
1813   MS_EXCEPTION_IF_NULL(node);
1814   CNodePtr cnode = node->cast<CNodePtr>();
1815   MS_EXCEPTION_IF_NULL(cnode);
1816   InterProcessOpEdgesInfo comm_edges;
1817   for (size_t i = 1; i < cnode->size(); i++) {
1818     auto input_i = cnode->inputs()[i];
1819     MS_EXCEPTION_IF_NULL(input_i);
1820 
1821     // If the input's not a cnode, or its label is the same as this node's, or the input is 'Load' node for parameter,
1822     // there's no need to add communication nodes.
1823     if (!input_i->isa<CNode>() || IsNodesWithSameLabel(input_i, cnode) ||
1824         common::AnfAlgo::GetCNodeName(input_i) == "Load") {
1825       if (IsOneOfRealGraphInput(func_graph_, input_i) && !IsNodesWithSameLabel(input_i, cnode)) {
1826         MS_LOG(INFO) << "The input " << input_i->fullname_with_scope() << " needs to be split.";
1827       } else {
1828         continue;
1829       }
1830     }
1831 
1832     InterProcessEdgeLabel edge_label = GenerateEdgeLabel(input_i, cnode);
1833     InterProcessOpEdge edge = {input_i, node_labels_[input_i], cnode, node_labels_[cnode], edge_label};
1834 
1835     auto send_node = CreateSendNode(func_graph_, edge);
1836     MS_EXCEPTION_IF_NULL(send_node);
1837     // The label should be the same as the node which will 'launch' Send node.
1838     node_labels_[send_node] = edge.src_label;
1839 
1840     auto recv_node = CreateRecvNode(func_graph_, edge);
1841     MS_EXCEPTION_IF_NULL(recv_node);
1842     // The label should be the same as the node which Receives the 'input'.
1843     node_labels_[recv_node] = edge.dst_label;
1844 
1845     auto comm_node_pair = std::make_tuple(send_node, recv_node, cnode, SizeToInt(i));
1846     (void)comm_edges.insert(std::make_pair(edge, comm_node_pair));
1847   }
1848   return comm_edges;
1849 }
1850 
GenerateEdgeLabel(const AnfNodePtr & src_node,const AnfNodePtr & dst_node) const1851 InterProcessEdgeLabel GraphSplitter::GenerateEdgeLabel(const AnfNodePtr &src_node, const AnfNodePtr &dst_node) const {
1852   MS_EXCEPTION_IF_NULL(src_node);
1853   MS_EXCEPTION_IF_NULL(dst_node);
1854   std::string src_node_edge_label = "";
1855   std::string dst_node_edge_label = "";
1856   if (src_node->isa<CNode>()) {
1857     src_node_edge_label = common::AnfAlgo::HasNodeAttr(kAttrInterProcessEdgeLabel, src_node->cast<CNodePtr>())
1858                             ? common::AnfAlgo::GetNodeAttr<std::string>(src_node, kAttrInterProcessEdgeLabel)
1859                             : "";
1860   }
1861   if (dst_node->isa<CNode>()) {
1862     dst_node_edge_label = common::AnfAlgo::HasNodeAttr(kAttrInterProcessEdgeLabel, dst_node->cast<CNodePtr>())
1863                             ? common::AnfAlgo::GetNodeAttr<std::string>(dst_node, kAttrInterProcessEdgeLabel)
1864                             : "";
1865   }
1866   if (!src_node_edge_label.empty() && !dst_node_edge_label.empty()) {
1867     if (src_node_edge_label != dst_node_edge_label) {
1868       MS_LOG(EXCEPTION) << "The edge label name of src node and dst node should be same."
1869                         << src_node->fullname_with_scope() << "->" << dst_node->fullname_with_scope();
1870     }
1871   }
1872   InterProcessEdgeLabel edge_label;
1873   if (!src_node_edge_label.empty()) {
1874     edge_label.label_name = src_node_edge_label;
1875   } else if (!dst_node_edge_label.empty()) {
1876     edge_label.label_name = dst_node_edge_label;
1877   } else {
1878     MS_LOG(DEBUG) << "Edge label is empty for " << src_node->fullname_with_scope() << "->"
1879                   << dst_node->fullname_with_scope();
1880   }
1881   return edge_label;
1882 }
1883 
FindInterProcessInDegree(const std::vector<AnfNodePtr> & nodes,const InterProcessOpEdgesInfo & comm_edges)1884 std::vector<AnfNodePtr> GraphSplitter::FindInterProcessInDegree(const std::vector<AnfNodePtr> &nodes,
1885                                                                 const InterProcessOpEdgesInfo &comm_edges) {
1886   std::vector<AnfNodePtr> results;
1887   for (auto &n : nodes) {
1888     if (!n->isa<CNode>()) {
1889       continue;
1890     }
1891 
1892     CNodePtr cnode = n->cast<CNodePtr>();
1893     for (size_t i = 1; i < cnode->size(); i++) {
1894       auto input_i = cnode->inputs()[i];
1895       InterProcessOpEdge edge = {input_i, node_labels_[input_i], cnode, node_labels_[cnode]};
1896       if (comm_edges.count(edge) != 0 && edge.src_label == this_process_label_) {
1897         MS_LOG(INFO) << edge.to_string() << " is a communication edge.";
1898         auto comm_node_pair = comm_edges.at(edge);
1899         (void)results.emplace_back(std::get<0>(comm_node_pair));
1900       }
1901     }
1902   }
1903   return results;
1904 }
1905 
FindInterProcessOutDegree(const std::vector<AnfNodePtr> & nodes,const InterProcessOpEdgesInfo & comm_edges)1906 std::vector<AnfNodePtr> GraphSplitter::FindInterProcessOutDegree(const std::vector<AnfNodePtr> &nodes,
1907                                                                  const InterProcessOpEdgesInfo &comm_edges) {
1908   std::vector<AnfNodePtr> results;
1909   for (auto &n : nodes) {
1910     if (!n->isa<CNode>()) {
1911       continue;
1912     }
1913 
1914     CNodePtr cnode = n->cast<CNodePtr>();
1915     auto users = func_graph_->manager()->node_users()[cnode];
1916     for (auto &u : users) {
1917       auto user_node = u.first->cast<CNodePtr>();
1918       InterProcessOpEdge edge = {cnode, node_labels_[cnode], user_node, node_labels_[user_node]};
1919       if (comm_edges.count(edge) != 0 && edge.dst_label == this_process_label_) {
1920         MS_LOG(INFO) << edge.to_string() << " is a communication edge.";
1921         auto comm_node_pair = comm_edges.at(edge);
1922         (void)results.emplace_back(std::get<1>(comm_node_pair));
1923       }
1924     }
1925   }
1926   return results;
1927 }
1928 
GenerateInOutDegreeList(const std::vector<SplitGraphSegment> & segments,const InterProcessOpEdgesInfo & comm_edges)1929 InOutDegreeList GraphSplitter::GenerateInOutDegreeList(const std::vector<SplitGraphSegment> &segments,
1930                                                        const InterProcessOpEdgesInfo &comm_edges) {
1931   MS_LOG(INFO) << "Start finding inter-process in-degrees.";
1932 
1933   InOutDegreeList in_out_degree_list;
1934   // Traverse all the segments to add Depend for this process's graph.
1935   for (const auto &segment : segments) {
1936     // If this segment should be on current process, continue.
1937     if (segment.label == this_process_label_) {
1938       continue;
1939     }
1940     std::vector<AnfNodePtr> nodes = segment.nodes;
1941     if (nodes.empty()) {
1942       MS_LOG(EXCEPTION) << "This segment is empty.";
1943       return in_out_degree_list;
1944     }
1945 
1946     auto segment_first_node = nodes[0];
1947     if (node_labels_[segment_first_node] != segment.label) {
1948       MS_LOG(EXCEPTION) << "Node label " << node_labels_[segment_first_node].to_string()
1949                         << " is not the same as segment label " << segment.label.to_string();
1950     }
1951 
1952     // Prepare for adding Depend between in-degree and out-degree of this segment because the execution order should
1953     // be kept consistent.
1954     std::vector<AnfNodePtr> concerned_in_degree_nodes = FindInterProcessInDegree(nodes, comm_edges);
1955     std::vector<AnfNodePtr> concerned_out_degree_nodes = FindInterProcessOutDegree(nodes, comm_edges);
1956     if (!concerned_in_degree_nodes.empty() || !concerned_out_degree_nodes.empty()) {
1957       (void)in_out_degree_list.emplace_back(std::make_pair(concerned_in_degree_nodes, concerned_out_degree_nodes));
1958     }
1959   }
1960   MS_LOG(INFO) << "End finding inter-process in-degrees.";
1961   return in_out_degree_list;
1962 }
1963 
AddDependencyBetweenEdges(const InterProcessOpEdgesInfo & comm_edges)1964 void GraphSplitter::AddDependencyBetweenEdges(const InterProcessOpEdgesInfo &comm_edges) {
1965   // 'in_degree_comm_edges' is the edges with recv node on this process.
1966   InterProcessOpEdgesInfo in_degree_comm_edges;
1967   // 'out_degree_comm_edges' is the edges with send node on this process.
1968   InterProcessOpEdgesInfo out_degree_comm_edges;
1969 
1970   // Src nodes of RpcSend nodes.
1971   AnfNodePtrSet send_src_nodes;
1972   // Map of src nodes to its all RpcSend nodes.
1973   std::map<AnfNodePtr, AnfNodePtrSet> src_nodes_to_send_nodes;
1974   // This map represents which send nodes are hung. Key is RpcSend node.
1975   std::map<AnfNodePtr, bool> is_send_node_hung;
1976   for (const auto &e : comm_edges) {
1977     const InterProcessOpEdge &edge_info = e.first;
1978     const InterProcessOpPair &op_pair = e.second;
1979 
1980     if (edge_info.src_label == this_process_label_) {
1981       const AnfNodePtr &send_src_node = edge_info.src_node;
1982       const AnfNodePtr &rpc_send_node = std::get<0>(op_pair);
1983       (void)send_src_nodes.insert(send_src_node);
1984       (void)src_nodes_to_send_nodes[send_src_node].insert(rpc_send_node);
1985       is_send_node_hung[rpc_send_node] = true;
1986 
1987       MS_LOG(DEBUG) << "Out degree edge: " << edge_info.to_string() << ". Send src node "
1988                     << send_src_node->fullname_with_scope() << " has RpcSend node "
1989                     << rpc_send_node->fullname_with_scope();
1990       out_degree_comm_edges[edge_info] = op_pair;
1991     }
1992 
1993     if (edge_info.dst_label == this_process_label_) {
1994       MS_LOG(DEBUG) << "In degree edge: " << edge_info.to_string();
1995       in_degree_comm_edges[edge_info] = op_pair;
1996     }
1997   }
1998 
1999   // This step is vital. It builds a map consists of all dependencies to send src nodes, which helps to
2000   // add explicit dependency edges for RpcSend and RpcRecv nodes.
2001   std::map<AnfNodePtr, AnfNodePtrSet> node_dependency = FilterDependencyToTargetNode(func_graph_, send_src_nodes);
2002   MS_LOG(INFO) << "After filtering out the dependencies, add dependency edges between RpcSend and RpcRecv nodes.";
2003 
2004   // Connect RpcSend and RpcRecv with minimal dependencies.
2005   AddSendRecvDependency(in_degree_comm_edges, send_src_nodes, src_nodes_to_send_nodes, node_dependency,
2006                         &is_send_node_hung);
2007 
2008   // Some RpcSend nodes may be hung, we need to connect these nodes to output in case they are optimized out.
2009   AnfNodePtrList hung_nodes_list;
2010   for (const auto &is_hung : is_send_node_hung) {
2011     if (is_hung.second) {
2012       MS_LOG(INFO) << "RpcSend node: " << is_hung.first->fullname_with_scope() << " is hung.";
2013       (void)hung_nodes_list.emplace_back(is_hung.first);
2014     }
2015   }
2016   if (!hung_nodes_list.empty()) {
2017     HandleHungNodes(func_graph_, node_labels_, this_process_label_, hung_nodes_list);
2018   }
2019 }
2020 
AddDependencyBetweenSegments(const InOutDegreeList & in_out_degree_list)2021 void GraphSplitter::AddDependencyBetweenSegments(const InOutDegreeList &in_out_degree_list) {
2022   MS_LOG(INFO) << "Start adding dependency between segments.";
2023   // This tuple is key to the dependency of send nodes so that they will not be optimized out in some cases.
2024   std::vector<AnfNodePtr> send_node_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2025   for (size_t i = 0; i < in_out_degree_list.size(); i++) {
2026     auto &concerned_in_degree_nodes = in_out_degree_list[i].first;
2027     auto &concerned_out_degree_nodes = in_out_degree_list[i].second;
2028     (void)send_node_tuple_inputs.insert(send_node_tuple_inputs.cend(), concerned_in_degree_nodes.cbegin(),
2029                                         concerned_in_degree_nodes.cend());
2030     if (concerned_out_degree_nodes.empty()) {
2031       // If this is the last segment's in and out degrees and has no out degrees, connect the send nodes to graph's
2032       // output.
2033       if (i == in_out_degree_list.size() - 1) {
2034         auto make_tuple_node = func_graph_->NewCNode(send_node_tuple_inputs);
2035         AbstractBasePtrList abstract_list;
2036         (void)std::for_each(send_node_tuple_inputs.cbegin() + 1, send_node_tuple_inputs.cend(),
2037                             [&](const auto &input) { (void)abstract_list.emplace_back(input->abstract()); });
2038         make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
2039 
2040         // Connect fused send nodes to the output so they will not be optimized out.
2041         AnfNodePtr origin_output = func_graph_->output();
2042         if (node_labels_.count(origin_output) == 0) {
2043           MS_LOG(EXCEPTION) << "The origin output node " << origin_output->fullname_with_scope()
2044                             << " should have corresponding operator label.";
2045         }
2046 
2047         // If the output is not on this process, replace it with a fake output node.
2048         AnfNodePtr replaced_output = nullptr;
2049         if (node_labels_[origin_output] != this_process_label_) {
2050           replaced_output = CreateReplacedOutputNode(func_graph_, origin_output);
2051         } else {
2052           replaced_output = origin_output;
2053         }
2054 
2055         // Add dependency and replace.
2056         std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output, make_tuple_node};
2057         auto final_output_node = func_graph_->NewCNode(depend_inputs);
2058         MS_EXCEPTION_IF_NULL(final_output_node);
2059         final_output_node->set_abstract(replaced_output->abstract());
2060         (void)func_graph_->manager()->SetEdge(func_graph_->get_return(), 1, final_output_node);
2061       }
2062     } else {
2063       auto make_tuple_node = func_graph_->NewCNode(send_node_tuple_inputs);
2064       for (auto &recv : concerned_out_degree_nodes) {
2065         std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), recv->cast<CNodePtr>()->inputs()[1],
2066                                                 make_tuple_node};
2067         auto depend = func_graph_->NewCNode(depend_input);
2068         depend->set_abstract(recv->cast<CNodePtr>()->inputs()[1]->abstract());
2069         func_graph_->manager()->SetEdge(recv, 1, depend);
2070       }
2071       // Reset the make tuple node inputs for next segments in degrees.
2072       send_node_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2073     }
2074   }
2075   MS_LOG(INFO) << "End adding dependency between segments.";
2076 }
2077 
EliminateExtraNodes(const InterProcessOpEdgesInfo & comm_edges)2078 void GraphSplitter::EliminateExtraNodes(const InterProcessOpEdgesInfo &comm_edges) {
2079   MS_LOG(INFO) << "Start eliminating nodes not on this process.";
2080   // Eliminate nodes which should be launched by other processes by set output edge.
2081   for (auto &edge : comm_edges) {
2082     InterProcessOpPair send_recv_pair = edge.second;
2083     auto send_node = std::get<0>(send_recv_pair);
2084     auto recv_node = std::get<1>(send_recv_pair);
2085     auto user_node = std::get<2>(send_recv_pair);
2086     int user_node_index = std::get<3>(send_recv_pair);
2087 
2088     OperatorLabel send_label = node_labels_[send_node];
2089     OperatorLabel recv_label = node_labels_[recv_node];
2090     if (send_label == recv_label) {
2091       MS_LOG(EXCEPTION) << "The Send and Recv must have different label. But got Send: " << send_label.to_string()
2092                         << ", Recv: " << recv_label.to_string();
2093     }
2094 
2095     if (recv_label == this_process_label_) {
2096       func_graph_->manager()->SetEdge(user_node, user_node_index, recv_node);
2097     }
2098   }
2099   MS_LOG(INFO) << "End eliminating nodes not on this process.";
2100 }
2101 
ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)2102 void GraphSplitter::ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
2103   MS_EXCEPTION_IF_NULL(func_graph_);
2104   for (const auto &op_pair_info : fused_inter_process_op_pairs) {
2105     const OperatorLabel &send_label = op_pair_info.first.src_label;
2106     const OperatorLabel &recv_label = op_pair_info.first.dst_label;
2107     const std::vector<FusedInterProcessOpPair> &op_pairs = op_pair_info.second;
2108     if (op_pairs.empty()) {
2109       MS_LOG(EXCEPTION) << "Fused inter-process ops should not be empty for edge " << send_label.to_string() << "->"
2110                         << recv_label.to_string();
2111     }
2112 
2113     const auto &fused_recv_node = std::get<1>(*op_pairs.begin());
2114     MS_EXCEPTION_IF_NULL(fused_recv_node);
2115 
2116     // Replace origin input with recv node.
2117     if (recv_label == this_process_label_) {
2118       for (const auto &send_recv_pair : op_pairs) {
2119         const auto &user_node = std::get<3>(send_recv_pair);
2120         int user_node_index = std::get<4>(send_recv_pair);
2121 
2122         const auto &recv_abs = fused_recv_node->abstract();
2123         MS_EXCEPTION_IF_NULL(recv_abs);
2124         // The outputs of a Recv node could be a tuple or a single tensor because it could be fused.
2125         if (recv_abs->isa<abstract::AbstractTuple>()) {
2126           int output_index = std::get<2>(send_recv_pair);
2127           CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, fused_recv_node, IntToSize(output_index));
2128           func_graph_->manager()->SetEdge(user_node, user_node_index, tuple_get_item_node);
2129         } else {
2130           func_graph_->manager()->SetEdge(user_node, user_node_index, fused_recv_node);
2131         }
2132       }
2133     }
2134   }
2135 }
2136 
AddSendRecvDependency(const InterProcessOpEdgesInfo & in_degree_comm_edges,const AnfNodePtrSet & send_src_nodes,const std::map<AnfNodePtr,AnfNodePtrSet> & src_nodes_to_send_nodes,const std::map<AnfNodePtr,AnfNodePtrSet> & node_dependency,std::map<AnfNodePtr,bool> * is_send_node_hung)2137 void GraphSplitter::AddSendRecvDependency(const InterProcessOpEdgesInfo &in_degree_comm_edges,
2138                                           const AnfNodePtrSet &send_src_nodes,
2139                                           const std::map<AnfNodePtr, AnfNodePtrSet> &src_nodes_to_send_nodes,
2140                                           const std::map<AnfNodePtr, AnfNodePtrSet> &node_dependency,
2141                                           std::map<AnfNodePtr, bool> *is_send_node_hung) {
2142   for (const auto &in_edge : in_degree_comm_edges) {
2143     const auto &rpc_recv_node = std::get<1>(in_edge.second);
2144     const auto &recv_dst_node = std::get<2>(in_edge.second);
2145     MS_LOG(DEBUG) << "Add dependency for RpcRecv node " << rpc_recv_node->fullname_with_scope()
2146                   << " with recv dst node " << recv_dst_node->fullname_with_scope();
2147     AnfNodePtrSet depended_nodes;
2148     for (const auto &send_src_node : send_src_nodes) {
2149       // Get minimum send src nodes set which have dependencies with RpcRecv node.
2150       if (node_dependency.count(recv_dst_node) != 0 && node_dependency.at(recv_dst_node).contains(send_src_node)) {
2151         depended_nodes = UpdateDependedSet(send_src_node, depended_nodes, node_dependency);
2152       }
2153     }
2154     MS_LOG(DEBUG) << "RpcRecv dst node " << recv_dst_node->fullname_with_scope()
2155                   << " depends on RpcSend src node size: " << depended_nodes.size();
2156 
2157     // Generate RpcSend nodes input list to add dependency with RpcRecv Nodes.
2158     AnfNodePtrList rpc_send_list;
2159     for (const auto &send_src_node : depended_nodes) {
2160       if (src_nodes_to_send_nodes.count(send_src_node) == 0) {
2161         MS_LOG(EXCEPTION) << "Send src node " << send_src_node->fullname_with_scope()
2162                           << " has no corresponding RpcSend nodes.";
2163       }
2164       const AnfNodePtrSet &rpc_send_nodes = src_nodes_to_send_nodes.at(send_src_node);
2165       for (const auto &rpc_send : rpc_send_nodes) {
2166         (*is_send_node_hung)[rpc_send] = false;
2167         (void)rpc_send_list.emplace_back(rpc_send);
2168       }
2169     }
2170     if (!rpc_send_list.empty()) {
2171       AnfNodePtr send_node_make_tuple = CreateMakeTupleNode(func_graph_, rpc_send_list);
2172       MS_EXCEPTION_IF_NULL(send_node_make_tuple);
2173       MS_LOG(DEBUG) << "Connect " << send_node_make_tuple->fullname_with_scope() << " with RpcRecv node "
2174                     << rpc_recv_node->fullname_with_scope();
2175 
2176       auto recv_data = rpc_recv_node->cast<CNodePtr>()->inputs()[kIndex1];
2177       MS_EXCEPTION_IF_NULL(recv_data);
2178 
2179       AnfNodePtrList depend_node_inputs = {NewValueNode(prim::kPrimDepend), recv_data, send_node_make_tuple};
2180       auto depend_node = func_graph_->NewCNode(depend_node_inputs);
2181       MS_EXCEPTION_IF_NULL(depend_node);
2182       depend_node->set_abstract(recv_data->abstract());
2183       func_graph_->manager()->SetEdge(rpc_recv_node, kIndex1, depend_node);
2184     }
2185   }
2186 }
2187 
AddDependencyForSend(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)2188 void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
2189   // Connect all Send nodes to MakeTuple.
2190   std::vector<AnfNodePtr> fused_send_node_tuple_inputs;
2191   MS_EXCEPTION_IF_NULL(func_graph_);
2192   for (const auto &op_pair_info : fused_inter_process_op_pairs) {
2193     const OperatorLabel &send_label = op_pair_info.first.src_label;
2194     const OperatorLabel &recv_label = op_pair_info.first.dst_label;
2195     const std::vector<FusedInterProcessOpPair> &op_pairs = op_pair_info.second;
2196     if (op_pairs.empty()) {
2197       MS_LOG(EXCEPTION) << "Fused inter-process ops should not be empty for edge " << send_label.to_string() << "->"
2198                         << recv_label.to_string();
2199     }
2200     const auto &fused_send_node = std::get<0>(*op_pairs.begin());
2201     MS_EXCEPTION_IF_NULL(fused_send_node);
2202     // Make tuple all fused send nodes.
2203     if (send_label == this_process_label_) {
2204       (void)fused_send_node_tuple_inputs.emplace_back(fused_send_node);
2205     }
2206   }
2207   CNodePtr fused_send_make_tuple_node = CreateMakeTupleNode(func_graph_, fused_send_node_tuple_inputs);
2208   MS_EXCEPTION_IF_NULL(fused_send_make_tuple_node);
2209 
2210   // Connect fused send nodes to the output so they will not be optimized out.
2211   AnfNodePtr origin_output = func_graph_->output();
2212   if (node_labels_.count(origin_output) == 0) {
2213     MS_LOG(EXCEPTION) << "The origin output node " << origin_output->fullname_with_scope()
2214                       << " should have corresponding operator label.";
2215   }
2216 
2217   // If the output is not on this process, replace it with a fake output node.
2218   AnfNodePtr replaced_output = nullptr;
2219   if (node_labels_[origin_output] != this_process_label_) {
2220     replaced_output = CreateReplacedOutputNode(func_graph_, origin_output);
2221   } else {
2222     replaced_output = origin_output;
2223   }
2224 
2225   // Add dependency and replace.
2226   std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output,
2227                                            fused_send_make_tuple_node};
2228   auto final_output_node = func_graph_->NewCNode(depend_inputs);
2229   MS_EXCEPTION_IF_NULL(final_output_node);
2230   final_output_node->set_abstract(replaced_output->abstract());
2231   (void)func_graph_->manager()->SetEdge(func_graph_->get_return(), 1, final_output_node);
2232 }
2233 
IsNodesWithSameLabel(const AnfNodePtr & node1,const AnfNodePtr & node2)2234 bool GraphSplitter::IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2) {
2235   if (node_labels_.count(node1) == 0 || node_labels_.count(node2) == 0) {
2236     MS_LOG(EXCEPTION) << "Either 'node1': " << node1->fullname_with_scope()
2237                       << " or 'node2': " << node2->fullname_with_scope() << " is not marked with split label.";
2238   }
2239   return node_labels_[node1] == node_labels_[node2];
2240 }
2241 
NeedSplitGraph() const2242 bool GraphSplitter::NeedSplitGraph() const {
2243   return std::find_if(node_labels_.begin(), node_labels_.end(), [&](const auto &node_to_label) {
2244            return node_to_label.second != this_process_label_;
2245          }) != node_labels_.end();
2246 }
2247 
NodeHasLabel(const AnfNodePtr & node)2248 bool GraphSplitter::NodeHasLabel(const AnfNodePtr &node) { return node_labels_.count(node) != 0; }
2249 }  // namespace parallel
2250 }  // namespace mindspore
2251