1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/pass/communication_op_fusion.h"
17
18 #include <vector>
19 #include <set>
20 #include <memory>
21 #include <unordered_map>
22
23 #include "ir/graph_utils.h"
24 #include "base/core_ops.h"
25 #include "runtime/device/kernel_info.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "backend/kernel_compiler/kernel_build_info.h"
28 #include "frontend/parallel/context.h"
29
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 constexpr auto kAttrDefaultGroup = "default_group";
34 constexpr auto kAttrDefaultOp = "default_op";
35 constexpr size_t kAlignSize = 2 << 9;
36
GenerateKernelBuildInfo(const CommunicationOpInfo & communication_op_info,size_t start_index,size_t end_index)37 kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index,
38 size_t end_index) {
39 if (end_index >= communication_op_info.communication_op_nodes.size()) {
40 MS_LOG(EXCEPTION) << "end index out of communication_op_nodes size";
41 }
42 std::vector<std::string> inputs_device_format;
43 std::vector<std::string> outputs_device_format;
44 std::vector<TypeId> inputs_device_type;
45 std::vector<TypeId> outputs_device_type;
46 std::vector<std::vector<size_t>> outputs_shape;
47 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
48 for (size_t idx = start_index; idx <= end_index; ++idx) {
49 auto cnode = communication_op_info.communication_op_nodes[idx];
50 int64_t rank_size = 1;
51 if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) && AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) {
52 rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
53 }
54 size_t rank_size_t = LongToSize(rank_size);
55 if (rank_size_t == 0) {
56 MS_LOG(EXCEPTION) << "Rank size should not be zero.";
57 }
58 MS_EXCEPTION_IF_NULL(cnode);
59 size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
60 for (size_t input_index = 0; input_index < input_num; ++input_index) {
61 inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
62 inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
63 }
64 for (size_t rank_index = 0; rank_index < rank_size_t; ++rank_index) {
65 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
66 for (size_t output_index = 0; output_index < output_num; ++output_index) {
67 outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
68 outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
69 std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
70 if (!shape.empty()) {
71 shape[0] /= rank_size_t;
72 }
73 outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
74 }
75 }
76 builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
77 builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
78 builder.SetKernelType(AnfAlgo::GetKernelType(cnode));
79 }
80 builder.SetInputsFormat(inputs_device_format);
81 builder.SetOutputsFormat(outputs_device_format);
82 builder.SetInputsDeviceType(inputs_device_type);
83 builder.SetOutputsDeviceType(outputs_device_type);
84 return builder.Build();
85 }
86
GetFusionGroupKey(const AnfNodePtr & node)87 std::string GetFusionGroupKey(const AnfNodePtr &node) {
88 auto primitive = AnfAlgo::GetCNodePrimitive(node);
89 MS_EXCEPTION_IF_NULL(primitive);
90 ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
91 if (attr_fusion == nullptr) {
92 return "";
93 }
94 auto fusion = GetValue<int64_t>(attr_fusion);
95 if (fusion == 0) {
96 return "";
97 }
98 std::string group = kAttrDefaultGroup;
99 ValuePtr attr_group = primitive->GetAttr(kAttrGroup);
100 if (attr_group != nullptr) {
101 group = GetValue<std::string>(attr_group);
102 }
103 std::string op = kAttrDefaultOp;
104 ValuePtr attr_op = primitive->GetAttr(kAttrOp);
105 if (attr_op != nullptr) {
106 op = GetValue<std::string>(attr_op);
107 }
108 auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
109 return group + op + std::to_string(fusion) + TypeIdLabel(dtype);
110 }
111
CheckInputs(const std::vector<AnfNodePtr> & fusion_inputs)112 void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
113 std::set<AnfNodePtr> inputs_set(fusion_inputs.begin(), fusion_inputs.end());
114 if (inputs_set.size() < fusion_inputs.size()) {
115 MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input";
116 }
117 }
118
CheckSegments(size_t segments,size_t communication_op_node_size,const std::vector<size_t> * segment_index)119 bool CheckSegments(size_t segments, size_t communication_op_node_size, const std::vector<size_t> *segment_index) {
120 MS_EXCEPTION_IF_NULL(segment_index);
121 if (segments >= communication_op_node_size) {
122 MS_LOG(INFO) << "fusion not changed: segment_num=" << segments
123 << ", communication_op_node_size=" << communication_op_node_size;
124 return false;
125 }
126 if (segment_index->at(segments - 1) != communication_op_node_size - 1) {
127 MS_LOG(EXCEPTION) << "the last segment index is invalid.";
128 }
129 for (size_t i = 0; i < segments - 1; ++i) {
130 if (segment_index->at(i) > segment_index->at(i + 1)) {
131 MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ "
132 << (i + 1) << "]=" << segment_index->at(i + 1);
133 }
134 }
135 return true;
136 }
137 } // namespace
138
GetSplitSegments(const CommunicationOpInfo & communication_op_info,size_t * segment_num,std::vector<size_t> * segment_index,const std::string & group) const139 bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
140 std::vector<size_t> *segment_index, const std::string &group) const {
141 MS_EXCEPTION_IF_NULL(segment_num);
142 MS_EXCEPTION_IF_NULL(segment_index);
143 size_t communication_op_node_size = communication_op_info.communication_op_nodes.size();
144 MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size;
145
146 if (op_name_ == kHcomSendOpName || op_name_ == kReceiveOpName) {
147 *segment_num = 1;
148 if (communication_op_node_size == 0) {
149 return false;
150 }
151 (void)segment_index->emplace_back(communication_op_node_size - 1);
152 return true;
153 }
154
155 auto parallel_context = parallel::ParallelContext::GetInstance();
156 MS_EXCEPTION_IF_NULL(parallel_context);
157 std::vector<uint32_t> split_indices;
158 if (!parallel_context->enable_parallel_optimizer()) {
159 split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
160 }
161
162 size_t segments = 0;
163 if (!split_indices.empty()) {
164 uint32_t last_index = 0;
165 for (size_t i = 0; i < split_indices.size(); ++i) {
166 uint32_t index = split_indices[i];
167 if (index <= last_index && i != 0) {
168 MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index;
169 }
170 if (index >= communication_op_node_size) {
171 MS_LOG(WARNING) << op_name_ << "'s split index " << index
172 << " is Greater than or equal to total gradient's number " << communication_op_node_size;
173 continue;
174 }
175 segment_index->push_back(index);
176 last_index = index;
177 segments++;
178 }
179 if (last_index != communication_op_node_size - 1) {
180 segment_index->push_back(communication_op_node_size - 1);
181 segments++;
182 }
183 } else {
184 segments = groups_;
185 for (size_t i = 0; i < segments - 1; ++i) {
186 segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1);
187 }
188 segment_index->push_back(communication_op_node_size - 1);
189 }
190
191 *segment_num = segments;
192 return CheckSegments(segments, communication_op_node_size, segment_index);
193 }
194
195 // Hard coded Load(%paraxxx, cnode()) to Load(%paraxxx, U) to prevent
196 // cycle after AllReduce fused. It's a workaround.
197 // case 1:
198 // cnode_load = Load(%para2, cnode_u)
199 // %100 = UpdateState(cnode_u, cnode_load)
200 // ...
201 // %109 = AssignAdd(%para485, Tensor(34), %100)
202 // %110 = UpdateState(%100, xxx)
203 // will convert to:
204 // cnode_load = Load(%para2, U)
205 // ...
206 // %109 = AssignAdd(%para485, Tensor(34), cnode_u)
207 // %110 = UpdateState(cnode_u, xxx)
208 //
209 // case 2:
210 // cnode_load = Load(%para2, cnode_u)
211 // %99 = make_tuple(yyy, ..., cnode_load, ...)
212 // %100 = UpdateState(cnode_u, %99)
213 // ...
214 // %109 = AssignAdd(%para485, Tensor(34), %100)
215 // %110 = UpdateState(%100, xxx)
216 // will convert to:
217 // cnode_load = Load(%para2, U)
218 // %99 = make_tuple(yyy, ...)
219 // %100 = UpdateState(cnode_u, %99)
220 // ...
221 // %109 = AssignAdd(%para485, Tensor(34), %100)
222 // %110 = UpdateState(%100, xxx)
223 //
224 // case 3:
225 // cnode_load = Load(%para2, cnode_u)
226 // %99 = make_tuple(cnode_load)
227 // %100 = UpdateState(cnode_u, %99)
228 // ...
229 // %109 = AssignAdd(%para485, Tensor(34), %100)
230 // %110 = UpdateState(%100, xxx)
231 // will convert to:
232 // cnode_load = Load(%para2, U)
233 // ...
234 // %109 = AssignAdd(%para485, Tensor(34), cnode_u)
235 // %110 = UpdateState(cnode_u, xxx)
AdjustAllReduceInputWithLoad(const CNodePtr & cnode)236 static void AdjustAllReduceInputWithLoad(const CNodePtr &cnode) {
237 const size_t monad_index = 2;
238 const size_t tuple_inputs_size = 2;
239 const size_t load_inputs_size = 3;
240 auto cnode_load = BroadFirstSearchFirstOf({cnode}, [](const CNodePtr &search_cnode) {
241 if (!IsPrimitiveCNode(search_cnode, prim::kPrimLoad)) {
242 return false;
243 }
244 if (search_cnode->inputs().size() != load_inputs_size) {
245 MS_LOG(EXCEPTION) << "Load CNode should have 3 inputs, but: " << search_cnode->DebugString();
246 }
247 return search_cnode->input(monad_index)->isa<CNode>();
248 });
249 if (cnode_load != nullptr) {
250 auto const_u_monad = NewValueNode(kUMonad);
251 const_u_monad->set_abstract(kUMonad->ToAbstract());
252 const auto &cnode_u = cnode_load->input(monad_index);
253 MS_LOG(DEBUG) << "Replace Load with CNode U to constant U for cnode: " << cnode_load->DebugString();
254 MS_EXCEPTION_IF_NULL(cnode->func_graph());
255 MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
256 auto manager = cnode->func_graph()->manager();
257 manager->SetEdge(cnode_load, monad_index, const_u_monad);
258 // Update the u_monad input of UpdateState from CNode U same as Load to constant U.
259 CNodePtr cnode_update_state = nullptr;
260 CNodePtr cnode_make_tuple = nullptr;
261 const auto &cnode_load_users = manager->node_users()[cnode_load];
262 for (auto &load_user : cnode_load_users) {
263 if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
264 const auto &cnode_make_tuple_users = manager->node_users()[load_user.first];
265 for (auto &make_tuple_user : cnode_make_tuple_users) {
266 if (IsPrimitiveCNode(make_tuple_user.first, prim::kPrimUpdateState)) {
267 const auto &cnode_user = make_tuple_user.first->cast<CNodePtr>();
268 if (cnode_user->input(1) == cnode_u) {
269 cnode_update_state = cnode_user;
270 cnode_make_tuple = load_user.first->cast<CNodePtr>();
271 break;
272 }
273 }
274 }
275 if (cnode_update_state != nullptr) {
276 break;
277 }
278 }
279 if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
280 const auto &cnode_user = load_user.first->cast<CNodePtr>();
281 if (cnode_user->input(1) == cnode_u) {
282 cnode_update_state = cnode_user;
283 break;
284 }
285 }
286 }
287 if (cnode_update_state != nullptr) {
288 if (cnode_make_tuple == nullptr || cnode_make_tuple->inputs().size() == tuple_inputs_size) {
289 // case 1 and case 3: Replace cnode_update_state to cnode_u;
290 MS_LOG(DEBUG) << "Replace UpdateState with CNode U: " << cnode_update_state->DebugString()
291 << " ::TO:: " << cnode_u->DebugString();
292 manager->Replace(cnode_update_state, cnode_u);
293 } else if (cnode_make_tuple->inputs().size() > tuple_inputs_size) {
294 // case 2: remove cnode_load from cnode_make_tuple;
295 MS_LOG(DEBUG) << "Drop " << cnode_load->DebugString() << " from " << cnode_make_tuple->DebugString();
296 const auto &make_tuple_inputs = cnode_make_tuple->inputs();
297 AnfNodePtrList new_tuple_inputs(make_tuple_inputs.size() - 1);
298 std::copy_if(make_tuple_inputs.cbegin(), make_tuple_inputs.cend(), new_tuple_inputs.begin(),
299 [cnode_load](const auto &inp) { return inp != cnode_load; });
300 auto new_cnode_make_tuple = cnode_make_tuple->func_graph()->NewCNode(new_tuple_inputs);
301 manager->Replace(cnode_make_tuple, new_cnode_make_tuple);
302 } else {
303 MS_LOG(EXCEPTION) << "Cannot replace UpdateState with CNode U: " << cnode_update_state->DebugString()
304 << " as make_tuple CNode cannot match " << cnode_make_tuple->DebugString();
305 }
306 }
307 }
308 }
309
CreateFusedCommunicationOp(const FuncGraphPtr & func_graph,const CommunicationOpInfo & communication_op_info,size_t start_index,size_t end_index) const310 AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
311 const CommunicationOpInfo &communication_op_info,
312 size_t start_index, size_t end_index) const {
313 MS_EXCEPTION_IF_NULL(func_graph);
314 auto prim = std::make_shared<Primitive>(op_name_);
315 MS_EXCEPTION_IF_NULL(prim);
316 std::vector<AnfNodePtr> fusion_inputs = {NewValueNode(prim)};
317 // get all inputs of current segment
318 if (end_index >= communication_op_info.communication_op_nodes.size()) {
319 MS_LOG(EXCEPTION) << "end index out of communication_op_nodes size";
320 }
321 for (size_t idx = start_index; idx <= end_index; ++idx) {
322 auto cnode = communication_op_info.communication_op_nodes[idx];
323 MS_EXCEPTION_IF_NULL(cnode);
324 if (idx != start_index) {
325 AdjustAllReduceInputWithLoad(cnode);
326 }
327 fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
328 }
329 CheckInputs(fusion_inputs);
330 AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs);
331 MS_EXCEPTION_IF_NULL(fused_node);
332 auto kernel_info = std::make_shared<device::KernelInfo>();
333 MS_EXCEPTION_IF_NULL(kernel_info);
334 fused_node->set_kernel_info(kernel_info);
335 auto final_node = communication_op_info.communication_op_nodes[end_index];
336 size_t node_num = end_index - start_index + 1;
337 int64_t rank_size = 1;
338 if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) && AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) {
339 rank_size = AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize);
340 }
341 size_t rank_size_t = LongToSize(rank_size);
342 if (rank_size_t == 0) {
343 MS_LOG(EXCEPTION) << "Rank size should not be zero.";
344 }
345 size_t output_num = node_num * rank_size_t;
346 std::vector<TypeId> dtypes(output_num, AnfAlgo::GetOutputInferDataType(final_node, 0));
347 std::vector<std::vector<size_t>> shapes;
348 int64_t fusion_total_size = 0;
349 for (size_t i = 0; i < rank_size_t; ++i) {
350 for (size_t idx = start_index; idx <= end_index; ++idx) {
351 auto input_node = communication_op_info.communication_op_nodes[idx];
352 MS_EXCEPTION_IF_NULL(input_node);
353 std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(input_node, 0);
354 if (!shape.empty()) {
355 shape[0] /= rank_size_t;
356 }
357 shapes.push_back(shape);
358 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
359 TypeId output_type = AnfAlgo::GetOutputDeviceDataType(input_node, 0);
360 size_t type_size = GetTypeByte(TypeIdToType(output_type));
361 if (type_size == 0) {
362 MS_LOG(EXCEPTION) << "Divisor 'type_size' should not be 0.";
363 }
364 tensor_size = (tensor_size / kAlignSize + 1) * kAlignSize / type_size;
365 fusion_total_size += static_cast<int64_t>(tensor_size);
366 }
367 }
368 AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get());
369 auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
370 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get());
371 const std::vector<std::string> kHcclFusionAttrs = {kAttrFusion, kAttrGroup, kAttrGroupBack,
372 kAttrSrTag, kAttrDestRank, kAttrSrcRank,
373 kAttrDType, kAttrOp, kAttrRankSize};
374 for (const auto &attr : kHcclFusionAttrs) {
375 if (AnfAlgo::HasNodeAttr(attr, final_node)) {
376 AnfAlgo::CopyNodeAttr(attr, final_node, fused_node);
377 }
378 }
379 if (AnfAlgo::HasNodeAttr(kAttrShape, final_node)) {
380 std::vector<int64_t> fusion_total_shape{fusion_total_size};
381 AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(fusion_total_shape), fused_node);
382 }
383 bool is_recompute =
384 final_node->GetAttr(kAttrDuplicated) != nullptr && GetValue<bool>(final_node->GetAttr(kAttrDuplicated));
385 if (AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName && is_recompute) {
386 auto fused_cnode = fused_node->cast<CNodePtr>();
387 fused_cnode->AddAttr("duplicated", MakeValue(true));
388 auto fused_prim = GetCNodePrimitive(fused_cnode);
389 auto final_node_prim = GetCNodePrimitive(final_node);
390 fused_prim->set_instance_name(final_node_prim->instance_name());
391 }
392 return fused_node;
393 }
394
DoFusion(const FuncGraphPtr & func_graph,const CommunicationOpInfo & communication_op_info,size_t segment_num,const std::vector<size_t> & segment_index) const395 bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info,
396 size_t segment_num, const std::vector<size_t> &segment_index) const {
397 MS_EXCEPTION_IF_NULL(func_graph);
398 auto manager = func_graph->manager();
399 MS_EXCEPTION_IF_NULL(manager);
400 bool changed = false;
401 size_t start_index = 0;
402 for (size_t segment_idx = 0; segment_idx < segment_num; ++segment_idx) {
403 size_t end_index = segment_index.at(segment_idx);
404 if (end_index - start_index < 1) {
405 start_index = end_index + 1;
406 continue;
407 }
408 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
409 MS_EXCEPTION_IF_NULL(kernel_graph);
410 auto graph_id = kernel_graph->graph_id();
411 AnfNodePtr new_communication_op =
412 CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
413 AnfAlgo::SetGraphId(graph_id, new_communication_op.get());
414 // replace old communication op with new communication op
415 for (auto idx = start_index; idx <= end_index; ++idx) {
416 std::vector<AnfNodePtr> tuple_getitem_input;
417 tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
418 tuple_getitem_input.push_back(new_communication_op);
419 auto offset = SizeToLong(idx - start_index);
420 auto index = NewValueNode(offset);
421 MS_EXCEPTION_IF_NULL(index);
422 auto imm = std::make_shared<Int64Imm>(idx - start_index);
423 MS_EXCEPTION_IF_NULL(imm);
424 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>();
425 MS_EXCEPTION_IF_NULL(abstract_scalar);
426 index->set_abstract(abstract_scalar);
427 tuple_getitem_input.push_back(index);
428 AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input);
429 MS_EXCEPTION_IF_NULL(tuple_getitem);
430 auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx);
431 MS_EXCEPTION_IF_NULL(communication_op_node_item);
432 tuple_getitem->set_abstract(communication_op_node_item->abstract());
433 if (kernel_graph->IsInternalOutput(communication_op_node_item, 0)) {
434 kernel_graph->ReplaceInternalOutput(communication_op_node_item, new_communication_op, 0, LongToSize(offset));
435 }
436 if (!manager->Replace(communication_op_node_item, tuple_getitem)) {
437 MS_LOG(EXCEPTION) << "manager replace node failed";
438 }
439 }
440 start_index = end_index + 1;
441 changed = true;
442 }
443 return changed;
444 }
445
Run(const FuncGraphPtr & func_graph)446 bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
447 MS_EXCEPTION_IF_NULL(func_graph);
448 const float input_grad_size_num = 0.0;
449 const float input_grad_time_num = 0.0;
450 // divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion
451 std::unordered_map<std::string, CommunicationOpInfo> candidate_groups;
452 std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
453 for (auto &node : node_list) {
454 if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == op_name_) {
455 std::string key = GetFusionGroupKey(node);
456 if (key.empty()) {
457 continue;
458 }
459 if (candidate_groups.find(key) == candidate_groups.end()) {
460 CommunicationOpInfo communication_op_info;
461 candidate_groups[key] = communication_op_info;
462 }
463 candidate_groups[key].communication_op_nodes.push_back(node->cast<CNodePtr>());
464 candidate_groups[key].input_grad_size.push_back(input_grad_size_num);
465 candidate_groups[key].input_grad_time.push_back(input_grad_time_num);
466 }
467 }
468 // split candidate group to segments according to _group class member
469 bool changed = false;
470 for (auto &it : candidate_groups) {
471 if (it.second.communication_op_nodes.size() <= 1) {
472 continue;
473 }
474 auto first_node = it.second.communication_op_nodes[0];
475 TraceGuard guard(std::make_shared<TraceOpt>(first_node->debug_info()));
476 if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr<int64_t>(first_node, kAttrIndex) > 0) {
477 std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(),
478 [](const CNodePtr &a, const CNodePtr &b) {
479 return AnfAlgo::GetNodeAttr<int64_t>(a, kAttrIndex) <
480 AnfAlgo::GetNodeAttr<int64_t>(b, kAttrIndex);
481 });
482 }
483 size_t segment_num = 0;
484 std::vector<size_t> segment_index;
485 if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) {
486 if (DoFusion(func_graph, it.second, segment_num, segment_index)) {
487 changed = true;
488 }
489 }
490 }
491 return changed;
492 }
493 } // namespace opt
494 } // namespace mindspore
495