• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/parameter_manager.h"
18 
19 #include <inttypes.h>
20 #include <sys/time.h>
21 #include <algorithm>
22 
23 #include <map>
24 #include <memory>
25 #include <set>
26 #include <string>
27 #include <unordered_map>
28 #include <utility>
29 
30 #include "base/core_ops.h"
31 #include "frontend/operator/ops.h"
32 #include "frontend/optimizer/optimizer.h"
33 #include "frontend/parallel/context.h"
34 #include "frontend/parallel/device_manager.h"
35 #include "frontend/parallel/graph_util/generate_graph.h"
36 #include "frontend/parallel/graph_util/graph_info.h"
37 #include "frontend/parallel/graph_util/node_info.h"
38 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
39 #include "frontend/parallel/node_check.h"
40 #include "ir/param_info.h"
41 #include "ir/tensor.h"
42 #include "utils/trace_base.h"
43 #include "utils/comm_manager.h"
44 #include "utils/ms_context.h"
45 #include "utils/symbolic.h"
46 #include "mindspore/core/utils/parallel_node_check.h"
47 #include "frontend/parallel/step_parallel_utils.h"
48 
49 namespace mindspore {
50 namespace parallel {
FindRefKeyNodeUsers(const RefKeyPair & ref_key_pair,bool (* IsCareNode)(const CNodePtr &))51 static ParameterUsersInfo FindRefKeyNodeUsers(const RefKeyPair &ref_key_pair, bool (*IsCareNode)(const CNodePtr &)) {
52   // Dealing with the RefKey case
53   ParameterUsersInfo parameter_user_info;
54   auto refkeys = ref_key_pair.second;
55   auto cnode = ref_key_pair.first;
56 
57   auto cnode_ptr = cnode->cast<CNodePtr>();
58   if ((cnode_ptr == nullptr) || !IsValueNode<Primitive>(cnode_ptr->input(0)) || !IsCareNode(cnode_ptr)) {
59     return parameter_user_info;
60   }
61 
62   if (refkeys.size() > 1) {
63     MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << "'s inputs have more than 1 RefKeys";
64   }
65   MS_EXCEPTION_IF_NULL(cnode->func_graph());
66   auto cnode_func_graph = cnode->func_graph();
67   MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
68 
69   // Find the RefKey being used
70   auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
71   for (auto &candidate : candidate_set_by_refkey) {
72     auto candidate_node = candidate.first;
73     auto c = candidate_node->cast<CNodePtr>();
74     if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
75       continue;
76     }
77     parameter_user_info.second.second.insert(candidate);
78   }
79 
80   // Find the corresponding Parameter being used
81   std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
82   if (parameters.size() != 1) {
83     MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
84   }
85   parameter_user_info.first = parameters[0]->cast<ParameterPtr>()->name();
86   parameter_user_info.second.first = parameters[0];
87   auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
88   for (auto &candidate : candidate_set_by_para) {
89     auto candidate_node = candidate.first;
90     auto c = candidate_node->cast<CNodePtr>();
91     if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
92       continue;
93     }
94     parameter_user_info.second.second.insert(candidate);
95   }
96   return parameter_user_info;
97 }
98 
FindParameterNodeUsers(const AnfNodePtr & node)99 static ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node) {
100   // In this case, node is a Parameter
101   ParameterUsersInfo parameter_user_info;
102   MS_EXCEPTION_IF_NULL(node->func_graph());
103   MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
104   auto candidate_set = node->func_graph()->manager()->node_users()[node];
105   for (auto &candidate : candidate_set) {
106     auto candidate_node = candidate.first;
107     if (IsPrimitiveCNode(candidate_node, prim::kPrimLoad)) {
108       if (candidate.second != 1) {
109         continue;
110       }
111       auto load_node_users = node->func_graph()->manager()->node_users()[candidate_node];
112       for (auto &node_user : load_node_users) {
113         auto cnode = node_user.first->cast<CNodePtr>();
114         if (cnode == nullptr || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
115           continue;
116         }
117         parameter_user_info.second.second.insert(node_user);
118       }
119     } else {
120       auto c = candidate_node->cast<CNodePtr>();
121       if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) {
122         continue;
123       }
124       parameter_user_info.second.second.insert(candidate);
125     }
126   }
127   parameter_user_info.first = node->cast<ParameterPtr>()->name();
128   parameter_user_info.second.first = node;
129   return parameter_user_info;
130 }
131 
CNodeWithRefKeys(const AnfNodePtr & cnode)132 static RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode) {
133   MS_EXCEPTION_IF_NULL(cnode);
134   std::vector<AnfNodePtr> refkeys;
135   if (cnode->isa<CNode>()) {
136     auto cnode_ptr = cnode->cast<CNodePtr>();
137     auto inputs = cnode_ptr->inputs();
138     for (auto &one_input : inputs) {
139       if (IsValueNode<RefKey>(one_input)) {
140         refkeys.push_back(one_input);
141       }
142     }
143     if (refkeys.size() >= 1) {
144       return std::make_pair(cnode, refkeys);
145     }
146   }
147   return {nullptr, refkeys};
148 }
149 
FindParameterUsers(const AnfNodePtr & node,bool (* IsCareNode)(const CNodePtr &))150 ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) {
151   ParameterUsersInfo parameter_users_info;
152 
153   auto cnode_with_refkeys = CNodeWithRefKeys(node);
154   if (cnode_with_refkeys.first != nullptr) {
155     // the node is a ref key node
156     return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode);
157   } else if (node->isa<Parameter>()) {
158     // the node is a parameter node
159     return FindParameterNodeUsers(node);
160   }
161 
162   return parameter_users_info;
163 }
164 
IsUsedParameter(const FuncGraphPtr & graph,const AnfNodePtr & parameter,size_t max_depth)165 static bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter, size_t max_depth) {
166   if (max_depth > MAX_RECURSIVE_DEPTH) {
167     MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
168   }
169   MS_EXCEPTION_IF_NULL(graph);
170   MS_EXCEPTION_IF_NULL(parameter);
171   auto manager = graph->manager();
172   auto node_users = manager->node_users()[parameter];
173   if (node_users.empty()) {
174     return false;
175   }
176   for (auto node_user : node_users) {
177     auto use_node = node_user.first->cast<CNodePtr>();
178     if (IsValueNode<FuncGraph>(use_node->input(0))) {
179       auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
180       auto parameters = graph_sub->parameters();
181       auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
182       return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
183     }
184     if (use_node->input(0)->isa<CNode>()) {
185       auto cnode = use_node->input(0)->cast<CNodePtr>();
186       if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
187         return true;
188       }
189       auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
190       auto parameters = graph_sub->parameters();
191       auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
192       return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
193     }
194     return true;
195   }
196   return true;
197 }
198 
GetGroupByTensorInfo(const TensorInfo & tensor_info)199 static RankList GetGroupByTensorInfo(const TensorInfo &tensor_info) {
200   CheckGlobalDeviceManager();
201   int64_t rank = g_device_manager->global_rank();
202   RankList stage_device_list = g_device_manager->GetDeviceListInThisStage();
203   Shape dev_matrix_shape = tensor_info.tensor_layout().device_arrangement().array();
204   Shape tensor_map = tensor_info.tensor_layout().tensor_map().array();
205 
206   DeviceMatrix dev_matrix(rank, stage_device_list, dev_matrix_shape);
207   RankList group_devices;
208   if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
209     MS_LOG(EXCEPTION) << "Get devices by tensor map failed";
210   }
211 
212   std::sort(group_devices.begin(), group_devices.end());
213   return group_devices;
214 }
215 
GetParameterSliceInfo(const std::pair<AnfNodePtr,int64_t> & param_info)216 static ParameterSliceInfo GetParameterSliceInfo(const std::pair<AnfNodePtr, int64_t> &param_info) {
217   auto user_cnode = param_info.first->cast<CNodePtr>();
218   MS_EXCEPTION_IF_NULL(user_cnode);
219   auto user_input_index = param_info.second;
220   OperatorInfoPtr op_info = user_cnode->user_data<OperatorInfo>();
221   MS_EXCEPTION_IF_NULL(op_info);
222 
223   TensorInfo tensor_info;
224   if (IsPrimitiveCNode(user_cnode, prim::kPrimSend)) {
225     auto param_index = IntToSize(GetValue<int>(user_cnode->GetPrimalAttr(PARAM_INDEX)));
226     tensor_info = op_info->inputs_tensor_info()[param_index];
227   } else {
228     size_t input_tensor_info_size = op_info->inputs_tensor_info().size();
229     if (SizeToLong(input_tensor_info_size) <= user_input_index - 1) {
230       MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size
231                         << ", but the index is " << (user_input_index - 1);
232     }
233     tensor_info = op_info->inputs_tensor_info()[LongToSize(user_input_index - 1)];
234   }
235 
236   ParameterSliceInfo parameter_slice_info;
237   parameter_slice_info.slice_shape = tensor_info.slice_shape();
238   parameter_slice_info.group_ranks = GetGroupByTensorInfo(tensor_info);
239   MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << (user_input_index - 1)
240                 << ", the slice shape is " << tensor_info.slice_shape() << ", the origin shape is "
241                 << tensor_info.shape() << ", the group rank list is " << parameter_slice_info.group_ranks;
242   return parameter_slice_info;
243 }
244 
CheckParameterSplit(const std::vector<AnfNodePtr> & all_nodes)245 void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
246   for (auto &node : all_nodes) {
247     ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode);
248     auto &users_set = parameter_users_info.second.second;
249     if (users_set.size() <= 1) {
250       continue;
251     }
252 
253     auto parameter_name = parameter_users_info.first;
254     MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
255     auto &first_user = users_set.front();
256     ParameterSliceInfo parameter_slice_info = GetParameterSliceInfo(first_user);
257     Shape first_user_slice_shape = parameter_slice_info.slice_shape;
258     RankList first_user_group_list = parameter_slice_info.group_ranks;
259 
260     for (auto iter = users_set.begin() + 1; iter != users_set.end(); ++iter) {
261       auto &user = *iter;
262       ParameterSliceInfo user_slice_info = GetParameterSliceInfo(user);
263       Shape user_slice_shape = user_slice_info.slice_shape;
264       RankList user_group_list = user_slice_info.group_ranks;
265       if (first_user_slice_shape != user_slice_shape) {
266         MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
267                           << " has multiple users, but the slice shapes are different";
268       }
269 
270       if (ParallelContext::GetInstance()->pipeline_stage_split_num() == 1 && first_user_group_list != user_group_list) {
271         MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
272                           << " has multiple users, but the group rank list are different, "
273                           << "the group rank list for first user is " << first_user_group_list
274                           << ", and the group rank list for this user is " << user_group_list;
275       }
276     }
277   }
278 }
279 
280 namespace {
RevertSymbolicKeyInstance(const FuncGraphPtr & root,const AnfNodePtr & node)281 void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) {
282   MS_EXCEPTION_IF_NULL(root);
283   MS_EXCEPTION_IF_NULL(node);
284   auto symbolic_key = GetValueNode<SymbolicKeyInstancePtr>(node);
285   MS_EXCEPTION_IF_NULL(symbolic_key);
286   auto all_upstream_node = root->manager()->node_users()[node];
287   for (auto &upstream_node : all_upstream_node) {
288     FuncGraphPtr fg = upstream_node.first->func_graph();
289     if (symbolic_key->node()->isa<Parameter>()) {
290       for (auto &param : root->parameters()) {
291         if (*param == *symbolic_key->node()) {
292           AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param});
293           MS_EXCEPTION_IF_NULL(reverted_node);
294           MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString();
295           (void)fg->manager()->Replace(node, reverted_node);
296           MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString();
297         }
298       }
299     }
300   }
301 }
302 }  // namespace
303 
HandleSymbolicKeyInstance(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)304 void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
305   MS_EXCEPTION_IF_NULL(root);
306   for (auto &node : all_nodes) {
307     // revert back SymbolicKeyInstance to embed() primitive
308     if (IsValueNode<SymbolicKeyInstance>(node)) {
309       RevertSymbolicKeyInstance(root, node);
310       continue;
311     }
312   }
313 }
314 
ParameterIsCloned(const AnfNodePtr & parameter_node)315 bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
316   MS_EXCEPTION_IF_NULL(parameter_node);
317   auto cloned_parameter = parameter_node->cast<ParameterPtr>();
318   MS_EXCEPTION_IF_NULL(cloned_parameter);
319 
320   // find the clone parameter
321   if (!cloned_parameter->has_default()) {
322     return false;
323   }
324   auto param_value = cloned_parameter->param_info();
325   if (param_value == nullptr) {
326     return false;
327   }
328   bool cloned = param_value->cloned();
329   if (!cloned) {
330     return false;
331   }
332 
333   MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
334   return true;
335 }
336 
HandleNoUsedParameter(const FuncGraphPtr & root)337 void HandleNoUsedParameter(const FuncGraphPtr &root) {
338   MS_EXCEPTION_IF_NULL(root);
339   bool full_batch = ParallelContext::GetInstance()->full_batch();
340   if (full_batch) {
341     return;
342   }
343 
344   // in grad accumulation mode, if use dynamic lr, it has some parameters in optimizer which no used for first graph,
345   // but used for second graph(such as global_step), so can not change their shapes
346   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
347   if (grad_accumulation_step > 1) {
348     MS_LOG(INFO) << "In grad accumulation mode, do not handle no used parameters";
349     return;
350   }
351 
352   auto dev_num = g_device_manager->stage_device_num();
353   auto parameters = root->parameters();
354   for (auto &parameter : parameters) {
355     if (IsUsedParameter(root, parameter, 0)) {
356       continue;
357     }
358     auto parameter_shape = GetNodeShape(parameter);
359     if (parameter_shape.empty()) {
360       continue;
361     }
362     Shape slice_shape = parameter_shape[0];
363     if (slice_shape.empty()) {
364       continue;
365     }
366     slice_shape[0] = slice_shape[0] / dev_num;
367     auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
368     auto abstract = parameter->abstract();
369     MS_EXCEPTION_IF_NULL(abstract);
370     auto abstract_cloned = abstract->Clone();
371     MS_EXCEPTION_IF_NULL(abstract_cloned);
372     abstract_cloned->set_shape(slice_shape_ptr);
373     parameter->set_abstract(abstract_cloned);
374   }
375 }
376 
IsFullySplitParameter(const ParameterPtr & param_ptr)377 static bool IsFullySplitParameter(const ParameterPtr &param_ptr) {
378   auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
379   if (tensor_layout == nullptr) {
380     return false;
381   }
382 
383   auto dev_mat_shape = tensor_layout->device_arrangement().array();
384   auto tensor_map = tensor_layout->tensor_map().array();
385   int64_t rank = g_device_manager->global_rank();
386   RankList rank_list = g_device_manager->GetDeviceListInThisStage();
387   DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
388   RankList group_devices;
389   if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
390     MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
391     return false;
392   }
393 
394   if (group_devices.size() == 1) {
395     MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
396     return true;
397   }
398   return false;
399 }
400 
InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr,int> & node_user,const FuncGraphManagerPtr & manager,const AnfNodePtr & accu_parameter)401 static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
402                                           const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
403   auto cnode = node_user.first->cast<CNodePtr>();
404   auto prim = GetCNodePrimitive(cnode);
405   if (prim == nullptr) {
406     MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node";
407     return;
408   }
409   OperatorAttrs attrs;
410   auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
411   auto value_node = NewValueNode(py_instance);
412   std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
413   auto graph = cnode->func_graph();
414   auto virtual_node = graph->NewCNode(virtual_node_input);
415   manager->SetEdge(cnode, node_user.second, virtual_node);
416 }
417 
HandleFullySplitParameters(const FuncGraphPtr & root)418 void HandleFullySplitParameters(const FuncGraphPtr &root) {
419   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
420   if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) {
421     return;
422   }
423 
424   auto parameters = root->parameters();
425   auto node_users_map = root->manager()->node_users();
426   for (auto &parameter : parameters) {
427     auto param_ptr = parameter->cast<ParameterPtr>();
428     MS_EXCEPTION_IF_NULL(param_ptr);
429 
430     if (!IsFullySplitParameter(param_ptr)) {
431       continue;
432     }
433 
434     auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
435     if (!accu_parameter) {
436       continue;  // some parameters no need to handle, such as itself or lr
437     }
438 
439     auto node_users = node_users_map[parameter];
440     for (auto &user : node_users) {
441       auto node = user.first;
442       auto cnode = node->cast<CNodePtr>();
443       MS_EXCEPTION_IF_NULL(cnode);
444       if (!cnode->in_forward_flag()) {
445         continue;
446       }
447       InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter);
448       MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name();
449       break;  // only need to insert once, if the parameter has many users
450     }
451   }
452 }
453 
SetClonedTensorShapeForOptimizer(const FuncGraphPtr & root)454 void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
455   MS_EXCEPTION_IF_NULL(root);
456   for (auto &cloned_parameter_node : root->parameters()) {
457     MS_EXCEPTION_IF_NULL(cloned_parameter_node);
458     auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
459     MS_EXCEPTION_IF_NULL(cloned_parameter);
460 
461     if (!ParameterIsCloned(cloned_parameter_node)) {
462       continue;
463     }
464     auto param_value = cloned_parameter->param_info();
465     if (param_value == nullptr) {
466       continue;
467     }
468     // get the cloned index
469     int64_t cloned_index = param_value->cloned_index();
470 
471     // find the be cloned parameter
472     bool found_be_cloned_parameter = false;
473     ParameterPtr cloned_from_parameter = nullptr;
474     AnfNodePtr cloned_from_node = nullptr;
475     for (auto &be_cloned_parameter_node : root->parameters()) {
476       MS_EXCEPTION_IF_NULL(be_cloned_parameter_node);
477       auto be_cloned_parameter = be_cloned_parameter_node->cast<ParameterPtr>();
478       MS_EXCEPTION_IF_NULL(be_cloned_parameter);
479       if (!be_cloned_parameter->has_default()) {
480         continue;
481       }
482 
483       auto param_value_in = be_cloned_parameter->param_info();
484       if (param_value_in == nullptr) {
485         continue;
486       }
487       if (!param_value_in->be_cloned()) {
488         continue;
489       }
490 
491       // get the be cloned index
492       auto &be_cloned_index = param_value_in->be_cloned_index();
493       if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
494         found_be_cloned_parameter = true;
495         cloned_from_parameter = be_cloned_parameter;
496         cloned_from_node = be_cloned_parameter_node;
497       }
498     }
499 
500     if (found_be_cloned_parameter) {
501       // set the shape and tensor layout for cloned parameter
502       std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
503       if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
504         MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
505         continue;
506       }
507       auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>();
508       MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
509       MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
510       auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
511       MS_EXCEPTION_IF_NULL(cloned_abstract);
512       // from pipeline or grad accumulation
513       if (param_name.find(ACCU_GRADS) != std::string::npos) {
514         auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
515         std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
516         MS_EXCEPTION_IF_NULL(parallel_shape);
517         cloned_abstract->set_shape(parallel_shape);
518         // in opt shard, accu_grad's shape is different from the original param's shape
519         if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
520           TensorLayout new_layout = *tensor_layout;
521           new_layout.set_opt_shard_group("");
522           tensor_layout = std::make_shared<TensorLayout>(new_layout);
523         }
524       } else {
525         cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
526       }
527       cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
528       cloned_parameter_node->set_abstract(cloned_abstract);
529       MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
530                    << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
531                    << ", clone index is:  " << cloned_index;
532     } else {
533       MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is  "
534                         << cloned_index << ", but not found the be cloned parameter";
535     }
536   }
537 }
538 
HandleAdaFactorOpt(const FuncGraphPtr & root)539 void HandleAdaFactorOpt(const FuncGraphPtr &root) {
540   MS_EXCEPTION_IF_NULL(root);
541   for (auto &param_node : root->parameters()) {
542     MS_EXCEPTION_IF_NULL(param_node);
543     auto param = param_node->cast<ParameterPtr>();
544     MS_EXCEPTION_IF_NULL(param);
545     std::string param_name = param->name();
546     if (param_name.find(EXP_AVG) != std::string::npos) {
547       continue;
548     }
549 
550     auto tensor_layout = param->user_data<TensorLayout>();
551     if (tensor_layout == nullptr) {
552       continue;
553     }
554 
555     int64_t row_col_count = 0;
556     int64_t exp_avg_sq_count = 0;
557     for (auto &row_col_node : root->parameters()) {
558       MS_EXCEPTION_IF_NULL(row_col_node);
559       auto row_col_param = row_col_node->cast<ParameterPtr>();
560       MS_EXCEPTION_IF_NULL(row_col_param);
561       std::string row_col_param_name = row_col_param->name();
562       std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
563       std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
564       std::string exp_avg_name = EXP_AVG_SQ + param_name;
565 
566       if ((row_col_param_name != exp_row_name) && (row_col_param_name != exp_col_name) &&
567           (row_col_param_name != exp_avg_name)) {
568         continue;
569       }
570 
571       auto slice_shape = tensor_layout->slice_shape().array();
572       auto shape_size = slice_shape.size();
573       bool is_row_or_col_param = (row_col_param_name == exp_row_name) || (row_col_param_name == exp_col_name);
574       if (is_row_or_col_param && shape_size <= 1) {
575         continue;
576       }
577 
578       if (row_col_param_name == exp_avg_name && shape_size != 1) {
579         continue;
580       }
581 
582       auto origin_shape = tensor_layout->tensor_shape().array();
583       auto dev_mat = tensor_layout->device_arrangement().array();
584       auto tensor_map = tensor_layout->tensor_map().array();
585 
586       if (row_col_param_name == exp_row_name) {
587         slice_shape.pop_back();
588         origin_shape.pop_back();
589         tensor_map.pop_back();
590         row_col_count++;
591       } else if (row_col_param_name == exp_col_name) {
592         (void)slice_shape.erase(slice_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
593         (void)origin_shape.erase(origin_shape.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
594         (void)tensor_map.erase(tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
595         row_col_count++;
596       } else {
597         exp_avg_sq_count++;
598       }
599 
600       TensorLayout new_tensor_layout;
601       if (new_tensor_layout.InitFromVector(dev_mat, tensor_map, origin_shape) != SUCCESS) {
602         MS_LOG(EXCEPTION) << "Init tensor layout failed";
603       }
604 
605       auto cloned_abstract = row_col_node->abstract()->Clone();
606       MS_EXCEPTION_IF_NULL(cloned_abstract);
607       std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
608       MS_EXCEPTION_IF_NULL(parallel_shape);
609       cloned_abstract->set_shape(parallel_shape);
610       row_col_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(new_tensor_layout));
611       row_col_node->set_abstract(cloned_abstract);
612       MS_LOG(INFO) << "Set the slice shape for " << row_col_param_name << ", origin shape is " << origin_shape
613                    << ", new slice shape is " << slice_shape;
614 
615       if (row_col_count == 2 || exp_avg_sq_count == 1) {
616         break;
617       }
618     }
619   }
620 }
621 }  // namespace parallel
622 }  // namespace mindspore
623