1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/graph_util/graph_splitter.h"
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include "include/common/debug/draw.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/common/utils/parallel_context.h"
28 #include "include/common/utils/utils.h"
29 #include "mindspore/core/utils/ms_context.h"
30 #include "ops/array_op_name.h"
31 #include "ops/framework_ops.h"
32 #include "ops/math_op_name.h"
33 #include "ops/sequence_ops.h"
34 #if defined(__linux__) && defined(WITH_BACKEND)
35 #include "include/backend/distributed/ps/ps_context.h"
36 #endif
37
38 namespace mindspore {
39 namespace parallel {
operator <(const OperatorLabel & label) const40 bool OperatorLabel::operator<(const OperatorLabel &label) const { return to_string() < label.to_string(); }
41
operator ==(const OperatorLabel & label) const42 bool OperatorLabel::operator==(const OperatorLabel &label) const { return to_string() == label.to_string(); }
43
operator !=(const OperatorLabel & label) const44 bool OperatorLabel::operator!=(const OperatorLabel &label) const { return !(*this == label); }
45
LooseEqual(const OperatorLabel & label,distributed::DistExecutionMode mode) const46 bool OperatorLabel::LooseEqual(const OperatorLabel &label, distributed::DistExecutionMode mode) const {
47 if (kLabelMatchingFuncMap.count(mode) == 0) {
48 MS_LOG(DEBUG) << "The mode " << mode << " does not need LooseEqual.";
49 return to_string() == label.to_string();
50 }
51 return kLabelMatchingFuncMap.at(mode)(label, *this);
52 }
53
to_string() const54 std::string OperatorLabel::to_string() const { return std::to_string(rank_id) + "_" + ms_role; }
55
CreateFakeValueNode(bool use_origin_node,const AnfNodePtr & origin_node,bool use_fake_shape)56 ValueNodePtr CreateFakeValueNode(bool use_origin_node, const AnfNodePtr &origin_node, bool use_fake_shape) {
57 tensor::TensorPtr fake_tensor = nullptr;
58 if (use_origin_node) {
59 MS_EXCEPTION_IF_NULL(origin_node);
60 abstract::AbstractTensorPtr origin_abstract;
61 if (origin_node->abstract()->isa<abstract::AbstractTuple>()) {
62 // Defaultly, if the origin node's output is a tuple, get the abstract of the first element.
63 auto get_one_tuple_element = origin_node->abstract()->cast<abstract::AbstractTuplePtr>()->elements()[0];
64 origin_abstract = get_one_tuple_element->cast<abstract::AbstractTensorPtr>();
65 } else {
66 origin_abstract = origin_node->abstract()->cast<abstract::AbstractTensorPtr>();
67 }
68 MS_EXCEPTION_IF_NULL(origin_abstract);
69 auto element = origin_abstract->element();
70 MS_EXCEPTION_IF_NULL(element);
71 auto build_type = element->BuildType();
72 MS_EXCEPTION_IF_NULL(build_type);
73 auto type_id = build_type->type_id();
74 if (use_fake_shape) {
75 // Assign send's output shape as {1};
76 ShapeVector fake_shape = {kSizeOne};
77 fake_tensor = std::make_shared<tensor::Tensor>(type_id, fake_shape);
78 } else {
79 auto shape = origin_abstract->shape();
80 MS_EXCEPTION_IF_NULL(shape);
81 fake_tensor = std::make_shared<tensor::Tensor>(type_id, shape->shape());
82 fake_tensor->set_base_shape(shape->Clone());
83 }
84 } else {
85 fake_tensor = std::make_shared<tensor::Tensor>(1.0);
86 }
87
88 MS_EXCEPTION_IF_NULL(fake_tensor);
89 auto fake_value = NewValueNode(fake_tensor);
90 MS_EXCEPTION_IF_NULL(fake_value);
91 fake_value->set_abstract(fake_tensor->ToAbstract());
92 return fake_value;
93 }
94
CreateTupleGetItemNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node_with_tuple_output,size_t item_index)95 CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node_with_tuple_output,
96 size_t item_index) {
97 MS_EXCEPTION_IF_NULL(func_graph);
98 MS_EXCEPTION_IF_NULL(node_with_tuple_output);
99 const auto &tuple_abstract = node_with_tuple_output->abstract();
100 MS_EXCEPTION_IF_NULL(tuple_abstract);
101 if (!tuple_abstract->isa<abstract::AbstractTuple>()) {
102 MS_LOG(EXCEPTION) << "Only create TupleGetItem for tuple output.";
103 }
104
105 auto item_index_value_node = NewValueNode(MakeValue(UlongToLong(item_index)));
106 MS_EXCEPTION_IF_NULL(item_index_value_node);
107
108 std::vector<AnfNodePtr> tuple_get_item_inputs = {NewValueNode(prim::kPrimTupleGetItem), node_with_tuple_output,
109 item_index_value_node};
110 CNodePtr tuple_get_item_node = func_graph->NewCNode(tuple_get_item_inputs);
111 MS_EXCEPTION_IF_NULL(tuple_get_item_node);
112 tuple_get_item_node->set_abstract(tuple_abstract->cast<abstract::AbstractTuplePtr>()->elements()[item_index]);
113 return tuple_get_item_node;
114 }
115
CreateMakeTupleNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & tuple_inputs)116 CNodePtr CreateMakeTupleNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &tuple_inputs) {
117 MS_EXCEPTION_IF_NULL(func_graph);
118 AnfNodePtrList new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
119 (void)new_make_tuple_inputs.insert(new_make_tuple_inputs.cend(), tuple_inputs.cbegin(), tuple_inputs.cend());
120 auto make_tuple_node = func_graph->NewCNode(new_make_tuple_inputs);
121 MS_EXCEPTION_IF_NULL(make_tuple_node);
122
123 // MakeTuple's abstract must consist of all inputs' abstract in case unexpected graph compiling error.
124 AbstractBasePtrList abstract_list;
125 (void)std::for_each(tuple_inputs.cbegin(), tuple_inputs.cend(),
126 [&](const auto &input) { (void)abstract_list.emplace_back(input->abstract()); });
127 make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
128 return make_tuple_node;
129 }
130
CreateReplacedOutputNode(const FuncGraphPtr & func_graph,const AnfNodePtr & origin_output)131 AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output) {
132 MS_EXCEPTION_IF_NULL(func_graph);
133 MS_EXCEPTION_IF_NULL(origin_output);
134 MS_EXCEPTION_IF_NULL(origin_output->abstract());
135 if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
136 auto kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(origin_output, kIndex0);
137 auto real_output = kernel_with_index.first;
138 if (!IsPrimitiveCNode(real_output, prim::kPrimMakeTuple)) {
139 MS_LOG(EXCEPTION) << "Tuple output is not a MakeTuple node: " << real_output->DebugString();
140 }
141 AnfNodePtrList tuple_inputs;
142 auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
143 for (size_t i = kIndex0; i < tuple_elements.size(); i++) {
144 // If tuple input is a ValueNode, use it as new tuple's input.
145 const auto tuple_input = real_output->cast<CNodePtr>()->input(i + kSizeOne);
146 if (tuple_input->isa<Parameter>() || tuple_input->isa<ValueNode>()) {
147 MS_LOG(INFO) << "Use " << tuple_input->DebugString() << " as replaced output.";
148 (void)tuple_inputs.emplace_back(tuple_input);
149 continue;
150 }
151
152 const auto &element = tuple_elements[i];
153 MS_EXCEPTION_IF_NULL(element);
154 auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
155 if (!tensor_abstract) {
156 MS_LOG(EXCEPTION) << "Only support to replace tuple with all tensor elements.";
157 }
158 auto fake_tensor = std::make_shared<tensor::Tensor>(tensor_abstract->element()->BuildType()->type_id(),
159 tensor_abstract->shape()->shape());
160 MS_EXCEPTION_IF_NULL(fake_tensor);
161 auto fake_value_node = NewValueNode(fake_tensor);
162 MS_EXCEPTION_IF_NULL(fake_value_node);
163 fake_value_node->set_abstract(fake_tensor->ToAbstract());
164 (void)tuple_inputs.emplace_back(fake_value_node);
165 }
166 return CreateMakeTupleNode(func_graph, tuple_inputs);
167 } else {
168 return CreateFakeValueNode(true, origin_output);
169 }
170 }
171
SetSendNodeAttr(const AnfNodePtr & send_node,const InterProcessOpEdge & inter_process_edge)172 void SetSendNodeAttr(const AnfNodePtr &send_node, const InterProcessOpEdge &inter_process_edge) {
173 const auto &send_src_node = inter_process_edge.src_node;
174 const auto &send_dst_node = inter_process_edge.dst_node;
175 MS_EXCEPTION_IF_NULL(send_src_node);
176 MS_EXCEPTION_IF_NULL(send_dst_node);
177 MS_EXCEPTION_IF_NULL(send_node);
178
179 std::string src_node_name = send_src_node->fullname_with_scope();
180 std::string dst_node_name = send_dst_node->fullname_with_scope();
181
182 // These attributes are the inter-process edge information.
183 std::vector<uint32_t> dst_ranks = {inter_process_edge.dst_label.rank_id};
184 common::AnfAlgo::SetNodeAttr(kAttrSendDstRanks, MakeValue(dst_ranks), send_node);
185 std::vector<std::string> dst_roles = {inter_process_edge.dst_label.ms_role};
186 common::AnfAlgo::SetNodeAttr(kAttrSendDstRoles, MakeValue(dst_roles), send_node);
187
188 common::AnfAlgo::SetNodeAttr(kAttrSendSrcNodeName, MakeValue(src_node_name), send_node);
189 common::AnfAlgo::SetNodeAttr(kAttrSendDstNodeName, MakeValue(dst_node_name), send_node);
190 std::vector<std::string> inter_process_edges = {inter_process_edge.to_string()};
191 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), send_node);
192
193 // Set send node to CPU for now.
194 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), send_node);
195 }
196
SetRecvNodeAttr(const AnfNodePtr & recv_node,const InterProcessOpEdge & inter_process_edge)197 void SetRecvNodeAttr(const AnfNodePtr &recv_node, const InterProcessOpEdge &inter_process_edge) {
198 const auto &recv_src_node = inter_process_edge.src_node;
199 const auto &recv_dst_node = inter_process_edge.dst_node;
200 MS_EXCEPTION_IF_NULL(recv_src_node);
201 MS_EXCEPTION_IF_NULL(recv_dst_node);
202 MS_EXCEPTION_IF_NULL(recv_node);
203
204 std::string src_node_name = recv_src_node->fullname_with_scope();
205 std::string dst_node_name = recv_dst_node->fullname_with_scope();
206
207 // These attributes are the inter-process edge information.
208 std::vector<uint32_t> src_ranks = {inter_process_edge.src_label.rank_id};
209 common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRanks, MakeValue(src_ranks), recv_node);
210 std::vector<std::string> src_roles = {inter_process_edge.src_label.ms_role};
211 common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRoles, MakeValue(src_roles), recv_node);
212
213 common::AnfAlgo::SetNodeAttr(kAttrRecvSrcNodeName, MakeValue(src_node_name), recv_node);
214 common::AnfAlgo::SetNodeAttr(kAttrRecvDstNodeName, MakeValue(dst_node_name), recv_node);
215 std::vector<std::string> inter_process_edges = {inter_process_edge.to_string()};
216 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), recv_node);
217
218 // Set recv node to CPU for now.
219 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), recv_node);
220 }
221
CreateSendNode(const FuncGraphPtr & func_graph,const InterProcessOpEdge & inter_process_edge)222 CNodePtr CreateSendNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge &inter_process_edge) {
223 const auto &src_node = inter_process_edge.src_node;
224 const auto &dst_node = inter_process_edge.dst_node;
225 MS_EXCEPTION_IF_NULL(src_node);
226 MS_EXCEPTION_IF_NULL(dst_node);
227
228 std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
229 ValueNodePtr mock_value = nullptr;
230 if (IsPrimitiveCNode(src_node, prim::kPrimUpdateState)) {
231 mock_value = CreateFakeValueNode(false);
232 send_inputs.push_back(mock_value);
233 send_inputs.push_back(src_node);
234 } else {
235 if (src_node->abstract()->isa<abstract::AbstractTuple>()) {
236 // If src_node's output is a tuple, get the first element of the tuple as Send's input.
237 auto tuple_get_item_node = CreateTupleGetItemNode(func_graph, src_node, kIndex0);
238 send_inputs.push_back(tuple_get_item_node);
239 mock_value = CreateFakeValueNode(true, tuple_get_item_node);
240 } else {
241 send_inputs.push_back(src_node);
242 mock_value = CreateFakeValueNode(true, src_node);
243 }
244 }
245 CNodePtr send_node = func_graph->NewCNode(send_inputs);
246 MS_EXCEPTION_IF_NULL(send_node);
247 send_node->set_abstract(mock_value->abstract());
248
249 SetSendNodeAttr(send_node, inter_process_edge);
250 return send_node;
251 }
252
CreateRecvNode(const FuncGraphPtr & func_graph,const InterProcessOpEdge & inter_process_edge)253 CNodePtr CreateRecvNode(const FuncGraphPtr &func_graph, const InterProcessOpEdge &inter_process_edge) {
254 const auto &src_node = inter_process_edge.src_node;
255 const auto &dst_node = inter_process_edge.dst_node;
256 MS_EXCEPTION_IF_NULL(src_node);
257 MS_EXCEPTION_IF_NULL(dst_node);
258
259 std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
260 CNodePtr recv_node = nullptr;
261 AbstractBasePtr recv_node_abs = nullptr;
262 if (IsPrimitiveCNode(src_node, prim::kPrimUpdateState)) {
263 ValuePtr monad_value = nullptr;
264 if (HasAbstractUMonad(src_node)) {
265 monad_value = kUMonad;
266 } else if (HasAbstractIOMonad(src_node)) {
267 monad_value = kIOMonad;
268 } else {
269 MS_LOG(EXCEPTION) << "The src_node is PrimUpdateState must have monad abstract.";
270 }
271 auto monad_input = NewValueNode(monad_value);
272 MS_EXCEPTION_IF_NULL(monad_input);
273 monad_input->set_abstract(monad_value->ToAbstract());
274 recv_inputs.push_back(monad_input);
275 recv_node_abs = src_node->abstract();
276 } else {
277 if (src_node->isa<CNode>() && common::AnfAlgo::HasNodeAttr(kAttrUpdateParameter, src_node->cast<CNodePtr>()) &&
278 common::AnfAlgo::HasNodeAttr(kAttrParameterInputIndex, src_node->cast<CNodePtr>())) {
279 int64_t parameter_index = common::AnfAlgo::GetNodeAttr<int64_t>(src_node, kAttrParameterInputIndex);
280 auto kernel_with_index = common::AnfAlgo::VisitKernel(
281 common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), parameter_index), kIndex0);
282 auto param_node = kernel_with_index.first;
283 recv_inputs.push_back(param_node);
284
285 // To update the parameter on the device side in heterogeneous case, side-effect node should be added to recv's
286 // input.
287 ValuePtr monad_value = kUMonad;
288 auto monad_input = NewValueNode(monad_value);
289 MS_EXCEPTION_IF_NULL(monad_input);
290 monad_input->set_abstract(monad_value->ToAbstract());
291 recv_inputs.push_back(monad_input);
292
293 recv_node_abs = param_node->abstract();
294 } else if (src_node->isa<CNode>() && common::AnfAlgo::GetCNodeName(src_node) == distributed::kDataSyncSrcOpName) {
295 auto kernel_with_index =
296 common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(src_node->cast<CNodePtr>(), kIndex0), kIndex0);
297 auto param_node = kernel_with_index.first;
298 recv_inputs.push_back(param_node);
299
300 ValuePtr monad_value = kUMonad;
301 auto monad_input = NewValueNode(monad_value);
302 MS_EXCEPTION_IF_NULL(monad_input);
303 monad_input->set_abstract(monad_value->ToAbstract());
304 recv_inputs.push_back(monad_input);
305
306 recv_node_abs = param_node->abstract();
307 } else {
308 // Use the same shape as origin node's.
309 auto mock_value = CreateFakeValueNode(true, src_node, false);
310 MS_EXCEPTION_IF_NULL(mock_value);
311 recv_inputs.push_back(mock_value);
312 recv_node_abs = src_node->abstract();
313 }
314 }
315 recv_node = func_graph->NewCNode(recv_inputs);
316 MS_EXCEPTION_IF_NULL(recv_node);
317 recv_node->set_abstract(recv_node_abs);
318
319 SetRecvNodeAttr(recv_node, inter_process_edge);
320 return recv_node;
321 }
322
GetRealIndexToSeg(const std::vector<size_t> & split_segment,size_t real_size)323 std::map<size_t, size_t> GetRealIndexToSeg(const std::vector<size_t> &split_segment, size_t real_size) {
324 std::map<size_t, size_t> result;
325 // If split_segment is empty, return an empty map.
326 if (split_segment.empty()) {
327 return result;
328 }
329
330 // Check whether the vector of indices is valid.
331 if (!std::is_sorted(split_segment.begin(), split_segment.end())) {
332 MS_LOG(EXCEPTION) << "Indices of segments is not in a ascending order: " << split_segment;
333 }
334
335 size_t real_index = 0;
336 for (size_t seg_index = 0; seg_index < split_segment.size(); seg_index++) {
337 size_t upper_bound = split_segment[seg_index];
338 for (; real_index < real_size; real_index++) {
339 if (real_index <= upper_bound) {
340 result[real_index] = seg_index;
341 } else {
342 break;
343 }
344 }
345 }
346
347 // Map the rest of real index to a segment.
348 if (real_size > (*split_segment.rbegin()) + 1) {
349 for (; real_index < real_size; real_index++) {
350 result[real_index] = split_segment.size();
351 }
352 }
353 return result;
354 }
355
IsOneOfRealGraphInput(const FuncGraphPtr & func_graph,const AnfNodePtr & input)356 bool IsOneOfRealGraphInput(const FuncGraphPtr &func_graph, const AnfNodePtr &input) {
357 MS_EXCEPTION_IF_NULL(func_graph);
358 MS_EXCEPTION_IF_NULL(input);
359 auto all_inputs = func_graph->get_inputs();
360 return std::count(all_inputs.begin(), all_inputs.end(), input) != 0;
361 }
362
GenerateStrategy()363 distributed::DistExecutionMode GenerateStrategy() {
364 distributed::DistExecutionMode strategy;
365 bool enable_ps = false;
366 bool enable_embedding_cache = false;
367 #if defined(__linux__) && defined(WITH_BACKEND)
368 enable_ps = ps::PSContext::instance()->is_ps_mode();
369 enable_embedding_cache = ps::PSContext::instance()->cache_enable();
370 #endif
371 std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
372 MS_LOG(INFO) << "Current parallel mode is " << parallel_mode;
373 bool using_parallel = (parallel_mode != parallel::kStandalone) ? true : false;
374 // The conditions' priority is: EmbeddingCache > Parameter Server > General.
375 if (enable_embedding_cache) {
376 strategy = distributed::DistExecutionMode::kEmbeddingCacheMode;
377 } else if (enable_ps) {
378 strategy = distributed::DistExecutionMode::kPSMode;
379 } else if (using_parallel) {
380 strategy = distributed::DistExecutionMode::kParallelMode;
381 } else {
382 strategy = distributed::DistExecutionMode::kGeneralMode;
383 }
384 MS_LOG(INFO) << "Generated distributed strategy is " << strategy;
385 return strategy;
386 }
387
TransformPrimAttrToAttr(const CNodePtr & cnode)388 void TransformPrimAttrToAttr(const CNodePtr &cnode) {
389 MS_EXCEPTION_IF_NULL(cnode);
390 auto prim = GetValueNode<PrimitivePtr>(cnode->input(kIndex0));
391 MS_EXCEPTION_IF_NULL(prim);
392 if (cnode->HasPrimalAttr(distributed::kOpLabelRankId)) {
393 MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'rank_id'.";
394 prim->set_attr(distributed::kOpLabelRankId, cnode->GetPrimalAttr(distributed::kOpLabelRankId));
395 }
396 if (cnode->HasPrimalAttr(distributed::kOpLabelRole)) {
397 MS_LOG(DEBUG) << cnode->fullname_with_scope() << " has primal attr 'ms_role'.";
398 prim->set_attr(distributed::kOpLabelRole, cnode->GetPrimalAttr(distributed::kOpLabelRole));
399 }
400 }
401
NodeHasLabel(const AnfNodePtr & node)402 bool NodeHasLabel(const AnfNodePtr &node) {
403 MS_EXCEPTION_IF_NULL(node);
404 if (!node->isa<CNode>()) {
405 return false;
406 }
407
408 bool has_label = false;
409 CNodePtr cnode = node->cast<CNodePtr>();
410 MS_EXCEPTION_IF_NULL(cnode);
411 auto prim_node = cnode->input(0);
412 MS_EXCEPTION_IF_NULL(prim_node);
413
414 // As long as the node has 'ms_role' and 'rank_id' attributes, we consider this node has label regardless the value of
415 // these two attributes.
416 if (IsValueNode<Primitive>(prim_node)) {
417 auto prim = GetValueNode<PrimitivePtr>(prim_node);
418 MS_EXCEPTION_IF_NULL(prim);
419 if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
420 has_label = true;
421 }
422 } else {
423 // Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
424 if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
425 has_label = true;
426 }
427 }
428 return has_label;
429 }
430
GraphHasLabel(const FuncGraphPtr & func_graph)431 bool GraphHasLabel(const FuncGraphPtr &func_graph) {
432 MS_EXCEPTION_IF_NULL(func_graph);
433
434 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph->get_return());
435 // If one node has label, this graph has label. Thus it needs to be split.
436 for (const auto &node : all_nodes) {
437 MS_EXCEPTION_IF_NULL(node);
438 if (NodeHasLabel(node)) {
439 return true;
440 }
441 }
442 return false;
443 }
444
GetSideEffectNodeList(const AnfNodePtrList & nodes)445 CNodePtrList GetSideEffectNodeList(const AnfNodePtrList &nodes) {
446 CNodePtrList side_effect_nodes;
447 for (const auto &node : nodes) {
448 MS_EXCEPTION_IF_NULL(node);
449 if (!node->isa<CNode>()) {
450 continue;
451 }
452 auto cnode = node->cast<CNodePtr>();
453 MS_EXCEPTION_IF_NULL(cnode);
454 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
455 MS_EXCEPTION_IF_NULL(prim);
456 if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_MEM)) {
457 (void)side_effect_nodes.emplace_back(cnode);
458 MS_LOG(DEBUG) << "CNode with side effect mem: " << cnode->fullname_with_scope();
459 }
460 }
461 return side_effect_nodes;
462 }
463
GetRefInputs(const CNodePtr & cnode)464 AnfNodePtrList GetRefInputs(const CNodePtr &cnode) {
465 MS_EXCEPTION_IF_NULL(cnode);
466 AnfNodePtrList ref_inputs;
467 for (size_t i = kIndex1; i < cnode->size(); ++i) {
468 auto &input = cnode->inputs().at(i);
469 if (common::AnfAlgo::HasAbstractRef(input)) {
470 ref_inputs.push_back(input);
471 MS_LOG(DEBUG) << "The ref input " << input->fullname_with_scope() << " of node " << cnode->fullname_with_scope();
472 }
473 }
474 return ref_inputs;
475 }
476
FindNextUpdateStateNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)477 CNodePtr FindNextUpdateStateNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
478 MS_EXCEPTION_IF_NULL(func_graph);
479 MS_EXCEPTION_IF_NULL(cnode);
480 auto cnode_users = func_graph->manager()->node_users()[cnode];
481 for (const auto &user : cnode_users) {
482 auto user_node = user.first;
483 MS_EXCEPTION_IF_NULL(user_node);
484 if (common::AnfAlgo::GetCNodeName(user_node) == kUpdateStateOpName) {
485 return user_node->cast<CNodePtr>();
486 }
487 }
488 return nullptr;
489 }
490
CreateUMonadNode()491 ValueNodePtr CreateUMonadNode() {
492 ValuePtr monad_value = kUMonad;
493 auto monad_input = NewValueNode(monad_value);
494 MS_EXCEPTION_IF_NULL(monad_input);
495 monad_input->set_abstract(monad_value->ToAbstract());
496 return monad_input;
497 }
498
CreateUpdateStateNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & update_state_inputs)499 CNodePtr CreateUpdateStateNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &update_state_inputs) {
500 if (update_state_inputs.empty()) {
501 MS_LOG(EXCEPTION) << "The inputs of UpdateState should not be empty.";
502 }
503 // The first input of UpdateState is an 'U'.
504 ValueNodePtr umonad_input = CreateUMonadNode();
505 MS_EXCEPTION_IF_NULL(umonad_input);
506 AnfNodePtrList inputs = {NewValueNode(prim::kPrimUpdateState), umonad_input};
507 (void)inputs.insert(inputs.end(), update_state_inputs.begin(), update_state_inputs.end());
508
509 auto update_state_node = func_graph->NewCNode(inputs);
510 MS_EXCEPTION_IF_NULL(update_state_node);
511 update_state_node->set_abstract(umonad_input->abstract());
512 return update_state_node;
513 }
514
FilterDependencyToTargetNode(const FuncGraphPtr & func_graph,const AnfNodePtrSet & target_nodes)515 std::map<AnfNodePtr, AnfNodePtrSet> FilterDependencyToTargetNode(const FuncGraphPtr &func_graph,
516 const AnfNodePtrSet &target_nodes) {
517 std::map<AnfNodePtr, AnfNodePtrSet> depend_matrix;
518 MS_EXCEPTION_IF_NULL(func_graph);
519 auto return_node = func_graph->get_return();
520 MS_EXCEPTION_IF_NULL(return_node);
521 AnfNodePtrList nodes = FuncGraph::TopoSort(return_node);
522 // Trasverse all nodes in topo-sort so that time complexity is O(n).
523 for (const auto &node : nodes) {
524 MS_EXCEPTION_IF_NULL(node);
525 if (!node->isa<CNode>()) {
526 continue;
527 }
528 CNodePtr cnode = node->cast<CNodePtr>();
529 MS_EXCEPTION_IF_NULL(cnode);
530 const auto &inputs = cnode->inputs();
531 // Traverse all inputs and only filter out inputs which is in target nodes set.
532 for (const auto &input : inputs) {
533 MS_EXCEPTION_IF_NULL(input);
534 // If the input is stored already, this means it depends on some of target nodes, so we expand its inputs and
535 // insert them.
536 if (depend_matrix.count(input) != 0) {
537 (void)depend_matrix[node].insert(depend_matrix[input].begin(), depend_matrix[input].end());
538 }
539 // If input itself is in target nodes set, insert it as well.
540 if (target_nodes.contains(input)) {
541 (void)depend_matrix[node].insert(input);
542 }
543 }
544 }
545 return depend_matrix;
546 }
547
UpdateDependedSet(const AnfNodePtr & new_node,const AnfNodePtrSet & old_depended_set,const std::map<AnfNodePtr,AnfNodePtrSet> & node_dependency)548 AnfNodePtrSet UpdateDependedSet(const AnfNodePtr &new_node, const AnfNodePtrSet &old_depended_set,
549 const std::map<AnfNodePtr, AnfNodePtrSet> &node_dependency) {
550 AnfNodePtrSet updated = old_depended_set;
551 bool is_independent = true;
552 for (const auto &stored_node : old_depended_set) {
553 // If 'new_node' is already depended on by 'stored_node', no need to add 'new_node'.
554 if (node_dependency.count(stored_node) != 0 && node_dependency.at(stored_node).contains(new_node)) {
555 MS_LOG(DEBUG) << "Old node " << stored_node->fullname_with_scope() << " depends on "
556 << new_node->fullname_with_scope() << ". Do not update.";
557 is_independent = false;
558 break;
559 }
560 // If 'new_node' depends on 'stored_node', replace 'stored_node' with 'new_node' to keep minimal dependency.
561 if (node_dependency.count(new_node) != 0 && node_dependency.at(new_node).contains(stored_node)) {
562 MS_LOG(DEBUG) << "Replace old node " << stored_node->fullname_with_scope() << " with new node "
563 << new_node->fullname_with_scope();
564 (void)updated.erase(stored_node);
565 (void)updated.insert(new_node);
566 }
567 }
568 if (is_independent) {
569 MS_LOG(DEBUG) << "Add new node to depended set " << new_node->fullname_with_scope();
570 (void)updated.insert(new_node);
571 }
572 return updated;
573 }
574
HandleHungNodes(const FuncGraphPtr & func_graph,const NodeLabels & node_labels,OperatorLabel process_label,const AnfNodePtrList & hung_nodes_list)575 void HandleHungNodes(const FuncGraphPtr &func_graph, const NodeLabels &node_labels, OperatorLabel process_label,
576 const AnfNodePtrList &hung_nodes_list) {
577 MS_EXCEPTION_IF_NULL(func_graph);
578 auto make_tuple_node = CreateMakeTupleNode(func_graph, hung_nodes_list);
579 MS_EXCEPTION_IF_NULL(make_tuple_node);
580
581 const auto &origin_output = func_graph->output();
582 MS_EXCEPTION_IF_NULL(origin_output);
583 if (node_labels.count(origin_output) == 0) {
584 MS_LOG(EXCEPTION) << "The origin output node " << origin_output->fullname_with_scope()
585 << " should have corresponding operator label.";
586 }
587 AnfNodePtr replaced_output = nullptr;
588 if (node_labels.at(origin_output) != process_label) {
589 replaced_output = CreateReplacedOutputNode(func_graph, origin_output);
590 } else {
591 replaced_output = origin_output;
592 }
593 MS_EXCEPTION_IF_NULL(replaced_output);
594
595 // Add dependency and replace.
596 std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output, make_tuple_node};
597 auto final_output_node = func_graph->NewCNode(depend_inputs);
598 MS_EXCEPTION_IF_NULL(final_output_node);
599 final_output_node->set_abstract(replaced_output->abstract());
600 (void)func_graph->manager()->SetEdge(func_graph->get_return(), 1, final_output_node);
601 }
602
PreBuildDistributedGraph()603 void ParameterServerMode::PreBuildDistributedGraph() {
604 MS_LOG(INFO) << "Start pre-building distribtued graph in Parameter Server mode.";
605 MS_EXCEPTION_IF_NULL(node_labels_);
606 ProcessForSplitOptimizer();
607 MS_LOG(INFO) << "End pre-building distribtued graph in Parameter Server mode.";
608 }
609
DoRpcNodeFusion(InterProcessOpEdgesInfo * comm_edges_ptr)610 FusedInterProcessOpPairMap ParameterServerMode::DoRpcNodeFusion(InterProcessOpEdgesInfo *comm_edges_ptr) {
611 MS_EXCEPTION_IF_NULL(comm_edges_ptr);
612
613 // The edges of server optimizers should be fused with same peers. For example, edges from Worker_0 to Server_0 will
614 // be fused by segments.
615 InterProcessOpEdgesInfo comm_edges_of_server_optimizer = FilterCommEdgesOfServerOptimizer(*comm_edges_ptr);
616 FusedInterProcessOpPairMap optimizer_fused_edges = FuseRpcNodesForSplitOptimizer(comm_edges_of_server_optimizer);
617
618 // The rest of the edges are not fused like edges for EmbeddingLookup, but the FusedInterProcessOpPairMap object
619 // should be created.
620 FusedInterProcessOpPairMap rest_edges = FilterNotServerOptimizerEdges(*comm_edges_ptr);
621 (void)optimizer_fused_edges.insert(rest_edges.cbegin(), rest_edges.cend());
622 return optimizer_fused_edges;
623 }
624
PostBuildDistributedGraph(const InterProcessOpEdgesInfo & comm_edges)625 void ParameterServerMode::PostBuildDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
626 MS_LOG(INFO) << "Start post-building distribtued graph in Parameter Server mode.";
627 MS_EXCEPTION_IF_NULL(node_labels_);
628 // Judge the node role number validation.
629 uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
630 if (worker_num == 0) {
631 MS_LOG(EXCEPTION) << "In PS mode, worker number should be greater than 0.";
632 }
633 uint32_t server_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfServer);
634 if (server_num == 0) {
635 MS_LOG(EXCEPTION) << "In PS mode, server number should be greater than 0.";
636 }
637 // Only multiple worker scenario needs this optimizer.
638 if (worker_num < kMinGradAccumWorkerNum) {
639 return;
640 }
641
642 MS_EXCEPTION_IF_NULL(func_graph_);
643 std::vector<CNodePtr> ps_optimizer_node_list = FilterServerAwareOptimizerList();
644
645 // Duplicate out degrees for ps optimizers because defaultly there's only one edge to the rank 0 worker.
646 for (const auto &ps_optimizer : ps_optimizer_node_list) {
647 for (const auto &edge_info : comm_edges) {
648 if (edge_info.first.src_node == ps_optimizer) {
649 // The optimizer's output should always connect to Send node which is the input of a MakeTuple node.
650 // We need to replace the MakeTuple node with a new one.
651 const auto &origin_send_node = std::get<0>(edge_info.second);
652 std::vector<AnfNodePtr> new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), origin_send_node};
653 AnfNodePtr dst_node = edge_info.first.dst_node;
654 for (uint32_t i = 1; i < worker_num; i++) {
655 OperatorLabel worker_label = {i, distributed::kEnvRoleOfWorker};
656 InterProcessOpEdge edge = {ps_optimizer, node_labels_->at(ps_optimizer), dst_node, worker_label};
657 auto duplicated_send_node = CreateSendNode(func_graph_, edge);
658 (void)node_labels_->insert(std::make_pair(duplicated_send_node, edge.src_label));
659 (void)new_make_tuple_inputs.emplace_back(duplicated_send_node);
660 }
661 auto new_make_tuple_node = func_graph_->NewCNode(new_make_tuple_inputs);
662 new_make_tuple_node->set_abstract(new_make_tuple_inputs.back()->abstract());
663 (void)func_graph_->manager()->Replace(origin_send_node, new_make_tuple_node);
664 }
665 }
666 }
667 MS_LOG(INFO) << "End post-building distribtued graph in Parameter Server mode.";
668 }
669
PostBuildDistributedGraph(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)670 void ParameterServerMode::PostBuildDistributedGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
671 MS_LOG(INFO) << "Start post-building distribtued graph in Parameter Server mode.";
672 MS_EXCEPTION_IF_NULL(node_labels_);
673 // Judge the node role number validation.
674 uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
675 if (worker_num == 0) {
676 MS_LOG(EXCEPTION) << "In PS mode, worker number should be greater than 0.";
677 }
678 uint32_t server_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfServer);
679 if (server_num == 0) {
680 MS_LOG(EXCEPTION) << "In PS mode, server number should be greater than 0.";
681 }
682 // Only multiple worker scenario needs this optimizer.
683 if (worker_num < kMinGradAccumWorkerNum) {
684 return;
685 }
686
687 MS_EXCEPTION_IF_NULL(func_graph_);
688 std::vector<CNodePtr> ps_optimizer_node_list = FilterServerAwareOptimizerList();
689 if (ps_optimizer_node_list.empty()) {
690 MS_LOG(INFO) << "This process has no ps optimizer on it. No need to do post building.";
691 return;
692 }
693
694 // Duplicate out degrees for ps optimizers because defaultly there's only one edge to the rank 0 worker.
695 for (const auto &op_pair_info : fused_inter_process_op_pairs) {
696 const auto &op_pairs = op_pair_info.second;
697 CNodePtr fused_send_node = std::get<0>(op_pairs[0]);
698 // Node's inputs except primtive value node.
699 std::vector<AnfNodePtr> fused_send_node_inputs = fused_send_node->inputs();
700 (void)fused_send_node_inputs.erase(fused_send_node_inputs.cbegin());
701
702 // Only handle the edge whose src_node is optimizer.
703 if (std::find_if(ps_optimizer_node_list.cbegin(), ps_optimizer_node_list.cend(), [&](const auto &ps_optimizer) {
704 return ps_optimizer.get() == fused_send_node_inputs[0].get();
705 }) == ps_optimizer_node_list.cend()) {
706 continue;
707 }
708
709 std::vector<AnfNodePtr> new_make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), fused_send_node};
710 for (uint32_t i = 1; i < worker_num; i++) {
711 std::vector<CNodePtr> new_send_nodes;
712 OperatorLabel worker_label = {i, distributed::kEnvRoleOfWorker};
713 for (size_t j = 0; j < op_pairs.size(); j++) {
714 const auto &src_node = fused_send_node_inputs[j];
715 const auto &dst_node = std::get<3>(op_pairs[j]);
716 InterProcessOpEdge edge = {src_node, node_labels_->at(src_node), dst_node, worker_label};
717 auto duplicated_send_node = CreateSendNode(func_graph_, edge);
718 MS_EXCEPTION_IF_NULL(duplicated_send_node);
719 (void)node_labels_->insert(std::make_pair(duplicated_send_node, edge.src_label));
720 (void)new_send_nodes.emplace_back(duplicated_send_node);
721 }
722 CNodePtr new_fused_send_node = FuseRpcSendNodes(new_send_nodes);
723 MS_EXCEPTION_IF_NULL(new_fused_send_node);
724 (void)new_make_tuple_inputs.emplace_back(new_fused_send_node);
725 }
726 auto new_make_tuple_node = func_graph_->NewCNode(new_make_tuple_inputs);
727 new_make_tuple_node->set_abstract(fused_send_node->abstract());
728 (void)func_graph_->manager()->Replace(fused_send_node, new_make_tuple_node);
729 }
730 MS_LOG(INFO) << "End post-building distribtued graph in Parameter Server mode.";
731 }
732
ProcessForSplitOptimizer()733 void ParameterServerMode::ProcessForSplitOptimizer() {
734 MS_EXCEPTION_IF_NULL(func_graph_);
735 std::vector<CNodePtr> ps_optimizer_node_list = FilterServerAwareOptimizerList();
736
737 // Judge the node role number validation.
738 uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
739 if (worker_num == 0) {
740 MS_LOG(EXCEPTION) << "In PS mode, worker number should be greater than 0.";
741 }
742 uint32_t server_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfServer);
743 if (server_num == 0) {
744 MS_LOG(EXCEPTION) << "In PS mode, server number should be greater than 0.";
745 }
746 // Only multiple worker scenario needs this optimizer.
747 if (worker_num < kMinGradAccumWorkerNum) {
748 return;
749 }
750
751 for (const auto &ps_optimizer : ps_optimizer_node_list) {
752 MS_EXCEPTION_IF_NULL(ps_optimizer);
753 // Load attributes for this optimizer.
754 auto gradient_index = common::AnfAlgo::HasNodeAttr(kAttrGradientInputIndex, ps_optimizer)
755 ? LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(ps_optimizer, kAttrGradientInputIndex))
756 : UINT64_MAX;
757 size_t indices_index = common::AnfAlgo::HasNodeAttr(kAttrIndicesInputIndex, ps_optimizer)
758 ? LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(ps_optimizer, kAttrIndicesInputIndex))
759 : UINT64_MAX;
760 std::string gradient_type = (common::AnfAlgo::HasNodeAttr(kAttrGradientType, ps_optimizer))
761 ? common::AnfAlgo::GetNodeAttr<std::string>(ps_optimizer, kAttrGradientType)
762 : kDenseGradient;
763 if (kGradTypeToAccumOpName.count(gradient_type) == 0) {
764 MS_LOG(EXCEPTION) << "The gradient type " << gradient_type << " is invalid.";
765 }
766
767 const std::string &opt_device_target = GetCNodeTarget(ps_optimizer);
768 for (size_t i = 0; i < common::AnfAlgo::GetInputNum(ps_optimizer); i++) {
769 auto input = common::AnfAlgo::GetInputNode(ps_optimizer, i);
770 // If the input is not a cnode, no inter-process edge is added so no node with multiple inputs should be created.
771 // Unless it's a real input.
772 if (!input->isa<CNode>()) {
773 if (IsOneOfRealGraphInput(func_graph_, input)) {
774 MS_LOG(INFO) << "The input " << i << " of optimizer " << ps_optimizer->fullname_with_scope() << ": "
775 << input->fullname_with_scope() << " is a real input from data.";
776 } else {
777 continue;
778 }
779 }
780
781 if (i == gradient_index) {
782 // Create the node to replace origin gradient which could be a RealDiv node.
783 std::pair<CNodePtr, CNodePtr> grad_accum_nodes = CreateNodesForGradAccumulation(
784 input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, gradient_type, worker_num);
785
786 const auto &accum_node = grad_accum_nodes.first;
787 const auto &real_div_node = grad_accum_nodes.second;
788 func_graph_->manager()->SetEdge(ps_optimizer, i + 1, real_div_node);
789 (void)node_labels_->insert(std::make_pair(accum_node, node_labels_->at(ps_optimizer)));
790 (void)node_labels_->insert(std::make_pair(real_div_node, node_labels_->at(ps_optimizer)));
791 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), accum_node);
792 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), real_div_node);
793 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), accum_node);
794 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), real_div_node);
795 } else if (i == indices_index) {
796 // Create the node to replace origin indices.
797 AnfNodePtr new_indices_input = CreateNodeWithInterProcessEdgeOnPServer(
798 kConcatOpName, input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, worker_num);
799
800 func_graph_->manager()->SetEdge(ps_optimizer, i + 1, new_indices_input);
801 (void)node_labels_->insert(std::make_pair(new_indices_input, node_labels_->at(ps_optimizer)));
802 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), new_indices_input);
803 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), new_indices_input);
804 } else {
805 std::pair<CNodePtr, CNodePtr> make_tuple_get_item_nodes = CreateNodesForMakeTuple(input, worker_num);
806
807 auto &make_tuple_node = make_tuple_get_item_nodes.first;
808 auto &tuple_get_item_node = make_tuple_get_item_nodes.second;
809 func_graph_->manager()->SetEdge(ps_optimizer, i + 1, tuple_get_item_node);
810 (void)node_labels_->insert(std::make_pair(make_tuple_node, node_labels_->at(ps_optimizer)));
811 (void)node_labels_->insert(std::make_pair(tuple_get_item_node, node_labels_->at(ps_optimizer)));
812 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), make_tuple_node);
813 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(opt_device_target), tuple_get_item_node);
814 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), make_tuple_node);
815 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), tuple_get_item_node);
816 }
817 }
818 }
819 }
820
FilterServerAwareOptimizerList()821 std::vector<CNodePtr> ParameterServerMode::FilterServerAwareOptimizerList() {
822 MS_EXCEPTION_IF_NULL(func_graph_);
823 auto return_node = func_graph_->get_return();
824 MS_EXCEPTION_IF_NULL(return_node);
825
826 std::vector<CNodePtr> ps_optim_list;
827 std::vector<AnfNodePtr> nodes = FuncGraph::TopoSort(return_node);
828 for (const auto &node : nodes) {
829 if (!node->isa<CNode>()) {
830 continue;
831 }
832 const auto &cnode = node->cast<CNodePtr>();
833 if (common::AnfAlgo::HasNodeAttr(kAttrUpdateParameter, cnode)) {
834 (void)ps_optim_list.emplace_back(cnode);
835 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeLabel, MakeValue(kPSOptimizerEdgeLabel), cnode);
836 }
837 }
838 return ps_optim_list;
839 }
840
CreateNodesForGradAccumulation(const AnfNodePtr & gradient_input,size_t gradient_input_index,const std::string & gradient_type,size_t total_gradient_number)841 std::pair<CNodePtr, CNodePtr> ParameterServerMode::CreateNodesForGradAccumulation(const AnfNodePtr &gradient_input,
842 size_t gradient_input_index,
843 const std::string &gradient_type,
844 size_t total_gradient_number) {
845 MS_EXCEPTION_IF_NULL(gradient_input);
846
847 if (kGradTypeToAccumOpName.count(gradient_type) == 0) {
848 MS_LOG(EXCEPTION) << "The gradient type " << gradient_type << " is invalid.";
849 }
850 const std::string &accum_node_name = kGradTypeToAccumOpName.at(gradient_type);
851 CNodePtr grad_accum_node = CreateNodeWithInterProcessEdgeOnPServer(accum_node_name, gradient_input,
852 gradient_input_index, total_gradient_number);
853 MS_EXCEPTION_IF_NULL(grad_accum_node);
854
855 CNodePtr grad_mean_node = CreateGradMeanNode(grad_accum_node, total_gradient_number);
856 MS_EXCEPTION_IF_NULL(grad_mean_node);
857 return std::make_pair(grad_accum_node, grad_mean_node);
858 }
859
CreateGradMeanNode(const AnfNodePtr & gradient,size_t divisor)860 CNodePtr ParameterServerMode::CreateGradMeanNode(const AnfNodePtr &gradient, size_t divisor) {
861 MS_EXCEPTION_IF_NULL(gradient);
862
863 // Step 1: Create the value node of divisor. The divisor's value is worker number.
864 auto addn_abstract = gradient->abstract()->cast<abstract::AbstractTensorPtr>();
865 MS_EXCEPTION_IF_NULL(addn_abstract);
866 // Use reciprocal of the divisor so Mul node should be created.
867 auto divisor_tensor =
868 std::make_shared<tensor::Tensor>(1 / static_cast<double>(divisor), addn_abstract->element()->BuildType());
869 MS_EXCEPTION_IF_NULL(divisor_tensor);
870 auto divisor_value_node = NewValueNode(divisor_tensor);
871 MS_EXCEPTION_IF_NULL(divisor_value_node);
872 divisor_value_node->set_abstract(divisor_tensor->ToAbstract());
873
874 // Step 2: Create Mul node.
875 std::vector<AnfNodePtr> real_div_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), gradient,
876 divisor_value_node};
877 CNodePtr grad_mean_node = func_graph_->NewCNode(real_div_inputs);
878 MS_EXCEPTION_IF_NULL(grad_mean_node);
879 grad_mean_node->set_abstract(gradient->abstract());
880 return grad_mean_node;
881 }
882
CreateNodesForMakeTuple(const AnfNodePtr & input,size_t total_inputs_number)883 std::pair<CNodePtr, CNodePtr> ParameterServerMode::CreateNodesForMakeTuple(const AnfNodePtr &input,
884 size_t total_inputs_number) {
885 MS_EXCEPTION_IF_NULL(input);
886 CNodePtr make_tuple_node = CreateNodeWithInterProcessEdgeOnPServer(
887 kMakeTupleOpName, input, (role_ == distributed::kEnvRoleOfWorker) ? rank_id_ : 0, total_inputs_number);
888 MS_EXCEPTION_IF_NULL(make_tuple_node);
889 // For MakeTuple node on Parameter Server, we get the first input as its abstract because the other inputs are
890 // supposed to be the same as the first one.
891 CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, make_tuple_node, kIndex0);
892 return std::make_pair(make_tuple_node, tuple_get_item_node);
893 }
894
CreateNodeWithInterProcessEdgeOnPServer(const std::string & many_to_one_node_name,const AnfNodePtr & real_input,size_t index_of_real_input,uint32_t total_inputs_number)895 CNodePtr ParameterServerMode::CreateNodeWithInterProcessEdgeOnPServer(const std::string &many_to_one_node_name,
896 const AnfNodePtr &real_input,
897 size_t index_of_real_input,
898 uint32_t total_inputs_number) {
899 if (index_of_real_input >= total_inputs_number) {
900 MS_LOG(EXCEPTION) << "The index of real input for " << many_to_one_node_name << " " << index_of_real_input
901 << " is greater or equal to worker number " << total_inputs_number;
902 }
903
904 // Step 1: Create multiple inputs of new node including extra nodes.
905 std::vector<AnfNodePtr> new_node_inputs;
906 new_node_inputs.resize(total_inputs_number);
907 std::vector<AnfNodePtr> mock_node_inputs = {NewValueNode(std::make_shared<Primitive>(
908 IsPrimitiveCNode(real_input, prim::kPrimUpdateState) ? kUpdateStateOpName : kVirtualNode))};
909 for (size_t i = 0; i < new_node_inputs.size(); i++) {
910 new_node_inputs[i] = func_graph_->NewCNode(mock_node_inputs);
911 MS_EXCEPTION_IF_NULL(new_node_inputs[i]);
912 new_node_inputs[i]->set_abstract(real_input->abstract());
913 new_node_inputs[i]->cast<CNodePtr>()->set_fullname_with_scope(real_input->fullname_with_scope());
914
915 // Set operator label for new node's inputs.
916 OperatorLabel input_label = {SizeToUint(i), distributed::kEnvRoleOfWorker};
917 (void)node_labels_->insert(std::make_pair(new_node_inputs[i], input_label));
918 }
919 new_node_inputs[index_of_real_input] = real_input;
920
921 // Step 2: Create the new node.
922 auto new_node_prim = NewValueNode(std::make_shared<Primitive>(many_to_one_node_name));
923 (void)new_node_inputs.insert(new_node_inputs.cbegin(), new_node_prim);
924 if (many_to_one_node_name == kConcatOpName) {
925 // Create axis input for concat.
926 auto axis_value = MakeValue(0L);
927 MS_EXCEPTION_IF_NULL(axis_value);
928 auto axis_value_node = NewValueNode(axis_value);
929 MS_EXCEPTION_IF_NULL(axis_value_node);
930 axis_value_node->set_abstract(axis_value->ToAbstract());
931 (void)new_node_inputs.insert(new_node_inputs.cend(), axis_value_node);
932 }
933
934 auto new_node = func_graph_->NewCNode(new_node_inputs);
935 MS_EXCEPTION_IF_NULL(new_node);
936
937 // Step 3: Set the new node's abstract and attrs.
938 if (many_to_one_node_name == kAddNOpName) {
939 common::AnfAlgo::SetNodeAttr("N", MakeValue(static_cast<int64_t>(total_inputs_number)), new_node);
940 common::AnfAlgo::SetNodeAttr("n", MakeValue(static_cast<int64_t>(total_inputs_number)), new_node);
941 new_node->set_abstract(real_input->abstract());
942 } else if (many_to_one_node_name == kConcatOpName) {
943 auto origin_abs = real_input->abstract()->cast<abstract::AbstractTensorPtr>();
944 MS_EXCEPTION_IF_NULL(origin_abs);
945
946 auto new_abs = origin_abs->Clone()->cast<abstract::AbstractTensorPtr>();
947 ShapeVector new_shape = new_abs->shape()->shape();
948 new_shape[0] = new_shape[0] * static_cast<int64_t>(total_inputs_number);
949 new_abs->shape()->set_shape(new_shape);
950 new_node->set_abstract(new_abs);
951 } else if (many_to_one_node_name == kMakeTupleOpName) {
952 AbstractBasePtrList abstract_list;
953 auto first_input = new_node_inputs.begin();
954 std::advance(first_input, 1);
955 (void)std::for_each(first_input, new_node_inputs.end(),
956 [&](const auto &input) { (void)abstract_list.emplace_back(input->abstract()); });
957 new_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
958 } else {
959 new_node->set_abstract(real_input->abstract());
960 }
961 return new_node;
962 }
963
FuseRpcNodesForSplitOptimizer(const InterProcessOpEdgesInfo & comm_edges_of_server_optimizer)964 FusedInterProcessOpPairMap ParameterServerMode::FuseRpcNodesForSplitOptimizer(
965 const InterProcessOpEdgesInfo &comm_edges_of_server_optimizer) {
966 InterProcessOpPairMap comm_edges_with_same_peer;
967 for (const auto &comm_edge_info : comm_edges_of_server_optimizer) {
968 const InterProcessOpEdge &edge = comm_edge_info.first;
969 const InterProcessOpPair &node_pair = comm_edge_info.second;
970 (void)comm_edges_with_same_peer[{edge.src_label, edge.dst_label, 0}].emplace_back(node_pair);
971 }
972
973 InterProcessOpPairMap comm_edges_segments;
974 for (auto comm_edge_info = comm_edges_with_same_peer.cbegin(); comm_edge_info != comm_edges_with_same_peer.cend();
975 ++comm_edge_info) {
976 InterProcessEdgeWithIndex edge_with_index = comm_edge_info->first;
977 const std::vector<InterProcessOpPair> &op_pair_list = comm_edge_info->second;
978 std::map<size_t, size_t> real_index_to_segment =
979 GetRealIndexToSeg(ps_optimizer_fusion_segments_, op_pair_list.size());
980 if (real_index_to_segment.empty()) {
981 comm_edges_segments[edge_with_index] = op_pair_list;
982 continue;
983 } else {
984 if (real_index_to_segment.size() != op_pair_list.size()) {
985 MS_LOG(EXCEPTION) << "Real index to segment index map is invalid: size not matched.";
986 }
987 for (size_t i = 0; i < op_pair_list.size(); i++) {
988 edge_with_index.index = real_index_to_segment[i];
989 (void)comm_edges_segments[edge_with_index].emplace_back(op_pair_list[i]);
990 }
991 }
992 }
993
994 FusedInterProcessOpPairMap results;
995 for (auto rpc_nodes_fuse_info = comm_edges_segments.begin(); rpc_nodes_fuse_info != comm_edges_segments.end();
996 ++rpc_nodes_fuse_info) {
997 // Reorder the rpc node pairs list. Place monad inputs to the end of the list so that rpc send/recv nodes can be
998 // built.
999 std::vector<InterProcessOpPair> &inter_process_pairs = (*rpc_nodes_fuse_info).second;
1000 std::vector<InterProcessOpPair> monad_pairs;
1001 std::vector<InterProcessOpPair> no_monad_pairs;
1002 (void)std::for_each(inter_process_pairs.begin(), inter_process_pairs.end(), [&](const auto &op_pair) {
1003 if (HasAbstractMonad(std::get<1>(op_pair))) {
1004 (void)monad_pairs.emplace_back(op_pair);
1005 } else {
1006 (void)no_monad_pairs.emplace_back(op_pair);
1007 }
1008 });
1009 (void)no_monad_pairs.insert(no_monad_pairs.cend(), monad_pairs.cbegin(), monad_pairs.cend());
1010 inter_process_pairs = no_monad_pairs;
1011
1012 std::vector<FusedInterProcessOpPair> fused_pairs;
1013 if (!common::GetEnv("fusion2").empty()) {
1014 fused_pairs = FuseCommEdges(inter_process_pairs);
1015 } else {
1016 std::vector<CNodePtr> rpc_send_nodes, rpc_recv_nodes;
1017 (void)std::for_each(inter_process_pairs.begin(), inter_process_pairs.end(),
1018 [&rpc_send_nodes, &rpc_recv_nodes](const auto &node_pair) {
1019 (void)rpc_send_nodes.emplace_back(std::get<0>(node_pair));
1020 (void)rpc_recv_nodes.emplace_back(std::get<1>(node_pair));
1021 });
1022 CNodePtr fused_send_node = FuseRpcSendNodes(rpc_send_nodes);
1023 CNodePtr fused_recv_node = FuseRpcRecvNodes(rpc_recv_nodes);
1024
1025 for (size_t i = 0; i < inter_process_pairs.size(); i++) {
1026 FusedInterProcessOpPair fused_inter_process_pair =
1027 std::make_tuple(fused_send_node, fused_recv_node, i, std::get<2>(inter_process_pairs[i]),
1028 std::get<3>(inter_process_pairs[i]));
1029 (void)fused_pairs.emplace_back(fused_inter_process_pair);
1030 }
1031 }
1032 results[rpc_nodes_fuse_info->first] = fused_pairs;
1033 }
1034 return results;
1035 }
1036
FilterCommEdgesOfServerOptimizer(const InterProcessOpEdgesInfo & comm_edges) const1037 InterProcessOpEdgesInfo ParameterServerMode::FilterCommEdgesOfServerOptimizer(
1038 const InterProcessOpEdgesInfo &comm_edges) const {
1039 InterProcessOpEdgesInfo comm_edges_of_server_optimizer;
1040 for (const auto &edge_info : comm_edges) {
1041 if (edge_info.first.edge_label.label_name == kPSOptimizerEdgeLabel) {
1042 (void)comm_edges_of_server_optimizer.insert(edge_info);
1043 }
1044 }
1045 return comm_edges_of_server_optimizer;
1046 }
1047
FilterNotServerOptimizerEdges(const InterProcessOpEdgesInfo & comm_edges) const1048 FusedInterProcessOpPairMap ParameterServerMode::FilterNotServerOptimizerEdges(
1049 const InterProcessOpEdgesInfo &comm_edges) const {
1050 FusedInterProcessOpPairMap results;
1051 for (const auto &edge_info : comm_edges) {
1052 if (edge_info.first.edge_label.label_name != kPSOptimizerEdgeLabel) {
1053 const InterProcessOpEdge &edge = edge_info.first;
1054 const InterProcessOpPair &node_pair = edge_info.second;
1055
1056 // We use the hash value to make these edges with index unique. So this index has no actual meaning.
1057 size_t edge_index = std::hash<std::string>{}(edge.to_string());
1058 InterProcessEdgeWithIndex edge_with_index = {edge.src_label, edge.dst_label, edge_index};
1059 FusedInterProcessOpPair fused_op_pair = std::make_tuple(std::get<0>(node_pair), std::get<1>(node_pair), 0,
1060 std::get<2>(node_pair), std::get<3>(node_pair));
1061 std::vector<FusedInterProcessOpPair> pair_list = {fused_op_pair};
1062 (void)results.insert(std::make_pair(edge_with_index, pair_list));
1063 }
1064 }
1065 return results;
1066 }
1067
FuseRpcSendNodes(const std::vector<CNodePtr> & rpc_send_nodes)1068 CNodePtr ParameterServerMode::FuseRpcSendNodes(const std::vector<CNodePtr> &rpc_send_nodes) {
1069 if (rpc_send_nodes.empty()) {
1070 MS_LOG(EXCEPTION) << "Rpc send node list is empty.";
1071 }
1072 std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
1073 AbstractBasePtrList abstract_list;
1074 std::string fused_inter_process_edge_name = "";
1075 for (const auto &send_node : rpc_send_nodes) {
1076 MS_EXCEPTION_IF_NULL(send_node);
1077 for (size_t i = 1; i < send_node->size(); i++) {
1078 auto input_i = send_node->inputs()[i];
1079 MS_EXCEPTION_IF_NULL(input_i);
1080 // If the input of send is monad, do not pass it to fused send node.
1081 if (HasAbstractMonad(input_i)) {
1082 continue;
1083 }
1084 (void)send_inputs.emplace_back(input_i);
1085 }
1086 (void)abstract_list.emplace_back(send_node->abstract());
1087 (void)fused_inter_process_edge_name.append(
1088 common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(send_node, kAttrInterProcessEdgeNames).front());
1089 }
1090
1091 CNodePtr fused_send_node = func_graph_->NewCNode(send_inputs);
1092 MS_EXCEPTION_IF_NULL(fused_send_node);
1093 fused_send_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
1094 std::vector<std::string> fused_inter_process_edge_names = {fused_inter_process_edge_name};
1095 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(fused_inter_process_edge_names), fused_send_node);
1096 common::AnfAlgo::CopyNodeAttr(kAttrPrimitiveTarget, rpc_send_nodes[0], fused_send_node);
1097 common::AnfAlgo::CopyNodeAttr(kAttrSendDstRanks, rpc_send_nodes[0], fused_send_node);
1098 common::AnfAlgo::CopyNodeAttr(kAttrSendDstRoles, rpc_send_nodes[0], fused_send_node);
1099 common::AnfAlgo::CopyNodeAttr(kAttrSendSrcNodeName, rpc_send_nodes[0], fused_send_node);
1100 common::AnfAlgo::CopyNodeAttr(kAttrSendDstNodeName, rpc_send_nodes[0], fused_send_node);
1101 return fused_send_node;
1102 }
1103
FuseRpcRecvNodes(const std::vector<CNodePtr> & rpc_recv_nodes)1104 CNodePtr ParameterServerMode::FuseRpcRecvNodes(const std::vector<CNodePtr> &rpc_recv_nodes) {
1105 std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
1106 AbstractBasePtrList abstract_list;
1107 std::string fused_inter_process_edge_name = "";
1108 for (const auto &recv_node : rpc_recv_nodes) {
1109 MS_EXCEPTION_IF_NULL(recv_node);
1110 for (size_t i = 1; i < recv_node->size(); i++) {
1111 auto input_i = recv_node->inputs()[i];
1112 MS_EXCEPTION_IF_NULL(input_i);
1113 // If the input of recv is monad, do not pass it to fused recv node.
1114 if (HasAbstractMonad(input_i)) {
1115 continue;
1116 }
1117 (void)recv_inputs.emplace_back(input_i);
1118 }
1119 (void)abstract_list.emplace_back(recv_node->abstract());
1120 (void)fused_inter_process_edge_name.append(
1121 common::AnfAlgo::GetNodeAttr<std::vector<std::string>>(recv_node, kAttrInterProcessEdgeNames).front());
1122 }
1123 // Add umonad for recv node to update reference.
1124 ValuePtr monad_value = kUMonad;
1125 auto monad_input = NewValueNode(monad_value);
1126 MS_EXCEPTION_IF_NULL(monad_input);
1127 monad_input->set_abstract(monad_value->ToAbstract());
1128 recv_inputs.push_back(monad_input);
1129
1130 CNodePtr fused_recv_node = func_graph_->NewCNode(recv_inputs);
1131 MS_EXCEPTION_IF_NULL(fused_recv_node);
1132 fused_recv_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
1133 std::vector<std::string> fused_inter_process_edge_names = {fused_inter_process_edge_name};
1134 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(fused_inter_process_edge_names), fused_recv_node);
1135 common::AnfAlgo::CopyNodeAttr(kAttrPrimitiveTarget, rpc_recv_nodes[0], fused_recv_node);
1136 common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcRanks, rpc_recv_nodes[0], fused_recv_node);
1137 common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcRoles, rpc_recv_nodes[0], fused_recv_node);
1138 common::AnfAlgo::CopyNodeAttr(kAttrRecvSrcNodeName, rpc_recv_nodes[0], fused_recv_node);
1139 common::AnfAlgo::CopyNodeAttr(kAttrRecvDstNodeName, rpc_recv_nodes[0], fused_recv_node);
1140 return fused_recv_node;
1141 }
1142
FuseCommEdges(const std::vector<InterProcessOpPair> & inter_process_pairs)1143 std::vector<FusedInterProcessOpPair> ParameterServerMode::FuseCommEdges(
1144 const std::vector<InterProcessOpPair> &inter_process_pairs) {
1145 std::vector<FusedInterProcessOpPair> fused_op_pairs;
1146 std::vector<CNodePtr> rpc_send_nodes, rpc_recv_nodes;
1147 std::map<size_t, size_t> indices_map;
1148 for (size_t i = 0; i < inter_process_pairs.size(); i++) {
1149 auto &op_pair = inter_process_pairs[i];
1150 auto reused_send_node =
1151 std::find_if(rpc_send_nodes.begin(), rpc_send_nodes.end(), [&op_pair](const auto &send_node_need_fuse) {
1152 CNodePtr send_node = std::get<0>(op_pair);
1153 auto node_name1 = common::AnfAlgo::GetInputNode(send_node, kIndex0)->fullname_with_scope();
1154 auto node_name2 = common::AnfAlgo::GetInputNode(send_node_need_fuse, kIndex0)->fullname_with_scope();
1155 return node_name1 == node_name2;
1156 });
1157 if (reused_send_node != rpc_send_nodes.end()) {
1158 size_t index = static_cast<size_t>(std::distance(rpc_send_nodes.begin(), reused_send_node));
1159 indices_map[i] = index;
1160 } else {
1161 (void)rpc_send_nodes.emplace_back(std::get<0>(op_pair));
1162 (void)rpc_recv_nodes.emplace_back(std::get<1>(op_pair));
1163 indices_map[i] = rpc_send_nodes.size() - 1;
1164 }
1165 }
1166
1167 CNodePtr fused_send_node = FuseRpcSendNodes(rpc_send_nodes);
1168 CNodePtr fused_recv_node = FuseRpcRecvNodes(rpc_recv_nodes);
1169 for (size_t i = 0; i < inter_process_pairs.size(); i++) {
1170 FusedInterProcessOpPair fused_inter_process_pair =
1171 std::make_tuple(fused_send_node, fused_recv_node, indices_map[i], std::get<2>(inter_process_pairs[i]),
1172 std::get<3>(inter_process_pairs[i]));
1173 (void)fused_op_pairs.emplace_back(fused_inter_process_pair);
1174 }
1175 return fused_op_pairs;
1176 }
1177
GraphSplitter(const FuncGraphPtr & func_graph,uint32_t rank_id,const std::string & role)1178 GraphSplitter::GraphSplitter(const FuncGraphPtr &func_graph, uint32_t rank_id, const std::string &role)
1179 : func_graph_(func_graph),
1180 rank_id_(rank_id),
1181 role_(role),
1182 exec_mode_(nullptr),
1183 this_process_label_({rank_id, role}),
1184 node_labels_{},
1185 need_fuse_rpc_nodes_(true) {
1186 // The distributed strategy is not explicitly defined by user. Distributed module generates the distributed strategy
1187 // and default label according to some flags set by other modules.
1188 mode_ = GenerateStrategy();
1189 default_label_ = {0, distributed::kEnvRoleOfWorker};
1190 }
1191
PreBuildDistributedGraph()1192 void EmbeddingCacheMode::PreBuildDistributedGraph() {
1193 // Only need add embedding cache ops of remote cache.
1194 if (role_ != distributed::kEnvRoleOfPServer) {
1195 return;
1196 }
1197
1198 // 1. Add embedding cache ops of remote cache, and build service-side graph.
1199 AddEmbeddingCacheOps();
1200
1201 // 2. Get node labels.
1202 MS_EXCEPTION_IF_NULL(node_labels_);
1203 node_labels_->clear();
1204
1205 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1206 (void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
1207 MS_EXCEPTION_IF_NULL(node);
1208 if (node->isa<CNode>()) {
1209 CNodePtr cnode = node->cast<CNodePtr>();
1210 MS_EXCEPTION_IF_NULL(cnode);
1211 OperatorLabel label = GetNodeLabel(cnode);
1212 (void)node_labels_->emplace(node, label);
1213 }
1214 });
1215 }
1216
AddEmbeddingCacheOps() const1217 void EmbeddingCacheMode::AddEmbeddingCacheOps() const {
1218 uint32_t worker_num = ClusterContext::instance()->node_num(distributed::kEnvRoleOfWorker);
1219 if (worker_num == 0) {
1220 MS_LOG(EXCEPTION) << "In embedding cache mode, worker number should be greater than 0.";
1221 }
1222
1223 // Build service-side graph.
1224 std::shared_ptr<parallel::PsEmbeddingCacheInserter> embedding_cache_inserter =
1225 std::make_shared<parallel::PsEmbeddingCacheInserter>(func_graph_, static_cast<int64_t>(rank_id_), role_,
1226 worker_num);
1227 if (!embedding_cache_inserter->Run()) {
1228 MS_LOG(EXCEPTION) << "Insert ps embedding cache failed.";
1229 }
1230 }
1231
GetNodeLabel(const AnfNodePtr & node) const1232 OperatorLabel EmbeddingCacheMode::GetNodeLabel(const AnfNodePtr &node) const {
1233 MS_EXCEPTION_IF_NULL(node);
1234 if (!node->isa<CNode>()) {
1235 MS_LOG(EXCEPTION) << "Only CNode has distributed split label.";
1236 }
1237
1238 CNodePtr cnode = node->cast<CNodePtr>();
1239 auto prim_node = cnode->input(0);
1240 if (IsValueNode<Primitive>(prim_node)) {
1241 auto prim = GetValueNode<PrimitivePtr>(prim_node);
1242 MS_EXCEPTION_IF_NULL(prim);
1243 if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
1244 MS_LOG(INFO) << "CNode which has distributed split label: " << cnode->fullname_with_scope();
1245 uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(prim->GetAttr(distributed::kOpLabelRankId)));
1246 std::string ms_role = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
1247 return {rank_id, ms_role};
1248 }
1249 } else {
1250 // Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
1251 if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
1252 uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(cnode->GetAttr(distributed::kOpLabelRankId)));
1253 std::string ms_role = GetValue<std::string>(cnode->GetAttr(distributed::kOpLabelRole));
1254 return {rank_id, ms_role};
1255 }
1256 }
1257 return {rank_id_, role_};
1258 }
1259
~GraphSplitter()1260 GraphSplitter::~GraphSplitter() { node_labels_.clear(); }
1261
Run()1262 void GraphSplitter::Run() {
1263 MS_EXCEPTION_IF_NULL(func_graph_);
1264 MS_EXCEPTION_IF_NULL(func_graph_->manager());
1265
1266 // Step 1: Dye all the nodes of the whole func_graph_.
1267 DyeGraph();
1268 // If all nodes are all on this process, no need to split the graph. So return.
1269 if (!NeedSplitGraph()) {
1270 MS_LOG(INFO) << "All nodes are on this process so there's no need to build and split distributed graph.";
1271 return;
1272 }
1273
1274 // Step 2: Create exec_mode_ according to the current execution mode.
1275 CreateExecutionMode();
1276
1277 // If this is general mode but no label is set, do not split graph to avoid unexpected optimizing out.
1278 if (mode_ == distributed::DistExecutionMode::kGeneralMode && !GraphHasLabel(func_graph_)) {
1279 MS_LOG(INFO) << "This graph has no label on it in general mode. So no need to split.";
1280 return;
1281 }
1282
1283 // Step 3: Prebuild the distributed graph before it gets split.
1284 exec_mode_->PreBuildDistributedGraph();
1285
1286 if (!NeedSplitGraph()) {
1287 MS_LOG(INFO) << "All nodes are on this precoess so there's no need to build and split distributed graph.";
1288 return;
1289 }
1290
1291 // For TupleGetItem nodes, their label should be reset for good splitting performance.
1292 ReassignTupleGetItemNodeLabel();
1293
1294 if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
1295 // Only use ref sync mechanism when in general mode.
1296 ProcessRefNodes();
1297 // Add some control edges between different labels.
1298 AddExtraControlEdgeAcrossProcess();
1299 }
1300
1301 // Step 4: Create inter-process operators for segments with different labels.
1302 InterProcessOpEdgesInfo comm_edges = GenerateInterProcessOperators();
1303
1304 need_fuse_rpc_nodes_ = common::GetEnv(kEnvNeedFusion).empty() ? false : true;
1305 if (need_fuse_rpc_nodes_) {
1306 // Step 5: Fuse the rpc nodes to improve performance.
1307 const FusedInterProcessOpPairMap &fused_inter_process_op_pairs = exec_mode_->DoRpcNodeFusion(&comm_edges);
1308
1309 // Step 6: Add dependency and eliminate extra nodes for fused rpc nodes.
1310 SplitGraph(fused_inter_process_op_pairs);
1311
1312 // Step 7: Postbuild the graph after splitting with fused edges.
1313 exec_mode_->PostBuildDistributedGraph(fused_inter_process_op_pairs);
1314 } else {
1315 // Step 5: Generate the node segments with different labels.
1316 std::vector<SplitGraphSegment> segments = GenerateSplitSegments();
1317 // If the segment number is 0, there will be no distributed execution.
1318 if (segments.empty()) {
1319 return;
1320 }
1321
1322 // Step 6: Split the graph and eliminate extra nodes.
1323 SplitGraph(segments, comm_edges);
1324
1325 // Step 7: Postbuild the graph after splitting.
1326 exec_mode_->PostBuildDistributedGraph(comm_edges);
1327 }
1328 // Only eliminate the data-sync node pairs in general mode.
1329 if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
1330 EliminateDataSyncNode();
1331 EliminateControlEdgeNode();
1332 }
1333 }
1334
DyeGraph()1335 void GraphSplitter::DyeGraph() {
1336 MS_EXCEPTION_IF_NULL(func_graph_);
1337 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1338 (void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
1339 MS_EXCEPTION_IF_NULL(node);
1340 // Mark all nodes with original label at the beginning. This means the node is supposed to be on the process with
1341 // default_label_.
1342 node_labels_[node] = default_label_;
1343 if (node->isa<CNode>()) {
1344 // For CNodes, mark them with the label passed by frontend if has one.
1345 CNodePtr cnode = node->cast<CNodePtr>();
1346 MS_EXCEPTION_IF_NULL(cnode);
1347 OperatorLabel label = GetSplitLabel(cnode);
1348 node_labels_[node] = label;
1349 }
1350
1351 // If the node's label is the same as this process's, set its label to this_process_label_.
1352 if (this_process_label_.LooseEqual(node_labels_[node], mode_)) {
1353 node_labels_[node] = this_process_label_;
1354 }
1355 });
1356 }
1357
CreateExecutionMode()1358 void GraphSplitter::CreateExecutionMode() {
1359 MS_EXCEPTION_IF_NULL(func_graph_);
1360 if (node_labels_.empty()) {
1361 MS_LOG(EXCEPTION) << "Must dye the original graph before creating execution mode.";
1362 }
1363 if (mode_ == distributed::DistExecutionMode::kPSMode) {
1364 exec_mode_ = std::make_unique<ParameterServerMode>(func_graph_, &node_labels_, rank_id_, role_);
1365 } else if (mode_ == distributed::DistExecutionMode::kEmbeddingCacheMode) {
1366 exec_mode_ = std::make_unique<EmbeddingCacheMode>(func_graph_, &node_labels_, rank_id_, role_);
1367 } else if (mode_ == distributed::DistExecutionMode::kParallelMode) {
1368 exec_mode_ = std::make_unique<ParallelMode>(func_graph_, &node_labels_, rank_id_, role_);
1369 } else if (mode_ == distributed::DistExecutionMode::kGeneralMode) {
1370 exec_mode_ = std::make_unique<GeneralMode>(func_graph_, &node_labels_, rank_id_, role_);
1371 }
1372 MS_EXCEPTION_IF_NULL(exec_mode_);
1373 }
1374
GenerateSplitSegments()1375 std::vector<SplitGraphSegment> GraphSplitter::GenerateSplitSegments() {
1376 MS_EXCEPTION_IF_NULL(func_graph_);
1377 auto return_node = func_graph_->get_return();
1378 MS_EXCEPTION_IF_NULL(return_node);
1379 std::vector<AnfNodePtr> nodes = FuncGraph::TopoSort(return_node);
1380
1381 std::vector<SplitGraphSegment> results;
1382 SplitGraphSegment segment;
1383 OperatorLabel last_label = this_process_label_;
1384 segment.label = last_label;
1385 for (auto &n : nodes) {
1386 if (!n->isa<CNode>()) {
1387 continue;
1388 }
1389 auto cnode_split_label = node_labels_[n];
1390 // If this node's label is not the same as last node's, create a segment from 'segment_nodes'.
1391 if (cnode_split_label != last_label && !segment.nodes.empty()) {
1392 (void)results.emplace_back(segment);
1393 segment.nodes.clear();
1394 }
1395 // Mark the last label.
1396 last_label = cnode_split_label;
1397 segment.label = cnode_split_label;
1398 (void)segment.nodes.emplace_back(n);
1399 }
1400
1401 // Add the last segment.
1402 (void)results.emplace_back(segment);
1403 MS_LOG(INFO) << "Segments number with different distributed split labels is " << results.size();
1404 return results;
1405 }
1406
ReassignTupleGetItemNodeLabel()1407 void GraphSplitter::ReassignTupleGetItemNodeLabel() {
1408 MS_EXCEPTION_IF_NULL(func_graph_);
1409 AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1410 for (const auto &node : all_nodes) {
1411 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1412 node_labels_[node] = RecursiveSetTupeGetItemLabel(node->cast<CNodePtr>());
1413 }
1414 }
1415 }
1416
RecursiveSetTupeGetItemLabel(const CNodePtr & tuple_get_item_node)1417 OperatorLabel GraphSplitter::RecursiveSetTupeGetItemLabel(const CNodePtr &tuple_get_item_node) {
1418 // Return if this node has already been visited.
1419 if (visited_tuple_get_item_nodes_.count(tuple_get_item_node) != 0) {
1420 if (NodeHasLabel(tuple_get_item_node)) {
1421 return node_labels_[tuple_get_item_node];
1422 } else {
1423 MS_LOG(EXCEPTION) << "TupeGetItem node " << tuple_get_item_node->fullname_with_scope() << " has no lebel.";
1424 }
1425 }
1426
1427 visited_tuple_get_item_nodes_[tuple_get_item_node] = true;
1428 auto tuple_input = common::AnfAlgo::GetInputNode(tuple_get_item_node, kIndex0);
1429 OperatorLabel tuple_get_item_label;
1430 if (IsPrimitiveCNode(tuple_input, prim::kPrimTupleGetItem)) {
1431 // If TupleGetItem's input is a TupleGetItem node, recursively trace up and get a proper input's label.
1432 tuple_get_item_label = RecursiveSetTupeGetItemLabel(tuple_input->cast<CNodePtr>());
1433 node_labels_[tuple_input] = tuple_get_item_label;
1434 } else {
1435 // Set TupleGetItem's label the same as its input so it's easier to split.
1436 tuple_get_item_label = node_labels_[tuple_input];
1437 }
1438 return tuple_get_item_label;
1439 }
1440
ProcessRefNodes()1441 void GraphSplitter::ProcessRefNodes() {
1442 MS_EXCEPTION_IF_NULL(func_graph_);
1443 AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1444 // Traverse all nodes and find each nodes with side effect.
1445 CNodePtrList cnodes_with_side_effect = GetSideEffectNodeList(all_nodes);
1446 for (const auto &cnode : cnodes_with_side_effect) {
1447 // Filter out all ref inputs which need to be synchronized between different processes.
1448 AnfNodePtrList ref_inputs = GetRefInputs(cnode);
1449 // Get the user node(UpdateState) of side effect node.
1450 CNodePtr update_state_node = FindNextUpdateStateNode(func_graph_, cnode);
1451 MS_EXCEPTION_IF_NULL(update_state_node);
1452
1453 // The key method to keep the correctness of reference nodes across computing graph nodes.
1454 AddDataSyncNode(cnode, update_state_node, ref_inputs);
1455 }
1456 }
1457
AddExtraControlEdgeAcrossProcess()1458 void GraphSplitter::AddExtraControlEdgeAcrossProcess() { AddControlEdgeForProcessWithoutIndegree(); }
1459
GenerateInterProcessOperators()1460 InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOperators() {
1461 InterProcessOpEdgesInfo comm_edges;
1462 MS_EXCEPTION_IF_NULL(func_graph_);
1463 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1464 for (auto &node : all_nodes) {
1465 MS_EXCEPTION_IF_NULL(node);
1466 // Only support to split CNode to other process.
1467 if (!node->isa<CNode>()) {
1468 continue;
1469 }
1470
1471 // Generating send/recv nodes for each nodes' inputs will be enough.
1472 auto node_inputs_comm_edges = GenerateInterProcessOpsForNodeInputs(node);
1473 (void)comm_edges.insert(node_inputs_comm_edges.cbegin(), node_inputs_comm_edges.cend());
1474 }
1475 MS_LOG(INFO) << "The communication edge number is " << comm_edges.size();
1476 return comm_edges;
1477 }
1478
SplitGraph(const std::vector<SplitGraphSegment> & segments,const InterProcessOpEdgesInfo & comm_edges)1479 void GraphSplitter::SplitGraph(const std::vector<SplitGraphSegment> &segments,
1480 const InterProcessOpEdgesInfo &comm_edges) {
1481 // Step 1: Traverse all the segments to add Depend for this process's graph.
1482 // The list of corresponding in and out degrees. In another word, the map between one segments' input send nodes. and
1483 // output recv nodes.
1484 InOutDegreeList in_out_degree_list = GenerateInOutDegreeList(segments, comm_edges);
1485 if (in_out_degree_list.empty()) {
1486 MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
1487 auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
1488 (void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
1489 return;
1490 }
1491
1492 // Step 2: Add dependency between communication edges on this process.
1493 AddDependencyBetweenEdges(comm_edges);
1494
1495 // Step 3: Eliminate nodes not on this process.
1496 EliminateExtraNodes(comm_edges);
1497 }
1498
SplitGraph(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)1499 void GraphSplitter::SplitGraph(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
1500 if (fused_inter_process_op_pairs.empty()) {
1501 MS_LOG(WARNING) << "After splitting, this process has no graph on it. So optimize out the whole graph.";
1502 auto return_value_node = CreateReplacedOutputNode(func_graph_, func_graph_->output());
1503 (void)func_graph_->manager()->Replace(func_graph_->output(), return_value_node);
1504 return;
1505 }
1506
1507 // Step 1: Replace origin nodes with recv nodes.
1508 ReplaceOriginNodesWithRecv(fused_inter_process_op_pairs);
1509
1510 // Step 2: Connect output for send nodes.
1511 AddDependencyForSend(fused_inter_process_op_pairs);
1512 }
1513
AddDataSyncNode(const CNodePtr & side_effect_node,const CNodePtr & update_state_node,const AnfNodePtrList & ref_nodes)1514 void GraphSplitter::AddDataSyncNode(const CNodePtr &side_effect_node, const CNodePtr &update_state_node,
1515 const AnfNodePtrList &ref_nodes) {
1516 MS_EXCEPTION_IF_NULL(func_graph_);
1517 MS_EXCEPTION_IF_NULL(side_effect_node);
1518 MS_EXCEPTION_IF_NULL(update_state_node);
1519
1520 MS_EXCEPTION_IF_CHECK_FAIL(
1521 (node_labels_.count(side_effect_node) != 0),
1522 "The node label for side effect node " + side_effect_node->fullname_with_scope() + " is not set.");
1523 auto side_effect_node_label = node_labels_[side_effect_node];
1524
1525 for (const auto &ref : ref_nodes) {
1526 std::set<OperatorLabel> diff_labels;
1527 for (const auto &user : func_graph_->manager()->node_users()[ref]) {
1528 const auto &user_node = user.first;
1529 MS_LOG(DEBUG) << "The user of ref " << ref->fullname_with_scope() << " is " << user_node->fullname_with_scope()
1530 << ", side-effect node label: " << side_effect_node_label.to_string()
1531 << ", user node label: " << node_labels_[user_node].to_string();
1532 if (node_labels_[user_node] != side_effect_node_label) {
1533 (void)diff_labels.insert(node_labels_[user_node]);
1534 } else {
1535 // If the user node is Load, we need to find one next user of it so the node could be correctly split.
1536 if (IsPrimitiveCNode(user_node, prim::kPrimLoad)) {
1537 for (const auto &load_user : func_graph_->manager()->node_users()[user_node]) {
1538 const auto &load_user_node = load_user.first;
1539 MS_LOG(DEBUG) << "Load user is " << load_user_node
1540 << ", label: " << node_labels_[load_user_node].to_string();
1541 if (node_labels_[load_user_node] != side_effect_node_label) {
1542 (void)diff_labels.insert(node_labels_[load_user_node]);
1543 }
1544 }
1545 }
1546 }
1547 }
1548 // If the ref is used in multiple compute graph nodes, it needs to be synchronized.
1549 if (diff_labels.empty()) {
1550 MS_LOG(INFO) << "No need to synchronize ref node " << ref->fullname_with_scope()
1551 << " because the user nodes are on the same process.";
1552 continue;
1553 }
1554
1555 // Create data-sync nodes and connect them to UpdateState node.
1556 auto data_sync_node_list = CreateDataSyncNodes(side_effect_node, ref, diff_labels);
1557 for (const auto &node_pair : data_sync_node_list) {
1558 CNodePtr src_node = node_pair.first;
1559 CNodePtr dst_node = node_pair.second;
1560 func_graph_->manager()->AddEdge(update_state_node, dst_node);
1561 }
1562 }
1563 }
1564
CreateDataSyncNodes(const CNodePtr & side_effect_node,const AnfNodePtr & ref,const std::set<OperatorLabel> & diff_labels)1565 DataSyncNodePairList GraphSplitter::CreateDataSyncNodes(const CNodePtr &side_effect_node, const AnfNodePtr &ref,
1566 const std::set<OperatorLabel> &diff_labels) {
1567 MS_EXCEPTION_IF_NULL(side_effect_node);
1568 MS_EXCEPTION_IF_NULL(ref);
1569
1570 DataSyncNodePairList result;
1571 for (const auto &label : diff_labels) {
1572 // Data sync src node.
1573 std::vector<AnfNodePtr> sync_src_node_inputs = {
1574 NewValueNode(std::make_shared<Primitive>(distributed::kDataSyncSrcOpName))};
1575 (void)sync_src_node_inputs.emplace_back(ref);
1576 (void)sync_src_node_inputs.emplace_back(side_effect_node);
1577 CNodePtr sync_src_node = func_graph_->NewCNode(sync_src_node_inputs);
1578 MS_EXCEPTION_IF_NULL(sync_src_node);
1579 sync_src_node->set_abstract(ref->abstract());
1580 node_labels_[sync_src_node] = node_labels_[side_effect_node];
1581
1582 // Data sync dst node.
1583 std::vector<AnfNodePtr> sync_dst_node_inputs = {
1584 NewValueNode(std::make_shared<Primitive>(distributed::kDataSyncDstOpName))};
1585 (void)sync_dst_node_inputs.emplace_back(sync_src_node);
1586 CNodePtr sync_dst_node = func_graph_->NewCNode(sync_dst_node_inputs);
1587 MS_EXCEPTION_IF_NULL(sync_dst_node);
1588 auto fake_value = CreateFakeValueNode(false);
1589 MS_EXCEPTION_IF_NULL(fake_value);
1590 sync_dst_node->set_abstract(fake_value->abstract());
1591 node_labels_[sync_dst_node] = label;
1592
1593 MS_LOG(INFO) << "Data sync pair: " << sync_src_node->fullname_with_scope() << "_"
1594 << node_labels_[sync_src_node].to_string() << "->" << sync_dst_node->fullname_with_scope() << "_"
1595 << label.to_string();
1596 result.push_back(std::make_pair(sync_src_node, sync_dst_node));
1597 }
1598 return result;
1599 }
1600
AddControlEdgeForProcessWithoutIndegree()1601 void GraphSplitter::AddControlEdgeForProcessWithoutIndegree() {
1602 (void)std::for_each(node_labels_.begin(), node_labels_.end(),
1603 [this](const auto &node_label_pair) { (void)all_labels_.insert(node_label_pair.second); });
1604
1605 std::set<OperatorLabel> labels_has_indegree;
1606 AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1607 for (const auto &node : all_nodes) {
1608 if (!node->isa<CNode>()) {
1609 continue;
1610 }
1611 auto cnode = node->cast<CNodePtr>();
1612 for (size_t i = kIndex1; i < cnode->size(); ++i) {
1613 const auto &input = cnode->inputs().at(i);
1614 if (NodeHasLabel(input) && NodeHasLabel(cnode) && node_labels_[input] != node_labels_[cnode] &&
1615 input->isa<CNode>()) {
1616 MS_LOG(DEBUG) << "Label " << node_labels_[cnode].to_string() << " has indegree from label "
1617 << node_labels_[input].to_string() << ", edge: " << input->fullname_with_scope() << " to "
1618 << cnode->fullname_with_scope();
1619 (void)labels_has_indegree.insert(node_labels_[cnode]);
1620 }
1621 }
1622 }
1623
1624 ControlEdgeNodePairList control_edge_node_pair_list;
1625 for (const OperatorLabel &label : all_labels_) {
1626 // If this label has no indegree, add extra control edge nodes.
1627 if (labels_has_indegree.count(label) == 0) {
1628 ControlEdgeNodePair control_edge_nodes = CreateControlEdgeNode(default_label_, label);
1629 (void)control_edge_node_pair_list.emplace_back(control_edge_nodes);
1630 }
1631 }
1632
1633 if (!control_edge_node_pair_list.empty()) {
1634 // Connect the dangling control dst nodes to the output.
1635 AnfNodePtrList make_tuple_inputs;
1636 (void)std::for_each(control_edge_node_pair_list.begin(), control_edge_node_pair_list.end(),
1637 [&make_tuple_inputs](const auto &node_pair) {
1638 CNodePtr control_dst_node = node_pair.second;
1639 (void)make_tuple_inputs.emplace_back(control_dst_node);
1640 });
1641
1642 // Make tuple for all control-edge dst nodes.
1643 MS_EXCEPTION_IF_NULL(func_graph_);
1644 auto tuple_of_control_dst_nodes = CreateMakeTupleNode(func_graph_, make_tuple_inputs);
1645 MS_EXCEPTION_IF_NULL(tuple_of_control_dst_nodes);
1646 node_labels_[tuple_of_control_dst_nodes] = default_label_;
1647
1648 // Add dependency to the Return node so control-edge nodes won't be optimized out.
1649 AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), func_graph_->output(), tuple_of_control_dst_nodes};
1650 auto final_output_node = func_graph_->NewCNode(depend_inputs);
1651 MS_EXCEPTION_IF_NULL(final_output_node);
1652 node_labels_[final_output_node] = default_label_;
1653
1654 final_output_node->set_abstract(func_graph_->output()->abstract());
1655 (void)func_graph_->manager()->SetEdge(func_graph_->get_return(), kIndex1, final_output_node);
1656 }
1657 return;
1658 }
1659
CreateControlEdgeNode(const OperatorLabel & src_label,const OperatorLabel & dst_label)1660 ControlEdgeNodePair GraphSplitter::CreateControlEdgeNode(const OperatorLabel &src_label,
1661 const OperatorLabel &dst_label) {
1662 // Control src node's input is a value node. It has not practical meaning.
1663 auto fake_tensor = std::make_shared<tensor::Tensor>(1.0);
1664 MS_EXCEPTION_IF_NULL(fake_tensor);
1665 auto fake_value = NewValueNode(fake_tensor);
1666 MS_EXCEPTION_IF_NULL(fake_value);
1667 fake_value->set_abstract(fake_tensor->ToAbstract());
1668
1669 AnfNodePtrList control_src_inputs = {NewValueNode(std::make_shared<Primitive>(distributed::kControlSrcOpName)),
1670 fake_value};
1671 CNodePtr control_src_node = func_graph_->NewCNode(control_src_inputs);
1672 MS_EXCEPTION_IF_NULL(control_src_node);
1673 control_src_node->set_abstract(fake_value->abstract());
1674 node_labels_[control_src_node] = src_label;
1675
1676 // Control dst node's input is control src node.
1677 AnfNodePtrList control_dst_inputs = {NewValueNode(std::make_shared<Primitive>(distributed::kControlDstOpName)),
1678 control_src_node};
1679 CNodePtr control_dst_node = func_graph_->NewCNode(control_dst_inputs);
1680 MS_EXCEPTION_IF_NULL(control_dst_node);
1681 control_dst_node->set_abstract(control_src_node->abstract());
1682 node_labels_[control_dst_node] = dst_label;
1683
1684 // At this phase, the control_dst_node is still a dangling node. We need to connect it to the output to avoid
1685 // optimizing out.
1686 return std::make_pair(control_src_node, control_dst_node);
1687 }
1688
EliminateDataSyncNode()1689 void GraphSplitter::EliminateDataSyncNode() {
1690 MS_EXCEPTION_IF_NULL(func_graph_);
1691 AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1692 for (const auto &node : all_nodes) {
1693 MS_EXCEPTION_IF_NULL(node);
1694 if (!node->isa<CNode>()) {
1695 continue;
1696 }
1697
1698 auto cnode = node->cast<CNodePtr>();
1699 MS_EXCEPTION_IF_NULL(cnode);
1700 if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kDataSyncSrcOpName) {
1701 if (cnode->size() != kSizeThree) {
1702 MS_LOG(EXCEPTION) << "Node DataSyncSrc's input number should be 3, but got " << cnode->size();
1703 }
1704 // The first input is parameter and the second input is side effect node.
1705 auto param_node = cnode->inputs()[kIndex1];
1706 MS_EXCEPTION_IF_NULL(param_node);
1707 auto side_effect_node = cnode->inputs()[kIndex2];
1708 MS_EXCEPTION_IF_NULL(side_effect_node);
1709 MS_LOG(DEBUG) << "Parameter node is " << param_node->fullname_with_scope() << ", side effect node is "
1710 << side_effect_node->fullname_with_scope();
1711
1712 AnfNodePtrList update_state_inputs = {side_effect_node};
1713 CNodePtr update_state_node = CreateUpdateStateNode(func_graph_, update_state_inputs);
1714 MS_EXCEPTION_IF_NULL(update_state_node);
1715
1716 // For parameter, connect it to a 'Load' node so that the control arrow could be correctly linked.
1717 AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), param_node, update_state_node};
1718
1719 auto load_node_replace_data_sync_src = func_graph_->NewCNode(load_inputs);
1720 MS_EXCEPTION_IF_NULL(load_node_replace_data_sync_src);
1721 load_node_replace_data_sync_src->set_abstract(cnode->abstract());
1722 (void)func_graph_->manager()->Replace(cnode, load_node_replace_data_sync_src);
1723 } else if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kDataSyncDstOpName) {
1724 if (cnode->size() != kSizeTwo) {
1725 MS_LOG(EXCEPTION) << "Node DataSyncDst's input number should be 2, but got " << cnode->size();
1726 }
1727 auto input_node = cnode->inputs()[kIndex1];
1728 MS_EXCEPTION_IF_NULL(input_node);
1729
1730 auto users = func_graph_->manager()->node_users()[cnode];
1731 for (const auto &user_pair : users) {
1732 auto user_node = user_pair.first;
1733 int input_index = user_pair.second;
1734 func_graph_->manager()->SetEdge(user_node, input_index, input_node);
1735 }
1736 }
1737 }
1738 }
1739
EliminateControlEdgeNode()1740 void GraphSplitter::EliminateControlEdgeNode() {
1741 MS_EXCEPTION_IF_NULL(func_graph_);
1742 AnfNodePtrList all_nodes = DeepScopedGraphSearch(func_graph_->get_return());
1743 for (const auto &node : all_nodes) {
1744 MS_EXCEPTION_IF_NULL(node);
1745 if (!node->isa<CNode>()) {
1746 continue;
1747 }
1748
1749 auto cnode = node->cast<CNodePtr>();
1750 MS_EXCEPTION_IF_NULL(cnode);
1751 if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kControlSrcOpName) {
1752 // ControlSrc->RpcSend is converted to FakeValue->RpcSend.
1753 auto fake_value_node = CreateFakeValueNode(false);
1754 MS_EXCEPTION_IF_NULL(fake_value_node);
1755 (void)func_graph_->manager()->Replace(cnode, fake_value_node);
1756 } else if (common::AnfAlgo::GetCNodeName(cnode) == distributed::kControlDstOpName) {
1757 if (cnode->size() != kSizeTwo) {
1758 MS_LOG(EXCEPTION) << "Node DataSyncDst's input number should be 2, but got " << cnode->size();
1759 }
1760 auto input_node = cnode->inputs()[kIndex1];
1761 MS_EXCEPTION_IF_NULL(input_node);
1762 (void)func_graph_->manager()->Replace(cnode, input_node);
1763 }
1764 }
1765 }
1766
DumpDistributedGraph(const InterProcessOpEdgesInfo & comm_edges)1767 void GraphSplitter::DumpDistributedGraph(const InterProcessOpEdgesInfo &comm_edges) {
1768 // Traverse all the segments to add Depend for this process's graph.
1769 for (const auto &edge : comm_edges) {
1770 auto send_recv_pair = edge.second;
1771 auto send_node = std::get<0>(send_recv_pair);
1772 auto recv_node = std::get<1>(send_recv_pair);
1773 auto user_node = std::get<2>(send_recv_pair);
1774 auto user_node_index = std::get<3>(send_recv_pair);
1775 func_graph_->manager()->SetEdge(recv_node, 1, send_node);
1776 func_graph_->manager()->SetEdge(user_node, user_node_index, recv_node);
1777 }
1778 MS_LOG(INFO) << "Cut graph without eliminating nodes.";
1779 draw::Draw("single_node_graph.dot", func_graph_);
1780 }
1781
GetSplitLabel(const AnfNodePtr & node)1782 OperatorLabel GraphSplitter::GetSplitLabel(const AnfNodePtr &node) {
1783 MS_EXCEPTION_IF_NULL(node);
1784 if (!node->isa<CNode>()) {
1785 MS_LOG(EXCEPTION) << "Only CNode has distributed split label.";
1786 }
1787 CNodePtr cnode = node->cast<CNodePtr>();
1788 MS_EXCEPTION_IF_NULL(cnode);
1789 auto prim_node = cnode->input(0);
1790 if (IsValueNode<Primitive>(prim_node)) {
1791 TransformPrimAttrToAttr(cnode);
1792 auto prim = GetValueNode<PrimitivePtr>(prim_node);
1793 MS_EXCEPTION_IF_NULL(prim);
1794 if (prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole)) {
1795 MS_LOG(INFO) << "CNode which has distributed split label: " << cnode->fullname_with_scope();
1796 uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(prim->GetAttr(distributed::kOpLabelRankId)));
1797 std::string ms_role = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
1798 return {rank_id, ms_role};
1799 }
1800 } else {
1801 // Get label for call node, 'call' node hasn't primitive to save attrs, so get attrs of 'call' from cnode.
1802 if (cnode->HasAttr(distributed::kOpLabelRankId) && cnode->HasAttr(distributed::kOpLabelRole)) {
1803 uint32_t rank_id = static_cast<uint32_t>(GetValue<int64_t>(cnode->GetAttr(distributed::kOpLabelRankId)));
1804 std::string ms_role = GetValue<std::string>(cnode->GetAttr(distributed::kOpLabelRole));
1805 return {rank_id, ms_role};
1806 }
1807 }
1808 return default_label_;
1809 }
1810
GenerateInterProcessOpsForNodeInputs(const AnfNodePtr & node)1811 InterProcessOpEdgesInfo GraphSplitter::GenerateInterProcessOpsForNodeInputs(const AnfNodePtr &node) {
1812 MS_EXCEPTION_IF_NULL(func_graph_);
1813 MS_EXCEPTION_IF_NULL(node);
1814 CNodePtr cnode = node->cast<CNodePtr>();
1815 MS_EXCEPTION_IF_NULL(cnode);
1816 InterProcessOpEdgesInfo comm_edges;
1817 for (size_t i = 1; i < cnode->size(); i++) {
1818 auto input_i = cnode->inputs()[i];
1819 MS_EXCEPTION_IF_NULL(input_i);
1820
1821 // If the input's not a cnode, or its label is the same as this node's, or the input is 'Load' node for parameter,
1822 // there's no need to add communication nodes.
1823 if (!input_i->isa<CNode>() || IsNodesWithSameLabel(input_i, cnode) ||
1824 common::AnfAlgo::GetCNodeName(input_i) == "Load") {
1825 if (IsOneOfRealGraphInput(func_graph_, input_i) && !IsNodesWithSameLabel(input_i, cnode)) {
1826 MS_LOG(INFO) << "The input " << input_i->fullname_with_scope() << " needs to be split.";
1827 } else {
1828 continue;
1829 }
1830 }
1831
1832 InterProcessEdgeLabel edge_label = GenerateEdgeLabel(input_i, cnode);
1833 InterProcessOpEdge edge = {input_i, node_labels_[input_i], cnode, node_labels_[cnode], edge_label};
1834
1835 auto send_node = CreateSendNode(func_graph_, edge);
1836 MS_EXCEPTION_IF_NULL(send_node);
1837 // The label should be the same as the node which will 'launch' Send node.
1838 node_labels_[send_node] = edge.src_label;
1839
1840 auto recv_node = CreateRecvNode(func_graph_, edge);
1841 MS_EXCEPTION_IF_NULL(recv_node);
1842 // The label should be the same as the node which Receives the 'input'.
1843 node_labels_[recv_node] = edge.dst_label;
1844
1845 auto comm_node_pair = std::make_tuple(send_node, recv_node, cnode, SizeToInt(i));
1846 (void)comm_edges.insert(std::make_pair(edge, comm_node_pair));
1847 }
1848 return comm_edges;
1849 }
1850
GenerateEdgeLabel(const AnfNodePtr & src_node,const AnfNodePtr & dst_node) const1851 InterProcessEdgeLabel GraphSplitter::GenerateEdgeLabel(const AnfNodePtr &src_node, const AnfNodePtr &dst_node) const {
1852 MS_EXCEPTION_IF_NULL(src_node);
1853 MS_EXCEPTION_IF_NULL(dst_node);
1854 std::string src_node_edge_label = "";
1855 std::string dst_node_edge_label = "";
1856 if (src_node->isa<CNode>()) {
1857 src_node_edge_label = common::AnfAlgo::HasNodeAttr(kAttrInterProcessEdgeLabel, src_node->cast<CNodePtr>())
1858 ? common::AnfAlgo::GetNodeAttr<std::string>(src_node, kAttrInterProcessEdgeLabel)
1859 : "";
1860 }
1861 if (dst_node->isa<CNode>()) {
1862 dst_node_edge_label = common::AnfAlgo::HasNodeAttr(kAttrInterProcessEdgeLabel, dst_node->cast<CNodePtr>())
1863 ? common::AnfAlgo::GetNodeAttr<std::string>(dst_node, kAttrInterProcessEdgeLabel)
1864 : "";
1865 }
1866 if (!src_node_edge_label.empty() && !dst_node_edge_label.empty()) {
1867 if (src_node_edge_label != dst_node_edge_label) {
1868 MS_LOG(EXCEPTION) << "The edge label name of src node and dst node should be same."
1869 << src_node->fullname_with_scope() << "->" << dst_node->fullname_with_scope();
1870 }
1871 }
1872 InterProcessEdgeLabel edge_label;
1873 if (!src_node_edge_label.empty()) {
1874 edge_label.label_name = src_node_edge_label;
1875 } else if (!dst_node_edge_label.empty()) {
1876 edge_label.label_name = dst_node_edge_label;
1877 } else {
1878 MS_LOG(DEBUG) << "Edge label is empty for " << src_node->fullname_with_scope() << "->"
1879 << dst_node->fullname_with_scope();
1880 }
1881 return edge_label;
1882 }
1883
FindInterProcessInDegree(const std::vector<AnfNodePtr> & nodes,const InterProcessOpEdgesInfo & comm_edges)1884 std::vector<AnfNodePtr> GraphSplitter::FindInterProcessInDegree(const std::vector<AnfNodePtr> &nodes,
1885 const InterProcessOpEdgesInfo &comm_edges) {
1886 std::vector<AnfNodePtr> results;
1887 for (auto &n : nodes) {
1888 if (!n->isa<CNode>()) {
1889 continue;
1890 }
1891
1892 CNodePtr cnode = n->cast<CNodePtr>();
1893 for (size_t i = 1; i < cnode->size(); i++) {
1894 auto input_i = cnode->inputs()[i];
1895 InterProcessOpEdge edge = {input_i, node_labels_[input_i], cnode, node_labels_[cnode]};
1896 if (comm_edges.count(edge) != 0 && edge.src_label == this_process_label_) {
1897 MS_LOG(INFO) << edge.to_string() << " is a communication edge.";
1898 auto comm_node_pair = comm_edges.at(edge);
1899 (void)results.emplace_back(std::get<0>(comm_node_pair));
1900 }
1901 }
1902 }
1903 return results;
1904 }
1905
FindInterProcessOutDegree(const std::vector<AnfNodePtr> & nodes,const InterProcessOpEdgesInfo & comm_edges)1906 std::vector<AnfNodePtr> GraphSplitter::FindInterProcessOutDegree(const std::vector<AnfNodePtr> &nodes,
1907 const InterProcessOpEdgesInfo &comm_edges) {
1908 std::vector<AnfNodePtr> results;
1909 for (auto &n : nodes) {
1910 if (!n->isa<CNode>()) {
1911 continue;
1912 }
1913
1914 CNodePtr cnode = n->cast<CNodePtr>();
1915 auto users = func_graph_->manager()->node_users()[cnode];
1916 for (auto &u : users) {
1917 auto user_node = u.first->cast<CNodePtr>();
1918 InterProcessOpEdge edge = {cnode, node_labels_[cnode], user_node, node_labels_[user_node]};
1919 if (comm_edges.count(edge) != 0 && edge.dst_label == this_process_label_) {
1920 MS_LOG(INFO) << edge.to_string() << " is a communication edge.";
1921 auto comm_node_pair = comm_edges.at(edge);
1922 (void)results.emplace_back(std::get<1>(comm_node_pair));
1923 }
1924 }
1925 }
1926 return results;
1927 }
1928
GenerateInOutDegreeList(const std::vector<SplitGraphSegment> & segments,const InterProcessOpEdgesInfo & comm_edges)1929 InOutDegreeList GraphSplitter::GenerateInOutDegreeList(const std::vector<SplitGraphSegment> &segments,
1930 const InterProcessOpEdgesInfo &comm_edges) {
1931 MS_LOG(INFO) << "Start finding inter-process in-degrees.";
1932
1933 InOutDegreeList in_out_degree_list;
1934 // Traverse all the segments to add Depend for this process's graph.
1935 for (const auto &segment : segments) {
1936 // If this segment should be on current process, continue.
1937 if (segment.label == this_process_label_) {
1938 continue;
1939 }
1940 std::vector<AnfNodePtr> nodes = segment.nodes;
1941 if (nodes.empty()) {
1942 MS_LOG(EXCEPTION) << "This segment is empty.";
1943 return in_out_degree_list;
1944 }
1945
1946 auto segment_first_node = nodes[0];
1947 if (node_labels_[segment_first_node] != segment.label) {
1948 MS_LOG(EXCEPTION) << "Node label " << node_labels_[segment_first_node].to_string()
1949 << " is not the same as segment label " << segment.label.to_string();
1950 }
1951
1952 // Prepare for adding Depend between in-degree and out-degree of this segment because the execution order should
1953 // be kept consistent.
1954 std::vector<AnfNodePtr> concerned_in_degree_nodes = FindInterProcessInDegree(nodes, comm_edges);
1955 std::vector<AnfNodePtr> concerned_out_degree_nodes = FindInterProcessOutDegree(nodes, comm_edges);
1956 if (!concerned_in_degree_nodes.empty() || !concerned_out_degree_nodes.empty()) {
1957 (void)in_out_degree_list.emplace_back(std::make_pair(concerned_in_degree_nodes, concerned_out_degree_nodes));
1958 }
1959 }
1960 MS_LOG(INFO) << "End finding inter-process in-degrees.";
1961 return in_out_degree_list;
1962 }
1963
AddDependencyBetweenEdges(const InterProcessOpEdgesInfo & comm_edges)1964 void GraphSplitter::AddDependencyBetweenEdges(const InterProcessOpEdgesInfo &comm_edges) {
1965 // 'in_degree_comm_edges' is the edges with recv node on this process.
1966 InterProcessOpEdgesInfo in_degree_comm_edges;
1967 // 'out_degree_comm_edges' is the edges with send node on this process.
1968 InterProcessOpEdgesInfo out_degree_comm_edges;
1969
1970 // Src nodes of RpcSend nodes.
1971 AnfNodePtrSet send_src_nodes;
1972 // Map of src nodes to its all RpcSend nodes.
1973 std::map<AnfNodePtr, AnfNodePtrSet> src_nodes_to_send_nodes;
1974 // This map represents which send nodes are hung. Key is RpcSend node.
1975 std::map<AnfNodePtr, bool> is_send_node_hung;
1976 for (const auto &e : comm_edges) {
1977 const InterProcessOpEdge &edge_info = e.first;
1978 const InterProcessOpPair &op_pair = e.second;
1979
1980 if (edge_info.src_label == this_process_label_) {
1981 const AnfNodePtr &send_src_node = edge_info.src_node;
1982 const AnfNodePtr &rpc_send_node = std::get<0>(op_pair);
1983 (void)send_src_nodes.insert(send_src_node);
1984 (void)src_nodes_to_send_nodes[send_src_node].insert(rpc_send_node);
1985 is_send_node_hung[rpc_send_node] = true;
1986
1987 MS_LOG(DEBUG) << "Out degree edge: " << edge_info.to_string() << ". Send src node "
1988 << send_src_node->fullname_with_scope() << " has RpcSend node "
1989 << rpc_send_node->fullname_with_scope();
1990 out_degree_comm_edges[edge_info] = op_pair;
1991 }
1992
1993 if (edge_info.dst_label == this_process_label_) {
1994 MS_LOG(DEBUG) << "In degree edge: " << edge_info.to_string();
1995 in_degree_comm_edges[edge_info] = op_pair;
1996 }
1997 }
1998
1999 // This step is vital. It builds a map consists of all dependencies to send src nodes, which helps to
2000 // add explicit dependency edges for RpcSend and RpcRecv nodes.
2001 std::map<AnfNodePtr, AnfNodePtrSet> node_dependency = FilterDependencyToTargetNode(func_graph_, send_src_nodes);
2002 MS_LOG(INFO) << "After filtering out the dependencies, add dependency edges between RpcSend and RpcRecv nodes.";
2003
2004 // Connect RpcSend and RpcRecv with minimal dependencies.
2005 AddSendRecvDependency(in_degree_comm_edges, send_src_nodes, src_nodes_to_send_nodes, node_dependency,
2006 &is_send_node_hung);
2007
2008 // Some RpcSend nodes may be hung, we need to connect these nodes to output in case they are optimized out.
2009 AnfNodePtrList hung_nodes_list;
2010 for (const auto &is_hung : is_send_node_hung) {
2011 if (is_hung.second) {
2012 MS_LOG(INFO) << "RpcSend node: " << is_hung.first->fullname_with_scope() << " is hung.";
2013 (void)hung_nodes_list.emplace_back(is_hung.first);
2014 }
2015 }
2016 if (!hung_nodes_list.empty()) {
2017 HandleHungNodes(func_graph_, node_labels_, this_process_label_, hung_nodes_list);
2018 }
2019 }
2020
AddDependencyBetweenSegments(const InOutDegreeList & in_out_degree_list)2021 void GraphSplitter::AddDependencyBetweenSegments(const InOutDegreeList &in_out_degree_list) {
2022 MS_LOG(INFO) << "Start adding dependency between segments.";
2023 // This tuple is key to the dependency of send nodes so that they will not be optimized out in some cases.
2024 std::vector<AnfNodePtr> send_node_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2025 for (size_t i = 0; i < in_out_degree_list.size(); i++) {
2026 auto &concerned_in_degree_nodes = in_out_degree_list[i].first;
2027 auto &concerned_out_degree_nodes = in_out_degree_list[i].second;
2028 (void)send_node_tuple_inputs.insert(send_node_tuple_inputs.cend(), concerned_in_degree_nodes.cbegin(),
2029 concerned_in_degree_nodes.cend());
2030 if (concerned_out_degree_nodes.empty()) {
2031 // If this is the last segment's in and out degrees and has no out degrees, connect the send nodes to graph's
2032 // output.
2033 if (i == in_out_degree_list.size() - 1) {
2034 auto make_tuple_node = func_graph_->NewCNode(send_node_tuple_inputs);
2035 AbstractBasePtrList abstract_list;
2036 (void)std::for_each(send_node_tuple_inputs.cbegin() + 1, send_node_tuple_inputs.cend(),
2037 [&](const auto &input) { (void)abstract_list.emplace_back(input->abstract()); });
2038 make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
2039
2040 // Connect fused send nodes to the output so they will not be optimized out.
2041 AnfNodePtr origin_output = func_graph_->output();
2042 if (node_labels_.count(origin_output) == 0) {
2043 MS_LOG(EXCEPTION) << "The origin output node " << origin_output->fullname_with_scope()
2044 << " should have corresponding operator label.";
2045 }
2046
2047 // If the output is not on this process, replace it with a fake output node.
2048 AnfNodePtr replaced_output = nullptr;
2049 if (node_labels_[origin_output] != this_process_label_) {
2050 replaced_output = CreateReplacedOutputNode(func_graph_, origin_output);
2051 } else {
2052 replaced_output = origin_output;
2053 }
2054
2055 // Add dependency and replace.
2056 std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output, make_tuple_node};
2057 auto final_output_node = func_graph_->NewCNode(depend_inputs);
2058 MS_EXCEPTION_IF_NULL(final_output_node);
2059 final_output_node->set_abstract(replaced_output->abstract());
2060 (void)func_graph_->manager()->SetEdge(func_graph_->get_return(), 1, final_output_node);
2061 }
2062 } else {
2063 auto make_tuple_node = func_graph_->NewCNode(send_node_tuple_inputs);
2064 for (auto &recv : concerned_out_degree_nodes) {
2065 std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), recv->cast<CNodePtr>()->inputs()[1],
2066 make_tuple_node};
2067 auto depend = func_graph_->NewCNode(depend_input);
2068 depend->set_abstract(recv->cast<CNodePtr>()->inputs()[1]->abstract());
2069 func_graph_->manager()->SetEdge(recv, 1, depend);
2070 }
2071 // Reset the make tuple node inputs for next segments in degrees.
2072 send_node_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2073 }
2074 }
2075 MS_LOG(INFO) << "End adding dependency between segments.";
2076 }
2077
EliminateExtraNodes(const InterProcessOpEdgesInfo & comm_edges)2078 void GraphSplitter::EliminateExtraNodes(const InterProcessOpEdgesInfo &comm_edges) {
2079 MS_LOG(INFO) << "Start eliminating nodes not on this process.";
2080 // Eliminate nodes which should be launched by other processes by set output edge.
2081 for (auto &edge : comm_edges) {
2082 InterProcessOpPair send_recv_pair = edge.second;
2083 auto send_node = std::get<0>(send_recv_pair);
2084 auto recv_node = std::get<1>(send_recv_pair);
2085 auto user_node = std::get<2>(send_recv_pair);
2086 int user_node_index = std::get<3>(send_recv_pair);
2087
2088 OperatorLabel send_label = node_labels_[send_node];
2089 OperatorLabel recv_label = node_labels_[recv_node];
2090 if (send_label == recv_label) {
2091 MS_LOG(EXCEPTION) << "The Send and Recv must have different label. But got Send: " << send_label.to_string()
2092 << ", Recv: " << recv_label.to_string();
2093 }
2094
2095 if (recv_label == this_process_label_) {
2096 func_graph_->manager()->SetEdge(user_node, user_node_index, recv_node);
2097 }
2098 }
2099 MS_LOG(INFO) << "End eliminating nodes not on this process.";
2100 }
2101
ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)2102 void GraphSplitter::ReplaceOriginNodesWithRecv(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
2103 MS_EXCEPTION_IF_NULL(func_graph_);
2104 for (const auto &op_pair_info : fused_inter_process_op_pairs) {
2105 const OperatorLabel &send_label = op_pair_info.first.src_label;
2106 const OperatorLabel &recv_label = op_pair_info.first.dst_label;
2107 const std::vector<FusedInterProcessOpPair> &op_pairs = op_pair_info.second;
2108 if (op_pairs.empty()) {
2109 MS_LOG(EXCEPTION) << "Fused inter-process ops should not be empty for edge " << send_label.to_string() << "->"
2110 << recv_label.to_string();
2111 }
2112
2113 const auto &fused_recv_node = std::get<1>(*op_pairs.begin());
2114 MS_EXCEPTION_IF_NULL(fused_recv_node);
2115
2116 // Replace origin input with recv node.
2117 if (recv_label == this_process_label_) {
2118 for (const auto &send_recv_pair : op_pairs) {
2119 const auto &user_node = std::get<3>(send_recv_pair);
2120 int user_node_index = std::get<4>(send_recv_pair);
2121
2122 const auto &recv_abs = fused_recv_node->abstract();
2123 MS_EXCEPTION_IF_NULL(recv_abs);
2124 // The outputs of a Recv node could be a tuple or a single tensor because it could be fused.
2125 if (recv_abs->isa<abstract::AbstractTuple>()) {
2126 int output_index = std::get<2>(send_recv_pair);
2127 CNodePtr tuple_get_item_node = CreateTupleGetItemNode(func_graph_, fused_recv_node, IntToSize(output_index));
2128 func_graph_->manager()->SetEdge(user_node, user_node_index, tuple_get_item_node);
2129 } else {
2130 func_graph_->manager()->SetEdge(user_node, user_node_index, fused_recv_node);
2131 }
2132 }
2133 }
2134 }
2135 }
2136
AddSendRecvDependency(const InterProcessOpEdgesInfo & in_degree_comm_edges,const AnfNodePtrSet & send_src_nodes,const std::map<AnfNodePtr,AnfNodePtrSet> & src_nodes_to_send_nodes,const std::map<AnfNodePtr,AnfNodePtrSet> & node_dependency,std::map<AnfNodePtr,bool> * is_send_node_hung)2137 void GraphSplitter::AddSendRecvDependency(const InterProcessOpEdgesInfo &in_degree_comm_edges,
2138 const AnfNodePtrSet &send_src_nodes,
2139 const std::map<AnfNodePtr, AnfNodePtrSet> &src_nodes_to_send_nodes,
2140 const std::map<AnfNodePtr, AnfNodePtrSet> &node_dependency,
2141 std::map<AnfNodePtr, bool> *is_send_node_hung) {
2142 for (const auto &in_edge : in_degree_comm_edges) {
2143 const auto &rpc_recv_node = std::get<1>(in_edge.second);
2144 const auto &recv_dst_node = std::get<2>(in_edge.second);
2145 MS_LOG(DEBUG) << "Add dependency for RpcRecv node " << rpc_recv_node->fullname_with_scope()
2146 << " with recv dst node " << recv_dst_node->fullname_with_scope();
2147 AnfNodePtrSet depended_nodes;
2148 for (const auto &send_src_node : send_src_nodes) {
2149 // Get minimum send src nodes set which have dependencies with RpcRecv node.
2150 if (node_dependency.count(recv_dst_node) != 0 && node_dependency.at(recv_dst_node).contains(send_src_node)) {
2151 depended_nodes = UpdateDependedSet(send_src_node, depended_nodes, node_dependency);
2152 }
2153 }
2154 MS_LOG(DEBUG) << "RpcRecv dst node " << recv_dst_node->fullname_with_scope()
2155 << " depends on RpcSend src node size: " << depended_nodes.size();
2156
2157 // Generate RpcSend nodes input list to add dependency with RpcRecv Nodes.
2158 AnfNodePtrList rpc_send_list;
2159 for (const auto &send_src_node : depended_nodes) {
2160 if (src_nodes_to_send_nodes.count(send_src_node) == 0) {
2161 MS_LOG(EXCEPTION) << "Send src node " << send_src_node->fullname_with_scope()
2162 << " has no corresponding RpcSend nodes.";
2163 }
2164 const AnfNodePtrSet &rpc_send_nodes = src_nodes_to_send_nodes.at(send_src_node);
2165 for (const auto &rpc_send : rpc_send_nodes) {
2166 (*is_send_node_hung)[rpc_send] = false;
2167 (void)rpc_send_list.emplace_back(rpc_send);
2168 }
2169 }
2170 if (!rpc_send_list.empty()) {
2171 AnfNodePtr send_node_make_tuple = CreateMakeTupleNode(func_graph_, rpc_send_list);
2172 MS_EXCEPTION_IF_NULL(send_node_make_tuple);
2173 MS_LOG(DEBUG) << "Connect " << send_node_make_tuple->fullname_with_scope() << " with RpcRecv node "
2174 << rpc_recv_node->fullname_with_scope();
2175
2176 auto recv_data = rpc_recv_node->cast<CNodePtr>()->inputs()[kIndex1];
2177 MS_EXCEPTION_IF_NULL(recv_data);
2178
2179 AnfNodePtrList depend_node_inputs = {NewValueNode(prim::kPrimDepend), recv_data, send_node_make_tuple};
2180 auto depend_node = func_graph_->NewCNode(depend_node_inputs);
2181 MS_EXCEPTION_IF_NULL(depend_node);
2182 depend_node->set_abstract(recv_data->abstract());
2183 func_graph_->manager()->SetEdge(rpc_recv_node, kIndex1, depend_node);
2184 }
2185 }
2186 }
2187
AddDependencyForSend(const FusedInterProcessOpPairMap & fused_inter_process_op_pairs)2188 void GraphSplitter::AddDependencyForSend(const FusedInterProcessOpPairMap &fused_inter_process_op_pairs) {
2189 // Connect all Send nodes to MakeTuple.
2190 std::vector<AnfNodePtr> fused_send_node_tuple_inputs;
2191 MS_EXCEPTION_IF_NULL(func_graph_);
2192 for (const auto &op_pair_info : fused_inter_process_op_pairs) {
2193 const OperatorLabel &send_label = op_pair_info.first.src_label;
2194 const OperatorLabel &recv_label = op_pair_info.first.dst_label;
2195 const std::vector<FusedInterProcessOpPair> &op_pairs = op_pair_info.second;
2196 if (op_pairs.empty()) {
2197 MS_LOG(EXCEPTION) << "Fused inter-process ops should not be empty for edge " << send_label.to_string() << "->"
2198 << recv_label.to_string();
2199 }
2200 const auto &fused_send_node = std::get<0>(*op_pairs.begin());
2201 MS_EXCEPTION_IF_NULL(fused_send_node);
2202 // Make tuple all fused send nodes.
2203 if (send_label == this_process_label_) {
2204 (void)fused_send_node_tuple_inputs.emplace_back(fused_send_node);
2205 }
2206 }
2207 CNodePtr fused_send_make_tuple_node = CreateMakeTupleNode(func_graph_, fused_send_node_tuple_inputs);
2208 MS_EXCEPTION_IF_NULL(fused_send_make_tuple_node);
2209
2210 // Connect fused send nodes to the output so they will not be optimized out.
2211 AnfNodePtr origin_output = func_graph_->output();
2212 if (node_labels_.count(origin_output) == 0) {
2213 MS_LOG(EXCEPTION) << "The origin output node " << origin_output->fullname_with_scope()
2214 << " should have corresponding operator label.";
2215 }
2216
2217 // If the output is not on this process, replace it with a fake output node.
2218 AnfNodePtr replaced_output = nullptr;
2219 if (node_labels_[origin_output] != this_process_label_) {
2220 replaced_output = CreateReplacedOutputNode(func_graph_, origin_output);
2221 } else {
2222 replaced_output = origin_output;
2223 }
2224
2225 // Add dependency and replace.
2226 std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), replaced_output,
2227 fused_send_make_tuple_node};
2228 auto final_output_node = func_graph_->NewCNode(depend_inputs);
2229 MS_EXCEPTION_IF_NULL(final_output_node);
2230 final_output_node->set_abstract(replaced_output->abstract());
2231 (void)func_graph_->manager()->SetEdge(func_graph_->get_return(), 1, final_output_node);
2232 }
2233
IsNodesWithSameLabel(const AnfNodePtr & node1,const AnfNodePtr & node2)2234 bool GraphSplitter::IsNodesWithSameLabel(const AnfNodePtr &node1, const AnfNodePtr &node2) {
2235 if (node_labels_.count(node1) == 0 || node_labels_.count(node2) == 0) {
2236 MS_LOG(EXCEPTION) << "Either 'node1': " << node1->fullname_with_scope()
2237 << " or 'node2': " << node2->fullname_with_scope() << " is not marked with split label.";
2238 }
2239 return node_labels_[node1] == node_labels_[node2];
2240 }
2241
NeedSplitGraph() const2242 bool GraphSplitter::NeedSplitGraph() const {
2243 return std::find_if(node_labels_.begin(), node_labels_.end(), [&](const auto &node_to_label) {
2244 return node_to_label.second != this_process_label_;
2245 }) != node_labels_.end();
2246 }
2247
NodeHasLabel(const AnfNodePtr & node)2248 bool GraphSplitter::NodeHasLabel(const AnfNodePtr &node) { return node_labels_.count(node) != 0; }
2249 } // namespace parallel
2250 } // namespace mindspore
2251