1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/parallel/came_parallel_handler.h"
17
18 #include <deque>
19 #include <algorithm>
20
21 #include "frontend/parallel/parameter_manager.h"
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/other_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "mindspore/core/utils/convert_utils_base.h"
27 #include "utils/hash_map.h"
28 #include "frontend/operator/ops.h"
29 #include "frontend/optimizer/optimizer.h"
30 #include "include/common/utils/parallel_context.h"
31 #include "frontend/parallel/device_manager.h"
32 #include "frontend/parallel/graph_util/generate_graph.h"
33 #include "frontend/parallel/graph_util/graph_info.h"
34 #include "frontend/parallel/graph_util/node_info.h"
35 #include "frontend/parallel/graph_util/get_parallel_info.h"
36 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
37 #include "frontend/parallel/node_check.h"
38 #include "ir/param_info.h"
39 #include "ir/tensor.h"
40 #include "utils/trace_base.h"
41 #include "include/common/utils/comm_manager.h"
42 #include "utils/ms_context.h"
43 #include "utils/symbolic.h"
44 #include "pipeline/jit/ps/pipeline.h"
45 #include "mindspore/core/utils/parallel_node_check.h"
46 #include "frontend/parallel/step_parallel_utils.h"
47 #include "mindspore/core/ops/nn_ops.h"
48
49 namespace mindspore {
50 namespace parallel {
GetCNodeOpName(const CNodePtr & cnode)51 const std::string GetCNodeOpName(const CNodePtr &cnode) {
52 // get the prim name of cnode
53 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
54 MS_EXCEPTION_IF_NULL(prim_anf_node);
55 PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
56 MS_EXCEPTION_IF_NULL(node_prim);
57 return node_prim->name();
58 }
59
BackwardSearchCNode(const CNodePtr & bottom_node,const std::vector<std::pair<std::string,size_t>> & bwd_calls,const std::string & target_name)60 std::pair<bool, const CNodePtr> BackwardSearchCNode(const CNodePtr &bottom_node,
61 const std::vector<std::pair<std::string, size_t>> &bwd_calls,
62 const std::string &target_name) {
63 CNodePtr target_node = bottom_node;
64 for (const auto &call_param : bwd_calls) {
65 const auto node_name = call_param.first;
66 const auto idx = call_param.second;
67 auto cnode_name = GetCNodeOpName(target_node);
68 if (cnode_name != node_name) {
69 MS_LOG(DEBUG) << "[CAME] backward search failed, expect node name: " << node_name << " but got " << cnode_name;
70 return {false, bottom_node};
71 }
72 const auto ¶m_node = target_node->input(idx + 1);
73 if (!param_node) {
74 MS_LOG(DEBUG) << "[CAME] backward search failed, expect param at index: " << (idx + 1) << " but got null";
75 return {false, bottom_node};
76 }
77 if (!param_node->isa<CNode>()) {
78 MS_LOG(DEBUG) << "[CAME] param node is not a cnode!";
79 return {false, bottom_node};
80 }
81 auto param_cnode = param_node->cast<CNodePtr>();
82 MS_EXCEPTION_IF_NULL(param_cnode);
83 target_node = param_cnode;
84 }
85 auto cnode_name = GetCNodeOpName(target_node);
86 if (cnode_name != target_name) {
87 MS_LOG(DEBUG) << "[CAME] backward search failed, expect target node name: " << target_name << " but got "
88 << cnode_name;
89 return {false, bottom_node};
90 }
91 return {true, target_node};
92 }
93
ForwardSearchCNode(const CNodePtr & start_node,const std::vector<std::string> & fwd_calls,const NodeUsersMap & node_user_map)94 std::pair<bool, std::vector<CNodePtr>> ForwardSearchCNode(const CNodePtr &start_node,
95 const std::vector<std::string> &fwd_calls,
96 const NodeUsersMap &node_user_map) {
97 if (!start_node) {
98 MS_LOG(DEBUG) << "[CAME] forward search start is null!";
99 return {false, {}};
100 }
101 if (fwd_calls.empty()) {
102 MS_LOG(DEBUG) << "[CAME] gives empty forward calls!";
103 return {false, {}};
104 }
105 std::vector<CNodePtr> candidates;
106 std::deque<CNodePtr> visited;
107 CNodePtr cur_node = nullptr;
108 uint32_t depth = 0;
109
110 visited.push_back(start_node);
111 CNodePtr last_node = visited.back();
112 while (!visited.empty()) {
113 if (depth == fwd_calls.size() - 1) {
114 std::copy(visited.begin(), visited.end(), std::back_inserter(candidates));
115 break;
116 }
117 cur_node = visited.front();
118 MS_LOG(INFO) << "[CAME] fwd current node: " << cur_node->DebugString();
119 visited.pop_front();
120 auto node_set = node_user_map.at(cur_node->cast<AnfNodePtr>());
121 for (auto item : node_set) {
122 auto user_node = item.first;
123 if (!user_node->isa<CNode>()) {
124 continue;
125 }
126 auto user_cnode = user_node->cast<CNodePtr>();
127 if (GetCNodeOpName(user_cnode) == fwd_calls[depth + 1]) {
128 visited.push_back(user_cnode);
129 }
130 }
131 if (last_node == cur_node) {
132 last_node = visited.back();
133 depth++;
134 }
135 }
136
137 if (candidates.empty()) {
138 return {false, {}};
139 } else {
140 return {true, candidates};
141 }
142 }
143
CameCommHandler(ParameterPtr origin,const std::vector<AnfNodePtr> & all_parameters,const NodeUsersMap & node_user_map)144 CameCommHandler::CameCommHandler(ParameterPtr origin, const std::vector<AnfNodePtr> &all_parameters,
145 const NodeUsersMap &node_user_map)
146 : origin(origin), all_parameters(all_parameters), node_user_map(node_user_map) {
147 CheckGlobalDeviceManager();
148 cur_rank = g_device_manager->global_rank();
149 full_rank_list = g_device_manager->GetDeviceListInThisStage();
150
151 tensor_layout = origin->user_data<TensorLayout>();
152 MS_EXCEPTION_IF_NULL(tensor_layout);
153
154 auto opt_shard_group_name = tensor_layout->opt_shard_group();
155 if (!opt_shard_group_name.empty()) {
156 is_opt_shard = true;
157 }
158 MS_LOG(DEBUG) << "CAME processing parameter";
159 MS_LOG(DEBUG) << "tensor shape:" << tensor_layout->tensor_shape().ToString();
160 MS_LOG(DEBUG) << "slice shape:" << tensor_layout->slice_shape().ToString();
161
162 MS_LOG(DEBUG) << "opt shard slice shape:";
163 for (const auto &item : tensor_layout->opt_shard_slice_shape()) {
164 MS_LOG(DEBUG) << item;
165 }
166 MS_LOG(DEBUG) << "opt shard group:" << tensor_layout->opt_shard_group();
167 MS_LOG(DEBUG) << "opt shard step:" << tensor_layout->opt_weight_shard_step();
168
169 MS_LOG(DEBUG) << "device arrangement:" << tensor_layout->device_arrangement().ToString();
170 MS_LOG(DEBUG) << "original device arrangement:" << tensor_layout->device_arrangement_origin().ToString();
171
172 MS_LOG(DEBUG) << "tensor map:" << tensor_layout->tensor_map().ToString();
173 MS_LOG(DEBUG) << "original tensor map:" << tensor_layout->origin_tensor_map().ToString();
174
175 FindCameParams();
176 }
177
FindCameParams()178 void CameCommHandler::FindCameParams() {
179 const std::string origin_name = origin->name();
180 const std::string exp_row_name = EXP_AVG_SQ_ROW + origin_name;
181 const std::string exp_col_name = EXP_AVG_SQ_COL + origin_name;
182 const std::string exp_insta_row_name = EXP_AVG_INSTA_ROW + origin_name;
183 const std::string exp_insta_col_name = EXP_AVG_INSTA_COL + origin_name;
184 const std::string exp_avg_name = std::string(EXP_AVG) + "." + origin_name;
185 const size_t param_to_find_size = 5;
186 size_t cur_found_param_count = 0;
187 for (const auto ¶m_node : all_parameters) {
188 auto param = param_node->cast<ParameterPtr>();
189 MS_EXCEPTION_IF_NULL(param);
190 const std::string param_name = param->name();
191 if (param_name == exp_row_name) {
192 MS_LOG(DEBUG) << "[CAME] found exp_avg_sq_row: " << param_name;
193 exp_avg_sq_row = param;
194 cur_found_param_count++;
195 } else if (param_name == exp_col_name) {
196 MS_LOG(DEBUG) << "[CAME] found exp_avg_sq_col: " << param_name;
197 exp_avg_sq_col = param;
198 cur_found_param_count++;
199 } else if (param_name == exp_insta_row_name) {
200 MS_LOG(DEBUG) << "[CAME] found exp_avg_insta_row: " << param_name;
201 exp_avg_insta_row = param;
202 cur_found_param_count++;
203 } else if (param_name == exp_insta_col_name) {
204 MS_LOG(DEBUG) << "[CAME] found exp_avg_insta_col: " << param_name;
205 exp_avg_insta_col = param;
206 cur_found_param_count++;
207 } else if (param_name == exp_avg_name) {
208 MS_LOG(DEBUG) << "[CAME] found exp_avg: " << param_name;
209 exp_avg = param;
210 cur_found_param_count++;
211 }
212
213 if (cur_found_param_count == param_to_find_size) {
214 break;
215 }
216 }
217 MS_LOG(INFO) << "[CAME] found params corresponding to origin param size: " << cur_found_param_count;
218 }
219
GetOptShardRankList(const int64_t rank)220 std::pair<Status, RankList> CameCommHandler::GetOptShardRankList(const int64_t rank) {
221 DeviceMatrix temp_dev_matrix(rank, full_rank_list, tensor_layout->device_arrangement().array());
222 RankList group_devices;
223 Shape orig_tensor_map = tensor_layout->tensor_map().array();
224 if (temp_dev_matrix.GetDevicesByTensorMap(orig_tensor_map, &group_devices) != SUCCESS) {
225 return {FAILED, {}};
226 }
227 if (group_devices.size() < 2) {
228 MS_LOG(ERROR) << "get opt shard rank list with less than two group devices!";
229 return {FAILED, {}};
230 }
231
232 int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
233 MS_EXCEPTION_IF_ZERO("optimizer_weight_shard_size", optimizer_weight_shard_size);
234 if ((optimizer_weight_shard_size == -1) || (optimizer_weight_shard_size > SizeToLong(group_devices.size()))) {
235 MS_LOG(INFO) << "[CAME] detect optimizer_weight_shard_size = -1 or exceed max shard size, use group devices size: "
236 << group_devices.size();
237 optimizer_weight_shard_size = SizeToLong(group_devices.size());
238 }
239
240 int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin();
241
242 // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24]
243 auto rank_list =
244 RankList(group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size,
245 group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size);
246 return std::make_pair(SUCCESS, rank_list);
247 }
248
GetDimRankList(const int64_t rank,const int64_t dim)249 std::pair<Status, RankList> CameCommHandler::GetDimRankList(const int64_t rank, const int64_t dim) {
250 DeviceMatrix dev_matrix(rank, full_rank_list, tensor_layout->device_arrangement().array());
251 int64_t device_reverse_dim = tensor_layout->tensor_map().GetDimByIdx(dim);
252 if (device_reverse_dim == -1) {
253 return {SUCCESS, {rank}};
254 }
255 int64_t device_dim = SizeToLong(tensor_layout->device_arrangement().array().size()) - 1 - device_reverse_dim;
256 RankList rank_list;
257 if (dev_matrix.GetDevicesAlongDim(LongToUlong(device_dim), &rank_list) != SUCCESS) {
258 MS_LOG(ERROR) << "Get devices along dim failed";
259 return {FAILED, rank_list};
260 }
261 return {SUCCESS, rank_list};
262 }
263
ExpandRankListWithOptShard(const RankList & rank_list)264 RankList CameCommHandler::ExpandRankListWithOptShard(const RankList &rank_list) {
265 if (!is_opt_shard) {
266 return rank_list;
267 }
268 MS_LOG(INFO) << "opt shard yes, group name:" << tensor_layout->opt_shard_group();
269
270 RankList opt_rank_list_find = g_device_manager->FindRankListByHashName(tensor_layout->opt_shard_group());
271 for (const auto &opt_find_rank : opt_rank_list_find) {
272 MS_LOG(INFO) << "group device member:" << opt_find_rank;
273 }
274
275 RankList expanded_list;
276 for (const auto &rank : rank_list) {
277 Status ret_state;
278 RankList opt_shard_rank_list;
279 std::tie(ret_state, opt_shard_rank_list) = GetOptShardRankList(rank);
280 if (ret_state != SUCCESS) {
281 MS_LOG(EXCEPTION) << "find opt shard rank list in adafactor failed";
282 }
283 MS_LOG(INFO) << "found opt shard rank list for rank " << rank;
284
285 for (const auto &opt_rank : opt_shard_rank_list) {
286 MS_LOG(INFO) << opt_rank;
287 }
288 expanded_list.insert(expanded_list.end(), opt_shard_rank_list.begin(), opt_shard_rank_list.end());
289 }
290 std::sort(expanded_list.begin(), expanded_list.end());
291 MS_LOG(INFO) << "expand rank list with opt shard, before:";
292 for (const auto &item : rank_list) {
293 MS_LOG(INFO) << item;
294 }
295 MS_LOG(INFO) << "after:";
296 for (const auto &item : expanded_list) {
297 MS_LOG(INFO) << item;
298 }
299 return expanded_list;
300 }
301
ExpandRankListWithDim(const RankList & rank_list,const int64_t dim)302 RankList CameCommHandler::ExpandRankListWithDim(const RankList &rank_list, const int64_t dim) {
303 RankList expanded_list;
304 for (const auto &rank : rank_list) {
305 Status ret_status;
306 RankList dim_rank_list;
307 std::tie(ret_status, dim_rank_list) = GetDimRankList(rank, dim);
308 if (ret_status != SUCCESS) {
309 MS_LOG(EXCEPTION) << "find dim rank list in adafactor failed";
310 }
311 expanded_list.insert(expanded_list.end(), dim_rank_list.begin(), dim_rank_list.end());
312 }
313 std::sort(expanded_list.begin(), expanded_list.end());
314 return expanded_list;
315 }
316
FindReduceMean(size_t number)317 CNodePtr CameCommHandler::FindReduceMean(size_t number) {
318 if (reduce_mean_numbers.find(number) == reduce_mean_numbers.end()) {
319 MS_LOG(INFO) << "[CAME] invalid reduce mean number: " << number;
320 }
321
322 if (number == kFirstCameReduceMean) {
323 return FindReduceMean1256(exp_avg_sq_row);
324 } else if (number == kSecondCameReduceMean) {
325 return FindReduceMean1256(exp_avg_sq_col);
326 } else if (number == kThirdCameReduceMean) {
327 return FindReduceMean37(exp_avg_sq_row);
328 } else if (number == kForthCameReduceMean) {
329 return FindReduceMean4();
330 } else if (number == kFifthCameReduceMean) {
331 return FindReduceMean1256(exp_avg_insta_row);
332 } else if (number == kSixthCameReduceMean) {
333 return FindReduceMean1256(exp_avg_insta_col);
334 } else if (number == kSeventhCameReduceMean) {
335 return FindReduceMean37(exp_avg_insta_row);
336 } else {
337 return nullptr;
338 }
339 }
340
FindReduceMean1256(const ParameterPtr & param)341 CNodePtr CameCommHandler::FindReduceMean1256(const ParameterPtr ¶m) {
342 if (!param) {
343 return nullptr;
344 }
345 MS_LOG(INFO) << "[CAME] try find reduce_mean according to " << param->name() << " Assign:";
346 auto param_user_set = node_user_map.at(param->cast<AnfNodePtr>());
347 for (auto ¶m_pair : param_user_set) {
348 auto user_cnode = param_pair.first->cast<CNodePtr>();
349 MS_EXCEPTION_IF_NULL(user_cnode);
350 if (IsSomePrimitive(user_cnode, ASSIGN)) {
351 MS_LOG(INFO) << "[CAME] found assign node";
352 // assign 1 -> add 1 -> mul 0 -> reduce_mean
353 auto res = BackwardSearchCNode(user_cnode, {{ASSIGN, 1}, {ADD, 1}, {MUL, 0}}, REDUCE_MEAN);
354 if (res.first) {
355 MS_LOG(INFO) << "[CAME] found reduce mean node: " << res.second->DebugString();
356 return res.second;
357 }
358 }
359 }
360 return nullptr;
361 }
362
FindReduceMean37(const ParameterPtr & param)363 CNodePtr CameCommHandler::FindReduceMean37(const ParameterPtr ¶m) {
364 if (!param) {
365 return nullptr;
366 }
367 auto param_user_set = node_user_map.at(param->cast<AnfNodePtr>());
368 MS_LOG(INFO) << "[CAME] user map size: " << param_user_set.size();
369 size_t load_count = 0;
370 for (auto ¶m_pair : param_user_set) {
371 auto user_cnode = param_pair.first->cast<CNodePtr>();
372 MS_EXCEPTION_IF_NULL(user_cnode);
373 if (IsSomePrimitive(user_cnode, LOAD)) {
374 MS_LOG(INFO) << "[CAME] found load node";
375 load_count++;
376 // load -> reduce mean
377 auto res = ForwardSearchCNode(user_cnode, {LOAD, REDUCE_MEAN}, node_user_map);
378 if (res.first) {
379 MS_LOG(INFO) << "[CAME] found reduce mean node size: " << res.second.size();
380 return res.second[0]; // get the first one
381 }
382 }
383 }
384 MS_LOG(INFO) << "[CAME] found load count: " << load_count;
385 return nullptr;
386 }
387
FindReduceMean4()388 CNodePtr CameCommHandler::FindReduceMean4() {
389 MS_LOG(INFO) << "[CAME] try find reduce_mean no.4 according to exp_avg Assign:";
390 if (!exp_avg) {
391 return nullptr;
392 }
393 auto exp_avg_user_set = node_user_map.at(exp_avg->cast<AnfNodePtr>());
394 for (auto ¶m_pair : exp_avg_user_set) {
395 auto user_cnode = param_pair.first->cast<CNodePtr>();
396 MS_EXCEPTION_IF_NULL(user_cnode);
397 if (IsSomePrimitive(user_cnode, ASSIGN)) {
398 MS_LOG(INFO) << "[CAME] found exp_avg's assign node";
399 auto res = BackwardSearchCNode(
400 user_cnode, {{ASSIGN, 1}, {ADD, 1}, {MUL, 0}, {REAL_DIV, 1}, {MAXIMUM, 0}, {REAL_DIV, 0}, {SQRT, 0}},
401 REDUCE_MEAN);
402 if (res.first) {
403 MS_LOG(INFO) << "[CAME] found reduce mean node: " << res.second->DebugString();
404 return res.second;
405 }
406 }
407 }
408 return nullptr;
409 }
410
InsertAllReduceAndRealDivToReduceMeanInput(CNodePtr reduce_mean,const RankList & comm_rank_list)411 void CameCommHandler::InsertAllReduceAndRealDivToReduceMeanInput(CNodePtr reduce_mean, const RankList &comm_rank_list) {
412 // construct all reduce cnode and insert to the first input
413 if (!reduce_mean) {
414 return;
415 }
416 FuncGraphPtr func_graph = reduce_mean->func_graph();
417 MS_EXCEPTION_IF_NULL(func_graph);
418 FuncGraphManagerPtr manager = func_graph->manager();
419 MS_EXCEPTION_IF_NULL(manager);
420
421 CheckGlobalDeviceManager();
422
423 MS_LOG(INFO) << "Insert All Reduce and RealDiv to node" << reduce_mean->DebugString();
424 // insert all reduce
425 OperatorName allreduce_op_name = ALL_REDUCE;
426 OperatorAttrs all_reduce_op_attrs;
427 ValuePtr allreduce_pyop_instance = CreateOpInstance(all_reduce_op_attrs, allreduce_op_name, "came_norm_allreduce");
428 std::vector<AnfNodePtr> all_reduce_input = {NewValueNode(allreduce_pyop_instance), reduce_mean};
429 auto all_reduce_node = func_graph->NewCNode(all_reduce_input);
430 auto all_reduce_prim = GetCNodePrimitive(all_reduce_node);
431 auto all_reduce_attrs = all_reduce_prim->attrs();
432 all_reduce_attrs["op"] = MakeValue<std::string>(REDUCE_OP_SUM);
433
434 std::string group_name = CreateCommGroupFromRankList(comm_rank_list);
435 MS_LOG(INFO) << "[CAME] came allreduce opt shard group: " << group_name;
436 all_reduce_attrs["group"] = MakeValue<std::string>(group_name);
437 int64_t fusion_id = 0;
438 all_reduce_attrs["fusion"] = MakeValue(fusion_id);
439 all_reduce_prim->SetAttrs(all_reduce_attrs);
440 // insert real div
441 OperatorName operator_name = REAL_DIV;
442 OperatorAttrs operator_attrs;
443
444 ValuePtr pyop_instance = CreateOpInstance(operator_attrs, operator_name, "came_norm_realdiv");
445 MS_EXCEPTION_IF_NULL(pyop_instance);
446
447 size_t group_rank_size = comm_rank_list.size();
448 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
449 static_cast<float>(group_rank_size),
450 reduce_mean->abstract()->cast<abstract::AbstractTensorPtr>()->element()->GetType());
451 ValuePtr scale_value = MakeValue(tensor_ptr);
452
453 std::vector<AnfNodePtr> real_div_input = {NewValueNode(pyop_instance), all_reduce_node->cast<AnfNodePtr>(),
454 NewValueNode(scale_value)};
455 auto real_div_node = func_graph->NewCNode(real_div_input);
456 manager->Replace(reduce_mean, real_div_node);
457 }
458
Process()459 void CameCommHandler::Process() {
460 auto reduce_mean_1 = FindReduceMean(1);
461 auto reduce_mean_2 = FindReduceMean(2);
462 auto reduce_mean_3 = FindReduceMean(3);
463 auto reduce_mean_4 = FindReduceMean(4);
464 auto reduce_mean_5 = FindReduceMean(5);
465 auto reduce_mean_6 = FindReduceMean(6);
466 auto reduce_mean_7 = FindReduceMean(7);
467 MS_LOG(INFO) << "found all reduce mean for came/adafactor";
468
469 auto shape_size = tensor_layout->slice_shape().array().size();
470 if (shape_size == 1) {
471 // for shape [A], mp and opt shard may overlay on dim A.
472 Status ret_status;
473 RankList comm_rank_list;
474 std::tie(ret_status, comm_rank_list) = GetDimRankList(cur_rank, 0);
475 if (ret_status != SUCCESS) {
476 MS_LOG(ERROR) << "[CAME] shape size = 1, getting rank list along 0 failed";
477 }
478 comm_rank_list = ExpandRankListWithOptShard(comm_rank_list);
479 if (comm_rank_list.size() > 1) {
480 InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_4, comm_rank_list);
481 }
482 } else {
483 Status ret_status;
484 RankList comm_rank_list_along_neg_1;
485 RankList comm_rank_list_along_neg_2;
486 RankList comm_rank_list_along_neg_12;
487 int64_t actual_dim_of_neg_1 = SizeToLong(shape_size) - 1;
488 int64_t actual_dim_of_neg_2 = SizeToLong(shape_size) - 2;
489 std::tie(ret_status, comm_rank_list_along_neg_1) = GetDimRankList(cur_rank, actual_dim_of_neg_1);
490 if (ret_status != SUCCESS) {
491 MS_LOG(ERROR) << "[CAME] shape = 2, getting rank list along negative dim -1 failed";
492 }
493 std::tie(ret_status, comm_rank_list_along_neg_2) = GetDimRankList(cur_rank, actual_dim_of_neg_2);
494 if (ret_status != SUCCESS) {
495 MS_LOG(ERROR) << "[CAME] shape = 2, getting rank list along negative dim -2 failed";
496 }
497 if (shape_size == kParameterDimTwo) {
498 comm_rank_list_along_neg_2 = ExpandRankListWithOptShard(comm_rank_list_along_neg_2);
499 }
500 comm_rank_list_along_neg_12 = ExpandRankListWithDim(comm_rank_list_along_neg_2, actual_dim_of_neg_1);
501 if (comm_rank_list_along_neg_1.size() > 1) {
502 InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_1, comm_rank_list_along_neg_1);
503 InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_5, comm_rank_list_along_neg_1);
504 }
505 if (comm_rank_list_along_neg_2.size() > 1) {
506 InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_2, comm_rank_list_along_neg_2);
507 InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_3, comm_rank_list_along_neg_2);
508 InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_6, comm_rank_list_along_neg_2);
509 InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_7, comm_rank_list_along_neg_2);
510 }
511 if (comm_rank_list_along_neg_12.size() > 1) {
512 InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_4, comm_rank_list_along_neg_12);
513 }
514 }
515 }
516
CreateCommGroupFromRankList(const RankList & rank_list)517 std::string CameCommHandler::CreateCommGroupFromRankList(const RankList &rank_list) {
518 Group comm_group;
519 if (g_device_manager->CreateGroup(rank_list, &comm_group) != SUCCESS) {
520 MS_LOG(EXCEPTION) << "Create comm group failed in came";
521 }
522 std::string group_name = comm_group.name();
523 return group_name;
524 }
525
526 } // namespace parallel
527 } // namespace mindspore
528