• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/parameter_manager.h"
18 
19 #include <cinttypes>
20 #include <algorithm>
21 
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <deque>
28 #include <functional>
29 
30 #include "mindspore/core/ops/sequence_ops.h"
31 #include "mindspore/core/ops/other_ops.h"
32 #include "mindspore/core/ops/array_ops.h"
33 #include "mindspore/core/ops/framework_ops.h"
34 #include "utils/hash_map.h"
35 #include "frontend/operator/ops.h"
36 #include "frontend/optimizer/optimizer.h"
37 #include "include/common/utils/parallel_context.h"
38 #include "frontend/parallel/device_manager.h"
39 #include "frontend/parallel/graph_util/generate_graph.h"
40 #include "frontend/parallel/graph_util/graph_info.h"
41 #include "frontend/parallel/graph_util/node_info.h"
42 #include "frontend/parallel/graph_util/get_parallel_info.h"
43 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
44 #include "frontend/parallel/node_check.h"
45 #include "ir/param_info.h"
46 #include "ir/tensor.h"
47 #include "utils/trace_base.h"
48 #include "include/common/utils/comm_manager.h"
49 #include "utils/ms_context.h"
50 #include "utils/symbolic.h"
51 #include "pipeline/jit/ps/pipeline.h"
52 #include "mindspore/core/utils/parallel_node_check.h"
53 #include "frontend/parallel/step_parallel_utils.h"
54 #include "mindspore/core/ops/nn_ops.h"
55 
56 namespace mindspore {
57 namespace parallel {
58 using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
FindRefKeyNodeUsers(const RefKeyPair & ref_key_pair,bool (* IsCareNode)(const CNodePtr &))59 static ParameterUsersInfo FindRefKeyNodeUsers(const RefKeyPair &ref_key_pair, bool (*IsCareNode)(const CNodePtr &)) {
60   // Dealing with the RefKey case
61   ParameterUsersInfo parameter_user_info;
62   auto refkeys = ref_key_pair.second;
63   auto cnode = ref_key_pair.first;
64 
65   auto cnode_ptr = cnode->cast<CNodePtr>();
66   if ((cnode_ptr == nullptr) || !IsValueNode<Primitive>(cnode_ptr->input(0)) || !IsCareNode(cnode_ptr)) {
67     return parameter_user_info;
68   }
69 
70   if (refkeys.size() > 1) {
71     MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << "'s inputs have more than 1 RefKeys";
72   }
73   MS_EXCEPTION_IF_NULL(cnode->func_graph());
74   auto cnode_func_graph = cnode->func_graph();
75   MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
76 
77   // Find the RefKey being used
78   auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
79   for (auto &candidate : candidate_set_by_refkey) {
80     auto candidate_node = candidate.first;
81     auto c = candidate_node->cast<CNodePtr>();
82     if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
83       continue;
84     }
85     parameter_user_info.second.second.insert(candidate);
86   }
87 
88   // Find the corresponding Parameter being used
89   std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
90   if (parameters.size() != 1) {
91     MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
92   }
93   parameter_user_info.first = parameters[0]->cast<ParameterPtr>()->name();
94   parameter_user_info.second.first = parameters[0];
95   auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
96   for (auto &candidate : candidate_set_by_para) {
97     auto candidate_node = candidate.first;
98     auto c = candidate_node->cast<CNodePtr>();
99     if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
100       continue;
101     }
102     parameter_user_info.second.second.insert(candidate);
103   }
104   return parameter_user_info;
105 }
106 
FindParameterNodeUsers(const AnfNodePtr & node,const std::vector<AnfNodePtr> & all_nodes)107 static ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, const std::vector<AnfNodePtr> &all_nodes) {
108   // In this case, node is a Parameter
109   ParameterUsersInfo parameter_user_info;
110   MS_EXCEPTION_IF_NULL(node->func_graph());
111   MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
112   auto candidate_set = node->func_graph()->manager()->node_users()[node];
113   for (auto &candidate : candidate_set) {
114     auto candidate_node = candidate.first;
115     if (IsPrimitiveCNode(candidate_node, prim::kPrimLoad)) {
116       if (candidate.second != 1) {
117         continue;
118       }
119       auto &node_user_map = node->func_graph()->manager()->node_users();
120       auto load_node_users = node_user_map[candidate_node];
121       for (auto &node_user : load_node_users) {
122         auto cnode = node_user.first->cast<CNodePtr>();
123         std::pair<AnfNodePtr, int> child_parallel_care_node;
124         if (IsSomePrimitive(cnode, UPDATESTATE) || !cnode->in_forward_flag()) {
125           continue;
126         }
127         if (!IsSomePrimitive(cnode, MAKE_TUPLE) && (IsParallelCareNode(cnode) || IsAutoParallelCareNode(cnode))) {
128           child_parallel_care_node = node_user;
129         } else {
130           child_parallel_care_node = BFSParallelCareNode(cnode, node_user_map, node_user.second, all_nodes);
131         }
132         if (child_parallel_care_node.first) {
133           cnode = child_parallel_care_node.first->cast<CNodePtr>();
134         } else {
135           continue;
136         }
137         if (cnode == nullptr || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
138           continue;
139         }
140         parameter_user_info.second.second.insert(child_parallel_care_node);
141       }
142     } else {
143       auto c = candidate_node->cast<CNodePtr>();
144       if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) {
145         continue;
146       }
147       parameter_user_info.second.second.insert(candidate);
148     }
149   }
150   parameter_user_info.first = node->cast<ParameterPtr>()->name();
151   parameter_user_info.second.first = node;
152   return parameter_user_info;
153 }
154 
CNodeWithRefKeys(const AnfNodePtr & cnode)155 static RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode) {
156   MS_EXCEPTION_IF_NULL(cnode);
157   std::vector<AnfNodePtr> refkeys;
158   if (cnode->isa<CNode>()) {
159     auto cnode_ptr = cnode->cast<CNodePtr>();
160     auto inputs = cnode_ptr->inputs();
161     for (auto &one_input : inputs) {
162       if (IsValueNode<RefKey>(one_input)) {
163         refkeys.push_back(one_input);
164       }
165     }
166     if (refkeys.size() >= 1) {
167       return std::make_pair(cnode, refkeys);
168     }
169   }
170   return {nullptr, refkeys};
171 }
172 
FindParameterUsers(const AnfNodePtr & node,bool (* IsCareNode)(const CNodePtr &),const std::vector<AnfNodePtr> & all_nodes)173 ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &),
174                                       const std::vector<AnfNodePtr> &all_nodes) {
175   ParameterUsersInfo parameter_users_info;
176 
177   auto cnode_with_refkeys = CNodeWithRefKeys(node);
178   if (cnode_with_refkeys.first != nullptr) {
179     // the node is a ref key node
180     return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode);
181   } else if (node->isa<Parameter>()) {
182     auto param_ptr = node->cast<ParameterPtr>();
183     MS_EXCEPTION_IF_NULL(param_ptr);
184     // the node is a parameter node
185     if (param_ptr->has_default()) {
186       return FindParameterNodeUsers(node, all_nodes);
187     }
188   }
189 
190   return parameter_users_info;
191 }
192 
IsUsedParameter(const FuncGraphPtr & graph,const AnfNodePtr & parameter,size_t max_depth)193 static bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter, size_t max_depth) {
194   if (max_depth > MAX_RECURSIVE_DEPTH) {
195     MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
196   }
197   MS_EXCEPTION_IF_NULL(graph);
198   MS_EXCEPTION_IF_NULL(parameter);
199   auto manager = graph->manager();
200   auto node_users = manager->node_users()[parameter];
201   if (node_users.empty()) {
202     return false;
203   }
204   for (auto node_user : node_users) {
205     auto use_node = node_user.first->cast<CNodePtr>();
206     if (IsValueNode<FuncGraph>(use_node->input(0))) {
207       auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
208       auto parameters = graph_sub->parameters();
209       auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
210       return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
211     }
212     if (use_node->input(0)->isa<CNode>()) {
213       auto cnode = use_node->input(0)->cast<CNodePtr>();
214       if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
215         return true;
216       }
217       auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
218       auto parameters = graph_sub->parameters();
219       auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
220       return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
221     }
222     return true;
223   }
224   return true;
225 }
226 
GetDevListByTensorMapValue(DeviceMatrix dev_matrix,int64_t tensor_map_value,size_t dev_matrix_size)227 static RankList GetDevListByTensorMapValue(DeviceMatrix dev_matrix, int64_t tensor_map_value, size_t dev_matrix_size) {
228   RankList rank_list;
229   if (tensor_map_value >= SizeToLong(dev_matrix_size) || tensor_map_value < MAP_NONE) {
230     MS_LOG(ERROR) << "The size of dev_matrix is " << dev_matrix_size << ", but the tensor map value is "
231                   << tensor_map_value;
232     return rank_list;
233   }
234 
235   if (tensor_map_value == MAP_NONE) {
236     rank_list.push_back(g_device_manager->global_rank());
237     return rank_list;
238   }
239 
240   uint64_t dim = dev_matrix_size - LongToSize(tensor_map_value) - 1;
241   if (dev_matrix.GetDevicesAlongDim(dim, &rank_list) != SUCCESS) {
242     MS_LOG(ERROR) << "Get devices along dim failed";
243   }
244 
245   return rank_list;
246 }
247 
IsSameTensorLayout(const TensorLayout & a,const TensorLayout & b)248 static bool IsSameTensorLayout(const TensorLayout &a, const TensorLayout &b) {
249   if (!a.IsSameTensorShape(b)) {
250     return false;
251   }
252   if (a.IsSameDeviceArrangement(b) && a.IsSameTensorMap(b)) {
253     return true;
254   }
255 
256   Shape a_tensor_map = a.tensor_map().array();
257   Shape b_tensor_map = b.tensor_map().array();
258   if (a_tensor_map.size() != b_tensor_map.size()) {
259     return false;
260   }
261 
262   CheckGlobalDeviceManager();
263   int64_t rank = g_device_manager->global_rank();
264   DeviceMatrix a_dev_matrix(rank, g_device_manager->GetDeviceListInThisStage(), a.device_arrangement().array());
265   DeviceMatrix b_dev_matrix(rank, g_device_manager->GetDeviceListInThisStage(), b.device_arrangement().array());
266   size_t a_dev_mat_size = a.device_arrangement().array().size();
267   size_t b_dev_mat_size = b.device_arrangement().array().size();
268 
269   for (size_t i = 0; i < a_tensor_map.size(); ++i) {
270     if (a_tensor_map[i] == MAP_NONE && b_tensor_map[i] == MAP_NONE) {
271       continue;
272     }
273 
274     RankList a_dev_list_by_dim = GetDevListByTensorMapValue(a_dev_matrix, a_tensor_map[i], a_dev_mat_size);
275     RankList b_dev_list_by_dim = GetDevListByTensorMapValue(b_dev_matrix, b_tensor_map[i], b_dev_mat_size);
276     if (a_dev_list_by_dim.empty() || b_dev_list_by_dim.empty()) {
277       MS_LOG(EXCEPTION) << "Can not get device list by tensor map value, these layouts are " << a.ToString()
278                         << std::endl
279                         << " and " << b.ToString();
280     }
281 
282     if (a_dev_list_by_dim != b_dev_list_by_dim) {
283       return false;
284     }
285   }
286 
287   return true;
288 }
289 
IsSameTensorInfo(const TensorInfo & a,const TensorInfo & b)290 bool IsSameTensorInfo(const TensorInfo &a, const TensorInfo &b) {
291   return IsSameTensorLayout(a.tensor_layout(), b.tensor_layout());
292 }
293 
CheckParameterSplit(const std::vector<AnfNodePtr> & all_nodes)294 void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
295   for (auto &node : all_nodes) {
296     ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode, all_nodes);
297     auto &users_set = parameter_users_info.second.second;
298     if (users_set.size() <= 1) {
299       continue;
300     }
301 
302     auto parameter_name = parameter_users_info.first;
303     MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
304     auto &first_user = users_set.front();
305     auto parameter_tensor_info = GetInputsTensorInfo(first_user);
306 
307     for (auto iter = users_set.begin() + 1; iter != users_set.end(); ++iter) {
308       auto &user = *iter;
309       auto user_tensor_info = GetInputsTensorInfo(user);
310       if (IsSameTensorInfo(parameter_tensor_info, user_tensor_info)) {
311         continue;
312       } else {
313         MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
314                           << " has multiple users, but the TensorInfo are different, they are "
315                           << parameter_tensor_info.tensor_layout().ToString() << std::endl
316                           << " and " << user_tensor_info.tensor_layout().ToString();
317       }
318     }
319   }
320 }
321 
322 namespace {
RevertSymbolicKeyInstance(const FuncGraphPtr & root,const AnfNodePtr & node)323 void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) {
324   MS_EXCEPTION_IF_NULL(root);
325   MS_EXCEPTION_IF_NULL(node);
326   auto symbolic_key = GetValueNode<SymbolicKeyInstancePtr>(node);
327   MS_EXCEPTION_IF_NULL(symbolic_key);
328   auto all_upstream_node = root->manager()->node_users()[node];
329   for (auto &upstream_node : all_upstream_node) {
330     FuncGraphPtr fg = upstream_node.first->func_graph();
331     if (symbolic_key->node()->isa<Parameter>()) {
332       for (auto &param : root->parameters()) {
333         if (*param == *symbolic_key->node()) {
334           AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param});
335           MS_EXCEPTION_IF_NULL(reverted_node);
336           MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString();
337           (void)fg->manager()->Replace(node, reverted_node);
338           MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString();
339         }
340       }
341     }
342   }
343 }
344 }  // namespace
345 
HandleSymbolicKeyInstance(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)346 void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
347   MS_EXCEPTION_IF_NULL(root);
348   for (auto &node : all_nodes) {
349     // revert back SymbolicKeyInstance to embed() primitive
350     if (IsValueNode<SymbolicKeyInstance>(node)) {
351       RevertSymbolicKeyInstance(root, node);
352       continue;
353     }
354   }
355 }
356 
IsStrategySaved(const AnfNodePtr & parameter_node)357 bool IsStrategySaved(const AnfNodePtr &parameter_node) {
358   MS_EXCEPTION_IF_NULL(parameter_node);
359   auto cloned_parameter = parameter_node->cast<ParameterPtr>();
360   MS_EXCEPTION_IF_NULL(cloned_parameter);
361 
362   // find the clone parameter
363   if (!cloned_parameter->has_default()) {
364     return false;
365   }
366   auto param_value = cloned_parameter->param_info();
367   if (param_value == nullptr) {
368     return false;
369   }
370   return param_value->strategy_ckpt_saved();
371 }
372 
ParameterIsCloned(const AnfNodePtr & parameter_node)373 bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
374   MS_EXCEPTION_IF_NULL(parameter_node);
375   auto cloned_parameter = parameter_node->cast<ParameterPtr>();
376   MS_EXCEPTION_IF_NULL(cloned_parameter);
377 
378   // find the clone parameter
379   if (!cloned_parameter->has_default()) {
380     return false;
381   }
382   auto param_value = cloned_parameter->param_info();
383   if (param_value == nullptr) {
384     return false;
385   }
386   bool cloned = param_value->cloned();
387   if (!cloned) {
388     return false;
389   }
390 
391   MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
392   return true;
393 }
394 
HandleNoUsedParameter(const FuncGraphPtr & root)395 void HandleNoUsedParameter(const FuncGraphPtr &root) {
396   MS_EXCEPTION_IF_NULL(root);
397   bool full_batch = ParallelContext::GetInstance()->full_batch();
398   if (full_batch) {
399     return;
400   }
401 
402   auto dev_num = g_device_manager->stage_device_num();
403   auto parameters = root->parameters();
404   if (parameters.empty()) {
405     MS_LOG(INFO) << "Parameters is not in graph, thus no need to set parallel shape";
406   } else {
407     for (auto &parameter : parameters) {
408       if (IsUsedParameter(root, parameter, 0)) {
409         continue;
410       }
411       auto parameter_shape = GetNodeShape(parameter);
412       if (parameter_shape.empty()) {
413         continue;
414       }
415       Shape slice_shape = parameter_shape[0];
416       if (slice_shape.empty() || slice_shape[0] < dev_num) {
417         continue;
418       }
419       slice_shape[0] = slice_shape[0] / dev_num;
420       auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
421       auto abstract = parameter->abstract();
422       MS_EXCEPTION_IF_NULL(abstract);
423       auto abstract_cloned = abstract->Clone();
424       MS_EXCEPTION_IF_NULL(abstract_cloned);
425       abstract_cloned->set_shape(slice_shape_ptr);
426       parameter->set_abstract(abstract_cloned);
427     }
428   }
429 }
430 
IsFullySplitParameter(const ParameterPtr & param_ptr,size_t allow_repeat_num)431 bool IsFullySplitParameter(const ParameterPtr &param_ptr, size_t allow_repeat_num) {
432   auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
433   if (tensor_layout == nullptr) {
434     return false;
435   }
436 
437   auto dev_mat_shape = tensor_layout->device_arrangement().array();
438   auto tensor_map = tensor_layout->tensor_map().array();
439   int64_t rank = g_device_manager->global_rank();
440   RankList rank_list = g_device_manager->GetDeviceListInThisStage();
441   DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
442   RankList group_devices;
443   if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
444     MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
445     return false;
446   }
447 
448   if (group_devices.size() <= allow_repeat_num) {
449     MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
450     return true;
451   }
452   return false;
453 }
454 
GetPyParameterObj(const ParamInfoPtr & param_info,const std::string & obj)455 py::object GetPyParameterObj(const ParamInfoPtr &param_info, const std::string &obj) {
456   py::object py_obj = py::cast(param_info);
457   if (py::isinstance<py::none>(py_obj)) {
458     return py::none();
459   }
460   return python_adapter::GetPyObjAttr(py_obj, obj);
461 }
462 
IsAccuGradObj(const py::object & py_obj)463 static bool IsAccuGradObj(const py::object &py_obj) {
464   auto name = python_adapter::GetPyObjAttr(py_obj, PARAM_NAME);
465   if (py::isinstance<py::none>(name)) {
466     return false;
467   }
468   if (py::cast<std::string>(name).find(ACCU_GRADS) == 0) {
469     return true;
470   }
471   return false;
472 }
473 
SliceParameterObj(const ParameterPtr & parameter,const TensorLayoutPtr & tensor_layout)474 void SliceParameterObj(const ParameterPtr &parameter, const TensorLayoutPtr &tensor_layout) {
475   auto param_info = parameter->param_info();
476   if (param_info == nullptr) {
477     MS_LOG(WARNING) << "parameter: " << parameter->DebugString() << " doesn't have param_info.";
478     return;
479   }
480   auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
481   MS_EXCEPTION_IF_NULL(graph_executor);
482   auto phase = graph_executor->phase();
483   auto py_obj = GetPyParameterObj(param_info, OBJ);
484   if (py::isinstance<py::none>(py_obj)) {
485     MS_LOG(WARNING) << "Parameter: " << parameter->DebugString() << " can't find python obj.";
486     return;
487   }
488   if (tensor_layout == nullptr) {
489     (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, py_obj, py::str(phase),
490                                    py::none());
491     return;
492   }
493   // create python layout obj
494   const auto &device_arrangement = tensor_layout->device_arrangement().array();
495   const auto &tensor_map = tensor_layout->tensor_map().array();
496   auto slice_shape = tensor_layout->base_slice_shape().array();
497   int64_t field_size = tensor_layout->get_field_size();
498   bool uniform_split = tensor_layout->uniform_split();
499   std::string opt_shard_group = tensor_layout->opt_shard_group();
500   if (!opt_shard_group.empty()) {
501     slice_shape = tensor_layout->opt_shard_slice_shape();
502   }
503   auto full_shape = tensor_layout->tensor_shape().array();
504   py::tuple layout =
505     py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group, full_shape);
506 
507   // Call Python _slice_parameter Fn to slice python parameter obj
508   (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, py_obj, py::str(phase), layout);
509 
510   // handle cloned parameter, like accu_grad and optimizer param
511   auto grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
512   auto cloned_py_obj = GetPyParameterObj(param_info, CLONED_OBJ);
513   if (!py::isinstance<py::none>(cloned_py_obj)) {
514     if (!py::isinstance<py::list>(cloned_py_obj)) {
515       MS_LOG(EXCEPTION) << "parameter: " << parameter->DebugString() << " doesn't have correct cloned obj";
516     }
517     auto obj_list = py::cast<py::list>(cloned_py_obj);
518     for (size_t i = 0; i < obj_list.size(); ++i) {
519       py::object each_cloned_obj = obj_list[i];
520       auto cloned_param_slice_shape = tensor_layout->slice_shape().array();
521       if (!opt_shard_group.empty()) {
522         if (!IsAccuGradObj(each_cloned_obj) || grad_accumulation_shard) {
523           cloned_param_slice_shape = tensor_layout->opt_shard_slice_shape();
524         }
525       }
526       py::tuple cloned_param_layout = py::make_tuple(device_arrangement, tensor_map, cloned_param_slice_shape,
527                                                      field_size, uniform_split, opt_shard_group, full_shape);
528       (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, each_cloned_obj, py::str(phase),
529                                      cloned_param_layout);
530     }
531   }
532 }
533 
SliceTensorObj(const ParameterPtr & parameter,const TensorLayoutPtr & tensor_layout,size_t rank_id)534 void SliceTensorObj(const ParameterPtr &parameter, const TensorLayoutPtr &tensor_layout, size_t rank_id) {
535   auto param = parameter->default_param();
536   MS_EXCEPTION_IF_NULL(param);
537   auto p_tensor = param->cast<tensor::TensorPtr>();
538   MS_EXCEPTION_IF_NULL(p_tensor);
539   if (p_tensor->DataSize() == 1) {
540     MS_LOG(INFO) << "The parameter's data size is 1, no need to layout.";
541     return;
542   }
543   if (tensor_layout == nullptr) {
544     MS_LOG(INFO) << "No need to layout parameter";
545     return;
546   }
547   // start get layout info
548   const auto &device_arrangement = tensor_layout->device_arrangement().array();
549   for (auto i : device_arrangement) std::cout << i << ' ';
550   const auto &tensor_map = tensor_layout->tensor_map().array();
551   auto slice_shape = tensor_layout->slice_shape().array();
552   int64_t field_size = tensor_layout->get_field_size();
553   bool uniform_split = tensor_layout->uniform_split();
554   if (uniform_split == 0) {
555     MS_LOG(ERROR) << "The load tensor only support uniform split now.";
556   }
557   std::string opt_shard_group = tensor_layout->opt_shard_group();
558   if (!opt_shard_group.empty()) {
559     slice_shape = tensor_layout->opt_shard_slice_shape();
560   }
561   py::tuple layout =
562     py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
563 
564   MS_LOG(INFO) << "origin p_tensor:" << p_tensor->name() << p_tensor->Size() << p_tensor->shape();
565   auto tensor_py = python_adapter::CastToPyObj(p_tensor);
566   // Call Python _slice_tensor Fn to slice python tensor obj
567   auto new_tensor_py =
568     python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_TENSOR_FN_NAME, tensor_py, layout, rank_id);
569   MS_LOG(INFO) << "Success Call Python _slice_parameter Fn to slice python parameter obj";
570   auto new_tensor = new_tensor_py.cast<tensor::TensorPtr>();
571   MS_LOG(INFO) << "new p_tensor:" << new_tensor->name() << new_tensor->Size() << new_tensor->shape();
572   parameter->set_default_param(new_tensor);
573 }
574 
SliceCacheParameterObj(const ParameterPtr & parameter,const py::dict & layout_dict)575 static void SliceCacheParameterObj(const ParameterPtr &parameter, const py::dict &layout_dict) {
576   auto param_info = parameter->param_info();
577   if (param_info == nullptr) {
578     MS_LOG(WARNING) << "parameter: " << parameter->DebugString() << " doesn't have param_info.";
579     return;
580   }
581   auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
582   MS_EXCEPTION_IF_NULL(graph_executor);
583   auto phase = graph_executor->phase();
584   auto py_obj = GetPyParameterObj(param_info, OBJ);
585   if (py::isinstance<py::none>(py_obj)) {
586     MS_LOG(WARNING) << "Parameter: " << parameter->DebugString() << " can't find python obj.";
587     return;
588   }
589   auto name = parameter->name();
590   if (!layout_dict.contains(name)) {
591     (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, INIT_OPTIMIZER_STATE_FN, py_obj, py::str(phase));
592     return;
593   }
594   auto layout = layout_dict[py::str(name)];
595   // Call Python _slice_parameter Fn to slice python parameter obj
596   (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, py_obj, py::str(phase), layout);
597 
598   // handle cloned parameter, like accu_grad and optimizer param
599   auto cloned_py_obj = GetPyParameterObj(param_info, CLONED_OBJ);
600   if (!py::isinstance<py::none>(cloned_py_obj)) {
601     if (!py::isinstance<py::list>(cloned_py_obj)) {
602       MS_LOG(EXCEPTION) << "parameter: " << parameter->DebugString() << " doesn't have correct cloned obj";
603     }
604     auto obj_list = py::cast<py::list>(cloned_py_obj);
605     for (size_t i = 0; i < obj_list.size(); ++i) {
606       py::object each_cloned_obj = obj_list[i];
607       (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, each_cloned_obj, py::str(phase),
608                                      layout);
609     }
610   }
611 }
612 
InitCompileCacheParams(const pipeline::ResourcePtr & resource)613 void InitCompileCacheParams(const pipeline::ResourcePtr &resource) {
614   auto layout_dict = GetParameterLayoutFromResource(resource);
615   auto graph = resource->func_graph();
616   auto params = graph->parameters();
617   for (auto &param : params) {
618     auto param_ptr = param->cast<ParameterPtr>();
619     MS_EXCEPTION_IF_NULL(param_ptr);
620     if (!param_ptr->has_default()) {
621       continue;
622     }
623     SliceCacheParameterObj(param_ptr, layout_dict);
624   }
625 }
626 
InitPynativeNoShardParams(const FuncGraphPtr & root)627 void InitPynativeNoShardParams(const FuncGraphPtr &root) {
628   auto parameters = root->parameters();
629   for (auto &parameter : parameters) {
630     auto param_ptr = parameter->cast<ParameterPtr>();
631     MS_EXCEPTION_IF_NULL(param_ptr);
632     auto param_info = param_ptr->param_info();
633     if (!param_info) {
634       MS_LOG(DEBUG) << "Parameter:" << parameter->DebugString() << " doesn't have param_info.";
635       continue;
636     }
637     auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
638     MS_EXCEPTION_IF_NULL(graph_executor);
639     auto phase = graph_executor->phase();
640     auto py_obj = GetPyParameterObj(param_info, OBJ);
641     if (py::isinstance<py::none>(py_obj)) {
642       MS_LOG(WARNING) << "Parameter: " << parameter->DebugString() << " can't find python obj.";
643       continue;
644     }
645     (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, INIT_OPTIMIZER_STATE_FN, py_obj, py::str(phase));
646   }
647 }
648 
AutoParallelPostProcess(const FuncGraphPtr & root)649 void AutoParallelPostProcess(const FuncGraphPtr &root) {
650   auto parameters = root->parameters();
651   for (auto &param : parameters) {
652     if (ParameterIsCloned(param)) {
653       continue;
654     }
655     auto layout = param->user_data<TensorLayout>();
656     auto param_ptr = param->cast<ParameterPtr>();
657     MS_EXCEPTION_IF_NULL(param_ptr);
658     if (!param_ptr->has_default()) {
659       continue;
660     }
661     SliceParameterObj(param_ptr, layout);
662   }
663 }
664 
SetClonedTensorShapeForOptimizer(const FuncGraphPtr & root)665 void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
666   MS_EXCEPTION_IF_NULL(root);
667   auto grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
668 
669   for (auto &cloned_parameter_node : root->parameters()) {
670     MS_EXCEPTION_IF_NULL(cloned_parameter_node);
671     auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
672     MS_EXCEPTION_IF_NULL(cloned_parameter);
673 
674     if (!ParameterIsCloned(cloned_parameter_node)) {
675       continue;
676     }
677     auto param_value = cloned_parameter->param_info();
678     if (param_value == nullptr) {
679       continue;
680     }
681     // get the cloned index
682     int64_t cloned_index = param_value->cloned_index();
683 
684     // find the be cloned parameter
685     bool found_be_cloned_parameter = false;
686     ParameterPtr cloned_from_parameter = nullptr;
687     AnfNodePtr cloned_from_node = nullptr;
688     for (auto &be_cloned_parameter_node : root->parameters()) {
689       MS_EXCEPTION_IF_NULL(be_cloned_parameter_node);
690       auto be_cloned_parameter = be_cloned_parameter_node->cast<ParameterPtr>();
691       MS_EXCEPTION_IF_NULL(be_cloned_parameter);
692       if (!be_cloned_parameter->has_default()) {
693         continue;
694       }
695 
696       auto param_value_in = be_cloned_parameter->param_info();
697       if (param_value_in == nullptr) {
698         continue;
699       }
700       if (!param_value_in->be_cloned()) {
701         continue;
702       }
703 
704       // get the be cloned index
705       auto &be_cloned_index = param_value_in->be_cloned_index();
706       if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
707         found_be_cloned_parameter = true;
708         cloned_from_parameter = be_cloned_parameter;
709         cloned_from_node = be_cloned_parameter_node;
710       }
711     }
712 
713     if (found_be_cloned_parameter) {
714       // set the shape and tensor layout for cloned parameter
715       std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
716       if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
717         MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
718         continue;
719       }
720       auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>();
721       MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
722       MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
723       auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
724       MS_EXCEPTION_IF_NULL(cloned_abstract);
725       // from pipeline or grad accumulation
726       if (param_name.find(ACCU_GRADS) != std::string::npos) {
727         auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
728         auto opt_shard_group = tensor_layout->opt_shard_group();
729         auto opt_shard_shape = cloned_from_parameter->user_data<TensorLayout>()->opt_shard_slice_shape();
730         std::shared_ptr<abstract::BaseShape> parallel_shape = nullptr;
731         // set opt shard shape if the pipeline sharding is set
732         if (grad_accumulation_shard && !opt_shard_group.empty()) {
733           parallel_shape = std::make_shared<abstract::Shape>(opt_shard_shape);
734         } else {
735           parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
736         }
737         MS_EXCEPTION_IF_NULL(parallel_shape);
738         cloned_abstract->set_shape(parallel_shape);
739         // in opt shard, accu_grad's shape is different from the original param's shape
740         // if the grad_accumulation_shard is enabled, the accu_grads will be a opt-sharded shape
741         if (!grad_accumulation_shard && ParallelContext::GetInstance()->enable_parallel_optimizer()) {
742           TensorLayout new_layout = *tensor_layout;
743           new_layout.set_opt_shard_group("");
744           tensor_layout = std::make_shared<TensorLayout>(new_layout);
745         }
746       } else {
747         cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
748       }
749       cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
750       cloned_parameter_node->set_abstract(cloned_abstract);
751       // copy the fusion tag
752       auto cloned_param_info = cloned_parameter->param_info();
753       MS_EXCEPTION_IF_NULL(cloned_param_info);
754       auto cloned_from_param_info = cloned_from_parameter->param_info();
755       MS_EXCEPTION_IF_NULL(cloned_from_param_info);
756       cloned_param_info->set_comm_fusion(cloned_from_param_info->comm_fusion());
757 
758       MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
759                    << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
760                    << ", clone index is:  " << cloned_index;
761     } else {
762       MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is  "
763                         << cloned_index << ", but not found the be cloned parameter";
764     }
765   }
766 }
767 
768 // For adafactor optimizer, the relationship between parameter and state's shape as follows:
769 // 1) parameter: [A, B, C, D] (shape_size > 2), exp_avg_sq_row: [A, B, C], exp_avg_sq_col: [A, B, D], exp_avg_sq: [1]
770 //    If the parameter is opt shard, the exp_avg_sq_row and exp_avg_sq_col need to be shard accordingly.
771 // 2) parameter: [A, B] (shape_size = 2), exp_avg_sq_row: [A], exp_avg_sq_col: [B], exp_avg_sq: [1]
772 //    If the parameter is opt shard, the exp_avg_sq_row needs to be shard accordingly.
773 // 3) parameter: [A] (shape_size = 1), exp_avg_sq_row: [1], exp_avg_sq_col: [1], exp_avg_sq: [A]
774 //    If the parameter is opt shard, the exp_avg_sq needs to be shard accordingly.
AdafactorStateIsOptShard(const std::string & opt_shard_group,size_t shape_size,const std::string & param_name,const std::string & state_name)775 static bool AdafactorStateIsOptShard(const std::string &opt_shard_group, size_t shape_size,
776                                      const std::string &param_name, const std::string &state_name) {
777   if (opt_shard_group.empty()) {
778     return false;
779   }
780 
781   std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
782   std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
783   std::string exp_avg_name = EXP_AVG_SQ + param_name;
784   std::string exp_insta_row_name = EXP_AVG_INSTA_ROW + param_name;
785   std::string exp_insta_col_name = EXP_AVG_INSTA_COL + param_name;
786 
787   if (shape_size > 2 && state_name == exp_avg_name) {
788     return false;
789   }
790 
791   if (shape_size == 2 &&
792       (state_name == exp_col_name || state_name == exp_avg_name || state_name == exp_insta_col_name)) {
793     return false;
794   }
795 
796   if (shape_size == 1 &&
797       (state_name == exp_row_name || state_name == exp_col_name || state_name == exp_insta_row_name)) {
798     return false;
799   }
800 
801   MS_LOG(INFO) << "The parameter " << param_name << " is opt shard";
802   return true;
803 }
804 
IsOriginWeight(const ParameterPtr & param)805 static bool IsOriginWeight(const ParameterPtr &param) {
806   std::string param_name = param->name();
807   if (param_name.find(EXP_AVG) != std::string::npos) {
808     return false;
809   }
810 
811   auto tensor_layout = param->user_data<TensorLayout>();
812   if (tensor_layout == nullptr) {
813     return false;
814   }
815 
816   return true;
817 }
818 
FindParameterByValueNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph,const std::string & name=ALL_REDUCE)819 static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
820                                                             const std::string &name = ALL_REDUCE) {
821   if (IsValueNode<RefKey>(node)) {
822     std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
823     if (param_v.size() != 1) {
824       MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is  "
825                         << param_v.size();
826     }
827     auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
828     if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
829         name == ALL_REDUCE) {
830       return std::make_pair(nullptr, true);
831     }
832     return std::make_pair(node, true);
833   }
834   return std::make_pair(nullptr, false);
835 }
836 
RefParameterToActualParameter(const AnfNodePtr & node)837 AnfNodePtr RefParameterToActualParameter(const AnfNodePtr &node) {
838   if (!node->isa<Parameter>()) {
839     return nullptr;
840   }
841   auto node_param_ptr = node->cast<ParameterPtr>();
842   if (node_param_ptr->has_default()) {
843     return node;
844   }
845   auto sub_func_graph = node_param_ptr->func_graph();
846   auto call_cnodes_map = sub_func_graph->func_graph_cnodes_index();
847   auto sub_graph_parameters = sub_func_graph->parameters();
848   auto curr_param_iter = std::find(sub_graph_parameters.begin(), sub_graph_parameters.end(), node);
849   if (curr_param_iter == sub_graph_parameters.end()) {
850     MS_LOG(EXCEPTION) << "Cannot find param " << node_param_ptr->DebugString() << " in current sub_graph";
851   }
852   size_t curr_param_index = static_cast<size_t>(curr_param_iter - sub_graph_parameters.begin());
853   for (const auto &node_pair : call_cnodes_map) {
854     if (!node_pair.first->first->isa<CNode>() || node_pair.first->second > 0) {
855       continue;
856     }
857     auto cnode = node_pair.first->first->cast<CNodePtr>();
858     auto cnode_input = cnode->input(curr_param_index + 1);
859     auto new_cnode = GetInputNodeWithFilter(cnode_input, [&](const CNodePtr &cnode) {
860       bool filter = IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather) ||
861                     IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
862                     IsPrimitiveCNode(cnode, prim::kPrimCast) ||
863                     (IsPrimitiveCNode(cnode, prim::kPrimAllGather) &&
864                      GetCNodePrimitive(cnode)->instance_name().find(PARALLEL_OPTIMIZER) != std::string::npos);
865       return std::make_pair(filter, 1);
866     });
867     return RefParameterToActualParameter(new_cnode);
868   }
869   return nullptr;
870 }
871 
FindParameterByParameter(const AnfNodePtr & node,const std::string & name=ALL_REDUCE)872 static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node,
873                                                             const std::string &name = ALL_REDUCE) {
874   if (!node->isa<Parameter>()) {
875     MS_LOG(EXCEPTION) << "The node is not a parameter, node:" << node->DebugString();
876   }
877   auto node_param_ptr = node->cast<ParameterPtr>();
878   if (node_param_ptr->has_default()) {
879     auto param_ptr = node->user_data<parallel::TensorLayout>();
880     if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
881         name == ALL_REDUCE) {
882       return std::make_pair(nullptr, false);
883     }
884     return std::make_pair(node, false);
885   }
886   AnfNodePtr ref_param = RefParameterToActualParameter(node);
887   if (!ref_param) {
888     return std::make_pair(nullptr, false);
889   }
890   auto ref_param_layout = ref_param->user_data<parallel::TensorLayout>();
891   if (ref_param_layout && !ref_param_layout->opt_shard_group().empty() &&
892       ref_param_layout->opt_shard_mirror_group().empty() && name == ALL_REDUCE) {
893     return std::make_pair(nullptr, false);
894   }
895   return std::make_pair(ref_param, false);
896 }
897 
FindParameterByFuncGraph(const AnfNodePtr & node)898 static std::pair<AnfNodePtr, bool> FindParameterByFuncGraph(const AnfNodePtr &node) {
899   auto cnode = node->cast<CNodePtr>();
900   MS_EXCEPTION_IF_NULL(cnode);
901   auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
902 
903   auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first;
904   if (pre_node) {
905     return FindParameter(pre_node, pre_node->func_graph());
906   }
907   return std::make_pair(nullptr, false);
908 }
909 
910 // Only used for InsertMirrorOps
FindParameter(const AnfNodePtr & node,const FuncGraphPtr & func_graph)911 std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
912   if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
913     return std::make_pair(nullptr, false);
914   }
915 
916   if (node->isa<Parameter>()) {
917     return FindParameterByParameter(node);
918   }
919 
920   if (node->isa<ValueNode>()) {
921     return FindParameterByValueNode(node, func_graph);
922   }
923   CNodePtr cnode = node->cast<CNodePtr>();
924   MS_EXCEPTION_IF_NULL(cnode);
925   if (IsValueNode<FuncGraph>(cnode->input(0))) {
926     return FindParameterByFuncGraph(node);
927   }
928   if (!IsValueNode<Primitive>(cnode->input(0))) {
929     for (size_t index = 0; index < cnode->size(); ++index) {
930       auto res = FindParameter(cnode->input(index), func_graph);
931       if (!res.first) {
932         continue;
933       }
934       return res;
935     }
936   }
937 
938   // When not fully use opt shard, allgather and mirror would be both inserted.
939   // Skip allgather here and find parameter recursively.
940   if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
941     return std::make_pair(nullptr, false);
942   }
943   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
944   MS_EXCEPTION_IF_NULL(prim_anf_node);
945   for (size_t index = 0; index < cnode->size(); ++index) {
946     PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
947     MS_EXCEPTION_IF_NULL(prim);
948     if ((prim->name() == DEPEND || prim->name() == LOAD || IsInAllGatherNodeList(cnode)) && index != 1) {
949       continue;
950     }
951     auto res = FindParameter(cnode->input(index), func_graph);
952     if (!res.first) {
953       continue;
954     }
955     return res;
956   }
957   return std::make_pair(nullptr, false);
958 }
959 
960 // Used for allgather and reducescatter
FindParameterWithAllgather(const AnfNodePtr & node,const FuncGraphPtr & func_graph,const std::string & name)961 std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
962                                                        const std::string &name) {
963   if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
964     return std::make_pair(nullptr, false);
965   }
966 
967   if (node->isa<Parameter>()) {
968     return FindParameterByParameter(node, name);
969   }
970 
971   if (node->isa<ValueNode>()) {
972     return FindParameterByValueNode(node, func_graph, name);
973   }
974 
975   CNodePtr cnode = node->cast<CNodePtr>();
976   MS_EXCEPTION_IF_NULL(cnode);
977   for (size_t index = 0; index < cnode->size(); ++index) {
978     if (index != 1) {
979       continue;
980     }
981     auto res = FindParameterWithAllgather(cnode->input(index), func_graph, name);
982     if (!res.first) {
983       continue;
984     }
985     return res;
986   }
987   return std::make_pair(nullptr, false);
988 }
989 
AdaSumParamTensorLayout(const FuncGraphPtr & root)990 std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root) {
991   MS_EXCEPTION_IF_NULL(root);
992   std::unordered_map<std::string, std::shared_ptr<TensorLayout>> adasum_param_map;
993   for (auto &parameter_node : root->parameters()) {
994     MS_EXCEPTION_IF_NULL(parameter_node);
995     auto cloned_parameter = parameter_node->cast<ParameterPtr>();
996     MS_EXCEPTION_IF_NULL(cloned_parameter);
997 
998     if (!ParameterIsCloned(parameter_node)) {
999       auto parameter_tensor_layout = cloned_parameter->user_data<TensorLayout>();
1000       adasum_param_map["adasum_delta_weight." + cloned_parameter->name()] = parameter_tensor_layout;
1001     }
1002   }
1003   return adasum_param_map;
1004 }
1005 
ValueSequeueScaleToShape(const ValuePtr & value_seq,const Shape & scale,size_t expand_ratio=1)1006 Shape ValueSequeueScaleToShape(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) {
1007   if (!value_seq->isa<ValueSequeue>()) {
1008     MS_LOG(EXCEPTION) << "The input is not a value_sequeue";
1009   }
1010   std::vector<int64_t> origin_value_vector;
1011   if (TransValueSequeueToVector(value_seq, &origin_value_vector) != SUCCESS) {
1012     MS_LOG(EXCEPTION) << "Transform value_seq to vector failed";
1013   }
1014   if (origin_value_vector.size() > scale.size()) {
1015     MS_LOG(EXCEPTION) << "Cannot scale, the size of value_seq is: " << origin_value_vector.size()
1016                       << ", which should be less_equal than scale's size which is: " << scale.size();
1017   }
1018   for (size_t i = 0; i < origin_value_vector.size(); ++i) {
1019     origin_value_vector[i] = origin_value_vector[i] / scale[i];
1020     if (i == 0) {
1021       origin_value_vector[i] = origin_value_vector[i] * SizeToLong(expand_ratio);
1022     }
1023   }
1024   return origin_value_vector;
1025 }
1026 
ValueSequeueScale(const ValuePtr & value_seq,const Shape & scale,size_t expand_ratio=1)1027 ValuePtr ValueSequeueScale(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) {
1028   Shape origin_value_vector = ValueSequeueScaleToShape(value_seq, scale, expand_ratio);
1029   if (value_seq->isa<ValueTuple>()) {
1030     return TransVectorToValueSequeue<ValueTuple>(origin_value_vector);
1031   }
1032   return TransVectorToValueSequeue<ValueList>(origin_value_vector);
1033 }
1034 
ReplaceAdaSumStridedSliceValue(const CNodePtr & stridedslice_cnode1,const std::shared_ptr<TensorLayout> & target_param_layout,size_t slice_expand_ratio)1035 void ReplaceAdaSumStridedSliceValue(const CNodePtr &stridedslice_cnode1,
1036                                     const std::shared_ptr<TensorLayout> &target_param_layout,
1037                                     size_t slice_expand_ratio) {
1038   auto target_param_info = std::make_shared<TensorInfo>(target_param_layout->SqueezeShape());
1039   Dimensions param_strategy = target_param_info->InferStrategy();
1040   auto new_begin1_value =
1041     ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(2)), param_strategy, slice_expand_ratio);
1042   auto new_end1_value =
1043     ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(3)), param_strategy, slice_expand_ratio);
1044   ValueNodePtr new_begin_value_node = std::make_shared<ValueNode>(new_begin1_value);
1045   ValueNodePtr new_end_value_node = std::make_shared<ValueNode>(new_end1_value);
1046   stridedslice_cnode1->set_input(2, new_begin_value_node);
1047   stridedslice_cnode1->set_input(3, new_end_value_node);
1048 }
1049 
GetRankListByLayout(const std::shared_ptr<TensorLayout> & target_param_layout)1050 RankList GetRankListByLayout(const std::shared_ptr<TensorLayout> &target_param_layout) {
1051   int64_t rank = g_device_manager->global_rank();
1052   auto dev_shape = target_param_layout->device_arrangement().array();
1053   auto stage_device_list = g_device_manager->GetDeviceListInThisStage();
1054   DeviceMatrix dev_matrix(rank, stage_device_list, dev_shape);
1055   RankList group_devices;
1056   if (dev_matrix.GetDevicesByTensorMap(target_param_layout->tensor_map().array(), &group_devices) != SUCCESS) {
1057     MS_LOG(EXCEPTION) << "Get adasum parameter origin mirror group by tensor layout failed.";
1058   }
1059   return group_devices;
1060 }
1061 
IsBorderAdaSumSendReceive(const AnfNodePtr & node,const RankList & group_devices)1062 std::vector<bool> IsBorderAdaSumSendReceive(const AnfNodePtr &node, const RankList &group_devices) {
1063   bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
1064   PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
1065   int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
1066   int64_t rank = g_device_manager->global_rank();
1067   if (group_devices.size() - 1 == 0) {
1068     MS_LOG(EXCEPTION) << "May division by zero.";
1069   }
1070   int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / SizeToLong(group_devices.size() - 1);
1071   if (adasum_rank_distance < ADASUM_MIN_DIS) {
1072     adasum_rank_distance = ADASUM_MIN_DIS;
1073   }
1074   size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS));
1075   int64_t fusion_id = GetValue<int64_t>(send_rec_prim->GetAttr("origin_fusion"));
1076   // when cutting nodes, the fusion id should change.
1077   int64_t new_fusion_id = fusion_id + SizeToLong(g_device_manager->DeviceNum() * (border_step + IntToSize(1)));
1078   send_rec_prim->set_attr(FUSION, MakeValue(new_fusion_id));
1079   std::vector<int64_t> group_list;
1080   int64_t new_dest_src_rank;
1081   if (rank > origin_dest_rank) {
1082     group_list = {origin_dest_rank, rank};
1083     new_dest_src_rank = 0;
1084   } else {
1085     group_list = {rank, origin_dest_rank};
1086     new_dest_src_rank = 1;
1087   }
1088   Group adasum_send_rec_group;
1089   if (g_device_manager->CreateGroup(group_list, &adasum_send_rec_group) != SUCCESS) {
1090     MS_LOG(EXCEPTION) << "Create send/receive group in adasum failed, the group is:" << group_list;
1091   }
1092   send_rec_prim->set_attr(GROUP, MakeValue(adasum_send_rec_group.name()));
1093   if (is_send) {
1094     send_rec_prim->set_attr(DEST_RANK, MakeValue(new_dest_src_rank));
1095   } else {
1096     send_rec_prim->set_attr(SRC_RANK, MakeValue(new_dest_src_rank));
1097   }
1098   int64_t rank_dis = abs(origin_dest_rank - rank);
1099   if (adasum_rank_distance == ADASUM_MIN_DIS) {
1100     return {false, false, false, false};
1101   }
1102   bool is_origin_first_node_if_forward = false;
1103   bool is_new_first_node_if_forward = false;
1104   bool is_origin_last_node_if_rollback = false;
1105   bool is_new_last_node_if_rollback = false;
1106   if (rank_dis == ADASUM_MIN_DIS) {
1107     is_origin_first_node_if_forward = true;
1108     is_origin_last_node_if_rollback = true;
1109   }
1110   if (rank_dis == adasum_rank_distance) {
1111     is_new_first_node_if_forward = true;
1112   }
1113   if (rank_dis == adasum_rank_distance / 2) {
1114     is_new_last_node_if_rollback = true;
1115   }
1116   return {is_origin_first_node_if_forward, is_new_first_node_if_forward, is_origin_last_node_if_rollback,
1117           is_new_last_node_if_rollback};
1118 }
1119 
HandleAdaSumReshape(const CNodePtr & reshape_cnode,const std::shared_ptr<TensorLayout> & target_param_layout)1120 void HandleAdaSumReshape(const CNodePtr &reshape_cnode, const std::shared_ptr<TensorLayout> &target_param_layout) {
1121   auto slice_shape = target_param_layout->slice_shape().array();
1122   auto slice_shape_value = TransVectorToValueSequeue<ValueTuple>(slice_shape);
1123   ValueNodePtr new_slice_shape_value_node = std::make_shared<ValueNode>(slice_shape_value);
1124   reshape_cnode->set_input(2, new_slice_shape_value_node);
1125 }
1126 
RemoveAdasumRedundantNodes(const FuncGraphManagerPtr & manager,std::unordered_map<std::string,CNodePtr> * forward_origin_first_node_map,std::unordered_map<std::string,CNodePtr> * forward_new_first_node_map,std::unordered_map<std::string,CNodePtr> * rollback_origin_last_node_map,std::unordered_map<std::string,CNodePtr> * rollback_new_last_node_map)1127 void RemoveAdasumRedundantNodes(const FuncGraphManagerPtr &manager,
1128                                 std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map,
1129                                 std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map,
1130                                 std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map,
1131                                 std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map) {
1132   // connect forward last node and rollback first node
1133   if (forward_origin_first_node_map->size() != forward_new_first_node_map->size() ||
1134       rollback_origin_last_node_map->size() != rollback_new_last_node_map->size()) {
1135     MS_LOG(EXCEPTION) << "The over border node is not equal in adasum forward process and rollback process.";
1136   }
1137   for (auto node : *forward_origin_first_node_map) {
1138     std::string target_param = node.first;
1139     CNodePtr forward_origin_first_node = node.second;
1140     CNodePtr forward_new_first_node = (*forward_new_first_node_map)[target_param];
1141     manager->SetEdge(forward_new_first_node, 1, forward_origin_first_node->input(1));
1142   }
1143   for (auto node : *rollback_origin_last_node_map) {
1144     std::string target_param = node.first;
1145     CNodePtr rollback_origin_last_node = node.second;
1146     CNodePtr rollback_new_last_node = (*rollback_new_last_node_map)[target_param];
1147     (void)manager->Replace(rollback_origin_last_node, rollback_new_last_node);
1148   }
1149 }
1150 
HandleAdasumAllReduce(const PrimitivePtr & prim,const RankList & group_devices)1151 void HandleAdasumAllReduce(const PrimitivePtr &prim, const RankList &group_devices) {
1152   size_t step = size_t(GetValue<int64_t>(prim->GetAttr("step")));
1153   std::vector<int64_t> neighbor_ids;
1154   int64_t adasum_rank_distance =
1155     (group_devices.back() - group_devices.front()) / SizeToLong((group_devices.size() - 1));
1156   if (adasum_rank_distance < ADASUM_MIN_DIS) {
1157     adasum_rank_distance = ADASUM_MIN_DIS;
1158   }
1159   size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS));
1160   MS_LOG(INFO) << "current border step is: " << border_step;
1161   if (step < border_step) {
1162     return;
1163   }
1164   int64_t rank = g_device_manager->global_rank();
1165   size_t double_d = size_t(IntToSize(2) << step);
1166   for (size_t index = 0; index < double_d; ++index) {
1167     int64_t node_rank = rank / ADASUM_MIN_DIS;
1168     int64_t neighbor_id =
1169       (node_rank / SizeToLong(double_d) * SizeToLong(double_d) + SizeToLong(index)) * ADASUM_MIN_DIS +
1170       rank % ADASUM_MIN_DIS;
1171     neighbor_ids.push_back(neighbor_id);
1172   }
1173   Group adasum_allreduce_group;
1174   if (g_device_manager->CreateGroup(neighbor_ids, &adasum_allreduce_group) != SUCCESS) {
1175     MS_LOG(EXCEPTION) << "Create group allreduce group in adasum failed, the group is " << neighbor_ids;
1176   }
1177   auto new_group_name = MakeValue(adasum_allreduce_group.name());
1178   int64_t fusion_id = GetValue<int64_t>(prim->GetAttr("origin_fusion"));
1179   int64_t new_fusion_id = fusion_id + SizeToLong(g_device_manager->DeviceNum() * (border_step + IntToSize(1)));
1180   prim->set_attr(GROUP, new_group_name);
1181   prim->set_attr(FUSION, MakeValue(new_fusion_id));
1182 }
1183 
HandleAdasumSlice(const AnfNodePtr & stridedslice_node1,const std::shared_ptr<TensorLayout> & target_param_layout,size_t slice_expand_ratio)1184 void HandleAdasumSlice(const AnfNodePtr &stridedslice_node1, const std::shared_ptr<TensorLayout> &target_param_layout,
1185                        size_t slice_expand_ratio) {
1186   auto stridedslice_cnode1 = stridedslice_node1->cast<CNodePtr>();
1187   ReplaceAdaSumStridedSliceValue(stridedslice_cnode1, target_param_layout, slice_expand_ratio);
1188   auto squeeze_node = RealInputNode(stridedslice_cnode1, 1);
1189   if (!IsPrimitiveCNode(squeeze_node, prim::kPrimSqueeze)) {
1190     MS_LOG(EXCEPTION) << "The stridedslice input node should be squeeze in adasum";
1191   }
1192   auto squeeze_cnode = squeeze_node->cast<CNodePtr>();
1193   FuncGraphManagerPtr manager = squeeze_node->func_graph()->manager();
1194   MS_EXCEPTION_IF_NULL(manager);
1195   AnfNodeIndexSet node_set = manager->node_users()[squeeze_cnode];
1196   for (auto &node_pair : node_set) {
1197     if (IsPrimitiveCNode(node_pair.first, prim::kPrimStridedSlice) && node_pair.first != stridedslice_node1) {
1198       CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
1199       ReplaceAdaSumStridedSliceValue(use_apply, target_param_layout, slice_expand_ratio);
1200     }
1201   }
1202 }
1203 
HandleAdaSumConcat(const AnfNodePtr & concat_node,const std::vector<bool> & border_info,const std::string & target_param,std::unordered_map<std::string,CNodePtr> * rollback_new_last_node_map,std::unordered_map<std::string,CNodePtr> * rollback_origin_last_node_map)1204 void HandleAdaSumConcat(const AnfNodePtr &concat_node, const std::vector<bool> &border_info,
1205                         const std::string &target_param,
1206                         std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map,
1207                         std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map) {
1208   if (border_info[3]) {
1209     (*rollback_new_last_node_map)[target_param] = concat_node->cast<CNodePtr>();
1210   }
1211   if (border_info[2]) {
1212     auto manager = concat_node->func_graph()->manager();
1213     AnfNodeIndexSet concat_node_user_set = manager->node_users()[concat_node];
1214     for (auto &node_pair : concat_node_user_set) {
1215       if (IsPrimitiveCNode(node_pair.first, prim::kPrimMakeTuple)) {
1216         AnfNodeIndexSet make_tuple_node_user_set = manager->node_users()[node_pair.first];
1217         for (auto &tuple_user : make_tuple_node_user_set) {
1218           if (IsPrimitiveCNode(tuple_user.first, prim::kPrimConcat)) {
1219             (*rollback_origin_last_node_map)[target_param] = tuple_user.first->cast<CNodePtr>();
1220             return;
1221           }
1222         }
1223         return;
1224       }
1225     }
1226   }
1227 }
1228 
HandleAdaSumSqueeze(const AnfNodePtr & stridedslice_node1,const std::vector<bool> & border_info,const std::string & target_param,std::unordered_map<std::string,CNodePtr> * forward_origin_first_node_map,std::unordered_map<std::string,CNodePtr> * forward_new_first_node_map)1229 void HandleAdaSumSqueeze(const AnfNodePtr &stridedslice_node1, const std::vector<bool> &border_info,
1230                          const std::string &target_param,
1231                          std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map,
1232                          std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map) {
1233   auto squeeze_node = RealInputNode(stridedslice_node1->cast<CNodePtr>(), 1);
1234   if (border_info[0]) {
1235     (*forward_origin_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>();
1236   }
1237   if (border_info[1]) {
1238     (*forward_new_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>();
1239   }
1240 }
1241 
HandleAdaSumPureModelParallel(const AnfNodePtr & node)1242 void HandleAdaSumPureModelParallel(const AnfNodePtr &node) {
1243   if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
1244     return;
1245   }
1246   PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
1247   int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
1248   int64_t rank = g_device_manager->global_rank();
1249   CNodePtr cnode = node->cast<CNodePtr>();
1250   auto pre_cnode = RealInputNode(cnode, 1);
1251   int64_t rank_dis = abs(origin_dest_rank - rank);
1252   if (rank_dis == ADASUM_MIN_DIS && IsPrimitiveCNode(pre_cnode, prim::kPrimStridedSlice)) {
1253     auto squeeze_node = pre_cnode->cast<CNodePtr>()->input(1);
1254     if (!IsPrimitiveCNode(squeeze_node, prim::kPrimSqueeze)) {
1255       return;
1256     }
1257     auto squeeze_input = squeeze_node->cast<CNodePtr>()->input(1);
1258     auto manager = squeeze_node->func_graph()->manager();
1259     AnfNodeIndexSet squeeze_input_node_user_set = manager->node_users()[squeeze_input];
1260     for (auto &squeeze_input_user : squeeze_input_node_user_set) {
1261       if (IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimSqueeze) ||
1262           IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimUpdateState) ||
1263           IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimMakeTuple)) {
1264         continue;
1265       }
1266       (void)manager->Replace(squeeze_input_user.first, squeeze_input);
1267     }
1268   }
1269 }
1270 
HandleAdaSum(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes,std::unordered_map<std::string,std::shared_ptr<TensorLayout>> * adasum_param_tensor_layout_map)1271 bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
1272                   std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) {
1273   std::unordered_map<std::string, CNodePtr> forward_origin_first_node_map;
1274   std::unordered_map<std::string, CNodePtr> forward_new_first_node_map;
1275   std::unordered_map<std::string, CNodePtr> rollback_origin_last_node_map;
1276   std::unordered_map<std::string, CNodePtr> rollback_new_last_node_map;
1277   bool is_adasum = false;
1278   for (auto &node : all_nodes) {
1279     bool is_allreduce = IsPrimitiveCNode(node, prim::kPrimAllReduce);
1280     bool is_reshape = IsPrimitiveCNode(node, prim::kPrimReshape);
1281     bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
1282     bool is_receive = IsPrimitiveCNode(node, prim::kPrimReceive);
1283     if (!is_allreduce && !is_reshape && !is_send && !is_receive) {
1284       continue;
1285     }
1286     std::string target_param;
1287     CNodePtr cnode = node->cast<CNodePtr>();
1288     PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
1289     if (!prim->HasAttr(TARGET_PARAM)) {
1290       continue;
1291     }
1292     target_param = GetValue<std::string>(prim->GetAttr(TARGET_PARAM));
1293     auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];
1294     RankList group_devices = GetRankListByLayout(target_param_layout);
1295     // only model parallel
1296     if (group_devices.size() == 1) {
1297       HandleAdaSumPureModelParallel(node);
1298       continue;
1299     }
1300 
1301     int64_t adasum_rank_distance =
1302       (group_devices.back() - group_devices.front()) / SizeToLong((group_devices.size() - 1));
1303     // when the repeat dim is right, the parameter do not enable adasum.
1304     if (adasum_rank_distance == 1 && group_devices.size() < size_t(g_device_manager->stage_device_num())) {
1305       continue;
1306     }
1307     MS_LOG(INFO) << "Apply adasum in auto parallel, current dealing node is: " << node->fullname_with_scope();
1308     is_adasum = true;
1309     size_t slice_expand_ratio =
1310       LongToSize(adasum_rank_distance / ADASUM_MIN_DIS) > 0 ? LongToSize(adasum_rank_distance / ADASUM_MIN_DIS) : 1;
1311     if (is_reshape) {
1312       HandleAdaSumReshape(cnode, (*adasum_param_tensor_layout_map)[target_param]);
1313     }
1314     if (is_allreduce && prim->HasAttr("step")) {
1315       HandleAdasumAllReduce(prim, group_devices);
1316     }
1317     if (is_send || is_receive) {
1318       std::vector<bool> border_info = IsBorderAdaSumSendReceive(node, group_devices);
1319       if (is_receive) {
1320         auto target_param_info = std::make_shared<TensorInfo>(*target_param_layout);
1321         Dimensions param_strategy = target_param_info->InferStrategy();
1322         Shape new_rec_shape = ValueSequeueScaleToShape(prim->GetAttr(SHAPE), param_strategy, slice_expand_ratio);
1323         auto new_rec_shape_value = TransVectorToValueSequeue<ValueList>(new_rec_shape);
1324         prim->set_attr(SHAPE, new_rec_shape_value);
1325         continue;
1326       }
1327       auto stridedslice_node1 = RealInputNode(cnode, 1);
1328       if (IsPrimitiveCNode(stridedslice_node1, prim::kPrimConcat)) {
1329         HandleAdaSumConcat(stridedslice_node1, border_info, target_param, &rollback_new_last_node_map,
1330                            &rollback_origin_last_node_map);
1331         continue;
1332       }
1333       if (!IsPrimitiveCNode(stridedslice_node1, prim::kPrimStridedSlice)) {
1334         continue;
1335       }
1336       HandleAdasumSlice(stridedslice_node1, target_param_layout, slice_expand_ratio);
1337       HandleAdaSumSqueeze(stridedslice_node1, border_info, target_param, &forward_origin_first_node_map,
1338                           &forward_new_first_node_map);
1339     }
1340   }
1341   RemoveAdasumRedundantNodes(root->manager(), &forward_origin_first_node_map, &forward_new_first_node_map,
1342                              &rollback_origin_last_node_map, &rollback_new_last_node_map);
1343   return is_adasum;
1344 }
1345 
ResetMirrorAttr(const PrimitivePtr & prim,const RankList & new_group)1346 void ResetMirrorAttr(const PrimitivePtr &prim, const RankList &new_group) {
1347   if (new_group.size() == 1) {
1348     prim->set_attr(DEV_NUM, MakeValue<int64_t>(SizeToLong(new_group.size())));
1349     prim->set_attr(GROUP, MakeValue("one_rank_group"));
1350     prim->set_attr(GROUP_RANKS, MakeValue(std::to_string(new_group[0])));
1351     return;
1352   }
1353   Group adasum_mirror_group;
1354   if (g_device_manager->CreateGroup(new_group, &adasum_mirror_group) != SUCCESS) {
1355     MS_LOG(EXCEPTION) << "Create new mirror group failed in adasum, new group is: " << new_group;
1356   }
1357   auto new_group_name = MakeValue(adasum_mirror_group.name());
1358   prim->set_attr(GROUP, new_group_name);
1359   prim->set_attr(DEV_NUM, MakeValue<int64_t>(SizeToLong(new_group.size())));
1360   std::string rank_list_name = g_device_manager->FindRankListNameByHashName(adasum_mirror_group.name());
1361   prim->set_attr(GROUP_RANKS, MakeValue(rank_list_name));
1362 }
1363 
HandleMirrorInAdaSum(const FuncGraphPtr & root,std::unordered_map<std::string,std::shared_ptr<TensorLayout>> * adasum_param_tensor_layout_map)1364 void HandleMirrorInAdaSum(
1365   const FuncGraphPtr &root,
1366   std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) {
1367   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root->get_return());
1368   for (auto &node : all_nodes) {
1369     if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
1370       continue;
1371     }
1372     CNodePtr mirror_cnode = node->cast<CNodePtr>();
1373     auto param_node_pair = FindParameter(mirror_cnode->input(1), node->func_graph());
1374     if (!param_node_pair.first) {
1375       MS_LOG(EXCEPTION) << "Mirror input is not a param";
1376     }
1377     auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
1378     std::string param_name = param_ptr->name();
1379     MS_LOG(INFO) << "Mirror param name is: " << param_name;
1380     std::string target_param = "adasum_delta_weight." + param_name;
1381     auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];
1382 
1383     // Change mirror group
1384     RankList group_devices = GetRankListByLayout(target_param_layout);
1385     int64_t rank = g_device_manager->global_rank();
1386     size_t group_dis = LongToSize(group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
1387     auto prim = GetCNodePrimitive(node);
1388     if (group_dis < ADASUM_MIN_DIS && group_dis > 0) {
1389       size_t new_group_size = size_t(ADASUM_MIN_DIS) / group_dis;
1390       // compute new group range
1391       size_t group_begin = 0;
1392       for (size_t group_end = new_group_size; group_end < group_devices.size() + new_group_size;
1393            group_end += new_group_size) {
1394         int64_t max_group_value =
1395           group_end >= group_devices.size() ? (group_devices.back() + 1) : group_devices[group_end];
1396         if (group_devices[group_begin] <= rank && rank < max_group_value) {
1397           std::vector<int64_t> new_group(group_devices.begin() + SizeToLong(group_begin),
1398                                          group_devices.begin() + SizeToLong(group_end));
1399           MS_LOG(INFO) << "Find new mirror group in adasum: " << new_group << " target_param:" << target_param;
1400           ResetMirrorAttr(prim, new_group);
1401           break;
1402         }
1403         group_begin = group_end;
1404       }
1405       continue;
1406     }
1407     ResetMirrorAttr(prim, {rank});
1408   }
1409 }
1410 
SetParamInfoSaveStrategy(ParameterPtr row_col_param)1411 void SetParamInfoSaveStrategy(ParameterPtr row_col_param) {
1412   if (!row_col_param) {
1413     return;
1414   }
1415   auto param_info = row_col_param->param_info();
1416   if (param_info) {
1417     param_info->set_strategy_ckpt_saved(true);
1418   }
1419 }
1420 
HandleCameAndAdaFactorOpt(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)1421 void HandleCameAndAdaFactorOpt(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
1422                                const FuncGraphManagerPtr &manager) {
1423   MS_LOG(INFO) << "Adafactor or Came optimizer process start";
1424   MS_EXCEPTION_IF_NULL(root);
1425   std::set<AnfNodePtr> origin_params;
1426   for (auto &param_node : root->parameters()) {
1427     MS_EXCEPTION_IF_NULL(param_node);
1428     auto param = param_node->cast<ParameterPtr>();
1429     MS_EXCEPTION_IF_NULL(param);
1430 
1431     if (!IsOriginWeight(param)) {
1432       continue;
1433     }
1434 
1435     int64_t row_col_count = 0;
1436     int64_t exp_avg_sq_count = 0;
1437     for (auto &row_col_node : root->parameters()) {
1438       bool is_all_param_collected = (row_col_count == 4) && (exp_avg_sq_count == 1);
1439       if (is_all_param_collected) {
1440         break;
1441       }
1442 
1443       MS_EXCEPTION_IF_NULL(row_col_node);
1444       auto row_col_param = row_col_node->cast<ParameterPtr>();
1445       MS_EXCEPTION_IF_NULL(row_col_param);
1446       std::string row_col_param_name = row_col_param->name();
1447       std::string param_name = param->name();
1448       std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
1449       std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
1450       std::string exp_insta_row_name = EXP_AVG_INSTA_ROW + param_name;
1451       std::string exp_insta_col_name = EXP_AVG_INSTA_COL + param_name;
1452       std::string exp_avg_name = EXP_AVG_SQ + param_name;
1453       std::set<std::string> came_param_set = {exp_row_name, exp_col_name, exp_insta_row_name, exp_insta_col_name,
1454                                               exp_avg_name};
1455 
1456       if (came_param_set.find(row_col_param_name) == came_param_set.end()) {
1457         continue;
1458       }
1459       origin_params.insert(param_node);
1460       auto tensor_layout = param->user_data<TensorLayout>();
1461       MS_EXCEPTION_IF_NULL(tensor_layout);
1462       auto slice_shape = tensor_layout->slice_shape().array();
1463       Shape opt_shard_slice_shape = slice_shape;
1464       if (!tensor_layout->opt_shard_group().empty()) {
1465         opt_shard_slice_shape = tensor_layout->opt_shard_slice_shape();
1466       }
1467 
1468       auto shape_size = slice_shape.size();
1469       bool is_row_or_col_param = row_col_param_name != exp_avg_name;
1470       if (is_row_or_col_param && shape_size <= 1) {
1471         row_col_count++;
1472         continue;
1473       }
1474 
1475       if (row_col_param_name == exp_avg_name && shape_size != 1) {
1476         exp_avg_sq_count++;
1477         continue;
1478       }
1479 
1480       auto origin_shape = tensor_layout->tensor_shape().array();
1481       auto dev_mat = tensor_layout->device_arrangement().array();
1482       auto tensor_map = tensor_layout->tensor_map().array();
1483 
1484       if (row_col_param_name == exp_row_name || row_col_param_name == exp_insta_row_name) {
1485         opt_shard_slice_shape.pop_back();
1486         origin_shape.pop_back();
1487         tensor_map.pop_back();
1488         row_col_count++;
1489       } else if (row_col_param_name == exp_col_name || row_col_param_name == exp_insta_col_name) {
1490         (void)opt_shard_slice_shape.erase(opt_shard_slice_shape.cbegin() +
1491                                           static_cast<different_type>(SECOND_FROM_END(shape_size)));
1492         (void)origin_shape.erase(origin_shape.cbegin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
1493         (void)tensor_map.erase(tensor_map.cbegin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
1494         row_col_count++;
1495       } else {
1496         exp_avg_sq_count++;
1497       }
1498 
1499       TensorLayout new_tensor_layout;
1500       if (new_tensor_layout.InitFromVector(dev_mat, tensor_map, origin_shape) != SUCCESS) {
1501         MS_LOG(EXCEPTION) << "Init tensor layout failed";
1502       }
1503 
1504       if (AdafactorStateIsOptShard(tensor_layout->opt_shard_group(), shape_size, param_name, row_col_param_name)) {
1505         new_tensor_layout.set_opt_shard_group(tensor_layout->opt_shard_group());
1506         new_tensor_layout.set_opt_shard_slice_shape(opt_shard_slice_shape);
1507       }
1508       SetParamInfoSaveStrategy(row_col_param);
1509       auto cloned_abstract = row_col_node->abstract()->Clone();
1510       MS_EXCEPTION_IF_NULL(cloned_abstract);
1511       std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(opt_shard_slice_shape);
1512       MS_EXCEPTION_IF_NULL(parallel_shape);
1513       cloned_abstract->set_shape(parallel_shape);
1514       row_col_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(new_tensor_layout));
1515       row_col_node->set_abstract(cloned_abstract);
1516     }
1517   }
1518 
1519   for (const auto &origin_param_node : origin_params) {
1520     auto inserter = CameCommHandler(origin_param_node->cast<ParameterPtr>(), root->parameters(), manager->node_users());
1521     inserter.Process();
1522   }
1523 }
1524 
GenerateTensorLayoutForParamReshapeWithStra(const AnfNodePtr & node,const Dimensions input_stra)1525 static std::shared_ptr<TensorLayout> GenerateTensorLayoutForParamReshapeWithStra(const AnfNodePtr &node,
1526                                                                                  const Dimensions input_stra) {
1527   CheckGlobalDeviceManager();
1528   int64_t dev_num = g_device_manager->stage_device_num();
1529   MS_EXCEPTION_IF_ZERO("dev_num", dev_num);
1530 
1531   Shapes inputs_shape = GetNodeShape(node);
1532   Shape param_shape = inputs_shape[0];
1533 
1534   Shape param_dev_matrix_shape(input_stra.size() + 1, 0);
1535   for (size_t i = param_dev_matrix_shape.size() - 1; i > 0; i--) {
1536     param_dev_matrix_shape[i] = input_stra[i - 1];
1537   }
1538   param_dev_matrix_shape[0] =
1539     dev_num / std::accumulate(input_stra.begin(), input_stra.end(), 1, std::multiplies<int64_t>());
1540 
1541   TensorMap param_tensor_map;
1542   for (size_t i = 0; i < param_shape.size(); ++i) {
1543     param_tensor_map.push_back(static_cast<int64_t>(param_shape.size() - i - 1));
1544   }
1545 
1546   TensorLayout param_layout;
1547 
1548   if (param_layout.InitFromVector(param_dev_matrix_shape, param_tensor_map, param_shape) != SUCCESS) {
1549     MS_LOG(EXCEPTION) << "Infer param-Reshape with strategy tensor layout failed.";
1550   }
1551 
1552   return std::make_shared<TensorLayout>(param_layout);
1553 }
1554 
FindParameterByCallNode(const CNodePtr & call,int64_t index)1555 static AnfNodePtr FindParameterByCallNode(const CNodePtr &call, int64_t index) {
1556   MS_EXCEPTION_IF_NULL(call);
1557   AnfNodePtr graph_value_node = call->input(0);
1558   if (!IsValueNode<FuncGraph>(graph_value_node)) {
1559     return nullptr;
1560   }
1561   auto graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node);
1562   auto parameters = graph_sub->parameters();
1563   if (LongToSize(index - 1) >= parameters.size()) {
1564     MS_LOG(EXCEPTION) << "The index is out of range, index is: " << (index - 1) << ", vector size is "
1565                       << parameters.size();
1566   }
1567   return parameters[LongToSize(index - 1)];
1568 }
1569 
FindParameterNextLayout(const AnfNodePtr & node,size_t curr_depth)1570 static std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node, size_t curr_depth) {
1571   if (curr_depth > MAX_RECURSIVE_DEPTH) {
1572     MS_LOG(WARNING) << "When finding the next tensor layout for the parameter, exceeded the maximum recursion depth: "
1573                     << MAX_RECURSIVE_DEPTH;
1574     return nullptr;
1575   }
1576   FuncGraphManagerPtr manager = node->func_graph()->manager();
1577   MS_EXCEPTION_IF_NULL(manager);
1578   AnfNodeIndexSet node_set = manager->node_users()[node];
1579   for (auto &node_pair : node_set) {
1580     if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) {
1581       auto layout_param = FindParameterNextLayout(node_pair.first, ++curr_depth);
1582       if (!layout_param) {
1583         continue;
1584       }
1585       return layout_param;
1586     }
1587     CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
1588     if (use_apply == nullptr) {
1589       continue;
1590     }
1591     auto op = use_apply->input(0);
1592     MS_EXCEPTION_IF_NULL(op);
1593     if (IsValueNode<FuncGraph>(op)) {
1594       auto fg = GetValueNode<FuncGraphPtr>(op);
1595       auto para = FindParameterByCallNode(use_apply, node_pair.second);
1596       auto layout_param = FindParameterNextLayout(para, ++curr_depth);
1597       if (!layout_param) {
1598         continue;
1599       }
1600       return layout_param;
1601     }
1602     if (!IsValueNode<Primitive>(use_apply->input(0))) {
1603       continue;
1604     }
1605     ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
1606     MS_EXCEPTION_IF_NULL(prim_anf_node);
1607     PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
1608     MS_EXCEPTION_IF_NULL(node_prim);
1609     if (node_prim->name() == DEPEND && node_pair.second != 1) {
1610       continue;
1611     }
1612     if (node_prim->name() == RESHAPE) {
1613       auto attrs_temp = node_prim->attrs();
1614       if (!StrategyFound(attrs_temp)) {
1615         continue;
1616       }
1617       StrategyPtr strategy = ExtractStrategy(attrs_temp[IN_STRATEGY]);
1618       Strategies stra = strategy->GetInputDim();
1619       Dimensions input_strategy = stra.at(0);
1620 
1621       auto param_layout = GenerateTensorLayoutForParamReshapeWithStra(node, input_strategy);
1622 
1623       return param_layout;
1624     }
1625     if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
1626       auto layout = GetInputLayoutFromCNode(node_pair, -1);
1627       return std::make_shared<TensorLayout>(layout);
1628     }
1629   }
1630   return nullptr;
1631 }
1632 
CreateParameterLayout(const AnfNodePtr & node)1633 std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
1634   // Create DataParallel tensor layout for parameter(support WideDeep).
1635   auto next_layout = FindParameterNextLayout(node, 0);
1636   if (next_layout != nullptr) {
1637     return next_layout;
1638   }
1639   CheckGlobalDeviceManager();
1640   int64_t dev_num = g_device_manager->stage_device_num();
1641   MS_EXCEPTION_IF_ZERO("dev_num", dev_num);
1642   TensorLayout input_tensor_layout;
1643   // create input_shape
1644   Shapes inputs_shape = GetNodeShape(node);
1645   Shape input_shape_array = inputs_shape[0];
1646 
1647   // create dev_matrix
1648   Shape dev_matrix_array = {dev_num};
1649 
1650   // create tensor_map
1651   size_t shape_size = input_shape_array.size();
1652   TensorMap input_tensor_map_array(shape_size, MAP_NONE);
1653   if ((shape_size > 0) && (input_shape_array[0] % dev_num == 0)) {
1654     input_tensor_map_array[0] = 0;  // shard parameter's first dimension when parameter->Reshape->Op
1655   }
1656 
1657   if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) {
1658     MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed.";
1659   }
1660   return std::make_shared<TensorLayout>(input_tensor_layout);
1661 }
1662 
1663 // temporary method for handling StandardNormal Insertion in opt graph
InsertUniformRealForTaggedNodes(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & all_nodes)1664 void InsertUniformRealForTaggedNodes(const FuncGraphManagerPtr &manager, const std::vector<AnfNodePtr> &all_nodes) {
1665   for (auto &node : all_nodes) {
1666     MS_EXCEPTION_IF_NULL(node);
1667     if (!node->isa<CNode>()) {
1668       continue;
1669     }
1670     auto primitive = GetCNodePrimitive(node);
1671     if (primitive == nullptr) {
1672       continue;
1673     }
1674     if (common::AnfAlgo::IsCommunicationOp(node)) {
1675       continue;
1676     }
1677     auto comm_prim = common::AnfAlgo::GetCNodePrimitive(node);
1678     if (comm_prim->HasAttr("insert_rand")) {
1679       MS_LOG(INFO) << "Insert UniformReal to node" << node->DebugString();
1680       std::vector<AnfNodePtr> inputShape = {NewValueNode(prim::kPrimShape), node->cast<CNodePtr>()->input(kIndex1)};
1681       auto inputShapeNode = node->func_graph()->NewCNode(inputShape);
1682 
1683       std::vector<AnfNodePtr> uniformReal = {NewValueNode(prim::kPrimUniformReal), inputShapeNode->cast<AnfNodePtr>()};
1684       auto uniformRealNode = node->func_graph()->NewCNode(uniformReal);
1685 
1686       auto uniformRealPrim = GetCNodePrimitive(uniformRealNode);
1687       auto attrs = uniformRealPrim->attrs();
1688       attrs["seed"] = MakeValue<int64_t>(0);
1689       attrs["seed2"] = MakeValue<int64_t>(0);
1690       (void)uniformRealPrim->SetAttrs(attrs);
1691 
1692       manager->SetEdge(node, 1, uniformRealNode);
1693     }
1694   }
1695 }
1696 }  // namespace parallel
1697 }  // namespace mindspore
1698