1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/parameter_manager.h"
18
19 #include <inttypes.h>
20 #include <sys/time.h>
21 #include <algorithm>
22
23 #include <map>
24 #include <memory>
25 #include <set>
26 #include <string>
27 #include <unordered_map>
28 #include <utility>
29
30 #include "base/core_ops.h"
31 #include "frontend/operator/ops.h"
32 #include "frontend/optimizer/optimizer.h"
33 #include "frontend/parallel/context.h"
34 #include "frontend/parallel/device_manager.h"
35 #include "frontend/parallel/graph_util/generate_graph.h"
36 #include "frontend/parallel/graph_util/graph_info.h"
37 #include "frontend/parallel/graph_util/node_info.h"
38 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
39 #include "frontend/parallel/node_check.h"
40 #include "ir/param_info.h"
41 #include "ir/tensor.h"
42 #include "utils/trace_base.h"
43 #include "utils/comm_manager.h"
44 #include "utils/ms_context.h"
45 #include "utils/symbolic.h"
46 #include "mindspore/core/utils/parallel_node_check.h"
47 #include "frontend/parallel/step_parallel_utils.h"
48
49 namespace mindspore {
50 namespace parallel {
FindRefKeyNodeUsers(const RefKeyPair & ref_key_pair,bool (* IsCareNode)(const CNodePtr &))51 static ParameterUsersInfo FindRefKeyNodeUsers(const RefKeyPair &ref_key_pair, bool (*IsCareNode)(const CNodePtr &)) {
52 // Dealing with the RefKey case
53 ParameterUsersInfo parameter_user_info;
54 auto refkeys = ref_key_pair.second;
55 auto cnode = ref_key_pair.first;
56
57 auto cnode_ptr = cnode->cast<CNodePtr>();
58 if ((cnode_ptr == nullptr) || !IsValueNode<Primitive>(cnode_ptr->input(0)) || !IsCareNode(cnode_ptr)) {
59 return parameter_user_info;
60 }
61
62 if (refkeys.size() > 1) {
63 MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << "'s inputs have more than 1 RefKeys";
64 }
65 MS_EXCEPTION_IF_NULL(cnode->func_graph());
66 auto cnode_func_graph = cnode->func_graph();
67 MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
68
69 // Find the RefKey being used
70 auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
71 for (auto &candidate : candidate_set_by_refkey) {
72 auto candidate_node = candidate.first;
73 auto c = candidate_node->cast<CNodePtr>();
74 if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
75 continue;
76 }
77 parameter_user_info.second.second.insert(candidate);
78 }
79
80 // Find the corresponding Parameter being used
81 std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
82 if (parameters.size() != 1) {
83 MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
84 }
85 parameter_user_info.first = parameters[0]->cast<ParameterPtr>()->name();
86 parameter_user_info.second.first = parameters[0];
87 auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
88 for (auto &candidate : candidate_set_by_para) {
89 auto candidate_node = candidate.first;
90 auto c = candidate_node->cast<CNodePtr>();
91 if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
92 continue;
93 }
94 parameter_user_info.second.second.insert(candidate);
95 }
96 return parameter_user_info;
97 }
98
FindParameterNodeUsers(const AnfNodePtr & node)99 static ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node) {
100 // In this case, node is a Parameter
101 ParameterUsersInfo parameter_user_info;
102 MS_EXCEPTION_IF_NULL(node->func_graph());
103 MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
104 auto candidate_set = node->func_graph()->manager()->node_users()[node];
105 for (auto &candidate : candidate_set) {
106 auto candidate_node = candidate.first;
107 if (IsPrimitiveCNode(candidate_node, prim::kPrimLoad)) {
108 if (candidate.second != 1) {
109 continue;
110 }
111 auto load_node_users = node->func_graph()->manager()->node_users()[candidate_node];
112 for (auto &node_user : load_node_users) {
113 auto cnode = node_user.first->cast<CNodePtr>();
114 if (cnode == nullptr || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
115 continue;
116 }
117 parameter_user_info.second.second.insert(node_user);
118 }
119 } else {
120 auto c = candidate_node->cast<CNodePtr>();
121 if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) {
122 continue;
123 }
124 parameter_user_info.second.second.insert(candidate);
125 }
126 }
127 parameter_user_info.first = node->cast<ParameterPtr>()->name();
128 parameter_user_info.second.first = node;
129 return parameter_user_info;
130 }
131
CNodeWithRefKeys(const AnfNodePtr & cnode)132 static RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode) {
133 MS_EXCEPTION_IF_NULL(cnode);
134 std::vector<AnfNodePtr> refkeys;
135 if (cnode->isa<CNode>()) {
136 auto cnode_ptr = cnode->cast<CNodePtr>();
137 auto inputs = cnode_ptr->inputs();
138 for (auto &one_input : inputs) {
139 if (IsValueNode<RefKey>(one_input)) {
140 refkeys.push_back(one_input);
141 }
142 }
143 if (refkeys.size() >= 1) {
144 return std::make_pair(cnode, refkeys);
145 }
146 }
147 return {nullptr, refkeys};
148 }
149
FindParameterUsers(const AnfNodePtr & node,bool (* IsCareNode)(const CNodePtr &))150 ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) {
151 ParameterUsersInfo parameter_users_info;
152
153 auto cnode_with_refkeys = CNodeWithRefKeys(node);
154 if (cnode_with_refkeys.first != nullptr) {
155 // the node is a ref key node
156 return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode);
157 } else if (node->isa<Parameter>()) {
158 // the node is a parameter node
159 return FindParameterNodeUsers(node);
160 }
161
162 return parameter_users_info;
163 }
164
IsUsedParameter(const FuncGraphPtr & graph,const AnfNodePtr & parameter,size_t max_depth)165 static bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, size_t max_depth) {
166 if (max_depth > MAX_RECURSIVE_DEPTH) {
167 MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
168 }
169 MS_EXCEPTION_IF_NULL(graph);
170 MS_EXCEPTION_IF_NULL(parameter);
171 auto manager = graph->manager();
172 auto node_users = manager->node_users()[parameter];
173 if (node_users.empty()) {
174 return false;
175 }
176 for (auto node_user : node_users) {
177 auto use_node = node_user.first->cast<CNodePtr>();
178 if (IsValueNode<FuncGraph>(use_node->input(0))) {
179 auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
180 auto parameters = graph_sub->parameters();
181 auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
182 return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
183 }
184 if (use_node->input(0)->isa<CNode>()) {
185 auto cnode = use_node->input(0)->cast<CNodePtr>();
186 if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
187 return true;
188 }
189 auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
190 auto parameters = graph_sub->parameters();
191 auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
192 return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
193 }
194 return true;
195 }
196 return true;
197 }
198
GetGroupByTensorInfo(const TensorInfo & tensor_info)199 static RankList GetGroupByTensorInfo(const TensorInfo &tensor_info) {
200 CheckGlobalDeviceManager();
201 int64_t rank = g_device_manager->global_rank();
202 RankList stage_device_list = g_device_manager->GetDeviceListInThisStage();
203 Shape dev_matrix_shape = tensor_info.tensor_layout().device_arrangement().array();
204 Shape tensor_map = tensor_info.tensor_layout().tensor_map().array();
205
206 DeviceMatrix dev_matrix(rank, stage_device_list, dev_matrix_shape);
207 RankList group_devices;
208 if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
209 MS_LOG(EXCEPTION) << "Get devices by tensor map failed";
210 }
211
212 std::sort(group_devices.begin(), group_devices.end());
213 return group_devices;
214 }
215
GetParameterSliceInfo(const std::pair<AnfNodePtr,int64_t> & param_info)216 static ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info) {
217 auto user_cnode = param_info.first->cast<CNodePtr>();
218 MS_EXCEPTION_IF_NULL(user_cnode);
219 auto user_input_index = param_info.second;
220 OperatorInfoPtr op_info = user_cnode->user_data<OperatorInfo>();
221 MS_EXCEPTION_IF_NULL(op_info);
222
223 TensorInfo tensor_info;
224 if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) {
225 auto param_index = IntToSize(GetValue<int>(user_cnode->GetPrimalAttr(PARAM_INDEX)));
226 tensor_info = op_info->inputs_tensor_info()[param_index];
227 } else {
228 size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
229 if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
230 MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
231 << ", but the index is " << (user_input_index - 1);
232 }
233 tensor_info = op_info->inputs_tensor_info()[LongToSize(user_input_index - 1)];
234 }
235
236 ParameterSliceInfo parameter_slice_info;
237 parameter_slice_info.slice_shape = tensor_info.slice_shape();
238 parameter_slice_info.group_ranks = GetGroupByTensorInfo(tensor_info);
239 MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << (user_input_index - 1)
240 << ", the slice shape is " << tensor_info.slice_shape() << ", the origin shape is "
241 << tensor_info.shape() << ", the group rank list is " << parameter_slice_info.group_ranks;
242 return parameter_slice_info;
243 }
244
CheckParameterSplit(const std::vector<AnfNodePtr> & all_nodes)245 void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
246 for (auto &node : all_nodes) {
247 ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode);
248 auto &users_set = parameter_users_info.second.second;
249 if (users_set.size() <= 1) {
250 continue;
251 }
252
253 auto parameter_name = parameter_users_info.first;
254 MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
255 auto &first_user = users_set.front();
256 ParameterSliceInfo parameter_slice_info = GetParameterSliceInfo(first_user);
257 Shape first_user_slice_shape = parameter_slice_info.slice_shape;
258 RankList first_user_group_list = parameter_slice_info.group_ranks;
259
260 for (auto iter = users_set.begin() + 1; iter != users_set.end(); ++iter) {
261 auto &user = *iter;
262 ParameterSliceInfo user_slice_info = GetParameterSliceInfo(user);
263 Shape user_slice_shape = user_slice_info.slice_shape;
264 RankList user_group_list = user_slice_info.group_ranks;
265 if (first_user_slice_shape != user_slice_shape) {
266 MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
267 << " has multiple users, but the slice shapes are different";
268 }
269
270 if (ParallelContext::GetInstance()->pipeline_stage_split_num() == 1 && first_user_group_list != user_group_list) {
271 MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
272 << " has multiple users, but the group rank list are different, "
273 << "the group rank list for first user is " << first_user_group_list
274 << ", and the group rank list for this user is " << user_group_list;
275 }
276 }
277 }
278 }
279
280 namespace {
RevertSymbolicKeyInstance(const FuncGraphPtr & root,const AnfNodePtr & node)281 void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) {
282 MS_EXCEPTION_IF_NULL(root);
283 MS_EXCEPTION_IF_NULL(node);
284 auto symbolic_key = GetValueNode<SymbolicKeyInstancePtr>(node);
285 MS_EXCEPTION_IF_NULL(symbolic_key);
286 auto all_upstream_node = root->manager()->node_users()[node];
287 for (auto &upstream_node : all_upstream_node) {
288 FuncGraphPtr fg = upstream_node.first->func_graph();
289 if (symbolic_key->node()->isa<Parameter>()) {
290 for (auto ¶m : root->parameters()) {
291 if (*param == *symbolic_key->node()) {
292 AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param});
293 MS_EXCEPTION_IF_NULL(reverted_node);
294 MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString();
295 (void)fg->manager()->Replace(node, reverted_node);
296 MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString();
297 }
298 }
299 }
300 }
301 }
302 } // namespace
303
HandleSymbolicKeyInstance(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)304 void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
305 MS_EXCEPTION_IF_NULL(root);
306 for (auto &node : all_nodes) {
307 // revert back SymbolicKeyInstance to embed() primitive
308 if (IsValueNode<SymbolicKeyInstance>(node)) {
309 RevertSymbolicKeyInstance(root, node);
310 continue;
311 }
312 }
313 }
314
ParameterIsCloned(const AnfNodePtr & parameter_node)315 bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
316 MS_EXCEPTION_IF_NULL(parameter_node);
317 auto cloned_parameter = parameter_node->cast<ParameterPtr>();
318 MS_EXCEPTION_IF_NULL(cloned_parameter);
319
320 // find the clone parameter
321 if (!cloned_parameter->has_default()) {
322 return false;
323 }
324 auto param_value = cloned_parameter->param_info();
325 if (param_value == nullptr) {
326 return false;
327 }
328 bool cloned = param_value->cloned();
329 if (!cloned) {
330 return false;
331 }
332
333 MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
334 return true;
335 }
336
HandleNoUsedParameter(const FuncGraphPtr & root)337 void HandleNoUsedParameter(const FuncGraphPtr &root) {
338 MS_EXCEPTION_IF_NULL(root);
339 bool full_batch = ParallelContext::GetInstance()->full_batch();
340 if (full_batch) {
341 return;
342 }
343
344 // in grad accumulation mode, if use dynamic lr, it has some parameters in optimizer which no used for first graph,
345 // but used for second graph(such as global_step), so can not change their shapes
346 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
347 if (grad_accumulation_step > 1) {
348 MS_LOG(INFO) << "In grad accumulation mode, do not handle no used parameters";
349 return;
350 }
351
352 auto dev_num = g_device_manager->stage_device_num();
353 auto parameters = root->parameters();
354 for (auto ¶meter : parameters) {
355 if (IsUsedParameter(root, parameter, 0)) {
356 continue;
357 }
358 auto parameter_shape = GetNodeShape(parameter);
359 if (parameter_shape.empty()) {
360 continue;
361 }
362 Shape slice_shape = parameter_shape[0];
363 if (slice_shape.empty()) {
364 continue;
365 }
366 slice_shape[0] = slice_shape[0] / dev_num;
367 auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
368 auto abstract = parameter->abstract();
369 MS_EXCEPTION_IF_NULL(abstract);
370 auto abstract_cloned = abstract->Clone();
371 MS_EXCEPTION_IF_NULL(abstract_cloned);
372 abstract_cloned->set_shape(slice_shape_ptr);
373 parameter->set_abstract(abstract_cloned);
374 }
375 }
376
IsFullySplitParameter(const ParameterPtr & param_ptr)377 static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
378 auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
379 if (tensor_layout == nullptr) {
380 return false;
381 }
382
383 auto dev_mat_shape = tensor_layout->device_arrangement().array();
384 auto tensor_map = tensor_layout->tensor_map().array();
385 int64_t rank = g_device_manager->global_rank();
386 RankList rank_list = g_device_manager->GetDeviceListInThisStage();
387 DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
388 RankList group_devices;
389 if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
390 MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
391 return false;
392 }
393
394 if (group_devices.size() == 1) {
395 MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
396 return true;
397 }
398 return false;
399 }
400
InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr,int> & node_user,const FuncGraphManagerPtr & manager,const AnfNodePtr & accu_parameter)401 static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
402 const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
403 auto cnode = node_user.first->cast<CNodePtr>();
404 auto prim = GetCNodePrimitive(cnode);
405 if (prim == nullptr) {
406 MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node";
407 return;
408 }
409 OperatorAttrs attrs;
410 auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
411 auto value_node = NewValueNode(py_instance);
412 std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
413 auto graph = cnode->func_graph();
414 auto virtual_node = graph->NewCNode(virtual_node_input);
415 manager->SetEdge(cnode, node_user.second, virtual_node);
416 }
417
HandleFullySplitParameters(const FuncGraphPtr & root)418 void HandleFullySplitParameters(const FuncGraphPtr &root) {
419 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
420 if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) {
421 return;
422 }
423
424 auto parameters = root->parameters();
425 auto node_users_map = root->manager()->node_users();
426 for (auto ¶meter : parameters) {
427 auto param_ptr = parameter->cast<ParameterPtr>();
428 MS_EXCEPTION_IF_NULL(param_ptr);
429
430 if (!IsFullySplitParameter(param_ptr)) {
431 continue;
432 }
433
434 auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
435 if (!accu_parameter) {
436 continue; // some parameters no need to handle, such as itself or lr
437 }
438
439 auto node_users = node_users_map[parameter];
440 for (auto &user : node_users) {
441 auto node = user.first;
442 auto cnode = node->cast<CNodePtr>();
443 MS_EXCEPTION_IF_NULL(cnode);
444 if (!cnode->in_forward_flag()) {
445 continue;
446 }
447 InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter);
448 MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name();
449 break; // only need to insert once, if the parameter has many users
450 }
451 }
452 }
453
SetClonedTensorShapeForOptimizer(const FuncGraphPtr & root)454 void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
455 MS_EXCEPTION_IF_NULL(root);
456 for (auto &cloned_parameter_node : root->parameters()) {
457 MS_EXCEPTION_IF_NULL(cloned_parameter_node);
458 auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
459 MS_EXCEPTION_IF_NULL(cloned_parameter);
460
461 if (!ParameterIsCloned(cloned_parameter_node)) {
462 continue;
463 }
464 auto param_value = cloned_parameter->param_info();
465 if (param_value == nullptr) {
466 continue;
467 }
468 // get the cloned index
469 int64_t cloned_index = param_value->cloned_index();
470
471 // find the be cloned parameter
472 bool found_be_cloned_parameter = false;
473 ParameterPtr cloned_from_parameter = nullptr;
474 AnfNodePtr cloned_from_node = nullptr;
475 for (auto &be_cloned_parameter_node : root->parameters()) {
476 MS_EXCEPTION_IF_NULL(be_cloned_parameter_node);
477 auto be_cloned_parameter = be_cloned_parameter_node->cast<ParameterPtr>();
478 MS_EXCEPTION_IF_NULL(be_cloned_parameter);
479 if (!be_cloned_parameter->has_default()) {
480 continue;
481 }
482
483 auto param_value_in = be_cloned_parameter->param_info();
484 if (param_value_in == nullptr) {
485 continue;
486 }
487 if (!param_value_in->be_cloned()) {
488 continue;
489 }
490
491 // get the be cloned index
492 auto &be_cloned_index = param_value_in->be_cloned_index();
493 if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
494 found_be_cloned_parameter = true;
495 cloned_from_parameter = be_cloned_parameter;
496 cloned_from_node = be_cloned_parameter_node;
497 }
498 }
499
500 if (found_be_cloned_parameter) {
501 // set the shape and tensor layout for cloned parameter
502 std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
503 if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
504 MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
505 continue;
506 }
507 auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>();
508 MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
509 MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
510 auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
511 MS_EXCEPTION_IF_NULL(cloned_abstract);
512 // from pipeline or grad accumulation
513 if (param_name.find(ACCU_GRADS) != std::string::npos) {
514 auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
515 std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
516 MS_EXCEPTION_IF_NULL(parallel_shape);
517 cloned_abstract->set_shape(parallel_shape);
518 // in opt shard, accu_grad's shape is different from the original param's shape
519 if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
520 TensorLayout new_layout = *tensor_layout;
521 new_layout.set_opt_shard_group("");
522 tensor_layout = std::make_shared<TensorLayout>(new_layout);
523 }
524 } else {
525 cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
526 }
527 cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
528 cloned_parameter_node->set_abstract(cloned_abstract);
529 MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
530 << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
531 << ", clone index is: " << cloned_index;
532 } else {
533 MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is "
534 << cloned_index << ", but not found the be cloned parameter";
535 }
536 }
537 }
538
HandleAdaFactorOpt(const FuncGraphPtr & root)539 void HandleAdaFactorOpt(const FuncGraphPtr &root) {
540 MS_EXCEPTION_IF_NULL(root);
541 for (auto ¶m_node : root->parameters()) {
542 MS_EXCEPTION_IF_NULL(param_node);
543 auto param = param_node->cast<ParameterPtr>();
544 MS_EXCEPTION_IF_NULL(param);
545 std::string param_name = param->name();
546 if (param_name.find(EXP_AVG) != std::string::npos) {
547 continue;
548 }
549
550 auto tensor_layout = param->user_data<TensorLayout>();
551 if (tensor_layout == nullptr) {
552 continue;
553 }
554
555 int64_t row_col_count = 0;
556 int64_t exp_avg_sq_count = 0;
557 for (auto &row_col_node : root->parameters()) {
558 MS_EXCEPTION_IF_NULL(row_col_node);
559 auto row_col_param = row_col_node->cast<ParameterPtr>();
560 MS_EXCEPTION_IF_NULL(row_col_param);
561 std::string row_col_param_name = row_col_param->name();
562 std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
563 std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
564 std::string exp_avg_name = EXP_AVG_SQ + param_name;
565
566 if ((row_col_param_name != exp_row_name) && (row_col_param_name != exp_col_name) &&
567 (row_col_param_name != exp_avg_name)) {
568 continue;
569 }
570
571 auto slice_shape = tensor_layout->slice_shape().array();
572 auto shape_size = slice_shape.size();
573 bool is_row_or_col_param = (row_col_param_name == exp_row_name) || (row_col_param_name == exp_col_name);
574 if (is_row_or_col_param && shape_size <= 1) {
575 continue;
576 }
577
578 if (row_col_param_name == exp_avg_name && shape_size != 1) {
579 continue;
580 }
581
582 auto origin_shape = tensor_layout->tensor_shape().array();
583 auto dev_mat = tensor_layout->device_arrangement().array();
584 auto tensor_map = tensor_layout->tensor_map().array();
585
586 if (row_col_param_name == exp_row_name) {
587 slice_shape.pop_back();
588 origin_shape.pop_back();
589 tensor_map.pop_back();
590 row_col_count++;
591 } else if (row_col_param_name == exp_col_name) {
592 (void)slice_shape.erase(slice_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
593 (void)origin_shape.erase(origin_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
594 (void)tensor_map.erase(tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
595 row_col_count++;
596 } else {
597 exp_avg_sq_count++;
598 }
599
600 TensorLayout new_tensor_layout;
601 if (new_tensor_layout.InitFromVector(dev_mat, tensor_map, origin_shape) != SUCCESS) {
602 MS_LOG(EXCEPTION) << "Init tensor layout failed";
603 }
604
605 auto cloned_abstract = row_col_node->abstract()->Clone();
606 MS_EXCEPTION_IF_NULL(cloned_abstract);
607 std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
608 MS_EXCEPTION_IF_NULL(parallel_shape);
609 cloned_abstract->set_shape(parallel_shape);
610 row_col_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(new_tensor_layout));
611 row_col_node->set_abstract(cloned_abstract);
612 MS_LOG(INFO) << "Set the slice shape for " << row_col_param_name << ", origin shape is " << origin_shape
613 << ", new slice shape is " << slice_shape;
614
615 if (row_col_count == 2 || exp_avg_sq_count == 1) {
616 break;
617 }
618 }
619 }
620 }
621 } // namespace parallel
622 } // namespace mindspore
623