1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/step_parallel.h"
18
19 #include <inttypes.h>
20 #include <sys/time.h>
21 #include <algorithm>
22
23 #include <map>
24 #include <memory>
25 #include <set>
26 #include <string>
27 #include <unordered_map>
28 #include <utility>
29
30 #include "base/core_ops.h"
31 #include "frontend/operator/ops.h"
32 #include "frontend/optimizer/optimizer.h"
33 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
34 #include "frontend/parallel/context.h"
35 #include "frontend/parallel/device_manager.h"
36 #include "frontend/parallel/dynamic_creator.h"
37 #include "frontend/parallel/graph_util/generate_graph.h"
38 #include "frontend/parallel/graph_util/graph_info.h"
39 #include "frontend/parallel/graph_util/node_info.h"
40 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
41 #include "frontend/parallel/node_check.h"
42 #include "frontend/parallel/parameter_manager.h"
43 #include "frontend/parallel/ops_info/matmul_info.h"
44 #include "ir/param_info.h"
45 #include "ir/tensor.h"
46 #include "utils/trace_base.h"
47 #include "utils/comm_manager.h"
48 #include "utils/ms_context.h"
49 #include "utils/symbolic.h"
50 #include "mindspore/core/utils/parallel_node_check.h"
51 #if ((defined ENABLE_CPU) && (!defined _WIN32))
52 #include "ps/util.h"
53 #include "ps/ps_context.h"
54 #endif
55
56 using mindspore::tensor::Tensor;
57
58 namespace mindspore {
59 namespace parallel {
60 static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
61 static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
62 static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
63 // g_RefMap, for CNode B input i is a RefKey[Parameter C],
64 // it will be one item in map with key: C, and value: (B, i)
65 std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
66
SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input,bool accu_flag)67 void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) {
68 if (new_node_input.empty()) {
69 return;
70 }
71 auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
72 auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
73 MS_EXCEPTION_IF_NULL(prim);
74
75 auto attrs = prim->attrs();
76 attrs[DO_MIRROR] = MakeValue<bool>(!accu_flag);
77 prim->SetAttrs(attrs);
78 }
79
SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> & new_node_input,const CNodePtr & node)80 void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, const CNodePtr &node) {
81 if (new_node_input.empty()) {
82 return;
83 }
84
85 auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
86 auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
87 MS_EXCEPTION_IF_NULL(prim);
88 auto attrs = prim->attrs();
89
90 auto anf_node = node->input(0)->cast<ValueNodePtr>();
91 auto prim_node = GetValueNode<PrimitivePtr>(anf_node);
92 MS_EXCEPTION_IF_NULL(prim_node);
93 auto node_attrs = prim_node->attrs();
94 if (node_attrs.find(RECOMPUTE_COMM_OP) != node_attrs.end() && !GetValue<bool>(node_attrs[RECOMPUTE_COMM_OP])) {
95 attrs[RECOMPUTE] = MakeValue<bool>(false);
96 prim->SetAttrs(attrs);
97 MS_LOG(INFO) << "Do not recompute the forward communication operator of " << prim_node->ToString();
98 }
99 }
100
CreateInput(const Operator & op,const AnfNodePtr & node,const std::string & instance_name)101 std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
102 MS_EXCEPTION_IF_NULL(node);
103 OperatorArgs arg_forward = op.second;
104 ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name);
105 MS_EXCEPTION_IF_NULL(pyop_instance);
106 OperatorParams params = arg_forward.second;
107
108 std::vector<AnfNodePtr> new_node_input = {NewValueNode(pyop_instance), node};
109 if (!params.empty()) {
110 for (auto ¶m : params) {
111 AnfNodePtr val = NewValueNode(param.first.second);
112 MS_EXCEPTION_IF_NULL(val);
113 int64_t position = param.second;
114 (void)new_node_input.insert(new_node_input.begin() + position, val);
115 }
116 }
117
118 // if the op have 'group' attr, set the rank list name for the op
119 SetCommunicationOpGroupLabel(new_node_input);
120 return new_node_input;
121 }
122
CreateMirrorInput(const FuncGraphPtr & root,const Operator & op,const AnfNodePtr & node,const std::string & instance_name,const std::string & weight_name)123 std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
124 const std::string &instance_name, const std::string &weight_name) {
125 MS_EXCEPTION_IF_NULL(root);
126 MS_EXCEPTION_IF_NULL(node);
127 MS_EXCEPTION_IF_NULL(root->manager());
128
129 AnfNodePtr grad_accu = nullptr;
130 std::string op_name = op.first;
131 OperatorArgs arg_forward = op.second;
132
133 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
134 int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
135
136 if (grad_accumulation_step > 1 || split_stage_num > 1) {
137 auto parameters = root->parameters();
138 bool find_grad_accu_node = false;
139 for (auto ¶m : parameters) {
140 if (!ParameterIsCloned(param)) {
141 continue;
142 }
143
144 auto param_ptr = param->cast<ParameterPtr>();
145 MS_EXCEPTION_IF_NULL(param_ptr);
146 if (param_ptr->name().find(weight_name) != std::string::npos &&
147 param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
148 find_grad_accu_node = true;
149 grad_accu = param;
150 MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
151 break;
152 }
153 }
154
155 if (!find_grad_accu_node) {
156 if (op_name == MIRROR_MINI_STEP_OPERATOR) {
157 op_name = MIRROR_OPERATOR;
158 arg_forward.first.pop_back();
159 } else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR ||
160 op_name == MICRO_STEP_ALL_GATHER) {
161 MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name;
162 }
163 }
164 }
165
166 ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name);
167 MS_EXCEPTION_IF_NULL(pyop_instance);
168 OperatorParams params = arg_forward.second;
169
170 std::vector<AnfNodePtr> new_node_input;
171 if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER ||
172 op_name == MIRROR_MICRO_STEP_OPERATOR || op_name == MICRO_STEP_ALL_GATHER) {
173 new_node_input = {NewValueNode(pyop_instance), node, grad_accu};
174 MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
175 } else {
176 new_node_input = {NewValueNode(pyop_instance), node};
177 }
178
179 if (!params.empty()) {
180 for (auto ¶m : params) {
181 AnfNodePtr val = NewValueNode(param.first.second);
182 MS_EXCEPTION_IF_NULL(val);
183 int64_t position = param.second;
184 (void)new_node_input.insert(new_node_input.begin() + position, val);
185 }
186 }
187
188 // if the op have 'group' attr, set the rank list name for the op
189 SetCommunicationOpGroupLabel(new_node_input);
190 // gradient accumulation
191 if (grad_accumulation_step > 1) {
192 SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION));
193 }
194 return new_node_input;
195 }
196
InsertNode(const Operator & op,const CNodePtr & node,size_t index,const AnfNodePtr & pre_node,const FuncGraphPtr & func_graph,const std::string & instance_name,const std::string & param_name="",const FuncGraphPtr & root=nullptr)197 void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node,
198 const FuncGraphPtr &func_graph, const std::string &instance_name, const std::string ¶m_name = "",
199 const FuncGraphPtr &root = nullptr) {
200 // insert new node before the node
201 FuncGraphManagerPtr manager = func_graph->manager();
202 MS_EXCEPTION_IF_NULL(manager);
203 ScopePtr scope = node->scope();
204 MS_EXCEPTION_IF_NULL(scope);
205 std::vector<AnfNodePtr> node_input;
206 if (root && !param_name.empty()) {
207 node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
208 } else {
209 node_input = CreateInput(op, pre_node, instance_name);
210 }
211 CNodePtr new_node = func_graph->NewCNode(node_input);
212 MS_EXCEPTION_IF_NULL(new_node);
213 if (instance_name.find(SPLIT_SENS) == std::string::npos) {
214 new_node->set_in_forward_flag(true); // mark forward flag
215 }
216 auto new_node_value = node_input[0]->cast<ValueNodePtr>();
217 MS_EXCEPTION_IF_NULL(new_node_value);
218 PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
219 new_node_prim->set_instance_name(instance_name);
220 new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
221 if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
222 new_node_prim->set_attr("recompute", MakeValue(false));
223 }
224 new_node->set_scope(scope);
225 node_input[0]->set_scope(scope);
226 manager->SetEdge(node, SizeToInt(index), new_node);
227 MS_LOG(INFO) << "Insert " << instance_name << " success";
228 }
229
230 // Replace pre_node with pre_node->op
ReplaceNode(const Operator & op,const AnfNodePtr & pre_node,const FuncGraphPtr & func_graph,const std::string & instance_name,const std::string & param_name="",const FuncGraphPtr & root=nullptr)231 static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
232 const std::string &instance_name, const std::string ¶m_name = "",
233 const FuncGraphPtr &root = nullptr) {
234 // insert new node before the node
235 FuncGraphManagerPtr manager = func_graph->manager();
236 MS_EXCEPTION_IF_NULL(manager);
237 ScopePtr scope = pre_node->scope();
238 MS_EXCEPTION_IF_NULL(scope);
239 std::vector<AnfNodePtr> node_input;
240 if (root && !param_name.empty()) {
241 node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
242 } else {
243 node_input = CreateInput(op, pre_node, instance_name);
244 }
245 CNodePtr new_node = func_graph->NewCNode(node_input);
246 MS_EXCEPTION_IF_NULL(new_node);
247 if (instance_name.find(SPLIT_SENS) == std::string::npos) {
248 new_node->set_in_forward_flag(true); // mark forward flag
249 }
250 auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
251 new_node_prim->set_instance_name(instance_name);
252 new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
253 if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
254 new_node_prim->set_attr("recompute", MakeValue(false));
255 }
256 new_node->set_scope(scope);
257 node_input[0]->set_scope(scope);
258 manager->Replace(pre_node, new_node);
259 MS_LOG(INFO) << "Insert " << instance_name << " success";
260 return new_node;
261 }
262
ForwardCommunication(OperatorVector forward_op,const CNodePtr & node)263 void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
264 MS_EXCEPTION_IF_NULL(node);
265 // step1:get graph manager distribute_operator
266 FuncGraphPtr func_graph = node->func_graph();
267 MS_EXCEPTION_IF_NULL(func_graph);
268 FuncGraphManagerPtr manager = func_graph->manager();
269 MS_EXCEPTION_IF_NULL(manager);
270 auto uses_set = manager->node_users()[node];
271 CNodePtr node_to_insert = node;
272 for (auto &uses_pair : uses_set) {
273 auto uses_cnode = uses_pair.first->cast<CNodePtr>();
274 MS_EXCEPTION_IF_NULL(uses_cnode);
275 if (!IsValueNode<Primitive>(uses_cnode->input(0))) {
276 break;
277 }
278 PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0));
279 MS_EXCEPTION_IF_NULL(value_node_prim);
280 if (value_node_prim->name() == prim::kTupleGetItem) {
281 if (uses_set.size() > 1) {
282 MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size();
283 }
284 node_to_insert = uses_cnode;
285 }
286 }
287 MS_EXCEPTION_IF_NULL(node_to_insert);
288 std::reverse(forward_op.begin(), forward_op.end());
289
290 // step2:traverse op_list and insert node
291 for (size_t index = 0; index < forward_op.size(); ++index) {
292 std::string instance_name_base = FORWARD_OP;
293 std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
294 std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
295 SetAllReduceRecomputeFlag(forward_input, node_to_insert);
296 CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
297 MS_EXCEPTION_IF_NULL(forward_node);
298 ScopePtr scope = node->scope();
299 MS_EXCEPTION_IF_NULL(scope);
300 forward_node->set_scope(scope);
301 forward_node->set_in_forward_flag(true);
302 forward_input[0]->set_scope(scope);
303 (void)manager->Replace(node_to_insert, forward_node); // using Replace function to insert node
304 }
305 }
306
InsertMakeTuple(const AnfNodePtr & prev,uint64_t num,const FuncGraphPtr & func_graph)307 CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint64_t num, const FuncGraphPtr &func_graph) {
308 MS_EXCEPTION_IF_NULL(prev);
309 MS_EXCEPTION_IF_NULL(func_graph);
310 std::vector<AnfNodePtr> make_tuple_inputs;
311 make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
312 for (uint64_t i = 0; i < num; i++) {
313 std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), prev,
314 CreatInt64Imm(UlongToLong(i))};
315 auto tuple_get_item = func_graph->NewCNode(tuple_get_item_inputs);
316 MS_EXCEPTION_IF_NULL(tuple_get_item);
317 make_tuple_inputs.push_back(tuple_get_item);
318 }
319 auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
320 MS_EXCEPTION_IF_NULL(make_tuple);
321 FuncGraphManagerPtr manager = func_graph->manager();
322 MS_EXCEPTION_IF_NULL(manager);
323 (void)manager->Replace(prev, make_tuple);
324 return make_tuple;
325 }
326
InsertRedistribution(const RedistributionOpListPtr & redistribution_oplist_ptr,const CNodePtr & node,const FuncGraphPtr & func_graph,int64_t pos,const CNodePtr & pre_node)327 void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node,
328 const FuncGraphPtr &func_graph, int64_t pos, const CNodePtr &pre_node) {
329 MS_EXCEPTION_IF_NULL(node);
330 MS_EXCEPTION_IF_NULL(pre_node);
331 MS_EXCEPTION_IF_NULL(func_graph);
332 FuncGraphManagerPtr manager = func_graph->manager();
333 MS_EXCEPTION_IF_NULL(manager);
334 if ((redistribution_oplist_ptr->first).size() != (redistribution_oplist_ptr->second).size()) {
335 MS_LOG(EXCEPTION) << "size of OperatorVector and OutPutInfoVector must be the same!";
336 }
337 for (size_t index = 0; index < (redistribution_oplist_ptr->first).size(); ++index) {
338 if (pos >= SizeToLong(node->inputs().size())) {
339 MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size";
340 }
341 // Create new node
342 AnfNodePtr target_node = node->input(LongToSize(pos));
343 MS_EXCEPTION_IF_NULL(target_node);
344 // Create instance_name
345 auto op = (redistribution_oplist_ptr->first)[index];
346 std::string op_name = (redistribution_oplist_ptr->first)[index].first;
347 std::string instance_name_base = REDISTRIBUTION_OP;
348 std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name;
349 auto prim_out = GetCNodePrimitive(node);
350 auto prim_in = GetCNodePrimitive(pre_node);
351 if (prim_out != nullptr && prim_in != nullptr) {
352 auto prim_out_attr = prim_out->attrs();
353 auto prim_in_attr = prim_in->attrs();
354 if (prim_out_attr.find(RECOMPUTE_COMM_OP) != prim_out_attr.end() &&
355 !GetValue<bool>(prim_out_attr[RECOMPUTE_COMM_OP]) &&
356 prim_in_attr.find(RECOMPUTE_COMM_OP) != prim_in_attr.end() &&
357 !GetValue<bool>(prim_in_attr[RECOMPUTE_COMM_OP]) &&
358 COMMUNICATION_OPS.find(op_name) != COMMUNICATION_OPS.end()) {
359 MS_LOG(INFO) << "The redistribution node would not be recomputed.";
360 instance_name = instance_name + "_" + NOT_RECOMPUTE;
361 }
362 }
363 InsertNode(op, node, LongToSize(pos), target_node, func_graph, instance_name);
364 if ((redistribution_oplist_ptr->second)[index].first) {
365 target_node = node->input(LongToSize(pos));
366 MS_EXCEPTION_IF_NULL(target_node);
367 (void)InsertMakeTuple(target_node, (redistribution_oplist_ptr->second)[index].second, func_graph);
368 }
369 }
370 }
371
InsertGetTensorSliceOp(const Operator & op,const CNodePtr & node,const FuncGraphPtr & func_graph,int64_t pos,const std::string & instance_name)372 void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph, int64_t pos,
373 const std::string &instance_name) {
374 if (func_graph == nullptr) {
375 MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name;
376 }
377
378 FuncGraphManagerPtr manager = func_graph->manager();
379 MS_EXCEPTION_IF_NULL(manager);
380 if (pos >= SizeToLong(node->inputs().size())) {
381 MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is "
382 << instance_name;
383 }
384 // Create new node
385 AnfNodePtr pre_node = node->input(LongToSize(pos));
386 MS_EXCEPTION_IF_NULL(pre_node);
387 InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name);
388 }
389
GetTensorInLayout(const CNodePtr & middle_node,const PrimitivePtr & middle_prim,const OperatorInfoPtr & distribute_operator)390 TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim,
391 const OperatorInfoPtr &distribute_operator) {
392 TensorInfo tensorinfo_in;
393 if (middle_prim->name() == prim::kTupleGetItem) {
394 auto value_node = middle_node->input(2)->cast<ValueNodePtr>();
395 MS_EXCEPTION_IF_NULL(value_node);
396 size_t index_s = LongToSize(GetValue<int64_t>(value_node->value()));
397 if (index_s >= distribute_operator->outputs_tensor_info().size()) {
398 MS_LOG(EXCEPTION) << "The index out of range, index: " << index_s
399 << ", vector size: " << distribute_operator->outputs_tensor_info().size();
400 }
401 tensorinfo_in = distribute_operator->outputs_tensor_info()[index_s];
402 } else {
403 if (distribute_operator->outputs_tensor_info().empty()) {
404 MS_LOG(EXCEPTION) << "The outputs tensor info is empty";
405 }
406 tensorinfo_in = distribute_operator->outputs_tensor_info()[0];
407 }
408 return tensorinfo_in.tensor_layout();
409 }
410
GetPrimName(const CNodePtr & node)411 std::string GetPrimName(const CNodePtr &node) {
412 auto prim = GetCNodePrimitive(node);
413 MS_EXCEPTION_IF_NULL(prim);
414 return prim->name();
415 }
416
GetDistributeOperator(const CNodePtr & node)417 OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
418 MS_EXCEPTION_IF_NULL(node);
419 if (!IsParallelCareNode(node)) {
420 return nullptr;
421 }
422 OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
423 if (distribute_operator == nullptr) {
424 MS_LOG(EXCEPTION) << "Distribute operator is nullptr, the prim is " << GetPrimName(node);
425 }
426 return distribute_operator;
427 }
428
Redistribution(const std::pair<AnfNodePtr,int64_t> & node_pair,const OperatorInfoPtr & distribute_operator,const CNodePtr & middle_node,int64_t index,TensorRedistribution tensor_redistribution,const CNodePtr & pre_node)429 void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const OperatorInfoPtr &distribute_operator,
430 const CNodePtr &middle_node, int64_t index, TensorRedistribution tensor_redistribution,
431 const CNodePtr &pre_node) {
432 FuncGraphPtr func_graph = middle_node->func_graph();
433 if (func_graph == nullptr) {
434 MS_LOG(EXCEPTION) << "Redistribution:get graph failed";
435 }
436 CNodePtr next_node = node_pair.first->cast<CNodePtr>();
437 MS_EXCEPTION_IF_NULL(next_node);
438 auto middle_value = middle_node->input(0)->cast<ValueNodePtr>();
439 MS_EXCEPTION_IF_NULL(middle_value);
440 PrimitivePtr middle_prim = middle_value->value()->cast<PrimitivePtr>();
441 MS_EXCEPTION_IF_NULL(middle_prim);
442 OperatorInfoPtr next_distribute_operator = GetDistributeOperator(next_node);
443 if (next_distribute_operator == nullptr) {
444 MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed";
445 }
446 RankList dev_list = distribute_operator->stage_device_list();
447 std::string next_prim_name = GetValueNode<PrimitivePtr>(next_node->input(0))->name();
448 MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name;
449 MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString();
450 // extract tensor layout in and out
451 if (distribute_operator->outputs_tensor_info().empty()) {
452 MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name();
453 return;
454 }
455
456 if (LongToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) {
457 MS_LOG(WARNING) << "The index is out of range, the index is " << (index - 1) << ", the vector size is "
458 << next_distribute_operator->inputs_tensor_info().size() << "next operator name is "
459 << next_distribute_operator->name();
460 return;
461 }
462 TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[LongToSize(index - 1)];
463 TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
464 TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator);
465 if (IsPrimitiveCNode(middle_node, prim::kPrimReceive)) {
466 tensorlayout_in = *(middle_node->user_data<TensorLayout>());
467 }
468 if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
469 MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name;
470 MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node "
471 << next_node->ToString();
472 DumpGraph(func_graph, "redistribution_error");
473 MS_LOG(EXCEPTION) << "Failure:tensor_redistribution init failed";
474 }
475 RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
476 if (redistribution_oplist_ptr == nullptr) {
477 MS_LOG(EXCEPTION) << "Failure:InferTensorRedistribution failed";
478 }
479 MS_LOG(DEBUG) << "Redistribution size " << redistribution_oplist_ptr->first.size();
480 if (!redistribution_oplist_ptr->first.empty()) {
481 // insert node before next node
482 InsertRedistribution(redistribution_oplist_ptr, next_node, func_graph, node_pair.second, pre_node);
483 }
484 }
485
StrategyFound(std::unordered_map<std::string,ValuePtr> attrs)486 bool StrategyFound(std::unordered_map<std::string, ValuePtr> attrs) {
487 auto iter = attrs.find(STRATEGY);
488 return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
489 }
490
HasStrategy(const FuncGraphPtr & root)491 bool HasStrategy(const FuncGraphPtr &root) {
492 AnfNodePtr ret = root->get_return();
493 MS_EXCEPTION_IF_NULL(ret);
494 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
495
496 for (auto &node : all_nodes) {
497 auto cnode = node->cast<CNodePtr>();
498 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
499 continue;
500 }
501
502 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
503 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
504 auto attrs = prim->attrs();
505 if (StrategyFound(attrs)) {
506 return true;
507 }
508 }
509
510 return false;
511 }
512
IsCommunicationOp(const PrimitivePtr & prim)513 bool IsCommunicationOp(const PrimitivePtr &prim) {
514 MS_EXCEPTION_IF_NULL(prim);
515 return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end());
516 }
517
FindCommunicationOp(const std::vector<AnfNodePtr> & all_nodes)518 bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
519 for (auto &node : all_nodes) {
520 MS_EXCEPTION_IF_NULL(node);
521 if (!node->isa<CNode>()) {
522 continue;
523 }
524 auto cnode = node->cast<CNodePtr>();
525 if (!IsValueNode<Primitive>(cnode->input(0))) {
526 continue;
527 }
528 ValueNodePtr prim_value_node = cnode->input(0)->cast<ValueNodePtr>();
529 MS_EXCEPTION_IF_NULL(prim_value_node);
530 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_value_node);
531 MS_EXCEPTION_IF_NULL(prim);
532
533 if (IsCommunicationOp(prim) && cnode->in_forward_flag()) {
534 MS_EXCEPTION_IF_NULL(prim_value_node->scope());
535 MS_LOG(INFO) << "The graph contain communication op: " << prim->name() << ", scope name is "
536 << prim_value_node->scope()->name();
537 return true;
538 }
539 }
540 return false;
541 }
542
StepRedistribution(const CNodePtr & node,const OperatorInfoPtr & distribute_operator,const CNodePtr & insert_node,const TensorRedistribution & tensor_redistribution,const CNodePtr & pre_node)543 void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node,
544 const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) {
545 MS_EXCEPTION_IF_NULL(node->func_graph());
546 FuncGraphManagerPtr manager = node->func_graph()->manager();
547 MS_EXCEPTION_IF_NULL(manager);
548 AnfNodeIndexSet node_set = manager->node_users()[node];
549 CNodePtr insert_node_new;
550
551 if (IsPrimitiveCNode(node, prim::kPrimSend)) {
552 return;
553 }
554 if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
555 MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node";
556 return;
557 }
558 if (IsValueNode<Primitive>(node->input(0))) {
559 auto current_value = node->input(0)->cast<ValueNodePtr>();
560 MS_EXCEPTION_IF_NULL(current_value);
561 PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
562 MS_EXCEPTION_IF_NULL(current_prim);
563 insert_node_new = ((current_prim->name() == prim::kTupleGetItem) ? node : insert_node);
564 } else {
565 insert_node_new = insert_node;
566 }
567 MS_EXCEPTION_IF_NULL(insert_node_new);
568 for (auto &node_pair : node_set) {
569 CNodePtr use_cnode = node_pair.first->cast<CNodePtr>();
570 MS_EXCEPTION_IF_NULL(use_cnode);
571 if (!IsValueNode<Primitive>(use_cnode->input(0))) {
572 StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node);
573 } else {
574 ValueNodePtr prim_anf_node = use_cnode->input(0)->cast<ValueNodePtr>();
575 MS_EXCEPTION_IF_NULL(prim_anf_node);
576 PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
577 MS_EXCEPTION_IF_NULL(node_prim);
578 if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == UPDATESTATE) {
579 continue;
580 }
581 if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
582 Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution,
583 pre_node);
584 } else {
585 StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node);
586 }
587 }
588 }
589 }
590
SplitTensor(const AnfNodePtr & node,const CNodePtr & next_node,int64_t index)591 void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int64_t index) {
592 MS_EXCEPTION_IF_NULL(node);
593 MS_EXCEPTION_IF_NULL(next_node);
594 OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
595 MS_EXCEPTION_IF_NULL(op_info);
596
597 // If the shape of tensor is [] or [1], no need to split it.
598 Shapes shapes = GetNodeShape(node);
599 if (shapes.size() != 1) {
600 MS_LOG(EXCEPTION) << "Split tensor for " << op_info->name()
601 << ": GetNodeShape for tensor_node, output size is not 1";
602 }
603 Shape shape = shapes[0];
604 std::string shape_str = ShapeToString(shape);
605 if (shape.empty() || ((shape.size() == 1) && (shape[0] == 1))) {
606 MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape is " << shape_str
607 << ", no need to split it.";
608 return;
609 }
610
611 MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape of tensor is " << shape_str;
612
613 // extract tensor layout
614 if (LongToSize(index - 1) >= op_info->inputs_tensor_info().size()) {
615 MS_LOG(EXCEPTION) << "The index is out of range, index is " << (index - 1) << ", vector size is "
616 << op_info->inputs_tensor_info().size();
617 }
618 TensorInfo tensor_info = op_info->inputs_tensor_info()[LongToSize(index - 1)];
619 TensorLayout tensor_layout = tensor_info.tensor_layout();
620
621 // Use _GetTensorSlice operator to split the tensor
622 FuncGraphPtr func_graph = next_node->func_graph(); // only cnode can get the graph
623 MS_EXCEPTION_IF_NULL(func_graph);
624 Operator op = CreateGetTensorSliceOp(tensor_layout);
625 InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR);
626 if (!op_info->sub_ops().empty()) {
627 auto sub_ops = op_info->sub_ops();
628 for (size_t i = 0; i < sub_ops.size(); i++) {
629 if (!sub_ops.at(i).empty()) {
630 InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB);
631 }
632 }
633 }
634 }
635
SplitTensorList(const AnfNodePtr & node,const CNodePtr & next_node,int index)636 void SplitTensorList(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
637 MS_EXCEPTION_IF_NULL(node);
638 MS_EXCEPTION_IF_NULL(next_node);
639 if (next_node->inputs().size() != 2 || index != 1) {
640 MS_LOG(INFO) << next_node->fullname_with_scope() << " Inputs must have only one input, get "
641 << (next_node->inputs().size() - 1) << " index should be 1, get " << index;
642 return;
643 }
644 OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
645 MS_EXCEPTION_IF_NULL(op_info);
646
647 std::vector<ValuePtr> inputs_values;
648 if (IsValueNode<ValueList>(node)) {
649 inputs_values = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
650 } else {
651 inputs_values = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
652 }
653 if (inputs_values.size() != op_info->inputs_tensor_info().size()) {
654 MS_LOG(EXCEPTION) << "The inputs size " << inputs_values.size() << ", is not equal to inputs shape size "
655 << op_info->inputs_tensor_info().size();
656 }
657 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
658 FuncGraphPtr func_graph = next_node->func_graph();
659 MS_EXCEPTION_IF_NULL(func_graph);
660 FuncGraphManagerPtr manager = func_graph->manager();
661 MS_EXCEPTION_IF_NULL(manager);
662 ScopePtr scope = next_node->scope();
663 MS_EXCEPTION_IF_NULL(scope);
664 for (size_t i = 0; i < inputs_values.size(); ++i) {
665 auto value_ptr = inputs_values[i];
666 auto tensor = value_ptr->cast<tensor::TensorPtr>();
667 MS_EXCEPTION_IF_NULL(tensor);
668 TensorInfo tensor_info = op_info->inputs_tensor_info()[i];
669 TensorLayout tensor_layout = tensor_info.tensor_layout();
670 auto value_node = NewValueNode(value_ptr)->cast<AnfNodePtr>();
671 Operator op = CreateGetTensorSliceOp(tensor_layout);
672 std::vector<AnfNodePtr> node_input = CreateInput(op, value_node, SPLIT_TENSOR);
673 CNodePtr new_node = func_graph->NewCNode(node_input);
674 new_node->set_in_forward_flag(true);
675 auto new_node_value = node_input[0]->cast<ValueNodePtr>();
676 MS_EXCEPTION_IF_NULL(new_node_value);
677 PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
678 new_node_prim->set_instance_name(SPLIT_TENSOR);
679 new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
680 new_node->set_scope(scope);
681 node_input[0]->set_scope(scope);
682 make_tuple_inputs.push_back(new_node);
683 }
684 CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
685 manager->Replace(node, make_tuple);
686 }
687
StepSplitTensor(const AnfNodePtr & node,const FuncGraphManagerPtr & manager)688 void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
689 MS_EXCEPTION_IF_NULL(node);
690 MS_EXCEPTION_IF_NULL(manager);
691 AnfNodeIndexSet node_set = manager->node_users()[node];
692 for (auto &node_pair : node_set) {
693 CNodePtr use_cnode = node_pair.first->cast<CNodePtr>();
694 if (use_cnode == nullptr || !IsValueNode<Primitive>(use_cnode->input(0))) {
695 continue;
696 }
697 ValueNodePtr prim_anf_node = use_cnode->input(0)->cast<ValueNodePtr>();
698 MS_EXCEPTION_IF_NULL(prim_anf_node);
699 PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast<PrimitivePtr>();
700 MS_EXCEPTION_IF_NULL(use_cnode_prim);
701 if ((use_cnode_prim->name() == DEPEND && node_pair.second != 1) ||
702 NO_INPUT_TENSOR_OPS.find(use_cnode_prim->name()) != NO_INPUT_TENSOR_OPS.end()) {
703 continue;
704 }
705 if (IsParallelCareNode(use_cnode)) {
706 if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
707 SplitTensorList(node, use_cnode, node_pair.second);
708 } else {
709 SplitTensor(node, use_cnode, node_pair.second);
710 }
711 }
712 }
713 }
714
StepReplaceOp(OperatorVector replace_op,const CNodePtr & node)715 void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
716 // step1:get graph manager distribute_operator
717 OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
718 if (distribute_operator == nullptr) {
719 MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
720 }
721 FuncGraphPtr func_graph = node->func_graph();
722 MS_EXCEPTION_IF_NULL(func_graph);
723 FuncGraphManagerPtr manager = func_graph->manager();
724 if (manager == nullptr) {
725 MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
726 }
727 // step2:traverse op_list and insert node
728 std::reverse(replace_op.begin(), replace_op.end());
729 auto replace_op_info = distribute_operator->replace_op_info();
730 std::reverse(replace_op_info.begin(), replace_op_info.end());
731 if (!replace_op_info.empty() && replace_op_info.size() != replace_op.size()) {
732 MS_LOG(EXCEPTION) << "replace_op_info is not empty and size not equal to replace_op!";
733 }
734 bool replace_op_info_flag = !replace_op_info.empty();
735 for (size_t index = 0; index < replace_op.size(); ++index) {
736 std::string instance_name = CreateInstanceName(node, index);
737 std::vector<AnfNodePtr> replace_input;
738 if (index != replace_op.size() - 1) {
739 replace_input = CreateInput(replace_op[index], node, instance_name);
740 } else {
741 replace_input = ReplaceOpInput(replace_op[index], instance_name, node);
742 }
743 CNodePtr replace_node = func_graph->NewCNode(replace_input);
744 MS_EXCEPTION_IF_NULL(replace_node);
745 ScopePtr scope = node->scope();
746 MS_EXCEPTION_IF_NULL(scope);
747 replace_node->set_scope(scope);
748 PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
749 PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
750 SetUserAttrs(origin_prim->attrs(), prim);
751 auto origin_prim_attrs = origin_prim->attrs();
752 if (origin_prim_attrs.find(RECOMPUTE_COMM_OP) != origin_prim_attrs.end() &&
753 !GetValue<bool>(origin_prim_attrs[RECOMPUTE_COMM_OP]) &&
754 COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()) {
755 MS_LOG(INFO) << "The redistribution node in reshape would not be recomputed.";
756 prim->set_attr("recompute", MakeValue(false));
757 }
758 if (index == replace_op.size() - 1) {
759 replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
760 replace_node->set_primal_attrs(node->primal_attrs());
761 }
762 replace_node->set_in_forward_flag(true);
763 replace_input[0]->set_scope(scope);
764 if (replace_op_info_flag && replace_op_info[index].first) {
765 auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph);
766 new_cnode->set_primal_attrs(node->primal_attrs());
767 (void)manager->Replace(node, new_cnode); // using Replace function to insert node
768 } else {
769 (void)manager->Replace(node, replace_node); // using Replace function to insert node
770 }
771 }
772 MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name();
773 }
774
StepReplaceGraph(const ReplaceGraphPtr & replace_graph,const CNodePtr & node)775 void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) {
776 MS_EXCEPTION_IF_NULL(replace_graph);
777 MS_EXCEPTION_IF_NULL(node);
778 MS_EXCEPTION_IF_NULL(replace_graph->second);
779 FuncGraphPtr func_graph = node->func_graph();
780 MS_EXCEPTION_IF_NULL(func_graph);
781 FuncGraphManagerPtr manager = func_graph->manager();
782 if (manager == nullptr) {
783 MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
784 }
785 // Solve the input order
786 // For example input_node:{segment_sum:1, segment_sum:2, gahter:2}
787 // The Original code here will bind the all operations to the first inputs of these operatos
788 // However, the segment_sum operation needs two inputs, To solve this
789 // We maintain a dict to count the times of the same operations,
790 // and bind the inputs according to the times of the op appears.
791 std::unordered_map<AnfNodePtr, int> input_map = {};
792 static int appear_count = 0;
793 for (auto &replace_input : replace_graph->first) {
794 auto pre_node = node->input(LongToSize(replace_input.second));
795
796 auto it = input_map.find(replace_input.first);
797 if (it != input_map.end()) {
798 appear_count = 1 + it->second;
799 } else {
800 appear_count = 1;
801 }
802 input_map[replace_input.first] = appear_count;
803 manager->SetEdge(replace_input.first, appear_count, pre_node);
804 }
805 // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
806 auto replace_output = replace_graph->second->cast<CNodePtr>();
807 MS_EXCEPTION_IF_NULL(replace_output);
808 replace_output->set_primal_attrs(node->primal_attrs());
809 (void)manager->Replace(node, replace_output);
810 }
811
GetTupleGetItemIndex(const CNodePtr & cnode)812 int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
813 MS_EXCEPTION_IF_NULL(cnode);
814 if (cnode->inputs().size() != 3) {
815 MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3";
816 }
817
818 if (!cnode->input(2)->isa<ValueNode>()) {
819 MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node";
820 }
821
822 ValuePtr tuple_index_value = GetValueNode(cnode->input(2));
823 MS_EXCEPTION_IF_NULL(tuple_index_value);
824 if (!tuple_index_value->isa<Int64Imm>()) {
825 MS_LOG(EXCEPTION) << "The index of tuple getitem is not int32";
826 }
827 return tuple_index_value->cast<Int64ImmPtr>()->value();
828 }
829
InsertVirtualDivOp(const VirtualDivOp & virtual_div_op,const CNodePtr & node)830 void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) {
831 MS_EXCEPTION_IF_NULL(node);
832 size_t node_size = node->inputs().size();
833 FuncGraphPtr func_graph = node->func_graph();
834 MS_EXCEPTION_IF_NULL(func_graph);
835 FuncGraphManagerPtr manager = func_graph->manager();
836 MS_EXCEPTION_IF_NULL(manager);
837
838 if (IsSomePrimitive(node, DROPOUT_DO_MASK)) {
839 MS_LOG(INFO) << "Handle dropout do mask, only insert the virtual div to input[0]";
840 node_size = 2;
841 }
842
843 for (size_t index = 1; index < node_size; ++index) {
844 AnfNodePtr input = node->input(index);
845 MS_EXCEPTION_IF_NULL(input);
846 // if it is not a tensor, continue
847 if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
848 MS_LOG(INFO) << "insert div op: the index " << index << " is not tensor, skip";
849 continue;
850 }
851
852 for (size_t pos = 0; pos < virtual_div_op.size(); ++pos) {
853 std::string instance_name = CreateInstanceName(node, pos);
854 InsertNode(virtual_div_op[pos], node, index, node->input(index), func_graph, instance_name);
855 }
856 MS_LOG(INFO) << "insert div op for input index " << index << " of node";
857 }
858 }
859
InsertVirtualOutput(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)860 void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
861 vector<std::string> last_forward_node_ids;
862 vector<size_t> last_indexs;
863 FindLastNodesUniqueId(root, &last_forward_node_ids, &last_indexs);
864 MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
865 for (auto &node : all_nodes) {
866 // here insert virtualoutput node
867 auto cnode = node->cast<CNodePtr>();
868 if (cnode == nullptr) {
869 continue;
870 }
871 auto last_node_iter = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId());
872 if (last_node_iter == last_forward_node_ids.end()) {
873 continue;
874 }
875 for (size_t last_node_index = 0; last_node_index < last_forward_node_ids.size(); ++last_node_index) {
876 if (last_forward_node_ids[last_node_index] != cnode->UniqueId()) {
877 continue;
878 }
879 MS_LOG(INFO) << "find last node: " << cnode->fullname_with_scope() << ", the parallel care node is: "
880 << cnode->input(last_indexs[last_node_index])->fullname_with_scope();
881 if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
882 FuncGraphManagerPtr manager = cnode->func_graph()->manager();
883 MS_EXCEPTION_IF_NULL(manager);
884 auto node_pair = manager->node_users()[cnode].front();
885 if (!node_pair.first->isa<CNode>()) {
886 MS_LOG(EXCEPTION) << "the output of tuple_get_item is not a cnode";
887 }
888 cnode = node_pair.first->cast<CNodePtr>();
889 last_indexs[last_node_index] = IntToSize(node_pair.second);
890 }
891 auto pre_node = cnode->input(last_indexs[last_node_index]);
892 Shapes shape_outputs = GetNodeShape(pre_node);
893 if (shape_outputs[0].empty()) {
894 continue;
895 }
896 FuncGraphPtr func_graph = node->func_graph();
897 MS_EXCEPTION_IF_NULL(func_graph);
898 OperatorParams params;
899 OperatorAttrs attrs;
900 OperatorArgs args = std::make_pair(attrs, params);
901 Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
902 InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
903 auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
904 AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();
905 std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
906 virtual_output_abstract->set_shape(virtual_output_shape);
907 virtual_output_node->set_abstract(virtual_output_abstract);
908 }
909 }
910 }
911
FindParameterByValueNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph)912 static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
913 if (IsValueNode<RefKey>(node)) {
914 std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
915 if (param_v.size() != 1) {
916 MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
917 << param_v.size();
918 }
919 auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
920 if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
921 return std::make_pair(nullptr, true);
922 }
923 return std::make_pair(node, true);
924 }
925 return std::make_pair(nullptr, false);
926 }
927
FindParameterByParameter(const AnfNodePtr & node)928 static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node) {
929 auto param_ptr = node->user_data<parallel::TensorLayout>();
930 if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
931 return std::make_pair(nullptr, false);
932 }
933 return std::make_pair(node, false);
934 }
935
936 // Only used for InsertMirrorOps
FindParameter(const AnfNodePtr & node,const FuncGraphPtr & func_graph)937 std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
938 if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
939 return std::make_pair(nullptr, false);
940 }
941
942 if (node->isa<Parameter>()) {
943 return FindParameterByParameter(node);
944 }
945
946 if (node->isa<ValueNode>()) {
947 return FindParameterByValueNode(node, func_graph);
948 }
949
950 CNodePtr cnode = node->cast<CNodePtr>();
951 MS_EXCEPTION_IF_NULL(cnode);
952 if (!IsValueNode<Primitive>(cnode->input(0))) {
953 for (size_t index = 0; index < cnode->inputs().size(); ++index) {
954 auto res = FindParameter(cnode->input(index), func_graph);
955 if (!res.first) {
956 continue;
957 }
958 return res;
959 }
960 }
961
962 // When not fully use opt shard, allgather and mirror would be both inserted.
963 // Skip allgather here and find parameter recursively.
964 if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
965 return std::make_pair(nullptr, false);
966 }
967
968 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
969 MS_EXCEPTION_IF_NULL(prim_anf_node);
970 for (size_t index = 0; index < cnode->inputs().size(); ++index) {
971 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
972 MS_EXCEPTION_IF_NULL(prim);
973 if ((prim->name() == DEPEND || prim->name() == LOAD || IsInAllGatherNodeList(cnode)) && index != 1) {
974 continue;
975 }
976 auto res = FindParameter(cnode->input(index), func_graph);
977 if (!res.first) {
978 continue;
979 }
980 return res;
981 }
982 return std::make_pair(nullptr, false);
983 }
984
985 // only used for FindCNode
SkipTrivialNodesMoveDown(const FuncGraphManagerPtr & manager,CNodePtr node)986 CNodePtr SkipTrivialNodesMoveDown(const FuncGraphManagerPtr &manager, CNodePtr node) {
987 MS_EXCEPTION_IF_NULL(node);
988 while (IsInTrivialNodeList(node) || IsSomePrimitive(node, LOAD)) {
989 node = manager->node_users()[node].begin()->first->cast<CNodePtr>();
990 }
991 return node;
992 }
993
FindCNode(const AnfNodePtr & anode,const std::string & name,const FuncGraphPtr & func_graph,size_t max_depth)994 std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph,
995 size_t max_depth) {
996 MS_EXCEPTION_IF_NULL(anode);
997 MS_EXCEPTION_IF_NULL(anode->func_graph());
998 FuncGraphManagerPtr manager = anode->func_graph()->manager();
999 MS_EXCEPTION_IF_NULL(manager);
1000 if (max_depth > MAX_RECURSIVE_DEPTH) {
1001 MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
1002 }
1003 AnfNodeIndexSet node_set = manager->node_users()[anode];
1004 bool result = false;
1005 CNodePtr cnode_return = nullptr;
1006 for (auto &node_pair : node_set) {
1007 CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
1008 if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
1009 continue;
1010 }
1011 if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
1012 use_apply = SkipTrivialNodesMoveDown(manager, use_apply);
1013 }
1014 ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
1015 MS_EXCEPTION_IF_NULL(prim_anf_node);
1016 PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
1017 MS_EXCEPTION_IF_NULL(node_prim);
1018 if (node_prim->name() == name && node_pair.second == 1) {
1019 if (use_apply->func_graph() == func_graph) {
1020 result = true;
1021 cnode_return = use_apply;
1022 MS_LOG(INFO) << "Find Primitive " << name << " in the same func_graph";
1023 continue;
1024 }
1025 MS_LOG(INFO) << "Find Primitive " << name << " in different func_graph";
1026 }
1027 if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(use_apply)) {
1028 return FindCNode(node_pair.first, name, func_graph, max_depth + 1);
1029 }
1030 }
1031 return std::make_pair(result, cnode_return);
1032 }
1033
InsertMirrorBeforeCast(const CNodePtr & node,size_t index)1034 bool InsertMirrorBeforeCast(const CNodePtr &node, size_t index) {
1035 // only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true
1036 if (!ParallelContext::GetInstance()->gradient_fp32_sync()) {
1037 return false;
1038 }
1039 auto pre_node = node->input(index);
1040 MS_EXCEPTION_IF_NULL(pre_node);
1041 auto cnode = pre_node->cast<CNodePtr>();
1042 if (cnode == nullptr || !IsValueNode<Primitive>(cnode->input(0))) {
1043 return false;
1044 }
1045 if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(cnode)) {
1046 pre_node = cnode->input(1);
1047 }
1048 if (!IsPrimitiveCNode(pre_node, prim::kPrimCast)) {
1049 return false;
1050 }
1051 auto node_type = pre_node->Type();
1052 MS_EXCEPTION_IF_NULL(node_type);
1053 if (!node_type->isa<mindspore::TensorType>()) {
1054 MS_LOG(EXCEPTION) << "Unknown type.";
1055 }
1056 auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1057 MS_EXCEPTION_IF_NULL(input_element_type);
1058 auto type_id = input_element_type->type_id();
1059
1060 return (type_id != kNumberTypeFloat32);
1061 }
1062
CheckInsertMirrorOps(const MirrorOps & mirror_ops,const CNodePtr & node,size_t node_size)1063 static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
1064 if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1065 return true;
1066 }
1067 if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
1068 MS_LOG(INFO) << "Input is ValueList, skip it.";
1069 return false;
1070 }
1071
1072 if ((node->inputs().size() == 2) &&
1073 (AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
1074 MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
1075 return false;
1076 }
1077
1078 if (mirror_ops.size() != node_size - 1) {
1079 MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
1080 << (node_size - 1);
1081 }
1082 return true;
1083 }
1084
1085 // only used for InsertMirrorOps
SkipTrivialNodesMoveUp(CNodePtr node)1086 CNodePtr SkipTrivialNodesMoveUp(CNodePtr node) {
1087 MS_EXCEPTION_IF_NULL(node);
1088 while (!IsSomePrimitive(node, LOAD)) {
1089 if (IsInTrivialNodeList(node) || IsInAllGatherNodeList(node)) {
1090 node = node->input(1)->cast<CNodePtr>();
1091 }
1092 }
1093 auto prev_node = node->input(1)->cast<CNodePtr>();
1094 if (prev_node != nullptr) {
1095 if (IsSomePrimitive(prev_node, DEPEND)) {
1096 auto prev_prev_node = prev_node->input(1)->cast<CNodePtr>();
1097 if (IsSomePrimitive(node, LOAD)) {
1098 node = prev_prev_node;
1099 MS_LOG(INFO) << "Moving to the Load node before Depend node.";
1100 }
1101 }
1102 }
1103 return node;
1104 }
1105
MirrorOpName()1106 std::string MirrorOpName() {
1107 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
1108 int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
1109 std::string mirror_op_name;
1110 if (grad_accumulation_step > 1) {
1111 mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
1112 } else if (split_stage_num > 1) {
1113 mirror_op_name = MIRROR_MICRO_STEP_OPERATOR;
1114 } else {
1115 mirror_op_name = MIRROR_OPERATOR;
1116 }
1117 return mirror_op_name;
1118 }
1119
InsertMirrorOps(const FuncGraphPtr & root,const MirrorOps & mirror_ops,const CNodePtr & node)1120 void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
1121 MS_EXCEPTION_IF_NULL(node);
1122 size_t node_size = node->inputs().size();
1123 FuncGraphPtr func_graph = node->func_graph();
1124 MS_EXCEPTION_IF_NULL(func_graph);
1125 FuncGraphManagerPtr manager = func_graph->manager();
1126 MS_EXCEPTION_IF_NULL(manager);
1127 for (auto input : node->inputs()) {
1128 if (HasAbstractMonad(input)) {
1129 node_size--;
1130 }
1131 }
1132
1133 if (!CheckInsertMirrorOps(mirror_ops, node, node_size)) {
1134 return;
1135 }
1136
1137 for (size_t index = 1; index < node_size; ++index) {
1138 OperatorVector backward_op = mirror_ops[index - 1];
1139 if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1140 auto param_index = GetValue<int>(node->GetPrimalAttr(PARAM_INDEX));
1141 backward_op = mirror_ops[IntToSize(param_index)];
1142 }
1143 if (backward_op.empty()) {
1144 continue;
1145 }
1146 std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(node->input(index), func_graph);
1147 if (!param_node_pair.first) {
1148 continue;
1149 }
1150
1151 auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
1152 std::string param_name;
1153 bool is_shared_param = false;
1154 if (param_ptr) {
1155 param_name = param_ptr->name();
1156 if (!param_ptr->param_info() || !param_ptr->param_info()->requires_grad()) {
1157 MS_LOG(INFO) << param_name << " do not need gradient. Skip inserting mirror.";
1158 continue;
1159 }
1160 std::string opt_shard_mirror_group;
1161 if (param_ptr->user_data<TensorLayout>()) {
1162 opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
1163 is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
1164 }
1165 if (!opt_shard_mirror_group.empty()) {
1166 // mirror ops is covered in not fully use opt shard case
1167 backward_op = CreateMirrorOps(opt_shard_mirror_group, static_cast<size_t>(opt_shard_mirror_group[0]));
1168 }
1169 }
1170 // not a RefKey
1171 std::string mirror_op_name = MirrorOpName();
1172 AnfNodePtr pre_node = node->input(index);
1173 if (!param_node_pair.second) {
1174 auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph, 0);
1175 // if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
1176 if (next_cnode.first) {
1177 MS_EXCEPTION_IF_NULL(next_cnode.second);
1178 // assume Load is inserted next to parameter
1179 // skip Load moving up and insert mirror next to the parameter
1180 if (pre_node->cast<CNodePtr>()) {
1181 CNodePtr load_node = SkipTrivialNodesMoveUp(node->input(index)->cast<CNodePtr>());
1182 manager->SetEdge(load_node, 1, next_cnode.second);
1183 } else {
1184 manager->SetEdge(node, static_cast<int>(index), next_cnode.second);
1185 }
1186 MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1187 << " and share the mirror.";
1188 continue;
1189 }
1190 }
1191 // if the parameter found is a RefKey, or no MirrorOp is found in the same graph, insert a new MirrorOp
1192 // only one MirrorOp in backward_op
1193 if (backward_op.size() != 1) {
1194 MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size();
1195 }
1196 auto op = backward_op[0];
1197 if (pre_node->cast<CNodePtr>() && (InsertMirrorBeforeCast(node, index) || is_shared_param)) {
1198 // assume Load is inserted next to parameter
1199 // skip Load moving up and insert mirror next to the parameter
1200 CNodePtr load_node = SkipTrivialNodesMoveUp(pre_node->cast<CNodePtr>());
1201 InsertNode(op, load_node, 1, load_node->input(1), func_graph, mirror_op_name, param_name, root);
1202 auto comm_op = load_node->input(1)->cast<CNodePtr>();
1203 // add fusion flag
1204 AddCommOpFusionType(comm_op, param_node_pair.first);
1205 MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1206 << " and insert mirror before Load";
1207 AddCommOpParamFlag(comm_op);
1208 continue;
1209 }
1210 InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root);
1211 MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1212 << " and insert mirror before the node";
1213 auto comm_op = node->input(index)->cast<CNodePtr>();
1214 // add fusion flag
1215 // pipeline mirror would not be set, which should be supported later
1216 AddCommOpFusionType(comm_op, param_node_pair.first);
1217 AddCommOpParamFlag(comm_op);
1218 }
1219 }
1220
BackwardCommunication(const FuncGraphPtr & root,const OperatorInfoPtr & distribute_operator,const CNodePtr & node,const std::vector<std::pair<CNodePtr,LossNodeInfo>> & sens_loss_pairs)1221 void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
1222 const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
1223 MS_EXCEPTION_IF_NULL(distribute_operator);
1224 MS_EXCEPTION_IF_NULL(node);
1225
1226 if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
1227 return;
1228 }
1229 bool is_loss_cnode =
1230 std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
1231 [node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
1232
1233 MirrorOps mirror_ops = distribute_operator->mirror_ops();
1234 VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
1235 // insert mirror op
1236 if (!mirror_ops.empty()) {
1237 MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
1238 InsertMirrorOps(root, mirror_ops, node);
1239 }
1240 // insert virtual div op
1241 if (!virtual_div_op.empty() && is_loss_cnode && IsLastStage()) {
1242 MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name();
1243 InsertVirtualDivOp(virtual_div_op, node);
1244 }
1245 }
1246
GetDisOpName(const std::string & prim_name)1247 std::string GetDisOpName(const std::string &prim_name) {
1248 std::string op_name = prim_name;
1249 if (!prim_name.empty() && (prim_name[0] == '_')) {
1250 op_name = prim_name.substr(1);
1251 }
1252 return op_name + "Info";
1253 }
1254
OperatorInstanceByName(const std::string & name,const PrimitiveAttrs & attrs,const std::vector<Shapes> & shape_list)1255 OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs,
1256 const std::vector<Shapes> &shape_list) {
1257 if (shape_list.size() != 2) {
1258 MS_LOG(ERROR) << "The size of shape list is not 2";
1259 return nullptr;
1260 }
1261 if (name.length() == 0) {
1262 MS_LOG(EXCEPTION) << "Length of name is zero!";
1263 }
1264 std::string distribute_opname = GetDisOpName(name);
1265 if (name == GATHERV2) {
1266 distribute_opname = name + "PInfo";
1267 auto data_parallel_iter = attrs.find(DATA_PARALLEL);
1268 if (data_parallel_iter != attrs.end()) {
1269 MS_EXCEPTION_IF_NULL(data_parallel_iter->second);
1270 if (!data_parallel_iter->second->isa<BoolImm>()) {
1271 MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool.";
1272 }
1273 bool data_parallel = data_parallel_iter->second->cast<BoolImmPtr>()->value();
1274 if (data_parallel) {
1275 distribute_opname = name + "Info";
1276 }
1277 }
1278 }
1279 OperatorInfoPtr operator_ =
1280 (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
1281 if (operator_ == nullptr) {
1282 MS_LOG(INFO) << "Create " << name << " failed";
1283 return nullptr;
1284 }
1285 std::string origin_name = operator_->name();
1286 operator_->set_name(origin_name + std::to_string(TOTAL_OPS));
1287 MS_LOG(INFO) << "Successfully created operator " << origin_name;
1288 ++TOTAL_OPS;
1289 return operator_;
1290 }
1291
OperatorInstance(const PrimitivePtr & prim,const PrimitiveAttrs & attrs,const std::vector<Shapes> & shape_list)1292 OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
1293 const std::vector<Shapes> &shape_list) {
1294 MS_EXCEPTION_IF_NULL(prim);
1295 OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list);
1296 if (operator_ == nullptr) {
1297 if (IsInBatchParallelBlackList(prim)) {
1298 MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode.";
1299 }
1300 MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel";
1301 operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
1302 MS_EXCEPTION_IF_NULL(operator_);
1303 }
1304 return operator_;
1305 }
1306
NewOperatorInstance(const PrimitivePtr & prim,const PrimitiveAttrs & attrs,std::vector<Shapes> shape_list)1307 OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
1308 std::vector<Shapes> shape_list) {
1309 OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list);
1310 for (size_t i = 0; i < shape_list[0].size(); ++i) {
1311 MS_LOG(INFO) << "No: " << i << " input's shape: " << ShapeToString(shape_list[0][i]);
1312 }
1313 return operator_;
1314 }
1315
ExtractStrategy(const ValuePtr & stra)1316 StrategyPtr ExtractStrategy(const ValuePtr &stra) {
1317 ValueTuplePtr var = stra->cast<ValueTuplePtr>();
1318 StrategyPtr strategyPtr;
1319 int64_t stage_id = g_device_manager->stage_id();
1320
1321 MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
1322 if (var == nullptr) {
1323 MS_LOG(EXCEPTION) << "Strategy value is nullptr";
1324 }
1325 if (var->size() > 0) {
1326 std::vector<ValuePtr> elements = var->value();
1327 Strategys strategy;
1328 for (uint64_t index = 0; index < elements.size(); ++index) {
1329 Dimensions dim;
1330 if (elements[index]->isa<ValueSequeue>()) {
1331 ValueTuplePtr value_tuple = elements[index]->cast<ValueTuplePtr>();
1332 std::vector<ValuePtr> value_vector = value_tuple->value();
1333 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
1334 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
1335 strategy.push_back(dim);
1336 } else {
1337 MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
1338 }
1339 }
1340 if (strategy.empty()) {
1341 MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
1342 }
1343 strategyPtr = NewStrategy(stage_id, strategy);
1344 }
1345
1346 return strategyPtr;
1347 }
1348
GetRefKeyNodeShape(const AnfNodePtr & node,const FuncGraphPtr & func_graph)1349 Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
1350 MS_EXCEPTION_IF_NULL(node);
1351 MS_EXCEPTION_IF_NULL(func_graph);
1352
1353 std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(node, func_graph);
1354 if (parameters.size() != 1) {
1355 MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1356 }
1357
1358 Shapes input_shapes;
1359 input_shapes = GetNodeShape(parameters[0]);
1360 if (input_shapes.size() != 1) {
1361 MS_LOG(EXCEPTION) << "Get input shape failed";
1362 }
1363
1364 MS_LOG(INFO) << "The parameter shape is " << ShapeToString(input_shapes[0]);
1365 return input_shapes;
1366 }
1367
ExtractShape(const CNodePtr & node)1368 std::vector<Shapes> ExtractShape(const CNodePtr &node) {
1369 MS_EXCEPTION_IF_NULL(node);
1370 Shapes shape_inputs, shape_outputs;
1371 std::vector<Shapes> shape_all;
1372 std::vector<AnfNodePtr> all_inputs = node->inputs();
1373
1374 size_t inputs_size = all_inputs.size();
1375 for (size_t i = 1; i < inputs_size; ++i) {
1376 Shapes input_shapes;
1377 AnfNodePtr input = all_inputs[i];
1378 if (HasAbstractMonad(input)) {
1379 continue;
1380 }
1381 if (IsValueNode<RefKey>(input)) {
1382 auto func_graph = node->func_graph();
1383 MS_EXCEPTION_IF_NULL(func_graph);
1384 std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
1385 if (parameters.size() != 1) {
1386 MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
1387 }
1388 std::pair<AnfNodePtr, int64_t> node_pair = std::make_pair(node, SizeToLong(i));
1389 g_RefMap[parameters[0]] = node_pair;
1390 input_shapes = GetRefKeyNodeShape(input, func_graph);
1391 } else if (input->isa<CNode>() || IsValueNode<Tensor>(input) || input->isa<Parameter>() ||
1392 ((IsValueNode<ValueList>(input) || IsValueNode<ValueTuple>(input)) && (inputs_size == 2))) {
1393 input_shapes = GetNodeShape(input);
1394 } else {
1395 continue;
1396 }
1397 if (input_shapes.size() != 1) {
1398 if (inputs_size == 2) { // like concat
1399 shape_inputs = input_shapes;
1400 break;
1401 } else {
1402 MS_LOG(EXCEPTION) << "ExtractShape: Get input shape failed";
1403 }
1404 }
1405 shape_inputs.push_back(input_shapes[0]);
1406 }
1407 shape_all.push_back(shape_inputs);
1408 // extract out shape
1409 shape_outputs = GetNodeShape(node);
1410 shape_all.push_back(shape_outputs);
1411 return shape_all;
1412 }
1413
FindParallelCareNode(const AnfNodePtr & node,int32_t recursion_num)1414 std::pair<AnfNodePtr, int64_t> FindParallelCareNode(const AnfNodePtr &node, int32_t recursion_num) {
1415 if (recursion_num >= RECURSION_LIMIT) {
1416 return std::make_pair(nullptr, 0);
1417 }
1418
1419 MS_EXCEPTION_IF_NULL(node);
1420 FuncGraphPtr func_graph = node->func_graph();
1421 MS_EXCEPTION_IF_NULL(func_graph);
1422 FuncGraphManagerPtr manager = func_graph->manager();
1423 MS_EXCEPTION_IF_NULL(manager);
1424 AnfNodeIndexSet node_set = manager->node_users()[node];
1425 for (auto &node_pair : node_set) {
1426 CNodePtr cnode = node_pair.first->cast<CNodePtr>();
1427 MS_EXCEPTION_IF_NULL(cnode);
1428 if (!IsValueNode<Primitive>(cnode->input(0))) {
1429 continue;
1430 }
1431 ValueNodePtr prim_node_anf = cnode->input(0)->cast<ValueNodePtr>();
1432 MS_EXCEPTION_IF_NULL(prim_node_anf);
1433 PrimitivePtr node_prim = prim_node_anf->value()->cast<PrimitivePtr>();
1434 MS_EXCEPTION_IF_NULL(node_prim);
1435 if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive) ||
1436 IsPrimitiveCNode(cnode, prim::kPrimSend)) {
1437 continue;
1438 }
1439 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
1440 return node_pair;
1441 } else {
1442 auto tmp_pair = FindParallelCareNode(node_pair.first, recursion_num + 1);
1443 if (tmp_pair.first != nullptr) {
1444 return tmp_pair;
1445 }
1446 }
1447 }
1448 return std::make_pair(nullptr, 0);
1449 }
1450
FindSubGraph(const FuncGraphPtr & graph,const AnfNodePtr & parameter)1451 std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) {
1452 MS_EXCEPTION_IF_NULL(graph);
1453 MS_EXCEPTION_IF_NULL(parameter);
1454 FuncGraphManagerPtr manager = graph->manager();
1455 MS_EXCEPTION_IF_NULL(manager);
1456 std::pair<AnfNodePtr, int64_t> prim_anf_node_pair = FindParallelCareNode(parameter, 0);
1457 if (prim_anf_node_pair.first != nullptr) {
1458 return prim_anf_node_pair;
1459 } else {
1460 AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
1461 for (auto ¶m_pair : param_sub_set) {
1462 CNodePtr param_cnode = param_pair.first->cast<CNodePtr>();
1463 AnfNodePtr graph_value_node;
1464 if (param_cnode->input(0)->isa<CNode>()) {
1465 graph_value_node = param_cnode->input(0)->cast<CNodePtr>()->input(1);
1466 } else {
1467 graph_value_node = param_cnode->input(0);
1468 }
1469 if (!IsValueNode<FuncGraph>(graph_value_node)) {
1470 continue;
1471 }
1472 FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node);
1473 auto parameters = graph_sub->parameters();
1474 if (LongToSize(param_pair.second - 1) >= parameters.size()) {
1475 MS_LOG(EXCEPTION) << "The index is out of range, index is: " << (param_pair.second - 1) << ", vector size is "
1476 << parameters.size();
1477 }
1478 std::pair<AnfNodePtr, int64_t> res = FindSubGraph(graph_sub, parameters[LongToSize(param_pair.second - 1)]);
1479 if (res.first != nullptr) {
1480 return res;
1481 }
1482 }
1483 }
1484 return std::make_pair(nullptr, 0);
1485 }
1486
InsertAllGatherAfterCast(const CNodePtr & cnode)1487 CNodePtr InsertAllGatherAfterCast(const CNodePtr &cnode) {
1488 MS_EXCEPTION_IF_NULL(cnode);
1489 auto graph = cnode->func_graph();
1490 MS_EXCEPTION_IF_NULL(graph);
1491 auto manager = graph->manager();
1492 MS_EXCEPTION_IF_NULL(manager);
1493 // skip Load moving down and assume it only has one node user
1494 CNodePtr res = cnode;
1495 if (IsSomePrimitive(res, LOAD)) {
1496 res = manager->node_users()[cnode].begin()->first->cast<CNodePtr>();
1497 }
1498 // return true only if cnode is Cast from fp32 to fp16
1499 if (!IsSomePrimitive(res, CAST)) {
1500 return nullptr;
1501 }
1502 auto node_type = res->Type();
1503 MS_EXCEPTION_IF_NULL(node_type);
1504 if (!node_type->isa<mindspore::TensorType>()) {
1505 MS_LOG(EXCEPTION) << "Unknown type.";
1506 }
1507 auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1508 MS_EXCEPTION_IF_NULL(input_element_type);
1509 auto type_id = input_element_type->type_id();
1510
1511 if (type_id != kNumberTypeFloat32) {
1512 return res;
1513 } else {
1514 return nullptr;
1515 }
1516 }
1517
InsertAllGatherOp(const FuncGraphPtr & root,const std::string & group,const std::pair<AnfNodePtr,int> & res,const AnfNodePtr & node,const std::string & op_name,bool is_shared_param)1518 static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
1519 const AnfNodePtr &node, const std::string &op_name, bool is_shared_param) {
1520 MS_EXCEPTION_IF_NULL(res.first);
1521 MS_EXCEPTION_IF_NULL(node);
1522 auto cnode = res.first->cast<CNodePtr>();
1523 auto graph = cnode->func_graph();
1524 MS_EXCEPTION_IF_NULL(graph);
1525 auto manager = graph->manager();
1526 MS_EXCEPTION_IF_NULL(manager);
1527 auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1528 MS_EXCEPTION_IF_NULL(cnode_prim);
1529 Operator op;
1530 CNodePtr allgather;
1531 auto param_name = node->cast<ParameterPtr>()->name();
1532 if (op_name == MINI_STEP_ALL_GATHER) {
1533 op = CreateMiniStepAllGatherOp(group);
1534 } else if (op_name == MICRO_STEP_ALL_GATHER) {
1535 op = CreateMicroStepAllGatherOp(group);
1536 } else {
1537 op = CreateAllGatherOp(group);
1538 }
1539 CNodePtr cast_node = InsertAllGatherAfterCast(cnode);
1540 if (!is_shared_param && cast_node) {
1541 allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root);
1542 MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
1543 } else {
1544 InsertNode(op, cnode, IntToSize(res.second), node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name,
1545 root);
1546 allgather = cnode->input(IntToSize(res.second))->cast<CNodePtr>();
1547 MS_LOG(INFO) << "Parallel optimizer is applied before " << GetPrimName(cnode) << " for " << param_name;
1548 }
1549 // add fusion flag
1550 AddCommOpFusionType(allgather, node);
1551 // add gradients mean
1552 AddCommOpMeanFlag(allgather);
1553 }
1554
ApplyParallelOptOnParam(const FuncGraphPtr & root,const AnfNodePtr & parameter,const std::string & opt_shard_group)1555 static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
1556 const std::string &opt_shard_group) {
1557 if (opt_shard_group.empty()) {
1558 return;
1559 }
1560 FuncGraphManagerPtr manager = root->manager();
1561 MS_EXCEPTION_IF_NULL(parameter);
1562 MS_EXCEPTION_IF_NULL(manager);
1563 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
1564 int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
1565 std::string op_name;
1566 if (grad_accumulation_step > 1) {
1567 op_name = MINI_STEP_ALL_GATHER;
1568 } else if (split_stage_num > 1) {
1569 op_name = MICRO_STEP_ALL_GATHER;
1570 } else {
1571 op_name = ALL_GATHER;
1572 }
1573 auto param_sub_set = manager->node_users()[parameter];
1574 bool insert_flag = false;
1575 for (auto ¶m_pair : param_sub_set) {
1576 auto cnode = param_pair.first->cast<CNodePtr>();
1577 MS_EXCEPTION_IF_NULL(cnode);
1578 if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive) &&
1579 !IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
1580 OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
1581 if (distribute_operator == nullptr) {
1582 MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr";
1583 } else if (IntToSize(param_pair.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
1584 MS_LOG(EXCEPTION) << "The index is out of range, index is " << (param_pair.second - 1) << ", vector size is "
1585 << distribute_operator->inputs_tensor_info().size();
1586 }
1587 if (insert_flag) {
1588 // if there are multiple node users, they share one same allgather
1589 auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph(), 0);
1590 if (next_cnode.first) {
1591 manager->SetEdge(cnode, param_pair.second, next_cnode.second);
1592 MS_LOG(INFO) << "Parallel optimizer is shared between " << parameter->ToString() << " and "
1593 << GetPrimName(cnode);
1594 } else {
1595 MS_LOG(ERROR) << "Can not find the shared AllGather with multiple node users.";
1596 }
1597 } else {
1598 // insert allgather operator between shard parameter and cnode
1599 auto param_ptr = parameter->cast<ParameterPtr>();
1600 MS_EXCEPTION_IF_NULL(param_ptr);
1601 bool is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
1602 InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name, is_shared_param);
1603 insert_flag = true;
1604 }
1605 }
1606 }
1607 }
1608
GetOptShardGroup(const AnfNodePtr & parameter,TensorLayout * const tensor_layout,const OperatorInfoPtr & distribute_operator)1609 static std::string GetOptShardGroup(const AnfNodePtr ¶meter, TensorLayout *const tensor_layout,
1610 const OperatorInfoPtr &distribute_operator) {
1611 std::string opt_shard_group;
1612 if (!ParameterRequireGrad(parameter)) {
1613 // only trainable parameters need parallel optimizer
1614 MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
1615 } else if (parameter->cast<ParameterPtr>()->param_info() &&
1616 !parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
1617 MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
1618 } else if (tensor_layout->GenerateOptShardSliceShape() == Status::SUCCESS) {
1619 // get the shard tensor slice shape if the weight is repeated on devices
1620 // and the shape of the first dimension could be divided
1621 // apply parallel optimizer on parameters
1622 // create communication group for allgather operator
1623 std::vector<Group> dev_group;
1624 if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS &&
1625 !dev_group.empty()) {
1626 opt_shard_group = dev_group[0].name();
1627 MS_LOG(INFO) << "Parallel optimizer: create group for " << parameter->ToString() << " success.";
1628 } else {
1629 MS_LOG(ERROR) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
1630 }
1631 } else {
1632 MS_LOG(WARNING) << "Parallel optimizer: " << parameter->ToString() << "'s distributed shape "
1633 << tensor_layout->slice_shape().ToString() << " does not satisfy the conditions.";
1634 }
1635 return opt_shard_group;
1636 }
1637
SetSharedParameterFlag(const FuncGraphPtr & root,const AnfNodePtr & parameter)1638 void SetSharedParameterFlag(const FuncGraphPtr &root, const AnfNodePtr ¶meter) {
1639 MS_EXCEPTION_IF_NULL(root);
1640 MS_EXCEPTION_IF_NULL(parameter);
1641 FuncGraphManagerPtr manager = root->manager();
1642 MS_EXCEPTION_IF_NULL(manager);
1643 auto parameter_ptr = parameter->cast<ParameterPtr>();
1644 if (!parameter_ptr) {
1645 MS_LOG(INFO) << parameter->ToString() << " is not a parameter";
1646 return;
1647 }
1648 auto param_sub_set = manager->node_users()[parameter];
1649 int32_t users_count = 0;
1650 for (auto ¶m_pair : param_sub_set) {
1651 auto cnode = param_pair.first->cast<CNodePtr>();
1652 MS_EXCEPTION_IF_NULL(cnode);
1653 if (cnode->in_forward_flag()) users_count++;
1654 }
1655 if (users_count > 1) {
1656 auto tensor_layout = parameter_ptr->user_data<TensorLayout>();
1657 tensor_layout->set_is_shared_param(true);
1658 MS_LOG(WARNING) << "There are multiple users for " << parameter->ToString()
1659 << ". Mixed precision optimization is not valid here.";
1660 }
1661 }
1662
1663 // When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
SetParallelShape(const AnfNodePtr & parameter,const std::pair<AnfNodePtr,int64_t> & res)1664 std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int64_t> &res) {
1665 MS_EXCEPTION_IF_NULL(parameter);
1666 AbstractBasePtr abstract = parameter->abstract();
1667 MS_EXCEPTION_IF_NULL(abstract);
1668 MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
1669 CNodePtr cnode = res.first->cast<CNodePtr>();
1670 MS_EXCEPTION_IF_NULL(cnode);
1671 OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
1672 if (distribute_operator == nullptr) {
1673 MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
1674 }
1675 if (LongToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
1676 MS_LOG(EXCEPTION) << "The index is out of range, index is " << (res.second - 1) << ", vector size is "
1677 << distribute_operator->inputs_tensor_info().size();
1678 }
1679 TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(res.second - 1)];
1680 TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
1681 Shape slice_shape = tensor_layout.slice_shape().array();
1682 std::string opt_shard_group;
1683 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1684 bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
1685 if (enable_parallel_optimizer) {
1686 opt_shard_group = GetOptShardGroup(parameter, &tensor_layout, distribute_operator);
1687 }
1688 if (!opt_shard_group.empty()) {
1689 slice_shape = tensor_layout.opt_shard_slice_shape();
1690 }
1691 MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
1692 << MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name();
1693 std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
1694 MS_EXCEPTION_IF_NULL(parallel_shape);
1695 // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis.
1696 auto cloned_abstract = abstract->Clone();
1697 MS_EXCEPTION_IF_NULL(cloned_abstract);
1698 cloned_abstract->set_shape(parallel_shape);
1699 parameter->set_abstract(cloned_abstract);
1700 ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
1701 MS_EXCEPTION_IF_NULL(parameter_ptr);
1702 parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1703 return opt_shard_group;
1704 }
1705
CoverSliceShape(const FuncGraphPtr & root)1706 void CoverSliceShape(const FuncGraphPtr &root) {
1707 MS_EXCEPTION_IF_NULL(root);
1708 auto parameters = root->parameters();
1709 for (auto ¶meter : parameters) {
1710 MS_EXCEPTION_IF_NULL(parameter->Shape());
1711 auto iter = g_RefMap.find(parameter);
1712 if (iter != g_RefMap.end()) {
1713 std::string group = SetParallelShape(parameter, g_RefMap[parameter]);
1714 // find all forward nodes that use parameter in graphs and insert allgather if group is not empty
1715 SetSharedParameterFlag(root, parameter);
1716 ApplyParallelOptOnParam(root, parameter, group);
1717 continue;
1718 }
1719 std::pair<AnfNodePtr, int64_t> res = FindSubGraph(root, parameter);
1720 if (res.first == nullptr) {
1721 MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
1722 } else {
1723 std::string group = SetParallelShape(parameter, res);
1724 // find all forward nodes that use parameter in graphs and insert allgather if group is not empty
1725 SetSharedParameterFlag(root, parameter);
1726 ApplyParallelOptOnParam(root, parameter, group);
1727 MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
1728 }
1729 }
1730 g_RefMap.clear();
1731 }
1732
SetVirtualDatasetStrategy(const CNodePtr & node)1733 void SetVirtualDatasetStrategy(const CNodePtr &node) {
1734 MS_EXCEPTION_IF_NULL(node);
1735 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1736 bool full_batch = ParallelContext::GetInstance()->full_batch();
1737
1738 PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0));
1739 MS_EXCEPTION_IF_NULL(prim);
1740 if (prim->name() == VIRTUAL_DATA_SET || prim->name() == VIRTUAL_OUTPUT) {
1741 CheckGlobalDeviceManager();
1742 auto attrs_temp = prim->attrs();
1743 if (!ParallelContext::GetInstance()->dataset_strategy().empty() && prim->name() == VIRTUAL_DATA_SET) {
1744 std::vector<ValuePtr> elements;
1745 auto dataset_strategy = ParallelContext::GetInstance()->dataset_strategy();
1746 (void)std::transform(dataset_strategy.begin(), dataset_strategy.end(), std::back_inserter(elements),
1747 [](auto input_stra) { return MakeValue(input_stra); });
1748 ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1749 attrs_temp[STRATEGY] = strategy;
1750 (void)prim->SetAttrs(attrs_temp);
1751 return;
1752 }
1753 int64_t dev_num;
1754 if (full_batch) {
1755 dev_num = 1;
1756 } else {
1757 dev_num = g_device_manager->stage_device_num();
1758 }
1759 if (dev_num == 0) {
1760 MS_LOG(EXCEPTION) << "Device Num must be larger than 0, but got 0.";
1761 }
1762 std::vector<Shapes> shape_list = ExtractShape(node);
1763 if (shape_list.empty()) {
1764 MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
1765 }
1766 std::vector<ValuePtr> elements;
1767 for (size_t i = 0; i < shape_list[0].size(); i++) {
1768 if (shape_list[0][i].empty()) {
1769 MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
1770 }
1771 Dimensions input_strategy;
1772 if (!shape_list[0][i].empty() && shape_list[0][i][0] % dev_num == 0) {
1773 input_strategy.push_back(dev_num);
1774 } else if (!shape_list[0][i].empty()) {
1775 input_strategy.push_back(1);
1776 }
1777 for (size_t j = 1; j < shape_list[0][i].size(); j++) {
1778 input_strategy.push_back(1);
1779 }
1780 elements.push_back(MakeValue(input_strategy));
1781 }
1782 ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1783 attrs_temp[STRATEGY] = strategy;
1784 (void)prim->SetAttrs(attrs_temp);
1785 }
1786 }
1787
1788 // find previous parallel care node's next node.
FindPreNodes(const AnfNodePtr & node,vector<std::string> * unique_ids,vector<size_t> * indexes,size_t curr_depth)1789 bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes, size_t curr_depth) {
1790 if (curr_depth > MAX_RECURSIVE_DEPTH) {
1791 MS_LOG(WARNING) << "When find the previous node, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
1792 return false;
1793 }
1794 MS_EXCEPTION_IF_NULL(unique_ids);
1795 MS_EXCEPTION_IF_NULL(indexes);
1796 if (!node->isa<CNode>()) {
1797 return false;
1798 }
1799 CNodePtr pre_cnode = node->cast<CNodePtr>();
1800 if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
1801 return false;
1802 }
1803 bool find = false;
1804 for (size_t index = 1; index < pre_cnode->inputs().size(); ++index) {
1805 auto next_node = pre_cnode->inputs()[index];
1806 if (!next_node->isa<CNode>() || next_node->isa<Parameter>()) {
1807 return false;
1808 }
1809 CNodePtr cnode = next_node->cast<CNodePtr>();
1810 if (!IsValueNode<Primitive>(cnode->input(0))) {
1811 return false;
1812 }
1813 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
1814 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
1815 if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) {
1816 unique_ids->push_back(pre_cnode->UniqueId());
1817 indexes->push_back(index);
1818 find = true;
1819 continue;
1820 }
1821 if (FindPreNodes(cnode, unique_ids, indexes, ++curr_depth)) {
1822 find = true;
1823 continue;
1824 }
1825 }
1826 return find;
1827 }
1828
FindLastNodesUniqueId(const FuncGraphPtr & root,std::vector<std::string> * unique_ids,std::vector<size_t> * indexes)1829 void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids,
1830 std::vector<size_t> *indexes) {
1831 MS_EXCEPTION_IF_NULL(unique_ids);
1832 CNodePtr cnode = root->get_return();
1833 if (!FindPreNodes(cnode, unique_ids, indexes, 0)) {
1834 MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
1835 }
1836 }
1837
GenerateBatchParallelStrategy(const OperatorInfoPtr operator_,const PrimitivePtr prim)1838 StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim) {
1839 MS_EXCEPTION_IF_NULL(operator_);
1840 MS_EXCEPTION_IF_NULL(prim);
1841 StrategyPtr strategyPtr;
1842 std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies();
1843 MS_EXCEPTION_IF_NULL(strategy_v_ptr);
1844 strategyPtr = NewStrategy(0, *strategy_v_ptr);
1845 std::vector<ValuePtr> elements;
1846 for (size_t i = 0; i < strategy_v_ptr->size(); i++) {
1847 elements.push_back(MakeValue((*strategy_v_ptr)[i]));
1848 }
1849 ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1850 // display the strategy generated by batch parallel
1851 auto attrs = prim->attrs();
1852 attrs[GEN_STRATEGY] = strategy;
1853 (void)prim->SetAttrs(attrs);
1854 MS_LOG(INFO) << "prim " << prim->name() << " batch parallel strategy is " << attrs[GEN_STRATEGY]->ToString();
1855 return strategyPtr;
1856 }
1857
CheckExtractInfomation(const CNodePtr & cnode)1858 static bool CheckExtractInfomation(const CNodePtr &cnode) {
1859 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
1860 return false;
1861 }
1862
1863 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
1864 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1865 if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
1866 return false;
1867 }
1868
1869 if (!IsParallelCareNode(cnode)) {
1870 return false;
1871 }
1872 return true;
1873 }
1874
ExtractInformation(const std::vector<AnfNodePtr> & all_nodes)1875 void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
1876 // load strategy map from checkpoint
1877 StrategyMap stra_map;
1878 if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
1879 (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) {
1880 MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
1881 }
1882
1883 for (auto &node : all_nodes) {
1884 auto cnode = node->cast<CNodePtr>();
1885 if (!CheckExtractInfomation(cnode) || IsPrimitiveCNode(node, prim::kPrimSend)) {
1886 continue;
1887 }
1888
1889 SetVirtualDatasetStrategy(cnode);
1890 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
1891 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1892
1893 auto attrs = prim->attrs();
1894 MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name();
1895
1896 std::vector<Shapes> shape_list = ExtractShape(cnode);
1897 if (shape_list.empty()) {
1898 MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
1899 }
1900 OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list);
1901 MS_EXCEPTION_IF_NULL(operator_);
1902
1903 auto &inputs = cnode->inputs();
1904 std::vector<ValuePtr> input_value;
1905 for (size_t index = 1; index < inputs.size(); ++index) {
1906 if (inputs[index]->isa<ValueNode>()) {
1907 input_value.push_back(GetValueNode(inputs[index]));
1908 continue;
1909 }
1910 input_value.emplace_back(nullptr);
1911 }
1912 StrategyPtr strategyPtr = nullptr;
1913 (*operator_).set_input_value(input_value);
1914 (*operator_).set_outputs_dtype(cnode->Type());
1915 (*operator_).set_cnode(cnode);
1916 if (prim->name() == RESHAPE) {
1917 cnode->set_user_data<OperatorInfo>(operator_);
1918 continue;
1919 }
1920 // load strategy checkpoint
1921 // key of strategy map
1922 std::string strategy_key_name = "";
1923 auto param_names = NodeParameterName(cnode, -1, 0);
1924 if (!param_names.empty()) {
1925 strategy_key_name = prim->name() + "_" + param_names[0].first;
1926 }
1927 bool load_strategy_from_ckpt =
1928 StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
1929 if ((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(STRATEGY)) {
1930 MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
1931 << " is empty, using batch parallel";
1932 strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
1933 } else if (cnode->HasPrimalAttr(STRATEGY)) {
1934 strategyPtr = ExtractStrategy(cnode->GetPrimalAttr(STRATEGY));
1935 } else if (StrategyFound(attrs)) {
1936 strategyPtr = ExtractStrategy(attrs[STRATEGY]);
1937 } else {
1938 strategyPtr = stra_map[strategy_key_name];
1939 }
1940
1941 MS_EXCEPTION_IF_NULL(strategyPtr);
1942 if (operator_->Init(strategyPtr) == FAILED) {
1943 MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"
1944 << " trace: " << trace::DumpSourceLines(cnode);
1945 }
1946 cnode->set_user_data<OperatorInfo>(operator_);
1947 }
1948 }
1949
GetInputLayoutFromCNode(const std::pair<AnfNodePtr,int64_t> & node_pair)1950 TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair) {
1951 CNodePtr cnode = node_pair.first->cast<CNodePtr>();
1952 MS_EXCEPTION_IF_NULL(cnode);
1953 OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
1954 MS_EXCEPTION_IF_NULL(distribute_operator);
1955 int64_t index = node_pair.second;
1956 if (index > SizeToLong(distribute_operator->inputs_tensor_info().size())) {
1957 MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is " << (index - 1)
1958 << ", the vector size is " << distribute_operator->inputs_tensor_info().size();
1959 }
1960 TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(index - 1)];
1961 TensorLayout tensorlayout_in = tensorinfo_in.tensor_layout();
1962 return tensorlayout_in;
1963 }
1964
1965 // if reshape's output connect to several primitive, return the first layout found
FindNextLayout(const CNodePtr & cnode,bool * next_is_reshape)1966 std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape) {
1967 MS_EXCEPTION_IF_NULL(cnode);
1968 MS_EXCEPTION_IF_NULL(cnode->func_graph());
1969 FuncGraphManagerPtr manager = cnode->func_graph()->manager();
1970 MS_EXCEPTION_IF_NULL(manager);
1971 AnfNodeIndexSet node_set = manager->node_users()[cnode];
1972 for (auto &node_pair : node_set) {
1973 CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
1974 if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
1975 continue;
1976 }
1977 if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
1978 *next_is_reshape = true;
1979 continue;
1980 }
1981 ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
1982 MS_EXCEPTION_IF_NULL(prim_anf_node);
1983 PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
1984 MS_EXCEPTION_IF_NULL(node_prim);
1985 MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
1986 if (node_prim->name() == DEPEND && node_pair.second != 1) {
1987 continue;
1988 }
1989 if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
1990 MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
1991 *next_is_reshape = false;
1992 auto layout = GetInputLayoutFromCNode(node_pair);
1993 return std::make_shared<TensorLayout>(layout);
1994 }
1995 MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
1996 << " " << use_apply->has_user_data<OperatorInfo>();
1997
1998 auto layout_ptr = FindNextLayout(use_apply, next_is_reshape);
1999 if (layout_ptr) {
2000 return layout_ptr;
2001 }
2002 }
2003 MS_LOG(WARNING) << "FindNextLayout return nullptr, if reshape is not the last primitive, there must be some error";
2004 return nullptr;
2005 }
2006
GetOutputLayoutFromCNode(const CNodePtr & cnode,size_t output_index)2007 std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) {
2008 MS_EXCEPTION_IF_NULL(cnode);
2009 OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2010 MS_EXCEPTION_IF_NULL(distribute_operator);
2011 if (distribute_operator->outputs_tensor_info().size() <= output_index) {
2012 MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size()
2013 << ", must be greater than output_index " << output_index;
2014 }
2015 TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index];
2016 TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
2017 return std::make_shared<TensorLayout>(tensorlayout_out);
2018 }
2019
FindPrevParallelCareNodeLayout(const AnfNodePtr & node,size_t output_index)2020 std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) {
2021 if (!node->isa<CNode>()) {
2022 return nullptr;
2023 }
2024 CNodePtr cnode = node->cast<CNodePtr>();
2025 if (!IsValueNode<Primitive>(cnode->input(0))) {
2026 return nullptr;
2027 }
2028 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
2029 auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
2030 if (!layout_ptr) {
2031 MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
2032 }
2033 return layout_ptr;
2034 }
2035 return nullptr;
2036 }
2037
FindParameterNextLayout(const AnfNodePtr & node,size_t curr_depth)2038 std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node, size_t curr_depth) {
2039 if (curr_depth > MAX_RECURSIVE_DEPTH) {
2040 MS_LOG(WARNING) << "When finding the next tensor layout for the parameter, exceeded the maximum recursion depth: "
2041 << MAX_RECURSIVE_DEPTH;
2042 return nullptr;
2043 }
2044 FuncGraphManagerPtr manager = node->func_graph()->manager();
2045 MS_EXCEPTION_IF_NULL(manager);
2046 AnfNodeIndexSet node_set = manager->node_users()[node];
2047 for (auto &node_pair : node_set) {
2048 if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) {
2049 auto layout_param = FindParameterNextLayout(node_pair.first, ++curr_depth);
2050 if (!layout_param) {
2051 continue;
2052 }
2053 return layout_param;
2054 }
2055 CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
2056 if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
2057 continue;
2058 }
2059 ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
2060 MS_EXCEPTION_IF_NULL(prim_anf_node);
2061 PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
2062 MS_EXCEPTION_IF_NULL(node_prim);
2063 if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == RESHAPE) {
2064 continue;
2065 }
2066 if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
2067 auto layout = GetInputLayoutFromCNode(node_pair);
2068 return std::make_shared<TensorLayout>(layout);
2069 }
2070 }
2071 return nullptr;
2072 }
2073
CreateParameterLayout(const AnfNodePtr & node)2074 std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
2075 // Create DataParallel tensor layout for parameter(support WideDeep).
2076 auto next_layout = FindParameterNextLayout(node, 0);
2077 if (next_layout != nullptr) {
2078 return next_layout;
2079 }
2080 CheckGlobalDeviceManager();
2081 int64_t dev_num = g_device_manager->stage_device_num();
2082 TensorLayout input_tensor_layout;
2083 // create input_shape
2084 Shapes inputs_shape = GetNodeShape(node);
2085 Shape input_shape_array = inputs_shape[0];
2086 if (input_shape_array.empty()) {
2087 MS_LOG(EXCEPTION) << "Don't support reshape a scalar parameter.";
2088 }
2089 // create tensor_map
2090 size_t shape_size = input_shape_array.size();
2091 TensorMap input_tensor_map_array(SizeToLong(shape_size) - 1, -1);
2092 input_tensor_map_array.insert(input_tensor_map_array.begin(), 0);
2093 // create dev_matrix
2094 Shape dev_matrix_array = {dev_num};
2095 if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) {
2096 MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed.";
2097 }
2098 return std::make_shared<TensorLayout>(input_tensor_layout);
2099 }
2100
InferSensRedistribution(const AnfNodePtr & node,const TensorLayout & loss_layout)2101 RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const TensorLayout &loss_layout) {
2102 MS_EXCEPTION_IF_NULL(node);
2103 TensorRedistribution tensor_redistribution;
2104 // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num].
2105 CheckGlobalDeviceManager();
2106 int64_t dev_num = g_device_manager->stage_device_num();
2107 TensorLayout stand_alone_layout;
2108 Shapes inputs_shape = GetNodeShape(node);
2109 if (inputs_shape.empty()) {
2110 MS_LOG(EXCEPTION) << "InferSensRedistribution failed cause inputs shape is empty.";
2111 }
2112 Shape input_shape_array = inputs_shape[0];
2113 if (input_shape_array.empty()) {
2114 MS_LOG(INFO) << "No need to redistribution for sens.";
2115 return nullptr;
2116 }
2117 // TensorMap
2118 TensorMap stand_alone_tensor_map_array(SizeToLong(input_shape_array.size()), -1);
2119 // Dev_matrix
2120 Shape dev_matrix_array = {dev_num};
2121 if (stand_alone_layout.InitFromVector(dev_matrix_array, stand_alone_tensor_map_array, input_shape_array) == FAILED) {
2122 MS_LOG(EXCEPTION) << "Create tensor layout for Sens failed.";
2123 }
2124
2125 // Infer Redistribution op list for stand alone and loss layout.
2126 RankList dev_list = g_device_manager->GetDeviceListInThisStage();
2127 if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) {
2128 MS_LOG(EXCEPTION) << "Redistribution for Sens init failed.";
2129 }
2130 RedistributionOpListPtr sens_redistribution_list = tensor_redistribution.InferTensorRedistributionOperatorList();
2131 MS_EXCEPTION_IF_NULL(sens_redistribution_list);
2132
2133 return sens_redistribution_list;
2134 }
2135
FindPrevLayout(const AnfNodePtr & node)2136 std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
2137 if (node->isa<Parameter>()) {
2138 return CreateParameterLayout(node);
2139 }
2140 if (!node->isa<CNode>()) {
2141 return nullptr;
2142 }
2143 CNodePtr cnode = node->cast<CNodePtr>();
2144 if (!IsValueNode<Primitive>(cnode->input(0))) {
2145 return nullptr;
2146 }
2147 if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
2148 return cnode->user_data<TensorLayout>();
2149 }
2150 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() &&
2151 !IsPrimitiveCNode(node, prim::kPrimReshape)) {
2152 auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
2153 if (!layout_ptr) {
2154 MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
2155 }
2156 return layout_ptr;
2157 }
2158 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2159 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
2160 if (prim->name() == prim::kTupleGetItem) {
2161 auto tuple_index = GetTupleGetItemIndex(cnode);
2162 auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
2163 if (!layout_ptr) {
2164 MS_LOG(EXCEPTION)
2165 << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a parallel care node "
2166 "before tuple_getitem!";
2167 }
2168 return layout_ptr;
2169 }
2170 for (size_t index = 0; index < cnode->inputs().size(); ++index) {
2171 if (prim->name() == DEPEND && index != 1) {
2172 continue;
2173 }
2174 auto layout_ptr = FindPrevLayout(cnode->inputs()[index]);
2175 if (!layout_ptr) {
2176 continue;
2177 }
2178 return layout_ptr;
2179 }
2180 MS_LOG(WARNING) << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error";
2181 return nullptr;
2182 }
2183
ReshapeInit(const std::vector<AnfNodePtr> & all_nodes)2184 void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
2185 for (auto &node : all_nodes) {
2186 auto cnode = node->cast<CNodePtr>();
2187 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
2188 continue;
2189 }
2190 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2191 if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
2192 continue;
2193 }
2194 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
2195 MS_EXCEPTION_IF_NULL(prim);
2196 OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
2197 if (operator_info == nullptr) {
2198 MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
2199 }
2200 if (prim->name() != RESHAPE) {
2201 continue;
2202 }
2203 auto attrs = prim->attrs();
2204 if (StrategyFound(attrs)) {
2205 MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
2206 }
2207 MS_ASSERT(cnode->inputs().size() == 3);
2208 auto prev_layout_ptr = FindPrevLayout(cnode->input(1));
2209 if (prev_layout_ptr) {
2210 auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2211 reshape_info_ptr->SetInputLayout(*prev_layout_ptr);
2212 }
2213 bool is_next_reshape = false;
2214 auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape);
2215 if (next_layout_ptr) {
2216 auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2217 reshape_info_ptr->SetOutputLayout(*next_layout_ptr);
2218 } else if (is_next_reshape && prev_layout_ptr != nullptr) {
2219 auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2220 reshape_info_ptr->SetOutputLayout(*prev_layout_ptr);
2221 }
2222 if (operator_info->Init(nullptr) == FAILED) {
2223 MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed";
2224 }
2225 }
2226 }
2227
HandleDependLoss(const CNodePtr & cnode,size_t curr_depth)2228 CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) {
2229 if (curr_depth > MAX_RECURSIVE_DEPTH) {
2230 MS_LOG(WARNING) << "When handling the loss node of Depend, exceeded the max recursive depth: "
2231 << MAX_RECURSIVE_DEPTH;
2232 return nullptr;
2233 }
2234 // Handle return->depend->loss
2235 if (IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
2236 (IsPrimitiveCNode(cnode, prim::kPrimCast) && !cnode->has_user_data<OperatorInfo>())) {
2237 auto depend_before = cnode->input(1)->cast<CNodePtr>();
2238 MS_EXCEPTION_IF_NULL(depend_before);
2239 return HandleDependLoss(depend_before, ++curr_depth);
2240 }
2241 return cnode;
2242 }
2243
FindLossCNode(const FuncGraphPtr & func_graph,size_t max_depth)2244 LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_depth) {
2245 if (max_depth > MAX_RECURSIVE_DEPTH) {
2246 MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
2247 }
2248 LossNodeInfo loss_node_info;
2249 MS_EXCEPTION_IF_NULL(func_graph);
2250 CNodePtr return_node = func_graph->get_return();
2251 MS_EXCEPTION_IF_NULL(return_node);
2252 if (return_node->size() < 2) {
2253 MS_LOG(EXCEPTION) << "Failure: " << return_node->DebugString() << " size is smaller than 2";
2254 }
2255 AnfNodePtr pre_node = return_node->input(1);
2256 MS_EXCEPTION_IF_NULL(pre_node);
2257 auto pre_cnode = pre_node->cast<CNodePtr>();
2258 pre_cnode = HandleDependLoss(pre_cnode, 0);
2259 if (pre_cnode->input(0)->isa<CNode>()) {
2260 auto switch_cnode = pre_cnode->input(0)->cast<CNodePtr>();
2261 if (IsPrimitiveCNode(switch_cnode, prim::kPrimSwitch)) {
2262 MS_EXCEPTION_IF_NULL(switch_cnode);
2263 auto switch_graph = GetValueNode<FuncGraphPtr>(switch_cnode->input(2));
2264 return FindLossCNode(switch_graph, max_depth + 1);
2265 }
2266 }
2267 if (pre_cnode == nullptr || !IsValueNode<Primitive>(pre_cnode->input(0))) {
2268 return loss_node_info;
2269 }
2270 if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
2271 MS_LOG(DEBUG) << "pre_cnode:" << pre_cnode->ToString();
2272 return loss_node_info;
2273 }
2274 auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
2275 // notice: the GetNext op has not input
2276 if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
2277 MS_LOG(INFO) << "The loss is: " << current_prim->name();
2278 loss_node_info.loss_node = pre_cnode;
2279 return loss_node_info;
2280 }
2281
2282 // size of common cnode is larger than 1
2283 if (pre_cnode->size() < 2) {
2284 MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2";
2285 }
2286
2287 // return -> tuple_getitem -> loss
2288 if (current_prim->name() == prim::kTupleGetItem) {
2289 auto tuple_index = GetTupleGetItemIndex(pre_cnode);
2290 AnfNodePtr pre_pre_node = pre_cnode->input(1);
2291 MS_EXCEPTION_IF_NULL(pre_pre_node);
2292
2293 auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
2294 loss_node_info.has_tuple_getitem = true;
2295 loss_node_info.dout_index = tuple_index;
2296 loss_node_info.loss_node = pre_pre_cnode;
2297 return loss_node_info;
2298 }
2299
2300 // return -> make_tuple
2301 if (current_prim->name() == MAKE_TUPLE) {
2302 MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
2303 return loss_node_info;
2304 }
2305
2306 // return -> loss
2307 loss_node_info.loss_node = pre_cnode;
2308 MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
2309 return loss_node_info;
2310 }
2311
GetLossNodeGradOutputLayout(const LossNodeInfo & node_info)2312 TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) {
2313 TensorLayouts ret;
2314 auto loss_cnode = node_info.loss_node;
2315 MS_EXCEPTION_IF_NULL(loss_cnode);
2316
2317 ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
2318 MS_EXCEPTION_IF_NULL(prim_anf_node);
2319 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
2320 MS_EXCEPTION_IF_NULL(prim);
2321 if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) {
2322 MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now";
2323 return ret;
2324 }
2325
2326 OperatorInfoPtr operator_info = loss_cnode->user_data<OperatorInfo>();
2327 MS_EXCEPTION_IF_NULL(operator_info);
2328 TensorInfo loss_grad_tensor_info;
2329 size_t op_output_size = operator_info->outputs_tensor_info().size();
2330 MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is "
2331 << node_info.has_tuple_getitem << ", the output size is " << op_output_size << ", the dout_index is "
2332 << node_info.dout_index;
2333
2334 if ((op_output_size == 0) || (op_output_size <= LongToSize(node_info.dout_index))) {
2335 MS_LOG(EXCEPTION) << "The index is " << node_info.dout_index << ", but the size of outputs is " << op_output_size;
2336 }
2337
2338 if (!node_info.has_tuple_getitem && (op_output_size > 1)) {
2339 MS_LOG(EXCEPTION) << "Currently, it is not supported that the sens is a tuple.";
2340 }
2341
2342 loss_grad_tensor_info = operator_info->outputs_tensor_info()[LongToSize(node_info.dout_index)];
2343 ret.push_back(loss_grad_tensor_info.tensor_layout());
2344 return ret;
2345 }
2346
SplitSens(const CNodePtr & grad_sens_node,const TensorLayout & loss_grad_layout)2347 void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) {
2348 MS_EXCEPTION_IF_NULL(grad_sens_node);
2349 if (grad_sens_node->size() <= 1) {
2350 MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2";
2351 }
2352 AnfNodePtr sens_tensor_node = grad_sens_node->input(1);
2353 MS_EXCEPTION_IF_NULL(sens_tensor_node);
2354 Shapes sens_shapes = GetNodeShape(sens_tensor_node);
2355 if (sens_shapes.size() != 1) {
2356 MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1";
2357 }
2358 // If the shape of sens tensor is [] or [1], no need to split it.
2359 Shape sens_shape = sens_shapes[0];
2360 if (sens_shape.empty() || ((sens_shape.size() == 1) && (sens_shape[0] == 1))) {
2361 if (sens_tensor_node->isa<Parameter>()) {
2362 auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
2363 MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
2364 sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
2365 }
2366 MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
2367 return;
2368 }
2369 auto loss_shape = loss_grad_layout.tensor_shape().array();
2370 if (loss_shape != sens_shape) {
2371 MS_LOG(EXCEPTION) << "The shape of sens is not equal to loss output, it is unsupported now. Sens shape is "
2372 << ShapeToString(sens_shape) << ", loss shape is " << ShapeToString(loss_shape);
2373 }
2374 MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", split it.";
2375
2376 if (!IsValueNode<Tensor>(sens_tensor_node)) {
2377 if (sens_tensor_node->isa<Parameter>()) {
2378 MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
2379 AbstractBasePtr abstract = sens_tensor_node->abstract();
2380 MS_EXCEPTION_IF_NULL(abstract);
2381 auto slice_shape = loss_grad_layout.slice_shape().array();
2382 std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
2383 MS_EXCEPTION_IF_NULL(parallel_shape);
2384 auto cloned_abstract = abstract->Clone();
2385 MS_EXCEPTION_IF_NULL(cloned_abstract);
2386 cloned_abstract->set_shape(parallel_shape);
2387 sens_tensor_node->set_abstract(cloned_abstract);
2388 auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
2389 sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
2390 return;
2391 }
2392 if (sens_tensor_node->isa<CNode>()) {
2393 auto op_list_ptr = InferSensRedistribution(sens_tensor_node, loss_grad_layout);
2394 if (op_list_ptr == nullptr) {
2395 return;
2396 }
2397 auto sens_tensor_cnode = sens_tensor_node->cast<CNodePtr>();
2398 auto func_graph = grad_sens_node->func_graph();
2399 MS_EXCEPTION_IF_NULL(func_graph);
2400 InsertRedistribution(op_list_ptr, grad_sens_node, func_graph, 1, sens_tensor_cnode);
2401 return;
2402 }
2403 MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter or CNode, it is unsupported now.";
2404 }
2405
2406 // Use _GetTensorSlice operator to split the sens tensor
2407 FuncGraphPtr func_graph = grad_sens_node->func_graph(); // only cnode can get the graph
2408 MS_EXCEPTION_IF_NULL(func_graph);
2409 Operator op = CreateGetTensorSliceOp(loss_grad_layout);
2410 InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS);
2411 }
2412
InsertForwardOps(const OperatorInfoPtr & distribute_operator,const CNodePtr & cnode)2413 void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
2414 MS_EXCEPTION_IF_NULL(distribute_operator);
2415 MS_EXCEPTION_IF_NULL(cnode);
2416 if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
2417 return;
2418 }
2419 OperatorVector forward_op = distribute_operator->forward_op();
2420 if (!forward_op.empty()) {
2421 MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name();
2422 ForwardCommunication(forward_op, cnode);
2423 }
2424 }
2425
StepReplace(const OperatorInfoPtr & distribute_operator,const CNodePtr & cnode)2426 void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
2427 MS_EXCEPTION_IF_NULL(distribute_operator);
2428 MS_EXCEPTION_IF_NULL(cnode);
2429 // StepReplaceOp
2430 OperatorVector replace_op = distribute_operator->replace_op();
2431 if (!replace_op.empty()) {
2432 MS_LOG(INFO) << "StepReplaceOp " << cnode->ToString();
2433 StepReplaceOp(replace_op, cnode);
2434 }
2435
2436 // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore.
2437 ReplaceGraphPtr replace_graph = distribute_operator->replace_graph(cnode);
2438 if (!replace_op.empty() && replace_graph) {
2439 MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used";
2440 }
2441 if (replace_graph) {
2442 MS_LOG(INFO) << "StepReplaceGraph " << cnode->ToString();
2443 StepReplaceGraph(replace_graph, cnode);
2444 }
2445 }
2446
FindForwardGraphByRootNodes(const AnfNodeSet & root_all_nodes)2447 std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) {
2448 // J->CNode->Graph
2449 std::set<FuncGraphPtr> graph_set;
2450 for (auto &node : root_all_nodes) {
2451 MS_EXCEPTION_IF_NULL(node);
2452 if (!node->isa<CNode>()) {
2453 continue;
2454 }
2455
2456 auto cnode = node->cast<CNodePtr>();
2457 if ((cnode->size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) {
2458 continue;
2459 }
2460 auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2461 if (expect_j_prim->name() != J) {
2462 continue;
2463 }
2464 if (IsValueNode<FuncGraph>(cnode->input(1))) {
2465 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
2466 MS_LOG(DEBUG) << "Find the forward graph success";
2467 graph_set.insert(graph);
2468 auto manager = graph->manager();
2469 MS_EXCEPTION_IF_NULL(manager);
2470 auto graph_used = manager->func_graphs_used_total(graph);
2471 for (auto &sub_graph : graph_used) {
2472 graph_set.insert(sub_graph);
2473 }
2474 }
2475 }
2476 return graph_set;
2477 }
2478
StepSplitSens(const std::pair<CNodePtr,LossNodeInfo> & sens_loss_pair)2479 void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
2480 CNodePtr sens_node = sens_loss_pair.first;
2481 auto loss_node = sens_loss_pair.second;
2482 auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
2483 if (!loss_grad_layout.empty()) {
2484 SplitSens(sens_node, loss_grad_layout[0]);
2485 }
2486 }
2487
2488 // Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
GetSensLossPairs(const FuncGraphPtr & root)2489 std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
2490 MS_EXCEPTION_IF_NULL(root);
2491 std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs;
2492 for (auto &node : root->nodes()) {
2493 if (!node->isa<CNode>()) {
2494 continue;
2495 }
2496
2497 // cnode(sens)-->cnode(tuple_getitem)
2498 auto sens_cnode = node->cast<CNodePtr>();
2499 AnfNodePtr expect_tuple_getitem = sens_cnode->input(0);
2500 MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
2501 if (!expect_tuple_getitem->isa<CNode>()) {
2502 continue;
2503 }
2504
2505 auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
2506 if (!IsSomePrimitive(expect_tuple_getitem_cnode, prim::kTupleGetItem)) {
2507 continue;
2508 }
2509
2510 // cnode(sens)-->cnode(tuple_getitem)-->cnode
2511 AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
2512 MS_EXCEPTION_IF_NULL(expect_anonymous);
2513 if (!expect_anonymous->isa<CNode>()) {
2514 continue;
2515 }
2516
2517 // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
2518 auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
2519 AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
2520 MS_EXCEPTION_IF_NULL(expect_j);
2521 if (!expect_j->isa<CNode>()) {
2522 continue;
2523 }
2524 auto expect_j_cnode = expect_j->cast<CNodePtr>();
2525 if (!IsSomePrimitive(expect_j_cnode, J)) {
2526 continue;
2527 }
2528
2529 if (!IsValueNode<FuncGraph>(expect_j_cnode->input(1))) {
2530 MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
2531 }
2532 auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
2533 auto loss_node_info = FindLossCNode(func_graph, 0);
2534 if (loss_node_info.loss_node == nullptr) {
2535 MS_LOG(WARNING) << "Can not find the loss cnode";
2536 continue;
2537 }
2538 std::pair<CNodePtr, LossNodeInfo> sens_loss_pair = std::make_pair(sens_cnode, loss_node_info);
2539 sens_loss_pairs.push_back(sens_loss_pair);
2540 }
2541 return sens_loss_pairs;
2542 }
2543
ParallelCommunication(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)2544 void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
2545 const FuncGraphManagerPtr &manager) {
2546 MS_EXCEPTION_IF_NULL(root);
2547 MS_EXCEPTION_IF_NULL(manager);
2548 TensorRedistribution tensor_redistribution;
2549
2550 std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs = GetSensLossPairs(root);
2551 bool has_backward = !sens_loss_pairs.empty();
2552 // split sens must before inserting the operators.
2553 for (auto &pair : sens_loss_pairs) {
2554 // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handle it.
2555 // If the type of sens node is not Tensor, it is unsupported now, do nothing default.
2556 if (IsLastStage()) {
2557 StepSplitSens(pair);
2558 }
2559 }
2560
2561 for (auto &node : all_nodes) {
2562 MS_EXCEPTION_IF_NULL(node);
2563 if (node->isa<CNode>()) {
2564 auto cnode = node->cast<CNodePtr>();
2565 // the make_tuple is parallel care node, but it may have not operator info
2566 if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
2567 continue;
2568 }
2569
2570 OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2571 MS_EXCEPTION_IF_NULL(distribute_operator);
2572
2573 // skip Send Receive
2574 if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
2575 // insert forward ops
2576 InsertForwardOps(distribute_operator, cnode);
2577
2578 // insert redistribution ops
2579 StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
2580 }
2581 // insert backward ops
2582 if (has_backward) {
2583 BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
2584 }
2585
2586 distribute_operator->ReplaceNodeInputOrAttrs();
2587 } else if (IsValueNode<Tensor>(node) || IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
2588 StepSplitTensor(node, manager);
2589 }
2590 }
2591
2592 for (auto &node : all_nodes) {
2593 MS_EXCEPTION_IF_NULL(node);
2594 if (node->isa<CNode>()) {
2595 auto cnode = node->cast<CNodePtr>();
2596 if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE) ||
2597 IsSomePrimitive(cnode, SEND)) {
2598 continue;
2599 }
2600
2601 OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2602 MS_EXCEPTION_IF_NULL(distribute_operator);
2603 // StepReplace
2604 StepReplace(distribute_operator, cnode);
2605 }
2606 }
2607 }
2608
IsCohesiveNode(const CNodePtr & cnode)2609 bool IsCohesiveNode(const CNodePtr &cnode) {
2610 return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
2611 IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather) ||
2612 IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather);
2613 }
2614
NodeParameterName(const CNodePtr & node,int64_t index,size_t curr_depth)2615 ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
2616 if (curr_depth > MAX_RECURSIVE_DEPTH) {
2617 MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: "
2618 << MAX_RECURSIVE_DEPTH;
2619 return {};
2620 }
2621 std::vector<AnfNodePtr> node_inputs{node->inputs()};
2622 ParameterMap param_names;
2623 for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
2624 int64_t idx = index > i ? index : i;
2625 auto input = node_inputs[LongToSize(i)];
2626 if (input->isa<Parameter>()) {
2627 auto input_parameter = input->cast<ParameterPtr>();
2628 if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
2629 (void)param_names.emplace_back(std::make_pair(input_parameter->name(), input_parameter));
2630 }
2631 } else if (input->isa<CNode>()) {
2632 CNodePtr cnode = input->cast<CNodePtr>();
2633 if (!IsValueNode<Primitive>(cnode->input(0))) {
2634 continue;
2635 }
2636 if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) {
2637 auto input_param_names = NodeParameterName(cnode, idx, 0);
2638 param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end());
2639 }
2640 }
2641 }
2642 return param_names;
2643 }
2644
IsGatherPInfo(const std::string & name)2645 bool IsGatherPInfo(const std::string &name) {
2646 std::vector<std::string> gather_p_info_names = {"GatherPInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"};
2647 for (std::string info_name : gather_p_info_names) {
2648 if (name.find(info_name) != std::string::npos) {
2649 return true;
2650 }
2651 }
2652 return false;
2653 }
2654
CheckpointStrategy(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)2655 void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
2656 StrategyMap stra_map;
2657 TensorInfoMap tensor_info_map;
2658 ManualShapeMap manual_shape_map;
2659 for (auto &node : all_nodes) {
2660 MS_EXCEPTION_IF_NULL(node);
2661 auto cnode = node->cast<CNodePtr>();
2662 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
2663 continue;
2664 }
2665 auto param_names = NodeParameterName(cnode, -1, 0);
2666 if (param_names.empty()) {
2667 continue;
2668 }
2669 string param_name = param_names[0].first;
2670 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2671 MS_EXCEPTION_IF_NULL(prim);
2672 OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
2673 if (operator_info) {
2674 if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
2675 continue;
2676 }
2677 std::string stratey_key_name = prim->name() + "_" + param_name;
2678 stra_map[stratey_key_name] = operator_info->strategy();
2679 for (auto param_name_pair : param_names) {
2680 tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>();
2681 }
2682 if (IsGatherPInfo(operator_info->name())) {
2683 auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info);
2684 auto param_split_shapes = gatherv2_info->param_split_shapes();
2685 auto index_offsets = gatherv2_info->index_offsets();
2686 if (param_split_shapes.size() != index_offsets.size()) {
2687 MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
2688 }
2689 std::vector<std::pair<int64_t, int64_t>> manual_shape;
2690 for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
2691 manual_shape.emplace_back(std::make_pair(param_split_shapes[LongToSize(i)], index_offsets[LongToSize(i)]));
2692 }
2693 manual_shape_map[param_name] = manual_shape;
2694 }
2695 }
2696 }
2697 for (auto &cloned_parameter_node : root->parameters()) {
2698 MS_EXCEPTION_IF_NULL(cloned_parameter_node);
2699 auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
2700 MS_EXCEPTION_IF_NULL(cloned_parameter);
2701
2702 if (!ParameterIsCloned(cloned_parameter_node)) {
2703 continue;
2704 }
2705 std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
2706 auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>();
2707 if (cloned_param_layout == nullptr) {
2708 continue;
2709 }
2710 tensor_info_map[cloned_param_name] = cloned_param_layout;
2711 }
2712 if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
2713 MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
2714 }
2715 }
2716
SetForwardFlag(const std::vector<AnfNodePtr> & all_nodes)2717 void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
2718 for (auto &node : all_nodes) {
2719 MS_EXCEPTION_IF_NULL(node);
2720 if (!node->isa<CNode>()) {
2721 continue;
2722 }
2723 auto cnode = node->cast<CNodePtr>();
2724 if (!IsValueNode<Primitive>(cnode->input(0))) {
2725 continue;
2726 }
2727
2728 // CNode is globally unique.
2729 MS_LOG(DEBUG) << "Set forward flag " << cnode->DebugString() << ".";
2730 cnode->set_in_forward_flag(true);
2731 }
2732 }
2733
SetForwardFlag(const AnfNodeSet & all_nodes)2734 void SetForwardFlag(const AnfNodeSet &all_nodes) {
2735 for (auto &node : all_nodes) {
2736 MS_EXCEPTION_IF_NULL(node);
2737 if (!node->isa<CNode>()) {
2738 continue;
2739 }
2740 auto cnode = node->cast<CNodePtr>();
2741 if (!IsValueNode<Primitive>(cnode->input(0))) {
2742 continue;
2743 }
2744
2745 // CNode is globally unique.
2746 cnode->set_in_forward_flag(true);
2747 }
2748 }
2749
ForwardGraph(const FuncGraphPtr & root)2750 std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
2751 MS_EXCEPTION_IF_NULL(root);
2752 const auto &all_nodes = root->nodes();
2753 std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
2754 return graph_set;
2755 }
2756
FindRootForwardCNode(const FuncGraphPtr & graph,const AnfNodeSet & all_nodes)2757 std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
2758 MS_EXCEPTION_IF_NULL(graph);
2759 std::vector<AnfNodePtr> root_forward_nodes;
2760 auto loss_cnode = FindLossCNode(graph, 0).loss_node;
2761 if (loss_cnode == nullptr) {
2762 MS_LOG(WARNING) << "Can not find the loss cnode";
2763 return root_forward_nodes;
2764 }
2765
2766 auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy();
2767 for (auto &node : all_nodes) {
2768 MS_EXCEPTION_IF_NULL(node);
2769 if (!node->isa<CNode>()) {
2770 continue;
2771 }
2772 auto cnode = node->cast<CNodePtr>();
2773 auto root_node_id = node->UniqueIdThroughCopy();
2774 if (loss_cnode_id == root_node_id) {
2775 root_forward_nodes = DeepLinkedGraphSearch(cnode);
2776 break;
2777 }
2778 }
2779 return root_forward_nodes;
2780 }
2781
InsertShapeOp(const CNodePtr & node,const AnfNodePtr & pre_node,const FuncGraphPtr & root)2782 void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncGraphPtr &root) {
2783 // shape op doesn't have params and attrs.
2784 OperatorParams params;
2785 OperatorAttrs attrs;
2786 auto shape_value = GetValueNode(node->input(2))->cast<ValueSequeuePtr>();
2787 MS_EXCEPTION_IF_NULL(shape_value);
2788 auto shape = shape_value->value();
2789 if (shape.empty()) {
2790 return;
2791 }
2792 OperatorArgs args = std::make_pair(attrs, params);
2793 Operator op = std::make_pair(SHAPE_OP, args);
2794 InsertNode(op, node, 2, pre_node, root, "shape");
2795 }
2796
FindGrad(const CNodePtr & cnode,size_t curr_depth)2797 static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) {
2798 if (curr_depth > MAX_RECURSIVE_DEPTH) {
2799 MS_LOG(WARNING) << "When finding Grad nodes, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
2800 return nullptr;
2801 }
2802 for (auto &node : cnode->inputs()) {
2803 if (!node->isa<CNode>()) {
2804 continue;
2805 }
2806 if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) {
2807 return FindGrad(node->cast<CNodePtr>(), ++curr_depth);
2808 } else {
2809 return node;
2810 }
2811 }
2812 return nullptr;
2813 }
2814
HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> & all_nodes)2815 void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) {
2816 // If root graph has reshape op. Find the corresponding parameter.
2817 // Reshape's shape is the shape of the parameter.
2818 auto executor = pipeline::GraphExecutorPy::GetInstance();
2819 for (auto &node : all_nodes) {
2820 if (!node->isa<CNode>()) {
2821 continue;
2822 }
2823 auto cnode = node->cast<CNodePtr>();
2824 if (!IsValueNode<Primitive>(cnode->input(0)) || cnode == nullptr) {
2825 continue;
2826 }
2827 if (cnode->in_forward_flag()) {
2828 // Save strategy in executor
2829 OperatorInfoPtr op_info = cnode->user_data<OperatorInfo>();
2830 if (op_info) {
2831 auto stra_ptr = op_info->strategy();
2832 if (stra_ptr) {
2833 auto strategy = stra_ptr->GetInputDim();
2834 // fullname with scope should be found in step parallel end ir
2835 executor->SetCNodeStrategy(cnode->fullname_with_scope(), strategy);
2836 }
2837 }
2838 continue;
2839 }
2840
2841 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2842 if (prim->name() != RESHAPE) {
2843 continue;
2844 }
2845 auto root = node->func_graph();
2846 auto grad_node = FindGrad(cnode, 0);
2847 if (grad_node) {
2848 InsertShapeOp(cnode, grad_node, root);
2849 }
2850 }
2851 }
2852
MarkForwardCNode(const FuncGraphPtr & root)2853 void MarkForwardCNode(const FuncGraphPtr &root) {
2854 MS_EXCEPTION_IF_NULL(root);
2855 auto all_nodes = root->nodes();
2856 auto graph_set = FindForwardGraphByRootNodes(all_nodes);
2857
2858 if (graph_set.empty()) {
2859 MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph";
2860 SetForwardFlag(all_nodes);
2861 } else {
2862 for (auto &func_graph : graph_set) {
2863 MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
2864 auto return_node = func_graph->get_return();
2865 MS_EXCEPTION_IF_NULL(return_node);
2866 auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);
2867 SetForwardFlag(all_dfs_nodes);
2868 auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes);
2869 if (root_forward_nodes.empty()) {
2870 continue;
2871 }
2872 // Mark forward flag for the nodes in root graph.
2873 SetForwardFlag(root_forward_nodes);
2874 }
2875 }
2876 }
2877
GetCommInfo()2878 CommInfo GetCommInfo() {
2879 int64_t device_num = ParallelContext::GetInstance()->device_num();
2880 int64_t global_rank = ParallelContext::GetInstance()->global_rank();
2881 auto ms_context = MsContext::GetInstance();
2882 MS_EXCEPTION_IF_NULL(ms_context);
2883 std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2884 std::string world_group;
2885 std::string communication_backend;
2886 if (backend == kAscendDevice || backend == kDavinciDevice) {
2887 world_group = HCCL_WORLD_GROUP;
2888 communication_backend = HCCL_BACKEND;
2889 } else if (backend == kGPUDevice) {
2890 world_group = NCCL_WORLD_GROUP;
2891 communication_backend = NCCL_BACKEND;
2892 } else {
2893 MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
2894 }
2895 uint32_t world_rank_size = 0;
2896 if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
2897 MS_LOG(EXCEPTION) << "Get rank size failed";
2898 }
2899
2900 if (!ParallelContext::GetInstance()->device_num_is_set()) {
2901 device_num = UintToInt(world_rank_size);
2902 MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
2903 }
2904 #if defined(ENABLE_GPU)
2905 if (ParallelContext::GetInstance()->device_num_is_set() && backend == kGPUDevice) {
2906 if (world_rank_size != device_num) {
2907 MS_LOG(EXCEPTION) << "The device_num " << device_num
2908 << " set in the context is not consist with the word group size " << world_rank_size;
2909 }
2910 }
2911 #endif
2912
2913 uint32_t rank_id = 0;
2914 if (!ParallelContext::GetInstance()->global_rank_is_set()) {
2915 if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
2916 MS_LOG(EXCEPTION) << "Get rank id failed";
2917 }
2918 global_rank = UintToInt(rank_id);
2919 MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
2920 }
2921 CommInfo comm_info{device_num, global_rank, world_group, communication_backend};
2922 return comm_info;
2923 }
2924
ParallelInit()2925 Status ParallelInit() {
2926 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
2927 int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
2928 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
2929 if (split_stage_num <= 0) {
2930 MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << ", expected a positive stage number";
2931 return FAILED;
2932 }
2933 auto comm_info = GetCommInfo();
2934 int64_t device_num = comm_info.device_num;
2935 int64_t global_rank = comm_info.global_rank;
2936 if ((device_num <= 0) || (device_num > MAX_DEVICE_NUM)) {
2937 MS_LOG(ERROR) << "Invalid device num " << device_num;
2938 return FAILED;
2939 }
2940
2941 // the device_num maybe get from communication interface
2942 if (device_num % split_stage_num != 0) {
2943 MS_LOG(ERROR) << "Device num " << device_num << " can't be divided by stage num " << split_stage_num;
2944 return FAILED;
2945 }
2946
2947 if ((global_rank < 0) || (global_rank >= device_num)) {
2948 MS_LOG(ERROR) << "Global rank " << global_rank << " is out of range, the device num is " << device_num;
2949 return FAILED;
2950 }
2951
2952 std::vector<int64_t> stages;
2953 for (int i = 0; i < split_stage_num; i++) {
2954 stages.push_back(device_num / split_stage_num);
2955 }
2956
2957 if ((split_stage_num > 1) && (parallel_mode != SEMI_AUTO_PARALLEL)) {
2958 MS_LOG(ERROR) << "To enable the pipeline parallel, please set the parallel mode to " << SEMI_AUTO_PARALLEL;
2959 return FAILED;
2960 }
2961
2962 if (!InitDevice(device_num, global_rank, comm_info.communication_backend, stages)) {
2963 MS_LOG(ERROR) << "Init device failed";
2964 return FAILED;
2965 }
2966
2967 MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
2968 << ", communication_backend: " << comm_info.communication_backend
2969 << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean()
2970 << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
2971
2972 return SUCCESS;
2973 }
2974
HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> & all_nodes)2975 void HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> &all_nodes) {
2976 for (auto &node : all_nodes) {
2977 if (!AnfNodeIsPrimitive(node, MAKE_TUPLE) && !AnfNodeIsPrimitive(node, MAKE_LIST)) {
2978 continue;
2979 }
2980
2981 auto cnode = node->cast<CNodePtr>();
2982 MS_EXCEPTION_IF_NULL(cnode);
2983 if (!cnode->in_forward_flag()) {
2984 continue;
2985 }
2986
2987 FuncGraphManagerPtr manager = cnode->func_graph()->manager();
2988 MS_EXCEPTION_IF_NULL(manager);
2989 std::string op_type = AnfNodeIsPrimitive(node, MAKE_TUPLE) ? MAKE_TUPLE : MAKE_LIST;
2990
2991 auto &make_tuple_list_user = manager->node_users()[cnode];
2992 if (make_tuple_list_user.size() != 1) {
2993 MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user must be 1, but got " << make_tuple_list_user.size();
2994 }
2995 CNodePtr make_tuple_list_next_cnode = make_tuple_list_user.front().first->cast<CNodePtr>();
2996 MS_EXCEPTION_IF_NULL(make_tuple_list_next_cnode);
2997
2998 std::string make_tuple__list_user_prim_name = GetPrimName(make_tuple_list_next_cnode);
2999 if (!IsParallelCareNode(make_tuple_list_next_cnode)) {
3000 MS_LOG(INFO) << "The " << op_type << "'s user is " << make_tuple__list_user_prim_name
3001 << ", no need to set operator info";
3002 continue;
3003 }
3004 if (make_tuple_list_next_cnode->inputs().size() != 2) {
3005 MS_LOG(EXCEPTION) << "Now the " << op_type << "'s user only support 1 input, but got "
3006 << (make_tuple_list_next_cnode->inputs().size() - 1);
3007 }
3008
3009 MS_LOG(INFO) << "Set the " << op_type << "'s operator info, and the op name is " << make_tuple__list_user_prim_name;
3010 OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_list_next_cnode);
3011 MS_EXCEPTION_IF_NULL(op_info);
3012 cnode->set_user_data<OperatorInfo>(op_info);
3013 }
3014 }
3015
CreateGroupsByCkptFile(const std::string & file)3016 bool CreateGroupsByCkptFile(const std::string &file) {
3017 GroupInfoMap group_info_map;
3018 if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) {
3019 return false;
3020 }
3021
3022 if (CreateGroups(group_info_map) != SUCCESS) {
3023 return false;
3024 }
3025 MS_LOG(INFO) << "Create groups by checkpoint file success";
3026 return true;
3027 }
3028
ReorderForPipelineSplit(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,int64_t pipeline_stages)3029 void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages) {
3030 if (!root->has_flag(BACKWARD) && pipeline_stages > 1) {
3031 root->set_flag(BACKWARD, true);
3032 if (root->has_flag(TRAINING)) {
3033 Reorder(root);
3034 } else {
3035 ReorderForPredict(root, manager);
3036 }
3037 }
3038 }
3039
IsInsertVirtualOutput(const FuncGraphPtr & root)3040 bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
3041 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
3042 auto comm_info = GetCommInfo();
3043 int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
3044 int64_t per_stage_device_num = comm_info.device_num / split_stage_num;
3045 int64_t current_stage = comm_info.global_rank / per_stage_device_num;
3046 MS_LOG(INFO) << "The current stage is: " << current_stage;
3047 if (!root->has_flag(TRAINING) && !ParallelContext::GetInstance()->dataset_strategy().empty()) {
3048 MS_LOG(WARNING) << "In eval/predict net, the output parallel strategy would not follow "
3049 "the input parallel strategy when using context.set_auto_parallel_context(dataset_strategy)"
3050 " to configure the input strategy.";
3051 }
3052 return (!root->has_flag(TRAINING) && ParallelContext::GetInstance()->dataset_strategy().empty() &&
3053 current_stage == split_stage_num - 1);
3054 }
3055
StepParallel(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)3056 bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
3057 #if ((defined ENABLE_CPU) && (!defined _WIN32))
3058 if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
3059 return false;
3060 }
3061 #endif
3062 MS_EXCEPTION_IF_NULL(root);
3063 MS_EXCEPTION_IF_NULL(optimizer);
3064 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
3065 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
3066 pipeline::ResourceBasePtr res = optimizer->resource();
3067 MS_EXCEPTION_IF_NULL(res);
3068 FuncGraphManagerPtr manager = res->manager();
3069 MS_EXCEPTION_IF_NULL(manager);
3070 auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3071 // assume no change to graph
3072 bool changes = false;
3073 // control whether use model_parallel mode
3074 if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
3075 (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) {
3076 if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) {
3077 if (HasStrategy(root)) {
3078 MS_LOG(INFO) << "Strategies ignored in " << parallel_mode
3079 << ", set_strategy() only valid in [semi_]auto_parallel.";
3080 }
3081 root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
3082 }
3083 ReorderForPipelineSplit(root, manager, pipeline_stages);
3084
3085 return changes;
3086 }
3087
3088 struct timeval start_time, end_time;
3089 (void)gettimeofday(&start_time, nullptr);
3090
3091 MS_LOG(INFO) << "Now entering step parallel";
3092 DumpGraph(root, std::string(STEP_PARALLEL_BEGIN));
3093 AnfNodePtr ret = root->get_return();
3094 MS_EXCEPTION_IF_NULL(ret);
3095 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
3096 std::reverse(all_nodes.begin(), all_nodes.end());
3097 if (parallel_mode != AUTO_PARALLEL) {
3098 TOTAL_OPS = 0;
3099 if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) {
3100 MS_LOG(EXCEPTION) << "Parallel init failed";
3101 }
3102
3103 if (pipeline_stages > 1) {
3104 HandleMicroBatch(all_nodes, manager);
3105 ParameterStartNode(all_nodes, manager);
3106 LastStageEndNode(all_nodes, manager, root);
3107 }
3108
3109 // mark the forward cnodes, parallel only care these nodes
3110 MarkForwardCNode(root);
3111
3112 if (FindCommunicationOp(all_nodes)) {
3113 MS_LOG(EXCEPTION) << "The graph contain communication op";
3114 }
3115
3116 if (IsInsertVirtualOutput(root)) {
3117 InsertVirtualOutput(root, all_nodes);
3118 AnfNodePtr ret_after = root->get_return();
3119 MS_EXCEPTION_IF_NULL(ret_after);
3120 all_nodes = DeepScopedGraphSearch(ret_after);
3121 std::reverse(all_nodes.begin(), all_nodes.end());
3122 }
3123
3124 // extract shape and strategy, set operator_info
3125 ExtractInformation(all_nodes);
3126 ReshapeInit(all_nodes);
3127 }
3128
3129 HandleRootReshapeAndSaveStrategy(all_nodes);
3130
3131 HandleForwardMakeTupleAndMakeList(all_nodes);
3132
3133 // if the input or parameter has multiple users, check whether its split strategies are consistent.
3134 CheckParameterSplit(all_nodes);
3135
3136 HandleSymbolicKeyInstance(root, all_nodes);
3137
3138 // cover Parallel shape
3139 CoverSliceShape(root);
3140
3141 // handle input is not used
3142 HandleNoUsedParameter(root);
3143
3144 // set the shape for optimizer's clone tensor
3145 SetClonedTensorShapeForOptimizer(root);
3146
3147 HandleAdaFactorOpt(root);
3148
3149 // save strategy as checkpoint for multi-train
3150 if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
3151 CheckpointStrategy(all_nodes, root);
3152 }
3153 // ForwardCommunication BackwardCommunication TensorRedistribution
3154 ParallelCommunication(root, all_nodes, manager);
3155
3156 if (pipeline_stages > 1) {
3157 AddVirtualAssignAdd(root);
3158 HandleReceiveParam(root, all_nodes);
3159 }
3160
3161 auto group_info = g_device_manager->group_info();
3162 if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
3163 StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
3164 MS_LOG(EXCEPTION) << "Save group info failed";
3165 }
3166
3167 // handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
3168 HandleFullySplitParameters(root);
3169
3170 DumpGraph(root, std::string(STEP_PARALLEL_END));
3171
3172 // step parallel only run once
3173 root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true);
3174 res->results()[pipeline::kStepParallelGraph] = root;
3175
3176 // in auto parallel mode, no need to check if stategies set
3177 root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
3178
3179 (void)gettimeofday(&end_time, nullptr);
3180 uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
3181 time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
3182
3183 MS_LOG(INFO) << "Now leaving step parallel, used time: " << time << " us";
3184 return changes;
3185 }
3186
3187 // Needed by rec_parser
ExtractInputsTensorName(const CNodePtr & node)3188 std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node) {
3189 std::vector<std::string> name_inputs;
3190 std::vector<AnfNodePtr> all_inputs = node->inputs();
3191 std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
3192
3193 std::string node_id = node->UniqueId();
3194 name_inputs.push_back(node_id);
3195 for (auto &input : node_inputs) {
3196 std::string name = input->UniqueId();
3197 name_inputs.push_back(name);
3198 }
3199
3200 return name_inputs;
3201 }
3202 } // namespace parallel
3203 } // namespace mindspore
3204