1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/step_parallel.h"
18
19 #include <cinttypes>
20 #include <algorithm>
21 #include <chrono>
22 #include <map>
23 #include <unordered_map>
24 #include <memory>
25 #include <set>
26 #include <string>
27 #include <queue>
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/other_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "mindspore/core/ops/structure_ops.h"
32 #include "mindspore/core/ops/framework_ops.h"
33 #include "utils/hash_map.h"
34 #include "frontend/operator/ops.h"
35 #include "frontend/optimizer/optimizer.h"
36 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
37 #include "include/common/utils/parallel_context.h"
38 #include "frontend/parallel/device_manager.h"
39 #include "frontend/parallel/ops_info/gather_info.h"
40 #include "frontend/parallel/ops_info/reshape_info.h"
41 #include "frontend/parallel/graph_util/generate_graph.h"
42 #include "frontend/parallel/graph_util/graph_info.h"
43 #include "frontend/parallel/graph_util/node_info.h"
44 #include "frontend/parallel/graph_util/graph_utils.h"
45 #include "frontend/parallel/tensor_layout/prime_generator.h"
46 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
47 #include "frontend/parallel/graph_util/fold_pipeline_split_utils.h"
48 #include "frontend/parallel/pipeline_transformer/pipeline_interleave.h"
49 #include "frontend/parallel/graph_util/grad_accumulation_utils.h"
50 #include "frontend/parallel/node_check.h"
51 #include "frontend/parallel/silent_check/silent_check.h"
52 #include "frontend/parallel/parameter_manager.h"
53 #include "frontend/parallel/ops_info/matmul_info.h"
54 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
55 #include "frontend/parallel/tensor_layout/tensor_transform.h"
56 #include "ir/param_info.h"
57 #include "ir/tensor.h"
58 #include "utils/trace_base.h"
59 #include "include/common/utils/comm_manager.h"
60 #include "utils/ms_context.h"
61 #include "utils/symbolic.h"
62 #include "mindspore/core/utils/parallel_node_check.h"
63 #include "frontend/parallel/parallel_optimizer/opt_param_mgr.h"
64 #include "mindspore/core/ops/conv_pool_ops.h"
65 #include "mindspore/core/ops/nn_ops.h"
66 #include "mindspore/core/ops/ops_func_impl/flash_attention_score.h"
67
68 #if defined(__linux__) && defined(WITH_BACKEND)
69 #include "include/backend/distributed/ps/util.h"
70 #include "include/backend/distributed/ps/ps_context.h"
71 #endif
72
73 using mindspore::tensor::Tensor;
74
75 namespace mindspore {
76 namespace parallel {
77 static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE};
78 static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL, STANDARD_NORMAL};
79 const uint32_t MAX_BFS_DEPTH = 7;
80 const char kSilentCheckEnvEnable[] = "1";
81
SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> & new_node_input,const CNodePtr & node)82 static void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, const CNodePtr &node) {
83 if (new_node_input.empty()) {
84 return;
85 }
86
87 auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
88 auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
89 MS_EXCEPTION_IF_NULL(prim);
90 auto attrs = prim->attrs();
91
92 auto anf_node = node->input(0)->cast<ValueNodePtr>();
93 auto prim_node = GetValueNode<PrimitivePtr>(anf_node);
94 MS_EXCEPTION_IF_NULL(prim_node);
95 auto node_attrs = prim_node->attrs();
96 if (node_attrs.find(RECOMPUTE_COMM_OP) != node_attrs.end() && !GetValue<bool>(node_attrs[RECOMPUTE_COMM_OP])) {
97 attrs[RECOMPUTE] = MakeValue<bool>(false);
98 (void)prim->SetAttrs(attrs);
99 MS_LOG(INFO) << "Do not recompute the forward communication operator of " << prim_node->ToString();
100 }
101 }
102
103 // Replace pre_node with pre_node->op
ReplaceNode(const Operator & op,const AnfNodePtr & pre_node,const FuncGraphPtr & func_graph,const std::string & instance_name,const std::string & param_name="",const FuncGraphPtr & root=nullptr)104 static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
105 const std::string &instance_name, const std::string ¶m_name = "",
106 const FuncGraphPtr &root = nullptr) {
107 // insert new node before the node
108 FuncGraphManagerPtr manager = func_graph->manager();
109 MS_EXCEPTION_IF_NULL(manager);
110 ScopePtr scope = pre_node->scope();
111 MS_EXCEPTION_IF_NULL(scope);
112 std::vector<AnfNodePtr> node_input;
113 if (root && !param_name.empty()) {
114 node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
115 } else {
116 node_input = CreateInput(op, pre_node, instance_name);
117 }
118 CNodePtr new_node = func_graph->NewCNode(node_input);
119 MS_EXCEPTION_IF_NULL(new_node);
120 if (instance_name.find(SPLIT_SENS) == std::string::npos) {
121 new_node->set_in_forward_flag(true); // mark forward flag
122 }
123 auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
124 new_node_prim->set_instance_name(instance_name);
125 new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
126 if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
127 new_node_prim->set_attr("recompute", MakeValue(false));
128 } else if (instance_name.find(RECOMPUTE) != std::string::npos) {
129 new_node_prim->set_attr("recompute", MakeValue(true));
130 }
131 new_node->set_scope(scope);
132 node_input[0]->set_scope(scope);
133 (void)manager->Replace(pre_node, new_node);
134 MS_LOG(INFO) << "Insert " << instance_name << " success";
135 return new_node;
136 }
137
ForwardCommunicationForMultiOut(OperatorVector forward_op,const CNodePtr & node)138 void ForwardCommunicationForMultiOut(OperatorVector forward_op, const CNodePtr &node) {
139 MS_EXCEPTION_IF_NULL(node);
140 // step1:get graph manager distribute_operator
141 FuncGraphPtr func_graph = node->func_graph();
142 MS_EXCEPTION_IF_NULL(func_graph);
143 FuncGraphManagerPtr manager = func_graph->manager();
144 MS_EXCEPTION_IF_NULL(manager);
145 auto uses_set = manager->node_users()[node];
146 // For GMM, its out always be tuplegetitem, so we need to find the real user of GMM
147 std::vector<CNodePtr> node_to_insert = {};
148 for (auto &uses_pair : uses_set) {
149 auto uses_cnode = uses_pair.first->cast<CNodePtr>();
150 MS_EXCEPTION_IF_NULL(uses_cnode);
151 if (!IsValueNode<Primitive>(uses_cnode->input(0))) {
152 break;
153 }
154 PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0));
155 MS_EXCEPTION_IF_NULL(value_node_prim);
156 if (value_node_prim->name() == prim::kPrimTupleGetItem->name()) {
157 node_to_insert.push_back(uses_cnode);
158 }
159 }
160 if (node_to_insert.empty()) {
161 MS_LOG(ERROR) << "The output of " << node->DebugString()
162 << "does not have a tuplegetitem node. Forward communication can not be inserted, the correctness of "
163 "current op can not be ensured.";
164 return;
165 }
166 std::reverse(forward_op.begin(), forward_op.end());
167
168 // step2:traverse op_list and insert node
169 for (size_t index = 0; index < forward_op.size(); ++index) {
170 std::string instance_name_base = FORWARD_OP;
171 std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
172 std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert[index], instance_name);
173 SetAllReduceRecomputeFlag(forward_input, node_to_insert[index]);
174 CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
175 MS_EXCEPTION_IF_NULL(forward_node);
176 ScopePtr scope = node->scope();
177 MS_EXCEPTION_IF_NULL(scope);
178 forward_node->set_scope(scope);
179 forward_node->set_in_forward_flag(true);
180 forward_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, MakeValue<std::string>(forward_node->UniqueId()));
181 if (node_to_insert[index]->HasPrimalAttr(MICRO)) {
182 forward_node->AddPrimalAttr(MICRO, node_to_insert[index]->GetPrimalAttr(MICRO));
183 }
184 forward_input[0]->set_scope(scope);
185 (void)manager->Replace(node_to_insert[index], forward_node); // using Replace function to insert node
186 }
187 }
188
ForwardCommunication(OperatorVector forward_op,const CNodePtr & node)189 void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
190 if (dyn_cast<abstract::SequenceShape>(node->Shape()) != nullptr) {
191 // For Ops like GMM has multiple output
192 MS_LOG(INFO) << "The input node " << node->DebugString()
193 << " has multiple output, enter ForwardCommunicationForMultiOut";
194 ForwardCommunicationForMultiOut(forward_op, node);
195 return;
196 }
197 MS_EXCEPTION_IF_NULL(node);
198 // step1:get graph manager distribute_operator
199 FuncGraphPtr func_graph = node->func_graph();
200 MS_EXCEPTION_IF_NULL(func_graph);
201 FuncGraphManagerPtr manager = func_graph->manager();
202 MS_EXCEPTION_IF_NULL(manager);
203 auto uses_set = manager->node_users()[node];
204 CNodePtr node_to_insert = node;
205 for (auto &uses_pair : uses_set) {
206 auto uses_cnode = uses_pair.first->cast<CNodePtr>();
207 MS_EXCEPTION_IF_NULL(uses_cnode);
208 if (!IsValueNode<Primitive>(uses_cnode->input(0))) {
209 break;
210 }
211 PrimitivePtr value_node_prim = GetValueNode<PrimitivePtr>(uses_cnode->input(0));
212 MS_EXCEPTION_IF_NULL(value_node_prim);
213 if (value_node_prim->name() == prim::kPrimTupleGetItem->name()) {
214 if (uses_set.size() > 1) {
215 MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size();
216 }
217 node_to_insert = uses_cnode;
218 }
219 }
220 MS_EXCEPTION_IF_NULL(node_to_insert);
221 std::reverse(forward_op.begin(), forward_op.end());
222
223 // step2:traverse op_list and insert node
224 for (size_t index = 0; index < forward_op.size(); ++index) {
225 std::string instance_name_base = FORWARD_OP;
226 std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
227 std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
228 SetAllReduceRecomputeFlag(forward_input, node_to_insert);
229 CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
230 MS_EXCEPTION_IF_NULL(forward_node);
231 ScopePtr scope = node->scope();
232 MS_EXCEPTION_IF_NULL(scope);
233 forward_node->set_scope(scope);
234 forward_node->set_in_forward_flag(true);
235 forward_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, MakeValue<std::string>(forward_node->UniqueId()));
236 if (node_to_insert->HasPrimalAttr(MICRO)) {
237 forward_node->AddPrimalAttr(MICRO, node_to_insert->GetPrimalAttr(MICRO));
238 }
239 forward_input[0]->set_scope(scope);
240 (void)manager->Replace(node_to_insert, forward_node); // using Replace function to insert node
241 }
242 }
243
InsertMakeTuple(const AnfNodePtr & prev,uint64_t num,const FuncGraphPtr & func_graph)244 static CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint64_t num, const FuncGraphPtr &func_graph) {
245 MS_EXCEPTION_IF_NULL(prev);
246 MS_EXCEPTION_IF_NULL(func_graph);
247 ScopeGuard scope_guard(prev->scope());
248 std::vector<AnfNodePtr> make_tuple_inputs;
249 make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
250 for (uint64_t i = 0; i < num; i++) {
251 std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), prev,
252 CreatInt64Imm(UlongToLong(i))};
253 auto tuple_get_item = func_graph->NewCNode(tuple_get_item_inputs);
254 MS_EXCEPTION_IF_NULL(tuple_get_item);
255 make_tuple_inputs.push_back(tuple_get_item);
256 }
257 auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
258 MS_EXCEPTION_IF_NULL(make_tuple);
259 FuncGraphManagerPtr manager = func_graph->manager();
260 MS_EXCEPTION_IF_NULL(manager);
261 (void)manager->Replace(prev, make_tuple);
262 return make_tuple;
263 }
264
InsertRedistribution(const RedistributionOpListPtr & redistribution_oplist_ptr,const CNodePtr & node,const FuncGraphPtr & func_graph,int64_t pos,const CNodePtr & pre_node,const TensorRedistributionPtr & tensor_redistribution)265 static void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node,
266 const FuncGraphPtr &func_graph, int64_t pos, const CNodePtr &pre_node,
267 const TensorRedistributionPtr &tensor_redistribution) {
268 MS_EXCEPTION_IF_NULL(node);
269 MS_EXCEPTION_IF_NULL(pre_node);
270 MS_EXCEPTION_IF_NULL(func_graph);
271 FuncGraphManagerPtr manager = func_graph->manager();
272 MS_EXCEPTION_IF_NULL(manager);
273 if ((redistribution_oplist_ptr->first).size() != (redistribution_oplist_ptr->second).size()) {
274 MS_LOG(EXCEPTION) << "size of OperatorVector and OutPutInfoVector must be the same!";
275 }
276
277 for (size_t index = 0; index < (redistribution_oplist_ptr->first).size(); ++index) {
278 if (pos >= SizeToLong(node->size())) {
279 MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size";
280 }
281 // Create new node
282 AnfNodePtr target_node = node->input(LongToSize(pos));
283 MS_EXCEPTION_IF_NULL(target_node);
284 // Create instance_name
285 auto op = (redistribution_oplist_ptr->first)[index];
286 std::string op_name = (redistribution_oplist_ptr->first)[index].first;
287 std::string instance_name_base = REDISTRIBUTION_OP;
288 std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name;
289 auto prim_out = GetCNodePrimitive(node);
290 auto prim_in = GetCNodePrimitive(pre_node);
291 if (prim_out != nullptr && prim_in != nullptr) {
292 auto prim_out_attr = prim_out->attrs();
293 auto prim_in_attr = prim_in->attrs();
294 std::string recompute_str = "";
295 if (prim_out_attr.find(RECOMPUTE_COMM_OP) != prim_out_attr.end()) {
296 recompute_str = GetValue<bool>(prim_out_attr[RECOMPUTE_COMM_OP]) ? RECOMPUTE : NOT_RECOMPUTE;
297 }
298 if (recompute_str.empty() && prim_in_attr.find(RECOMPUTE_COMM_OP) != prim_in_attr.end()) {
299 recompute_str = GetValue<bool>(prim_in_attr[RECOMPUTE_COMM_OP]) ? RECOMPUTE : NOT_RECOMPUTE;
300 }
301 instance_name = instance_name + "_" + recompute_str;
302 }
303 InsertNode(op, node, LongToSize(pos), target_node, func_graph, instance_name, "", nullptr, tensor_redistribution);
304 if ((redistribution_oplist_ptr->second)[index].first) {
305 target_node = node->input(LongToSize(pos));
306 MS_EXCEPTION_IF_NULL(target_node);
307 (void)InsertMakeTuple(target_node, (redistribution_oplist_ptr->second)[index].second, func_graph);
308 }
309 }
310 }
311
InsertGetTensorSliceOp(const Operator & op,const CNodePtr & node,const FuncGraphPtr & func_graph,int64_t pos,const std::string & instance_name)312 static void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph,
313 int64_t pos, const std::string &instance_name) {
314 if (func_graph == nullptr) {
315 MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name;
316 }
317
318 FuncGraphManagerPtr manager = func_graph->manager();
319 MS_EXCEPTION_IF_NULL(manager);
320 if (pos >= SizeToLong(node->size())) {
321 MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is "
322 << instance_name;
323 }
324 // Create new node
325 AnfNodePtr pre_node = node->input(LongToSize(pos));
326 MS_EXCEPTION_IF_NULL(pre_node);
327 InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name);
328 }
329
GetTensorInLayoutForNewShape(const AnfNodePtr & pre_node,std::vector<int> get_item_index)330 TensorLayout GetTensorInLayoutForNewShape(const AnfNodePtr &pre_node, std::vector<int> get_item_index) {
331 TensorLayout tensorinfo_in_layout;
332 auto pre_cnode = pre_node->cast<CNodePtr>();
333 MS_EXCEPTION_IF_NULL(pre_cnode);
334 auto distribute_operator = GetDistributeOperator(pre_cnode);
335 MS_EXCEPTION_IF_NULL(distribute_operator);
336 TensorInfoBasePtr tensorinfo_in;
337 auto tensor_info_pos = get_item_index.front();
338 get_item_index.erase(get_item_index.begin());
339 if (tensor_info_pos != -1) {
340 if (tensor_info_pos >= SizeToInt(distribute_operator->outputs_tensor_info_new().size())) {
341 MS_LOG(EXCEPTION) << "The index out of range. Node: " << pre_node->DebugString() << " index: " << tensor_info_pos
342 << " outputs_tensor_info's size: " << distribute_operator->outputs_tensor_info().size();
343 }
344 tensorinfo_in = distribute_operator->outputs_tensor_info_new()[IntToSize(tensor_info_pos)];
345 } else {
346 tensorinfo_in = distribute_operator->outputs_tensor_info_new()[0];
347 }
348 for (const auto &index : get_item_index) {
349 tensorinfo_in = tensorinfo_in->GetElement(IntToLong(index));
350 }
351 tensorinfo_in_layout = tensorinfo_in->GetValue().tensor_layout();
352 return tensorinfo_in_layout;
353 }
354
GetTensorInLayout(const AnfNodePtr & pre_node,std::vector<int> get_item_index)355 TensorLayout GetTensorInLayout(const AnfNodePtr &pre_node, std::vector<int> get_item_index) {
356 TensorLayout tensorinfo_in_layout;
357 auto pre_cnode = pre_node->cast<CNodePtr>();
358 MS_EXCEPTION_IF_NULL(pre_cnode);
359 auto distribute_operator = GetDistributeOperator(pre_cnode);
360 if (!distribute_operator->outputs_tensor_info_new().empty()) {
361 return GetTensorInLayoutForNewShape(pre_node, get_item_index);
362 }
363 MS_EXCEPTION_IF_NULL(distribute_operator);
364 if (get_item_index.size() != 1) {
365 // If does not have outputes_tensor_info_new, the outputs only have one tensor info
366 // thus the get item index must only have one value
367 MS_LOG(EXCEPTION) << "The get_item_index size is not 1, the size is " << get_item_index.size();
368 }
369 if (get_item_index[get_item_index.size() - 1] != -1) {
370 if (get_item_index[get_item_index.size() - 1] >= SizeToInt(distribute_operator->outputs_tensor_info().size())) {
371 MS_LOG(EXCEPTION) << "The index out of range. Node: " << pre_node->DebugString() << " index: " << get_item_index
372 << " outputs_tensor_info's size: " << distribute_operator->outputs_tensor_info().size();
373 }
374 auto tensorinfo_in =
375 distribute_operator->outputs_tensor_info()[IntToSize(get_item_index[get_item_index.size() - 1])];
376 tensorinfo_in_layout = tensorinfo_in.tensor_layout();
377 } else {
378 if (distribute_operator->outputs_tensor_info().empty()) {
379 MS_LOG(EXCEPTION) << "The outputs tensor info is empty. Node:" << pre_node->DebugString();
380 }
381 auto tensorinfo_in = distribute_operator->outputs_tensor_info()[0];
382 tensorinfo_in_layout = tensorinfo_in.tensor_layout();
383 }
384 return tensorinfo_in_layout;
385 }
386
ObtainOutputTensorLayout(const OperatorInfoPtr & next_distribute_operator,const std::pair<AnfNodePtr,std::vector<int>> & node_pair,const CNodePtr & next_cnode,const bool & using_func_param_op_info,TensorLayout * tensorlayout_out)387 Status ObtainOutputTensorLayout(const OperatorInfoPtr &next_distribute_operator,
388 const std::pair<AnfNodePtr, std::vector<int>> &node_pair, const CNodePtr &next_cnode,
389 const bool &using_func_param_op_info, TensorLayout *tensorlayout_out) {
390 bool next_dist_op_has_tuple = !next_distribute_operator->inputs_tensor_info_new().empty();
391 if (next_dist_op_has_tuple) {
392 auto next_inputs_tensor_info = using_func_param_op_info ? next_distribute_operator->outputs_tensor_info_new()
393 : next_distribute_operator->inputs_tensor_info_new();
394 auto it = std::find_if(node_pair.second.begin(), node_pair.second.end(), [&](const auto &input_idx) {
395 return LongToSize(input_idx - 1) >= next_inputs_tensor_info.size();
396 });
397 if (it != node_pair.second.end()) {
398 MS_LOG(INFO) << "The index is out of range, the index is " << (*it - 1) << ", the vector size is "
399 << next_inputs_tensor_info.size() << ", next node is " << next_cnode->DebugString();
400 return FAILED;
401 }
402 auto tensorinfo_out_ptr = next_inputs_tensor_info[LongToSize(node_pair.second[0] - 1)];
403 if (tensorinfo_out_ptr->is_list()) {
404 for (size_t i = 1; i < node_pair.second.size(); ++i) {
405 tensorinfo_out_ptr = tensorinfo_out_ptr->GetElement(LongToSize(node_pair.second[i] - 1));
406 }
407 }
408 TensorInfo tensorinfo_out = tensorinfo_out_ptr->GetValue();
409 *tensorlayout_out = tensorinfo_out.tensor_layout();
410 return SUCCESS;
411 }
412 auto next_inputs_tensor_info = using_func_param_op_info ? next_distribute_operator->outputs_tensor_info()
413 : next_distribute_operator->inputs_tensor_info();
414 size_t out_layout_index = LongToSize(node_pair.second[node_pair.second.size() - 1] - 1);
415 if (out_layout_index >= next_inputs_tensor_info.size()) {
416 MS_LOG(INFO) << "The index is out of range, the index is " << out_layout_index << ", the vector size is "
417 << next_inputs_tensor_info.size() << ", next node is " << next_cnode->DebugString();
418 return FAILED;
419 }
420 TensorInfo tensorinfo_out = next_inputs_tensor_info[out_layout_index];
421 *tensorlayout_out = tensorinfo_out.tensor_layout();
422 return SUCCESS;
423 }
424
InsertRedistributionForMicroInterleaved(const TensorRedistributionPtr & tensor_redistribution,const std::pair<AnfNodePtr,int64_t> & node_pair,const FuncGraphPtr & func_graph,const CNodePtr & attr_cnode,const CNodePtr & real_pre_node)425 void InsertRedistributionForMicroInterleaved(const TensorRedistributionPtr &tensor_redistribution,
426 const std::pair<AnfNodePtr, int64_t> &node_pair,
427 const FuncGraphPtr &func_graph, const CNodePtr &attr_cnode,
428 const CNodePtr &real_pre_node) {
429 auto redistribution_oplist_ptr_vector = tensor_redistribution->InferTensorRedistributionOperatorVirtualGraphs();
430 auto next_cnode = node_pair.first->cast<CNodePtr>();
431 MS_EXCEPTION_IF_NULL(next_cnode);
432 auto next_cnode_index = node_pair.second;
433 // create VirtualConverterBeginNode
434 MS_EXCEPTION_IF_NULL(real_pre_node);
435 auto virtual_converter_begin =
436 CreateVirtualConverterBeginNode(real_pre_node, redistribution_oplist_ptr_vector.size());
437 std::vector<CNodePtr> tuple_get_item_vector;
438 for (size_t i = 0; i < redistribution_oplist_ptr_vector.size(); ++i) {
439 if (redistribution_oplist_ptr_vector[i]->first.empty()) {
440 return;
441 }
442 // create tuple_get_item
443 std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), virtual_converter_begin,
444 CreatInt64Imm(UlongToLong(i))};
445 auto tuple_get_item_cnode = func_graph->NewCNode(tuple_get_item_inputs);
446 tuple_get_item_vector.push_back(tuple_get_item_cnode);
447 }
448 // create VirtualConverterEndNode
449 auto virtual_converter_end = CreateVirtualConverterEndNode(func_graph, tuple_get_item_vector);
450 auto manager = func_graph->manager();
451 (void)manager->SetEdge(next_cnode, next_cnode_index, virtual_converter_end);
452 // add recompute_comm_op attrs
453 auto prim_out = GetCNodePrimitive(next_cnode);
454 if (prim_out != nullptr && prim_out->HasAttr(RECOMPUTE_COMM_OP)) {
455 auto out_recompute_comm_op_attr = prim_out->GetAttr(RECOMPUTE_COMM_OP);
456 auto virtual_converter_end_prim = GetCNodePrimitive(virtual_converter_end);
457 virtual_converter_end_prim->AddAttr(RECOMPUTE_COMM_OP, out_recompute_comm_op_attr);
458 }
459 std::vector<std::vector<std::vector<int64_t>>> ag_group_ranks_vectors;
460
461 for (size_t i = 0; i < redistribution_oplist_ptr_vector.size(); ++i) {
462 auto redistribution_oplist_ptr = redistribution_oplist_ptr_vector[i];
463 if (!tensor_redistribution->IsAssembledStaticShape()) {
464 redistribution_oplist_ptr = TensorTransform::GetInstance()->OptimizeTensorRedistributionOperatorList(
465 redistribution_oplist_ptr, tensor_redistribution->input_shape());
466 }
467 // Get allgather group_ranks attr in redistribution_oplist_ptr
468 std::vector<std::vector<int64_t>> ag_group_ranks_vector;
469 for (size_t findex = 0; findex < (redistribution_oplist_ptr->first).size(); ++findex) {
470 // Create instance_name
471 auto index = (redistribution_oplist_ptr->first).size() - 1 - findex;
472 auto op = (redistribution_oplist_ptr->first)[index];
473 std::string op_name = (redistribution_oplist_ptr->first)[index].first;
474 if (op_name == ALL_GATHER) {
475 auto group_ranks_attr = (redistribution_oplist_ptr->first)[index].second.first[1].second;
476 auto group_ranks = GetValue<std::vector<int64_t>>(group_ranks_attr);
477 ag_group_ranks_vector.push_back(group_ranks);
478 }
479 }
480 ag_group_ranks_vectors.push_back(ag_group_ranks_vector);
481 InsertRedistribution(redistribution_oplist_ptr, virtual_converter_end, func_graph, i + 1, attr_cnode,
482 tensor_redistribution);
483 }
484 ConvertInterleaveAllGatherToConcat(func_graph, virtual_converter_end, ag_group_ranks_vectors);
485 }
486
Redistribution(const std::pair<AnfNodePtr,std::vector<int>> & node_pair,const AnfNodePtr & pre_node,const std::vector<int> & get_item_index)487 static void Redistribution(const std::pair<AnfNodePtr, std::vector<int>> &node_pair, const AnfNodePtr &pre_node,
488 const std::vector<int> &get_item_index) {
489 MS_LOG(DEBUG) << "Do Redistribution for " << node_pair.first->fullname_with_scope();
490 auto next_cnode = node_pair.first->cast<CNodePtr>();
491 MS_EXCEPTION_IF_NULL(next_cnode);
492 auto func_graph = next_cnode->func_graph();
493 MS_EXCEPTION_IF_NULL(func_graph);
494 auto pre_cnode = pre_node->cast<CNodePtr>();
495 MS_EXCEPTION_IF_NULL(pre_cnode);
496 auto distribute_operator = GetDistributeOperator(pre_cnode);
497 MS_EXCEPTION_IF_NULL(distribute_operator);
498 auto dev_list = distribute_operator->stage_device_list();
499 OperatorInfoPtr next_distribute_operator;
500 bool using_func_param_op_info = false;
501 if (IsValueNode<FuncGraph>(next_cnode->input(0))) {
502 auto fg = GetValueNode<FuncGraphPtr>(next_cnode->input(0));
503 auto fg_parameters = fg->parameters();
504 auto param = fg_parameters[IntToSize(node_pair.second[node_pair.second.size() - 1] - 1)];
505 if (param->has_user_data<OperatorInfo>()) {
506 MS_LOG(INFO) << "Func call node:" << next_cnode->DebugString() << " has operator info.";
507 next_distribute_operator = param->user_data<OperatorInfo>();
508 using_func_param_op_info = true;
509 } else {
510 next_distribute_operator = GetDistributeOperator(next_cnode);
511 }
512 } else {
513 next_distribute_operator = GetDistributeOperator(next_cnode);
514 }
515 MS_LOG(DEBUG) << "Redistribution for pre_node: " << pre_cnode->DebugString()
516 << " next_node: " << next_cnode->DebugString();
517 MS_EXCEPTION_IF_NULL(next_distribute_operator);
518
519 auto tensor_redistribution = next_distribute_operator->CreateTensorRedistribution();
520 tensor_redistribution->SetPreAndNextCNode(pre_cnode, next_cnode);
521 MS_LOG(DEBUG) << "Redistribution for pre_node: " << pre_cnode->DebugString()
522 << "next_node: " << next_cnode->DebugString();
523
524 // extract tensor layout in and out
525 if (distribute_operator->outputs_tensor_info().empty() && distribute_operator->outputs_tensor_info_new().empty()) {
526 MS_LOG(WARNING) << "pre_node's tensorinfo_in is empty, operator name is " << distribute_operator->name();
527 return;
528 }
529 TensorLayout tensorlayout_out;
530 auto status = ObtainOutputTensorLayout(next_distribute_operator, node_pair, next_cnode, using_func_param_op_info,
531 &tensorlayout_out);
532 if (status != SUCCESS) {
533 return;
534 }
535 TensorLayout tensorlayout_in = GetTensorInLayout(pre_node, get_item_index);
536 if (IsPrimitiveCNode(pre_node, prim::kPrimReceive)) {
537 tensorlayout_in = *(pre_node->user_data<TensorLayout>());
538 }
539
540 if (tensor_redistribution->Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) {
541 MS_LOG(ERROR) << "Redistribution: pre_node " << pre_cnode->DebugString() << " next_node "
542 << next_cnode->DebugString();
543 DumpGraph(func_graph, "redistribution_error");
544 MS_LOG(EXCEPTION) << "Failure:tensor_redistribution init failed";
545 }
546 if (tensorlayout_in.GetVirtualRank().size() > 1 || tensorlayout_out.GetVirtualRank().size() > 1) {
547 auto real_pre_node = next_cnode->input(node_pair.second[node_pair.second.size() - 1])->cast<CNodePtr>();
548 InsertRedistributionForMicroInterleaved(tensor_redistribution,
549 {node_pair.first, node_pair.second[node_pair.second.size() - 1]},
550 func_graph, pre_cnode, real_pre_node);
551 return;
552 }
553 RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution->InferTensorRedistributionOperatorList();
554 if (redistribution_oplist_ptr == nullptr) {
555 MS_LOG(INTERNAL_EXCEPTION) << "Infer tensor redistribution failed.";
556 }
557 if (!tensor_redistribution->IsAssembledStaticShape()) {
558 redistribution_oplist_ptr = TensorTransform::GetInstance()->OptimizeTensorRedistributionOperatorList(
559 redistribution_oplist_ptr, tensor_redistribution->input_shape());
560 }
561
562 if (redistribution_oplist_ptr == nullptr) {
563 MS_LOG(EXCEPTION) << "Failure:InferTensorRedistribution failed";
564 }
565 MS_LOG(DEBUG) << "Redistribution size " << redistribution_oplist_ptr->first.size();
566 if (!redistribution_oplist_ptr->first.empty()) {
567 // the last one is the pos of node in maketuple
568 tensor_redistribution->CreateAssembledDynamicMapping(next_cnode, pre_cnode, func_graph,
569 node_pair.second[node_pair.second.size() - 1]);
570 // insert node before next node
571 InsertRedistribution(redistribution_oplist_ptr, next_cnode, func_graph,
572 node_pair.second[node_pair.second.size() - 1], pre_cnode, tensor_redistribution);
573 }
574 // Rollback to dynamic shape.
575 if (tensor_redistribution->IsAssembledStaticShape() &&
576 tensor_redistribution->ResetLayoutTransfer() != Status::SUCCESS) {
577 MS_LOG(WARNING) << "Failed to reset layout transfer.";
578 }
579 }
580
StepRedistribution(const CNodePtr & cnode,const NodeUsersMap & node_users_map)581 static void StepRedistribution(const CNodePtr &cnode, const NodeUsersMap &node_users_map) {
582 MS_LOG(DEBUG) << "Do StepRedistribution for " << cnode->fullname_with_scope();
583 MS_EXCEPTION_IF_NULL(cnode->func_graph());
584 FuncGraphManagerPtr manager = cnode->func_graph()->manager();
585 MS_EXCEPTION_IF_NULL(manager);
586 // In pipeline parallel mode, redistribution is inserted after receive, not send.
587 if (IsPrimitiveCNode(cnode, prim::kPrimSend) || IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) ||
588 IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
589 return;
590 }
591 // Find Redistribution next_nodes
592 // next_node.first.second = (pos in next node input(don't need to -1), pos in tuple(need to -1))
593 std::vector<std::pair<std::pair<AnfNodePtr, std::vector<int>>, std::vector<int>>> next_nodes;
594 RedistributionNextNode(cnode, manager, node_users_map, {-1}, -1, &next_nodes);
595 if (next_nodes.empty()) {
596 return;
597 }
598
599 // Find Redistribution pre_nodes
600 std::vector<AnfNodePtr> pre_nodes;
601 RedistributionPreNode(cnode, manager, &pre_nodes);
602 if (pre_nodes.size() > 1) {
603 MS_LOG(EXCEPTION) << " Don't support Redistribution has multiple pre_node.";
604 }
605
606 // Insert Redistribution nodes between pre_nodes and next_nodes
607 for (auto &pre_node : pre_nodes) {
608 for (auto &next_node : next_nodes) {
609 MS_LOG(INFO) << "===========Do Redistribution start============" << std::endl
610 << pre_node->fullname_with_scope() << "->" << next_node.first.first->fullname_with_scope() << "("
611 << next_node.first.second << ")";
612 Redistribution(next_node.first, pre_node, next_node.second);
613 MS_LOG(INFO) << "===========Do Redistribution end ============";
614 }
615 for (const auto &next_node : next_nodes) {
616 if (!next_node.first.first->has_user_data(FUNC_PARAM)) {
617 continue;
618 }
619 if (pre_node->func_graph() == next_node.first.first->func_graph()) {
620 continue;
621 }
622 auto param = next_node.first.first->user_data<AnfNode>(FUNC_PARAM);
623 auto distribute_operator = GetDistributeOperator(pre_node->cast<CNodePtr>());
624 param->set_user_data<OperatorInfo>(distribute_operator);
625 break;
626 }
627 }
628 }
629
SplitTensor(const AnfNodePtr & node,const CNodePtr & next_node,int64_t index)630 static void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int64_t index) {
631 MS_EXCEPTION_IF_NULL(node);
632 MS_EXCEPTION_IF_NULL(next_node);
633 OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
634 if (!op_info) {
635 return;
636 }
637
638 if (op_info->name().find(FILLV2) != std::string::npos) {
639 MS_LOG(INFO) << "FillV2 operator info no need to split tensor";
640 return;
641 }
642
643 if (op_info->name().find(STAND_ALONE) != std::string::npos) {
644 MS_LOG(INFO) << "Stand alone operator info no need to split tensor";
645 return;
646 }
647
648 // If the shape of tensor is [] or [1], no need to split it.
649 Shapes shapes = GetNodeShape(node);
650 if (shapes.size() != 1) {
651 MS_LOG(EXCEPTION) << "Split tensor for " << op_info->name()
652 << ": GetNodeShape for tensor_node, output size is not 1";
653 }
654 Shape shape = shapes[0];
655 std::string shape_str = ShapeToString(shape);
656 if (shape.empty() || ((shape.size() == 1) && (shape[0] == 1))) {
657 MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape is " << shape_str
658 << ", no need to split it.";
659 return;
660 }
661
662 MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape of tensor is " << shape_str;
663
664 // extract tensor layout
665 TensorLayout tensor_layout;
666 auto inputs_info_size = op_info->inputs_tensor_info_new().empty() ? op_info->inputs_tensor_info().size()
667 : op_info->inputs_tensor_info_new().size();
668 if (LongToSize(index - 1) >= inputs_info_size) {
669 if (IsIgnoreSplitTensor(next_node, index - 1)) {
670 MS_LOG(INFO) << op_info->name() << ": no need to split tensor for index " << (index - 1);
671 return;
672 }
673 MS_LOG(EXCEPTION) << op_info->name() << ": The index is out of range, index is " << (index - 1)
674 << ", vector size is " << inputs_info_size;
675 }
676 if (op_info->inputs_tensor_info_new().empty()) {
677 TensorInfo tensor_info = op_info->inputs_tensor_info()[LongToSize(index - 1)];
678 tensor_layout = tensor_info.tensor_layout();
679 } else {
680 auto tensor_info = op_info->inputs_tensor_info_new()[LongToSize(index - 1)];
681 tensor_layout = tensor_info->GetValue().tensor_layout();
682 }
683
684 // Use _GetTensorSlice operator to split the tensor
685 FuncGraphPtr func_graph = next_node->func_graph(); // only cnode can get the graph
686 MS_EXCEPTION_IF_NULL(func_graph);
687 Operator op = CreateGetTensorSliceOp(tensor_layout);
688 InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR);
689 if (!op_info->sub_ops().empty()) {
690 auto sub_ops = op_info->sub_ops();
691 for (size_t i = 0; i < sub_ops.size(); i++) {
692 if (!sub_ops.at(i).empty()) {
693 InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB);
694 }
695 }
696 }
697 }
698
SplitTensorList(const AnfNodePtr & node,const CNodePtr & next_node,int index)699 static void SplitTensorList(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
700 MS_EXCEPTION_IF_NULL(node);
701 MS_EXCEPTION_IF_NULL(next_node);
702 if (((next_node->size() != kSizeTwo) && !IsSomePrimitiveList(next_node, SUPPORT_NEW_SHAPEBASE_OPS)) || index != 1) {
703 MS_LOG(INFO) << next_node->fullname_with_scope() << " Inputs must have only one input, get "
704 << (next_node->size() - 1) << " index should be 1, get " << index;
705 return;
706 }
707 OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
708 MS_EXCEPTION_IF_NULL(op_info);
709
710 std::vector<ValuePtr> inputs_values;
711 if (IsValueNode<ValueList>(node)) {
712 inputs_values = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
713 } else {
714 inputs_values = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
715 }
716 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
717 FuncGraphPtr func_graph = next_node->func_graph();
718 MS_EXCEPTION_IF_NULL(func_graph);
719 FuncGraphManagerPtr manager = func_graph->manager();
720 MS_EXCEPTION_IF_NULL(manager);
721 if (op_info->inputs_tensor_info_new().empty()) {
722 if (inputs_values.size() != op_info->inputs_tensor_info().size()) {
723 MS_LOG(EXCEPTION) << "The inputs size " << inputs_values.size() << ", is not equal to inputs shape size "
724 << op_info->inputs_tensor_info().size();
725 }
726 ScopePtr scope = next_node->scope();
727 MS_EXCEPTION_IF_NULL(scope);
728 for (size_t i = 0; i < inputs_values.size(); ++i) {
729 auto value_ptr = inputs_values[i];
730 auto tensor = value_ptr->cast<tensor::TensorPtr>();
731 MS_EXCEPTION_IF_NULL(tensor);
732 TensorInfo tensor_info = op_info->inputs_tensor_info()[i];
733 TensorLayout tensor_layout = tensor_info.tensor_layout();
734 auto value_node = NewValueNode(value_ptr)->cast<AnfNodePtr>();
735 Operator op = CreateGetTensorSliceOp(tensor_layout);
736 std::vector<AnfNodePtr> node_input = CreateInput(op, value_node, SPLIT_TENSOR);
737 CNodePtr new_node = func_graph->NewCNode(node_input);
738 new_node->set_in_forward_flag(true);
739 auto new_node_value = node_input[0]->cast<ValueNodePtr>();
740 MS_EXCEPTION_IF_NULL(new_node_value);
741 PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
742 new_node_prim->set_instance_name(SPLIT_TENSOR);
743 new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
744 new_node->set_scope(scope);
745 node_input[0]->set_scope(scope);
746 make_tuple_inputs.push_back(new_node);
747 }
748 } else {
749 if (inputs_values.size() != op_info->inputs_tensor_info_new()[index - 1]->size()) {
750 MS_LOG(EXCEPTION) << "The inputs size " << inputs_values.size() << ", is not equal to inputs shape size "
751 << op_info->inputs_tensor_info_new()[index - 1]->size();
752 }
753 auto corresponding_tensor_info = op_info->inputs_tensor_info_new()[index - 1];
754 ScopePtr scope = next_node->scope();
755 MS_EXCEPTION_IF_NULL(scope);
756 for (size_t i = 0; i < inputs_values.size(); ++i) {
757 auto value_ptr = inputs_values[i];
758 auto tensor = value_ptr->cast<tensor::TensorPtr>();
759 MS_EXCEPTION_IF_NULL(tensor);
760 TensorInfo tensor_info = corresponding_tensor_info->GetElement(SizeToLong(i))->GetValue();
761 TensorLayout tensor_layout = tensor_info.tensor_layout();
762 auto value_node = NewValueNode(value_ptr)->cast<AnfNodePtr>();
763 Operator op = CreateGetTensorSliceOp(tensor_layout);
764 std::vector<AnfNodePtr> node_input = CreateInput(op, value_node, SPLIT_TENSOR);
765 CNodePtr new_node = func_graph->NewCNode(node_input);
766 new_node->set_in_forward_flag(true);
767 auto new_node_value = node_input[0]->cast<ValueNodePtr>();
768 MS_EXCEPTION_IF_NULL(new_node_value);
769 PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
770 new_node_prim->set_instance_name(SPLIT_TENSOR);
771 new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
772 new_node->set_scope(scope);
773 node_input[0]->set_scope(scope);
774 make_tuple_inputs.push_back(new_node);
775 }
776 }
777 CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
778 (void)manager->Replace(node, make_tuple);
779 }
780
StepSplitTensor(const AnfNodePtr & node,const FuncGraphManagerPtr & manager)781 static void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
782 MS_EXCEPTION_IF_NULL(node);
783 MS_EXCEPTION_IF_NULL(manager);
784 AnfNodeIndexSet node_set = manager->node_users()[node];
785 for (auto &node_pair : node_set) {
786 CNodePtr use_cnode = node_pair.first->cast<CNodePtr>();
787 if (use_cnode == nullptr || !IsValueNode<Primitive>(use_cnode->input(0))) {
788 continue;
789 }
790 ValueNodePtr prim_anf_node = use_cnode->input(0)->cast<ValueNodePtr>();
791 MS_EXCEPTION_IF_NULL(prim_anf_node);
792 PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast<PrimitivePtr>();
793 MS_EXCEPTION_IF_NULL(use_cnode_prim);
794 if ((use_cnode_prim->name() == DEPEND && node_pair.second != 1) ||
795 NO_INPUT_TENSOR_OPS.find(use_cnode_prim->name()) != NO_INPUT_TENSOR_OPS.end()) {
796 continue;
797 }
798 if (IsParallelCareNode(use_cnode)) {
799 if (IsPrimitiveCNode(use_cnode, prim::kPrimReceive)) {
800 continue;
801 }
802 if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
803 SplitTensorList(node, use_cnode, node_pair.second);
804 } else {
805 SplitTensor(node, use_cnode, node_pair.second);
806 }
807 }
808 }
809 }
810
StepReplaceOp(OperatorVector replace_op,const CNodePtr & node)811 static void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
812 MS_LOG(INFO) << "Start StepReplaceOp for " << node->fullname_with_scope();
813 // step1:get graph manager distribute_operator
814 OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
815 if (distribute_operator == nullptr) {
816 MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
817 }
818 FuncGraphPtr func_graph = node->func_graph();
819 MS_EXCEPTION_IF_NULL(func_graph);
820 FuncGraphManagerPtr manager = func_graph->manager();
821 if (manager == nullptr) {
822 MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
823 }
824
825 // When reshape(bool), insert cast in the begin and end of op_list to avoid AllGather(bool).
826 auto reshape_type_str = node->abstract()->BuildType()->ToString();
827 auto replace_op_info = distribute_operator->replace_op_info();
828 if (IsPrimitiveCNode(node, prim::kPrimReshape) && reshape_type_str.find(BOOL) != std::string::npos) {
829 auto cast_int = CreateCastOp(kInt32);
830 auto cast_bool = CreateCastOp(kBool);
831 (void)replace_op.insert(replace_op.cbegin(), cast_int);
832 (void)replace_op.insert(replace_op.cend(), cast_bool);
833 (void)replace_op_info.insert(replace_op_info.cbegin(), {false, 1});
834 (void)replace_op_info.insert(replace_op_info.cend(), {false, 1});
835 }
836
837 // step2:traverse op_list and insert node
838 std::reverse(replace_op.begin(), replace_op.end());
839 std::reverse(replace_op_info.begin(), replace_op_info.end());
840 if (!replace_op_info.empty() && replace_op_info.size() != replace_op.size()) {
841 MS_LOG(EXCEPTION) << "replace_op_info is not empty and size not equal to replace_op!";
842 }
843 bool replace_op_info_flag = !replace_op_info.empty();
844 for (size_t index = 0; index < replace_op.size(); ++index) {
845 std::string instance_name = CreateInstanceName(node, index);
846 std::string full_inst_name = std::string(REDISTRIBUTION_OP) + "_" + instance_name;
847 std::vector<AnfNodePtr> replace_input;
848 if (index != replace_op.size() - 1) {
849 replace_input = CreateInput(replace_op[index], node, full_inst_name, node);
850 } else {
851 replace_input = ReplaceOpInput(replace_op[index], full_inst_name, node);
852 }
853 CNodePtr replace_node = func_graph->NewCNode(replace_input);
854 MS_EXCEPTION_IF_NULL(replace_node);
855 ScopePtr scope = node->scope();
856 MS_EXCEPTION_IF_NULL(scope);
857 replace_node->set_scope(scope);
858 PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
859 PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
860 SetUserAttrs(origin_prim->attrs(), prim);
861 auto origin_prim_attrs = origin_prim->attrs();
862 if (origin_prim_attrs.find(RECOMPUTE_COMM_OP) != origin_prim_attrs.end()) {
863 auto do_recompute = GetValue<bool>(origin_prim_attrs[RECOMPUTE_COMM_OP]);
864 MS_LOG(INFO) << "The redistribution node in reshape would not be recomputed.";
865 prim->set_attr(RECOMPUTE, MakeValue(do_recompute));
866 }
867 if (prim->name() == GET_NEXT && origin_prim_attrs.find(SYMBOLS) != origin_prim_attrs.end()) {
868 prim->set_attr(SYMBOLS, origin_prim_attrs[SYMBOLS]);
869 }
870 if (index == replace_op.size() - 1) {
871 replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
872 replace_node->set_primal_attrs(node->primal_attrs());
873 }
874 replace_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, MakeValue<std::string>(replace_node->UniqueId()));
875 if (node->HasPrimalAttr(MICRO)) {
876 replace_node->AddPrimalAttr(MICRO, node->GetPrimalAttr(MICRO));
877 }
878 replace_node->set_in_forward_flag(true);
879 replace_input[0]->set_scope(scope);
880 if (replace_op_info_flag && replace_op_info[index].first) {
881 auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph);
882 new_cnode->set_primal_attrs(node->primal_attrs());
883 (void)manager->Replace(node, new_cnode); // using Replace function to insert node
884 } else {
885 (void)manager->Replace(node, replace_node); // using Replace function to insert node
886 }
887 }
888 MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name();
889 }
890
StepReplaceGraph(const ReplaceGraphPtr & replace_graph,const CNodePtr & node,const OperatorInfoPtr & op_info)891 static void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node,
892 const OperatorInfoPtr &op_info) {
893 MS_EXCEPTION_IF_NULL(replace_graph);
894 MS_EXCEPTION_IF_NULL(node);
895 MS_EXCEPTION_IF_NULL(replace_graph->second);
896 FuncGraphPtr func_graph = node->func_graph();
897 MS_EXCEPTION_IF_NULL(func_graph);
898 FuncGraphManagerPtr manager = func_graph->manager();
899 if (manager == nullptr) {
900 MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
901 }
902 // Solve the input order
903 // For example input_node:{segment_sum:1, segment_sum:2, gather:2}
904 // The Original code here will bind the all operations to the first inputs of these operators
905 // However, the segment_sum operation needs two inputs, To solve this
906 // We maintain a dict to count the times of the same operations,
907 // and bind the inputs according to the times of the op appears.
908 mindspore::HashMap<AnfNodePtr, int> input_map = {};
909 static int appear_count = 0;
910 for (auto &replace_input : replace_graph->first) {
911 auto pre_node = node->input(LongToSize(replace_input.second));
912
913 auto it = input_map.find(replace_input.first);
914 if (it != input_map.end()) {
915 appear_count = 1 + it->second;
916 } else {
917 appear_count = 1;
918 }
919 auto replace_input_cnode = replace_input.first->cast<CNodePtr>();
920 replace_input_cnode->set_user_data<OperatorInfo>(op_info);
921 size_t inputs_size = replace_input_cnode->size();
922 while (IntToSize(appear_count) < inputs_size && replace_input_cnode->input(appear_count)->func_graph() != nullptr) {
923 ++appear_count;
924 }
925 if (IntToSize(appear_count) >= inputs_size) {
926 MS_LOG(EXCEPTION) << "No replaceable virtual_input_node";
927 }
928 input_map[replace_input.first] = appear_count;
929 replace_input_cnode->set_in_forward_flag(true);
930 manager->SetEdge(replace_input.first, appear_count, pre_node);
931 }
932 // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called
933 auto replace_output = replace_graph->second->cast<CNodePtr>();
934 MS_EXCEPTION_IF_NULL(replace_output);
935 replace_output->set_in_forward_flag(true);
936 replace_output->set_primal_attrs(node->primal_attrs());
937 (void)manager->Replace(node, replace_output);
938 }
939
InsertVirtualDivOp(const VirtualDivOp & virtual_div_op,const CNodePtr & node)940 static void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) {
941 MS_EXCEPTION_IF_NULL(node);
942 size_t node_size = node->size();
943 FuncGraphPtr func_graph = node->func_graph();
944 MS_EXCEPTION_IF_NULL(func_graph);
945 FuncGraphManagerPtr manager = func_graph->manager();
946 MS_EXCEPTION_IF_NULL(manager);
947
948 if (IsSomePrimitive(node, DROPOUT_DO_MASK)) {
949 MS_LOG(INFO) << "Handle dropout do mask, only insert the virtual div to input[0]";
950 node_size = 2;
951 }
952
953 for (size_t index = 1; index < node_size; ++index) {
954 AnfNodePtr input = node->input(index);
955 MS_EXCEPTION_IF_NULL(input);
956 // if it is not a tensor, continue
957 if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
958 MS_LOG(INFO) << "insert div op: the index " << index << " is not tensor, skip";
959 continue;
960 }
961
962 for (size_t pos = 0; pos < virtual_div_op.size(); ++pos) {
963 std::string instance_name = CreateInstanceName(node, pos);
964 InsertNode(virtual_div_op[pos], node, index, node->input(index), func_graph, instance_name);
965 }
966 MS_LOG(INFO) << "insert div op for input index " << index << " of node";
967 }
968 }
969
InsertRealDivOpToNodeInput(const CNodePtr & node,int64_t scale,const string & instance_name)970 static void InsertRealDivOpToNodeInput(const CNodePtr &node, int64_t scale, const string &instance_name) {
971 MS_EXCEPTION_IF_NULL(node);
972 if (scale == 0) {
973 MS_LOG(EXCEPTION) << "Find the scale value is 0, you should check the mirror operators's group size.";
974 }
975 size_t node_size = node->size();
976 FuncGraphPtr func_graph = node->func_graph();
977 MS_EXCEPTION_IF_NULL(func_graph);
978 // instance the real div operator
979 Operator div_op = CreateDivOp(LongToFloat(scale));
980
981 // Insert it as the input of the node
982 for (size_t index = 1; index < node_size; ++index) {
983 AnfNodePtr input = node->input(index);
984 MS_EXCEPTION_IF_NULL(input);
985 // if it is not a tensor, continue
986 if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
987 continue;
988 }
989 InsertNode(div_op, node, index, node->input(index), func_graph, instance_name);
990 }
991 }
992
InsertAllReduceToNodeInput(const CNodePtr & node,const std::string & group,const std::string & instance_name)993 static void InsertAllReduceToNodeInput(const CNodePtr &node, const std::string &group,
994 const std::string &instance_name) {
995 MS_EXCEPTION_IF_NULL(node);
996 size_t node_size = node->size();
997 FuncGraphPtr func_graph = node->func_graph();
998 MS_EXCEPTION_IF_NULL(func_graph);
999 // instance the real div operator
1000 CheckGlobalDeviceManager();
1001 Operator allreduce_op = CreateAllReduceOp(REDUCE_OP_SUM, group);
1002
1003 // Insert it as the input of the node
1004 for (size_t index = 1; index < node_size; ++index) {
1005 AnfNodePtr input = node->input(index);
1006 MS_EXCEPTION_IF_NULL(input);
1007 // if it is not a tensor, continue
1008 if ((!input->isa<CNode>() && !input->isa<Parameter>()) || HasAbstractMonad(input)) {
1009 continue;
1010 }
1011
1012 InsertNode(allreduce_op, node, index, node->input(index), func_graph, instance_name);
1013 }
1014 }
1015
PynativeParallelGraph(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1016 static FuncGraphPtr PynativeParallelGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
1017 FuncGraphPtr real_graph = root;
1018 for (auto &node : all_nodes) {
1019 if (!node->isa<CNode>()) {
1020 continue;
1021 }
1022 auto cnode = node->cast<CNodePtr>();
1023 if (!IsValueNode<Primitive>(cnode->input(0))) {
1024 continue;
1025 }
1026 auto expect_shard_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1027 if (expect_shard_prim->name() != SHARD) {
1028 continue;
1029 }
1030 real_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
1031 }
1032 return real_graph;
1033 }
1034
1035 // find previous parallel care node's next node.
FindPreNodes(const AnfNodePtr & node,std::vector<std::string> * unique_ids,std::vector<size_t> * indexes,size_t curr_depth)1036 static bool FindPreNodes(const AnfNodePtr &node, std::vector<std::string> *unique_ids, std::vector<size_t> *indexes,
1037 size_t curr_depth) {
1038 if (curr_depth > MAX_RECURSIVE_DEPTH) {
1039 MS_LOG(WARNING) << "When find the previous node, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
1040 return false;
1041 }
1042 MS_EXCEPTION_IF_NULL(unique_ids);
1043 MS_EXCEPTION_IF_NULL(indexes);
1044 if (!node->isa<CNode>()) {
1045 return false;
1046 }
1047 CNodePtr pre_cnode = node->cast<CNodePtr>();
1048 if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
1049 return false;
1050 }
1051 bool find = false;
1052 for (size_t index = 1; index < pre_cnode->size(); ++index) {
1053 if (IsPrimitiveCNode(pre_cnode, prim::kPrimDepend) && index > 1) {
1054 // For Depend, only the first input will be output.
1055 break;
1056 }
1057 auto next_node = pre_cnode->inputs()[index];
1058 if (!next_node->isa<CNode>() || next_node->isa<Parameter>()) {
1059 return false;
1060 }
1061 CNodePtr cnode = next_node->cast<CNodePtr>();
1062 if (!IsValueNode<Primitive>(cnode->input(0))) {
1063 return false;
1064 }
1065 if (IsParallelCareNode(cnode) && !IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) &&
1066 !IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
1067 unique_ids->push_back(pre_cnode->UniqueId());
1068 indexes->push_back(index);
1069 find = true;
1070 continue;
1071 }
1072 if (FindPreNodes(cnode, unique_ids, indexes, ++curr_depth)) {
1073 find = true;
1074 }
1075 }
1076 return find;
1077 }
1078
InsertVirtualOutput(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1079 void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
1080 auto real_graph = PynativeParallelGraph(root, all_nodes);
1081 auto out_pair = GetRealKernelNode(real_graph->output(), -1, nullptr, false);
1082 auto out_node = out_pair.first;
1083 MS_EXCEPTION_IF_NULL(out_node);
1084 OperatorParams params;
1085 OperatorAttrs attrs;
1086 OperatorArgs args = std::make_pair(attrs, params);
1087 Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
1088 if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
1089 auto tuple = out_node->cast<CNodePtr>();
1090 MS_EXCEPTION_IF_NULL(tuple);
1091 for (size_t i = 1; i < tuple->size(); ++i) {
1092 auto cur_input = tuple->input(i);
1093 Shapes shape_outputs = GetNodeShape(cur_input);
1094 if (shape_outputs[0].empty()) {
1095 continue;
1096 }
1097 InsertNode(op, tuple, i, cur_input, tuple->func_graph(), VIRTUAL_OUTPUT);
1098 auto virtual_output_abstract = cur_input->abstract()->Clone();
1099 std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
1100 virtual_output_abstract->set_shape(virtual_output_shape);
1101 auto virtual_output_node = tuple->input(i);
1102 virtual_output_node->set_abstract(virtual_output_abstract);
1103 }
1104 } else {
1105 Shapes shape_outputs = GetNodeShape(out_node);
1106 if (shape_outputs[0].empty() || out_node->isa<Parameter>()) {
1107 return;
1108 }
1109 auto node_input = CreateInput(op, out_node, VIRTUAL_OUTPUT);
1110 auto cur_graph = out_node->cast<CNodePtr>()->func_graph();
1111 MS_EXCEPTION_IF_NULL(cur_graph);
1112 auto new_node = cur_graph->NewCNode(node_input);
1113 auto manager = cur_graph->manager();
1114 (void)manager->Replace(out_node, new_node);
1115 auto virtual_output_abstract = out_node->abstract()->Clone();
1116 std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
1117 virtual_output_abstract->set_shape(virtual_output_shape);
1118 new_node->set_abstract(virtual_output_abstract);
1119 }
1120 }
1121
InsertMirrorBeforeCast(const CNodePtr & node,size_t index)1122 bool InsertMirrorBeforeCast(const CNodePtr &node, size_t index) {
1123 // only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true
1124 bool is_gradient_fp32_sync = ParallelContext::GetInstance()->gradient_fp32_sync();
1125 auto pre_node = node->input(index);
1126 MS_EXCEPTION_IF_NULL(pre_node);
1127 auto cnode = pre_node->cast<CNodePtr>();
1128 if (cnode == nullptr || !IsValueNode<Primitive>(cnode->input(0))) {
1129 return false;
1130 }
1131 if (ParallelContext::GetInstance()->enable_parallel_optimizer() && IsInAllGatherNodeList(cnode)) {
1132 pre_node = cnode->input(1);
1133 }
1134 if (!IsPrimitiveCNode(pre_node, prim::kPrimCast)) {
1135 return false;
1136 }
1137 auto node_type = pre_node->Type();
1138 MS_EXCEPTION_IF_NULL(node_type);
1139 if (!node_type->isa<mindspore::TensorType>()) {
1140 MS_LOG(EXCEPTION) << "Unknown type.";
1141 }
1142 auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1143 MS_EXCEPTION_IF_NULL(input_element_type);
1144 auto type_id = input_element_type->type_id();
1145 if (!is_gradient_fp32_sync && type_id != kNumberTypeFloat32) {
1146 return false;
1147 }
1148
1149 return true;
1150 }
1151
CheckInsertMirrorOps(const MirrorOps & mirror_ops,const CNodePtr & node)1152 static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
1153 if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1154 return true;
1155 }
1156 constexpr size_t kSingleArgCNodeSize = 2;
1157 if ((node->size() == kSingleArgCNodeSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
1158 (IsValueNode<ValueSequence>(node->input(1)))) {
1159 MS_LOG(INFO) << "Input is ValueList, skip it.";
1160 return false;
1161 }
1162
1163 if ((node->size() == kSingleArgCNodeSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
1164 (AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
1165 MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
1166 return false;
1167 }
1168 return true;
1169 }
1170
1171 // only used for InsertMirrorOps
SkipTrivialNodesMoveUp(CNodePtr node)1172 static CNodePtr SkipTrivialNodesMoveUp(CNodePtr node) {
1173 MS_EXCEPTION_IF_NULL(node);
1174 while (True) {
1175 if (IsPrimitiveCNode(node, prim::kPrimLoad) || IsInTrivialNodeList(node) || IsInAllGatherNodeList(node)) {
1176 if (IsPrimitiveCNode(node->input(1), prim::kPrimMicroStepAllGather)) {
1177 return node;
1178 }
1179 if (node->input(1)->isa<Parameter>()) {
1180 return node;
1181 }
1182 node = node->input(1)->cast<CNodePtr>();
1183 } else {
1184 MS_LOG(EXCEPTION) << "The node " << node->fullname_with_scope()
1185 << " is a abnormal node in inserting mirror node.";
1186 }
1187 }
1188 }
1189
CreateMirrorForParam(const ParameterPtr param_ptr,OperatorVector * backward_op,bool * is_shared_param)1190 static void CreateMirrorForParam(const ParameterPtr param_ptr, OperatorVector *backward_op, bool *is_shared_param) {
1191 std::string opt_shard_mirror_group;
1192 if (param_ptr->user_data<TensorLayout>()) {
1193 opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
1194 *is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
1195 }
1196 if (!opt_shard_mirror_group.empty()) {
1197 // mirror ops is covered in not fully use opt shard case
1198 uint32_t group_rank_size = 0;
1199 if (!CommManager::GetInstance().GetRankSize(opt_shard_mirror_group, &group_rank_size)) {
1200 MS_LOG(EXCEPTION) << "Got the group size from the group " << opt_shard_mirror_group << " failed";
1201 }
1202 *backward_op = CreateMirrorOps(opt_shard_mirror_group, static_cast<size_t>(group_rank_size));
1203 }
1204 }
1205
DoInsertMirrorOps(const FuncGraphPtr & root,const MirrorOps & mirror_ops,const CNodePtr & node)1206 static void DoInsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
1207 FuncGraphPtr func_graph = node->func_graph();
1208 MS_EXCEPTION_IF_NULL(func_graph);
1209 FuncGraphManagerPtr manager = func_graph->manager();
1210 MS_EXCEPTION_IF_NULL(manager);
1211 auto mirror_size = mirror_ops.size();
1212 if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1213 mirror_size = 1;
1214 }
1215
1216 for (size_t index = 1; index <= mirror_size; ++index) {
1217 OperatorVector backward_op = mirror_ops[index - 1];
1218 if (IsPrimitiveCNode(node, prim::kPrimSend)) {
1219 auto param_index = GetValue<int>(node->GetPrimalAttr(PARAM_INDEX));
1220 backward_op = mirror_ops[IntToSize(param_index)];
1221 }
1222 if (backward_op.empty()) {
1223 continue;
1224 }
1225 std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(node->input(index), func_graph);
1226 if (!param_node_pair.first) {
1227 continue;
1228 }
1229
1230 auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
1231 std::string param_name;
1232 bool is_shared_param = false;
1233 if (param_ptr) {
1234 param_name = param_ptr->name();
1235 if (!param_ptr->param_info() || !param_ptr->param_info()->requires_grad()) {
1236 MS_LOG(INFO) << param_name << " do not need gradient. Skip inserting mirror.";
1237 continue;
1238 }
1239 CreateMirrorForParam(param_ptr, &backward_op, &is_shared_param);
1240 }
1241 // not a RefKey
1242 std::string mirror_op_name = MirrorOpName();
1243 AnfNodePtr pre_node = node->input(index);
1244 if (!param_node_pair.second) {
1245 auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph, 0);
1246 // if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
1247 if (next_cnode.first) {
1248 MS_EXCEPTION_IF_NULL(next_cnode.second);
1249 // assume Load is inserted next to parameter
1250 // skip Load moving up and insert mirror next to the parameter
1251 if (pre_node->cast<CNodePtr>()) {
1252 CNodePtr load_node = SkipTrivialNodesMoveUp(node->input(index)->cast<CNodePtr>());
1253 manager->SetEdge(load_node, 1, next_cnode.second);
1254 } else {
1255 manager->SetEdge(node, static_cast<int>(index), next_cnode.second);
1256 }
1257 MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1258 << " and share the mirror.";
1259 AddNodeMirrorInfo(node->cast<CNodePtr>(), param_name);
1260 continue;
1261 }
1262 }
1263 // if the parameter found is a RefKey, or no MirrorOp is found in the same graph, insert a new MirrorOp
1264 // only one MirrorOp in backward_op
1265 if (backward_op.size() != 1) {
1266 MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size();
1267 }
1268 auto op = backward_op[0];
1269 if (pre_node->cast<CNodePtr>() && (InsertMirrorBeforeCast(node, index) || is_shared_param ||
1270 IsPrimitiveCNode(pre_node, prim::kPrimMirrorSilentCheck))) {
1271 // assume Load is inserted next to parameter
1272 // skip Load moving up and insert mirror next to the parameter
1273 CNodePtr load_node = SkipTrivialNodesMoveUp(pre_node->cast<CNodePtr>());
1274 InsertNode(op, load_node, 1, load_node->input(1), func_graph, mirror_op_name, param_name, root);
1275 auto comm_op = load_node->input(1)->cast<CNodePtr>();
1276 // add fusion flag
1277 auto fusion_id = AddCommOpFusionType(comm_op, param_node_pair.first);
1278 MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1279 << " and insert mirror before Load";
1280 AddCommOpParamFlag(comm_op);
1281 AddNodeFusionInfo(node, comm_op, "all_reduce", param_name, fusion_id);
1282 continue;
1283 }
1284 InsertNode(op, node, index, pre_node, func_graph, mirror_op_name, param_name, root);
1285 MS_LOG(INFO) << "Find parameter " << param_name << " for node " << GetPrimName(node->cast<CNodePtr>())
1286 << " and insert mirror before the node";
1287 auto comm_op = node->input(index)->cast<CNodePtr>();
1288 // add fusion flag
1289 // pipeline mirror would not be set, which should be supported later
1290 auto fusion_id = AddCommOpFusionType(comm_op, param_node_pair.first);
1291 AddCommOpParamFlag(comm_op);
1292 AddNodeFusionInfo(node, comm_op, "all_reduce", param_name, fusion_id);
1293 }
1294 }
1295
InsertMirrorOps(const FuncGraphPtr & root,const MirrorOps & mirror_ops,const CNodePtr & node)1296 static void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
1297 MS_EXCEPTION_IF_NULL(node);
1298 if (!CheckInsertMirrorOps(mirror_ops, node)) {
1299 return;
1300 }
1301
1302 DoInsertMirrorOps(root, mirror_ops, node);
1303 }
1304
BackwardCommunication(const FuncGraphPtr & root,const OperatorInfoPtr & distribute_operator,const CNodePtr & node,const std::vector<std::pair<CNodePtr,LossNodeInfo>> & sens_loss_pairs)1305 static void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator,
1306 const CNodePtr &node,
1307 const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
1308 MS_EXCEPTION_IF_NULL(distribute_operator);
1309 MS_EXCEPTION_IF_NULL(node);
1310
1311 if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
1312 return;
1313 }
1314 bool is_loss_cnode =
1315 std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
1316 [node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
1317
1318 MirrorOps mirror_ops = distribute_operator->mirror_ops();
1319 VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
1320 // insert mirror op
1321 if (!mirror_ops.empty()) {
1322 MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
1323 InsertMirrorOps(root, mirror_ops, node);
1324 }
1325 // insert virtual div op
1326 if (!virtual_div_op.empty() && is_loss_cnode && IsLastStage()) {
1327 MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name();
1328 InsertVirtualDivOp(virtual_div_op, node);
1329 }
1330 }
1331
FindParallelCareNode(const AnfNodePtr & node,int32_t recursion_num)1332 static std::pair<AnfNodePtr, int64_t> FindParallelCareNode(const AnfNodePtr &node, int32_t recursion_num) {
1333 if (recursion_num >= RECURSION_LIMIT) {
1334 return std::make_pair(nullptr, 0);
1335 }
1336
1337 MS_EXCEPTION_IF_NULL(node);
1338 FuncGraphPtr func_graph = node->func_graph();
1339 MS_EXCEPTION_IF_NULL(func_graph);
1340 FuncGraphManagerPtr manager = func_graph->manager();
1341 MS_EXCEPTION_IF_NULL(manager);
1342 AnfNodeIndexSet node_set = manager->node_users()[node];
1343 for (auto &node_pair : node_set) {
1344 CNodePtr cnode = node_pair.first->cast<CNodePtr>();
1345 MS_EXCEPTION_IF_NULL(cnode);
1346 if (!IsValueNode<Primitive>(cnode->input(0))) {
1347 continue;
1348 }
1349 if (IsPrimitiveCNode(cnode, prim::kPrimMirrorSilentCheck) && node_pair.second != 1) {
1350 continue;
1351 }
1352 ValueNodePtr prim_node_anf = cnode->input(0)->cast<ValueNodePtr>();
1353 MS_EXCEPTION_IF_NULL(prim_node_anf);
1354 PrimitivePtr node_prim = prim_node_anf->value()->cast<PrimitivePtr>();
1355 MS_EXCEPTION_IF_NULL(node_prim);
1356 if ((node_prim->name() == DEPEND && node_pair.second != 1) || IsPrimitiveCNode(cnode, prim::kPrimReceive) ||
1357 IsPrimitiveCNode(cnode, prim::kPrimSend)) {
1358 continue;
1359 }
1360 if (node_prim->name() == UPDATESTATE && node_pair.second > 0) {
1361 continue;
1362 }
1363 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
1364 return node_pair;
1365 } else {
1366 auto tmp_pair = FindParallelCareNode(node_pair.first, recursion_num + 1);
1367 if (tmp_pair.first != nullptr) {
1368 return tmp_pair;
1369 }
1370 }
1371 }
1372 return std::make_pair(nullptr, 0);
1373 }
1374
FindSubGraph(const FuncGraphPtr & graph,const AnfNodePtr & parameter)1375 static std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) {
1376 MS_EXCEPTION_IF_NULL(graph);
1377 MS_EXCEPTION_IF_NULL(parameter);
1378 FuncGraphManagerPtr manager = graph->manager();
1379 MS_EXCEPTION_IF_NULL(manager);
1380 std::pair<AnfNodePtr, int64_t> prim_anf_node_pair = FindParallelCareNode(parameter, 0);
1381 if (prim_anf_node_pair.first != nullptr) {
1382 return prim_anf_node_pair;
1383 } else {
1384 AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
1385 for (auto ¶m_pair : param_sub_set) {
1386 CNodePtr param_cnode = param_pair.first->cast<CNodePtr>();
1387 AnfNodePtr graph_value_node;
1388 if (param_cnode->input(0)->isa<CNode>()) {
1389 graph_value_node = param_cnode->input(0)->cast<CNodePtr>()->input(1);
1390 } else {
1391 graph_value_node = param_cnode->input(0);
1392 }
1393 if (!IsValueNode<FuncGraph>(graph_value_node)) {
1394 continue;
1395 }
1396 FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node);
1397 auto parameters = graph_sub->parameters();
1398 if (LongToSize(param_pair.second - 1) >= parameters.size()) {
1399 MS_LOG(EXCEPTION) << "The index is out of range, index is: " << (param_pair.second - 1) << ", vector size is "
1400 << parameters.size();
1401 }
1402 std::pair<AnfNodePtr, int64_t> res = FindSubGraph(graph_sub, parameters[LongToSize(param_pair.second - 1)]);
1403 if (res.first != nullptr) {
1404 return res;
1405 }
1406 }
1407 }
1408 return std::make_pair(nullptr, 0);
1409 }
1410
InsertAllGatherAfterCast(const std::pair<AnfNodePtr,int> & node_pair)1411 static CNodePtr InsertAllGatherAfterCast(const std::pair<AnfNodePtr, int> &node_pair) {
1412 if (ParallelContext::GetInstance()->pipeline_stage_split_num() <= 1) {
1413 return nullptr;
1414 }
1415 auto cnode = node_pair.first->cast<CNodePtr>();
1416 MS_EXCEPTION_IF_NULL(cnode);
1417 auto graph = cnode->func_graph();
1418 MS_EXCEPTION_IF_NULL(graph);
1419 auto manager = graph->manager();
1420 MS_EXCEPTION_IF_NULL(manager);
1421 // skip Load moving down and assume it only has one node user
1422 CNodePtr res = cnode;
1423 if (IsSomePrimitive(res, LOAD)) {
1424 res = manager->node_users()[cnode].begin()->first->cast<CNodePtr>();
1425 }
1426 // return true only if cnode is Cast from fp32 to fp16
1427 if (!IsSomePrimitive(res, CAST)) {
1428 return nullptr;
1429 }
1430 auto node_type = res->Type();
1431 MS_EXCEPTION_IF_NULL(node_type);
1432 if (!node_type->isa<mindspore::TensorType>()) {
1433 MS_LOG(EXCEPTION) << "Unknown type.";
1434 }
1435 auto input_element_type = node_type->cast<mindspore::TensorTypePtr>()->element();
1436 MS_EXCEPTION_IF_NULL(input_element_type);
1437 auto type_id = input_element_type->type_id();
1438
1439 if (type_id != kNumberTypeFloat32) {
1440 return res;
1441 } else {
1442 return nullptr;
1443 }
1444 }
1445
AddAllGatherAttrs(const CNodePtr & allgather,const CNodePtr & cnode,const AnfNodePtr & node,const std::string & op_name,bool add_accu,bool is_with_mirror,bool grad_accumulation_shard)1446 void AddAllGatherAttrs(const CNodePtr &allgather, const CNodePtr &cnode, const AnfNodePtr &node,
1447 const std::string &op_name, bool add_accu, bool is_with_mirror, bool grad_accumulation_shard) {
1448 // add fusion flag
1449 auto fusion_id = AddCommOpFusionType(allgather, node);
1450 auto param_ptr = node->cast<ParameterPtr>();
1451 auto param_name = param_ptr->name();
1452 AddNodeFusionInfo(cnode, allgather, "reduce_scatter", param_name, fusion_id);
1453 // add gradients mean
1454 AddCommOpMeanFlag(allgather);
1455 AddCNodePrimAttr(allgather, "with_mirror_operator", MakeValue<bool>(is_with_mirror));
1456 if (op_name == MICRO_STEP_ALL_GATHER) {
1457 // When grad_accumulation_shard is enabled, the ReduceScatter is inserted at each micro step
1458 // so no need to do backward for the micro_step_allgather
1459 AddCNodePrimAttr(allgather, DO_MIRROR, MakeValue<bool>(!grad_accumulation_shard));
1460 } else if (op_name == MINI_STEP_ALL_GATHER) {
1461 // We need to manually set the add_accu to be false if it's father node is MirrorMiniStep
1462 AddCNodePrimAttr(allgather, ADD_ACCU, MakeValue<bool>(!add_accu && !is_with_mirror));
1463 AddCNodePrimAttr(allgather, DO_MIRROR, MakeValue<bool>(!grad_accumulation_shard || !add_accu));
1464 }
1465 }
1466
InsertAllGatherOp(const FuncGraphPtr & root,const std::string & group,const std::pair<AnfNodePtr,int> & res,const AnfNodePtr & node,const std::string & op_name,bool is_shared_param)1467 static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res,
1468 const AnfNodePtr &node, const std::string &op_name, bool is_shared_param) {
1469 MS_EXCEPTION_IF_NULL(res.first);
1470 MS_EXCEPTION_IF_NULL(node);
1471 bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
1472 auto cnode = res.first->cast<CNodePtr>();
1473 auto graph = cnode->func_graph();
1474 MS_EXCEPTION_IF_NULL(graph);
1475 auto manager = graph->manager();
1476 MS_EXCEPTION_IF_NULL(manager);
1477 Operator op;
1478 CNodePtr allgather;
1479 auto param_name = node->cast<ParameterPtr>()->name();
1480 if (op_name == MICRO_STEP_ALL_GATHER) {
1481 op = CreateMicroStepAllGatherOp(group);
1482 } else {
1483 op = CreateAllGatherOp(group);
1484 }
1485 CNodePtr cast_node = InsertAllGatherAfterCast(res);
1486 auto param_ptr = node->cast<ParameterPtr>();
1487 MS_EXCEPTION_IF_NULL(param_ptr);
1488 bool is_with_mirror = false;
1489 if (param_ptr->user_data<TensorLayout>()) {
1490 auto opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
1491 is_with_mirror = !opt_shard_mirror_group.empty();
1492 if (!param_ptr->param_info()->parallel_optimizer()) {
1493 auto mirror_group = mirror_group_list(param_ptr->user_data<TensorLayout>());
1494 is_with_mirror = mirror_group.size() > 1;
1495 }
1496 }
1497 if (!is_shared_param && cast_node) {
1498 allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root);
1499 MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
1500 } else {
1501 auto pre_node = node;
1502 AnfNodePtr pre_node_ = node;
1503 auto &node_user_map = manager->node_users();
1504 TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(res, node_user_map);
1505 if (next_node_dtype) {
1506 MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving"
1507 << " communication.";
1508 pre_node_ = CreateFP16Cast(cnode, pre_node, next_node_dtype);
1509 }
1510 InsertNode(op, cnode, IntToSize(res.second), pre_node_, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name,
1511 root);
1512 allgather = cnode->input(IntToSize(res.second))->cast<CNodePtr>();
1513 MS_LOG(INFO) << "Parallel optimizer is applied before " << cnode->DebugString() << " for " << param_name;
1514 }
1515 bool add_accu = root->has_flag(kAccumulation);
1516 AddAllGatherAttrs(allgather, cnode, node, op_name, add_accu, is_with_mirror, grad_accumulation_shard);
1517 }
1518
IsForwardCNode(const CNodePtr & cnode)1519 bool IsForwardCNode(const CNodePtr &cnode) {
1520 if (cnode->in_forward_flag()) {
1521 return true;
1522 }
1523 if (cnode->input(0) && IsValueNode<FuncGraph>(cnode->input(0))) {
1524 auto func_graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
1525 auto orders = func_graph->GetOrderedCnodes();
1526 return std::any_of(orders.begin(), orders.end(), [](const auto &c_node) { return c_node->in_forward_flag(); });
1527 }
1528 return false;
1529 }
1530
InsertParallelOpt(const FuncGraphPtr & root,const AnfNodePtr & parameter,const std::string & opt_shard_group,const std::string & op_name)1531 void InsertParallelOpt(const FuncGraphPtr &root, const AnfNodePtr ¶meter, const std::string &opt_shard_group,
1532 const std::string &op_name) {
1533 // insert all gather
1534 FuncGraphManagerPtr manager = root->manager();
1535 MS_EXCEPTION_IF_NULL(manager);
1536 auto param_sub_set = manager->node_users()[parameter];
1537 bool insert_flag = false;
1538 for (auto ¶m_pair : param_sub_set) {
1539 auto cnode = param_pair.first->cast<CNodePtr>();
1540 MS_EXCEPTION_IF_NULL(cnode);
1541 if (IsForwardCNode(cnode) && !IsPrimitiveCNode(cnode, prim::kPrimReceive) &&
1542 !(IsPrimitiveCNode(cnode, prim::kPrimDepend) && param_pair.second == INDEX_TWO)) {
1543 if (insert_flag) {
1544 // if there are multiple node users, they share one same allgather
1545 auto next_cnode = FindCNode(parameter, op_name, cnode->func_graph(), 0);
1546 if (next_cnode.first) {
1547 manager->SetEdge(cnode, param_pair.second, next_cnode.second);
1548 auto param_ptr = parameter->cast<ParameterPtr>();
1549 MS_EXCEPTION_IF_NULL(param_ptr);
1550 AddNodeMirrorInfo(cnode, param_ptr->name());
1551 MS_LOG(INFO) << "Parallel optimizer is shared between " << parameter->ToString() << " and "
1552 << GetPrimName(cnode);
1553 } else {
1554 MS_LOG(ERROR) << "Can not find the shared AllGather with multiple node users.";
1555 }
1556 } else {
1557 // insert allgather operator between shard parameter and cnode
1558 auto param_ptr = parameter->cast<ParameterPtr>();
1559 MS_EXCEPTION_IF_NULL(param_ptr);
1560 bool is_shared_param = param_ptr->user_data<TensorLayout>()->is_shared_param();
1561 InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name, is_shared_param);
1562 insert_flag = true;
1563 }
1564 }
1565 }
1566 }
1567
ApplyParallelOptOnParam(const FuncGraphPtr & root,const AnfNodePtr & parameter,const std::string & opt_shard_group)1568 static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
1569 const std::string &opt_shard_group) {
1570 auto enable_opt_shard = ParallelContext::GetInstance()->enable_parallel_optimizer();
1571 if (!enable_opt_shard) {
1572 return;
1573 }
1574 MS_EXCEPTION_IF_NULL(parameter);
1575 if (ParameterIsCloned(parameter)) {
1576 return;
1577 }
1578
1579 int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
1580 if (opt_shard_group.empty() &&
1581 (split_stage_num <= 1 || !ParameterRequireGrad(parameter) || !root->has_flag(kTraining))) {
1582 return;
1583 }
1584
1585 // set all gather type
1586 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
1587 std::string op_name = ALL_GATHER;
1588 if (root->has_flag(kTraining)) {
1589 if ((grad_accumulation_step > 1 || split_stage_num > 1) && ParameterRequireGrad(parameter)) {
1590 op_name = MICRO_STEP_ALL_GATHER;
1591 }
1592 }
1593
1594 // insert all gather
1595 InsertParallelOpt(root, parameter, opt_shard_group, op_name);
1596 }
1597
1598 // When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
SetParallelShape(const AnfNodePtr & parameter,const std::pair<AnfNodePtr,int64_t> & res,const FuncGraphPtr & root,const int & idx)1599 static std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int64_t> &res,
1600 const FuncGraphPtr &root, const int &idx) {
1601 // check null for param and cnode
1602 MS_EXCEPTION_IF_NULL(parameter);
1603 auto param_shape = parameter->Shape();
1604
1605 MS_EXCEPTION_IF_NULL(param_shape);
1606
1607 CNodePtr cnode = res.first->cast<CNodePtr>();
1608 MS_EXCEPTION_IF_NULL(cnode);
1609
1610 // get slice_shape
1611 OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
1612 if (distribute_operator == nullptr) {
1613 MS_LOG(EXCEPTION) << "node " << cnode->ToString() << " 's distribute_operator is nullptr";
1614 }
1615 TensorLayout tensor_layout;
1616 if (distribute_operator->inputs_tensor_info_new().empty()) {
1617 if (LongToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) {
1618 MS_LOG(EXCEPTION) << "The parameter index is not in inputs_tensor_info. index = " << (res.second - 1)
1619 << ", inputs_tensor_info size = " << distribute_operator->inputs_tensor_info().size();
1620 }
1621 TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[LongToSize(res.second - 1)];
1622 tensor_layout = tensorinfo_in.tensor_layout();
1623 } else {
1624 TensorInfoBasePtr tensorinfo_in;
1625 if (idx == -1) {
1626 tensorinfo_in = distribute_operator->inputs_tensor_info_new()[LongToSize(res.second - 1)];
1627 } else {
1628 // idx != -1, input is maketuple
1629 tensorinfo_in = distribute_operator->inputs_tensor_info_new()[LongToSize(idx)];
1630 }
1631 if (tensorinfo_in->is_list()) {
1632 if (idx == -1) {
1633 MS_LOG(EXCEPTION) << "The input of " << distribute_operator->name() << " is a list, but idx is -1.";
1634 }
1635 tensor_layout = tensorinfo_in->GetElement(res.second - 1)->GetValue().tensor_layout();
1636 } else {
1637 tensor_layout = tensorinfo_in->GetValue().tensor_layout();
1638 }
1639 }
1640 Shape slice_shape = tensor_layout.base_slice_shape().array();
1641
1642 // generate shard group
1643 std::string opt_shard_group;
1644 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1645 bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
1646 if (enable_parallel_optimizer) {
1647 std::unique_ptr<OptParamMgr> apOptParamMgr = createOptParamMgr(root);
1648 opt_shard_group = apOptParamMgr->ShardOptGroup(parameter, &tensor_layout, distribute_operator);
1649 // set the shape of parameter to sliced shape
1650 if (!opt_shard_group.empty()) {
1651 slice_shape = tensor_layout.opt_shard_slice_shape();
1652 }
1653 MS_LOG(INFO) << "the shape of " << parameter->ToString() << "(original: " << param_shape->ToString() << ")"
1654 << " will be sliced into " << MakeValue(slice_shape)->ToString() << " in op "
1655 << distribute_operator->name();
1656 }
1657
1658 AbstractBasePtr abstract = parameter->abstract();
1659 if (abstract == nullptr) {
1660 MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract is nullptr";
1661 }
1662
1663 AbstractBasePtr cloned_abstract = abstract->Clone();
1664 if (cloned_abstract == nullptr) {
1665 MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << ": abstract clone failed";
1666 }
1667
1668 cloned_abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape));
1669 parameter->set_abstract(cloned_abstract);
1670 ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
1671 MS_EXCEPTION_IF_NULL(parameter_ptr);
1672 if (tensor_layout.IsInterleavedParallel()) {
1673 MS_LOG(EXCEPTION) << "parameter " << parameter->ToString() << " can not set to interleaved parallel";
1674 }
1675 parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
1676 if (ParallelContext::GetInstance()->direct_split() && parameter_ptr->has_default()) {
1677 auto layout = parameter_ptr->user_data<TensorLayout>();
1678 MS_LOG(INFO) << "parameter: " << parameter->ToString() << parameter->Shape()->ToString()
1679 << "parameter_ptr->default_param()" << parameter_ptr->default_param() << "LAYOUT"
1680 << layout->ToString();
1681 SliceTensorObj(parameter_ptr, layout);
1682 }
1683 return opt_shard_group;
1684 }
1685
ObtainActualInputIdxForSupportedOps(const AnfNodeIndexSet & node_set)1686 int ObtainActualInputIdxForSupportedOps(const AnfNodeIndexSet &node_set) {
1687 int idx = 0;
1688 for (const auto &node_pair : node_set) {
1689 auto use_cnode = node_pair.first->cast<CNodePtr>();
1690 if (IsSomePrimitiveList(use_cnode, SUPPORT_NEW_SHAPEBASE_OPS)) {
1691 idx = node_pair.second;
1692 }
1693 }
1694 return idx;
1695 }
1696
CoverSliceShape(const FuncGraphPtr & root)1697 static void CoverSliceShape(const FuncGraphPtr &root) {
1698 MS_EXCEPTION_IF_NULL(root);
1699 auto parameters = root->parameters();
1700 FuncGraphManagerPtr manager = root->manager();
1701 MS_EXCEPTION_IF_NULL(manager);
1702 const auto &node_users_map = manager->node_users();
1703 for (auto ¶meter : parameters) {
1704 MS_EXCEPTION_IF_NULL(parameter->Shape());
1705 auto iter = g_RefMap.find(parameter);
1706 if (iter != g_RefMap.cend()) {
1707 auto node_set = node_users_map.at(g_RefMap[parameter].first);
1708 auto idx = ObtainActualInputIdxForSupportedOps(node_set);
1709 std::string group = SetParallelShape(parameter, g_RefMap[parameter], root, idx - 1);
1710 // find all forward nodes that use parameter in graphs and insert allgather if group is not empty
1711 SetSharedParameterFlag(root, parameter);
1712 ApplyParallelOptOnParam(root, parameter, group);
1713 continue;
1714 }
1715
1716 std::pair<AnfNodePtr, int64_t> res = FindSubGraph(root, parameter);
1717 if (res.first == nullptr) {
1718 MS_LOG(INFO) << "Parameter " << parameter->ToString() << " is not in graph, thus no need to set parallel shape";
1719 if (parameter->has_user_data<TensorLayout>()) {
1720 auto param_abstract = parameter->abstract()->Clone();
1721 auto tensor_layout = parameter->user_data<TensorLayout>();
1722 Shape slice_shape = tensor_layout->base_slice_shape().array();
1723 param_abstract->set_shape(std::make_shared<abstract::Shape>(slice_shape));
1724 parameter->set_abstract(param_abstract);
1725 }
1726 } else {
1727 auto node_set = node_users_map.at(res.first);
1728 auto idx = ObtainActualInputIdxForSupportedOps(node_set);
1729 std::string group = SetParallelShape(parameter, res, root, idx - 1);
1730 // find all forward nodes that use parameter in graphs and insert allgather if group is not empty
1731 SetSharedParameterFlag(root, parameter);
1732 ApplyParallelOptOnParam(root, parameter, group);
1733 MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
1734 }
1735 }
1736 g_RefMap.clear();
1737 }
1738
PreProcessActualSeqLenInputForFlashAttentionScore(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1739 static void PreProcessActualSeqLenInputForFlashAttentionScore(const FuncGraphPtr &root,
1740 const std::vector<AnfNodePtr> &all_nodes) {
1741 auto manager = root->manager();
1742 MS_EXCEPTION_IF_NULL(manager);
1743 for (auto node : all_nodes) {
1744 if (IsPrimitiveCNode(node, prim::kPrimFlashAttentionScore)) {
1745 auto fa_cnode = node->cast<CNodePtr>();
1746 MS_EXCEPTION_IF_NULL(fa_cnode);
1747 auto fa_inputs = fa_cnode->inputs();
1748 for (size_t index = ops::kFlashAttentionScoreInputActualSeqQlenIndex;
1749 index <= ops::kFlashAttentionScoreInputActualSeqKVlenIndex; ++index) {
1750 auto input = fa_inputs.at(index + 1);
1751 if (IsValueNode<None>(input)) {
1752 continue;
1753 }
1754 // Transfer Tuple to Tensor
1755 if (IsPrimitiveCNode(input, prim::kPrimTensorToTuple)) {
1756 // Eliminate TensorToTuple
1757 manager->SetEdge(fa_cnode, index + 1, input->cast<CNodePtr>()->input(kIndex1));
1758 MS_LOG(DEBUG) << "Eliminate TensorToTuple for " << fa_cnode->fullname_with_scope() << ", index is "
1759 << index + 1;
1760 } else {
1761 auto dtype = NewValueNode(MakeValue<int64_t>(kInt64->type_id()));
1762 dtype->set_abstract(abstract::FromValue((int64_t)(kInt64->type_id())));
1763 auto tuple_to_tensor_cnode =
1764 fa_cnode->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleToTensor), input, dtype});
1765 auto abs = GenerateAbsByOpInfer(GetCNodePrimitive(tuple_to_tensor_cnode), {input, dtype});
1766 tuple_to_tensor_cnode->set_abstract(abs);
1767 manager->SetEdge(fa_cnode, index + 1, tuple_to_tensor_cnode);
1768 MS_LOG(DEBUG) << "Insert TupleToTensor for " << fa_cnode->fullname_with_scope() << ", index is " << index + 1;
1769 }
1770 }
1771 }
1772 }
1773 }
1774
PostProcessActualSeqLenInputForFlashAttentionScore(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)1775 static void PostProcessActualSeqLenInputForFlashAttentionScore(const FuncGraphPtr &root,
1776 const std::vector<AnfNodePtr> &all_nodes) {
1777 auto manager = root->manager();
1778 MS_EXCEPTION_IF_NULL(manager);
1779 for (auto node : all_nodes) {
1780 if (IsPrimitiveCNode(node, prim::kPrimFlashAttentionScore)) {
1781 auto fa_cnode = node->cast<CNodePtr>();
1782 MS_EXCEPTION_IF_NULL(fa_cnode);
1783 auto fa_inputs = fa_cnode->inputs();
1784 for (size_t index = ops::kFlashAttentionScoreInputActualSeqQlenIndex;
1785 index <= ops::kFlashAttentionScoreInputActualSeqKVlenIndex; ++index) {
1786 auto input = fa_inputs.at(index + 1);
1787 auto input_abs = input->abstract();
1788 if (IsValueNode<None>(input)) {
1789 continue;
1790 }
1791
1792 if (IsPrimitiveCNode(input, prim::kPrimTupleToTensor)) {
1793 // Eliminate TupleToTensor
1794 manager->SetEdge(fa_cnode, index + 1, input->cast<CNodePtr>()->input(kIndex1));
1795 MS_LOG(DEBUG) << "Eliminate TensorToTuple for " << fa_cnode->fullname_with_scope() << ", index is "
1796 << index + 1;
1797 } else {
1798 // Transfer Tensor to Tuple
1799 auto tensor_to_tuple_cnode =
1800 fa_cnode->func_graph()->NewCNode({NewValueNode(prim::kPrimTensorToTuple), input});
1801 manager->SetEdge(fa_cnode, index + 1, tensor_to_tuple_cnode);
1802 MS_LOG(DEBUG) << "Insert TensorToTuple for " << fa_cnode->fullname_with_scope() << ", index is " << index + 1;
1803 }
1804 }
1805 }
1806 }
1807 }
1808
ObtainStrategyForNewShapes(const ShapeBasePtr & shape,const int64_t & dev_num)1809 ValuePtr ObtainStrategyForNewShapes(const ShapeBasePtr &shape, const int64_t &dev_num) {
1810 ValuePtr stra_value_ptr;
1811 if (shape->is_list()) {
1812 std::vector<ValuePtr> elements;
1813 for (size_t i = 0; i < shape->size(); ++i) {
1814 auto value_stra = ObtainStrategyForNewShapes(shape->GetElement(SizeToLong(i)), dev_num);
1815 elements.emplace_back(value_stra);
1816 }
1817 stra_value_ptr = std::make_shared<ValueTuple>(elements);
1818 } else {
1819 Dimensions stra;
1820 stra.push_back(dev_num);
1821 for (size_t j = 1; j < shape->size(); ++j) {
1822 stra.push_back(1);
1823 }
1824 stra_value_ptr = MakeValue(stra);
1825 }
1826 return stra_value_ptr;
1827 }
1828
ObtainElementsForStrategyNewShape(const std::vector<NewShapes> & new_shape_list,const int64_t & dev_num,std::vector<ValuePtr> * elements)1829 void ObtainElementsForStrategyNewShape(const std::vector<NewShapes> &new_shape_list, const int64_t &dev_num,
1830 std::vector<ValuePtr> *elements) {
1831 for (size_t i = 0; i < new_shape_list[0].size(); i++) {
1832 if (new_shape_list[0][i]->empty()) {
1833 (void)elements->emplace_back(MakeValue(Dimensions()));
1834 continue;
1835 }
1836 auto input_strategy = ObtainStrategyForNewShapes(new_shape_list[0][i], dev_num);
1837 (void)elements->emplace_back(MakeValue(input_strategy));
1838 }
1839 }
1840
ObtainElementsForStrategy(const std::vector<Shapes> & shape_list,const int64_t & dev_num,std::vector<ValuePtr> * elements)1841 void ObtainElementsForStrategy(const std::vector<Shapes> &shape_list, const int64_t &dev_num,
1842 std::vector<ValuePtr> *elements) {
1843 for (size_t i = 0; i < shape_list[0].size(); i++) {
1844 if (shape_list[0][i].empty()) {
1845 (void)elements->emplace_back(MakeValue(Dimensions()));
1846 continue;
1847 }
1848 Dimensions input_strategy;
1849 input_strategy.push_back(dev_num);
1850 if (shape_list[0][i][0] > 0 && shape_list[0][i][0] % dev_num != 0) {
1851 MS_LOG(EXCEPTION) << "The shapes of dataset is " << shape_list[0]
1852 << ", the batch dim can not be evenly div by dev_num " << dev_num;
1853 }
1854 for (size_t j = 1; j < shape_list[0][i].size(); j++) {
1855 input_strategy.push_back(1);
1856 }
1857 (void)elements->emplace_back(MakeValue(input_strategy));
1858 }
1859 }
1860
ObtainShape(const CNodePtr & node)1861 std::pair<std::vector<Shapes>, std::vector<NewShapes>> ObtainShape(const CNodePtr &node) {
1862 std::vector<Shapes> shape_list;
1863 std::vector<NewShapes> new_shape_list;
1864 if (HasSupportedValueSequence(node)) {
1865 new_shape_list = ExtractNewShape(node);
1866 } else {
1867 shape_list = ExtractShape(node);
1868 }
1869 return std::make_pair(shape_list, new_shape_list);
1870 }
1871
SetVirtualDatasetStrategy(const CNodePtr & node)1872 void SetVirtualDatasetStrategy(const CNodePtr &node) {
1873 MS_EXCEPTION_IF_NULL(node);
1874 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
1875 bool full_batch = ParallelContext::GetInstance()->full_batch();
1876
1877 PrimitivePtr prim = GetValueNode<PrimitivePtr>(node->input(0));
1878 MS_EXCEPTION_IF_NULL(prim);
1879 if (prim->name() == VIRTUAL_DATA_SET || prim->name() == VIRTUAL_OUTPUT) {
1880 CheckGlobalDeviceManager();
1881 auto attrs_temp = prim->attrs();
1882 if (!ParallelContext::GetInstance()->dataset_strategy().empty() && prim->name() == VIRTUAL_DATA_SET) {
1883 std::vector<ValuePtr> elements;
1884 auto dataset_strategy = ParallelContext::GetInstance()->dataset_strategy();
1885 (void)std::transform(dataset_strategy.begin(), dataset_strategy.end(), std::back_inserter(elements),
1886 [](auto input_stra) { return MakeValue(input_stra); });
1887 ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1888 attrs_temp[IN_STRATEGY] = strategy;
1889 (void)prim->SetAttrs(attrs_temp);
1890 if (prim->HasAttr(REPEAT_DIM_DIRECT) && GetValue<std::string>(prim->GetAttr(REPEAT_DIM_DIRECT)) == RIGHT) {
1891 ParallelContext::GetInstance()->set_dataset_repeat_dim_right(true);
1892 MS_LOG(INFO) << "dataset repeat dim is right";
1893 }
1894 return;
1895 }
1896 int64_t dev_num;
1897 if (full_batch) {
1898 dev_num = 1;
1899 } else {
1900 dev_num = g_device_manager->stage_device_num();
1901 }
1902 if (dev_num == 0) {
1903 MS_LOG(EXCEPTION) << "Device Num must be larger than 0, but got 0.";
1904 }
1905 std::vector<Shapes> shape_list;
1906 std::vector<NewShapes> new_shape_list;
1907 if (InDynamicGraph(node)) {
1908 shape_list = ExtractRealDivisor(node);
1909 MS_LOG(INFO) << "The node is in dynamic shape graph, the real divisor is " << ShapesToString(shape_list[0]);
1910 } else {
1911 std::tie(shape_list, new_shape_list) = ObtainShape(node);
1912 }
1913 if (shape_list.empty() && new_shape_list.empty()) {
1914 MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
1915 }
1916 std::vector<ValuePtr> elements;
1917 if (new_shape_list.empty()) {
1918 ObtainElementsForStrategy(shape_list, dev_num, &elements);
1919 } else {
1920 ObtainElementsForStrategyNewShape(new_shape_list, dev_num, &elements);
1921 }
1922 ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
1923 attrs_temp[IN_STRATEGY] = strategy;
1924 (void)prim->SetAttrs(attrs_temp);
1925 }
1926 }
1927
CheckExtractInformation(const CNodePtr & cnode)1928 static bool CheckExtractInformation(const CNodePtr &cnode) {
1929 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
1930 return false;
1931 }
1932
1933 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
1934 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
1935 if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
1936 return false;
1937 }
1938
1939 return IsParallelCareNode(cnode);
1940 }
1941
GenerateStandAloneStra(const OperatorInfoPtr & op_info)1942 StrategyPtr GenerateStandAloneStra(const OperatorInfoPtr &op_info) {
1943 StrategyPtr in_strategy;
1944 if (op_info->inputs_shape_new().empty()) {
1945 in_strategy = GenerateStandAloneStrategy(op_info->inputs_shape());
1946 } else {
1947 in_strategy = GenerateStandAloneStrategyForNewShapes(op_info->inputs_shape_new());
1948 }
1949 return in_strategy;
1950 }
1951
CheckStrategyAndShape(const StrategyPtr & in_strategy,const OperatorInfoPtr & op_info)1952 void CheckStrategyAndShape(const StrategyPtr &in_strategy, const OperatorInfoPtr &op_info) {
1953 MS_EXCEPTION_IF_NULL(in_strategy);
1954 auto has_tuple_stra = in_strategy->HasTupleInTupleStrategy();
1955 auto has_new_shape = !op_info->inputs_shape_new().empty();
1956 if (has_tuple_stra != has_new_shape) {
1957 MS_LOG(EXCEPTION)
1958 << "One of the strategy or input shape have tuple in tuple input, but the other does not; in_strategy is "
1959 << has_tuple_stra << ", input shape is " << has_new_shape;
1960 }
1961 }
1962
ExtractStrategyAndInit(const CNodePtr & cnode,const PrimitivePtr & prim,const OperatorInfoPtr & op_info)1963 static void ExtractStrategyAndInit(const CNodePtr &cnode, const PrimitivePtr &prim, const OperatorInfoPtr &op_info) {
1964 StrategyPtr in_strategy = nullptr, out_strategy = nullptr;
1965 auto attrs = prim->attrs();
1966
1967 // load strategy map from checkpoint
1968 StrategyMap stra_map;
1969 if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
1970 (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) {
1971 MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
1972 }
1973
1974 std::string strategy_key_name = "";
1975 auto param_names = NodeParameterName(cnode, -1, 0);
1976 if (!param_names.empty()) {
1977 strategy_key_name = prim->name() + "_" + param_names[0].first;
1978 }
1979 std::vector<std::shared_ptr<TensorLayout>> in_tensor_layouts;
1980 std::vector<std::shared_ptr<TensorLayout>> out_tensor_layouts;
1981 if (ExtractUserConfigLayout(attrs, op_info->inputs_shape(), op_info->outputs_shape(), &in_tensor_layouts,
1982 &out_tensor_layouts) != SUCCESS) {
1983 MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " extract configured layout failed"
1984 << trace::DumpSourceLines(cnode);
1985 }
1986 if (in_tensor_layouts.empty() && out_tensor_layouts.empty()) {
1987 bool load_strategy_from_ckpt =
1988 StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
1989 if (!prim->HasAttr(STAND_ALONE)) {
1990 if (((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(IN_STRATEGY)) ||
1991 prim->HasAttr(BATCH_PARALLEL)) {
1992 MS_LOG(INFO) << "ExtractInformation: the strategy of node " << cnode->ToString() << " prim " << prim->name()
1993 << " is empty, using batch parallel";
1994 in_strategy = GenerateBatchParallelStrategy(op_info, prim);
1995 } else if (cnode->HasPrimalAttr(IN_STRATEGY)) {
1996 in_strategy = ExtractStrategy(cnode->GetPrimalAttr(IN_STRATEGY));
1997 out_strategy = ExtractStrategy(cnode->GetPrimalAttr(OUT_STRATEGY));
1998 } else if (StrategyFound(attrs)) {
1999 in_strategy = ExtractStrategy(attrs[IN_STRATEGY]);
2000 out_strategy = ExtractStrategy(attrs[OUT_STRATEGY]);
2001 } else {
2002 in_strategy = stra_map[strategy_key_name];
2003 }
2004 } else {
2005 in_strategy = GenerateStandAloneStra(op_info);
2006 }
2007 CheckStrategyAndShape(in_strategy, op_info);
2008 }
2009 if (op_info->Init(in_strategy, out_strategy, in_tensor_layouts, out_tensor_layouts) == FAILED) {
2010 MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed" << trace::DumpSourceLines(cnode);
2011 }
2012 }
2013
ExtractInformation(const std::vector<AnfNodePtr> & all_nodes)2014 void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
2015 SetStridedSliceSplitStrategy(all_nodes);
2016 for (auto &node : all_nodes) {
2017 auto cnode = node->cast<CNodePtr>();
2018 if (!CheckExtractInformation(cnode) || IsPrimitiveCNode(node, prim::kPrimSend)) {
2019 continue;
2020 }
2021
2022 SetVirtualDatasetStrategy(cnode);
2023 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2024 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
2025
2026 OperatorInfoPtr operator_ = CreateOperatorInfo(cnode);
2027 MS_EXCEPTION_IF_NULL(operator_);
2028
2029 if (prim->name() == RESHAPE) {
2030 cnode->set_user_data<OperatorInfo>(operator_);
2031 continue;
2032 }
2033
2034 ExtractStrategyAndInit(cnode, prim, operator_);
2035 cnode->set_user_data<OperatorInfo>(operator_);
2036 }
2037 }
2038
2039 // if reshape's output connect to several primitive, return the first layout found
FindNextLayout(const AnfNodePtr & cnode,bool * next_is_reshape,mindspore::HashSet<AnfNodePtr> * visit,int make_tuple_index,int tuple_get_index,const std::shared_ptr<TensorLayout> & pre_layout)2040 static std::shared_ptr<TensorLayout> FindNextLayout(const AnfNodePtr &cnode, bool *next_is_reshape,
2041 mindspore::HashSet<AnfNodePtr> *visit, int make_tuple_index,
2042 int tuple_get_index,
2043 const std::shared_ptr<TensorLayout> &pre_layout) {
2044 MS_EXCEPTION_IF_NULL(cnode);
2045 MS_EXCEPTION_IF_NULL(next_is_reshape);
2046 MS_EXCEPTION_IF_NULL(visit);
2047 MS_EXCEPTION_IF_NULL(cnode->func_graph());
2048 FuncGraphManagerPtr manager = cnode->func_graph()->manager();
2049 MS_EXCEPTION_IF_NULL(manager);
2050 AnfNodeIndexSet node_set = manager->node_users()[cnode];
2051 for (auto &node_pair : node_set) {
2052 auto use_apply = node_pair.first->cast<CNodePtr>();
2053 if (visit->find(use_apply) != visit->end()) {
2054 continue;
2055 }
2056 (void)(visit->insert(use_apply));
2057
2058 if (IsPrimitiveCNode(use_apply, prim::kPrimPrint) || IsPrimitiveCNode(use_apply, prim::kPrimTensorDump)) {
2059 return pre_layout;
2060 }
2061
2062 if (IsValueNode<FuncGraph>(use_apply->input(0))) {
2063 auto fg = GetValueNode<FuncGraphPtr>(use_apply->input(0));
2064 MS_EXCEPTION_IF_NULL(fg);
2065 auto fg_parameters = fg->parameters();
2066 auto param = fg_parameters[IntToSize(node_pair.second - 1)];
2067 auto next_layout = FindNextLayout(param, next_is_reshape, visit, make_tuple_index, tuple_get_index, pre_layout);
2068 if (next_layout != nullptr) {
2069 return next_layout;
2070 }
2071 }
2072
2073 if (IsPrimitiveCNode(use_apply, prim::kPrimReturn)) {
2074 auto fg = use_apply->func_graph();
2075 auto fg_map = fg->func_graph_cnodes_index();
2076 for (auto &fg_use : fg_map) {
2077 auto fg_node = fg_use.first->first->cast<CNodePtr>();
2078 MS_EXCEPTION_IF_NULL(fg_node);
2079 auto next_layout =
2080 FindNextLayout(fg_node, next_is_reshape, visit, make_tuple_index, tuple_get_index, pre_layout);
2081 if (next_layout != nullptr) {
2082 return next_layout;
2083 }
2084 }
2085 }
2086
2087 if (IsPrimitiveCNode(use_apply, prim::kPrimTupleGetItem)) {
2088 auto temp = LongToInt(GetTupleGetItemIndex(use_apply));
2089 if (temp != make_tuple_index - 1 && make_tuple_index > 0) {
2090 continue;
2091 }
2092 temp = make_tuple_index > 0 ? -1 : temp;
2093 auto next_layout = FindNextLayout(use_apply, next_is_reshape, visit, temp, -1, pre_layout);
2094 if (next_layout != nullptr) {
2095 return next_layout;
2096 }
2097 }
2098
2099 if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
2100 continue;
2101 }
2102 if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
2103 *next_is_reshape = true;
2104 continue;
2105 }
2106 if (IsOneOfPrimitiveCNode(use_apply, {prim::kPrimDepend, prim::kPrimUpdateState}) && node_pair.second != 1) {
2107 continue;
2108 }
2109 if (IsPrimitiveCNode(use_apply, prim::kPrimMakeTuple)) {
2110 make_tuple_index = node_pair.second;
2111 auto next_layout =
2112 FindNextLayout(use_apply, next_is_reshape, visit, make_tuple_index, tuple_get_index, pre_layout);
2113 if (next_layout != nullptr) {
2114 return next_layout;
2115 }
2116 }
2117 if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>() &&
2118 IsSomePrimitiveList(use_apply, SUPPORT_NEW_SHAPEBASE_OPS)) {
2119 MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString() << ", in support new shapebase ops";
2120 *next_is_reshape = false;
2121 auto layout = GetInputLayoutFromCNode(node_pair, make_tuple_index);
2122 return std::make_shared<TensorLayout>(layout);
2123 }
2124 if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
2125 if (make_tuple_index > 0) {
2126 node_pair.second = make_tuple_index;
2127 }
2128 MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString();
2129 *next_is_reshape = false;
2130 auto layout = GetInputLayoutFromCNode(node_pair, -1);
2131 return std::make_shared<TensorLayout>(layout);
2132 }
2133 MS_LOG(DEBUG) << "FindNextLayout failed node " << use_apply->DebugString() << " " << IsParallelCareNode(use_apply)
2134 << " " << use_apply->has_user_data<OperatorInfo>();
2135
2136 auto layout_ptr = FindNextLayout(use_apply, next_is_reshape, visit, make_tuple_index, tuple_get_index, pre_layout);
2137 if (layout_ptr) {
2138 return layout_ptr;
2139 }
2140 }
2141 return nullptr;
2142 }
2143
GetOutputLayoutFromCNode(const CNodePtr & cnode,size_t output_index)2144 static std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) {
2145 MS_EXCEPTION_IF_NULL(cnode);
2146 OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2147 MS_EXCEPTION_IF_NULL(distribute_operator);
2148 TensorLayout tensorlayout_out;
2149 if (distribute_operator->outputs_tensor_info_new().empty()) {
2150 if (distribute_operator->outputs_tensor_info().size() <= output_index) {
2151 MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->outputs_tensor_info().size()
2152 << ", must be greater than output_index " << output_index;
2153 }
2154 TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index];
2155 tensorlayout_out = tensorinfo_out.tensor_layout();
2156 } else {
2157 if (distribute_operator->outputs_tensor_info_new().size() <= output_index) {
2158 MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->outputs_tensor_info_new().size()
2159 << ", must be greater than output_index " << output_index;
2160 }
2161 auto tensorinfo_out = distribute_operator->outputs_tensor_info_new()[output_index];
2162 if (tensorinfo_out->is_list()) {
2163 MS_LOG(EXCEPTION) << "For " << cnode->DebugString() << ": the " << output_index
2164 << " out tensorinfo is a list, which does not support yet";
2165 }
2166 tensorlayout_out = tensorinfo_out->GetValue().tensor_layout();
2167 }
2168 return std::make_shared<TensorLayout>(tensorlayout_out);
2169 }
2170
FindPrevParallelCareNodeLayout(const AnfNodePtr & node,size_t output_index)2171 static std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) {
2172 if (!node->isa<CNode>()) {
2173 return nullptr;
2174 }
2175 CNodePtr cnode = node->cast<CNodePtr>();
2176 if (!IsValueNode<Primitive>(cnode->input(0))) {
2177 return nullptr;
2178 }
2179 if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
2180 auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
2181 if (!layout_ptr) {
2182 MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
2183 }
2184 return layout_ptr;
2185 }
2186 return nullptr;
2187 }
2188
InferSensRedistribution(const AnfNodePtr & node,const TensorLayout & loss_layout)2189 static RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const TensorLayout &loss_layout) {
2190 MS_EXCEPTION_IF_NULL(node);
2191 TensorRedistribution tensor_redistribution;
2192 // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num].
2193 CheckGlobalDeviceManager();
2194 int64_t dev_num = g_device_manager->stage_device_num();
2195 TensorLayout stand_alone_layout;
2196 Shapes inputs_shape = GetNodeShape(node);
2197 if (inputs_shape.empty()) {
2198 MS_LOG(EXCEPTION) << "InferSensRedistribution failed cause inputs shape is empty.";
2199 }
2200 Shape input_shape_array = inputs_shape[0];
2201 if (input_shape_array.empty()) {
2202 MS_LOG(INFO) << "No need to redistribution for sens.";
2203 return nullptr;
2204 }
2205 // TensorMap
2206 TensorMap stand_alone_tensor_map_array(SizeToLong(input_shape_array.size()), -1);
2207 // Dev_matrix
2208 Shape dev_matrix_array = {dev_num};
2209 if (stand_alone_layout.InitFromVector(dev_matrix_array, stand_alone_tensor_map_array, input_shape_array) == FAILED) {
2210 MS_LOG(EXCEPTION) << "Create tensor layout for Sens failed.";
2211 }
2212
2213 // Infer Redistribution op list for stand alone and loss layout.
2214 RankList dev_list = g_device_manager->GetDeviceListInThisStage();
2215 if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) {
2216 MS_LOG(EXCEPTION) << "Redistribution for Sens init failed.";
2217 }
2218 RedistributionOpListPtr sens_redistribution_list = tensor_redistribution.InferTensorRedistributionOperatorList();
2219 MS_EXCEPTION_IF_NULL(sens_redistribution_list);
2220
2221 return sens_redistribution_list;
2222 }
2223
2224 // reshape1 ---> depend ---> call @sub_graph(x, y, z)
2225 // sub_graph(x, y, z): reshape2(y)
2226 // find the reshape1 through y
RefParameterToActualNode(const AnfNodePtr & node)2227 static AnfNodePtr RefParameterToActualNode(const AnfNodePtr &node) {
2228 if (!node->isa<Parameter>()) {
2229 return nullptr;
2230 }
2231 auto node_param_ptr = node->cast<ParameterPtr>();
2232 if (node_param_ptr->has_default()) {
2233 return node;
2234 }
2235 auto sub_func_graph = node_param_ptr->func_graph();
2236 auto call_cnodes_map = sub_func_graph->func_graph_cnodes_index();
2237 auto sub_graph_parameters = sub_func_graph->parameters();
2238 auto curr_param_iter = std::find(sub_graph_parameters.begin(), sub_graph_parameters.end(), node);
2239 if (curr_param_iter == sub_graph_parameters.end()) {
2240 MS_LOG(EXCEPTION) << "Cannot find param " << node_param_ptr->DebugString() << " in current sub_graph";
2241 }
2242 size_t curr_param_index = static_cast<size_t>(curr_param_iter - sub_graph_parameters.begin());
2243 for (const auto &node_pair : call_cnodes_map) {
2244 if (!node_pair.first->first->isa<CNode>() || node_pair.first->second > 0) {
2245 continue;
2246 }
2247 auto cnode = node_pair.first->first->cast<CNodePtr>();
2248 auto cnode_input = cnode->input(curr_param_index + 1);
2249 auto pre_cnode = GetInputNodeWithFilter(cnode_input, [&](const CNodePtr &cnode) {
2250 bool filter = IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
2251 IsPrimitiveCNode(cnode, prim::kPrimDepend);
2252 return std::make_pair(filter, 1);
2253 });
2254 if (pre_cnode) {
2255 return pre_cnode;
2256 }
2257 }
2258 return nullptr;
2259 }
2260
IsCommonOp(const AnfNodePtr & node)2261 static bool IsCommonOp(const AnfNodePtr &node) {
2262 CNodePtr cnode = node->cast<CNodePtr>();
2263 bool is_comm_op =
2264 IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>() && !IsPrimitiveCNode(node, prim::kPrimReshape);
2265 return is_comm_op;
2266 }
2267
FindPrevLayout(const AnfNodePtr & node,bool * is_input_param)2268 static std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node, bool *is_input_param) {
2269 if (node->isa<Parameter>()) {
2270 auto node_param_ptr = node->cast<ParameterPtr>();
2271 if (node_param_ptr->has_default()) {
2272 // Only when the real input of Reshape is a parameter that the strategy of Reshape will be assigned to this
2273 // parameter.
2274 *is_input_param = true;
2275 return CreateParameterLayout(node);
2276 }
2277
2278 // the node is parameter of sub-graph
2279 auto actual_node = RefParameterToActualNode(node);
2280 if (actual_node) {
2281 return FindPrevLayout(actual_node, is_input_param);
2282 }
2283 return nullptr;
2284 }
2285 if (!node->isa<CNode>()) {
2286 return nullptr;
2287 }
2288 CNodePtr cnode = node->cast<CNodePtr>();
2289 if (IsValueNode<FuncGraph>(cnode->input(0))) {
2290 auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
2291 auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first;
2292 if (!pre_node) {
2293 return nullptr;
2294 }
2295 return FindPrevLayout(pre_node, is_input_param);
2296 }
2297 if (!IsValueNode<Primitive>(cnode->input(0))) {
2298 return nullptr;
2299 }
2300 if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
2301 return cnode->user_data<TensorLayout>();
2302 }
2303 if (IsCommonOp(node)) {
2304 auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
2305 if (!layout_ptr) {
2306 MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
2307 }
2308 return layout_ptr;
2309 }
2310 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2311 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
2312 if (prim->name() == prim::kPrimTupleGetItem->name()) {
2313 auto tuple_index = GetTupleGetItemIndex(cnode);
2314 auto tuple_getitem_input = cnode->input(1)->cast<CNodePtr>();
2315 if (IsValueNode<FuncGraph>(tuple_getitem_input->input(0))) {
2316 auto fg = GetValueNode<FuncGraphPtr>(tuple_getitem_input->input(0));
2317 auto pre_node = GetRealKernelNode(fg->output(), tuple_index, nullptr).first;
2318 if (!pre_node) {
2319 return nullptr;
2320 }
2321 return FindPrevLayout(pre_node, is_input_param);
2322 }
2323 auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), LongToSize(tuple_index));
2324 if (!layout_ptr) {
2325 MS_LOG(EXCEPTION) << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a "
2326 "parallel care node "
2327 "before tuple_getitem!";
2328 }
2329 return layout_ptr;
2330 }
2331 for (size_t index = 0; index < cnode->size(); ++index) {
2332 if (prim->name() == DEPEND && index != 1) {
2333 continue;
2334 }
2335 auto layout_ptr = FindPrevLayout(cnode->inputs()[index], is_input_param);
2336 if (!layout_ptr) {
2337 continue;
2338 }
2339 return layout_ptr;
2340 }
2341 return nullptr;
2342 }
2343
ReshapeInit(const std::vector<AnfNodePtr> & all_nodes)2344 static void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
2345 MS_LOG(DEBUG) << "=============Do ReshapeInit start=============";
2346 for (auto &node : all_nodes) {
2347 auto cnode = node->cast<CNodePtr>();
2348 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
2349 continue;
2350 }
2351 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
2352 if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
2353 continue;
2354 }
2355 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
2356 MS_EXCEPTION_IF_NULL(prim);
2357 OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
2358 if (operator_info == nullptr) {
2359 MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
2360 }
2361 if (prim->name() != RESHAPE) {
2362 continue;
2363 }
2364
2365 bool is_input_param = false;
2366 auto prev_layout_ptr = FindPrevLayout(cnode->input(1), &is_input_param);
2367 if (prev_layout_ptr) {
2368 auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2369 reshape_info_ptr->SetInputLayout(*prev_layout_ptr);
2370 } else {
2371 MS_LOG(WARNING)
2372 << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error";
2373 }
2374 auto attrs = prim->attrs();
2375 if (StrategyFound(attrs) && !is_input_param) {
2376 MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!";
2377 }
2378 MS_ASSERT(cnode->size() == RESHAPE_INPUT_SIZE);
2379
2380 bool is_next_reshape = false;
2381 mindspore::HashSet<AnfNodePtr> visit;
2382 auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape, &visit, -1, -1, prev_layout_ptr);
2383 if (next_layout_ptr == nullptr) {
2384 std::string is_reshape = is_next_reshape ? "true" : "false";
2385 MS_LOG(WARNING) << "FindNextLayout for " << cnode->fullname_with_scope()
2386 << " return nullptr, and is_next_reshape is " << is_next_reshape
2387 << ". If reshape is not the last primitive, there must be some error.";
2388 }
2389 if (next_layout_ptr) {
2390 auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2391 reshape_info_ptr->SetOutputLayout(*next_layout_ptr);
2392 } else if (is_next_reshape && prev_layout_ptr != nullptr) {
2393 auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2394 reshape_info_ptr->SetOutputLayout(*prev_layout_ptr);
2395 }
2396 if (operator_info->Init(nullptr, nullptr) == FAILED) {
2397 MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed";
2398 }
2399 }
2400 MS_LOG(DEBUG) << "=============Do ReshapeInit end=============";
2401 }
2402
HandleDependLoss(const CNodePtr & cnode,size_t curr_depth)2403 static CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) {
2404 if (curr_depth > MAX_RECURSIVE_DEPTH) {
2405 MS_LOG(WARNING) << "When handling the loss node of Depend, exceeded the max recursive depth: "
2406 << MAX_RECURSIVE_DEPTH;
2407 return nullptr;
2408 }
2409 // Handle return->depend->loss
2410 if (IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
2411 (IsPrimitiveCNode(cnode, prim::kPrimCast) && !cnode->has_user_data<OperatorInfo>())) {
2412 auto depend_before = cnode->input(1)->cast<CNodePtr>();
2413 MS_EXCEPTION_IF_NULL(depend_before);
2414 return HandleDependLoss(depend_before, ++curr_depth);
2415 }
2416 return cnode;
2417 }
2418
FindLossCNode(const FuncGraphPtr & func_graph)2419 static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
2420 LossNodeInfo loss_node_info;
2421 MS_EXCEPTION_IF_NULL(func_graph);
2422 CNodePtr return_node = func_graph->get_return();
2423 MS_EXCEPTION_IF_NULL(return_node);
2424 if (return_node->size() < 2) {
2425 MS_LOG(EXCEPTION) << "Failure: " << return_node->DebugString() << " size is smaller than 2";
2426 }
2427 auto pre_node_pair = GetRealKernelNode(return_node->input(1), -1, nullptr);
2428 auto pre_node = pre_node_pair.first;
2429 MS_EXCEPTION_IF_NULL(pre_node);
2430 auto pre_cnode = pre_node->cast<CNodePtr>();
2431
2432 if (pre_cnode == nullptr || !IsValueNode<Primitive>(pre_cnode->input(0))) {
2433 return loss_node_info;
2434 }
2435 if (!IsValueNode<Primitive>(pre_cnode->input(0))) {
2436 MS_LOG(DEBUG) << "pre_cnode:" << pre_cnode->ToString();
2437 return loss_node_info;
2438 }
2439 auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
2440 // notice: the GetNext op has not input
2441 if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
2442 MS_LOG(INFO) << "The loss is: " << current_prim->name();
2443 loss_node_info.loss_node = pre_cnode;
2444 return loss_node_info;
2445 }
2446
2447 // return -> tuple_getitem -> loss
2448 if (pre_node_pair.second != -1) {
2449 loss_node_info.has_tuple_getitem = true;
2450 loss_node_info.dout_index = pre_node_pair.second;
2451 loss_node_info.loss_node = pre_cnode;
2452 return loss_node_info;
2453 }
2454
2455 // return -> make_tuple
2456 if (current_prim->name() == MAKE_TUPLE) {
2457 return loss_node_info;
2458 }
2459
2460 // return -> loss
2461 loss_node_info.loss_node = pre_cnode;
2462 MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
2463 return loss_node_info;
2464 }
2465
GetLossNodeGradOutputLayout(const LossNodeInfo & node_info)2466 static TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) {
2467 TensorLayouts ret;
2468 auto loss_cnode = node_info.loss_node;
2469 MS_EXCEPTION_IF_NULL(loss_cnode);
2470
2471 ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
2472 MS_EXCEPTION_IF_NULL(prim_anf_node);
2473 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
2474 MS_EXCEPTION_IF_NULL(prim);
2475 if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) {
2476 MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now";
2477 return ret;
2478 }
2479
2480 OperatorInfoPtr operator_info = loss_cnode->user_data<OperatorInfo>();
2481 if (!operator_info) {
2482 return ret;
2483 }
2484 MS_EXCEPTION_IF_NULL(operator_info);
2485 TensorInfo loss_grad_tensor_info;
2486 size_t op_output_size = operator_info->outputs_tensor_info().size();
2487 MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is "
2488 << node_info.has_tuple_getitem << ", the output size is " << op_output_size << ", the dout_index is "
2489 << node_info.dout_index;
2490
2491 if ((op_output_size == 0) || (op_output_size <= LongToSize(node_info.dout_index))) {
2492 MS_LOG(EXCEPTION) << "The index is " << node_info.dout_index << ", but the size of outputs is " << op_output_size;
2493 }
2494
2495 if (!node_info.has_tuple_getitem && (op_output_size > 1)) {
2496 MS_LOG(EXCEPTION) << "Currently, it is not supported that the sens is a tuple.";
2497 }
2498
2499 loss_grad_tensor_info = operator_info->outputs_tensor_info()[LongToSize(node_info.dout_index)];
2500 ret.push_back(loss_grad_tensor_info.tensor_layout());
2501 return ret;
2502 }
2503
SplitSens(const CNodePtr & grad_sens_node,const TensorLayout & loss_grad_layout)2504 static void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) {
2505 MS_EXCEPTION_IF_NULL(grad_sens_node);
2506 if (grad_sens_node->size() <= 1) {
2507 MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2";
2508 }
2509 AnfNodePtr sens_tensor_node = grad_sens_node->input(1);
2510 MS_EXCEPTION_IF_NULL(sens_tensor_node);
2511 Shapes sens_shapes = GetNodeShape(sens_tensor_node);
2512 if (sens_shapes.size() != 1) {
2513 MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1";
2514 }
2515 // If the shape of sens tensor is [] or [1], no need to split it.
2516 Shape sens_shape = sens_shapes[0];
2517 if (sens_shape.empty() || ((sens_shape.size() == 1) && (sens_shape[0] == 1))) {
2518 if (sens_tensor_node->isa<Parameter>()) {
2519 auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
2520 MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
2521 sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
2522 }
2523 MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
2524 return;
2525 }
2526 auto loss_shape = loss_grad_layout.tensor_shape().array();
2527 auto loss_tensor_map = loss_grad_layout.tensor_map_before();
2528 bool multi_split = std::any_of(loss_tensor_map.begin(), loss_tensor_map.end(),
2529 [](const auto &tensor_map) { return tensor_map.size() != 1; });
2530 if ((loss_shape != sens_shape) && !multi_split) {
2531 MS_LOG(EXCEPTION) << "The shape of sens is not equal to loss output, it is unsupported now. Sens shape is "
2532 << ShapeToString(sens_shape) << ", loss shape is " << ShapeToString(loss_shape);
2533 }
2534 MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", split it.";
2535
2536 if (!IsValueNode<Tensor>(sens_tensor_node)) {
2537 if (sens_tensor_node->isa<Parameter>()) {
2538 MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
2539 AbstractBasePtr abstract = sens_tensor_node->abstract();
2540 MS_EXCEPTION_IF_NULL(abstract);
2541 auto slice_shape = loss_grad_layout.slice_shape().array();
2542 std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
2543 MS_EXCEPTION_IF_NULL(parallel_shape);
2544 auto cloned_abstract = abstract->Clone();
2545 MS_EXCEPTION_IF_NULL(cloned_abstract);
2546 cloned_abstract->set_shape(parallel_shape);
2547 sens_tensor_node->set_abstract(cloned_abstract);
2548 auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
2549 sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
2550 return;
2551 }
2552 bool is_dynamic = InDynamicGraph(sens_tensor_node->cast<CNodePtr>());
2553 if (sens_tensor_node->isa<CNode>() && !is_dynamic) {
2554 auto op_list_ptr = InferSensRedistribution(sens_tensor_node, loss_grad_layout);
2555 if (op_list_ptr == nullptr) {
2556 return;
2557 }
2558 auto sens_tensor_cnode = sens_tensor_node->cast<CNodePtr>();
2559 auto func_graph = grad_sens_node->func_graph();
2560 MS_EXCEPTION_IF_NULL(func_graph);
2561 TensorRedistributionPtr tensor_redistribution = std::make_shared<TensorRedistribution>();
2562 InsertRedistribution(op_list_ptr, grad_sens_node, func_graph, 1, sens_tensor_cnode, tensor_redistribution);
2563 return;
2564 }
2565 if (is_dynamic) {
2566 return;
2567 }
2568 MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter or CNode, it is unsupported now.";
2569 }
2570
2571 // Use _GetTensorSlice operator to split the sens tensor
2572 FuncGraphPtr func_graph = grad_sens_node->func_graph(); // only cnode can get the graph
2573 MS_EXCEPTION_IF_NULL(func_graph);
2574 Operator op = CreateGetTensorSliceOp(loss_grad_layout);
2575 InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS);
2576 }
2577
InsertForwardOps(const OperatorInfoPtr & distribute_operator,const CNodePtr & cnode)2578 static void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
2579 MS_EXCEPTION_IF_NULL(distribute_operator);
2580 MS_EXCEPTION_IF_NULL(cnode);
2581 if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) {
2582 return;
2583 }
2584 OperatorVector forward_op = distribute_operator->forward_op();
2585 // for gmm, its make tuple will inherit its op info,
2586 // which will lead to insert allreduce for maketuple.
2587 if (!forward_op.empty() && !IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
2588 MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name();
2589 ForwardCommunication(forward_op, cnode);
2590 }
2591 }
2592
StepReplace(const std::vector<AnfNodePtr> & all_nodes)2593 static void StepReplace(const std::vector<AnfNodePtr> &all_nodes) {
2594 for (auto &node : all_nodes) {
2595 MS_EXCEPTION_IF_NULL(node);
2596 if (node->isa<CNode>()) {
2597 auto cnode = node->cast<CNodePtr>();
2598 if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE) ||
2599 IsSomePrimitive(cnode, SEND)) {
2600 continue;
2601 }
2602
2603 OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
2604 // StepReplace
2605 MS_EXCEPTION_IF_NULL(distribute_operator);
2606 auto replace_op = distribute_operator->replace_op();
2607 if (!replace_op.empty()) {
2608 MS_LOG(INFO) << "StepReplaceOp " << cnode->ToString();
2609 StepReplaceOp(replace_op, cnode);
2610 }
2611
2612 // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore.
2613 auto replace_graph = distribute_operator->replace_graph(cnode);
2614 if (!replace_op.empty() && replace_graph) {
2615 MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used";
2616 }
2617 if (replace_graph) {
2618 MS_LOG(INFO) << "StepReplaceGraph " << cnode->ToString();
2619 StepReplaceGraph(replace_graph, cnode, distribute_operator);
2620 }
2621 if (distribute_operator->name().find(RESHAPEINFO) != std::string::npos) {
2622 auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(distribute_operator);
2623 if (!reshape_info->InterleavedParallel()) {
2624 continue;
2625 }
2626 auto reshape_redis = reshape_info->ReshapeRedistribution();
2627 InsertRedistributionForMicroInterleaved(reshape_redis, {cnode, 1}, cnode->func_graph(), cnode,
2628 cnode->input(kIndex1)->cast<CNodePtr>());
2629 if (!IsPrimitiveCNode(cnode->input(kIndex1), prim::kPrimVirtualConverterEnd)) {
2630 continue;
2631 }
2632 auto virtual_converter_end = cnode->input(kIndex1)->cast<CNodePtr>();
2633 auto func_graph = cnode->func_graph();
2634 MS_EXCEPTION_IF_NULL(func_graph);
2635 auto manager = func_graph->manager();
2636 MS_EXCEPTION_IF_NULL(manager);
2637 manager->Replace(cnode, virtual_converter_end);
2638 }
2639 }
2640 }
2641 }
2642
StepSplitSens(const std::pair<CNodePtr,LossNodeInfo> & sens_loss_pair)2643 static void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
2644 CNodePtr sens_node = sens_loss_pair.first;
2645 auto loss_node = sens_loss_pair.second;
2646 auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
2647 if (!loss_grad_layout.empty()) {
2648 SplitSens(sens_node, loss_grad_layout[0]);
2649 }
2650 }
2651
2652 // Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
GetSensLossPairs(const FuncGraphPtr & root)2653 static std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
2654 MS_EXCEPTION_IF_NULL(root);
2655 std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs;
2656 for (auto &node : root->nodes()) {
2657 if (!node->isa<CNode>()) {
2658 continue;
2659 }
2660
2661 // cnode(sens)-->cnode(tuple_getitem)
2662 auto sens_cnode = node->cast<CNodePtr>();
2663 AnfNodePtr expect_tuple_getitem = sens_cnode->input(0);
2664 MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
2665 if (!expect_tuple_getitem->isa<CNode>()) {
2666 continue;
2667 }
2668
2669 auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
2670 if (!IsSomePrimitive(expect_tuple_getitem_cnode, prim::kPrimTupleGetItem->name())) {
2671 continue;
2672 }
2673
2674 // cnode(sens)-->cnode(tuple_getitem)-->cnode
2675 AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
2676 MS_EXCEPTION_IF_NULL(expect_anonymous);
2677 if (!expect_anonymous->isa<CNode>()) {
2678 continue;
2679 }
2680
2681 // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
2682 auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
2683 AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
2684 MS_EXCEPTION_IF_NULL(expect_j);
2685 if (!expect_j->isa<CNode>()) {
2686 continue;
2687 }
2688 auto expect_j_cnode = expect_j->cast<CNodePtr>();
2689 if (!IsSomePrimitive(expect_j_cnode, J)) {
2690 continue;
2691 }
2692
2693 if (!IsValueNode<FuncGraph>(expect_j_cnode->input(1))) {
2694 MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
2695 }
2696 auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
2697 auto loss_node_info = FindLossCNode(func_graph);
2698 if (loss_node_info.loss_node == nullptr) {
2699 MS_LOG(WARNING) << "Can not find the loss cnode";
2700 continue;
2701 }
2702 std::pair<CNodePtr, LossNodeInfo> sens_loss_pair = std::make_pair(sens_cnode, loss_node_info);
2703 sens_loss_pairs.push_back(sens_loss_pair);
2704 }
2705 return sens_loss_pairs;
2706 }
2707
HandleSens(const std::vector<std::pair<CNodePtr,LossNodeInfo>> & sens_loss_pairs)2708 static void HandleSens(const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
2709 // split sens must before inserting the operators.
2710 for (auto &pair : sens_loss_pairs) {
2711 // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handle it.
2712 // If the type of sens node is not Tensor, it is unsupported now, do nothing default.
2713 if (IsLastStage()) {
2714 StepSplitSens(pair);
2715 }
2716 }
2717 return;
2718 }
2719
ParallelCommunication(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)2720 static void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
2721 const FuncGraphManagerPtr &manager) {
2722 MS_EXCEPTION_IF_NULL(root);
2723 MS_EXCEPTION_IF_NULL(manager);
2724
2725 std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs = GetSensLossPairs(root);
2726 auto has_backward = HasBackward(root);
2727 // split sens must before inserting the operators.
2728 HandleSens(sens_loss_pairs);
2729
2730 const auto &node_users_map = manager->node_users();
2731 for (auto &node : all_nodes) {
2732 MS_EXCEPTION_IF_NULL(node);
2733 if (node->isa<CNode>()) {
2734 auto cnode = node->cast<CNodePtr>();
2735 if (IsValueNode<FuncGraph>(cnode->input(0))) {
2736 StepRedistribution(cnode, node_users_map);
2737 continue;
2738 }
2739 // the make_tuple is parallel care node, but it may have not operator info
2740 if ((!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) && !IsControlFlowNode(cnode)) {
2741 continue;
2742 }
2743 OperatorInfoPtr distribute_operator = nullptr;
2744 if (!IsControlFlowNode(cnode)) {
2745 distribute_operator = GetDistributeOperator(cnode);
2746 MS_EXCEPTION_IF_NULL(distribute_operator);
2747 }
2748
2749 // skip Send Receive
2750 auto parallel_context = parallel::ParallelContext::GetInstance();
2751 MS_EXCEPTION_IF_NULL(parallel_context);
2752 auto is_pp_interleave = parallel_context->pipeline_interleave();
2753 if (!cnode->HasPrimalAttr(PIPELINE_PARAM) || is_pp_interleave) {
2754 // insert forward ops
2755 if (!IsControlFlowNode(cnode)) {
2756 InsertForwardOps(distribute_operator, cnode);
2757 }
2758
2759 // insert redistribution ops
2760 StepRedistribution(cnode, node_users_map);
2761 }
2762 // insert backward ops
2763 if (!IsControlFlowNode(cnode) && (has_backward || IsPynativeParallel())) {
2764 BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
2765 }
2766 if (!IsControlFlowNode(cnode)) {
2767 distribute_operator->ReplaceNodeInputOrAttrs();
2768 }
2769 } else if (IsValueNode<Tensor>(node) || IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
2770 StepSplitTensor(node, manager);
2771 }
2772 }
2773 // StepReplace
2774 StepReplace(all_nodes);
2775 }
2776
IsGatherInfo(const std::string & name)2777 static bool IsGatherInfo(const std::string &name) {
2778 std::vector<std::string> gather_info_names = {"GatherInfo", "SparseGatherV2Info", "EmbeddingLookupInfo"};
2779 for (std::string info_name : gather_info_names) {
2780 if (name.find(info_name) != std::string::npos) {
2781 return true;
2782 }
2783 }
2784 return false;
2785 }
2786
AssignStrategyMap(const StrategyPtr & stra,const std::string & strategy_key_name,StrategyMap * stra_map)2787 void AssignStrategyMap(const StrategyPtr &stra, const std::string &strategy_key_name, StrategyMap *stra_map) {
2788 if (stra) {
2789 (*stra_map)[strategy_key_name] = stra;
2790 } else {
2791 Strategies new_stra_v;
2792 StrategyPtr new_stra = std::make_shared<Strategy>(g_device_manager->stage_id(), new_stra_v);
2793 (*stra_map)[strategy_key_name] = new_stra;
2794 }
2795 }
2796
AssignManualShapeMapForGather(const OperatorInfoPtr & operator_info,const std::string & param_name,ManualShapeMap * manual_shape_map)2797 void AssignManualShapeMapForGather(const OperatorInfoPtr &operator_info, const std::string ¶m_name,
2798 ManualShapeMap *manual_shape_map) {
2799 if (IsGatherInfo(operator_info->name())) {
2800 auto gather_info = std::dynamic_pointer_cast<GatherInfo>(operator_info);
2801 auto param_split_shapes = gather_info->param_split_shapes();
2802 auto index_offsets = gather_info->index_offsets();
2803 if (param_split_shapes.size() != index_offsets.size()) {
2804 MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
2805 }
2806 std::vector<std::pair<int64_t, int64_t>> manual_shape;
2807 for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
2808 (void)manual_shape.emplace_back(std::make_pair(param_split_shapes[LongToSize(i)], index_offsets[LongToSize(i)]));
2809 }
2810 (*manual_shape_map)[param_name] = manual_shape;
2811 }
2812 }
2813
CheckpointStrategy(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root)2814 static void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
2815 if (!StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
2816 return;
2817 }
2818
2819 StrategyMap stra_map;
2820 TensorInfoMap tensor_info_map;
2821 ManualShapeMap manual_shape_map;
2822 for (auto &node : all_nodes) {
2823 MS_EXCEPTION_IF_NULL(node);
2824 auto cnode = node->cast<CNodePtr>();
2825 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
2826 continue;
2827 }
2828 auto param_names = NodeParameterName(cnode, -1, 0);
2829 if (param_names.empty()) {
2830 continue;
2831 }
2832 string param_name = param_names[0].first;
2833 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2834 MS_EXCEPTION_IF_NULL(prim);
2835 OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
2836 if (operator_info) {
2837 std::string strategy_key_name = prim->name() + "_" + param_name;
2838 StrategyPtr stra;
2839 if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
2840 auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
2841 stra = reshape_info->get_input_shard_strategy();
2842 if (stra == nullptr) {
2843 MS_LOG(INFO) << "Reshape has not input strategy, Skipped";
2844 continue;
2845 }
2846 } else {
2847 stra = operator_info->strategy();
2848 }
2849 AssignStrategyMap(stra, strategy_key_name, &stra_map);
2850
2851 for (auto param_name_pair : param_names) {
2852 tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>();
2853 }
2854 AssignManualShapeMapForGather(operator_info, param_name, &manual_shape_map);
2855 }
2856 }
2857 for (auto &cloned_parameter_node : root->parameters()) {
2858 MS_EXCEPTION_IF_NULL(cloned_parameter_node);
2859 auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
2860 MS_EXCEPTION_IF_NULL(cloned_parameter);
2861
2862 if (!ParameterIsCloned(cloned_parameter_node) && !IsStrategySaved(cloned_parameter_node)) {
2863 continue;
2864 }
2865 std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
2866 auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>();
2867 if (cloned_param_layout == nullptr) {
2868 continue;
2869 }
2870 tensor_info_map[cloned_param_name] = cloned_param_layout;
2871 }
2872 if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, manual_shape_map) != SUCCESS) {
2873 MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
2874 }
2875 }
2876
SetForwardFlag(const std::vector<AnfNodePtr> & all_nodes)2877 static void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
2878 for (auto &node : all_nodes) {
2879 MS_EXCEPTION_IF_NULL(node);
2880 if (!node->isa<CNode>()) {
2881 continue;
2882 }
2883 auto cnode = node->cast<CNodePtr>();
2884 if (!IsValueNode<Primitive>(cnode->input(0))) {
2885 continue;
2886 }
2887
2888 // CNode is globally unique.
2889 MS_LOG(DEBUG) << "Set forward flag " << cnode->DebugString() << ".";
2890 cnode->set_in_forward_flag(true);
2891 }
2892 }
2893
SetForwardFlag(const AnfNodeSet & all_nodes)2894 static void SetForwardFlag(const AnfNodeSet &all_nodes) {
2895 for (auto &node : all_nodes) {
2896 MS_EXCEPTION_IF_NULL(node);
2897 if (!node->isa<CNode>()) {
2898 continue;
2899 }
2900 auto cnode = node->cast<CNodePtr>();
2901 if (!IsValueNode<Primitive>(cnode->input(0))) {
2902 continue;
2903 }
2904
2905 // CNode is globally unique.
2906 cnode->set_in_forward_flag(true);
2907 }
2908 }
2909
ForwardGraph(const FuncGraphPtr & root)2910 std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
2911 MS_EXCEPTION_IF_NULL(root);
2912 auto ret = root->get_return();
2913 MS_EXCEPTION_IF_NULL(ret);
2914 auto all_nodes = TopoSort(ret, SuccDeeperSimple);
2915 std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
2916 return graph_set;
2917 }
2918
FindRootForwardCNode(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & all_nodes)2919 static std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph,
2920 const std::vector<AnfNodePtr> &all_nodes) {
2921 MS_EXCEPTION_IF_NULL(graph);
2922 std::vector<AnfNodePtr> root_forward_nodes;
2923 auto loss_cnode = FindLossCNode(graph).loss_node;
2924 if (loss_cnode == nullptr) {
2925 return root_forward_nodes;
2926 }
2927
2928 auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy();
2929 for (auto &node : all_nodes) {
2930 MS_EXCEPTION_IF_NULL(node);
2931 if (!node->isa<CNode>()) {
2932 continue;
2933 }
2934 auto cnode = node->cast<CNodePtr>();
2935 auto root_node_id = node->UniqueIdThroughCopy();
2936 if (loss_cnode_id == root_node_id) {
2937 root_forward_nodes = DeepLinkedGraphSearch(cnode);
2938 break;
2939 }
2940 }
2941 return root_forward_nodes;
2942 }
2943
InsertShapeOp(const CNodePtr & node,const AnfNodePtr & pre_node,const FuncGraphPtr & root)2944 static void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncGraphPtr &root) {
2945 // shape op doesn't have params and attrs.
2946 OperatorParams params;
2947 OperatorAttrs attrs;
2948 auto shape_value = GetValueNode(node->input(2))->cast<ValueSequencePtr>();
2949 MS_EXCEPTION_IF_NULL(shape_value);
2950 auto shape = shape_value->value();
2951 if (shape.empty()) {
2952 return;
2953 }
2954 OperatorArgs args = std::make_pair(attrs, params);
2955 Operator op = std::make_pair(SHAPE_OP, args);
2956 InsertNode(op, node, 2, pre_node, root, "shape");
2957 }
2958
FindGrad(const CNodePtr & cnode,size_t curr_depth)2959 static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) {
2960 if (curr_depth > MAX_RECURSIVE_DEPTH) {
2961 MS_LOG(WARNING) << "When finding Grad nodes, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH;
2962 return nullptr;
2963 }
2964 for (auto &node : cnode->inputs()) {
2965 if (!node->isa<CNode>()) {
2966 continue;
2967 }
2968 if (!IsPrimitiveCNode(node, prim::kPrimEnvironGet)) {
2969 return FindGrad(node->cast<CNodePtr>(), ++curr_depth);
2970 } else {
2971 return node;
2972 }
2973 }
2974 return nullptr;
2975 }
2976
HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> & all_nodes)2977 static void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) {
2978 // If root graph has reshape op. Find the corresponding parameter.
2979 // Reshape's shape is the shape of the parameter.
2980 auto executor = pipeline::GraphExecutorPy::GetInstance();
2981 for (auto &node : all_nodes) {
2982 if (!node->isa<CNode>()) {
2983 continue;
2984 }
2985 auto cnode = node->cast<CNodePtr>();
2986 if (!IsValueNode<Primitive>(cnode->input(0)) || cnode == nullptr) {
2987 continue;
2988 }
2989 if (cnode->in_forward_flag()) {
2990 // Save strategy in executor
2991 OperatorInfoPtr op_info = cnode->user_data<OperatorInfo>();
2992 if (op_info) {
2993 auto stra_ptr = op_info->strategy();
2994 if (stra_ptr) {
2995 auto strategy = stra_ptr->GetInputDim();
2996 // fullname with scope should be found in step parallel end ir
2997 executor->SetCNodeStrategy(cnode->fullname_with_scope(), strategy);
2998 }
2999 }
3000 continue;
3001 }
3002
3003 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
3004 if (prim->name() != RESHAPE) {
3005 continue;
3006 }
3007
3008 Shape origin_dst_shape = GetValue<std::vector<int64_t>>(cnode->input(2)->cast<ValueNodePtr>()->value());
3009 if (origin_dst_shape.size() == 1 && origin_dst_shape[0] == -1) {
3010 continue;
3011 }
3012 auto root = node->func_graph();
3013 auto grad_node = FindGrad(cnode, 0);
3014 if (grad_node) {
3015 InsertShapeOp(cnode, grad_node, root);
3016 }
3017 }
3018 }
3019
MarkForwardCNode(const FuncGraphPtr & root)3020 void MarkForwardCNode(const FuncGraphPtr &root) {
3021 MS_EXCEPTION_IF_NULL(root);
3022 auto ret = root->get_return();
3023 MS_EXCEPTION_IF_NULL(ret);
3024 auto all_nodes = TopoSort(ret, SuccDeeperSimple);
3025 auto graph_set = FindForwardGraphByRootNodes(all_nodes);
3026
3027 if (graph_set.empty()) {
3028 MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph";
3029 auto fgs = root->manager()->func_graphs();
3030 for (auto fg = fgs.cbegin(); fg != fgs.cend(); ++fg) {
3031 SetForwardFlag((*fg)->nodes());
3032 }
3033 } else {
3034 for (auto func_graph = graph_set.cbegin(); func_graph != graph_set.cend(); ++func_graph) {
3035 MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
3036 auto return_node = (*func_graph)->get_return();
3037 MS_EXCEPTION_IF_NULL(return_node);
3038 auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);
3039 SetForwardFlag(all_dfs_nodes);
3040 auto root_forward_nodes = FindRootForwardCNode(*func_graph, all_nodes);
3041 if (root_forward_nodes.empty()) {
3042 continue;
3043 }
3044 // Mark forward flag for the nodes in root graph.
3045 SetForwardFlag(root_forward_nodes);
3046 }
3047 }
3048 }
3049
set_make_list_for_ifa(CNodePtr make_list,const CNodePtr & next_node)3050 OperatorInfoPtr set_make_list_for_ifa(CNodePtr make_list, const CNodePtr &next_node) {
3051 ValueNodePtr anf_node = next_node->input(0)->cast<ValueNodePtr>();
3052 if (!anf_node) {
3053 return nullptr;
3054 }
3055 PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
3056 if (!prim) {
3057 return nullptr;
3058 }
3059 if (prim->name() != INCRE_FLASH_ATTENTION) {
3060 return nullptr;
3061 }
3062
3063 int kv_index = 1;
3064 OperatorInfoPtr operator_make_list = CreateOperatorInfo(make_list);
3065 auto make_list_prim = GetValueNode<PrimitivePtr>(make_list->input(0));
3066 if (make_list_prim->HasAttr(STAND_ALONE)) {
3067 (void)make_list_prim->DelAttr(STAND_ALONE);
3068 }
3069 OperatorInfoPtr next_operator = next_node->user_data<OperatorInfo>();
3070 StrategyPtr next_node_strategy = next_operator->strategy();
3071 Strategies key_value_strategies;
3072 Dimensions key_value_dim = next_node_strategy->GetInputDim().at(kv_index);
3073 key_value_strategies.push_back(key_value_dim);
3074 auto make_list_stage = next_node_strategy->GetInputStage();
3075 auto make_list_new_in_stra = NewStrategy(make_list_stage, key_value_strategies);
3076 operator_make_list->set_strategy(make_list_new_in_stra);
3077
3078 std::vector<TensorInfo> kv_in_tensor_info(1, next_operator->inputs_tensor_info()[kv_index]);
3079 operator_make_list->set_inputs_tensor_info(kv_in_tensor_info);
3080 return operator_make_list;
3081 }
3082
HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> & all_nodes)3083 static void HandleForwardMakeTupleAndMakeList(const std::vector<AnfNodePtr> &all_nodes) {
3084 for (auto &node : all_nodes) {
3085 if (!AnfNodeIsPrimitive(node, MAKE_TUPLE) && !AnfNodeIsPrimitive(node, MAKE_LIST)) {
3086 continue;
3087 }
3088
3089 auto cnode = node->cast<CNodePtr>();
3090 MS_EXCEPTION_IF_NULL(cnode);
3091 if (!cnode->in_forward_flag()) {
3092 continue;
3093 }
3094
3095 FuncGraphManagerPtr manager = cnode->func_graph()->manager();
3096 MS_EXCEPTION_IF_NULL(manager);
3097
3098 // MakeTuple has multiple users, each user's TensorInfo must be same.
3099 auto make_tuple_list_next_node = CheckMakeTupleSplit(node, manager);
3100 if (make_tuple_list_next_node == nullptr) {
3101 continue;
3102 }
3103 auto make_tuple_list_next_cnode = make_tuple_list_next_node->cast<CNodePtr>();
3104 MS_EXCEPTION_IF_NULL(make_tuple_list_next_cnode);
3105 if (!IsSomePrimitiveList(make_tuple_list_next_cnode, INPUT_IS_TUPLE_OR_LIST_OPS)) {
3106 continue;
3107 }
3108
3109 OperatorInfoPtr op_info = set_make_list_for_ifa(cnode, make_tuple_list_next_cnode);
3110 if (op_info == nullptr) {
3111 op_info = GetDistributeOperator(make_tuple_list_next_cnode);
3112 }
3113 MS_EXCEPTION_IF_NULL(op_info);
3114 cnode->set_user_data<OperatorInfo>(op_info);
3115 }
3116 }
3117
CreateGroupsByCkptFile(const std::string & file)3118 bool CreateGroupsByCkptFile(const std::string &file) {
3119 GroupInfoMap group_info_map;
3120 if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) {
3121 return false;
3122 }
3123
3124 if (CreateGroups(group_info_map) != SUCCESS) {
3125 return false;
3126 }
3127 MS_LOG(INFO) << "Create groups by checkpoint file success";
3128 return true;
3129 }
3130
ReorderForPipelineSplit(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,int64_t pipeline_stages)3131 static void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager,
3132 int64_t pipeline_stages) {
3133 auto parallel_context = parallel::ParallelContext::GetInstance();
3134 MS_EXCEPTION_IF_NULL(parallel_context);
3135 auto is_pp_interleave = parallel_context->pipeline_interleave();
3136 if (is_pp_interleave) {
3137 return;
3138 }
3139 if (!root->has_flag(kSkipAutoParallelCompile) && !root->has_flag(BACKWARD) && pipeline_stages > 1) {
3140 root->set_flag(BACKWARD, true);
3141 if (IsTraining(manager)) {
3142 if (parallel_context->enable_fold_pipeline()) {
3143 MS_LOG(INFO) << "Begin Fold Pipeline Reorder. ";
3144 FoldPipelineReorder(root);
3145 } else {
3146 Reorder(root);
3147 }
3148 } else {
3149 ReorderForPredict(root, manager);
3150 }
3151 }
3152 }
3153
ReorderForGradAccumulation(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)3154 static void ReorderForGradAccumulation(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
3155 if (!root->has_flag(kSkipAutoParallelCompile) && !root->has_flag(BACKWARD) &&
3156 ParallelContext::GetInstance()->grad_accumulation_step() > 1) {
3157 root->set_flag(BACKWARD, true);
3158 auto context = MsContext::GetInstance();
3159 MS_EXCEPTION_IF_NULL(context);
3160 const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
3161 DumpGraph(root, "before_reorder");
3162 if (IsTraining(manager)) {
3163 if (cell_reuse) {
3164 TagMicroBatchBpEndInCellShare(root, manager);
3165 }
3166 std::unordered_map<int64_t, std::vector<CNodePtr>> forward_start;
3167 std::unordered_map<int64_t, std::vector<CNodePtr>> backward_end;
3168 ExtractMicroBatchBorderNodes(root, &forward_start, &backward_end);
3169 ReorderGradAccumulation(root, forward_start, backward_end);
3170 DumpGraph(root, "after_reorder");
3171 } else {
3172 MS_LOG(EXCEPTION) << "Current not support predict with grad_accu";
3173 }
3174 }
3175 }
3176
HandleDataParallel()3177 static void HandleDataParallel() {
3178 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
3179 if (parallel_mode == kDataParallel) {
3180 auto group_info_save_path = common::GetEnv("GROUP_INFO_FILE");
3181 if (!group_info_save_path.empty()) {
3182 std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info;
3183 int64_t device_num = GetCommInfo().device_num;
3184 RankList comm_group;
3185 for (size_t i = 0; i < size_t(device_num); ++i) {
3186 comm_group.push_back(i);
3187 }
3188 ParallelContext::GetInstance()->set_group_ckpt_save_file(group_info_save_path);
3189 if (StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info, comm_group) != SUCCESS) {
3190 MS_LOG(EXCEPTION) << "Save group info failed";
3191 }
3192 }
3193 }
3194 }
3195
MicroBatchPreProcess(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & all_nodes)3196 static void MicroBatchPreProcess(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager,
3197 const std::vector<AnfNodePtr> &all_nodes) {
3198 auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3199 if (pipeline_stages > 1) {
3200 HandleMicroBatch(all_nodes, manager);
3201 ParameterStartNode(all_nodes, manager);
3202 LastStageEndNode(all_nodes, manager, root);
3203 return;
3204 }
3205 TagMicroBatchStart(manager, all_nodes);
3206 TagMicroBatchEnd(manager, all_nodes);
3207 auto context = MsContext::GetInstance();
3208 MS_EXCEPTION_IF_NULL(context);
3209 const auto no_cell_reuse = context->CellReuseLevel() == CellReuseLevel::kNoCellReuse;
3210 bool enable_grad_accu = ParallelContext::GetInstance()->grad_accumulation_step() > 1;
3211 if (no_cell_reuse && enable_grad_accu) {
3212 TagMicroBatchBpEndPrim(root);
3213 TagMicroBatchBpEnd(root);
3214 }
3215 }
3216
MicroBatchPostProcess(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)3217 static void MicroBatchPostProcess(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
3218 auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3219 if (pipeline_stages > 1) {
3220 AddVirtualAssignAdd(root);
3221 HandleReceiveParam(root);
3222 LabelGenMaskMicro(root);
3223 return;
3224 }
3225 if (ParallelContext::GetInstance()->grad_accumulation_step() > 1) {
3226 AddVirtualAssignAdd(root);
3227 LabelGenMaskMicro(root);
3228 }
3229 }
3230
InsertAllReduceForNormValue(const AnfNodePtr & res_node)3231 static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) {
3232 auto cnode = res_node->cast<CNodePtr>();
3233 auto graphs = res_node->func_graph();
3234 MS_EXCEPTION_IF_NULL(graphs);
3235 auto manager = graphs->manager();
3236 MS_EXCEPTION_IF_NULL(manager);
3237 auto &node_user_map = manager->node_users();
3238 if (!IsSomePrimitive(cnode, EXPAND_DIMS)) {
3239 MS_LOG(ERROR) << "Expected the operator expand_dims, but found the " << GetPrimName(cnode)
3240 << "This may cause the calculation of the global norm incorrect";
3241 return;
3242 }
3243 auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3244 auto find_node = res_node;
3245 uint32_t limits = 0;
3246 while (!IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT) && limits < MAX_BFS_DEPTH) {
3247 auto users = node_user_map.at(find_node);
3248 if (users.empty()) {
3249 return;
3250 }
3251 find_node = users.front().first;
3252 ++limits;
3253 }
3254 if (!find_node || !IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT)) {
3255 return;
3256 }
3257 auto anf_node = find_node->cast<CNodePtr>();
3258 if (anf_node->size() > 1 && IsSomePrimitive(anf_node->input(1)->cast<CNodePtr>(), ALL_REDUCE)) {
3259 return;
3260 }
3261 auto sqrt_node = find_node;
3262 auto cur_stage_rank_list = g_device_manager->GetDeviceListInThisStage();
3263 Group cur_stage_device_list;
3264 if (g_device_manager->CreateGroup(cur_stage_rank_list, &cur_stage_device_list) != SUCCESS) {
3265 MS_LOG(EXCEPTION) << "Create the communication group for allreduce in calculating global norm failed, "
3266 "the rank_list is: "
3267 << cur_stage_rank_list;
3268 }
3269 InsertAllReduceToNodeInput(sqrt_node->cast<CNodePtr>(), cur_stage_device_list.name(), PARALLEL_GLOBALNORM);
3270 MS_LOG(INFO) << "Insert the AllReduce for global norm value in stages succeed.";
3271 if (pipeline_stages > 1) {
3272 MS_LOG(INFO) << "Insert the AllReduce for global norm value between stages succeed.";
3273 auto ranks_between_stages = g_device_manager->GetDeviceListBetweenStage();
3274 Group group_between_stages;
3275 if (g_device_manager->CreateGroup(ranks_between_stages, &group_between_stages) != SUCCESS) {
3276 MS_LOG(EXCEPTION) << "Create the communication group for allreduce in calculating global norm "
3277 "with pipeline parallel failed, the rank_list is: "
3278 << cur_stage_rank_list;
3279 }
3280 InsertAllReduceToNodeInput(sqrt_node->cast<CNodePtr>(), group_between_stages.name(), PARALLEL_GLOBALNORM_BETWEEN);
3281 }
3282 }
3283
FindExpandDimsWIthGradScale(const AnfNodePtr & node_ptr,const NodeUsersMap & node_users_map,uint32_t limits)3284 static AnfNodePtr FindExpandDimsWIthGradScale(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map,
3285 uint32_t limits) {
3286 std::queue<AnfNodePtr> visited;
3287 AnfNodePtr queue_node = nullptr;
3288 CNodePtr cnode = nullptr;
3289 AnfNodePtr last_node = nullptr;
3290 uint32_t depth = 0;
3291 if (!node_ptr) {
3292 return nullptr;
3293 }
3294 visited.push(node_ptr);
3295 while (!visited.empty()) {
3296 queue_node = visited.front();
3297 visited.pop();
3298 cnode = queue_node->cast<CNodePtr>();
3299 // MAKE_TUPLE will not appear after the load in the forward graph
3300 if (IsSomePrimitive(cnode, EXPAND_DIMS)) {
3301 auto value = GetAttrsFromAnfNode(queue_node, GRAD_SCALE);
3302 if (!value || !GetValue<bool>(value)) {
3303 continue;
3304 }
3305 return queue_node;
3306 }
3307 if (!IsSomePrimitiveList(
3308 cnode, {ENVIRONGET, MUL, SQUARE, REDUCE_SUM, EXPAND_DIMS, DEPEND, CAST, REF_TO_EMBED, EMBED, LOAD})) {
3309 continue;
3310 }
3311 auto node_set = node_users_map.at(queue_node);
3312 for (auto &node_user : node_set) {
3313 visited.push(node_user.first);
3314 }
3315 if (!last_node || last_node == queue_node) {
3316 if (++depth == limits) {
3317 break;
3318 }
3319 last_node = visited.back();
3320 }
3321 }
3322 return nullptr;
3323 }
3324
InsertDivAndAllReduceForNorm(const NodeUsersMap & node_user_map,const AnfNodePtr & parameter,uint32_t dev_num)3325 static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, const AnfNodePtr ¶meter,
3326 uint32_t dev_num) {
3327 auto params_user_set = node_user_map.at(parameter);
3328 for (auto ¶m_pair : params_user_set) {
3329 auto cnode = param_pair.first->cast<CNodePtr>();
3330 MS_EXCEPTION_IF_NULL(cnode);
3331 if (cnode->in_forward_flag()) {
3332 continue;
3333 }
3334 constexpr size_t bfs_depth = 10;
3335 auto expand_dims_node = FindExpandDimsWIthGradScale(cnode, node_user_map, bfs_depth);
3336 if (!expand_dims_node) {
3337 continue;
3338 }
3339 auto value = GetAttrsFromAnfNode(expand_dims_node, GRAD_SCALE);
3340 if (!value || !GetValue<bool>(value)) {
3341 continue;
3342 }
3343 if (dev_num > 0) {
3344 InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV);
3345 MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->fullname_with_scope()
3346 << " succeed!";
3347 }
3348 // If already inserted allreduce, the pattern will not be matched and thus no allreduce will be inserted.
3349 InsertAllReduceForNormValue(expand_dims_node);
3350 }
3351 }
3352
GetMirrorOp(const NodeUsersMap & node_user_map,const AnfNodePtr & parameter)3353 static AnfNodePtr GetMirrorOp(const NodeUsersMap &node_user_map, const AnfNodePtr ¶meter) {
3354 auto params_user_set = node_user_map.at(parameter);
3355 for (auto ¶m_pair : params_user_set) {
3356 auto cnode = param_pair.first->cast<CNodePtr>();
3357 std::vector<AnfNodePtr> candidate = {cnode};
3358 if (!cnode->in_forward_flag()) {
3359 continue;
3360 }
3361 while (IsInTrivialNodeList(cnode) || IsSomePrimitive(cnode, LOAD) ||
3362 IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather) || IsPrimitiveCNode(cnode, prim::kPrimAllGather)) {
3363 auto load_users = node_user_map.at(cnode);
3364 cnode = node_user_map.at(cnode).front().first->cast<CNodePtr>();
3365 MS_EXCEPTION_IF_NULL(cnode);
3366 (void)std::transform(load_users.begin(), load_users.end(), std::back_inserter(candidate),
3367 [](const auto &v) { return v.first; });
3368 }
3369 for (auto &node : candidate) {
3370 auto local_cnode = node->cast<CNodePtr>();
3371 if (!IsPrimitiveCNode(local_cnode, prim::kPrimMirror) &&
3372 !IsPrimitiveCNode(local_cnode, prim::kPrimMirrorMicroStep) &&
3373 !IsPrimitiveCNode(local_cnode, prim::kPrimMirrorMiniStep)) {
3374 continue;
3375 }
3376 return node;
3377 }
3378 }
3379 return nullptr;
3380 }
3381
HandleGlobalNormScale(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)3382 static void HandleGlobalNormScale(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
3383 auto parameters = root->parameters();
3384 const auto &node_user_map = manager->node_users();
3385 MS_LOG(INFO) << "Start to process the global norm";
3386
3387 for (auto ¶meter : parameters) {
3388 int64_t dev_num = 0;
3389 if (!ParameterRequireGrad(parameter)) {
3390 continue;
3391 }
3392 auto mirror_node = GetMirrorOp(node_user_map, parameter);
3393 auto device_num_ptr = GetAttrsFromAnfNode(mirror_node, DEV_NUM);
3394 if (device_num_ptr && device_num_ptr->isa<Int64Imm>()) {
3395 dev_num = GetValue<int64_t>(device_num_ptr);
3396 }
3397 InsertDivAndAllReduceForNorm(node_user_map, parameter, LongToUint(dev_num));
3398 }
3399 }
3400
MoveMicroMirrorOutCallFunc(const FuncGraphPtr & root)3401 static void MoveMicroMirrorOutCallFunc(const FuncGraphPtr &root) {
3402 AnfNodePtr ret_after = root->get_return();
3403 MS_EXCEPTION_IF_NULL(ret_after);
3404 auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3405 auto manager = root->manager();
3406 for (const auto &node : all_nodes) {
3407 if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep)) {
3408 continue;
3409 }
3410 auto micro_mirror = node->cast<CNodePtr>();
3411 auto param_anf_node = GetInputNodeWithFilter(micro_mirror, [&](const CNodePtr &cnode) {
3412 bool filter = IsPrimitiveCNode(cnode, prim::kPrimMirrorMicroStep) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
3413 IsPrimitiveCNode(cnode, prim::kPrimDepend);
3414 return std::make_pair(filter, 1);
3415 });
3416 if (!param_anf_node->isa<Parameter>()) {
3417 continue;
3418 }
3419 auto param = param_anf_node->cast<ParameterPtr>();
3420 if (param->has_default()) {
3421 continue;
3422 }
3423 auto sub_func_graph = param_anf_node->func_graph();
3424 auto call_cnodes_map = sub_func_graph->func_graph_cnodes_index();
3425 auto sub_graph_parameters = sub_func_graph->parameters();
3426 auto curr_param_iter = std::find(sub_graph_parameters.begin(), sub_graph_parameters.end(), param_anf_node);
3427 if (curr_param_iter == sub_graph_parameters.end()) {
3428 MS_LOG(EXCEPTION) << "Cannot find param " << param_anf_node->DebugString() << " in current sub_graph";
3429 }
3430 size_t curr_param_index = static_cast<size_t>(curr_param_iter - sub_graph_parameters.begin());
3431 AnfNodePtr call_nodes_common_param_input = nullptr;
3432 FuncGraphPtr call_nodes_func_graph = nullptr;
3433 for (const auto &node_pair : call_cnodes_map) {
3434 if (!node_pair.first->first->isa<CNode>() || node_pair.first->second > 0) {
3435 continue;
3436 }
3437 auto cnode = node_pair.first->first->cast<CNodePtr>();
3438 call_nodes_func_graph = cnode->func_graph();
3439 auto cnode_input = cnode->input(curr_param_index + 1);
3440 if (!call_nodes_common_param_input) {
3441 call_nodes_common_param_input = cnode_input;
3442 }
3443 if (call_nodes_common_param_input != cnode_input) {
3444 call_nodes_common_param_input = nullptr;
3445 break;
3446 }
3447 }
3448 if (!call_nodes_common_param_input || !call_nodes_func_graph) {
3449 continue;
3450 }
3451 // Insert new MicroMirror in root func
3452 if (!IsPrimitiveCNode(call_nodes_common_param_input, prim::kPrimMirrorMicroStep)) {
3453 auto new_mirror_node =
3454 NewMicroMirrorPrimByMicroMirror(call_nodes_func_graph, micro_mirror, call_nodes_common_param_input);
3455 for (const auto &node_pair : call_cnodes_map) {
3456 if (!node_pair.first->first->isa<CNode>() || node_pair.first->second > 0) {
3457 continue;
3458 }
3459 manager->SetEdge(node_pair.first->first, curr_param_index + 1, new_mirror_node);
3460 }
3461 }
3462
3463 // Remove MicroMirror in call_func
3464 (void)manager->Replace(micro_mirror, micro_mirror->input(kIndex1));
3465 }
3466 }
3467
MergeMicroMirrorForSharedParameter(const FuncGraphPtr & root)3468 static void MergeMicroMirrorForSharedParameter(const FuncGraphPtr &root) {
3469 AnfNodePtr ret_after = root->get_return();
3470 MS_EXCEPTION_IF_NULL(ret_after);
3471 auto all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3472 auto manager = root->manager();
3473 std::unordered_map<ParameterPtr, std::vector<CNodePtr>> param_mirror_map;
3474 for (const auto &node : all_nodes) {
3475 if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep)) {
3476 continue;
3477 }
3478 auto micro_mirror = node->cast<CNodePtr>();
3479 auto param_anf_node = GetInputNodeWithFilter(micro_mirror, [&](const CNodePtr &cnode) {
3480 bool filter = IsPrimitiveCNode(cnode, prim::kPrimMirrorMicroStep) || IsPrimitiveCNode(cnode, prim::kPrimLoad) ||
3481 IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
3482 IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather);
3483 return std::make_pair(filter, 1);
3484 });
3485 if (!param_anf_node->isa<Parameter>()) {
3486 continue;
3487 }
3488 auto param = param_anf_node->cast<ParameterPtr>();
3489 param_mirror_map[param].push_back(micro_mirror);
3490 }
3491 for (const auto &parm_pair : param_mirror_map) {
3492 if (parm_pair.second.size() <= 1) {
3493 continue;
3494 }
3495 MS_LOG(INFO) << "Parameter " << parm_pair.first->name() << " still has multi mirror user, merge those mirror.";
3496 auto mirror0 = parm_pair.second.front();
3497 for (size_t i = 1; i < parm_pair.second.size(); ++i) {
3498 (void)manager->Replace(parm_pair.second[i], mirror0);
3499 }
3500 }
3501 }
3502
BroadcastMultiOutputs(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager,const Group & group)3503 static void BroadcastMultiOutputs(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, const Group &group) {
3504 auto output = root->get_return()->input(1)->cast<CNodePtr>();
3505 auto output_abstract = output->abstract();
3506 MS_EXCEPTION_IF_NULL(output_abstract);
3507 auto abstract_tuple = output_abstract->cast<abstract::AbstractTuplePtr>();
3508 MS_EXCEPTION_IF_NULL(abstract_tuple);
3509 auto abstract_list = abstract_tuple->elements();
3510
3511 AnfNodePtrList make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
3512 for (size_t i = 0; i < abstract_list.size(); i++) {
3513 auto abstract = abstract_list[i];
3514 MS_EXCEPTION_IF_NULL(abstract);
3515
3516 // TupleGetItem
3517 auto idx = NewValueNode(SizeToLong(i));
3518 CNodePtr tuple_getitem = root->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, idx});
3519 MS_EXCEPTION_IF_NULL(tuple_getitem);
3520 tuple_getitem->set_abstract(abstract);
3521
3522 // Depend: prevent disorder and CSE
3523 if (i > 0) {
3524 tuple_getitem = root->NewCNode({NewValueNode(prim::kPrimDepend), tuple_getitem, make_tuple_input[i]});
3525 MS_EXCEPTION_IF_NULL(tuple_getitem);
3526 tuple_getitem->set_abstract(abstract);
3527 }
3528
3529 // Allreduce
3530 CNodePtr allreduce = root->NewCNode({NewValueNode(prim::kPrimAllReduce), tuple_getitem});
3531 MS_EXCEPTION_IF_NULL(allreduce);
3532 allreduce->set_abstract(abstract);
3533 common::AnfAlgo::SetNodeAttr(OP, MakeValue(REDUCE_OP_SUM), allreduce);
3534 common::AnfAlgo::SetNodeAttr(GROUP, MakeValue(group.name()), allreduce);
3535 // Disable GE allreduce fusion.
3536 common::AnfAlgo::SetNodeAttr(FUSION, MakeValue(static_cast<int64_t>(0)), allreduce);
3537
3538 make_tuple_input.push_back(allreduce);
3539 }
3540
3541 CNodePtr make_tuple_node = root->NewCNode(make_tuple_input);
3542 MS_EXCEPTION_IF_NULL(make_tuple_node);
3543 make_tuple_node->set_abstract(abstract_tuple);
3544 (void)manager->Replace(output, make_tuple_node);
3545 }
3546
BroadcastLastResult(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)3547 static void BroadcastLastResult(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
3548 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
3549 auto pipeline_result_broadcast = parallel::ParallelContext::GetInstance()->pipeline_result_broadcast();
3550 if (IsTraining(manager) || stage_num <= 1 || pipeline_result_broadcast == false) {
3551 return;
3552 }
3553
3554 std::vector<int64_t> rank_list = g_device_manager->GetDeviceListBetweenStage();
3555 Group group;
3556 if (g_device_manager->CreateGroup(rank_list, &group) != SUCCESS) {
3557 MS_LOG(EXCEPTION) << "Create communication group between all pipeline stages failed, the rank_list is: "
3558 << rank_list;
3559 }
3560
3561 auto return_node = root->get_return();
3562 const auto &abstract = return_node->abstract();
3563 if (abstract->isa<abstract::AbstractTuple>()) {
3564 return BroadcastMultiOutputs(root, manager, group);
3565 }
3566
3567 InsertAllReduceToNodeInput(return_node, group.name(), PARALLEL_RESULT_BROADCAST);
3568 return_node->input(1)->set_abstract(abstract);
3569 }
3570
RecordFlopsOriginShape(const FuncGraphManagerPtr & mng)3571 static void RecordFlopsOriginShape(const FuncGraphManagerPtr &mng) {
3572 for (const auto &each_graph : mng->func_graphs()) {
3573 std::list<CNodePtr> graph_orders = each_graph->GetOrderedCnodes();
3574 std::vector<CNodePtr> origin_nodes_topological(graph_orders.cbegin(), graph_orders.cend());
3575 for (const auto &node : origin_nodes_topological) {
3576 if (IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimBatchMatMul) ||
3577 IsPrimitiveCNode(node, prim::kPrimMatMul)) {
3578 node->AddPrimalAttr(kAttrOriginOutputShape, MakeValue(node->abstract()->GetShapeTrack()->GetShapeVector()));
3579 node->AddPrimalAttr(
3580 kAttrOriginInputShapes,
3581 MakeValue<std::vector<ShapeVector>>({node->input(kIndex1)->abstract()->GetShapeTrack()->GetShapeVector(),
3582 node->input(kIndex2)->abstract()->GetShapeTrack()->GetShapeVector()}));
3583 } else if (IsPrimitiveCNode(node, prim::kPrimFlashAttentionScore)) {
3584 node->AddPrimalAttr(
3585 kAttrOriginInputShapes,
3586 MakeValue<std::vector<ShapeVector>>({node->input(kIndex1)->abstract()->GetShapeTrack()->GetShapeVector(),
3587 node->input(kIndex2)->abstract()->GetShapeTrack()->GetShapeVector()}));
3588 }
3589 }
3590 }
3591 }
3592
IsVirtualDatasetDynamicShape(const FuncGraphPtr & func_graph)3593 bool IsVirtualDatasetDynamicShape(const FuncGraphPtr &func_graph) {
3594 MS_EXCEPTION_IF_NULL(func_graph);
3595 auto all_nodes = TopoSort(func_graph->get_return());
3596 for (const auto &node : all_nodes) {
3597 if (!node->isa<CNode>()) {
3598 continue;
3599 }
3600 auto cnode = node->cast<CNodePtr>();
3601 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
3602 if (prim == nullptr) {
3603 continue;
3604 }
3605 MS_EXCEPTION_IF_NULL(prim);
3606 if (prim->name() == VIRTUAL_DATA_SET) {
3607 MS_LOG(INFO) << "VIRTUAL_DATA_SET: " << cnode->DebugString();
3608 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
3609 auto input_node = cnode->input(i);
3610 auto base_shape = input_node->Shape();
3611 MS_EXCEPTION_IF_NULL(base_shape);
3612 std::vector<int64_t> shape_vec = base_shape->GetShapeVector();
3613 MS_LOG(INFO) << "VIRTUAL_DATA_SET: " << node->fullname_with_scope() << ", shape:" << shape_vec;
3614 if (std::find(shape_vec.begin(), shape_vec.end(), -1) != shape_vec.end()) {
3615 return true;
3616 }
3617 }
3618 }
3619 }
3620 return false;
3621 }
3622
HandleSilentCheck(const FuncGraphPtr & root,const FuncGraphManagerPtr & mng)3623 static void HandleSilentCheck(const FuncGraphPtr &root, const FuncGraphManagerPtr &mng) {
3624 auto env = common::GetEnv(NPU_ASD_ENABLE);
3625 if (env != kSilentCheckEnvEnable) {
3626 return;
3627 }
3628 auto sdc = std::make_shared<SilentCheck>(root, mng);
3629 if (sdc == nullptr) {
3630 MS_LOG(EXCEPTION) << "The silent check env got nullptr;";
3631 }
3632 sdc->GetLossScale();
3633 sdc->ModifySilentCheckOps();
3634 }
3635
ParallelPartProcess(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)3636 static void ParallelPartProcess(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root,
3637 const FuncGraphManagerPtr &manager) {
3638 ReshapeInit(all_nodes);
3639
3640 SetCastForParamNotRecompute(all_nodes);
3641
3642 HandleRootReshapeAndSaveStrategy(all_nodes);
3643
3644 HandleForwardMakeTupleAndMakeList(all_nodes);
3645
3646 // if the input or parameter has multiple users, check whether its split strategies are consistent.
3647 CheckParameterSplit(all_nodes);
3648
3649 HandleSymbolicKeyInstance(root, all_nodes);
3650
3651 // cover Parallel shape
3652 CoverSliceShape(root);
3653
3654 // handle input is not used
3655 HandleNoUsedParameter(root);
3656
3657 // set the shape for optimizer's clone tensor
3658 SetClonedTensorShapeForOptimizer(root);
3659
3660 HandleCameAndAdaFactorOpt(root, all_nodes, manager);
3661
3662 InsertUniformRealForTaggedNodes(manager, all_nodes);
3663
3664 auto adasum_param_tensor_layout_map = AdaSumParamTensorLayout(root);
3665 bool is_apply_adasum = HandleAdaSum(root, all_nodes, &adasum_param_tensor_layout_map);
3666
3667 if (MergeEntireShapeForDynamic(root) != Status::SUCCESS) {
3668 MS_LOG(EXCEPTION) << "Merge entire shape for dynamic shape failed.";
3669 }
3670
3671 auto parallel_context = parallel::ParallelContext::GetInstance();
3672 MS_EXCEPTION_IF_NULL(parallel_context);
3673 auto is_pp_interleave = parallel_context->pipeline_interleave();
3674 std::shared_ptr<PipelinePostProcess> pipeline_processor;
3675 auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3676 if (pipeline_stages > 1 && is_pp_interleave) {
3677 pipeline_processor =
3678 std::make_shared<PipelinePostProcess>(manager, g_device_manager->stage_id(), pipeline_stages, root);
3679 pipeline_processor->Init(all_nodes);
3680 pipeline_processor->ModifySendRecvAttr(all_nodes);
3681 }
3682 // ForwardCommunication BackwardCommunication TensorRedistribution
3683 ParallelCommunication(root, all_nodes, manager);
3684 SplitNotParallelCareOpsInterleaved(root);
3685 EraseVirtualConverter(root);
3686 if (is_apply_adasum) {
3687 HandleMirrorInAdaSum(root, &adasum_param_tensor_layout_map);
3688 }
3689
3690 if (pipeline_stages > 1 && is_pp_interleave) {
3691 MS_EXCEPTION_IF_NULL(pipeline_processor);
3692 pipeline_processor->GraphPartition(all_nodes);
3693 pipeline_processor->ElimGraphStage();
3694 pipeline_processor->ModifyParameterList();
3695 }
3696
3697 // save strategy as checkpoint for multi-train
3698 auto all_nodes_after_pp = TopoSort(root->get_return(), SuccDeeperSimple);
3699 if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
3700 CheckpointStrategy(all_nodes_after_pp, root);
3701 }
3702 auto comm_group = FindCommonMirrorGroup(root);
3703 StrategyCheckpoint::GetInstance().set_common_mirror_group(comm_group);
3704 MoveMicroMirrorOutCallFunc(root);
3705 HandleGlobalNormScale(root, manager);
3706 if (pipeline_stages > 1 && is_pp_interleave) {
3707 pipeline_processor->HandleSendParam();
3708 MarkForwardCNode(root);
3709 }
3710 MergeMicroMirrorForSharedParameter(root);
3711 // Insert TensorToTuple for FlashAttentionScore if input actual_seq_len is tensor
3712 PostProcessActualSeqLenInputForFlashAttentionScore(root, all_nodes);
3713 return;
3714 }
3715
StepParallel(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)3716 bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
3717 #if defined(__linux__) && defined(WITH_BACKEND)
3718 if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
3719 return false;
3720 }
3721 #endif
3722 MS_EXCEPTION_IF_NULL(root);
3723 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
3724 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
3725 HandleDataParallel();
3726 FuncGraphManagerPtr manager;
3727 pipeline::ResourceBasePtr res;
3728 if (optimizer == nullptr) {
3729 manager = root->manager();
3730 res = std::make_shared<pipeline::Resource>();
3731 res->set_manager(manager);
3732 } else {
3733 res = optimizer->resource();
3734 MS_EXCEPTION_IF_NULL(res);
3735 manager = res->manager();
3736 }
3737
3738 MS_EXCEPTION_IF_NULL(manager);
3739 auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
3740 if (IsTraining(manager)) {
3741 root->set_flag(kTraining, true);
3742 }
3743 // assume no change to graph
3744 bool changes = false;
3745 // control whether use model_parallel mode
3746 if (!IsAutoParallelCareGraph(root) || (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY)) || HasNestedMetaFg(root)) {
3747 if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) {
3748 MS_LOG(INFO) << "Strategies would be ignored in " << parallel_mode
3749 << ", shard() only valid in [semi_]auto_parallel.";
3750 root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
3751 }
3752 ReorderForPipelineSplit(root, manager, pipeline_stages);
3753 ReorderForGradAccumulation(root, manager);
3754 return changes;
3755 }
3756
3757 MSLogTime msTime;
3758 msTime.Start();
3759 DumpGraph(root, std::string(STEP_PARALLEL_BEGIN));
3760 RecordFlopsOriginShape(manager);
3761 AnfNodePtr ret = root->get_return();
3762 MS_EXCEPTION_IF_NULL(ret);
3763 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
3764 std::reverse(all_nodes.begin(), all_nodes.end());
3765 bool merged = MergeConcatSlice(all_nodes, manager);
3766 if (merged) {
3767 all_nodes = TopoSort(ret, SuccDeeperSimple);
3768 }
3769 if (pipeline_stages <= 1 && parallel_mode != kAutoParallel && ParallelInit() != SUCCESS) {
3770 MS_LOG(EXCEPTION) << "Parallel init failed";
3771 }
3772
3773 // Insert TupleToTensor for FA if actual_seq_len input is tuple type.
3774 PreProcessActualSeqLenInputForFlashAttentionScore(root, all_nodes);
3775
3776 MicroBatchPreProcess(root, manager, all_nodes);
3777 // mark the forward cnodes, parallel only care these nodes
3778 MarkForwardCNode(root);
3779 HandleSilentCheck(root, manager);
3780 // tag dynamic shape graph
3781 TagDynamicShapeFuncGraph(root);
3782 UpdateMicroBatchInterleavedStatus(all_nodes);
3783 if (parallel_mode != kAutoParallel) {
3784 TOTAL_OPS = 0;
3785 ExceptionIfHasCommunicationOp(all_nodes);
3786
3787 if (IsInsertVirtualOutput(root)) {
3788 InsertVirtualOutput(root, all_nodes);
3789 AnfNodePtr ret_after = root->get_return();
3790 MS_EXCEPTION_IF_NULL(ret_after);
3791 all_nodes = TopoSort(ret_after, SuccDeeperSimple);
3792 }
3793
3794 // extract shape and strategy, set operator_info
3795 ExtractInformation(all_nodes);
3796 }
3797
3798 ParallelPartProcess(all_nodes, root, manager);
3799 BroadcastLastResult(root, manager);
3800 MicroBatchPostProcess(root, all_nodes);
3801 UpdateParamSymbolicShape(root);
3802 DumpGraph(root, std::string(STEP_PARALLEL_END));
3803
3804 // step parallel only run once
3805 root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true);
3806 // Keep all func graph for parallel before save result.
3807 SetReserved(root);
3808 res->SetResult(pipeline::kStepParallelGraph, root);
3809
3810 // in auto parallel mode, no need to check if strategies set
3811 root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
3812
3813 msTime.End();
3814 uint64_t time = msTime.GetRunTimeUS();
3815 MS_LOG(INFO) << "Now leaving step parallel, used time: " << time << " us";
3816 return changes;
3817 }
3818 } // namespace parallel
3819 } // namespace mindspore
3820