• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/parallel/came_parallel_handler.h"
17 
18 #include <deque>
19 #include <algorithm>
20 
21 #include "frontend/parallel/parameter_manager.h"
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/other_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "mindspore/core/utils/convert_utils_base.h"
27 #include "utils/hash_map.h"
28 #include "frontend/operator/ops.h"
29 #include "frontend/optimizer/optimizer.h"
30 #include "include/common/utils/parallel_context.h"
31 #include "frontend/parallel/device_manager.h"
32 #include "frontend/parallel/graph_util/generate_graph.h"
33 #include "frontend/parallel/graph_util/graph_info.h"
34 #include "frontend/parallel/graph_util/node_info.h"
35 #include "frontend/parallel/graph_util/get_parallel_info.h"
36 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
37 #include "frontend/parallel/node_check.h"
38 #include "ir/param_info.h"
39 #include "ir/tensor.h"
40 #include "utils/trace_base.h"
41 #include "include/common/utils/comm_manager.h"
42 #include "utils/ms_context.h"
43 #include "utils/symbolic.h"
44 #include "pipeline/jit/ps/pipeline.h"
45 #include "mindspore/core/utils/parallel_node_check.h"
46 #include "frontend/parallel/step_parallel_utils.h"
47 #include "mindspore/core/ops/nn_ops.h"
48 
49 namespace mindspore {
50 namespace parallel {
GetCNodeOpName(const CNodePtr & cnode)51 const std::string GetCNodeOpName(const CNodePtr &cnode) {
52   // get the prim name of cnode
53   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
54   MS_EXCEPTION_IF_NULL(prim_anf_node);
55   PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
56   MS_EXCEPTION_IF_NULL(node_prim);
57   return node_prim->name();
58 }
59 
BackwardSearchCNode(const CNodePtr & bottom_node,const std::vector<std::pair<std::string,size_t>> & bwd_calls,const std::string & target_name)60 std::pair<bool, const CNodePtr> BackwardSearchCNode(const CNodePtr &bottom_node,
61                                                     const std::vector<std::pair<std::string, size_t>> &bwd_calls,
62                                                     const std::string &target_name) {
63   CNodePtr target_node = bottom_node;
64   for (const auto &call_param : bwd_calls) {
65     const auto node_name = call_param.first;
66     const auto idx = call_param.second;
67     auto cnode_name = GetCNodeOpName(target_node);
68     if (cnode_name != node_name) {
69       MS_LOG(DEBUG) << "[CAME] backward search failed, expect node name: " << node_name << " but got " << cnode_name;
70       return {false, bottom_node};
71     }
72     const auto &param_node = target_node->input(idx + 1);
73     if (!param_node) {
74       MS_LOG(DEBUG) << "[CAME] backward search failed, expect param at index: " << (idx + 1) << " but got null";
75       return {false, bottom_node};
76     }
77     if (!param_node->isa<CNode>()) {
78       MS_LOG(DEBUG) << "[CAME] param node is not a cnode!";
79       return {false, bottom_node};
80     }
81     auto param_cnode = param_node->cast<CNodePtr>();
82     MS_EXCEPTION_IF_NULL(param_cnode);
83     target_node = param_cnode;
84   }
85   auto cnode_name = GetCNodeOpName(target_node);
86   if (cnode_name != target_name) {
87     MS_LOG(DEBUG) << "[CAME] backward search failed, expect target node name: " << target_name << " but got "
88                   << cnode_name;
89     return {false, bottom_node};
90   }
91   return {true, target_node};
92 }
93 
ForwardSearchCNode(const CNodePtr & start_node,const std::vector<std::string> & fwd_calls,const NodeUsersMap & node_user_map)94 std::pair<bool, std::vector<CNodePtr>> ForwardSearchCNode(const CNodePtr &start_node,
95                                                           const std::vector<std::string> &fwd_calls,
96                                                           const NodeUsersMap &node_user_map) {
97   if (!start_node) {
98     MS_LOG(DEBUG) << "[CAME] forward search start is null!";
99     return {false, {}};
100   }
101   if (fwd_calls.empty()) {
102     MS_LOG(DEBUG) << "[CAME] gives empty forward calls!";
103     return {false, {}};
104   }
105   std::vector<CNodePtr> candidates;
106   std::deque<CNodePtr> visited;
107   CNodePtr cur_node = nullptr;
108   uint32_t depth = 0;
109 
110   visited.push_back(start_node);
111   CNodePtr last_node = visited.back();
112   while (!visited.empty()) {
113     if (depth == fwd_calls.size() - 1) {
114       std::copy(visited.begin(), visited.end(), std::back_inserter(candidates));
115       break;
116     }
117     cur_node = visited.front();
118     MS_LOG(INFO) << "[CAME] fwd current node: " << cur_node->DebugString();
119     visited.pop_front();
120     auto node_set = node_user_map.at(cur_node->cast<AnfNodePtr>());
121     for (auto item : node_set) {
122       auto user_node = item.first;
123       if (!user_node->isa<CNode>()) {
124         continue;
125       }
126       auto user_cnode = user_node->cast<CNodePtr>();
127       if (GetCNodeOpName(user_cnode) == fwd_calls[depth + 1]) {
128         visited.push_back(user_cnode);
129       }
130     }
131     if (last_node == cur_node) {
132       last_node = visited.back();
133       depth++;
134     }
135   }
136 
137   if (candidates.empty()) {
138     return {false, {}};
139   } else {
140     return {true, candidates};
141   }
142 }
143 
CameCommHandler(ParameterPtr origin,const std::vector<AnfNodePtr> & all_parameters,const NodeUsersMap & node_user_map)144 CameCommHandler::CameCommHandler(ParameterPtr origin, const std::vector<AnfNodePtr> &all_parameters,
145                                  const NodeUsersMap &node_user_map)
146     : origin(origin), all_parameters(all_parameters), node_user_map(node_user_map) {
147   CheckGlobalDeviceManager();
148   cur_rank = g_device_manager->global_rank();
149   full_rank_list = g_device_manager->GetDeviceListInThisStage();
150 
151   tensor_layout = origin->user_data<TensorLayout>();
152   MS_EXCEPTION_IF_NULL(tensor_layout);
153 
154   auto opt_shard_group_name = tensor_layout->opt_shard_group();
155   if (!opt_shard_group_name.empty()) {
156     is_opt_shard = true;
157   }
158   MS_LOG(DEBUG) << "CAME processing parameter";
159   MS_LOG(DEBUG) << "tensor shape:" << tensor_layout->tensor_shape().ToString();
160   MS_LOG(DEBUG) << "slice shape:" << tensor_layout->slice_shape().ToString();
161 
162   MS_LOG(DEBUG) << "opt shard slice shape:";
163   for (const auto &item : tensor_layout->opt_shard_slice_shape()) {
164     MS_LOG(DEBUG) << item;
165   }
166   MS_LOG(DEBUG) << "opt shard group:" << tensor_layout->opt_shard_group();
167   MS_LOG(DEBUG) << "opt shard step:" << tensor_layout->opt_weight_shard_step();
168 
169   MS_LOG(DEBUG) << "device arrangement:" << tensor_layout->device_arrangement().ToString();
170   MS_LOG(DEBUG) << "original device arrangement:" << tensor_layout->device_arrangement_origin().ToString();
171 
172   MS_LOG(DEBUG) << "tensor map:" << tensor_layout->tensor_map().ToString();
173   MS_LOG(DEBUG) << "original tensor map:" << tensor_layout->origin_tensor_map().ToString();
174 
175   FindCameParams();
176 }
177 
FindCameParams()178 void CameCommHandler::FindCameParams() {
179   const std::string origin_name = origin->name();
180   const std::string exp_row_name = EXP_AVG_SQ_ROW + origin_name;
181   const std::string exp_col_name = EXP_AVG_SQ_COL + origin_name;
182   const std::string exp_insta_row_name = EXP_AVG_INSTA_ROW + origin_name;
183   const std::string exp_insta_col_name = EXP_AVG_INSTA_COL + origin_name;
184   const std::string exp_avg_name = std::string(EXP_AVG) + "." + origin_name;
185   const size_t param_to_find_size = 5;
186   size_t cur_found_param_count = 0;
187   for (const auto &param_node : all_parameters) {
188     auto param = param_node->cast<ParameterPtr>();
189     MS_EXCEPTION_IF_NULL(param);
190     const std::string param_name = param->name();
191     if (param_name == exp_row_name) {
192       MS_LOG(DEBUG) << "[CAME] found exp_avg_sq_row: " << param_name;
193       exp_avg_sq_row = param;
194       cur_found_param_count++;
195     } else if (param_name == exp_col_name) {
196       MS_LOG(DEBUG) << "[CAME] found exp_avg_sq_col: " << param_name;
197       exp_avg_sq_col = param;
198       cur_found_param_count++;
199     } else if (param_name == exp_insta_row_name) {
200       MS_LOG(DEBUG) << "[CAME] found exp_avg_insta_row: " << param_name;
201       exp_avg_insta_row = param;
202       cur_found_param_count++;
203     } else if (param_name == exp_insta_col_name) {
204       MS_LOG(DEBUG) << "[CAME] found exp_avg_insta_col: " << param_name;
205       exp_avg_insta_col = param;
206       cur_found_param_count++;
207     } else if (param_name == exp_avg_name) {
208       MS_LOG(DEBUG) << "[CAME] found exp_avg: " << param_name;
209       exp_avg = param;
210       cur_found_param_count++;
211     }
212 
213     if (cur_found_param_count == param_to_find_size) {
214       break;
215     }
216   }
217   MS_LOG(INFO) << "[CAME] found params corresponding to origin param size: " << cur_found_param_count;
218 }
219 
GetOptShardRankList(const int64_t rank)220 std::pair<Status, RankList> CameCommHandler::GetOptShardRankList(const int64_t rank) {
221   DeviceMatrix temp_dev_matrix(rank, full_rank_list, tensor_layout->device_arrangement().array());
222   RankList group_devices;
223   Shape orig_tensor_map = tensor_layout->tensor_map().array();
224   if (temp_dev_matrix.GetDevicesByTensorMap(orig_tensor_map, &group_devices) != SUCCESS) {
225     return {FAILED, {}};
226   }
227   if (group_devices.size() < 2) {
228     MS_LOG(ERROR) << "get opt shard rank list with less than two group devices!";
229     return {FAILED, {}};
230   }
231 
232   int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
233   MS_EXCEPTION_IF_ZERO("optimizer_weight_shard_size", optimizer_weight_shard_size);
234   if ((optimizer_weight_shard_size == -1) || (optimizer_weight_shard_size > SizeToLong(group_devices.size()))) {
235     MS_LOG(INFO) << "[CAME] detect optimizer_weight_shard_size = -1 or exceed max shard size, use group devices size: "
236                  << group_devices.size();
237     optimizer_weight_shard_size = SizeToLong(group_devices.size());
238   }
239 
240   int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin();
241 
242   // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24]
243   auto rank_list =
244     RankList(group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size,
245              group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size);
246   return std::make_pair(SUCCESS, rank_list);
247 }
248 
GetDimRankList(const int64_t rank,const int64_t dim)249 std::pair<Status, RankList> CameCommHandler::GetDimRankList(const int64_t rank, const int64_t dim) {
250   DeviceMatrix dev_matrix(rank, full_rank_list, tensor_layout->device_arrangement().array());
251   int64_t device_reverse_dim = tensor_layout->tensor_map().GetDimByIdx(dim);
252   if (device_reverse_dim == -1) {
253     return {SUCCESS, {rank}};
254   }
255   int64_t device_dim = SizeToLong(tensor_layout->device_arrangement().array().size()) - 1 - device_reverse_dim;
256   RankList rank_list;
257   if (dev_matrix.GetDevicesAlongDim(LongToUlong(device_dim), &rank_list) != SUCCESS) {
258     MS_LOG(ERROR) << "Get devices along dim failed";
259     return {FAILED, rank_list};
260   }
261   return {SUCCESS, rank_list};
262 }
263 
ExpandRankListWithOptShard(const RankList & rank_list)264 RankList CameCommHandler::ExpandRankListWithOptShard(const RankList &rank_list) {
265   if (!is_opt_shard) {
266     return rank_list;
267   }
268   MS_LOG(INFO) << "opt shard yes, group name:" << tensor_layout->opt_shard_group();
269 
270   RankList opt_rank_list_find = g_device_manager->FindRankListByHashName(tensor_layout->opt_shard_group());
271   for (const auto &opt_find_rank : opt_rank_list_find) {
272     MS_LOG(INFO) << "group device member:" << opt_find_rank;
273   }
274 
275   RankList expanded_list;
276   for (const auto &rank : rank_list) {
277     Status ret_state;
278     RankList opt_shard_rank_list;
279     std::tie(ret_state, opt_shard_rank_list) = GetOptShardRankList(rank);
280     if (ret_state != SUCCESS) {
281       MS_LOG(EXCEPTION) << "find opt shard rank list in adafactor failed";
282     }
283     MS_LOG(INFO) << "found opt shard rank list for rank " << rank;
284 
285     for (const auto &opt_rank : opt_shard_rank_list) {
286       MS_LOG(INFO) << opt_rank;
287     }
288     expanded_list.insert(expanded_list.end(), opt_shard_rank_list.begin(), opt_shard_rank_list.end());
289   }
290   std::sort(expanded_list.begin(), expanded_list.end());
291   MS_LOG(INFO) << "expand rank list with opt shard, before:";
292   for (const auto &item : rank_list) {
293     MS_LOG(INFO) << item;
294   }
295   MS_LOG(INFO) << "after:";
296   for (const auto &item : expanded_list) {
297     MS_LOG(INFO) << item;
298   }
299   return expanded_list;
300 }
301 
ExpandRankListWithDim(const RankList & rank_list,const int64_t dim)302 RankList CameCommHandler::ExpandRankListWithDim(const RankList &rank_list, const int64_t dim) {
303   RankList expanded_list;
304   for (const auto &rank : rank_list) {
305     Status ret_status;
306     RankList dim_rank_list;
307     std::tie(ret_status, dim_rank_list) = GetDimRankList(rank, dim);
308     if (ret_status != SUCCESS) {
309       MS_LOG(EXCEPTION) << "find dim rank list in adafactor failed";
310     }
311     expanded_list.insert(expanded_list.end(), dim_rank_list.begin(), dim_rank_list.end());
312   }
313   std::sort(expanded_list.begin(), expanded_list.end());
314   return expanded_list;
315 }
316 
FindReduceMean(size_t number)317 CNodePtr CameCommHandler::FindReduceMean(size_t number) {
318   if (reduce_mean_numbers.find(number) == reduce_mean_numbers.end()) {
319     MS_LOG(INFO) << "[CAME] invalid reduce mean number: " << number;
320   }
321 
322   if (number == kFirstCameReduceMean) {
323     return FindReduceMean1256(exp_avg_sq_row);
324   } else if (number == kSecondCameReduceMean) {
325     return FindReduceMean1256(exp_avg_sq_col);
326   } else if (number == kThirdCameReduceMean) {
327     return FindReduceMean37(exp_avg_sq_row);
328   } else if (number == kForthCameReduceMean) {
329     return FindReduceMean4();
330   } else if (number == kFifthCameReduceMean) {
331     return FindReduceMean1256(exp_avg_insta_row);
332   } else if (number == kSixthCameReduceMean) {
333     return FindReduceMean1256(exp_avg_insta_col);
334   } else if (number == kSeventhCameReduceMean) {
335     return FindReduceMean37(exp_avg_insta_row);
336   } else {
337     return nullptr;
338   }
339 }
340 
FindReduceMean1256(const ParameterPtr & param)341 CNodePtr CameCommHandler::FindReduceMean1256(const ParameterPtr &param) {
342   if (!param) {
343     return nullptr;
344   }
345   MS_LOG(INFO) << "[CAME] try find reduce_mean according to " << param->name() << " Assign:";
346   auto param_user_set = node_user_map.at(param->cast<AnfNodePtr>());
347   for (auto &param_pair : param_user_set) {
348     auto user_cnode = param_pair.first->cast<CNodePtr>();
349     MS_EXCEPTION_IF_NULL(user_cnode);
350     if (IsSomePrimitive(user_cnode, ASSIGN)) {
351       MS_LOG(INFO) << "[CAME] found assign node";
352       // assign 1 -> add 1 -> mul 0 -> reduce_mean
353       auto res = BackwardSearchCNode(user_cnode, {{ASSIGN, 1}, {ADD, 1}, {MUL, 0}}, REDUCE_MEAN);
354       if (res.first) {
355         MS_LOG(INFO) << "[CAME] found reduce mean node: " << res.second->DebugString();
356         return res.second;
357       }
358     }
359   }
360   return nullptr;
361 }
362 
FindReduceMean37(const ParameterPtr & param)363 CNodePtr CameCommHandler::FindReduceMean37(const ParameterPtr &param) {
364   if (!param) {
365     return nullptr;
366   }
367   auto param_user_set = node_user_map.at(param->cast<AnfNodePtr>());
368   MS_LOG(INFO) << "[CAME] user map size: " << param_user_set.size();
369   size_t load_count = 0;
370   for (auto &param_pair : param_user_set) {
371     auto user_cnode = param_pair.first->cast<CNodePtr>();
372     MS_EXCEPTION_IF_NULL(user_cnode);
373     if (IsSomePrimitive(user_cnode, LOAD)) {
374       MS_LOG(INFO) << "[CAME] found load node";
375       load_count++;
376       // load -> reduce mean
377       auto res = ForwardSearchCNode(user_cnode, {LOAD, REDUCE_MEAN}, node_user_map);
378       if (res.first) {
379         MS_LOG(INFO) << "[CAME] found reduce mean node size: " << res.second.size();
380         return res.second[0];  // get the first one
381       }
382     }
383   }
384   MS_LOG(INFO) << "[CAME] found load count: " << load_count;
385   return nullptr;
386 }
387 
FindReduceMean4()388 CNodePtr CameCommHandler::FindReduceMean4() {
389   MS_LOG(INFO) << "[CAME] try find reduce_mean no.4 according to exp_avg Assign:";
390   if (!exp_avg) {
391     return nullptr;
392   }
393   auto exp_avg_user_set = node_user_map.at(exp_avg->cast<AnfNodePtr>());
394   for (auto &param_pair : exp_avg_user_set) {
395     auto user_cnode = param_pair.first->cast<CNodePtr>();
396     MS_EXCEPTION_IF_NULL(user_cnode);
397     if (IsSomePrimitive(user_cnode, ASSIGN)) {
398       MS_LOG(INFO) << "[CAME] found exp_avg's assign node";
399       auto res = BackwardSearchCNode(
400         user_cnode, {{ASSIGN, 1}, {ADD, 1}, {MUL, 0}, {REAL_DIV, 1}, {MAXIMUM, 0}, {REAL_DIV, 0}, {SQRT, 0}},
401         REDUCE_MEAN);
402       if (res.first) {
403         MS_LOG(INFO) << "[CAME] found reduce mean node: " << res.second->DebugString();
404         return res.second;
405       }
406     }
407   }
408   return nullptr;
409 }
410 
InsertAllReduceAndRealDivToReduceMeanInput(CNodePtr reduce_mean,const RankList & comm_rank_list)411 void CameCommHandler::InsertAllReduceAndRealDivToReduceMeanInput(CNodePtr reduce_mean, const RankList &comm_rank_list) {
412   // construct all reduce cnode and insert to the first input
413   if (!reduce_mean) {
414     return;
415   }
416   FuncGraphPtr func_graph = reduce_mean->func_graph();
417   MS_EXCEPTION_IF_NULL(func_graph);
418   FuncGraphManagerPtr manager = func_graph->manager();
419   MS_EXCEPTION_IF_NULL(manager);
420 
421   CheckGlobalDeviceManager();
422 
423   MS_LOG(INFO) << "Insert All Reduce and RealDiv to node" << reduce_mean->DebugString();
424   // insert all reduce
425   OperatorName allreduce_op_name = ALL_REDUCE;
426   OperatorAttrs all_reduce_op_attrs;
427   ValuePtr allreduce_pyop_instance = CreateOpInstance(all_reduce_op_attrs, allreduce_op_name, "came_norm_allreduce");
428   std::vector<AnfNodePtr> all_reduce_input = {NewValueNode(allreduce_pyop_instance), reduce_mean};
429   auto all_reduce_node = func_graph->NewCNode(all_reduce_input);
430   auto all_reduce_prim = GetCNodePrimitive(all_reduce_node);
431   auto all_reduce_attrs = all_reduce_prim->attrs();
432   all_reduce_attrs["op"] = MakeValue<std::string>(REDUCE_OP_SUM);
433 
434   std::string group_name = CreateCommGroupFromRankList(comm_rank_list);
435   MS_LOG(INFO) << "[CAME] came allreduce opt shard group: " << group_name;
436   all_reduce_attrs["group"] = MakeValue<std::string>(group_name);
437   int64_t fusion_id = 0;
438   all_reduce_attrs["fusion"] = MakeValue(fusion_id);
439   all_reduce_prim->SetAttrs(all_reduce_attrs);
440   // insert real div
441   OperatorName operator_name = REAL_DIV;
442   OperatorAttrs operator_attrs;
443 
444   ValuePtr pyop_instance = CreateOpInstance(operator_attrs, operator_name, "came_norm_realdiv");
445   MS_EXCEPTION_IF_NULL(pyop_instance);
446 
447   size_t group_rank_size = comm_rank_list.size();
448   mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
449     static_cast<float>(group_rank_size),
450     reduce_mean->abstract()->cast<abstract::AbstractTensorPtr>()->element()->GetType());
451   ValuePtr scale_value = MakeValue(tensor_ptr);
452 
453   std::vector<AnfNodePtr> real_div_input = {NewValueNode(pyop_instance), all_reduce_node->cast<AnfNodePtr>(),
454                                             NewValueNode(scale_value)};
455   auto real_div_node = func_graph->NewCNode(real_div_input);
456   manager->Replace(reduce_mean, real_div_node);
457 }
458 
Process()459 void CameCommHandler::Process() {
460   auto reduce_mean_1 = FindReduceMean(1);
461   auto reduce_mean_2 = FindReduceMean(2);
462   auto reduce_mean_3 = FindReduceMean(3);
463   auto reduce_mean_4 = FindReduceMean(4);
464   auto reduce_mean_5 = FindReduceMean(5);
465   auto reduce_mean_6 = FindReduceMean(6);
466   auto reduce_mean_7 = FindReduceMean(7);
467   MS_LOG(INFO) << "found all reduce mean for came/adafactor";
468 
469   auto shape_size = tensor_layout->slice_shape().array().size();
470   if (shape_size == 1) {
471     // for shape [A], mp and opt shard may overlay on dim A.
472     Status ret_status;
473     RankList comm_rank_list;
474     std::tie(ret_status, comm_rank_list) = GetDimRankList(cur_rank, 0);
475     if (ret_status != SUCCESS) {
476       MS_LOG(ERROR) << "[CAME] shape size = 1, getting rank list along 0 failed";
477     }
478     comm_rank_list = ExpandRankListWithOptShard(comm_rank_list);
479     if (comm_rank_list.size() > 1) {
480       InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_4, comm_rank_list);
481     }
482   } else {
483     Status ret_status;
484     RankList comm_rank_list_along_neg_1;
485     RankList comm_rank_list_along_neg_2;
486     RankList comm_rank_list_along_neg_12;
487     int64_t actual_dim_of_neg_1 = SizeToLong(shape_size) - 1;
488     int64_t actual_dim_of_neg_2 = SizeToLong(shape_size) - 2;
489     std::tie(ret_status, comm_rank_list_along_neg_1) = GetDimRankList(cur_rank, actual_dim_of_neg_1);
490     if (ret_status != SUCCESS) {
491       MS_LOG(ERROR) << "[CAME] shape = 2, getting rank list along negative dim -1 failed";
492     }
493     std::tie(ret_status, comm_rank_list_along_neg_2) = GetDimRankList(cur_rank, actual_dim_of_neg_2);
494     if (ret_status != SUCCESS) {
495       MS_LOG(ERROR) << "[CAME] shape = 2, getting rank list along negative dim -2 failed";
496     }
497     if (shape_size == kParameterDimTwo) {
498       comm_rank_list_along_neg_2 = ExpandRankListWithOptShard(comm_rank_list_along_neg_2);
499     }
500     comm_rank_list_along_neg_12 = ExpandRankListWithDim(comm_rank_list_along_neg_2, actual_dim_of_neg_1);
501     if (comm_rank_list_along_neg_1.size() > 1) {
502       InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_1, comm_rank_list_along_neg_1);
503       InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_5, comm_rank_list_along_neg_1);
504     }
505     if (comm_rank_list_along_neg_2.size() > 1) {
506       InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_2, comm_rank_list_along_neg_2);
507       InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_3, comm_rank_list_along_neg_2);
508       InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_6, comm_rank_list_along_neg_2);
509       InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_7, comm_rank_list_along_neg_2);
510     }
511     if (comm_rank_list_along_neg_12.size() > 1) {
512       InsertAllReduceAndRealDivToReduceMeanInput(reduce_mean_4, comm_rank_list_along_neg_12);
513     }
514   }
515 }
516 
CreateCommGroupFromRankList(const RankList & rank_list)517 std::string CameCommHandler::CreateCommGroupFromRankList(const RankList &rank_list) {
518   Group comm_group;
519   if (g_device_manager->CreateGroup(rank_list, &comm_group) != SUCCESS) {
520     MS_LOG(EXCEPTION) << "Create comm group failed in came";
521   }
522   std::string group_name = comm_group.name();
523   return group_name;
524 }
525 
526 }  // namespace parallel
527 }  // namespace mindspore
528