1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/step_assigned_parallel.h"
18
19 #include <cinttypes>
20 #include <ctime>
21 #include <algorithm>
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <vector>
28
29 #include "mindspore/core/ops/sequence_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "mindspore/core/ops/math_ops.h"
32 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
33 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "frontend/parallel/graph_util/graph_info.h"
36 #include "frontend/parallel/graph_util/graph_utils.h"
37 #include "frontend/parallel/ops_info/tmp_identity_info.h"
38 #include "frontend/parallel/step_parallel.h"
39 #include "frontend/parallel/step_parallel_utils.h"
40 #include "frontend/parallel/step_auto_parallel.h"
41 #include "frontend/parallel/parameter_manager.h"
42 #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
43 #include "ir/anf.h"
44 #include "ir/tensor.h"
45 #include "frontend/parallel/graph_util/generate_graph.h"
46 #include "utils/parallel_node_check.h"
47
48 namespace mindspore {
49 namespace parallel {
50 // l_RefMap, for CNode B input i is a RefKey[Parameter C],
51 // it will be one item in map with key: C, and value: (B, i)
52 std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> l_RefMap;
53
GetOutputLayoutFromCNode(const CNodePtr & cnode,size_t output_index)54 static std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) {
55 MS_EXCEPTION_IF_NULL(cnode);
56 OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
57 MS_EXCEPTION_IF_NULL(distribute_operator);
58 if (distribute_operator->outputs_tensor_info().size() <= output_index) {
59 MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size()
60 << ", must be greater than output_index " << output_index;
61 }
62 TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index];
63 TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
64 return std::make_shared<TensorLayout>(tensorlayout_out);
65 }
66
FindPrevParallelCareNodeLayout(const AnfNodePtr & node,size_t output_index)67 static std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) {
68 if (!node->isa<CNode>()) {
69 return nullptr;
70 }
71 CNodePtr cnode = node->cast<CNodePtr>();
72 if (!IsValueNode<Primitive>(cnode->input(0))) {
73 return nullptr;
74 }
75 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
76 auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
77 if (!layout_ptr) {
78 MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
79 }
80 return layout_ptr;
81 }
82 return nullptr;
83 }
84
FindPrevLayout(const AnfNodePtr & node)85 static std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
86 if (node->isa<Parameter>()) {
87 return CreateParameterLayout(node);
88 }
89 if (!node->isa<CNode>()) {
90 return nullptr;
91 }
92 CNodePtr cnode = node->cast<CNodePtr>();
93 if (!IsValueNode<Primitive>(cnode->input(0))) {
94 return nullptr;
95 }
96 if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
97 return cnode->user_data<TensorLayout>();
98 }
99 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() &&
100 !IsPrimitiveCNode(node, prim::kPrimReshape)) {
101 auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
102 if (!layout_ptr) {
103 MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
104 }
105 return layout_ptr;
106 }
107 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
108 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
109 if (prim->name() == prim::kPrimTupleGetItem->name()) {
110 auto tuple_index = GetTupleGetItemIndex(cnode);
111 auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
112 if (!layout_ptr) {
113 MS_LOG(EXCEPTION) << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a "
114 "parallel care node "
115 "before tuple_getitem!";
116 }
117 return layout_ptr;
118 }
119 for (size_t index = 0; index < cnode->size(); ++index) {
120 if (prim->name() == DEPEND && index != 1) {
121 continue;
122 }
123 auto layout_ptr = FindPrevLayout(cnode->inputs()[index]);
124 if (!layout_ptr) {
125 continue;
126 }
127 return layout_ptr;
128 }
129 MS_LOG(WARNING) << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error";
130 return nullptr;
131 }
132
133 // if reshape's output connect to several primitive, return the first layout found
FindNextLayout(const CNodePtr & cnode,bool * next_is_reshape,int make_tuple_index)134 static std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape,
135 int make_tuple_index) {
136 MS_EXCEPTION_IF_NULL(cnode);
137 MS_EXCEPTION_IF_NULL(cnode->func_graph());
138 FuncGraphManagerPtr manager = cnode->func_graph()->manager();
139 MS_EXCEPTION_IF_NULL(manager);
140 AnfNodeIndexSet node_set = manager->node_users()[cnode];
141 for (auto &node_pair : node_set) {
142 auto use_apply = node_pair.first->cast<CNodePtr>();
143 if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
144 continue;
145 }
146 if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
147 *next_is_reshape = true;
148 continue;
149 }
150 if (IsPrimitiveCNode(use_apply, prim::kPrimDepend) && node_pair.second != 1) {
151 continue;
152 }
153 if (IsPrimitiveCNode(use_apply, prim::kPrimMakeTuple)) {
154 make_tuple_index = node_pair.second;
155 return FindNextLayout(use_apply, next_is_reshape, make_tuple_index);
156 }
157 if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>() &&
158 IsSomePrimitiveList(use_apply, SUPPORT_NEW_SHAPEBASE_OPS)) {
159 MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString() << ", in support new shapebase ops";
160 *next_is_reshape = false;
161 auto layout = GetInputLayoutFromCNode(node_pair, make_tuple_index);
162 return std::make_shared<TensorLayout>(layout);
163 }
164 if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
165 if (make_tuple_index != -1) {
166 node_pair.second = make_tuple_index;
167 }
168 MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString();
169 *next_is_reshape = false;
170 auto layout = GetInputLayoutFromCNode(node_pair, -1);
171 return std::make_shared<TensorLayout>(layout);
172 }
173 MS_LOG(DEBUG) << "FindNextLayout failed node " << use_apply->DebugString() << " " << IsParallelCareNode(use_apply)
174 << " " << use_apply->has_user_data<OperatorInfo>();
175
176 auto layout_ptr = FindNextLayout(use_apply, next_is_reshape, -1);
177 if (layout_ptr) {
178 return layout_ptr;
179 }
180 }
181 MS_LOG(WARNING) << "FindNextLayout return nullptr, if reshape is not the last primitive, there must be some error";
182 return nullptr;
183 }
184
NewAllGatherNode(const std::string & name,const std::string & group)185 AnfNodePtr NewAllGatherNode(const std::string &name, const std::string &group) {
186 std::shared_ptr<Primitive> prim;
187 prim = std::make_shared<Primitive>(name);
188 ValuePtr attr0_value = MakeValue(group);
189 Attr attr0 = std::make_pair(GROUP, attr0_value);
190 prim->AddAttr(GROUP, attr0_value);
191 prim->AddAttr("fusion", MakeValue(static_cast<int64_t>(0)));
192 prim->AddAttr("mean_flag", MakeValue(false));
193 prim->AddAttr("no_eliminate", MakeValue(true));
194 std::vector<unsigned int> rank_list = {};
195 auto long_rank_list = parallel::g_device_manager->FindRankListByHashName(group);
196 (void)std::transform(long_rank_list.begin(), long_rank_list.end(), std::back_inserter(rank_list),
197 [](int64_t d) -> unsigned int { return IntToUint(LongToInt(d)); });
198
199 prim->AddAttr(kAttrRankSize, MakeValue(static_cast<int64_t>(rank_list.size())));
200 auto node = NewValueNode(prim);
201 return node;
202 }
203
204 // From ops To AllReduce->ops
InsertAllReduceToNodeInput(const CNodePtr & node,const std::string & group,const std::string & instance_name)205 static void InsertAllReduceToNodeInput(const CNodePtr &node, const std::string &group,
206 const std::string &instance_name) {
207 MS_EXCEPTION_IF_NULL(node);
208 FuncGraphPtr func_graph = node->func_graph();
209 size_t index = 1;
210 MS_EXCEPTION_IF_NULL(func_graph);
211 Operator allreduce_op = CreateAllReduceOp(REDUCE_OP_SUM, group);
212
213 // Insert it as the input of the node
214 AnfNodePtr input = node->input(index);
215 MS_EXCEPTION_IF_NULL(input);
216 // if it is not a tensor, continue
217 if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
218 return;
219 }
220 InsertNode(allreduce_op, node, index, node->input(index), func_graph, instance_name);
221 }
222
InsertAllReduceOps(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)223 bool InsertAllReduceOps(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, const size_t devices) {
224 int64_t device_num = devices;
225 if (device_num <= 1) {
226 return true;
227 }
228 MS_EXCEPTION_IF_NULL(root);
229 for (auto &node : all_nodes) {
230 if (!node->isa<CNode>()) {
231 continue;
232 }
233 auto expect_add = node->cast<CNodePtr>();
234 if (!IsSomePrimitive(expect_add, prim::kPrimAdd->name())) {
235 continue;
236 }
237 AnfNodePtr expect_matmul = expect_add->input(1);
238 MS_EXCEPTION_IF_NULL(expect_matmul);
239 if (!expect_matmul->isa<CNode>()) {
240 continue;
241 }
242 auto expect_matmul_cnode = expect_matmul->cast<CNodePtr>();
243 if (!IsSomePrimitive(expect_matmul_cnode, prim::kPrimMatMul->name())) {
244 continue;
245 }
246 auto matmul_prim = GetCNodePrimitive(expect_matmul_cnode);
247 MS_EXCEPTION_IF_NULL(matmul_prim);
248 if (matmul_prim->HasAttr(IN_STRATEGY)) {
249 auto matmul_stra = matmul_prim->GetAttr(IN_STRATEGY);
250 if (matmul_stra == nullptr) {
251 continue;
252 }
253 auto matmul_var = GetValue<vector<Shape>>(matmul_stra);
254 if (matmul_var.size() > 0) {
255 Dimensions sub_a_strategy = matmul_var.at(0);
256 Dimensions sub_b_strategy = matmul_var.at(1);
257 if (sub_a_strategy.size() == 2 && sub_b_strategy.size() == 2 && sub_a_strategy[1] == sub_b_strategy[0] &&
258 sub_a_strategy[1] > 1) {
259 MS_LOG(INFO) << "Here should insert AllReduce Ops: ";
260 InsertAllReduceToNodeInput(expect_add, HCCL_WORLD_GROUP, PARALLEL_GLOBALNORM);
261 AnfNodePtr expect_reshape = expect_matmul_cnode->input(1);
262 if (!expect_reshape->isa<CNode>()) {
263 continue;
264 }
265 auto expect_reshape_cnode = expect_reshape->cast<CNodePtr>();
266 if (!IsSomePrimitive(expect_reshape_cnode, prim::kPrimReshape->name())) {
267 continue;
268 }
269 Shape origin_dst_shape =
270 GetValue<std::vector<int64_t>>(expect_reshape_cnode->input(2)->cast<ValueNodePtr>()->value());
271 if (origin_dst_shape.size() == 1 && origin_dst_shape[0] == -1) {
272 continue;
273 }
274 Shape new_dst_shape;
275 new_dst_shape.push_back(origin_dst_shape[0]);
276 new_dst_shape.push_back(origin_dst_shape[1] / device_num);
277 for (auto s : new_dst_shape) {
278 MS_LOG(INFO) << "new_dst_shape: " << s;
279 }
280
281 expect_reshape_cnode->set_input(2, NewValueNode(MakeValue(new_dst_shape)));
282
283 auto reshape_node_abstract = expect_reshape_cnode->abstract()->Clone();
284 std::shared_ptr<abstract::BaseShape> output_shape = std::make_shared<abstract::Shape>(new_dst_shape);
285 reshape_node_abstract->set_shape(output_shape);
286 MS_LOG(INFO) << "new_dst_shape: " << reshape_node_abstract->ToString();
287 expect_reshape_cnode->set_abstract(reshape_node_abstract);
288 }
289 }
290 }
291 }
292 return true;
293 }
294
InsertAllReduceOpsForFFN(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)295 bool InsertAllReduceOpsForFFN(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root,
296 const size_t devices) {
297 MS_EXCEPTION_IF_NULL(root);
298 for (auto &node : all_nodes) {
299 if (!node->isa<CNode>()) {
300 continue;
301 }
302 auto expect_add = node->cast<CNodePtr>();
303 if (!IsSomePrimitive(expect_add, prim::kPrimAdd->name())) {
304 continue;
305 }
306 AnfNodePtr expect_batchmatmul = expect_add->input(1);
307 MS_EXCEPTION_IF_NULL(expect_batchmatmul);
308 if (!expect_batchmatmul->isa<CNode>()) {
309 continue;
310 }
311 auto expect_batchmatmul_cnode = expect_batchmatmul->cast<CNodePtr>();
312 if (!IsSomePrimitive(expect_batchmatmul_cnode, prim::kPrimBatchMatMul->name())) {
313 continue;
314 }
315 auto batchmatmul_prim = GetCNodePrimitive(expect_batchmatmul_cnode);
316 MS_EXCEPTION_IF_NULL(batchmatmul_prim);
317 if (batchmatmul_prim->HasAttr(IN_STRATEGY)) {
318 auto batchmatmul_stra = batchmatmul_prim->GetAttr(IN_STRATEGY);
319 if (batchmatmul_stra == nullptr) {
320 continue;
321 }
322 auto batchmatmul_var = GetValue<vector<Shape>>(batchmatmul_stra);
323 if (batchmatmul_var.size() > 0) {
324 Dimensions sub_a_strategy = batchmatmul_var.at(0);
325 Dimensions sub_b_strategy = batchmatmul_var.at(1);
326 if (sub_a_strategy.size() == 4 && sub_b_strategy.size() == 3 && sub_a_strategy[3] == sub_b_strategy[1] &&
327 sub_a_strategy[3] > 1) {
328 MS_LOG(INFO) << "Here should insert AllReduce Ops: ";
329 InsertAllReduceToNodeInput(expect_add, HCCL_WORLD_GROUP, PARALLEL_GLOBALNORM);
330 }
331 }
332 }
333 }
334 return true;
335 }
336
ChangeReshape(const AnfNodePtr & node,const size_t devices)337 void ChangeReshape(const AnfNodePtr &node, const size_t devices) {
338 int64_t device_num = devices;
339 MS_EXCEPTION_IF_NULL(node);
340 if (!node->isa<CNode>()) {
341 return;
342 }
343 auto expect_reshape_cnode = node->cast<CNodePtr>();
344 if (!IsSomePrimitive(expect_reshape_cnode, prim::kPrimReshape->name())) {
345 return;
346 }
347 auto reshape_node_input = expect_reshape_cnode->input(2);
348 if (reshape_node_input == nullptr) {
349 return;
350 }
351 MS_LOG(INFO) << "find reshape ops: " << expect_reshape_cnode->DebugString();
352 if (reshape_node_input->isa<ValueNode>()) {
353 Shape origin_dst_shape = GetValue<std::vector<int64_t>>(reshape_node_input->cast<ValueNodePtr>()->value());
354 if (origin_dst_shape.size() != 4) {
355 return;
356 }
357 if (origin_dst_shape[2] % device_num != 0) {
358 return;
359 }
360 Shape new_dst_shape;
361 new_dst_shape.push_back(origin_dst_shape[0]);
362 new_dst_shape.push_back(origin_dst_shape[1]);
363 new_dst_shape.push_back(origin_dst_shape[2] / device_num);
364 new_dst_shape.push_back(origin_dst_shape[3]);
365 for (auto s : new_dst_shape) {
366 MS_LOG(INFO) << "reshape new_dst_shape: " << s;
367 }
368 expect_reshape_cnode->set_input(2, NewValueNode(MakeValue(new_dst_shape)));
369 auto reshape_node_abstract = expect_reshape_cnode->abstract()->Clone();
370 std::shared_ptr<abstract::BaseShape> output_shape = std::make_shared<abstract::Shape>(new_dst_shape);
371 reshape_node_abstract->set_shape(output_shape);
372 MS_LOG(INFO) << "new_dst_shape: " << reshape_node_abstract->ToString();
373 expect_reshape_cnode->set_abstract(reshape_node_abstract);
374
375 } else if (reshape_node_input->isa<CNode>()) {
376 auto expect_maketuple_cnode = reshape_node_input->cast<CNodePtr>();
377 MS_LOG(INFO) << "Before modify reshape maketuple: " << expect_maketuple_cnode->DebugString();
378 if (!IsSomePrimitive(expect_maketuple_cnode, prim::kPrimMakeTuple->name())) {
379 return;
380 }
381 auto maketuple_node_input = expect_maketuple_cnode->input(3);
382 if (maketuple_node_input == nullptr) {
383 return;
384 }
385 if (!maketuple_node_input->isa<ValueNode>()) {
386 return;
387 }
388 int64_t origin_value = GetValue<int64_t>(maketuple_node_input->cast<ValueNodePtr>()->value());
389 if (origin_value % device_num == 0 && !expect_maketuple_cnode->HasAttr("has_modifyed")) {
390 int64_t new_value = origin_value / device_num;
391 expect_maketuple_cnode->set_input(3, NewValueNode(MakeValue(new_value)));
392 expect_maketuple_cnode->AddAttr("has_modifyed", MakeValue(true));
393 MS_LOG(INFO) << "After modify reshape maketuple: " << expect_maketuple_cnode->DebugString();
394 }
395 }
396 }
397
ModifyReshapeOps(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)398 bool ModifyReshapeOps(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, const size_t devices) {
399 int64_t device_num = devices;
400 MS_EXCEPTION_IF_NULL(root);
401 for (auto &node : all_nodes) {
402 if (!node->isa<CNode>()) {
403 continue;
404 }
405 auto expect_transpose = node->cast<CNodePtr>();
406 if (!IsSomePrimitive(expect_transpose, prim::kPrimTranspose->name())) {
407 continue;
408 }
409 auto transpose_prim = GetCNodePrimitive(expect_transpose);
410 MS_EXCEPTION_IF_NULL(transpose_prim);
411 if (!transpose_prim->HasAttr(IN_STRATEGY)) {
412 continue;
413 }
414 auto transpose_stra = transpose_prim->GetAttr(IN_STRATEGY);
415 if (transpose_stra == nullptr) {
416 continue;
417 }
418 auto transpose_var = GetValue<vector<Shape>>(transpose_stra);
419 if (transpose_var.size() > 0) {
420 Dimensions sub_strategy = transpose_var.at(0);
421 bool all_ones = std::all_of(sub_strategy.begin(), sub_strategy.end(), [](int64_t i) { return i == 1; });
422 if (all_ones) {
423 continue;
424 }
425 }
426 AnfNodePtr expect_reshape = expect_transpose->input(1);
427 ChangeReshape(expect_reshape, device_num);
428 }
429 return true;
430 }
431
ModifyMakeTupleOps(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)432 bool ModifyMakeTupleOps(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, const size_t devices) {
433 int64_t device_num = devices;
434 MS_EXCEPTION_IF_NULL(root);
435 for (auto &node : all_nodes) {
436 if (!node->isa<CNode>()) {
437 continue;
438 }
439 auto expect_maketuple = node->cast<CNodePtr>();
440 if (!IsSomePrimitive(expect_maketuple, prim::kPrimMakeTuple->name())) {
441 continue;
442 }
443 if (expect_maketuple->size() != 4) {
444 continue;
445 }
446 if (expect_maketuple->input(1)->isa<CNode>() && expect_maketuple->input(2)->isa<CNode>() &&
447 expect_maketuple->input(3)->isa<ValueNode>()) {
448 if (IsSomePrimitive(expect_maketuple->input(1)->cast<CNodePtr>(), prim::kPrimTupleGetItem->name()) &&
449 IsSomePrimitive(expect_maketuple->input(2)->cast<CNodePtr>(), prim::kPrimTupleGetItem->name())) {
450 auto maketuple_node_input = expect_maketuple->input(3);
451 int64_t origin_value = GetValue<int64_t>(maketuple_node_input->cast<ValueNodePtr>()->value());
452 if (origin_value % device_num == 0) {
453 int64_t new_value = origin_value / device_num;
454 expect_maketuple->set_input(3, NewValueNode(MakeValue(new_value)));
455 MS_LOG(INFO) << "After modify MakeTuple, the shape is : " << expect_maketuple->DebugString();
456 }
457 }
458 }
459 }
460 return true;
461 }
462
ModifySoftmaxReshapeOps(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const size_t devices)463 bool ModifySoftmaxReshapeOps(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root, const size_t devices) {
464 int64_t device_num = devices;
465 MS_EXCEPTION_IF_NULL(root);
466 for (auto &node : all_nodes) {
467 if (!node->isa<CNode>()) {
468 continue;
469 }
470 auto expect_reshape = node->cast<CNodePtr>();
471 if (!IsSomePrimitive(expect_reshape, prim::kPrimReshape->name())) {
472 continue;
473 }
474
475 AnfNodePtr expect_cast = expect_reshape->input(1);
476 MS_EXCEPTION_IF_NULL(expect_cast);
477 if (!expect_cast->isa<CNode>()) {
478 continue;
479 }
480 auto expect_cast_cnode = expect_cast->cast<CNodePtr>();
481 if (!IsSomePrimitive(expect_cast_cnode, "Cast")) {
482 continue;
483 }
484
485 auto expect_softmax = expect_cast_cnode->input(1);
486 MS_EXCEPTION_IF_NULL(expect_softmax);
487 if (!expect_softmax->isa<CNode>()) {
488 continue;
489 }
490 auto expect_softmax_cnode = expect_softmax->cast<CNodePtr>();
491 if (!IsSomePrimitive(expect_softmax_cnode, "Softmax")) {
492 continue;
493 }
494 auto reshape_node_input = expect_reshape->input(2);
495 if (reshape_node_input == nullptr) {
496 continue;
497 }
498 if (!reshape_node_input->isa<ValueNode>()) {
499 continue;
500 }
501 Shape origin_dst_shape = GetValue<std::vector<int64_t>>(reshape_node_input->cast<ValueNodePtr>()->value());
502 if (origin_dst_shape.size() != 4) {
503 continue;
504 }
505 if (origin_dst_shape[1] % device_num != 0) {
506 continue;
507 }
508 Shape new_dst_shape;
509 new_dst_shape.push_back(origin_dst_shape[0]);
510 new_dst_shape.push_back(origin_dst_shape[1] / device_num);
511 new_dst_shape.push_back(origin_dst_shape[2]);
512 new_dst_shape.push_back(origin_dst_shape[3]);
513 for (auto s : new_dst_shape) {
514 MS_LOG(INFO) << "reshape new_dst_shape: " << s;
515 }
516
517 expect_reshape->set_input(2, NewValueNode(MakeValue(new_dst_shape)));
518
519 auto reshape_node_abstract = expect_reshape->abstract()->Clone();
520 std::shared_ptr<abstract::BaseShape> output_shape = std::make_shared<abstract::Shape>(new_dst_shape);
521 reshape_node_abstract->set_shape(output_shape);
522 MS_LOG(INFO) << "new_dst_shape: " << reshape_node_abstract->ToString();
523 expect_reshape->set_abstract(reshape_node_abstract);
524 }
525 return true;
526 }
527
CheckExtractInformation(const CNodePtr & cnode)528 static bool CheckExtractInformation(const CNodePtr &cnode) {
529 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
530 return false;
531 }
532
533 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
534 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
535 if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
536 return false;
537 }
538 if (!IsParallelCareNode(cnode)) {
539 return false;
540 }
541 return true;
542 }
543
InitRefMap(const FuncGraphPtr & root)544 void InitRefMap(const FuncGraphPtr &root) {
545 auto manager = root->manager();
546 auto node_list = TopoSort(root->get_return());
547 for (auto &node : node_list) {
548 auto cnode = node->cast<CNodePtr>();
549 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
550 continue;
551 }
552
553 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
554 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
555 if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
556 continue;
557 }
558 if (IsPrimitiveCNode(node, prim::kPrimSend) || IsPrimitiveCNode(node, prim::kPrimUpdateState) ||
559 IsPrimitiveCNode(node, prim::kPrimDepend)) {
560 continue;
561 }
562 std::vector<AnfNodePtr> all_inputs = cnode->inputs();
563 size_t inputs_size = all_inputs.size();
564 for (size_t i = 1; i < inputs_size; ++i) {
565 AnfNodePtr input = all_inputs[i];
566 if (HasAbstractMonad(input)) {
567 continue;
568 }
569 if (input->isa<Parameter>() && input->cast<ParameterPtr>()->has_default()) {
570 auto func_graph = cnode->func_graph();
571 MS_EXCEPTION_IF_NULL(func_graph);
572 auto param_node = input->cast<ParameterPtr>();
573 std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(cnode, SizeToLong(i));
574 if (IsInTrivialNodeList(cnode) || IsSomePrimitive(cnode, prim::kPrimLoad->name())) {
575 auto &node_users = manager->node_users();
576 auto iter = node_users.find(node);
577 if (iter == node_users.end()) {
578 MS_LOG(ERROR) << "Can not find the parameter used node.";
579 }
580 auto &node_set = iter->second;
581 const auto node_set_back = node_set.back().first->cast<CNodePtr>();
582 if (node_set_back != nullptr && IsSomePrimitive(node_set_back, prim::kPrimMakeTuple->name())) {
583 l_RefMap[param_node] = node_set.front();
584 } else {
585 l_RefMap[param_node] = node_set.back();
586 }
587 } else {
588 l_RefMap[param_node] = node_pair;
589 }
590 }
591 }
592 }
593 }
594
SetParallelShape(const AnfNodePtr & parameter,const std::pair<AnfNodePtr,int64_t> & res,size_t rank_id)595 static void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int64_t> &res, size_t rank_id) {
596 MS_LOG(INFO) << "Begin set parallel shape";
597 // check null for param and cnode
598 auto param_shape = parameter->Shape();
599
600 MS_EXCEPTION_IF_NULL(parameter);
601 MS_EXCEPTION_IF_NULL(param_shape);
602
603 CNodePtr cnode = res.first->cast<CNodePtr>();
604 MS_EXCEPTION_IF_NULL(cnode);
605 OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
606 if (distribute_operator == nullptr) {
607 MS_LOG(EXCEPTION) << "node " << cnode->DebugString() << " 's distribute_operator is nullptr";
608 }
609 if (LongToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
610 MS_LOG(EXCEPTION) << "The parameter index is not in inputs_tensor_info. index = " << (res.second - 1)
611 << ", inputs_tensor_info size = " << distribute_operator->inputs_tensor_info().size();
612 }
613 TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(res.second - 1)];
614 TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
615 Shape slice_shape = tensor_layout.slice_shape().array();
616
617 AbstractBasePtr abstract = parameter->abstract();
618 if (abstract == nullptr) {
619 MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract is nullptr";
620 }
621
622 AbstractBasePtr cloned_abstract = abstract->Clone();
623 if (cloned_abstract == nullptr) {
624 MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract clone failed";
625 }
626
627 cloned_abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape));
628 parameter->set_abstract(cloned_abstract);
629 ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
630
631 MS_EXCEPTION_IF_NULL(parameter_ptr);
632 MS_LOG(INFO) << "Begin split parameters";
633 parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
634 if (ParallelContext::GetInstance()->direct_split() && parameter_ptr->has_default()) {
635 auto layout = parameter_ptr->user_data<TensorLayout>();
636 MS_LOG(INFO) << "parameter: " << parameter->ToString() << parameter->Shape()->ToString()
637 << "parameter_ptr->default_param()" << parameter_ptr->default_param() << "LAYOUT"
638 << layout->ToString();
639 SliceTensorObj(parameter_ptr, layout, rank_id);
640 }
641 }
642
DoParameterSliceShape(const FuncGraphPtr & root,size_t rank_id)643 static void DoParameterSliceShape(const FuncGraphPtr &root, size_t rank_id) {
644 MS_EXCEPTION_IF_NULL(root);
645 auto parameters = root->parameters();
646 for (auto ¶meter : parameters) {
647 MS_EXCEPTION_IF_NULL(parameter->Shape());
648 auto iter = l_RefMap.find(parameter);
649 if (iter != l_RefMap.cend()) {
650 MS_LOG(INFO) << "SetParallelShape for parameter: " << parameter->ToString();
651 SetParallelShape(parameter, l_RefMap[parameter], rank_id);
652 SetSharedParameterFlag(root, parameter);
653 continue;
654 }
655 }
656 l_RefMap.clear();
657 }
658
ExtractAndModifyStrategy(const CNodePtr & cnode,const std::string & attr_name,const ValuePtr & stra)659 StrategyPtr ExtractAndModifyStrategy(const CNodePtr &cnode, const std::string &attr_name, const ValuePtr &stra) {
660 if (stra == nullptr) {
661 return nullptr;
662 }
663 auto var = stra->cast<ValueTuplePtr>();
664 if (var == nullptr) {
665 return nullptr;
666 }
667
668 StrategyPtr strategyPtr;
669 int64_t stage_id = g_device_manager->stage_id();
670 MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
671 int64_t device_num = g_device_manager->DeviceNum();
672 MS_LOG(INFO) << "Extract information: device_num " << device_num;
673 if (var->size() > 0) {
674 std::vector<ValuePtr> elements = var->value();
675 Strategies strategy;
676 for (uint64_t index = 0; index < elements.size(); ++index) {
677 Dimensions dim;
678 if (elements[index]->isa<ValueSequence>()) {
679 auto value_tuple = elements[index]->cast<ValueTuplePtr>();
680 std::vector<ValuePtr> value_vector = value_tuple->value();
681 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
682 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
683 for (size_t i = 0; i < dim.size(); i++) {
684 if (dim[i] > 1 && dim[i] != device_num) {
685 dim[i] = device_num;
686 }
687 }
688 strategy.push_back(dim);
689 } else {
690 MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
691 }
692 }
693 if (strategy.empty()) {
694 MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
695 }
696 cnode->AddPrimalAttr(attr_name, MakeValue(strategy));
697 strategyPtr = NewStrategy(stage_id, strategy);
698 MS_LOG(INFO) << "Extract information: new strategy " << cnode->GetPrimalAttr(attr_name)->ToString();
699 }
700 return strategyPtr;
701 }
702
ExtractStrategyAndInit(const CNodePtr & cnode,const PrimitivePtr & prim,const OperatorInfoPtr & op_info)703 static void ExtractStrategyAndInit(const CNodePtr &cnode, const PrimitivePtr &prim, const OperatorInfoPtr &op_info) {
704 StrategyPtr in_strategy = nullptr, out_strategy = nullptr;
705 auto attrs = prim->attrs();
706
707 // load strategy map from checkpoint
708 StrategyMap stra_map;
709
710 std::string strategy_key_name = "";
711 auto param_names = NodeParameterName(cnode, -1, 0);
712 if (!param_names.empty()) {
713 strategy_key_name = prim->name() + "_" + param_names[0].first;
714 }
715 if (!prim->HasAttr(STAND_ALONE)) {
716 if ((!StrategyFound(attrs) && !cnode->HasPrimalAttr(IN_STRATEGY)) || prim->HasAttr(BATCH_PARALLEL)) {
717 MS_LOG(INFO) << "ExtractInformation: the strategy of node " << cnode->ToString() << " prim " << prim->name()
718 << " is empty, using batch parallel";
719 in_strategy = GenerateBatchParallelStrategy(op_info, prim);
720 } else if (cnode->HasPrimalAttr(IN_STRATEGY)) {
721 in_strategy = ExtractAndModifyStrategy(cnode, IN_STRATEGY, cnode->GetPrimalAttr(IN_STRATEGY));
722
723 out_strategy = ExtractAndModifyStrategy(cnode, OUT_STRATEGY, cnode->GetPrimalAttr(OUT_STRATEGY));
724 } else if (StrategyFound(attrs)) {
725 in_strategy = ExtractAndModifyStrategy(cnode, IN_STRATEGY, attrs[IN_STRATEGY]);
726 out_strategy = ExtractAndModifyStrategy(cnode, OUT_STRATEGY, attrs[OUT_STRATEGY]);
727 } else {
728 in_strategy = stra_map[strategy_key_name];
729 }
730 } else {
731 in_strategy = GenerateStandAloneStrategy(op_info->inputs_shape());
732 }
733
734 MS_EXCEPTION_IF_NULL(in_strategy);
735 if (op_info->Init(in_strategy, out_strategy) == FAILED) {
736 MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed" << trace::DumpSourceLines(cnode);
737 }
738 }
739
ExtractGraphInformation(const std::vector<AnfNodePtr> & all_nodes)740 void ExtractGraphInformation(const std::vector<AnfNodePtr> &all_nodes) {
741 MS_LOG(INFO) << "ExtractInformation";
742 SetStridedSliceSplitStrategy(all_nodes);
743 for (auto &node : all_nodes) {
744 auto cnode = node->cast<CNodePtr>();
745 if (!CheckExtractInformation(cnode) || IsPrimitiveCNode(node, prim::kPrimSend) ||
746 IsPrimitiveCNode(node, std::make_shared<Primitive>("PadV3")) ||
747 IsPrimitiveCNode(node, std::make_shared<Primitive>("StridedSlice")) ||
748 IsPrimitiveCNode(node, std::make_shared<Primitive>("Sort")) ||
749 IsPrimitiveCNode(node, std::make_shared<Primitive>("Less")) ||
750 IsPrimitiveCNode(node, std::make_shared<Primitive>("Range"))) {
751 continue;
752 }
753
754 SetVirtualDatasetStrategy(cnode);
755 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
756 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
757
758 OperatorInfoPtr operator_ = CreateOperatorInfo(cnode);
759 operator_->set_assigned_parallel(true);
760 MS_EXCEPTION_IF_NULL(operator_);
761
762 if (prim->name() == RESHAPE) {
763 cnode->set_user_data<OperatorInfo>(operator_);
764 continue;
765 }
766
767 ExtractStrategyAndInit(cnode, prim, operator_);
768 cnode->set_user_data<OperatorInfo>(operator_);
769 }
770 }
771
StepReplaceGraph(const ReplaceGraphPtr & replace_graph,const CNodePtr & node)772 static void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) {
773 MS_EXCEPTION_IF_NULL(replace_graph);
774 MS_EXCEPTION_IF_NULL(node);
775 MS_EXCEPTION_IF_NULL(replace_graph->second);
776 FuncGraphPtr func_graph = node->func_graph();
777 MS_EXCEPTION_IF_NULL(func_graph);
778 FuncGraphManagerPtr manager = func_graph->manager();
779 if (manager == nullptr) {
780 MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
781 }
782 mindspore::HashMap<AnfNodePtr, int> input_map = {};
783 static int appear_count = 0;
784 for (auto &replace_input : replace_graph->first) {
785 auto pre_node = node->input(LongToSize(replace_input.second));
786
787 auto it = input_map.find(replace_input.first);
788 if (it != input_map.end()) {
789 appear_count = 1 + it->second;
790 } else {
791 appear_count = 1;
792 }
793 auto replace_input_cnode = replace_input.first->cast<CNodePtr>();
794 size_t inputs_size = replace_input_cnode->size();
795 while (IntToSize(appear_count) < inputs_size && replace_input_cnode->input(appear_count)->func_graph() != nullptr) {
796 ++appear_count;
797 }
798 if (IntToSize(appear_count) >= inputs_size) {
799 MS_LOG(EXCEPTION) << "No replaceable virtual_input_node";
800 }
801 input_map[replace_input.first] = appear_count;
802 replace_input_cnode->set_in_forward_flag(true);
803 manager->SetEdge(replace_input.first, appear_count, pre_node);
804 }
805 // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
806 auto replace_output = replace_graph->second->cast<CNodePtr>();
807 MS_EXCEPTION_IF_NULL(replace_output);
808 replace_output->set_in_forward_flag(true);
809 replace_output->set_primal_attrs(node->primal_attrs());
810 (void)manager->Replace(node, replace_output);
811 }
812
ReplaceGatherOps(const std::vector<AnfNodePtr> & all_nodes,const size_t devices)813 static void ReplaceGatherOps(const std::vector<AnfNodePtr> &all_nodes, const size_t devices) {
814 for (auto &node : all_nodes) {
815 MS_EXCEPTION_IF_NULL(node);
816 if (node->isa<CNode>()) {
817 auto cnode = node->cast<CNodePtr>();
818 if (!IsSomePrimitive(cnode, prim::kPrimGather->name())) {
819 continue;
820 }
821 OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
822 MS_EXCEPTION_IF_NULL(distribute_operator);
823 auto replace_op = distribute_operator->replace_op();
824 // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore.
825 auto replace_graph = distribute_operator->replace_graph(cnode);
826 if (!replace_op.empty() && replace_graph) {
827 MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used";
828 }
829 if (replace_graph) {
830 MS_LOG(INFO) << "StepReplaceGraph " << cnode->DebugString();
831 StepReplaceGraph(replace_graph, cnode);
832 }
833 }
834 }
835 }
836
FixReturnRedistribution(const FuncGraphPtr & root,const size_t devices)837 static void FixReturnRedistribution(const FuncGraphPtr &root, const size_t devices) {
838 MS_LOG(INFO) << "FixReturnRedistribution";
839 CNodePtr ret = root->get_return();
840 AnfNodePtr expect_matmul = ret->input(1);
841 MS_EXCEPTION_IF_NULL(expect_matmul);
842 if (!expect_matmul->isa<CNode>()) {
843 return;
844 }
845 auto expect_matmul_node = expect_matmul->cast<CNodePtr>();
846 if (!IsSomePrimitive(expect_matmul_node, prim::kPrimMatMul->name())) {
847 return;
848 }
849 Shapes return_input_shapes = GetNodeShape(ret);
850 MS_LOG(INFO) << "return_input_shapes size" << return_input_shapes.size();
851 if (return_input_shapes.size() == 1) {
852 MS_LOG(INFO) << "return_input_shapes: " << return_input_shapes[0][0] << return_input_shapes[0][1];
853 GenerateGraph gen_g = GenerateGraph(expect_matmul->cast<CNodePtr>()->attrs());
854 if (gen_g.Init(ret) != SUCCESS) {
855 MS_LOG(ERROR) << "MatMul->Return"
856 << "GenerateGraph Init failed";
857 }
858
859 Attr transpose_a_attr = std::make_pair(TRANSPOSE_A, MakeValue(false));
860 Attr transpose_b_attr = std::make_pair(TRANSPOSE_B, MakeValue(true));
861 OperatorAttrs matmul_attrs = {transpose_a_attr, transpose_b_attr};
862 auto matmul = gen_g.PushBack({gen_g.NewOpInst(prim::kPrimMatMul->name(), matmul_attrs), gen_g.virtual_input_node(),
863 gen_g.virtual_input_node()});
864
865 if (return_input_shapes[0][0] == 1) {
866 auto des_shape = return_input_shapes[0];
867 auto des_size = return_input_shapes[0][1];
868 auto origin_size = des_size / devices;
869 Shape origin_shape;
870 origin_shape.push_back(origin_size);
871 ConstructOperator constructor;
872 constructor.UpdateTensorShape(origin_shape);
873
874 auto reshape = gen_g.PushBack({gen_g.NewOpInst(prim::kPrimReshape->name()), matmul, CreateTuple(origin_shape)});
875 auto allgather = gen_g.PushBack({NewAllGatherNode(ALL_GATHER, HCCL_WORLD_GROUP), reshape});
876 auto reshape2 = gen_g.PushBack({gen_g.NewOpInst(prim::kPrimReshape->name()), allgather, CreateTuple(des_shape)});
877 std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(matmul, 1), std::make_pair(matmul, 2)};
878 auto replace_graph = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
879 std::make_pair(input_nodes, reshape2));
880 MS_LOG(INFO) << "StepReplaceGraph " << expect_matmul->ToString();
881 StepReplaceGraph(replace_graph, expect_matmul->cast<CNodePtr>());
882 return;
883
884 } else {
885 auto allgather = gen_g.PushBack({NewAllGatherNode(ALL_GATHER, HCCL_WORLD_GROUP), matmul});
886 // split
887 int64_t split_count = devices;
888 Attr split_axis_attr = std::make_pair(AXIS, MakeValue(0));
889 Attr split_count_attr = std::make_pair(OUTPUT_NUM, MakeValue(split_count));
890 OperatorAttrs split_attrs = {split_axis_attr, split_count_attr};
891 auto split = gen_g.PushBack({gen_g.NewOpInst(SPLIT, split_attrs), allgather});
892
893 // tuple get item and make tuple
894 std::vector<AnfNodePtr> maketuple_inputs;
895 maketuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
896 for (int64_t i = 0; i < split_count; ++i) {
897 auto tuple_get_item = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), split, CreatInt64Imm(i)});
898 maketuple_inputs.push_back(tuple_get_item);
899 }
900 auto maketuple = gen_g.PushBack(maketuple_inputs);
901
902 // concat
903 Attr concat_axis_attr = std::make_pair(AXIS, MakeValue(1));
904 OperatorAttrs concat_attrs = {concat_axis_attr};
905 auto concat = gen_g.PushBack({gen_g.NewOpInst(CONCAT, concat_attrs), maketuple});
906
907 std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(matmul, 1), std::make_pair(matmul, 2)};
908 auto replace_graph = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
909 std::make_pair(input_nodes, concat));
910 MS_LOG(INFO) << "StepReplaceGraph " << expect_matmul->DebugString();
911 StepReplaceGraph(replace_graph, expect_matmul->cast<CNodePtr>());
912 return;
913 }
914 }
915 return;
916 }
917
StepAssignedParallel(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,size_t device_num,size_t rank_id,bool sapp)918 bool StepAssignedParallel(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, size_t device_num,
919 size_t rank_id, bool sapp) {
920 MS_EXCEPTION_IF_NULL(root);
921 MS_EXCEPTION_IF_NULL(manager);
922 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
923 // control whether use model_parallel mode
924 if (device_num == 0 || device_num > 8) {
925 MS_LOG(EXCEPTION) << "Error: device_num is <= 0 or > 8.";
926 return false;
927 }
928
929 MSLogTime msTime;
930 msTime.Start();
931 #ifdef ENABLE_DUMP_IR
932 auto context = MsContext::GetInstance();
933 MS_EXCEPTION_IF_NULL(context);
934 if (context->CanDump(kIntroductory)) {
935 DumpGraph(root, std::string("step_assigned_parallel_begin"));
936 }
937 #endif
938 MS_LOG(INFO) << "Now entering step assigned parallel";
939 TOTAL_OPS = 0;
940 AnfNodePtr ret = root->get_return();
941 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
942
943 if (ParallelInit(rank_id, device_num) != SUCCESS) {
944 MS_LOG(EXCEPTION) << "Parallel init failed";
945 }
946
947 MarkForwardCNode(root);
948
949 if (sapp) {
950 CostModelContext::GetInstance()->set_rp_matmul_mem_coef(1);
951 if (ParallelStrategyRecSearch(all_nodes, root, rank_id, device_num) != SUCCESS) {
952 MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode";
953 }
954 root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
955 }
956
957 InitRefMap(root);
958 // extract shape and strategy, set operator_info
959 ExtractGraphInformation(all_nodes);
960
961 MS_LOG(INFO) << "Now Assigned insert AllReduce opsl";
962
963 if (!InsertAllReduceOps(all_nodes, root, device_num)) {
964 MS_LOG(EXCEPTION) << "Assigned insert AllReduce ops failed.";
965 }
966 if (!InsertAllReduceOpsForFFN(all_nodes, root, device_num)) {
967 MS_LOG(EXCEPTION) << "Assigned insert AllReduce ops failed.";
968 }
969 if (!ModifyReshapeOps(all_nodes, root, device_num)) {
970 MS_LOG(EXCEPTION) << "Modify Reshape Ops failed.";
971 }
972 if (!ModifyMakeTupleOps(all_nodes, root, device_num)) {
973 MS_LOG(EXCEPTION) << "Modify Reshape Ops failed.";
974 }
975 if (!ModifySoftmaxReshapeOps(all_nodes, root, device_num)) {
976 MS_LOG(EXCEPTION) << "Modify Reshape Ops failed.";
977 }
978
979 ReplaceGatherOps(all_nodes, device_num);
980 FixReturnRedistribution(root, device_num);
981 DoParameterSliceShape(root, rank_id);
982 #ifdef ENABLE_DUMP_IR
983 if (context->CanDump(kIntroductory)) {
984 DumpGraph(root, std::string("step_assigned_parallel_end"));
985 }
986 #endif
987
988 msTime.End();
989 uint64_t time = msTime.GetRunTimeUS();
990
991 MS_LOG(INFO) << "Now leaving step assigned parallel, used time: " << time << " us";
992
993 return true;
994 }
995
996 } // namespace parallel
997 } // namespace mindspore
998