• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/step_assigned_parallel.h"
18 
19 #include <cinttypes>
20 #include <ctime>
21 #include <algorithm>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "mindspore/core/ops/sequence_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "mindspore/core/ops/math_ops.h"
32 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
33 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "frontend/parallel/graph_util/graph_info.h"
36 #include "frontend/parallel/graph_util/graph_utils.h"
37 #include "frontend/parallel/ops_info/tmp_identity_info.h"
38 #include "frontend/parallel/step_parallel.h"
39 #include "frontend/parallel/step_parallel_utils.h"
40 #include "frontend/parallel/step_auto_parallel.h"
41 #include "frontend/parallel/parameter_manager.h"
42 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
43 #include "ir/anf.h"
44 #include "ir/tensor.h"
45 #include "frontend/parallel/graph_util/generate_graph.h"
46 #include "utils/parallel_node_check.h"
47 
48 namespace mindspore {
49 namespace parallel {
50 // l_RefMap, for CNode B input i is a RefKey[Parameter C],
51 // it will be one item in map with key: C, and value: (B, i)
52 std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> l_RefMap;
53 
GetOutputLayoutFromCNode(const CNodePtr & cnode,size_t output_index)54 static std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) {
55   MS_EXCEPTION_IF_NULL(cnode);
56   OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
57   MS_EXCEPTION_IF_NULL(distribute_operator);
58   if (distribute_operator->outputs_tensor_info().size() <= output_index) {
59     MS_LOG(EXCEPTION) << "outputs_tensor_info size is  " << distribute_operator->inputs_tensor_info().size()
60                       << ", must be greater than output_index  " << output_index;
61   }
62   TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index];
63   TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
64   return std::make_shared<TensorLayout>(tensorlayout_out);
65 }
66 
FindPrevParallelCareNodeLayout(const AnfNodePtr & node,size_t output_index)67 static std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) {
68   if (!node->isa<CNode>()) {
69     return nullptr;
70   }
71   CNodePtr cnode = node->cast<CNodePtr>();
72   if (!IsValueNode<Primitive>(cnode->input(0))) {
73     return nullptr;
74   }
75   if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
76     auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
77     if (!layout_ptr) {
78       MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
79     }
80     return layout_ptr;
81   }
82   return nullptr;
83 }
84 
FindPrevLayout(const AnfNodePtr & node)85 static std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
86   if (node->isa<Parameter>()) {
87     return CreateParameterLayout(node);
88   }
89   if (!node->isa<CNode>()) {
90     return nullptr;
91   }
92   CNodePtr cnode = node->cast<CNodePtr>();
93   if (!IsValueNode<Primitive>(cnode->input(0))) {
94     return nullptr;
95   }
96   if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
97     return cnode->user_data<TensorLayout>();
98   }
99   if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() &&
100       !IsPrimitiveCNode(node, prim::kPrimReshape)) {
101     auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
102     if (!layout_ptr) {
103       MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
104     }
105     return layout_ptr;
106   }
107   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
108   PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
109   if (prim->name() == prim::kPrimTupleGetItem->name()) {
110     auto tuple_index = GetTupleGetItemIndex(cnode);
111     auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
112     if (!layout_ptr) {
113       MS_LOG(EXCEPTION) << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a "
114                            "parallel care node "
115                            "before tuple_getitem!";
116     }
117     return layout_ptr;
118   }
119   for (size_t index = 0; index < cnode->size(); ++index) {
120     if (prim->name() == DEPEND && index != 1) {
121       continue;
122     }
123     auto layout_ptr = FindPrevLayout(cnode->inputs()[index]);
124     if (!layout_ptr) {
125       continue;
126     }
127     return layout_ptr;
128   }
129   MS_LOG(WARNING) << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error";
130   return nullptr;
131 }
132 
133 // if reshape's output connect to several primitive, return the first layout found
FindNextLayout(const CNodePtr & cnode,bool * next_is_reshape,int make_tuple_index)134 static std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape,
135                                                     int make_tuple_index) {
136   MS_EXCEPTION_IF_NULL(cnode);
137   MS_EXCEPTION_IF_NULL(cnode->func_graph());
138   FuncGraphManagerPtr manager = cnode->func_graph()->manager();
139   MS_EXCEPTION_IF_NULL(manager);
140   AnfNodeIndexSet node_set = manager->node_users()[cnode];
141   for (auto &node_pair : node_set) {
142     auto use_apply = node_pair.first->cast<CNodePtr>();
143     if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
144       continue;
145     }
146     if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
147       *next_is_reshape = true;
148       continue;
149     }
150     if (IsPrimitiveCNode(use_apply, prim::kPrimDepend) && node_pair.second != 1) {
151       continue;
152     }
153     if (IsPrimitiveCNode(use_apply, prim::kPrimMakeTuple)) {
154       make_tuple_index = node_pair.second;
155       return FindNextLayout(use_apply, next_is_reshape, make_tuple_index);
156     }
157     if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>() &&
158         IsSomePrimitiveList(use_apply, SUPPORT_NEW_SHAPEBASE_OPS)) {
159       MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString() << ", in support new shapebase ops";
160       *next_is_reshape = false;
161       auto layout = GetInputLayoutFromCNode(node_pair, make_tuple_index);
162       return std::make_shared<TensorLayout>(layout);
163     }
164     if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
165       if (make_tuple_index != -1) {
166         node_pair.second = make_tuple_index;
167       }
168       MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString();
169       *next_is_reshape = false;
170       auto layout = GetInputLayoutFromCNode(node_pair, -1);
171       return std::make_shared<TensorLayout>(layout);
172     }
173     MS_LOG(DEBUG) << "FindNextLayout failed node " << use_apply->DebugString() << "  " << IsParallelCareNode(use_apply)
174                   << "   " << use_apply->has_user_data<OperatorInfo>();
175 
176     auto layout_ptr = FindNextLayout(use_apply, next_is_reshape, -1);
177     if (layout_ptr) {
178       return layout_ptr;
179     }
180   }
181   MS_LOG(WARNING) << "FindNextLayout return nullptr, if reshape is not the last primitive, there must be some error";
182   return nullptr;
183 }
184 
NewAllGatherNode(const std::string & name,const std::string & group)185 AnfNodePtr NewAllGatherNode(const std::string &name, const std::string &group) {
186   std::shared_ptr<Primitive> prim;
187   prim = std::make_shared<Primitive>(name);
188   ValuePtr attr0_value = MakeValue(group);
189   Attr attr0 = std::make_pair(GROUP, attr0_value);
190   prim->AddAttr(GROUP, attr0_value);
191   prim->AddAttr("fusion", MakeValue(static_cast<int64_t>(0)));
192   prim->AddAttr("mean_flag", MakeValue(false));
193   prim->AddAttr("no_eliminate", MakeValue(true));
194   std::vector<unsigned int> rank_list = {};
195   auto long_rank_list = parallel::g_device_manager->FindRankListByHashName(group);
196   (void)std::transform(long_rank_list.begin(), long_rank_list.end(), std::back_inserter(rank_list),
197                        [](int64_t d) -> unsigned int { return IntToUint(LongToInt(d)); });
198 
199   prim->AddAttr(kAttrRankSize, MakeValue(static_cast<int64_t>(rank_list.size())));
200   auto node = NewValueNode(prim);
201   return node;
202 }
203 
204 // From ops To AllReduce->ops
InsertAllReduceToNodeInput(const CNodePtr & node,const std::string & group,const std::string & instance_name)205 static void InsertAllReduceToNodeInput(const CNodePtr &node, const std::string &group,
206                                        const std::string &instance_name) {
207   MS_EXCEPTION_IF_NULL(node);
208   FuncGraphPtr func_graph = node->func_graph();
209   size_t index = 1;
210   MS_EXCEPTION_IF_NULL(func_graph);
211   Operator allreduce_op = CreateAllReduceOp(REDUCE_OP_SUM, group);
212 
213   // Insert it as the input of the node
214   AnfNodePtr input = node->input(index);
215   MS_EXCEPTION_IF_NULL(input);
216   // if it is not a tensor, continue
217   if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
218     return;
219   }
220   InsertNode(allreduce_op, node, index, node->input(index), func_graph, instance_name);
221 }
222 
InsertAllReduceOps(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)223 bool InsertAllReduceOps(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, const size_t devices) {
224   int64_t device_num = devices;
225   if (device_num <= 1) {
226     return true;
227   }
228   MS_EXCEPTION_IF_NULL(root);
229   for (auto &node : all_nodes) {
230     if (!node->isa<CNode>()) {
231       continue;
232     }
233     auto expect_add = node->cast<CNodePtr>();
234     if (!IsSomePrimitive(expect_add, prim::kPrimAdd->name())) {
235       continue;
236     }
237     AnfNodePtr expect_matmul = expect_add->input(1);
238     MS_EXCEPTION_IF_NULL(expect_matmul);
239     if (!expect_matmul->isa<CNode>()) {
240       continue;
241     }
242     auto expect_matmul_cnode = expect_matmul->cast<CNodePtr>();
243     if (!IsSomePrimitive(expect_matmul_cnode, prim::kPrimMatMul->name())) {
244       continue;
245     }
246     auto matmul_prim = GetCNodePrimitive(expect_matmul_cnode);
247     MS_EXCEPTION_IF_NULL(matmul_prim);
248     if (matmul_prim->HasAttr(IN_STRATEGY)) {
249       auto matmul_stra = matmul_prim->GetAttr(IN_STRATEGY);
250       if (matmul_stra == nullptr) {
251         continue;
252       }
253       auto matmul_var = GetValue<vector<Shape>>(matmul_stra);
254       if (matmul_var.size() > 0) {
255         Dimensions sub_a_strategy = matmul_var.at(0);
256         Dimensions sub_b_strategy = matmul_var.at(1);
257         if (sub_a_strategy.size() == 2 && sub_b_strategy.size() == 2 && sub_a_strategy[1] == sub_b_strategy[0] &&
258             sub_a_strategy[1] > 1) {
259           MS_LOG(INFO) << "Here should insert AllReduce Ops: ";
260           InsertAllReduceToNodeInput(expect_add, HCCL_WORLD_GROUP, PARALLEL_GLOBALNORM);
261           AnfNodePtr expect_reshape = expect_matmul_cnode->input(1);
262           if (!expect_reshape->isa<CNode>()) {
263             continue;
264           }
265           auto expect_reshape_cnode = expect_reshape->cast<CNodePtr>();
266           if (!IsSomePrimitive(expect_reshape_cnode, prim::kPrimReshape->name())) {
267             continue;
268           }
269           Shape origin_dst_shape =
270             GetValue<std::vector<int64_t>>(expect_reshape_cnode->input(2)->cast<ValueNodePtr>()->value());
271           if (origin_dst_shape.size() == 1 && origin_dst_shape[0] == -1) {
272             continue;
273           }
274           Shape new_dst_shape;
275           new_dst_shape.push_back(origin_dst_shape[0]);
276           new_dst_shape.push_back(origin_dst_shape[1] / device_num);
277           for (auto s : new_dst_shape) {
278             MS_LOG(INFO) << "new_dst_shape: " << s;
279           }
280 
281           expect_reshape_cnode->set_input(2, NewValueNode(MakeValue(new_dst_shape)));
282 
283           auto reshape_node_abstract = expect_reshape_cnode->abstract()->Clone();
284           std::shared_ptr<abstract::BaseShape> output_shape = std::make_shared<abstract::Shape>(new_dst_shape);
285           reshape_node_abstract->set_shape(output_shape);
286           MS_LOG(INFO) << "new_dst_shape: " << reshape_node_abstract->ToString();
287           expect_reshape_cnode->set_abstract(reshape_node_abstract);
288         }
289       }
290     }
291   }
292   return true;
293 }
294 
InsertAllReduceOpsForFFN(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)295 bool InsertAllReduceOpsForFFN(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root,
296                               const size_t devices) {
297   MS_EXCEPTION_IF_NULL(root);
298   for (auto &node : all_nodes) {
299     if (!node->isa<CNode>()) {
300       continue;
301     }
302     auto expect_add = node->cast<CNodePtr>();
303     if (!IsSomePrimitive(expect_add, prim::kPrimAdd->name())) {
304       continue;
305     }
306     AnfNodePtr expect_batchmatmul = expect_add->input(1);
307     MS_EXCEPTION_IF_NULL(expect_batchmatmul);
308     if (!expect_batchmatmul->isa<CNode>()) {
309       continue;
310     }
311     auto expect_batchmatmul_cnode = expect_batchmatmul->cast<CNodePtr>();
312     if (!IsSomePrimitive(expect_batchmatmul_cnode, prim::kPrimBatchMatMul->name())) {
313       continue;
314     }
315     auto batchmatmul_prim = GetCNodePrimitive(expect_batchmatmul_cnode);
316     MS_EXCEPTION_IF_NULL(batchmatmul_prim);
317     if (batchmatmul_prim->HasAttr(IN_STRATEGY)) {
318       auto batchmatmul_stra = batchmatmul_prim->GetAttr(IN_STRATEGY);
319       if (batchmatmul_stra == nullptr) {
320         continue;
321       }
322       auto batchmatmul_var = GetValue<vector<Shape>>(batchmatmul_stra);
323       if (batchmatmul_var.size() > 0) {
324         Dimensions sub_a_strategy = batchmatmul_var.at(0);
325         Dimensions sub_b_strategy = batchmatmul_var.at(1);
326         if (sub_a_strategy.size() == 4 && sub_b_strategy.size() == 3 && sub_a_strategy[3] == sub_b_strategy[1] &&
327             sub_a_strategy[3] > 1) {
328           MS_LOG(INFO) << "Here should insert AllReduce Ops: ";
329           InsertAllReduceToNodeInput(expect_add, HCCL_WORLD_GROUP, PARALLEL_GLOBALNORM);
330         }
331       }
332     }
333   }
334   return true;
335 }
336 
ChangeReshape(const AnfNodePtr & node,const size_t devices)337 void ChangeReshape(const AnfNodePtr &node, const size_t devices) {
338   int64_t device_num = devices;
339   MS_EXCEPTION_IF_NULL(node);
340   if (!node->isa<CNode>()) {
341     return;
342   }
343   auto expect_reshape_cnode = node->cast<CNodePtr>();
344   if (!IsSomePrimitive(expect_reshape_cnode, prim::kPrimReshape->name())) {
345     return;
346   }
347   auto reshape_node_input = expect_reshape_cnode->input(2);
348   if (reshape_node_input == nullptr) {
349     return;
350   }
351   MS_LOG(INFO) << "find reshape ops: " << expect_reshape_cnode->DebugString();
352   if (reshape_node_input->isa<ValueNode>()) {
353     Shape origin_dst_shape = GetValue<std::vector<int64_t>>(reshape_node_input->cast<ValueNodePtr>()->value());
354     if (origin_dst_shape.size() != 4) {
355       return;
356     }
357     if (origin_dst_shape[2] % device_num != 0) {
358       return;
359     }
360     Shape new_dst_shape;
361     new_dst_shape.push_back(origin_dst_shape[0]);
362     new_dst_shape.push_back(origin_dst_shape[1]);
363     new_dst_shape.push_back(origin_dst_shape[2] / device_num);
364     new_dst_shape.push_back(origin_dst_shape[3]);
365     for (auto s : new_dst_shape) {
366       MS_LOG(INFO) << "reshape new_dst_shape: " << s;
367     }
368     expect_reshape_cnode->set_input(2, NewValueNode(MakeValue(new_dst_shape)));
369     auto reshape_node_abstract = expect_reshape_cnode->abstract()->Clone();
370     std::shared_ptr<abstract::BaseShape> output_shape = std::make_shared<abstract::Shape>(new_dst_shape);
371     reshape_node_abstract->set_shape(output_shape);
372     MS_LOG(INFO) << "new_dst_shape: " << reshape_node_abstract->ToString();
373     expect_reshape_cnode->set_abstract(reshape_node_abstract);
374 
375   } else if (reshape_node_input->isa<CNode>()) {
376     auto expect_maketuple_cnode = reshape_node_input->cast<CNodePtr>();
377     MS_LOG(INFO) << "Before modify reshape maketuple: " << expect_maketuple_cnode->DebugString();
378     if (!IsSomePrimitive(expect_maketuple_cnode, prim::kPrimMakeTuple->name())) {
379       return;
380     }
381     auto maketuple_node_input = expect_maketuple_cnode->input(3);
382     if (maketuple_node_input == nullptr) {
383       return;
384     }
385     if (!maketuple_node_input->isa<ValueNode>()) {
386       return;
387     }
388     int64_t origin_value = GetValue<int64_t>(maketuple_node_input->cast<ValueNodePtr>()->value());
389     if (origin_value % device_num == 0 && !expect_maketuple_cnode->HasAttr("has_modifyed")) {
390       int64_t new_value = origin_value / device_num;
391       expect_maketuple_cnode->set_input(3, NewValueNode(MakeValue(new_value)));
392       expect_maketuple_cnode->AddAttr("has_modifyed", MakeValue(true));
393       MS_LOG(INFO) << "After modify reshape maketuple: " << expect_maketuple_cnode->DebugString();
394     }
395   }
396 }
397 
ModifyReshapeOps(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)398 bool ModifyReshapeOps(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, const size_t devices) {
399   int64_t device_num = devices;
400   MS_EXCEPTION_IF_NULL(root);
401   for (auto &node : all_nodes) {
402     if (!node->isa<CNode>()) {
403       continue;
404     }
405     auto expect_transpose = node->cast<CNodePtr>();
406     if (!IsSomePrimitive(expect_transpose, prim::kPrimTranspose->name())) {
407       continue;
408     }
409     auto transpose_prim = GetCNodePrimitive(expect_transpose);
410     MS_EXCEPTION_IF_NULL(transpose_prim);
411     if (!transpose_prim->HasAttr(IN_STRATEGY)) {
412       continue;
413     }
414     auto transpose_stra = transpose_prim->GetAttr(IN_STRATEGY);
415     if (transpose_stra == nullptr) {
416       continue;
417     }
418     auto transpose_var = GetValue<vector<Shape>>(transpose_stra);
419     if (transpose_var.size() > 0) {
420       Dimensions sub_strategy = transpose_var.at(0);
421       bool all_ones = std::all_of(sub_strategy.begin(), sub_strategy.end(), [](int64_t i) { return i == 1; });
422       if (all_ones) {
423         continue;
424       }
425     }
426     AnfNodePtr expect_reshape = expect_transpose->input(1);
427     ChangeReshape(expect_reshape, device_num);
428   }
429   return true;
430 }
431 
ModifyMakeTupleOps(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)432 bool ModifyMakeTupleOps(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, const size_t devices) {
433   int64_t device_num = devices;
434   MS_EXCEPTION_IF_NULL(root);
435   for (auto &node : all_nodes) {
436     if (!node->isa<CNode>()) {
437       continue;
438     }
439     auto expect_maketuple = node->cast<CNodePtr>();
440     if (!IsSomePrimitive(expect_maketuple, prim::kPrimMakeTuple->name())) {
441       continue;
442     }
443     if (expect_maketuple->size() != 4) {
444       continue;
445     }
446     if (expect_maketuple->input(1)->isa<CNode>() && expect_maketuple->input(2)->isa<CNode>() &&
447         expect_maketuple->input(3)->isa<ValueNode>()) {
448       if (IsSomePrimitive(expect_maketuple->input(1)->cast<CNodePtr>(), prim::kPrimTupleGetItem->name()) &&
449           IsSomePrimitive(expect_maketuple->input(2)->cast<CNodePtr>(), prim::kPrimTupleGetItem->name())) {
450         auto maketuple_node_input = expect_maketuple->input(3);
451         int64_t origin_value = GetValue<int64_t>(maketuple_node_input->cast<ValueNodePtr>()->value());
452         if (origin_value % device_num == 0) {
453           int64_t new_value = origin_value / device_num;
454           expect_maketuple->set_input(3, NewValueNode(MakeValue(new_value)));
455           MS_LOG(INFO) << "After modify MakeTuple, the shape is : " << expect_maketuple->DebugString();
456         }
457       }
458     }
459   }
460   return true;
461 }
462 
ModifySoftmaxReshapeOps(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)463 bool ModifySoftmaxReshapeOps(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, const size_t devices) {
464   int64_t device_num = devices;
465   MS_EXCEPTION_IF_NULL(root);
466   for (auto &node : all_nodes) {
467     if (!node->isa<CNode>()) {
468       continue;
469     }
470     auto expect_reshape = node->cast<CNodePtr>();
471     if (!IsSomePrimitive(expect_reshape, prim::kPrimReshape->name())) {
472       continue;
473     }
474 
475     AnfNodePtr expect_cast = expect_reshape->input(1);
476     MS_EXCEPTION_IF_NULL(expect_cast);
477     if (!expect_cast->isa<CNode>()) {
478       continue;
479     }
480     auto expect_cast_cnode = expect_cast->cast<CNodePtr>();
481     if (!IsSomePrimitive(expect_cast_cnode, "Cast")) {
482       continue;
483     }
484 
485     auto expect_softmax = expect_cast_cnode->input(1);
486     MS_EXCEPTION_IF_NULL(expect_softmax);
487     if (!expect_softmax->isa<CNode>()) {
488       continue;
489     }
490     auto expect_softmax_cnode = expect_softmax->cast<CNodePtr>();
491     if (!IsSomePrimitive(expect_softmax_cnode, "Softmax")) {
492       continue;
493     }
494     auto reshape_node_input = expect_reshape->input(2);
495     if (reshape_node_input == nullptr) {
496       continue;
497     }
498     if (!reshape_node_input->isa<ValueNode>()) {
499       continue;
500     }
501     Shape origin_dst_shape = GetValue<std::vector<int64_t>>(reshape_node_input->cast<ValueNodePtr>()->value());
502     if (origin_dst_shape.size() != 4) {
503       continue;
504     }
505     if (origin_dst_shape[1] % device_num != 0) {
506       continue;
507     }
508     Shape new_dst_shape;
509     new_dst_shape.push_back(origin_dst_shape[0]);
510     new_dst_shape.push_back(origin_dst_shape[1] / device_num);
511     new_dst_shape.push_back(origin_dst_shape[2]);
512     new_dst_shape.push_back(origin_dst_shape[3]);
513     for (auto s : new_dst_shape) {
514       MS_LOG(INFO) << "reshape new_dst_shape: " << s;
515     }
516 
517     expect_reshape->set_input(2, NewValueNode(MakeValue(new_dst_shape)));
518 
519     auto reshape_node_abstract = expect_reshape->abstract()->Clone();
520     std::shared_ptr<abstract::BaseShape> output_shape = std::make_shared<abstract::Shape>(new_dst_shape);
521     reshape_node_abstract->set_shape(output_shape);
522     MS_LOG(INFO) << "new_dst_shape: " << reshape_node_abstract->ToString();
523     expect_reshape->set_abstract(reshape_node_abstract);
524   }
525   return true;
526 }
527 
CheckExtractInformation(const CNodePtr & cnode)528 static bool CheckExtractInformation(const CNodePtr &cnode) {
529   if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
530     return false;
531   }
532 
533   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
534   PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
535   if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
536     return false;
537   }
538   if (!IsParallelCareNode(cnode)) {
539     return false;
540   }
541   return true;
542 }
543 
InitRefMap(const FuncGraphPtr & root)544 void InitRefMap(const FuncGraphPtr &root) {
545   auto manager = root->manager();
546   auto node_list = TopoSort(root->get_return());
547   for (auto &node : node_list) {
548     auto cnode = node->cast<CNodePtr>();
549     if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
550       continue;
551     }
552 
553     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
554     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
555     if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
556       continue;
557     }
558     if (IsPrimitiveCNode(node, prim::kPrimSend) || IsPrimitiveCNode(node, prim::kPrimUpdateState) ||
559         IsPrimitiveCNode(node, prim::kPrimDepend)) {
560       continue;
561     }
562     std::vector<AnfNodePtr> all_inputs = cnode->inputs();
563     size_t inputs_size = all_inputs.size();
564     for (size_t i = 1; i < inputs_size; ++i) {
565       AnfNodePtr input = all_inputs[i];
566       if (HasAbstractMonad(input)) {
567         continue;
568       }
569       if (input->isa<Parameter>() && input->cast<ParameterPtr>()->has_default()) {
570         auto func_graph = cnode->func_graph();
571         MS_EXCEPTION_IF_NULL(func_graph);
572         auto param_node = input->cast<ParameterPtr>();
573         std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(cnode, SizeToLong(i));
574         if (IsInTrivialNodeList(cnode) || IsSomePrimitive(cnode, prim::kPrimLoad->name())) {
575           auto &node_users = manager->node_users();
576           auto iter = node_users.find(node);
577           if (iter == node_users.end()) {
578             MS_LOG(ERROR) << "Can not find the parameter used node.";
579           }
580           auto &node_set = iter->second;
581           const auto node_set_back = node_set.back().first->cast<CNodePtr>();
582           if (node_set_back != nullptr && IsSomePrimitive(node_set_back, prim::kPrimMakeTuple->name())) {
583             l_RefMap[param_node] = node_set.front();
584           } else {
585             l_RefMap[param_node] = node_set.back();
586           }
587         } else {
588           l_RefMap[param_node] = node_pair;
589         }
590       }
591     }
592   }
593 }
594 
SetParallelShape(const AnfNodePtr & parameter,const std::pair<AnfNodePtr,int64_t> & res,size_t rank_id)595 static void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int64_t> &res, size_t rank_id) {
596   MS_LOG(INFO) << "Begin set parallel shape";
597   // check null for param and cnode
598   auto param_shape = parameter->Shape();
599 
600   MS_EXCEPTION_IF_NULL(parameter);
601   MS_EXCEPTION_IF_NULL(param_shape);
602 
603   CNodePtr cnode = res.first->cast<CNodePtr>();
604   MS_EXCEPTION_IF_NULL(cnode);
605   OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
606   if (distribute_operator == nullptr) {
607     MS_LOG(EXCEPTION) << "node " << cnode->DebugString() << " 's distribute_operator is nullptr";
608   }
609   if (LongToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
610     MS_LOG(EXCEPTION) << "The parameter index is not in inputs_tensor_info. index = " << (res.second - 1)
611                       << ", inputs_tensor_info size = " << distribute_operator->inputs_tensor_info().size();
612   }
613   TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(res.second - 1)];
614   TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
615   Shape slice_shape = tensor_layout.slice_shape().array();
616 
617   AbstractBasePtr abstract = parameter->abstract();
618   if (abstract == nullptr) {
619     MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract is nullptr";
620   }
621 
622   AbstractBasePtr cloned_abstract = abstract->Clone();
623   if (cloned_abstract == nullptr) {
624     MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract clone failed";
625   }
626 
627   cloned_abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape));
628   parameter->set_abstract(cloned_abstract);
629   ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
630 
631   MS_EXCEPTION_IF_NULL(parameter_ptr);
632   MS_LOG(INFO) << "Begin split parameters";
633   parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
634   if (ParallelContext::GetInstance()->direct_split() && parameter_ptr->has_default()) {
635     auto layout = parameter_ptr->user_data<TensorLayout>();
636     MS_LOG(INFO) << "parameter: " << parameter->ToString() << parameter->Shape()->ToString()
637                  << "parameter_ptr->default_param()" << parameter_ptr->default_param() << "LAYOUT"
638                  << layout->ToString();
639     SliceTensorObj(parameter_ptr, layout, rank_id);
640   }
641 }
642 
DoParameterSliceShape(const FuncGraphPtr & root,size_t rank_id)643 static void DoParameterSliceShape(const FuncGraphPtr &root, size_t rank_id) {
644   MS_EXCEPTION_IF_NULL(root);
645   auto parameters = root->parameters();
646   for (auto &parameter : parameters) {
647     MS_EXCEPTION_IF_NULL(parameter->Shape());
648     auto iter = l_RefMap.find(parameter);
649     if (iter != l_RefMap.cend()) {
650       MS_LOG(INFO) << "SetParallelShape for parameter: " << parameter->ToString();
651       SetParallelShape(parameter, l_RefMap[parameter], rank_id);
652       SetSharedParameterFlag(root, parameter);
653       continue;
654     }
655   }
656   l_RefMap.clear();
657 }
658 
ExtractAndModifyStrategy(const CNodePtr & cnode,const std::string & attr_name,const ValuePtr & stra)659 StrategyPtr ExtractAndModifyStrategy(const CNodePtr &cnode, const std::string &attr_name, const ValuePtr &stra) {
660   if (stra == nullptr) {
661     return nullptr;
662   }
663   auto var = stra->cast<ValueTuplePtr>();
664   if (var == nullptr) {
665     return nullptr;
666   }
667 
668   StrategyPtr strategyPtr;
669   int64_t stage_id = g_device_manager->stage_id();
670   MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
671   int64_t device_num = g_device_manager->DeviceNum();
672   MS_LOG(INFO) << "Extract information: device_num " << device_num;
673   if (var->size() > 0) {
674     std::vector<ValuePtr> elements = var->value();
675     Strategies strategy;
676     for (uint64_t index = 0; index < elements.size(); ++index) {
677       Dimensions dim;
678       if (elements[index]->isa<ValueSequence>()) {
679         auto value_tuple = elements[index]->cast<ValueTuplePtr>();
680         std::vector<ValuePtr> value_vector = value_tuple->value();
681         (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
682                              [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
683         for (size_t i = 0; i < dim.size(); i++) {
684           if (dim[i] > 1 && dim[i] != device_num) {
685             dim[i] = device_num;
686           }
687         }
688         strategy.push_back(dim);
689       } else {
690         MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
691       }
692     }
693     if (strategy.empty()) {
694       MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
695     }
696     cnode->AddPrimalAttr(attr_name, MakeValue(strategy));
697     strategyPtr = NewStrategy(stage_id, strategy);
698     MS_LOG(INFO) << "Extract information: new strategy " << cnode->GetPrimalAttr(attr_name)->ToString();
699   }
700   return strategyPtr;
701 }
702 
ExtractStrategyAndInit(const CNodePtr & cnode,const PrimitivePtr & prim,const OperatorInfoPtr & op_info)703 static void ExtractStrategyAndInit(const CNodePtr &cnode, const PrimitivePtr &prim, const OperatorInfoPtr &op_info) {
704   StrategyPtr in_strategy = nullptr, out_strategy = nullptr;
705   auto attrs = prim->attrs();
706 
707   // load strategy map from checkpoint
708   StrategyMap stra_map;
709 
710   std::string strategy_key_name = "";
711   auto param_names = NodeParameterName(cnode, -1, 0);
712   if (!param_names.empty()) {
713     strategy_key_name = prim->name() + "_" + param_names[0].first;
714   }
715   if (!prim->HasAttr(STAND_ALONE)) {
716     if ((!StrategyFound(attrs) && !cnode->HasPrimalAttr(IN_STRATEGY)) || prim->HasAttr(BATCH_PARALLEL)) {
717       MS_LOG(INFO) << "ExtractInformation: the strategy of node " << cnode->ToString() << " prim " << prim->name()
718                    << " is empty, using batch parallel";
719       in_strategy = GenerateBatchParallelStrategy(op_info, prim);
720     } else if (cnode->HasPrimalAttr(IN_STRATEGY)) {
721       in_strategy = ExtractAndModifyStrategy(cnode, IN_STRATEGY, cnode->GetPrimalAttr(IN_STRATEGY));
722 
723       out_strategy = ExtractAndModifyStrategy(cnode, OUT_STRATEGY, cnode->GetPrimalAttr(OUT_STRATEGY));
724     } else if (StrategyFound(attrs)) {
725       in_strategy = ExtractAndModifyStrategy(cnode, IN_STRATEGY, attrs[IN_STRATEGY]);
726       out_strategy = ExtractAndModifyStrategy(cnode, OUT_STRATEGY, attrs[OUT_STRATEGY]);
727     } else {
728       in_strategy = stra_map[strategy_key_name];
729     }
730   } else {
731     in_strategy = GenerateStandAloneStrategy(op_info->inputs_shape());
732   }
733 
734   MS_EXCEPTION_IF_NULL(in_strategy);
735   if (op_info->Init(in_strategy, out_strategy) == FAILED) {
736     MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed" << trace::DumpSourceLines(cnode);
737   }
738 }
739 
ExtractGraphInformation(const std::vector<AnfNodePtr> & all_nodes)740 void ExtractGraphInformation(const std::vector<AnfNodePtr> &all_nodes) {
741   MS_LOG(INFO) << "ExtractInformation";
742   SetStridedSliceSplitStrategy(all_nodes);
743   for (auto &node : all_nodes) {
744     auto cnode = node->cast<CNodePtr>();
745     if (!CheckExtractInformation(cnode) || IsPrimitiveCNode(node, prim::kPrimSend) ||
746         IsPrimitiveCNode(node, std::make_shared<Primitive>("PadV3")) ||
747         IsPrimitiveCNode(node, std::make_shared<Primitive>("StridedSlice")) ||
748         IsPrimitiveCNode(node, std::make_shared<Primitive>("Sort")) ||
749         IsPrimitiveCNode(node, std::make_shared<Primitive>("Less")) ||
750         IsPrimitiveCNode(node, std::make_shared<Primitive>("Range"))) {
751       continue;
752     }
753 
754     SetVirtualDatasetStrategy(cnode);
755     ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
756     PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
757 
758     OperatorInfoPtr operator_ = CreateOperatorInfo(cnode);
759     operator_->set_assigned_parallel(true);
760     MS_EXCEPTION_IF_NULL(operator_);
761 
762     if (prim->name() == RESHAPE) {
763       cnode->set_user_data<OperatorInfo>(operator_);
764       continue;
765     }
766 
767     ExtractStrategyAndInit(cnode, prim, operator_);
768     cnode->set_user_data<OperatorInfo>(operator_);
769   }
770 }
771 
StepReplaceGraph(const ReplaceGraphPtr & replace_graph,const CNodePtr & node)772 static void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) {
773   MS_EXCEPTION_IF_NULL(replace_graph);
774   MS_EXCEPTION_IF_NULL(node);
775   MS_EXCEPTION_IF_NULL(replace_graph->second);
776   FuncGraphPtr func_graph = node->func_graph();
777   MS_EXCEPTION_IF_NULL(func_graph);
778   FuncGraphManagerPtr manager = func_graph->manager();
779   if (manager == nullptr) {
780     MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
781   }
782   mindspore::HashMap<AnfNodePtr, int> input_map = {};
783   static int appear_count = 0;
784   for (auto &replace_input : replace_graph->first) {
785     auto pre_node = node->input(LongToSize(replace_input.second));
786 
787     auto it = input_map.find(replace_input.first);
788     if (it != input_map.end()) {
789       appear_count = 1 + it->second;
790     } else {
791       appear_count = 1;
792     }
793     auto replace_input_cnode = replace_input.first->cast<CNodePtr>();
794     size_t inputs_size = replace_input_cnode->size();
795     while (IntToSize(appear_count) < inputs_size && replace_input_cnode->input(appear_count)->func_graph() != nullptr) {
796       ++appear_count;
797     }
798     if (IntToSize(appear_count) >= inputs_size) {
799       MS_LOG(EXCEPTION) << "No replaceable virtual_input_node";
800     }
801     input_map[replace_input.first] = appear_count;
802     replace_input_cnode->set_in_forward_flag(true);
803     manager->SetEdge(replace_input.first, appear_count, pre_node);
804   }
805   //  "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
806   auto replace_output = replace_graph->second->cast<CNodePtr>();
807   MS_EXCEPTION_IF_NULL(replace_output);
808   replace_output->set_in_forward_flag(true);
809   replace_output->set_primal_attrs(node->primal_attrs());
810   (void)manager->Replace(node, replace_output);
811 }
812 
ReplaceGatherOps(const std::vector<AnfNodePtr> & all_nodes,const size_t devices)813 static void ReplaceGatherOps(const std::vector<AnfNodePtr> &all_nodes, const size_t devices) {
814   for (auto &node : all_nodes) {
815     MS_EXCEPTION_IF_NULL(node);
816     if (node->isa<CNode>()) {
817       auto cnode = node->cast<CNodePtr>();
818       if (!IsSomePrimitive(cnode, prim::kPrimGather->name())) {
819         continue;
820       }
821       OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
822       MS_EXCEPTION_IF_NULL(distribute_operator);
823       auto replace_op = distribute_operator->replace_op();
824       // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore.
825       auto replace_graph = distribute_operator->replace_graph(cnode);
826       if (!replace_op.empty() && replace_graph) {
827         MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used";
828       }
829       if (replace_graph) {
830         MS_LOG(INFO) << "StepReplaceGraph " << cnode->DebugString();
831         StepReplaceGraph(replace_graph, cnode);
832       }
833     }
834   }
835 }
836 
FixReturnRedistribution(const FuncGraphPtr & root,const size_t devices)837 static void FixReturnRedistribution(const FuncGraphPtr &root, const size_t devices) {
838   MS_LOG(INFO) << "FixReturnRedistribution";
839   CNodePtr ret = root->get_return();
840   AnfNodePtr expect_matmul = ret->input(1);
841   MS_EXCEPTION_IF_NULL(expect_matmul);
842   if (!expect_matmul->isa<CNode>()) {
843     return;
844   }
845   auto expect_matmul_node = expect_matmul->cast<CNodePtr>();
846   if (!IsSomePrimitive(expect_matmul_node, prim::kPrimMatMul->name())) {
847     return;
848   }
849   Shapes return_input_shapes = GetNodeShape(ret);
850   MS_LOG(INFO) << "return_input_shapes size" << return_input_shapes.size();
851   if (return_input_shapes.size() == 1) {
852     MS_LOG(INFO) << "return_input_shapes: " << return_input_shapes[0][0] << return_input_shapes[0][1];
853     GenerateGraph gen_g = GenerateGraph(expect_matmul->cast<CNodePtr>()->attrs());
854     if (gen_g.Init(ret) != SUCCESS) {
855       MS_LOG(ERROR) << "MatMul->Return"
856                     << "GenerateGraph Init failed";
857     }
858 
859     Attr transpose_a_attr = std::make_pair(TRANSPOSE_A, MakeValue(false));
860     Attr transpose_b_attr = std::make_pair(TRANSPOSE_B, MakeValue(true));
861     OperatorAttrs matmul_attrs = {transpose_a_attr, transpose_b_attr};
862     auto matmul = gen_g.PushBack({gen_g.NewOpInst(prim::kPrimMatMul->name(), matmul_attrs), gen_g.virtual_input_node(),
863                                   gen_g.virtual_input_node()});
864 
865     if (return_input_shapes[0][0] == 1) {
866       auto des_shape = return_input_shapes[0];
867       auto des_size = return_input_shapes[0][1];
868       auto origin_size = des_size / devices;
869       Shape origin_shape;
870       origin_shape.push_back(origin_size);
871       ConstructOperator constructor;
872       constructor.UpdateTensorShape(origin_shape);
873 
874       auto reshape = gen_g.PushBack({gen_g.NewOpInst(prim::kPrimReshape->name()), matmul, CreateTuple(origin_shape)});
875       auto allgather = gen_g.PushBack({NewAllGatherNode(ALL_GATHER, HCCL_WORLD_GROUP), reshape});
876       auto reshape2 = gen_g.PushBack({gen_g.NewOpInst(prim::kPrimReshape->name()), allgather, CreateTuple(des_shape)});
877       std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(matmul, 1), std::make_pair(matmul, 2)};
878       auto replace_graph = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
879         std::make_pair(input_nodes, reshape2));
880       MS_LOG(INFO) << "StepReplaceGraph " << expect_matmul->ToString();
881       StepReplaceGraph(replace_graph, expect_matmul->cast<CNodePtr>());
882       return;
883 
884     } else {
885       auto allgather = gen_g.PushBack({NewAllGatherNode(ALL_GATHER, HCCL_WORLD_GROUP), matmul});
886       // split
887       int64_t split_count = devices;
888       Attr split_axis_attr = std::make_pair(AXIS, MakeValue(0));
889       Attr split_count_attr = std::make_pair(OUTPUT_NUM, MakeValue(split_count));
890       OperatorAttrs split_attrs = {split_axis_attr, split_count_attr};
891       auto split = gen_g.PushBack({gen_g.NewOpInst(SPLIT, split_attrs), allgather});
892 
893       // tuple get item and make tuple
894       std::vector<AnfNodePtr> maketuple_inputs;
895       maketuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
896       for (int64_t i = 0; i < split_count; ++i) {
897         auto tuple_get_item = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), split, CreatInt64Imm(i)});
898         maketuple_inputs.push_back(tuple_get_item);
899       }
900       auto maketuple = gen_g.PushBack(maketuple_inputs);
901 
902       // concat
903       Attr concat_axis_attr = std::make_pair(AXIS, MakeValue(1));
904       OperatorAttrs concat_attrs = {concat_axis_attr};
905       auto concat = gen_g.PushBack({gen_g.NewOpInst(CONCAT, concat_attrs), maketuple});
906 
907       std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(matmul, 1), std::make_pair(matmul, 2)};
908       auto replace_graph = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
909         std::make_pair(input_nodes, concat));
910       MS_LOG(INFO) << "StepReplaceGraph " << expect_matmul->DebugString();
911       StepReplaceGraph(replace_graph, expect_matmul->cast<CNodePtr>());
912       return;
913     }
914   }
915   return;
916 }
917 
StepAssignedParallel(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,size_t device_num,size_t rank_id,bool sapp)918 bool StepAssignedParallel(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, size_t device_num,
919                           size_t rank_id, bool sapp) {
920   MS_EXCEPTION_IF_NULL(root);
921   MS_EXCEPTION_IF_NULL(manager);
922   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
923   // control whether use model_parallel mode
924   if (device_num == 0 || device_num > 8) {
925     MS_LOG(EXCEPTION) << "Error: device_num is <= 0 or > 8.";
926     return false;
927   }
928 
929   MSLogTime msTime;
930   msTime.Start();
931 #ifdef ENABLE_DUMP_IR
932   auto context = MsContext::GetInstance();
933   MS_EXCEPTION_IF_NULL(context);
934   if (context->CanDump(kIntroductory)) {
935     DumpGraph(root, std::string("step_assigned_parallel_begin"));
936   }
937 #endif
938   MS_LOG(INFO) << "Now entering step assigned parallel";
939   TOTAL_OPS = 0;
940   AnfNodePtr ret = root->get_return();
941   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
942 
943   if (ParallelInit(rank_id, device_num) != SUCCESS) {
944     MS_LOG(EXCEPTION) << "Parallel init failed";
945   }
946 
947   MarkForwardCNode(root);
948 
949   if (sapp) {
950     CostModelContext::GetInstance()->set_rp_matmul_mem_coef(1);
951     if (ParallelStrategyRecSearch(all_nodes, root, rank_id, device_num) != SUCCESS) {
952       MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
953     }
954     root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
955   }
956 
957   InitRefMap(root);
958   // extract shape and strategy, set operator_info
959   ExtractGraphInformation(all_nodes);
960 
961   MS_LOG(INFO) << "Now Assigned insert AllReduce opsl";
962 
963   if (!InsertAllReduceOps(all_nodes, root, device_num)) {
964     MS_LOG(EXCEPTION) << "Assigned insert AllReduce ops failed.";
965   }
966   if (!InsertAllReduceOpsForFFN(all_nodes, root, device_num)) {
967     MS_LOG(EXCEPTION) << "Assigned insert AllReduce ops failed.";
968   }
969   if (!ModifyReshapeOps(all_nodes, root, device_num)) {
970     MS_LOG(EXCEPTION) << "Modify Reshape Ops failed.";
971   }
972   if (!ModifyMakeTupleOps(all_nodes, root, device_num)) {
973     MS_LOG(EXCEPTION) << "Modify Reshape Ops failed.";
974   }
975   if (!ModifySoftmaxReshapeOps(all_nodes, root, device_num)) {
976     MS_LOG(EXCEPTION) << "Modify Reshape Ops failed.";
977   }
978 
979   ReplaceGatherOps(all_nodes, device_num);
980   FixReturnRedistribution(root, device_num);
981   DoParameterSliceShape(root, rank_id);
982 #ifdef ENABLE_DUMP_IR
983   if (context->CanDump(kIntroductory)) {
984     DumpGraph(root, std::string("step_assigned_parallel_end"));
985   }
986 #endif
987 
988   msTime.End();
989   uint64_t time = msTime.GetRunTimeUS();
990 
991   MS_LOG(INFO) << "Now leaving step assigned parallel, used time: " << time << " us";
992 
993   return true;
994 }
995 
996 }  // namespace parallel
997 }  // namespace mindspore
998