1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/parameter_manager.h"
18
19 #include <cinttypes>
20 #include <algorithm>
21
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <utility>
27 #include <deque>
28 #include <functional>
29
30 #include "mindspore/core/ops/sequence_ops.h"
31 #include "mindspore/core/ops/other_ops.h"
32 #include "mindspore/core/ops/array_ops.h"
33 #include "mindspore/core/ops/framework_ops.h"
34 #include "utils/hash_map.h"
35 #include "frontend/operator/ops.h"
36 #include "frontend/optimizer/optimizer.h"
37 #include "include/common/utils/parallel_context.h"
38 #include "frontend/parallel/device_manager.h"
39 #include "frontend/parallel/graph_util/generate_graph.h"
40 #include "frontend/parallel/graph_util/graph_info.h"
41 #include "frontend/parallel/graph_util/node_info.h"
42 #include "frontend/parallel/graph_util/get_parallel_info.h"
43 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
44 #include "frontend/parallel/node_check.h"
45 #include "ir/param_info.h"
46 #include "ir/tensor.h"
47 #include "utils/trace_base.h"
48 #include "include/common/utils/comm_manager.h"
49 #include "utils/ms_context.h"
50 #include "utils/symbolic.h"
51 #include "pipeline/jit/ps/pipeline.h"
52 #include "mindspore/core/utils/parallel_node_check.h"
53 #include "frontend/parallel/step_parallel_utils.h"
54 #include "mindspore/core/ops/nn_ops.h"
55
56 namespace mindspore {
57 namespace parallel {
58 using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
FindRefKeyNodeUsers(const RefKeyPair & ref_key_pair,bool (* IsCareNode)(const CNodePtr &))59 static ParameterUsersInfo FindRefKeyNodeUsers(const RefKeyPair &ref_key_pair, bool (*IsCareNode)(const CNodePtr &)) {
60 // Dealing with the RefKey case
61 ParameterUsersInfo parameter_user_info;
62 auto refkeys = ref_key_pair.second;
63 auto cnode = ref_key_pair.first;
64
65 auto cnode_ptr = cnode->cast<CNodePtr>();
66 if ((cnode_ptr == nullptr) || !IsValueNode<Primitive>(cnode_ptr->input(0)) || !IsCareNode(cnode_ptr)) {
67 return parameter_user_info;
68 }
69
70 if (refkeys.size() > 1) {
71 MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << "'s inputs have more than 1 RefKeys";
72 }
73 MS_EXCEPTION_IF_NULL(cnode->func_graph());
74 auto cnode_func_graph = cnode->func_graph();
75 MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
76
77 // Find the RefKey being used
78 auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]];
79 for (auto &candidate : candidate_set_by_refkey) {
80 auto candidate_node = candidate.first;
81 auto c = candidate_node->cast<CNodePtr>();
82 if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
83 continue;
84 }
85 parameter_user_info.second.second.insert(candidate);
86 }
87
88 // Find the corresponding Parameter being used
89 std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph);
90 if (parameters.size() != 1) {
91 MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
92 }
93 parameter_user_info.first = parameters[0]->cast<ParameterPtr>()->name();
94 parameter_user_info.second.first = parameters[0];
95 auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]];
96 for (auto &candidate : candidate_set_by_para) {
97 auto candidate_node = candidate.first;
98 auto c = candidate_node->cast<CNodePtr>();
99 if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
100 continue;
101 }
102 parameter_user_info.second.second.insert(candidate);
103 }
104 return parameter_user_info;
105 }
106
FindParameterNodeUsers(const AnfNodePtr & node,const std::vector<AnfNodePtr> & all_nodes)107 static ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, const std::vector<AnfNodePtr> &all_nodes) {
108 // In this case, node is a Parameter
109 ParameterUsersInfo parameter_user_info;
110 MS_EXCEPTION_IF_NULL(node->func_graph());
111 MS_EXCEPTION_IF_NULL(node->func_graph()->manager());
112 auto candidate_set = node->func_graph()->manager()->node_users()[node];
113 for (auto &candidate : candidate_set) {
114 auto candidate_node = candidate.first;
115 if (IsPrimitiveCNode(candidate_node, prim::kPrimLoad)) {
116 if (candidate.second != 1) {
117 continue;
118 }
119 auto &node_user_map = node->func_graph()->manager()->node_users();
120 auto load_node_users = node_user_map[candidate_node];
121 for (auto &node_user : load_node_users) {
122 auto cnode = node_user.first->cast<CNodePtr>();
123 std::pair<AnfNodePtr, int> child_parallel_care_node;
124 if (IsSomePrimitive(cnode, UPDATESTATE) || !cnode->in_forward_flag()) {
125 continue;
126 }
127 if (!IsSomePrimitive(cnode, MAKE_TUPLE) && (IsParallelCareNode(cnode) || IsAutoParallelCareNode(cnode))) {
128 child_parallel_care_node = node_user;
129 } else {
130 child_parallel_care_node = BFSParallelCareNode(cnode, node_user_map, node_user.second, all_nodes);
131 }
132 if (child_parallel_care_node.first) {
133 cnode = child_parallel_care_node.first->cast<CNodePtr>();
134 } else {
135 continue;
136 }
137 if (cnode == nullptr || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
138 continue;
139 }
140 parameter_user_info.second.second.insert(child_parallel_care_node);
141 }
142 } else {
143 auto c = candidate_node->cast<CNodePtr>();
144 if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) {
145 continue;
146 }
147 parameter_user_info.second.second.insert(candidate);
148 }
149 }
150 parameter_user_info.first = node->cast<ParameterPtr>()->name();
151 parameter_user_info.second.first = node;
152 return parameter_user_info;
153 }
154
CNodeWithRefKeys(const AnfNodePtr & cnode)155 static RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode) {
156 MS_EXCEPTION_IF_NULL(cnode);
157 std::vector<AnfNodePtr> refkeys;
158 if (cnode->isa<CNode>()) {
159 auto cnode_ptr = cnode->cast<CNodePtr>();
160 auto inputs = cnode_ptr->inputs();
161 for (auto &one_input : inputs) {
162 if (IsValueNode<RefKey>(one_input)) {
163 refkeys.push_back(one_input);
164 }
165 }
166 if (refkeys.size() >= 1) {
167 return std::make_pair(cnode, refkeys);
168 }
169 }
170 return {nullptr, refkeys};
171 }
172
FindParameterUsers(const AnfNodePtr & node,bool (* IsCareNode)(const CNodePtr &),const std::vector<AnfNodePtr> & all_nodes)173 ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &),
174 const std::vector<AnfNodePtr> &all_nodes) {
175 ParameterUsersInfo parameter_users_info;
176
177 auto cnode_with_refkeys = CNodeWithRefKeys(node);
178 if (cnode_with_refkeys.first != nullptr) {
179 // the node is a ref key node
180 return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode);
181 } else if (node->isa<Parameter>()) {
182 auto param_ptr = node->cast<ParameterPtr>();
183 MS_EXCEPTION_IF_NULL(param_ptr);
184 // the node is a parameter node
185 if (param_ptr->has_default()) {
186 return FindParameterNodeUsers(node, all_nodes);
187 }
188 }
189
190 return parameter_users_info;
191 }
192
IsUsedParameter(const FuncGraphPtr & graph,const AnfNodePtr & parameter,size_t max_depth)193 static bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, size_t max_depth) {
194 if (max_depth > MAX_RECURSIVE_DEPTH) {
195 MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
196 }
197 MS_EXCEPTION_IF_NULL(graph);
198 MS_EXCEPTION_IF_NULL(parameter);
199 auto manager = graph->manager();
200 auto node_users = manager->node_users()[parameter];
201 if (node_users.empty()) {
202 return false;
203 }
204 for (auto node_user : node_users) {
205 auto use_node = node_user.first->cast<CNodePtr>();
206 if (IsValueNode<FuncGraph>(use_node->input(0))) {
207 auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
208 auto parameters = graph_sub->parameters();
209 auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
210 return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
211 }
212 if (use_node->input(0)->isa<CNode>()) {
213 auto cnode = use_node->input(0)->cast<CNodePtr>();
214 if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
215 return true;
216 }
217 auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
218 auto parameters = graph_sub->parameters();
219 auto parameter_sub = parameters[IntToSize(node_user.second - 1)];
220 return IsUsedParameter(graph_sub, parameter_sub, max_depth + 1);
221 }
222 return true;
223 }
224 return true;
225 }
226
GetDevListByTensorMapValue(DeviceMatrix dev_matrix,int64_t tensor_map_value,size_t dev_matrix_size)227 static RankList GetDevListByTensorMapValue(DeviceMatrix dev_matrix, int64_t tensor_map_value, size_t dev_matrix_size) {
228 RankList rank_list;
229 if (tensor_map_value >= SizeToLong(dev_matrix_size) || tensor_map_value < MAP_NONE) {
230 MS_LOG(ERROR) << "The size of dev_matrix is " << dev_matrix_size << ", but the tensor map value is "
231 << tensor_map_value;
232 return rank_list;
233 }
234
235 if (tensor_map_value == MAP_NONE) {
236 rank_list.push_back(g_device_manager->global_rank());
237 return rank_list;
238 }
239
240 uint64_t dim = dev_matrix_size - LongToSize(tensor_map_value) - 1;
241 if (dev_matrix.GetDevicesAlongDim(dim, &rank_list) != SUCCESS) {
242 MS_LOG(ERROR) << "Get devices along dim failed";
243 }
244
245 return rank_list;
246 }
247
IsSameTensorLayout(const TensorLayout & a,const TensorLayout & b)248 static bool IsSameTensorLayout(const TensorLayout &a, const TensorLayout &b) {
249 if (!a.IsSameTensorShape(b)) {
250 return false;
251 }
252 if (a.IsSameDeviceArrangement(b) && a.IsSameTensorMap(b)) {
253 return true;
254 }
255
256 Shape a_tensor_map = a.tensor_map().array();
257 Shape b_tensor_map = b.tensor_map().array();
258 if (a_tensor_map.size() != b_tensor_map.size()) {
259 return false;
260 }
261
262 CheckGlobalDeviceManager();
263 int64_t rank = g_device_manager->global_rank();
264 DeviceMatrix a_dev_matrix(rank, g_device_manager->GetDeviceListInThisStage(), a.device_arrangement().array());
265 DeviceMatrix b_dev_matrix(rank, g_device_manager->GetDeviceListInThisStage(), b.device_arrangement().array());
266 size_t a_dev_mat_size = a.device_arrangement().array().size();
267 size_t b_dev_mat_size = b.device_arrangement().array().size();
268
269 for (size_t i = 0; i < a_tensor_map.size(); ++i) {
270 if (a_tensor_map[i] == MAP_NONE && b_tensor_map[i] == MAP_NONE) {
271 continue;
272 }
273
274 RankList a_dev_list_by_dim = GetDevListByTensorMapValue(a_dev_matrix, a_tensor_map[i], a_dev_mat_size);
275 RankList b_dev_list_by_dim = GetDevListByTensorMapValue(b_dev_matrix, b_tensor_map[i], b_dev_mat_size);
276 if (a_dev_list_by_dim.empty() || b_dev_list_by_dim.empty()) {
277 MS_LOG(EXCEPTION) << "Can not get device list by tensor map value, these layouts are " << a.ToString()
278 << std::endl
279 << " and " << b.ToString();
280 }
281
282 if (a_dev_list_by_dim != b_dev_list_by_dim) {
283 return false;
284 }
285 }
286
287 return true;
288 }
289
IsSameTensorInfo(const TensorInfo & a,const TensorInfo & b)290 bool IsSameTensorInfo(const TensorInfo &a, const TensorInfo &b) {
291 return IsSameTensorLayout(a.tensor_layout(), b.tensor_layout());
292 }
293
CheckParameterSplit(const std::vector<AnfNodePtr> & all_nodes)294 void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
295 for (auto &node : all_nodes) {
296 ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode, all_nodes);
297 auto &users_set = parameter_users_info.second.second;
298 if (users_set.size() <= 1) {
299 continue;
300 }
301
302 auto parameter_name = parameter_users_info.first;
303 MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users";
304 auto &first_user = users_set.front();
305 auto parameter_tensor_info = GetInputsTensorInfo(first_user);
306
307 for (auto iter = users_set.begin() + 1; iter != users_set.end(); ++iter) {
308 auto &user = *iter;
309 auto user_tensor_info = GetInputsTensorInfo(user);
310 if (IsSameTensorInfo(parameter_tensor_info, user_tensor_info)) {
311 continue;
312 } else {
313 MS_LOG(EXCEPTION) << "The parameter: " << parameter_name
314 << " has multiple users, but the TensorInfo are different, they are "
315 << parameter_tensor_info.tensor_layout().ToString() << std::endl
316 << " and " << user_tensor_info.tensor_layout().ToString();
317 }
318 }
319 }
320 }
321
322 namespace {
RevertSymbolicKeyInstance(const FuncGraphPtr & root,const AnfNodePtr & node)323 void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) {
324 MS_EXCEPTION_IF_NULL(root);
325 MS_EXCEPTION_IF_NULL(node);
326 auto symbolic_key = GetValueNode<SymbolicKeyInstancePtr>(node);
327 MS_EXCEPTION_IF_NULL(symbolic_key);
328 auto all_upstream_node = root->manager()->node_users()[node];
329 for (auto &upstream_node : all_upstream_node) {
330 FuncGraphPtr fg = upstream_node.first->func_graph();
331 if (symbolic_key->node()->isa<Parameter>()) {
332 for (auto ¶m : root->parameters()) {
333 if (*param == *symbolic_key->node()) {
334 AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param});
335 MS_EXCEPTION_IF_NULL(reverted_node);
336 MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString();
337 (void)fg->manager()->Replace(node, reverted_node);
338 MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString();
339 }
340 }
341 }
342 }
343 }
344 } // namespace
345
HandleSymbolicKeyInstance(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)346 void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
347 MS_EXCEPTION_IF_NULL(root);
348 for (auto &node : all_nodes) {
349 // revert back SymbolicKeyInstance to embed() primitive
350 if (IsValueNode<SymbolicKeyInstance>(node)) {
351 RevertSymbolicKeyInstance(root, node);
352 continue;
353 }
354 }
355 }
356
IsStrategySaved(const AnfNodePtr & parameter_node)357 bool IsStrategySaved(const AnfNodePtr ¶meter_node) {
358 MS_EXCEPTION_IF_NULL(parameter_node);
359 auto cloned_parameter = parameter_node->cast<ParameterPtr>();
360 MS_EXCEPTION_IF_NULL(cloned_parameter);
361
362 // find the clone parameter
363 if (!cloned_parameter->has_default()) {
364 return false;
365 }
366 auto param_value = cloned_parameter->param_info();
367 if (param_value == nullptr) {
368 return false;
369 }
370 return param_value->strategy_ckpt_saved();
371 }
372
ParameterIsCloned(const AnfNodePtr & parameter_node)373 bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
374 MS_EXCEPTION_IF_NULL(parameter_node);
375 auto cloned_parameter = parameter_node->cast<ParameterPtr>();
376 MS_EXCEPTION_IF_NULL(cloned_parameter);
377
378 // find the clone parameter
379 if (!cloned_parameter->has_default()) {
380 return false;
381 }
382 auto param_value = cloned_parameter->param_info();
383 if (param_value == nullptr) {
384 return false;
385 }
386 bool cloned = param_value->cloned();
387 if (!cloned) {
388 return false;
389 }
390
391 MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
392 return true;
393 }
394
HandleNoUsedParameter(const FuncGraphPtr & root)395 void HandleNoUsedParameter(const FuncGraphPtr &root) {
396 MS_EXCEPTION_IF_NULL(root);
397 bool full_batch = ParallelContext::GetInstance()->full_batch();
398 if (full_batch) {
399 return;
400 }
401
402 auto dev_num = g_device_manager->stage_device_num();
403 auto parameters = root->parameters();
404 if (parameters.empty()) {
405 MS_LOG(INFO) << "Parameters is not in graph, thus no need to set parallel shape";
406 } else {
407 for (auto ¶meter : parameters) {
408 if (IsUsedParameter(root, parameter, 0)) {
409 continue;
410 }
411 auto parameter_shape = GetNodeShape(parameter);
412 if (parameter_shape.empty()) {
413 continue;
414 }
415 Shape slice_shape = parameter_shape[0];
416 if (slice_shape.empty() || slice_shape[0] < dev_num) {
417 continue;
418 }
419 slice_shape[0] = slice_shape[0] / dev_num;
420 auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
421 auto abstract = parameter->abstract();
422 MS_EXCEPTION_IF_NULL(abstract);
423 auto abstract_cloned = abstract->Clone();
424 MS_EXCEPTION_IF_NULL(abstract_cloned);
425 abstract_cloned->set_shape(slice_shape_ptr);
426 parameter->set_abstract(abstract_cloned);
427 }
428 }
429 }
430
IsFullySplitParameter(const ParameterPtr & param_ptr,size_t allow_repeat_num)431 bool IsFullySplitParameter(const ParameterPtr ¶m_ptr, size_t allow_repeat_num) {
432 auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
433 if (tensor_layout == nullptr) {
434 return false;
435 }
436
437 auto dev_mat_shape = tensor_layout->device_arrangement().array();
438 auto tensor_map = tensor_layout->tensor_map().array();
439 int64_t rank = g_device_manager->global_rank();
440 RankList rank_list = g_device_manager->GetDeviceListInThisStage();
441 DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
442 RankList group_devices;
443 if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
444 MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
445 return false;
446 }
447
448 if (group_devices.size() <= allow_repeat_num) {
449 MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
450 return true;
451 }
452 return false;
453 }
454
GetPyParameterObj(const ParamInfoPtr & param_info,const std::string & obj)455 py::object GetPyParameterObj(const ParamInfoPtr ¶m_info, const std::string &obj) {
456 py::object py_obj = py::cast(param_info);
457 if (py::isinstance<py::none>(py_obj)) {
458 return py::none();
459 }
460 return python_adapter::GetPyObjAttr(py_obj, obj);
461 }
462
IsAccuGradObj(const py::object & py_obj)463 static bool IsAccuGradObj(const py::object &py_obj) {
464 auto name = python_adapter::GetPyObjAttr(py_obj, PARAM_NAME);
465 if (py::isinstance<py::none>(name)) {
466 return false;
467 }
468 if (py::cast<std::string>(name).find(ACCU_GRADS) == 0) {
469 return true;
470 }
471 return false;
472 }
473
SliceParameterObj(const ParameterPtr & parameter,const TensorLayoutPtr & tensor_layout)474 void SliceParameterObj(const ParameterPtr ¶meter, const TensorLayoutPtr &tensor_layout) {
475 auto param_info = parameter->param_info();
476 if (param_info == nullptr) {
477 MS_LOG(WARNING) << "parameter: " << parameter->DebugString() << " doesn't have param_info.";
478 return;
479 }
480 auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
481 MS_EXCEPTION_IF_NULL(graph_executor);
482 auto phase = graph_executor->phase();
483 auto py_obj = GetPyParameterObj(param_info, OBJ);
484 if (py::isinstance<py::none>(py_obj)) {
485 MS_LOG(WARNING) << "Parameter: " << parameter->DebugString() << " can't find python obj.";
486 return;
487 }
488 if (tensor_layout == nullptr) {
489 (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, py_obj, py::str(phase),
490 py::none());
491 return;
492 }
493 // create python layout obj
494 const auto &device_arrangement = tensor_layout->device_arrangement().array();
495 const auto &tensor_map = tensor_layout->tensor_map().array();
496 auto slice_shape = tensor_layout->base_slice_shape().array();
497 int64_t field_size = tensor_layout->get_field_size();
498 bool uniform_split = tensor_layout->uniform_split();
499 std::string opt_shard_group = tensor_layout->opt_shard_group();
500 if (!opt_shard_group.empty()) {
501 slice_shape = tensor_layout->opt_shard_slice_shape();
502 }
503 auto full_shape = tensor_layout->tensor_shape().array();
504 py::tuple layout =
505 py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group, full_shape);
506
507 // Call Python _slice_parameter Fn to slice python parameter obj
508 (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, py_obj, py::str(phase), layout);
509
510 // handle cloned parameter, like accu_grad and optimizer param
511 auto grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
512 auto cloned_py_obj = GetPyParameterObj(param_info, CLONED_OBJ);
513 if (!py::isinstance<py::none>(cloned_py_obj)) {
514 if (!py::isinstance<py::list>(cloned_py_obj)) {
515 MS_LOG(EXCEPTION) << "parameter: " << parameter->DebugString() << " doesn't have correct cloned obj";
516 }
517 auto obj_list = py::cast<py::list>(cloned_py_obj);
518 for (size_t i = 0; i < obj_list.size(); ++i) {
519 py::object each_cloned_obj = obj_list[i];
520 auto cloned_param_slice_shape = tensor_layout->slice_shape().array();
521 if (!opt_shard_group.empty()) {
522 if (!IsAccuGradObj(each_cloned_obj) || grad_accumulation_shard) {
523 cloned_param_slice_shape = tensor_layout->opt_shard_slice_shape();
524 }
525 }
526 py::tuple cloned_param_layout = py::make_tuple(device_arrangement, tensor_map, cloned_param_slice_shape,
527 field_size, uniform_split, opt_shard_group, full_shape);
528 (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, each_cloned_obj, py::str(phase),
529 cloned_param_layout);
530 }
531 }
532 }
533
SliceTensorObj(const ParameterPtr & parameter,const TensorLayoutPtr & tensor_layout,size_t rank_id)534 void SliceTensorObj(const ParameterPtr ¶meter, const TensorLayoutPtr &tensor_layout, size_t rank_id) {
535 auto param = parameter->default_param();
536 MS_EXCEPTION_IF_NULL(param);
537 auto p_tensor = param->cast<tensor::TensorPtr>();
538 MS_EXCEPTION_IF_NULL(p_tensor);
539 if (p_tensor->DataSize() == 1) {
540 MS_LOG(INFO) << "The parameter's data size is 1, no need to layout.";
541 return;
542 }
543 if (tensor_layout == nullptr) {
544 MS_LOG(INFO) << "No need to layout parameter";
545 return;
546 }
547 // start get layout info
548 const auto &device_arrangement = tensor_layout->device_arrangement().array();
549 for (auto i : device_arrangement) std::cout << i << ' ';
550 const auto &tensor_map = tensor_layout->tensor_map().array();
551 auto slice_shape = tensor_layout->slice_shape().array();
552 int64_t field_size = tensor_layout->get_field_size();
553 bool uniform_split = tensor_layout->uniform_split();
554 if (uniform_split == 0) {
555 MS_LOG(ERROR) << "The load tensor only support uniform split now.";
556 }
557 std::string opt_shard_group = tensor_layout->opt_shard_group();
558 if (!opt_shard_group.empty()) {
559 slice_shape = tensor_layout->opt_shard_slice_shape();
560 }
561 py::tuple layout =
562 py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split, opt_shard_group);
563
564 MS_LOG(INFO) << "origin p_tensor:" << p_tensor->name() << p_tensor->Size() << p_tensor->shape();
565 auto tensor_py = python_adapter::CastToPyObj(p_tensor);
566 // Call Python _slice_tensor Fn to slice python tensor obj
567 auto new_tensor_py =
568 python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_TENSOR_FN_NAME, tensor_py, layout, rank_id);
569 MS_LOG(INFO) << "Success Call Python _slice_parameter Fn to slice python parameter obj";
570 auto new_tensor = new_tensor_py.cast<tensor::TensorPtr>();
571 MS_LOG(INFO) << "new p_tensor:" << new_tensor->name() << new_tensor->Size() << new_tensor->shape();
572 parameter->set_default_param(new_tensor);
573 }
574
SliceCacheParameterObj(const ParameterPtr & parameter,const py::dict & layout_dict)575 static void SliceCacheParameterObj(const ParameterPtr ¶meter, const py::dict &layout_dict) {
576 auto param_info = parameter->param_info();
577 if (param_info == nullptr) {
578 MS_LOG(WARNING) << "parameter: " << parameter->DebugString() << " doesn't have param_info.";
579 return;
580 }
581 auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
582 MS_EXCEPTION_IF_NULL(graph_executor);
583 auto phase = graph_executor->phase();
584 auto py_obj = GetPyParameterObj(param_info, OBJ);
585 if (py::isinstance<py::none>(py_obj)) {
586 MS_LOG(WARNING) << "Parameter: " << parameter->DebugString() << " can't find python obj.";
587 return;
588 }
589 auto name = parameter->name();
590 if (!layout_dict.contains(name)) {
591 (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, INIT_OPTIMIZER_STATE_FN, py_obj, py::str(phase));
592 return;
593 }
594 auto layout = layout_dict[py::str(name)];
595 // Call Python _slice_parameter Fn to slice python parameter obj
596 (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, py_obj, py::str(phase), layout);
597
598 // handle cloned parameter, like accu_grad and optimizer param
599 auto cloned_py_obj = GetPyParameterObj(param_info, CLONED_OBJ);
600 if (!py::isinstance<py::none>(cloned_py_obj)) {
601 if (!py::isinstance<py::list>(cloned_py_obj)) {
602 MS_LOG(EXCEPTION) << "parameter: " << parameter->DebugString() << " doesn't have correct cloned obj";
603 }
604 auto obj_list = py::cast<py::list>(cloned_py_obj);
605 for (size_t i = 0; i < obj_list.size(); ++i) {
606 py::object each_cloned_obj = obj_list[i];
607 (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, each_cloned_obj, py::str(phase),
608 layout);
609 }
610 }
611 }
612
InitCompileCacheParams(const pipeline::ResourcePtr & resource)613 void InitCompileCacheParams(const pipeline::ResourcePtr &resource) {
614 auto layout_dict = GetParameterLayoutFromResource(resource);
615 auto graph = resource->func_graph();
616 auto params = graph->parameters();
617 for (auto ¶m : params) {
618 auto param_ptr = param->cast<ParameterPtr>();
619 MS_EXCEPTION_IF_NULL(param_ptr);
620 if (!param_ptr->has_default()) {
621 continue;
622 }
623 SliceCacheParameterObj(param_ptr, layout_dict);
624 }
625 }
626
InitPynativeNoShardParams(const FuncGraphPtr & root)627 void InitPynativeNoShardParams(const FuncGraphPtr &root) {
628 auto parameters = root->parameters();
629 for (auto ¶meter : parameters) {
630 auto param_ptr = parameter->cast<ParameterPtr>();
631 MS_EXCEPTION_IF_NULL(param_ptr);
632 auto param_info = param_ptr->param_info();
633 if (!param_info) {
634 MS_LOG(DEBUG) << "Parameter:" << parameter->DebugString() << " doesn't have param_info.";
635 continue;
636 }
637 auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
638 MS_EXCEPTION_IF_NULL(graph_executor);
639 auto phase = graph_executor->phase();
640 auto py_obj = GetPyParameterObj(param_info, OBJ);
641 if (py::isinstance<py::none>(py_obj)) {
642 MS_LOG(WARNING) << "Parameter: " << parameter->DebugString() << " can't find python obj.";
643 continue;
644 }
645 (void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, INIT_OPTIMIZER_STATE_FN, py_obj, py::str(phase));
646 }
647 }
648
AutoParallelPostProcess(const FuncGraphPtr & root)649 void AutoParallelPostProcess(const FuncGraphPtr &root) {
650 auto parameters = root->parameters();
651 for (auto ¶m : parameters) {
652 if (ParameterIsCloned(param)) {
653 continue;
654 }
655 auto layout = param->user_data<TensorLayout>();
656 auto param_ptr = param->cast<ParameterPtr>();
657 MS_EXCEPTION_IF_NULL(param_ptr);
658 if (!param_ptr->has_default()) {
659 continue;
660 }
661 SliceParameterObj(param_ptr, layout);
662 }
663 }
664
SetClonedTensorShapeForOptimizer(const FuncGraphPtr & root)665 void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
666 MS_EXCEPTION_IF_NULL(root);
667 auto grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
668
669 for (auto &cloned_parameter_node : root->parameters()) {
670 MS_EXCEPTION_IF_NULL(cloned_parameter_node);
671 auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
672 MS_EXCEPTION_IF_NULL(cloned_parameter);
673
674 if (!ParameterIsCloned(cloned_parameter_node)) {
675 continue;
676 }
677 auto param_value = cloned_parameter->param_info();
678 if (param_value == nullptr) {
679 continue;
680 }
681 // get the cloned index
682 int64_t cloned_index = param_value->cloned_index();
683
684 // find the be cloned parameter
685 bool found_be_cloned_parameter = false;
686 ParameterPtr cloned_from_parameter = nullptr;
687 AnfNodePtr cloned_from_node = nullptr;
688 for (auto &be_cloned_parameter_node : root->parameters()) {
689 MS_EXCEPTION_IF_NULL(be_cloned_parameter_node);
690 auto be_cloned_parameter = be_cloned_parameter_node->cast<ParameterPtr>();
691 MS_EXCEPTION_IF_NULL(be_cloned_parameter);
692 if (!be_cloned_parameter->has_default()) {
693 continue;
694 }
695
696 auto param_value_in = be_cloned_parameter->param_info();
697 if (param_value_in == nullptr) {
698 continue;
699 }
700 if (!param_value_in->be_cloned()) {
701 continue;
702 }
703
704 // get the be cloned index
705 auto &be_cloned_index = param_value_in->be_cloned_index();
706 if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
707 found_be_cloned_parameter = true;
708 cloned_from_parameter = be_cloned_parameter;
709 cloned_from_node = be_cloned_parameter_node;
710 }
711 }
712
713 if (found_be_cloned_parameter) {
714 // set the shape and tensor layout for cloned parameter
715 std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
716 if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
717 MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
718 continue;
719 }
720 auto tensor_layout = cloned_from_parameter->user_data<TensorLayout>();
721 MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
722 MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
723 auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
724 MS_EXCEPTION_IF_NULL(cloned_abstract);
725 // from pipeline or grad accumulation
726 if (param_name.find(ACCU_GRADS) != std::string::npos) {
727 auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
728 auto opt_shard_group = tensor_layout->opt_shard_group();
729 auto opt_shard_shape = cloned_from_parameter->user_data<TensorLayout>()->opt_shard_slice_shape();
730 std::shared_ptr<abstract::BaseShape> parallel_shape = nullptr;
731 // set opt shard shape if the pipeline sharding is set
732 if (grad_accumulation_shard && !opt_shard_group.empty()) {
733 parallel_shape = std::make_shared<abstract::Shape>(opt_shard_shape);
734 } else {
735 parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
736 }
737 MS_EXCEPTION_IF_NULL(parallel_shape);
738 cloned_abstract->set_shape(parallel_shape);
739 // in opt shard, accu_grad's shape is different from the original param's shape
740 // if the grad_accumulation_shard is enabled, the accu_grads will be a opt-sharded shape
741 if (!grad_accumulation_shard && ParallelContext::GetInstance()->enable_parallel_optimizer()) {
742 TensorLayout new_layout = *tensor_layout;
743 new_layout.set_opt_shard_group("");
744 tensor_layout = std::make_shared<TensorLayout>(new_layout);
745 }
746 } else {
747 cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
748 }
749 cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
750 cloned_parameter_node->set_abstract(cloned_abstract);
751 // copy the fusion tag
752 auto cloned_param_info = cloned_parameter->param_info();
753 MS_EXCEPTION_IF_NULL(cloned_param_info);
754 auto cloned_from_param_info = cloned_from_parameter->param_info();
755 MS_EXCEPTION_IF_NULL(cloned_from_param_info);
756 cloned_param_info->set_comm_fusion(cloned_from_param_info->comm_fusion());
757
758 MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
759 << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
760 << ", clone index is: " << cloned_index;
761 } else {
762 MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is "
763 << cloned_index << ", but not found the be cloned parameter";
764 }
765 }
766 }
767
768 // For adafactor optimizer, the relationship between parameter and state's shape as follows:
769 // 1) parameter: [A, B, C, D] (shape_size > 2), exp_avg_sq_row: [A, B, C], exp_avg_sq_col: [A, B, D], exp_avg_sq: [1]
770 // If the parameter is opt shard, the exp_avg_sq_row and exp_avg_sq_col need to be shard accordingly.
771 // 2) parameter: [A, B] (shape_size = 2), exp_avg_sq_row: [A], exp_avg_sq_col: [B], exp_avg_sq: [1]
772 // If the parameter is opt shard, the exp_avg_sq_row needs to be shard accordingly.
773 // 3) parameter: [A] (shape_size = 1), exp_avg_sq_row: [1], exp_avg_sq_col: [1], exp_avg_sq: [A]
774 // If the parameter is opt shard, the exp_avg_sq needs to be shard accordingly.
AdafactorStateIsOptShard(const std::string & opt_shard_group,size_t shape_size,const std::string & param_name,const std::string & state_name)775 static bool AdafactorStateIsOptShard(const std::string &opt_shard_group, size_t shape_size,
776 const std::string ¶m_name, const std::string &state_name) {
777 if (opt_shard_group.empty()) {
778 return false;
779 }
780
781 std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
782 std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
783 std::string exp_avg_name = EXP_AVG_SQ + param_name;
784 std::string exp_insta_row_name = EXP_AVG_INSTA_ROW + param_name;
785 std::string exp_insta_col_name = EXP_AVG_INSTA_COL + param_name;
786
787 if (shape_size > 2 && state_name == exp_avg_name) {
788 return false;
789 }
790
791 if (shape_size == 2 &&
792 (state_name == exp_col_name || state_name == exp_avg_name || state_name == exp_insta_col_name)) {
793 return false;
794 }
795
796 if (shape_size == 1 &&
797 (state_name == exp_row_name || state_name == exp_col_name || state_name == exp_insta_row_name)) {
798 return false;
799 }
800
801 MS_LOG(INFO) << "The parameter " << param_name << " is opt shard";
802 return true;
803 }
804
IsOriginWeight(const ParameterPtr & param)805 static bool IsOriginWeight(const ParameterPtr ¶m) {
806 std::string param_name = param->name();
807 if (param_name.find(EXP_AVG) != std::string::npos) {
808 return false;
809 }
810
811 auto tensor_layout = param->user_data<TensorLayout>();
812 if (tensor_layout == nullptr) {
813 return false;
814 }
815
816 return true;
817 }
818
FindParameterByValueNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph,const std::string & name=ALL_REDUCE)819 static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
820 const std::string &name = ALL_REDUCE) {
821 if (IsValueNode<RefKey>(node)) {
822 std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
823 if (param_v.size() != 1) {
824 MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is "
825 << param_v.size();
826 }
827 auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
828 if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
829 name == ALL_REDUCE) {
830 return std::make_pair(nullptr, true);
831 }
832 return std::make_pair(node, true);
833 }
834 return std::make_pair(nullptr, false);
835 }
836
RefParameterToActualParameter(const AnfNodePtr & node)837 AnfNodePtr RefParameterToActualParameter(const AnfNodePtr &node) {
838 if (!node->isa<Parameter>()) {
839 return nullptr;
840 }
841 auto node_param_ptr = node->cast<ParameterPtr>();
842 if (node_param_ptr->has_default()) {
843 return node;
844 }
845 auto sub_func_graph = node_param_ptr->func_graph();
846 auto call_cnodes_map = sub_func_graph->func_graph_cnodes_index();
847 auto sub_graph_parameters = sub_func_graph->parameters();
848 auto curr_param_iter = std::find(sub_graph_parameters.begin(), sub_graph_parameters.end(), node);
849 if (curr_param_iter == sub_graph_parameters.end()) {
850 MS_LOG(EXCEPTION) << "Cannot find param " << node_param_ptr->DebugString() << " in current sub_graph";
851 }
852 size_t curr_param_index = static_cast<size_t>(curr_param_iter - sub_graph_parameters.begin());
853 for (const auto &node_pair : call_cnodes_map) {
854 if (!node_pair.first->first->isa<CNode>() || node_pair.first->second > 0) {
855 continue;
856 }
857 auto cnode = node_pair.first->first->cast<CNodePtr>();
858 auto cnode_input = cnode->input(curr_param_index + 1);
859 auto new_cnode = GetInputNodeWithFilter(cnode_input, [&](const CNodePtr &cnode) {
860 bool filter = IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather) ||
861 IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimDepend) ||
862 IsPrimitiveCNode(cnode, prim::kPrimCast) ||
863 (IsPrimitiveCNode(cnode, prim::kPrimAllGather) &&
864 GetCNodePrimitive(cnode)->instance_name().find(PARALLEL_OPTIMIZER) != std::string::npos);
865 return std::make_pair(filter, 1);
866 });
867 return RefParameterToActualParameter(new_cnode);
868 }
869 return nullptr;
870 }
871
FindParameterByParameter(const AnfNodePtr & node,const std::string & name=ALL_REDUCE)872 static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node,
873 const std::string &name = ALL_REDUCE) {
874 if (!node->isa<Parameter>()) {
875 MS_LOG(EXCEPTION) << "The node is not a parameter, node:" << node->DebugString();
876 }
877 auto node_param_ptr = node->cast<ParameterPtr>();
878 if (node_param_ptr->has_default()) {
879 auto param_ptr = node->user_data<parallel::TensorLayout>();
880 if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty() &&
881 name == ALL_REDUCE) {
882 return std::make_pair(nullptr, false);
883 }
884 return std::make_pair(node, false);
885 }
886 AnfNodePtr ref_param = RefParameterToActualParameter(node);
887 if (!ref_param) {
888 return std::make_pair(nullptr, false);
889 }
890 auto ref_param_layout = ref_param->user_data<parallel::TensorLayout>();
891 if (ref_param_layout && !ref_param_layout->opt_shard_group().empty() &&
892 ref_param_layout->opt_shard_mirror_group().empty() && name == ALL_REDUCE) {
893 return std::make_pair(nullptr, false);
894 }
895 return std::make_pair(ref_param, false);
896 }
897
FindParameterByFuncGraph(const AnfNodePtr & node)898 static std::pair<AnfNodePtr, bool> FindParameterByFuncGraph(const AnfNodePtr &node) {
899 auto cnode = node->cast<CNodePtr>();
900 MS_EXCEPTION_IF_NULL(cnode);
901 auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
902
903 auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first;
904 if (pre_node) {
905 return FindParameter(pre_node, pre_node->func_graph());
906 }
907 return std::make_pair(nullptr, false);
908 }
909
910 // Only used for InsertMirrorOps
FindParameter(const AnfNodePtr & node,const FuncGraphPtr & func_graph)911 std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
912 if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
913 return std::make_pair(nullptr, false);
914 }
915
916 if (node->isa<Parameter>()) {
917 return FindParameterByParameter(node);
918 }
919
920 if (node->isa<ValueNode>()) {
921 return FindParameterByValueNode(node, func_graph);
922 }
923 CNodePtr cnode = node->cast<CNodePtr>();
924 MS_EXCEPTION_IF_NULL(cnode);
925 if (IsValueNode<FuncGraph>(cnode->input(0))) {
926 return FindParameterByFuncGraph(node);
927 }
928 if (!IsValueNode<Primitive>(cnode->input(0))) {
929 for (size_t index = 0; index < cnode->size(); ++index) {
930 auto res = FindParameter(cnode->input(index), func_graph);
931 if (!res.first) {
932 continue;
933 }
934 return res;
935 }
936 }
937
938 // When not fully use opt shard, allgather and mirror would be both inserted.
939 // Skip allgather here and find parameter recursively.
940 if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
941 return std::make_pair(nullptr, false);
942 }
943 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
944 MS_EXCEPTION_IF_NULL(prim_anf_node);
945 for (size_t index = 0; index < cnode->size(); ++index) {
946 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
947 MS_EXCEPTION_IF_NULL(prim);
948 if ((prim->name() == DEPEND || prim->name() == LOAD || IsInAllGatherNodeList(cnode)) && index != 1) {
949 continue;
950 }
951 auto res = FindParameter(cnode->input(index), func_graph);
952 if (!res.first) {
953 continue;
954 }
955 return res;
956 }
957 return std::make_pair(nullptr, false);
958 }
959
960 // Used for allgather and reducescatter
FindParameterWithAllgather(const AnfNodePtr & node,const FuncGraphPtr & func_graph,const std::string & name)961 std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
962 const std::string &name) {
963 if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
964 return std::make_pair(nullptr, false);
965 }
966
967 if (node->isa<Parameter>()) {
968 return FindParameterByParameter(node, name);
969 }
970
971 if (node->isa<ValueNode>()) {
972 return FindParameterByValueNode(node, func_graph, name);
973 }
974
975 CNodePtr cnode = node->cast<CNodePtr>();
976 MS_EXCEPTION_IF_NULL(cnode);
977 for (size_t index = 0; index < cnode->size(); ++index) {
978 if (index != 1) {
979 continue;
980 }
981 auto res = FindParameterWithAllgather(cnode->input(index), func_graph, name);
982 if (!res.first) {
983 continue;
984 }
985 return res;
986 }
987 return std::make_pair(nullptr, false);
988 }
989
AdaSumParamTensorLayout(const FuncGraphPtr & root)990 std::unordered_map<std::string, std::shared_ptr<TensorLayout>> AdaSumParamTensorLayout(const FuncGraphPtr &root) {
991 MS_EXCEPTION_IF_NULL(root);
992 std::unordered_map<std::string, std::shared_ptr<TensorLayout>> adasum_param_map;
993 for (auto ¶meter_node : root->parameters()) {
994 MS_EXCEPTION_IF_NULL(parameter_node);
995 auto cloned_parameter = parameter_node->cast<ParameterPtr>();
996 MS_EXCEPTION_IF_NULL(cloned_parameter);
997
998 if (!ParameterIsCloned(parameter_node)) {
999 auto parameter_tensor_layout = cloned_parameter->user_data<TensorLayout>();
1000 adasum_param_map["adasum_delta_weight." + cloned_parameter->name()] = parameter_tensor_layout;
1001 }
1002 }
1003 return adasum_param_map;
1004 }
1005
ValueSequeueScaleToShape(const ValuePtr & value_seq,const Shape & scale,size_t expand_ratio=1)1006 Shape ValueSequeueScaleToShape(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) {
1007 if (!value_seq->isa<ValueSequeue>()) {
1008 MS_LOG(EXCEPTION) << "The input is not a value_sequeue";
1009 }
1010 std::vector<int64_t> origin_value_vector;
1011 if (TransValueSequeueToVector(value_seq, &origin_value_vector) != SUCCESS) {
1012 MS_LOG(EXCEPTION) << "Transform value_seq to vector failed";
1013 }
1014 if (origin_value_vector.size() > scale.size()) {
1015 MS_LOG(EXCEPTION) << "Cannot scale, the size of value_seq is: " << origin_value_vector.size()
1016 << ", which should be less_equal than scale's size which is: " << scale.size();
1017 }
1018 for (size_t i = 0; i < origin_value_vector.size(); ++i) {
1019 origin_value_vector[i] = origin_value_vector[i] / scale[i];
1020 if (i == 0) {
1021 origin_value_vector[i] = origin_value_vector[i] * SizeToLong(expand_ratio);
1022 }
1023 }
1024 return origin_value_vector;
1025 }
1026
ValueSequeueScale(const ValuePtr & value_seq,const Shape & scale,size_t expand_ratio=1)1027 ValuePtr ValueSequeueScale(const ValuePtr &value_seq, const Shape &scale, size_t expand_ratio = 1) {
1028 Shape origin_value_vector = ValueSequeueScaleToShape(value_seq, scale, expand_ratio);
1029 if (value_seq->isa<ValueTuple>()) {
1030 return TransVectorToValueSequeue<ValueTuple>(origin_value_vector);
1031 }
1032 return TransVectorToValueSequeue<ValueList>(origin_value_vector);
1033 }
1034
ReplaceAdaSumStridedSliceValue(const CNodePtr & stridedslice_cnode1,const std::shared_ptr<TensorLayout> & target_param_layout,size_t slice_expand_ratio)1035 void ReplaceAdaSumStridedSliceValue(const CNodePtr &stridedslice_cnode1,
1036 const std::shared_ptr<TensorLayout> &target_param_layout,
1037 size_t slice_expand_ratio) {
1038 auto target_param_info = std::make_shared<TensorInfo>(target_param_layout->SqueezeShape());
1039 Dimensions param_strategy = target_param_info->InferStrategy();
1040 auto new_begin1_value =
1041 ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(2)), param_strategy, slice_expand_ratio);
1042 auto new_end1_value =
1043 ValueSequeueScale(GetValueNode(stridedslice_cnode1->input(3)), param_strategy, slice_expand_ratio);
1044 ValueNodePtr new_begin_value_node = std::make_shared<ValueNode>(new_begin1_value);
1045 ValueNodePtr new_end_value_node = std::make_shared<ValueNode>(new_end1_value);
1046 stridedslice_cnode1->set_input(2, new_begin_value_node);
1047 stridedslice_cnode1->set_input(3, new_end_value_node);
1048 }
1049
GetRankListByLayout(const std::shared_ptr<TensorLayout> & target_param_layout)1050 RankList GetRankListByLayout(const std::shared_ptr<TensorLayout> &target_param_layout) {
1051 int64_t rank = g_device_manager->global_rank();
1052 auto dev_shape = target_param_layout->device_arrangement().array();
1053 auto stage_device_list = g_device_manager->GetDeviceListInThisStage();
1054 DeviceMatrix dev_matrix(rank, stage_device_list, dev_shape);
1055 RankList group_devices;
1056 if (dev_matrix.GetDevicesByTensorMap(target_param_layout->tensor_map().array(), &group_devices) != SUCCESS) {
1057 MS_LOG(EXCEPTION) << "Get adasum parameter origin mirror group by tensor layout failed.";
1058 }
1059 return group_devices;
1060 }
1061
IsBorderAdaSumSendReceive(const AnfNodePtr & node,const RankList & group_devices)1062 std::vector<bool> IsBorderAdaSumSendReceive(const AnfNodePtr &node, const RankList &group_devices) {
1063 bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
1064 PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
1065 int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
1066 int64_t rank = g_device_manager->global_rank();
1067 if (group_devices.size() - 1 == 0) {
1068 MS_LOG(EXCEPTION) << "May division by zero.";
1069 }
1070 int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / SizeToLong(group_devices.size() - 1);
1071 if (adasum_rank_distance < ADASUM_MIN_DIS) {
1072 adasum_rank_distance = ADASUM_MIN_DIS;
1073 }
1074 size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS));
1075 int64_t fusion_id = GetValue<int64_t>(send_rec_prim->GetAttr("origin_fusion"));
1076 // when cutting nodes, the fusion id should change.
1077 int64_t new_fusion_id = fusion_id + SizeToLong(g_device_manager->DeviceNum() * (border_step + IntToSize(1)));
1078 send_rec_prim->set_attr(FUSION, MakeValue(new_fusion_id));
1079 std::vector<int64_t> group_list;
1080 int64_t new_dest_src_rank;
1081 if (rank > origin_dest_rank) {
1082 group_list = {origin_dest_rank, rank};
1083 new_dest_src_rank = 0;
1084 } else {
1085 group_list = {rank, origin_dest_rank};
1086 new_dest_src_rank = 1;
1087 }
1088 Group adasum_send_rec_group;
1089 if (g_device_manager->CreateGroup(group_list, &adasum_send_rec_group) != SUCCESS) {
1090 MS_LOG(EXCEPTION) << "Create send/receive group in adasum failed, the group is:" << group_list;
1091 }
1092 send_rec_prim->set_attr(GROUP, MakeValue(adasum_send_rec_group.name()));
1093 if (is_send) {
1094 send_rec_prim->set_attr(DEST_RANK, MakeValue(new_dest_src_rank));
1095 } else {
1096 send_rec_prim->set_attr(SRC_RANK, MakeValue(new_dest_src_rank));
1097 }
1098 int64_t rank_dis = abs(origin_dest_rank - rank);
1099 if (adasum_rank_distance == ADASUM_MIN_DIS) {
1100 return {false, false, false, false};
1101 }
1102 bool is_origin_first_node_if_forward = false;
1103 bool is_new_first_node_if_forward = false;
1104 bool is_origin_last_node_if_rollback = false;
1105 bool is_new_last_node_if_rollback = false;
1106 if (rank_dis == ADASUM_MIN_DIS) {
1107 is_origin_first_node_if_forward = true;
1108 is_origin_last_node_if_rollback = true;
1109 }
1110 if (rank_dis == adasum_rank_distance) {
1111 is_new_first_node_if_forward = true;
1112 }
1113 if (rank_dis == adasum_rank_distance / 2) {
1114 is_new_last_node_if_rollback = true;
1115 }
1116 return {is_origin_first_node_if_forward, is_new_first_node_if_forward, is_origin_last_node_if_rollback,
1117 is_new_last_node_if_rollback};
1118 }
1119
HandleAdaSumReshape(const CNodePtr & reshape_cnode,const std::shared_ptr<TensorLayout> & target_param_layout)1120 void HandleAdaSumReshape(const CNodePtr &reshape_cnode, const std::shared_ptr<TensorLayout> &target_param_layout) {
1121 auto slice_shape = target_param_layout->slice_shape().array();
1122 auto slice_shape_value = TransVectorToValueSequeue<ValueTuple>(slice_shape);
1123 ValueNodePtr new_slice_shape_value_node = std::make_shared<ValueNode>(slice_shape_value);
1124 reshape_cnode->set_input(2, new_slice_shape_value_node);
1125 }
1126
RemoveAdasumRedundantNodes(const FuncGraphManagerPtr & manager,std::unordered_map<std::string,CNodePtr> * forward_origin_first_node_map,std::unordered_map<std::string,CNodePtr> * forward_new_first_node_map,std::unordered_map<std::string,CNodePtr> * rollback_origin_last_node_map,std::unordered_map<std::string,CNodePtr> * rollback_new_last_node_map)1127 void RemoveAdasumRedundantNodes(const FuncGraphManagerPtr &manager,
1128 std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map,
1129 std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map,
1130 std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map,
1131 std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map) {
1132 // connect forward last node and rollback first node
1133 if (forward_origin_first_node_map->size() != forward_new_first_node_map->size() ||
1134 rollback_origin_last_node_map->size() != rollback_new_last_node_map->size()) {
1135 MS_LOG(EXCEPTION) << "The over border node is not equal in adasum forward process and rollback process.";
1136 }
1137 for (auto node : *forward_origin_first_node_map) {
1138 std::string target_param = node.first;
1139 CNodePtr forward_origin_first_node = node.second;
1140 CNodePtr forward_new_first_node = (*forward_new_first_node_map)[target_param];
1141 manager->SetEdge(forward_new_first_node, 1, forward_origin_first_node->input(1));
1142 }
1143 for (auto node : *rollback_origin_last_node_map) {
1144 std::string target_param = node.first;
1145 CNodePtr rollback_origin_last_node = node.second;
1146 CNodePtr rollback_new_last_node = (*rollback_new_last_node_map)[target_param];
1147 (void)manager->Replace(rollback_origin_last_node, rollback_new_last_node);
1148 }
1149 }
1150
HandleAdasumAllReduce(const PrimitivePtr & prim,const RankList & group_devices)1151 void HandleAdasumAllReduce(const PrimitivePtr &prim, const RankList &group_devices) {
1152 size_t step = size_t(GetValue<int64_t>(prim->GetAttr("step")));
1153 std::vector<int64_t> neighbor_ids;
1154 int64_t adasum_rank_distance =
1155 (group_devices.back() - group_devices.front()) / SizeToLong((group_devices.size() - 1));
1156 if (adasum_rank_distance < ADASUM_MIN_DIS) {
1157 adasum_rank_distance = ADASUM_MIN_DIS;
1158 }
1159 size_t border_step = size_t(log2(adasum_rank_distance / ADASUM_MIN_DIS));
1160 MS_LOG(INFO) << "current border step is: " << border_step;
1161 if (step < border_step) {
1162 return;
1163 }
1164 int64_t rank = g_device_manager->global_rank();
1165 size_t double_d = size_t(IntToSize(2) << step);
1166 for (size_t index = 0; index < double_d; ++index) {
1167 int64_t node_rank = rank / ADASUM_MIN_DIS;
1168 int64_t neighbor_id =
1169 (node_rank / SizeToLong(double_d) * SizeToLong(double_d) + SizeToLong(index)) * ADASUM_MIN_DIS +
1170 rank % ADASUM_MIN_DIS;
1171 neighbor_ids.push_back(neighbor_id);
1172 }
1173 Group adasum_allreduce_group;
1174 if (g_device_manager->CreateGroup(neighbor_ids, &adasum_allreduce_group) != SUCCESS) {
1175 MS_LOG(EXCEPTION) << "Create group allreduce group in adasum failed, the group is " << neighbor_ids;
1176 }
1177 auto new_group_name = MakeValue(adasum_allreduce_group.name());
1178 int64_t fusion_id = GetValue<int64_t>(prim->GetAttr("origin_fusion"));
1179 int64_t new_fusion_id = fusion_id + SizeToLong(g_device_manager->DeviceNum() * (border_step + IntToSize(1)));
1180 prim->set_attr(GROUP, new_group_name);
1181 prim->set_attr(FUSION, MakeValue(new_fusion_id));
1182 }
1183
HandleAdasumSlice(const AnfNodePtr & stridedslice_node1,const std::shared_ptr<TensorLayout> & target_param_layout,size_t slice_expand_ratio)1184 void HandleAdasumSlice(const AnfNodePtr &stridedslice_node1, const std::shared_ptr<TensorLayout> &target_param_layout,
1185 size_t slice_expand_ratio) {
1186 auto stridedslice_cnode1 = stridedslice_node1->cast<CNodePtr>();
1187 ReplaceAdaSumStridedSliceValue(stridedslice_cnode1, target_param_layout, slice_expand_ratio);
1188 auto squeeze_node = RealInputNode(stridedslice_cnode1, 1);
1189 if (!IsPrimitiveCNode(squeeze_node, prim::kPrimSqueeze)) {
1190 MS_LOG(EXCEPTION) << "The stridedslice input node should be squeeze in adasum";
1191 }
1192 auto squeeze_cnode = squeeze_node->cast<CNodePtr>();
1193 FuncGraphManagerPtr manager = squeeze_node->func_graph()->manager();
1194 MS_EXCEPTION_IF_NULL(manager);
1195 AnfNodeIndexSet node_set = manager->node_users()[squeeze_cnode];
1196 for (auto &node_pair : node_set) {
1197 if (IsPrimitiveCNode(node_pair.first, prim::kPrimStridedSlice) && node_pair.first != stridedslice_node1) {
1198 CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
1199 ReplaceAdaSumStridedSliceValue(use_apply, target_param_layout, slice_expand_ratio);
1200 }
1201 }
1202 }
1203
HandleAdaSumConcat(const AnfNodePtr & concat_node,const std::vector<bool> & border_info,const std::string & target_param,std::unordered_map<std::string,CNodePtr> * rollback_new_last_node_map,std::unordered_map<std::string,CNodePtr> * rollback_origin_last_node_map)1204 void HandleAdaSumConcat(const AnfNodePtr &concat_node, const std::vector<bool> &border_info,
1205 const std::string &target_param,
1206 std::unordered_map<std::string, CNodePtr> *rollback_new_last_node_map,
1207 std::unordered_map<std::string, CNodePtr> *rollback_origin_last_node_map) {
1208 if (border_info[3]) {
1209 (*rollback_new_last_node_map)[target_param] = concat_node->cast<CNodePtr>();
1210 }
1211 if (border_info[2]) {
1212 auto manager = concat_node->func_graph()->manager();
1213 AnfNodeIndexSet concat_node_user_set = manager->node_users()[concat_node];
1214 for (auto &node_pair : concat_node_user_set) {
1215 if (IsPrimitiveCNode(node_pair.first, prim::kPrimMakeTuple)) {
1216 AnfNodeIndexSet make_tuple_node_user_set = manager->node_users()[node_pair.first];
1217 for (auto &tuple_user : make_tuple_node_user_set) {
1218 if (IsPrimitiveCNode(tuple_user.first, prim::kPrimConcat)) {
1219 (*rollback_origin_last_node_map)[target_param] = tuple_user.first->cast<CNodePtr>();
1220 return;
1221 }
1222 }
1223 return;
1224 }
1225 }
1226 }
1227 }
1228
HandleAdaSumSqueeze(const AnfNodePtr & stridedslice_node1,const std::vector<bool> & border_info,const std::string & target_param,std::unordered_map<std::string,CNodePtr> * forward_origin_first_node_map,std::unordered_map<std::string,CNodePtr> * forward_new_first_node_map)1229 void HandleAdaSumSqueeze(const AnfNodePtr &stridedslice_node1, const std::vector<bool> &border_info,
1230 const std::string &target_param,
1231 std::unordered_map<std::string, CNodePtr> *forward_origin_first_node_map,
1232 std::unordered_map<std::string, CNodePtr> *forward_new_first_node_map) {
1233 auto squeeze_node = RealInputNode(stridedslice_node1->cast<CNodePtr>(), 1);
1234 if (border_info[0]) {
1235 (*forward_origin_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>();
1236 }
1237 if (border_info[1]) {
1238 (*forward_new_first_node_map)[target_param] = squeeze_node->cast<CNodePtr>();
1239 }
1240 }
1241
HandleAdaSumPureModelParallel(const AnfNodePtr & node)1242 void HandleAdaSumPureModelParallel(const AnfNodePtr &node) {
1243 if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
1244 return;
1245 }
1246 PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
1247 int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
1248 int64_t rank = g_device_manager->global_rank();
1249 CNodePtr cnode = node->cast<CNodePtr>();
1250 auto pre_cnode = RealInputNode(cnode, 1);
1251 int64_t rank_dis = abs(origin_dest_rank - rank);
1252 if (rank_dis == ADASUM_MIN_DIS && IsPrimitiveCNode(pre_cnode, prim::kPrimStridedSlice)) {
1253 auto squeeze_node = pre_cnode->cast<CNodePtr>()->input(1);
1254 if (!IsPrimitiveCNode(squeeze_node, prim::kPrimSqueeze)) {
1255 return;
1256 }
1257 auto squeeze_input = squeeze_node->cast<CNodePtr>()->input(1);
1258 auto manager = squeeze_node->func_graph()->manager();
1259 AnfNodeIndexSet squeeze_input_node_user_set = manager->node_users()[squeeze_input];
1260 for (auto &squeeze_input_user : squeeze_input_node_user_set) {
1261 if (IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimSqueeze) ||
1262 IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimUpdateState) ||
1263 IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimMakeTuple)) {
1264 continue;
1265 }
1266 (void)manager->Replace(squeeze_input_user.first, squeeze_input);
1267 }
1268 }
1269 }
1270
HandleAdaSum(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes,std::unordered_map<std::string,std::shared_ptr<TensorLayout>> * adasum_param_tensor_layout_map)1271 bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
1272 std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) {
1273 std::unordered_map<std::string, CNodePtr> forward_origin_first_node_map;
1274 std::unordered_map<std::string, CNodePtr> forward_new_first_node_map;
1275 std::unordered_map<std::string, CNodePtr> rollback_origin_last_node_map;
1276 std::unordered_map<std::string, CNodePtr> rollback_new_last_node_map;
1277 bool is_adasum = false;
1278 for (auto &node : all_nodes) {
1279 bool is_allreduce = IsPrimitiveCNode(node, prim::kPrimAllReduce);
1280 bool is_reshape = IsPrimitiveCNode(node, prim::kPrimReshape);
1281 bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
1282 bool is_receive = IsPrimitiveCNode(node, prim::kPrimReceive);
1283 if (!is_allreduce && !is_reshape && !is_send && !is_receive) {
1284 continue;
1285 }
1286 std::string target_param;
1287 CNodePtr cnode = node->cast<CNodePtr>();
1288 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
1289 if (!prim->HasAttr(TARGET_PARAM)) {
1290 continue;
1291 }
1292 target_param = GetValue<std::string>(prim->GetAttr(TARGET_PARAM));
1293 auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];
1294 RankList group_devices = GetRankListByLayout(target_param_layout);
1295 // only model parallel
1296 if (group_devices.size() == 1) {
1297 HandleAdaSumPureModelParallel(node);
1298 continue;
1299 }
1300
1301 int64_t adasum_rank_distance =
1302 (group_devices.back() - group_devices.front()) / SizeToLong((group_devices.size() - 1));
1303 // when the repeat dim is right, the parameter do not enable adasum.
1304 if (adasum_rank_distance == 1 && group_devices.size() < size_t(g_device_manager->stage_device_num())) {
1305 continue;
1306 }
1307 MS_LOG(INFO) << "Apply adasum in auto parallel, current dealing node is: " << node->fullname_with_scope();
1308 is_adasum = true;
1309 size_t slice_expand_ratio =
1310 LongToSize(adasum_rank_distance / ADASUM_MIN_DIS) > 0 ? LongToSize(adasum_rank_distance / ADASUM_MIN_DIS) : 1;
1311 if (is_reshape) {
1312 HandleAdaSumReshape(cnode, (*adasum_param_tensor_layout_map)[target_param]);
1313 }
1314 if (is_allreduce && prim->HasAttr("step")) {
1315 HandleAdasumAllReduce(prim, group_devices);
1316 }
1317 if (is_send || is_receive) {
1318 std::vector<bool> border_info = IsBorderAdaSumSendReceive(node, group_devices);
1319 if (is_receive) {
1320 auto target_param_info = std::make_shared<TensorInfo>(*target_param_layout);
1321 Dimensions param_strategy = target_param_info->InferStrategy();
1322 Shape new_rec_shape = ValueSequeueScaleToShape(prim->GetAttr(SHAPE), param_strategy, slice_expand_ratio);
1323 auto new_rec_shape_value = TransVectorToValueSequeue<ValueList>(new_rec_shape);
1324 prim->set_attr(SHAPE, new_rec_shape_value);
1325 continue;
1326 }
1327 auto stridedslice_node1 = RealInputNode(cnode, 1);
1328 if (IsPrimitiveCNode(stridedslice_node1, prim::kPrimConcat)) {
1329 HandleAdaSumConcat(stridedslice_node1, border_info, target_param, &rollback_new_last_node_map,
1330 &rollback_origin_last_node_map);
1331 continue;
1332 }
1333 if (!IsPrimitiveCNode(stridedslice_node1, prim::kPrimStridedSlice)) {
1334 continue;
1335 }
1336 HandleAdasumSlice(stridedslice_node1, target_param_layout, slice_expand_ratio);
1337 HandleAdaSumSqueeze(stridedslice_node1, border_info, target_param, &forward_origin_first_node_map,
1338 &forward_new_first_node_map);
1339 }
1340 }
1341 RemoveAdasumRedundantNodes(root->manager(), &forward_origin_first_node_map, &forward_new_first_node_map,
1342 &rollback_origin_last_node_map, &rollback_new_last_node_map);
1343 return is_adasum;
1344 }
1345
ResetMirrorAttr(const PrimitivePtr & prim,const RankList & new_group)1346 void ResetMirrorAttr(const PrimitivePtr &prim, const RankList &new_group) {
1347 if (new_group.size() == 1) {
1348 prim->set_attr(DEV_NUM, MakeValue<int64_t>(SizeToLong(new_group.size())));
1349 prim->set_attr(GROUP, MakeValue("one_rank_group"));
1350 prim->set_attr(GROUP_RANKS, MakeValue(std::to_string(new_group[0])));
1351 return;
1352 }
1353 Group adasum_mirror_group;
1354 if (g_device_manager->CreateGroup(new_group, &adasum_mirror_group) != SUCCESS) {
1355 MS_LOG(EXCEPTION) << "Create new mirror group failed in adasum, new group is: " << new_group;
1356 }
1357 auto new_group_name = MakeValue(adasum_mirror_group.name());
1358 prim->set_attr(GROUP, new_group_name);
1359 prim->set_attr(DEV_NUM, MakeValue<int64_t>(SizeToLong(new_group.size())));
1360 std::string rank_list_name = g_device_manager->FindRankListNameByHashName(adasum_mirror_group.name());
1361 prim->set_attr(GROUP_RANKS, MakeValue(rank_list_name));
1362 }
1363
HandleMirrorInAdaSum(const FuncGraphPtr & root,std::unordered_map<std::string,std::shared_ptr<TensorLayout>> * adasum_param_tensor_layout_map)1364 void HandleMirrorInAdaSum(
1365 const FuncGraphPtr &root,
1366 std::unordered_map<std::string, std::shared_ptr<TensorLayout>> *adasum_param_tensor_layout_map) {
1367 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root->get_return());
1368 for (auto &node : all_nodes) {
1369 if (!IsPrimitiveCNode(node, prim::kPrimMirror)) {
1370 continue;
1371 }
1372 CNodePtr mirror_cnode = node->cast<CNodePtr>();
1373 auto param_node_pair = FindParameter(mirror_cnode->input(1), node->func_graph());
1374 if (!param_node_pair.first) {
1375 MS_LOG(EXCEPTION) << "Mirror input is not a param";
1376 }
1377 auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
1378 std::string param_name = param_ptr->name();
1379 MS_LOG(INFO) << "Mirror param name is: " << param_name;
1380 std::string target_param = "adasum_delta_weight." + param_name;
1381 auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];
1382
1383 // Change mirror group
1384 RankList group_devices = GetRankListByLayout(target_param_layout);
1385 int64_t rank = g_device_manager->global_rank();
1386 size_t group_dis = LongToSize(group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
1387 auto prim = GetCNodePrimitive(node);
1388 if (group_dis < ADASUM_MIN_DIS && group_dis > 0) {
1389 size_t new_group_size = size_t(ADASUM_MIN_DIS) / group_dis;
1390 // compute new group range
1391 size_t group_begin = 0;
1392 for (size_t group_end = new_group_size; group_end < group_devices.size() + new_group_size;
1393 group_end += new_group_size) {
1394 int64_t max_group_value =
1395 group_end >= group_devices.size() ? (group_devices.back() + 1) : group_devices[group_end];
1396 if (group_devices[group_begin] <= rank && rank < max_group_value) {
1397 std::vector<int64_t> new_group(group_devices.begin() + SizeToLong(group_begin),
1398 group_devices.begin() + SizeToLong(group_end));
1399 MS_LOG(INFO) << "Find new mirror group in adasum: " << new_group << " target_param:" << target_param;
1400 ResetMirrorAttr(prim, new_group);
1401 break;
1402 }
1403 group_begin = group_end;
1404 }
1405 continue;
1406 }
1407 ResetMirrorAttr(prim, {rank});
1408 }
1409 }
1410
SetParamInfoSaveStrategy(ParameterPtr row_col_param)1411 void SetParamInfoSaveStrategy(ParameterPtr row_col_param) {
1412 if (!row_col_param) {
1413 return;
1414 }
1415 auto param_info = row_col_param->param_info();
1416 if (param_info) {
1417 param_info->set_strategy_ckpt_saved(true);
1418 }
1419 }
1420
HandleCameAndAdaFactorOpt(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)1421 void HandleCameAndAdaFactorOpt(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
1422 const FuncGraphManagerPtr &manager) {
1423 MS_LOG(INFO) << "Adafactor or Came optimizer process start";
1424 MS_EXCEPTION_IF_NULL(root);
1425 std::set<AnfNodePtr> origin_params;
1426 for (auto ¶m_node : root->parameters()) {
1427 MS_EXCEPTION_IF_NULL(param_node);
1428 auto param = param_node->cast<ParameterPtr>();
1429 MS_EXCEPTION_IF_NULL(param);
1430
1431 if (!IsOriginWeight(param)) {
1432 continue;
1433 }
1434
1435 int64_t row_col_count = 0;
1436 int64_t exp_avg_sq_count = 0;
1437 for (auto &row_col_node : root->parameters()) {
1438 bool is_all_param_collected = (row_col_count == 4) && (exp_avg_sq_count == 1);
1439 if (is_all_param_collected) {
1440 break;
1441 }
1442
1443 MS_EXCEPTION_IF_NULL(row_col_node);
1444 auto row_col_param = row_col_node->cast<ParameterPtr>();
1445 MS_EXCEPTION_IF_NULL(row_col_param);
1446 std::string row_col_param_name = row_col_param->name();
1447 std::string param_name = param->name();
1448 std::string exp_row_name = EXP_AVG_SQ_ROW + param_name;
1449 std::string exp_col_name = EXP_AVG_SQ_COL + param_name;
1450 std::string exp_insta_row_name = EXP_AVG_INSTA_ROW + param_name;
1451 std::string exp_insta_col_name = EXP_AVG_INSTA_COL + param_name;
1452 std::string exp_avg_name = EXP_AVG_SQ + param_name;
1453 std::set<std::string> came_param_set = {exp_row_name, exp_col_name, exp_insta_row_name, exp_insta_col_name,
1454 exp_avg_name};
1455
1456 if (came_param_set.find(row_col_param_name) == came_param_set.end()) {
1457 continue;
1458 }
1459 origin_params.insert(param_node);
1460 auto tensor_layout = param->user_data<TensorLayout>();
1461 MS_EXCEPTION_IF_NULL(tensor_layout);
1462 auto slice_shape = tensor_layout->slice_shape().array();
1463 Shape opt_shard_slice_shape = slice_shape;
1464 if (!tensor_layout->opt_shard_group().empty()) {
1465 opt_shard_slice_shape = tensor_layout->opt_shard_slice_shape();
1466 }
1467
1468 auto shape_size = slice_shape.size();
1469 bool is_row_or_col_param = row_col_param_name != exp_avg_name;
1470 if (is_row_or_col_param && shape_size <= 1) {
1471 row_col_count++;
1472 continue;
1473 }
1474
1475 if (row_col_param_name == exp_avg_name && shape_size != 1) {
1476 exp_avg_sq_count++;
1477 continue;
1478 }
1479
1480 auto origin_shape = tensor_layout->tensor_shape().array();
1481 auto dev_mat = tensor_layout->device_arrangement().array();
1482 auto tensor_map = tensor_layout->tensor_map().array();
1483
1484 if (row_col_param_name == exp_row_name || row_col_param_name == exp_insta_row_name) {
1485 opt_shard_slice_shape.pop_back();
1486 origin_shape.pop_back();
1487 tensor_map.pop_back();
1488 row_col_count++;
1489 } else if (row_col_param_name == exp_col_name || row_col_param_name == exp_insta_col_name) {
1490 (void)opt_shard_slice_shape.erase(opt_shard_slice_shape.cbegin() +
1491 static_cast<different_type>(SECOND_FROM_END(shape_size)));
1492 (void)origin_shape.erase(origin_shape.cbegin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
1493 (void)tensor_map.erase(tensor_map.cbegin() + static_cast<different_type>(SECOND_FROM_END(shape_size)));
1494 row_col_count++;
1495 } else {
1496 exp_avg_sq_count++;
1497 }
1498
1499 TensorLayout new_tensor_layout;
1500 if (new_tensor_layout.InitFromVector(dev_mat, tensor_map, origin_shape) != SUCCESS) {
1501 MS_LOG(EXCEPTION) << "Init tensor layout failed";
1502 }
1503
1504 if (AdafactorStateIsOptShard(tensor_layout->opt_shard_group(), shape_size, param_name, row_col_param_name)) {
1505 new_tensor_layout.set_opt_shard_group(tensor_layout->opt_shard_group());
1506 new_tensor_layout.set_opt_shard_slice_shape(opt_shard_slice_shape);
1507 }
1508 SetParamInfoSaveStrategy(row_col_param);
1509 auto cloned_abstract = row_col_node->abstract()->Clone();
1510 MS_EXCEPTION_IF_NULL(cloned_abstract);
1511 std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(opt_shard_slice_shape);
1512 MS_EXCEPTION_IF_NULL(parallel_shape);
1513 cloned_abstract->set_shape(parallel_shape);
1514 row_col_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(new_tensor_layout));
1515 row_col_node->set_abstract(cloned_abstract);
1516 }
1517 }
1518
1519 for (const auto &origin_param_node : origin_params) {
1520 auto inserter = CameCommHandler(origin_param_node->cast<ParameterPtr>(), root->parameters(), manager->node_users());
1521 inserter.Process();
1522 }
1523 }
1524
GenerateTensorLayoutForParamReshapeWithStra(const AnfNodePtr & node,const Dimensions input_stra)1525 static std::shared_ptr<TensorLayout> GenerateTensorLayoutForParamReshapeWithStra(const AnfNodePtr &node,
1526 const Dimensions input_stra) {
1527 CheckGlobalDeviceManager();
1528 int64_t dev_num = g_device_manager->stage_device_num();
1529 MS_EXCEPTION_IF_ZERO("dev_num", dev_num);
1530
1531 Shapes inputs_shape = GetNodeShape(node);
1532 Shape param_shape = inputs_shape[0];
1533
1534 Shape param_dev_matrix_shape(input_stra.size() + 1, 0);
1535 for (size_t i = param_dev_matrix_shape.size() - 1; i > 0; i--) {
1536 param_dev_matrix_shape[i] = input_stra[i - 1];
1537 }
1538 param_dev_matrix_shape[0] =
1539 dev_num / std::accumulate(input_stra.begin(), input_stra.end(), 1, std::multiplies<int64_t>());
1540
1541 TensorMap param_tensor_map;
1542 for (size_t i = 0; i < param_shape.size(); ++i) {
1543 param_tensor_map.push_back(static_cast<int64_t>(param_shape.size() - i - 1));
1544 }
1545
1546 TensorLayout param_layout;
1547
1548 if (param_layout.InitFromVector(param_dev_matrix_shape, param_tensor_map, param_shape) != SUCCESS) {
1549 MS_LOG(EXCEPTION) << "Infer param-Reshape with strategy tensor layout failed.";
1550 }
1551
1552 return std::make_shared<TensorLayout>(param_layout);
1553 }
1554
FindParameterByCallNode(const CNodePtr & call,int64_t index)1555 static AnfNodePtr FindParameterByCallNode(const CNodePtr &call, int64_t index) {
1556 MS_EXCEPTION_IF_NULL(call);
1557 AnfNodePtr graph_value_node = call->input(0);
1558 if (!IsValueNode<FuncGraph>(graph_value_node)) {
1559 return nullptr;
1560 }
1561 auto graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node);
1562 auto parameters = graph_sub->parameters();
1563 if (LongToSize(index - 1) >= parameters.size()) {
1564 MS_LOG(EXCEPTION) << "The index is out of range, index is: " << (index - 1) << ", vector size is "
1565 << parameters.size();
1566 }
1567 return parameters[LongToSize(index - 1)];
1568 }
1569
FindParameterNextLayout(const AnfNodePtr & node,size_t curr_depth)1570 static std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node, size_t curr_depth) {
1571 if (curr_depth > MAX_RECURSIVE_DEPTH) {
1572 MS_LOG(WARNING) << "When finding the next tensor layout for the parameter, exceeded the maximum recursion depth: "
1573 << MAX_RECURSIVE_DEPTH;
1574 return nullptr;
1575 }
1576 FuncGraphManagerPtr manager = node->func_graph()->manager();
1577 MS_EXCEPTION_IF_NULL(manager);
1578 AnfNodeIndexSet node_set = manager->node_users()[node];
1579 for (auto &node_pair : node_set) {
1580 if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) {
1581 auto layout_param = FindParameterNextLayout(node_pair.first, ++curr_depth);
1582 if (!layout_param) {
1583 continue;
1584 }
1585 return layout_param;
1586 }
1587 CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
1588 if (use_apply == nullptr) {
1589 continue;
1590 }
1591 auto op = use_apply->input(0);
1592 MS_EXCEPTION_IF_NULL(op);
1593 if (IsValueNode<FuncGraph>(op)) {
1594 auto fg = GetValueNode<FuncGraphPtr>(op);
1595 auto para = FindParameterByCallNode(use_apply, node_pair.second);
1596 auto layout_param = FindParameterNextLayout(para, ++curr_depth);
1597 if (!layout_param) {
1598 continue;
1599 }
1600 return layout_param;
1601 }
1602 if (!IsValueNode<Primitive>(use_apply->input(0))) {
1603 continue;
1604 }
1605 ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
1606 MS_EXCEPTION_IF_NULL(prim_anf_node);
1607 PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
1608 MS_EXCEPTION_IF_NULL(node_prim);
1609 if (node_prim->name() == DEPEND && node_pair.second != 1) {
1610 continue;
1611 }
1612 if (node_prim->name() == RESHAPE) {
1613 auto attrs_temp = node_prim->attrs();
1614 if (!StrategyFound(attrs_temp)) {
1615 continue;
1616 }
1617 StrategyPtr strategy = ExtractStrategy(attrs_temp[IN_STRATEGY]);
1618 Strategies stra = strategy->GetInputDim();
1619 Dimensions input_strategy = stra.at(0);
1620
1621 auto param_layout = GenerateTensorLayoutForParamReshapeWithStra(node, input_strategy);
1622
1623 return param_layout;
1624 }
1625 if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
1626 auto layout = GetInputLayoutFromCNode(node_pair, -1);
1627 return std::make_shared<TensorLayout>(layout);
1628 }
1629 }
1630 return nullptr;
1631 }
1632
CreateParameterLayout(const AnfNodePtr & node)1633 std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
1634 // Create DataParallel tensor layout for parameter(support WideDeep).
1635 auto next_layout = FindParameterNextLayout(node, 0);
1636 if (next_layout != nullptr) {
1637 return next_layout;
1638 }
1639 CheckGlobalDeviceManager();
1640 int64_t dev_num = g_device_manager->stage_device_num();
1641 MS_EXCEPTION_IF_ZERO("dev_num", dev_num);
1642 TensorLayout input_tensor_layout;
1643 // create input_shape
1644 Shapes inputs_shape = GetNodeShape(node);
1645 Shape input_shape_array = inputs_shape[0];
1646
1647 // create dev_matrix
1648 Shape dev_matrix_array = {dev_num};
1649
1650 // create tensor_map
1651 size_t shape_size = input_shape_array.size();
1652 TensorMap input_tensor_map_array(shape_size, MAP_NONE);
1653 if ((shape_size > 0) && (input_shape_array[0] % dev_num == 0)) {
1654 input_tensor_map_array[0] = 0; // shard parameter's first dimension when parameter->Reshape->Op
1655 }
1656
1657 if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) {
1658 MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed.";
1659 }
1660 return std::make_shared<TensorLayout>(input_tensor_layout);
1661 }
1662
1663 // temporary method for handling StandardNormal Insertion in opt graph
InsertUniformRealForTaggedNodes(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & all_nodes)1664 void InsertUniformRealForTaggedNodes(const FuncGraphManagerPtr &manager, const std::vector<AnfNodePtr> &all_nodes) {
1665 for (auto &node : all_nodes) {
1666 MS_EXCEPTION_IF_NULL(node);
1667 if (!node->isa<CNode>()) {
1668 continue;
1669 }
1670 auto primitive = GetCNodePrimitive(node);
1671 if (primitive == nullptr) {
1672 continue;
1673 }
1674 if (common::AnfAlgo::IsCommunicationOp(node)) {
1675 continue;
1676 }
1677 auto comm_prim = common::AnfAlgo::GetCNodePrimitive(node);
1678 if (comm_prim->HasAttr("insert_rand")) {
1679 MS_LOG(INFO) << "Insert UniformReal to node" << node->DebugString();
1680 std::vector<AnfNodePtr> inputShape = {NewValueNode(prim::kPrimShape), node->cast<CNodePtr>()->input(kIndex1)};
1681 auto inputShapeNode = node->func_graph()->NewCNode(inputShape);
1682
1683 std::vector<AnfNodePtr> uniformReal = {NewValueNode(prim::kPrimUniformReal), inputShapeNode->cast<AnfNodePtr>()};
1684 auto uniformRealNode = node->func_graph()->NewCNode(uniformReal);
1685
1686 auto uniformRealPrim = GetCNodePrimitive(uniformRealNode);
1687 auto attrs = uniformRealPrim->attrs();
1688 attrs["seed"] = MakeValue<int64_t>(0);
1689 attrs["seed2"] = MakeValue<int64_t>(0);
1690 (void)uniformRealPrim->SetAttrs(attrs);
1691
1692 manager->SetEdge(node, 1, uniformRealNode);
1693 }
1694 }
1695 }
1696 } // namespace parallel
1697 } // namespace mindspore
1698