• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/pass/communication_op_fusion.h"
17 
18 #include <vector>
19 #include <set>
20 #include <memory>
21 #include <unordered_map>
22 
23 #include "ir/graph_utils.h"
24 #include "base/core_ops.h"
25 #include "runtime/device/kernel_info.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "backend/kernel_compiler/kernel_build_info.h"
28 #include "frontend/parallel/context.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 constexpr auto kAttrDefaultGroup = "default_group";
34 constexpr auto kAttrDefaultOp = "default_op";
35 constexpr size_t kAlignSize = 2 << 9;
36 
GenerateKernelBuildInfo(const CommunicationOpInfo & communication_op_info,size_t start_index,size_t end_index)37 kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index,
38                                                    size_t end_index) {
39   if (end_index >= communication_op_info.communication_op_nodes.size()) {
40     MS_LOG(EXCEPTION) << "end index out of communication_op_nodes size";
41   }
42   std::vector<std::string> inputs_device_format;
43   std::vector<std::string> outputs_device_format;
44   std::vector<TypeId> inputs_device_type;
45   std::vector<TypeId> outputs_device_type;
46   std::vector<std::vector<size_t>> outputs_shape;
47   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
48   for (size_t idx = start_index; idx <= end_index; ++idx) {
49     auto cnode = communication_op_info.communication_op_nodes[idx];
50     int64_t rank_size = 1;
51     if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) && AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) {
52       rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
53     }
54     size_t rank_size_t = LongToSize(rank_size);
55     if (rank_size_t == 0) {
56       MS_LOG(EXCEPTION) << "Rank size should not be zero.";
57     }
58     MS_EXCEPTION_IF_NULL(cnode);
59     size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
60     for (size_t input_index = 0; input_index < input_num; ++input_index) {
61       inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
62       inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
63     }
64     for (size_t rank_index = 0; rank_index < rank_size_t; ++rank_index) {
65       size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
66       for (size_t output_index = 0; output_index < output_num; ++output_index) {
67         outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
68         outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
69         std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
70         if (!shape.empty()) {
71           shape[0] /= rank_size_t;
72         }
73         outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
74       }
75     }
76     builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
77     builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
78     builder.SetKernelType(AnfAlgo::GetKernelType(cnode));
79   }
80   builder.SetInputsFormat(inputs_device_format);
81   builder.SetOutputsFormat(outputs_device_format);
82   builder.SetInputsDeviceType(inputs_device_type);
83   builder.SetOutputsDeviceType(outputs_device_type);
84   return builder.Build();
85 }
86 
GetFusionGroupKey(const AnfNodePtr & node)87 std::string GetFusionGroupKey(const AnfNodePtr &node) {
88   auto primitive = AnfAlgo::GetCNodePrimitive(node);
89   MS_EXCEPTION_IF_NULL(primitive);
90   ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
91   if (attr_fusion == nullptr) {
92     return "";
93   }
94   auto fusion = GetValue<int64_t>(attr_fusion);
95   if (fusion == 0) {
96     return "";
97   }
98   std::string group = kAttrDefaultGroup;
99   ValuePtr attr_group = primitive->GetAttr(kAttrGroup);
100   if (attr_group != nullptr) {
101     group = GetValue<std::string>(attr_group);
102   }
103   std::string op = kAttrDefaultOp;
104   ValuePtr attr_op = primitive->GetAttr(kAttrOp);
105   if (attr_op != nullptr) {
106     op = GetValue<std::string>(attr_op);
107   }
108   auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
109   return group + op + std::to_string(fusion) + TypeIdLabel(dtype);
110 }
111 
CheckInputs(const std::vector<AnfNodePtr> & fusion_inputs)112 void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
113   std::set<AnfNodePtr> inputs_set(fusion_inputs.begin(), fusion_inputs.end());
114   if (inputs_set.size() < fusion_inputs.size()) {
115     MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input";
116   }
117 }
118 
CheckSegments(size_t segments,size_t communication_op_node_size,const std::vector<size_t> * segment_index)119 bool CheckSegments(size_t segments, size_t communication_op_node_size, const std::vector<size_t> *segment_index) {
120   MS_EXCEPTION_IF_NULL(segment_index);
121   if (segments >= communication_op_node_size) {
122     MS_LOG(INFO) << "fusion not changed: segment_num=" << segments
123                  << ", communication_op_node_size=" << communication_op_node_size;
124     return false;
125   }
126   if (segment_index->at(segments - 1) != communication_op_node_size - 1) {
127     MS_LOG(EXCEPTION) << "the last segment index is invalid.";
128   }
129   for (size_t i = 0; i < segments - 1; ++i) {
130     if (segment_index->at(i) > segment_index->at(i + 1)) {
131       MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ "
132                         << (i + 1) << "]=" << segment_index->at(i + 1);
133     }
134   }
135   return true;
136 }
137 }  // namespace
138 
GetSplitSegments(const CommunicationOpInfo & communication_op_info,size_t * segment_num,std::vector<size_t> * segment_index,const std::string & group) const139 bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
140                                              std::vector<size_t> *segment_index, const std::string &group) const {
141   MS_EXCEPTION_IF_NULL(segment_num);
142   MS_EXCEPTION_IF_NULL(segment_index);
143   size_t communication_op_node_size = communication_op_info.communication_op_nodes.size();
144   MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size;
145 
146   if (op_name_ == kHcomSendOpName || op_name_ == kReceiveOpName) {
147     *segment_num = 1;
148     if (communication_op_node_size == 0) {
149       return false;
150     }
151     (void)segment_index->emplace_back(communication_op_node_size - 1);
152     return true;
153   }
154 
155   auto parallel_context = parallel::ParallelContext::GetInstance();
156   MS_EXCEPTION_IF_NULL(parallel_context);
157   std::vector<uint32_t> split_indices;
158   if (!parallel_context->enable_parallel_optimizer()) {
159     split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
160   }
161 
162   size_t segments = 0;
163   if (!split_indices.empty()) {
164     uint32_t last_index = 0;
165     for (size_t i = 0; i < split_indices.size(); ++i) {
166       uint32_t index = split_indices[i];
167       if (index <= last_index && i != 0) {
168         MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index;
169       }
170       if (index >= communication_op_node_size) {
171         MS_LOG(WARNING) << op_name_ << "'s split index " << index
172                         << " is Greater than or equal to total gradient's number " << communication_op_node_size;
173         continue;
174       }
175       segment_index->push_back(index);
176       last_index = index;
177       segments++;
178     }
179     if (last_index != communication_op_node_size - 1) {
180       segment_index->push_back(communication_op_node_size - 1);
181       segments++;
182     }
183   } else {
184     segments = groups_;
185     for (size_t i = 0; i < segments - 1; ++i) {
186       segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1);
187     }
188     segment_index->push_back(communication_op_node_size - 1);
189   }
190 
191   *segment_num = segments;
192   return CheckSegments(segments, communication_op_node_size, segment_index);
193 }
194 
195 // Hard coded Load(%paraxxx, cnode()) to Load(%paraxxx, U) to prevent
196 // cycle after AllReduce fused. It's a workaround.
197 // case 1:
198 // cnode_load = Load(%para2, cnode_u)
199 // %100 = UpdateState(cnode_u, cnode_load)
200 // ...
201 // %109 = AssignAdd(%para485, Tensor(34), %100)
202 // %110 = UpdateState(%100, xxx)
203 // will convert to:
204 // cnode_load = Load(%para2, U)
205 // ...
206 // %109 = AssignAdd(%para485, Tensor(34), cnode_u)
207 // %110 = UpdateState(cnode_u, xxx)
208 //
209 // case 2:
210 // cnode_load = Load(%para2, cnode_u)
211 // %99 = make_tuple(yyy, ..., cnode_load, ...)
212 // %100 = UpdateState(cnode_u, %99)
213 // ...
214 // %109 = AssignAdd(%para485, Tensor(34), %100)
215 // %110 = UpdateState(%100, xxx)
216 // will convert to:
217 // cnode_load = Load(%para2, U)
218 // %99 = make_tuple(yyy, ...)
219 // %100 = UpdateState(cnode_u, %99)
220 // ...
221 // %109 = AssignAdd(%para485, Tensor(34), %100)
222 // %110 = UpdateState(%100, xxx)
223 //
224 // case 3:
225 // cnode_load = Load(%para2, cnode_u)
226 // %99 = make_tuple(cnode_load)
227 // %100 = UpdateState(cnode_u, %99)
228 // ...
229 // %109 = AssignAdd(%para485, Tensor(34), %100)
230 // %110 = UpdateState(%100, xxx)
231 // will convert to:
232 // cnode_load = Load(%para2, U)
233 // ...
234 // %109 = AssignAdd(%para485, Tensor(34), cnode_u)
235 // %110 = UpdateState(cnode_u, xxx)
AdjustAllReduceInputWithLoad(const CNodePtr & cnode)236 static void AdjustAllReduceInputWithLoad(const CNodePtr &cnode) {
237   const size_t monad_index = 2;
238   const size_t tuple_inputs_size = 2;
239   const size_t load_inputs_size = 3;
240   auto cnode_load = BroadFirstSearchFirstOf({cnode}, [](const CNodePtr &search_cnode) {
241     if (!IsPrimitiveCNode(search_cnode, prim::kPrimLoad)) {
242       return false;
243     }
244     if (search_cnode->inputs().size() != load_inputs_size) {
245       MS_LOG(EXCEPTION) << "Load CNode should have 3 inputs, but: " << search_cnode->DebugString();
246     }
247     return search_cnode->input(monad_index)->isa<CNode>();
248   });
249   if (cnode_load != nullptr) {
250     auto const_u_monad = NewValueNode(kUMonad);
251     const_u_monad->set_abstract(kUMonad->ToAbstract());
252     const auto &cnode_u = cnode_load->input(monad_index);
253     MS_LOG(DEBUG) << "Replace Load with CNode U to constant U for cnode: " << cnode_load->DebugString();
254     MS_EXCEPTION_IF_NULL(cnode->func_graph());
255     MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
256     auto manager = cnode->func_graph()->manager();
257     manager->SetEdge(cnode_load, monad_index, const_u_monad);
258     // Update the u_monad input of UpdateState from CNode U same as Load to constant U.
259     CNodePtr cnode_update_state = nullptr;
260     CNodePtr cnode_make_tuple = nullptr;
261     const auto &cnode_load_users = manager->node_users()[cnode_load];
262     for (auto &load_user : cnode_load_users) {
263       if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
264         const auto &cnode_make_tuple_users = manager->node_users()[load_user.first];
265         for (auto &make_tuple_user : cnode_make_tuple_users) {
266           if (IsPrimitiveCNode(make_tuple_user.first, prim::kPrimUpdateState)) {
267             const auto &cnode_user = make_tuple_user.first->cast<CNodePtr>();
268             if (cnode_user->input(1) == cnode_u) {
269               cnode_update_state = cnode_user;
270               cnode_make_tuple = load_user.first->cast<CNodePtr>();
271               break;
272             }
273           }
274         }
275         if (cnode_update_state != nullptr) {
276           break;
277         }
278       }
279       if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
280         const auto &cnode_user = load_user.first->cast<CNodePtr>();
281         if (cnode_user->input(1) == cnode_u) {
282           cnode_update_state = cnode_user;
283           break;
284         }
285       }
286     }
287     if (cnode_update_state != nullptr) {
288       if (cnode_make_tuple == nullptr || cnode_make_tuple->inputs().size() == tuple_inputs_size) {
289         // case 1 and case 3: Replace cnode_update_state to cnode_u;
290         MS_LOG(DEBUG) << "Replace UpdateState with CNode U: " << cnode_update_state->DebugString()
291                       << " ::TO:: " << cnode_u->DebugString();
292         manager->Replace(cnode_update_state, cnode_u);
293       } else if (cnode_make_tuple->inputs().size() > tuple_inputs_size) {
294         // case 2: remove cnode_load from cnode_make_tuple;
295         MS_LOG(DEBUG) << "Drop " << cnode_load->DebugString() << " from " << cnode_make_tuple->DebugString();
296         const auto &make_tuple_inputs = cnode_make_tuple->inputs();
297         AnfNodePtrList new_tuple_inputs(make_tuple_inputs.size() - 1);
298         std::copy_if(make_tuple_inputs.cbegin(), make_tuple_inputs.cend(), new_tuple_inputs.begin(),
299                      [cnode_load](const auto &inp) { return inp != cnode_load; });
300         auto new_cnode_make_tuple = cnode_make_tuple->func_graph()->NewCNode(new_tuple_inputs);
301         manager->Replace(cnode_make_tuple, new_cnode_make_tuple);
302       } else {
303         MS_LOG(EXCEPTION) << "Cannot replace UpdateState with CNode U: " << cnode_update_state->DebugString()
304                           << " as make_tuple CNode cannot match " << cnode_make_tuple->DebugString();
305       }
306     }
307   }
308 }
309 
CreateFusedCommunicationOp(const FuncGraphPtr & func_graph,const CommunicationOpInfo & communication_op_info,size_t start_index,size_t end_index) const310 AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
311                                                              const CommunicationOpInfo &communication_op_info,
312                                                              size_t start_index, size_t end_index) const {
313   MS_EXCEPTION_IF_NULL(func_graph);
314   auto prim = std::make_shared<Primitive>(op_name_);
315   MS_EXCEPTION_IF_NULL(prim);
316   std::vector<AnfNodePtr> fusion_inputs = {NewValueNode(prim)};
317   // get all inputs of current segment
318   if (end_index >= communication_op_info.communication_op_nodes.size()) {
319     MS_LOG(EXCEPTION) << "end index out of communication_op_nodes size";
320   }
321   for (size_t idx = start_index; idx <= end_index; ++idx) {
322     auto cnode = communication_op_info.communication_op_nodes[idx];
323     MS_EXCEPTION_IF_NULL(cnode);
324     if (idx != start_index) {
325       AdjustAllReduceInputWithLoad(cnode);
326     }
327     fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
328   }
329   CheckInputs(fusion_inputs);
330   AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs);
331   MS_EXCEPTION_IF_NULL(fused_node);
332   auto kernel_info = std::make_shared<device::KernelInfo>();
333   MS_EXCEPTION_IF_NULL(kernel_info);
334   fused_node->set_kernel_info(kernel_info);
335   auto final_node = communication_op_info.communication_op_nodes[end_index];
336   size_t node_num = end_index - start_index + 1;
337   int64_t rank_size = 1;
338   if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) && AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) {
339     rank_size = AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize);
340   }
341   size_t rank_size_t = LongToSize(rank_size);
342   if (rank_size_t == 0) {
343     MS_LOG(EXCEPTION) << "Rank size should not be zero.";
344   }
345   size_t output_num = node_num * rank_size_t;
346   std::vector<TypeId> dtypes(output_num, AnfAlgo::GetOutputInferDataType(final_node, 0));
347   std::vector<std::vector<size_t>> shapes;
348   int64_t fusion_total_size = 0;
349   for (size_t i = 0; i < rank_size_t; ++i) {
350     for (size_t idx = start_index; idx <= end_index; ++idx) {
351       auto input_node = communication_op_info.communication_op_nodes[idx];
352       MS_EXCEPTION_IF_NULL(input_node);
353       std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(input_node, 0);
354       if (!shape.empty()) {
355         shape[0] /= rank_size_t;
356       }
357       shapes.push_back(shape);
358       size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
359       TypeId output_type = AnfAlgo::GetOutputDeviceDataType(input_node, 0);
360       size_t type_size = GetTypeByte(TypeIdToType(output_type));
361       if (type_size == 0) {
362         MS_LOG(EXCEPTION) << "Divisor 'type_size' should not be 0.";
363       }
364       tensor_size = (tensor_size / kAlignSize + 1) * kAlignSize / type_size;
365       fusion_total_size += static_cast<int64_t>(tensor_size);
366     }
367   }
368   AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get());
369   auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
370   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get());
371   const std::vector<std::string> kHcclFusionAttrs = {kAttrFusion, kAttrGroup,    kAttrGroupBack,
372                                                      kAttrSrTag,  kAttrDestRank, kAttrSrcRank,
373                                                      kAttrDType,  kAttrOp,       kAttrRankSize};
374   for (const auto &attr : kHcclFusionAttrs) {
375     if (AnfAlgo::HasNodeAttr(attr, final_node)) {
376       AnfAlgo::CopyNodeAttr(attr, final_node, fused_node);
377     }
378   }
379   if (AnfAlgo::HasNodeAttr(kAttrShape, final_node)) {
380     std::vector<int64_t> fusion_total_shape{fusion_total_size};
381     AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(fusion_total_shape), fused_node);
382   }
383   bool is_recompute =
384     final_node->GetAttr(kAttrDuplicated) != nullptr && GetValue<bool>(final_node->GetAttr(kAttrDuplicated));
385   if (AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName && is_recompute) {
386     auto fused_cnode = fused_node->cast<CNodePtr>();
387     fused_cnode->AddAttr("duplicated", MakeValue(true));
388     auto fused_prim = GetCNodePrimitive(fused_cnode);
389     auto final_node_prim = GetCNodePrimitive(final_node);
390     fused_prim->set_instance_name(final_node_prim->instance_name());
391   }
392   return fused_node;
393 }
394 
DoFusion(const FuncGraphPtr & func_graph,const CommunicationOpInfo & communication_op_info,size_t segment_num,const std::vector<size_t> & segment_index) const395 bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info,
396                                      size_t segment_num, const std::vector<size_t> &segment_index) const {
397   MS_EXCEPTION_IF_NULL(func_graph);
398   auto manager = func_graph->manager();
399   MS_EXCEPTION_IF_NULL(manager);
400   bool changed = false;
401   size_t start_index = 0;
402   for (size_t segment_idx = 0; segment_idx < segment_num; ++segment_idx) {
403     size_t end_index = segment_index.at(segment_idx);
404     if (end_index - start_index < 1) {
405       start_index = end_index + 1;
406       continue;
407     }
408     auto kernel_graph = func_graph->cast<KernelGraphPtr>();
409     MS_EXCEPTION_IF_NULL(kernel_graph);
410     auto graph_id = kernel_graph->graph_id();
411     AnfNodePtr new_communication_op =
412       CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
413     AnfAlgo::SetGraphId(graph_id, new_communication_op.get());
414     // replace old communication op with new communication op
415     for (auto idx = start_index; idx <= end_index; ++idx) {
416       std::vector<AnfNodePtr> tuple_getitem_input;
417       tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
418       tuple_getitem_input.push_back(new_communication_op);
419       auto offset = SizeToLong(idx - start_index);
420       auto index = NewValueNode(offset);
421       MS_EXCEPTION_IF_NULL(index);
422       auto imm = std::make_shared<Int64Imm>(idx - start_index);
423       MS_EXCEPTION_IF_NULL(imm);
424       auto abstract_scalar = std::make_shared<abstract::AbstractScalar>();
425       MS_EXCEPTION_IF_NULL(abstract_scalar);
426       index->set_abstract(abstract_scalar);
427       tuple_getitem_input.push_back(index);
428       AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input);
429       MS_EXCEPTION_IF_NULL(tuple_getitem);
430       auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx);
431       MS_EXCEPTION_IF_NULL(communication_op_node_item);
432       tuple_getitem->set_abstract(communication_op_node_item->abstract());
433       if (kernel_graph->IsInternalOutput(communication_op_node_item, 0)) {
434         kernel_graph->ReplaceInternalOutput(communication_op_node_item, new_communication_op, 0, LongToSize(offset));
435       }
436       if (!manager->Replace(communication_op_node_item, tuple_getitem)) {
437         MS_LOG(EXCEPTION) << "manager replace node failed";
438       }
439     }
440     start_index = end_index + 1;
441     changed = true;
442   }
443   return changed;
444 }
445 
Run(const FuncGraphPtr & func_graph)446 bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
447   MS_EXCEPTION_IF_NULL(func_graph);
448   const float input_grad_size_num = 0.0;
449   const float input_grad_time_num = 0.0;
450   // divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion
451   std::unordered_map<std::string, CommunicationOpInfo> candidate_groups;
452   std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
453   for (auto &node : node_list) {
454     if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == op_name_) {
455       std::string key = GetFusionGroupKey(node);
456       if (key.empty()) {
457         continue;
458       }
459       if (candidate_groups.find(key) == candidate_groups.end()) {
460         CommunicationOpInfo communication_op_info;
461         candidate_groups[key] = communication_op_info;
462       }
463       candidate_groups[key].communication_op_nodes.push_back(node->cast<CNodePtr>());
464       candidate_groups[key].input_grad_size.push_back(input_grad_size_num);
465       candidate_groups[key].input_grad_time.push_back(input_grad_time_num);
466     }
467   }
468   // split candidate group to segments according to _group class member
469   bool changed = false;
470   for (auto &it : candidate_groups) {
471     if (it.second.communication_op_nodes.size() <= 1) {
472       continue;
473     }
474     auto first_node = it.second.communication_op_nodes[0];
475     TraceGuard guard(std::make_shared<TraceOpt>(first_node->debug_info()));
476     if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr<int64_t>(first_node, kAttrIndex) > 0) {
477       std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(),
478                        [](const CNodePtr &a, const CNodePtr &b) {
479                          return AnfAlgo::GetNodeAttr<int64_t>(a, kAttrIndex) <
480                                 AnfAlgo::GetNodeAttr<int64_t>(b, kAttrIndex);
481                        });
482     }
483     size_t segment_num = 0;
484     std::vector<size_t> segment_index;
485     if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) {
486       if (DoFusion(func_graph, it.second, segment_num, segment_index)) {
487         changed = true;
488       }
489     }
490   }
491   return changed;
492 }
493 }  // namespace opt
494 }  // namespace mindspore
495