• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/pass/communication_op_fusion.h"
17 
18 #include <memory>
19 #include <set>
20 #include <vector>
21 #include <queue>
22 
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/backend/kernel_info.h"
25 #include "include/backend/optimizer/helper.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/common/utils/parallel_context.h"
28 #include "ir/graph_utils.h"
29 #include "kernel/kernel_build_info.h"
30 #include "ops/framework_ops.h"
31 #include "ops/sequence_ops.h"
32 #include "utils/hash_map.h"
33 #include "ir/manager.h"
34 
35 namespace mindspore {
36 namespace opt {
37 namespace {
38 constexpr auto kAttrDefaultGroup = "default_group";
39 constexpr auto kAttrDefaultOp = "default_op";
40 constexpr auto kAttrCommZone = "comm_fusion_zone";
41 constexpr size_t kAlignSize = 2 << 9;
42 constexpr int64_t kDefaultThresholdMb2Byte = 262144;
43 
GenerateKernelBuildInfo(const CommunicationOpInfo & communication_op_info,size_t start_index,size_t end_index)44 kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index,
45                                                    size_t end_index) {
46   if (end_index >= communication_op_info.communication_op_nodes.size()) {
47     MS_LOG(EXCEPTION) << "end index out of communication_op_nodes size";
48   }
49   std::vector<std::string> inputs_device_format;
50   std::vector<std::string> outputs_device_format;
51   std::vector<TypeId> inputs_device_type;
52   std::vector<TypeId> outputs_device_type;
53   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
54   for (size_t idx = start_index; idx <= end_index; ++idx) {
55     auto cnode = communication_op_info.communication_op_nodes[idx];
56     int64_t rank_size = 1;
57     if (common::AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) &&
58         common::AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) {
59       rank_size = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
60     }
61     if (rank_size == 0) {
62       MS_LOG(EXCEPTION) << "Rank size should not be zero.";
63     }
64     MS_EXCEPTION_IF_NULL(cnode);
65     size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
66     for (size_t input_index = 0; input_index < input_num; ++input_index) {
67       inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
68       inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
69     }
70     for (int64_t rank_index = 0; rank_index < rank_size; ++rank_index) {
71       size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
72       for (size_t output_index = 0; output_index < output_num; ++output_index) {
73         outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
74         outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
75       }
76     }
77     builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
78     builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
79     builder.SetKernelType(AnfAlgo::GetKernelType(cnode));
80   }
81   builder.SetInputsFormat(inputs_device_format);
82   builder.SetOutputsFormat(outputs_device_format);
83   builder.SetInputsDeviceType(inputs_device_type);
84   builder.SetOutputsDeviceType(outputs_device_type);
85   return builder.Build();
86 }
87 
GetFusionGroupKey(const AnfNodePtr & node)88 std::string GetFusionGroupKey(const AnfNodePtr &node) {
89   MS_EXCEPTION_IF_NULL(node);
90   auto primitive = common::AnfAlgo::GetCNodePrimitive(node);
91   MS_EXCEPTION_IF_NULL(primitive);
92   ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
93   if (attr_fusion == nullptr) {
94     return "";
95   }
96   auto fusion = GetValue<int64_t>(attr_fusion);
97   if (fusion == 0) {
98     return "";
99   }
100   auto parallel_context = parallel::ParallelContext::GetInstance();
101   if (parallel_context->enable_fold_pipeline()) {
102     auto cnode = node->cast<CNodePtr>();
103     MS_EXCEPTION_IF_NULL(cnode);
104     auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
105     auto prim = GetCNodePrimitive(node);
106     MS_EXCEPTION_IF_NULL(prim);
107     if (cnode_name == kAllReduceOpName) {
108       if (prim->HasAttr(kAttrSegment)) {
109         auto segment_info = GetValue<int64_t>(prim->GetAttr(kAttrSegment));
110         MS_LOG(INFO) << "Cnode : " << cnode->fullname_with_scope() << ", instance_name: " << prim->instance_name()
111                      << ", segment: " << segment_info;
112         fusion = segment_info + 2;
113         (void)prim->AddAttr(kAttrFusion, MakeValue(std::make_shared<Int64Imm>(fusion)));
114         MS_LOG(INFO) << "Now cnode : " << cnode->fullname_with_scope()
115                      << ", fusion: " << GetValue<int64_t>(prim->GetAttr(kAttrFusion));
116       }
117     }
118     if (cnode_name == kAllGatherOpName) {
119       if (prim->HasAttr(kAttrSegment)) {
120         auto segment_info = GetValue<int64_t>(prim->GetAttr(kAttrSegment));
121         MS_LOG(INFO) << "Cnode : " << cnode->fullname_with_scope() << ", instance_name: " << prim->instance_name()
122                      << ", segment: " << segment_info;
123         if (segment_info != 0) {
124           int64_t fusion_interval = 100;
125           fusion = segment_info + fusion_interval;
126           (void)prim->AddAttr(kAttrFusion, MakeValue(std::make_shared<Int64Imm>(fusion)));
127         }
128         MS_LOG(INFO) << "Cnode : " << cnode->fullname_with_scope()
129                      << ", fusion: " << GetValue<int64_t>(prim->GetAttr(kAttrFusion));
130       }
131     }
132   }
133 
134   std::string group = kAttrDefaultGroup;
135   ValuePtr attr_group = primitive->GetAttr(kAttrGroup);
136   if (attr_group != nullptr) {
137     group = GetValue<std::string>(attr_group);
138   }
139   std::string op = kAttrDefaultOp;
140   ValuePtr attr_op = primitive->GetAttr(kAttrOp);
141   if (attr_op != nullptr) {
142     op = GetValue<std::string>(attr_op);
143   }
144   auto dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
145   return group + op + std::to_string(fusion) + TypeIdLabel(dtype);
146 }
147 
CheckInputs(const std::vector<AnfNodePtr> & fusion_inputs)148 void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
149   std::set<AnfNodePtr> inputs_set(fusion_inputs.begin(), fusion_inputs.end());
150   if (inputs_set.size() < fusion_inputs.size()) {
151     MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input";
152   }
153 }
154 
CheckSegments(size_t communication_op_node_size,const std::vector<size_t> * segment_index)155 bool CheckSegments(size_t communication_op_node_size, const std::vector<size_t> *segment_index) {
156   MS_EXCEPTION_IF_NULL(segment_index);
157   auto segments = segment_index->size();
158   if (segment_index->at(segments - 1) != communication_op_node_size - 1) {
159     MS_LOG(EXCEPTION) << "the last segment index is invalid.";
160   }
161   for (size_t i = 0; i < segments - 1; ++i) {
162     if (segment_index->at(i) > segment_index->at(i + 1)) {
163       MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ "
164                         << (i + 1) << "]=" << segment_index->at(i + 1);
165     }
166   }
167   return true;
168 }
169 
GetNodeCommZoneId(const CNodePtr & cnode)170 uint32_t GetNodeCommZoneId(const CNodePtr &cnode) {
171   MS_EXCEPTION_IF_NULL(cnode);
172   if (cnode->HasAttr(kAttrCommZone)) {
173     return GetValue<uint32_t>(cnode->GetAttr(kAttrCommZone));
174   }
175   return 0;
176 }
177 
MarkCommunicationZone(const FuncGraphPtr & func_graph,const string & comm_op_name)178 void MarkCommunicationZone(const FuncGraphPtr &func_graph, const string &comm_op_name) {
179   MS_EXCEPTION_IF_NULL(func_graph);
180   std::queue<AnfNodePtr> to_visit;
181   to_visit.emplace(func_graph->get_return());
182   auto seen = NewSeenGeneration();
183   while (!to_visit.empty()) {
184     auto node = to_visit.front();
185     to_visit.pop();
186     MS_EXCEPTION_IF_NULL(node);
187     if (!node->isa<CNode>()) {
188       continue;
189     }
190     auto cnode = node->cast<CNodePtr>();
191     MS_EXCEPTION_IF_NULL(cnode);
192     auto zone_id = GetNodeCommZoneId(cnode);
193     for (auto &input : cnode->inputs()) {
194       MS_EXCEPTION_IF_NULL(input);
195       if (!input->isa<CNode>()) {
196         continue;
197       }
198       auto input_cnode = input->cast<CNodePtr>();
199       MS_EXCEPTION_IF_NULL(input_cnode);
200       auto input_zone_id = GetNodeCommZoneId(input_cnode);
201       auto update_zone_id = zone_id;
202       if (common::AnfAlgo::GetCNodeName(input_cnode) == comm_op_name && common::AnfAlgo::IsFusion(input_cnode)) {
203         update_zone_id += 1;
204       }
205       if (input_zone_id >= update_zone_id && input->seen_ == seen) {
206         continue;
207       }
208       input_cnode->AddAttr(kAttrCommZone, MakeValue(update_zone_id));
209       to_visit.emplace(input);
210       input->seen_ = seen;
211     }
212   }
213 }
214 
RemoveCommunicationZone(const FuncGraphPtr & func_graph)215 void RemoveCommunicationZone(const FuncGraphPtr &func_graph) {
216   MS_EXCEPTION_IF_NULL(func_graph);
217   auto seen = NewSeenGeneration();
218   std::queue<AnfNodePtr> to_visit;
219   to_visit.emplace(func_graph->get_return());
220   while (!to_visit.empty()) {
221     auto node = to_visit.front();
222     to_visit.pop();
223     MS_EXCEPTION_IF_NULL(node);
224     if (!node->isa<CNode>()) {
225       continue;
226     }
227     auto cnode = node->cast<CNodePtr>();
228     MS_EXCEPTION_IF_NULL(cnode);
229     cnode->EraseAttr(kAttrCommZone);
230     for (auto &input : cnode->inputs()) {
231       MS_EXCEPTION_IF_NULL(input);
232       if (!input->isa<CNode>()) {
233         continue;
234       }
235       if (input->seen_ == seen) {
236         continue;
237       }
238       to_visit.emplace(input);
239       input->seen_ = seen;
240     }
241   }
242 }
243 }  // namespace
244 
GetSplitSegments(const CommunicationOpInfo & communication_op_info,std::vector<size_t> * segment_index,const std::string & group) const245 bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info,
246                                              std::vector<size_t> *segment_index, const std::string &group) const {
247   MS_EXCEPTION_IF_NULL(segment_index);
248   size_t communication_op_node_size = communication_op_info.communication_op_nodes.size();
249   MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size;
250 
251   if (op_name_ == kSendOpName || op_name_ == kReceiveOpName) {
252     if (communication_op_node_size == 0) {
253       return false;
254     }
255     (void)segment_index->emplace_back(communication_op_node_size - 1);
256     return true;
257   }
258 
259   auto parallel_context = parallel::ParallelContext::GetInstance();
260   MS_EXCEPTION_IF_NULL(parallel_context);
261   std::vector<uint32_t> split_indices;
262   if (!parallel_context->enable_parallel_optimizer()) {
263     split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
264   }
265 
266   if (!split_indices.empty()) {
267     uint32_t last_index = 0;
268     for (size_t i = 0; i < split_indices.size(); ++i) {
269       uint32_t index = split_indices[i];
270       if (index <= last_index && i != 0) {
271         MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index;
272       }
273       if (index >= communication_op_node_size) {
274         MS_LOG(WARNING) << op_name_ << "'s split index " << index
275                         << " is Greater than or equal to total gradient's number " << communication_op_node_size;
276         continue;
277       }
278       segment_index->push_back(index);
279       last_index = index;
280     }
281     if (last_index != communication_op_node_size - 1) {
282       segment_index->push_back(communication_op_node_size - 1);
283     }
284   } else {
285     for (size_t i = 0; i < groups_ - 1; ++i) {
286       segment_index->push_back((i + 1) * (communication_op_node_size / groups_) - 1);
287     }
288     segment_index->push_back(communication_op_node_size - 1);
289   }
290   auto parallel_mode = parallel_context->parallel_mode();
291   if (parallel_mode == parallel::kDataParallel && op_name_ == kAllReduceOpName) {
292     auto threshold = parallel_context->dp_fusion_threshold_mb();
293     GetAllReduceSplitSegment(communication_op_info.communication_op_nodes, threshold, segment_index);
294     MS_LOG(INFO) << "The split threshold for AllReduce is " << threshold << ", the segment num is "
295                  << segment_index->size();
296   }
297   return CheckSegments(communication_op_node_size, segment_index);
298 }
299 
GetAllReduceSplitSegment(const std::vector<CNodePtr> & nodes,int64_t threshold,std::vector<size_t> * segment_index) const300 void CommunicationOpFusion::GetAllReduceSplitSegment(const std::vector<CNodePtr> &nodes, int64_t threshold,
301                                                      std::vector<size_t> *segment_index) const {
302   MS_EXCEPTION_IF_NULL(segment_index);
303   if (threshold < 0) {
304     MS_LOG(INFO) << "Split threshold is " << threshold << ". AllReduce nodes will take default fusion strategy.";
305     return;
306   }
307   threshold *= kDefaultThresholdMb2Byte;
308   std::vector<size_t> real_segment_index;
309   size_t start_index = 0;
310   for (auto index : *segment_index) {
311     if (index >= nodes.size()) {
312       MS_LOG(WARNING) << "split index is greater than or equal to total gradient's number " << nodes.size();
313       continue;
314     }
315     size_t accumulate = 0;
316     for (size_t j = start_index; j <= index; ++j) {
317       auto tensor_size = AnfAlgo::GetOutputTensorMemSize(nodes[j], 0);
318       if (accumulate + tensor_size > LongToSize(threshold)) {
319         real_segment_index.push_back(j);
320         accumulate = 0;
321       } else {
322         accumulate += tensor_size;
323       }
324     }
325     if (accumulate != 0) {
326       real_segment_index.push_back(index);
327     }
328     start_index = index + 1;
329   }
330   *segment_index = std::move(real_segment_index);
331 }
332 
333 // Hard coded Load(%paraxxx, cnode()) to Load(%paraxxx, U) to prevent
334 // cycle after AllReduce fused. It's a workaround.
335 // case 1:
336 // cnode_load = Load(%para2, cnode_u)
337 // %100 = UpdateState(cnode_u, cnode_load)
338 // ...
339 // %109 = AssignAdd(%para485, Tensor(34), %100)
340 // %110 = UpdateState(%100, xxx)
341 // will convert to:
342 // cnode_load = Load(%para2, U)
343 // ...
344 // %109 = AssignAdd(%para485, Tensor(34), cnode_u)
345 // %110 = UpdateState(cnode_u, xxx)
346 //
347 // case 2:
348 // cnode_load = Load(%para2, cnode_u)
349 // %99 = make_tuple(yyy, ..., cnode_load, ...)
350 // %100 = UpdateState(cnode_u, %99)
351 // ...
352 // %109 = AssignAdd(%para485, Tensor(34), %100)
353 // %110 = UpdateState(%100, xxx)
354 // will convert to:
355 // cnode_load = Load(%para2, U)
356 // %99 = make_tuple(yyy, ...)
357 // %100 = UpdateState(cnode_u, %99)
358 // ...
359 // %109 = AssignAdd(%para485, Tensor(34), %100)
360 // %110 = UpdateState(%100, xxx)
361 //
362 // case 3:
363 // cnode_load = Load(%para2, cnode_u)
364 // %99 = make_tuple(cnode_load)
365 // %100 = UpdateState(cnode_u, %99)
366 // ...
367 // %109 = AssignAdd(%para485, Tensor(34), %100)
368 // %110 = UpdateState(%100, xxx)
369 // will convert to:
370 // cnode_load = Load(%para2, U)
371 // ...
372 // %109 = AssignAdd(%para485, Tensor(34), cnode_u)
373 // %110 = UpdateState(cnode_u, xxx)
AdjustAllReduceInputWithLoad(const CNodePtr & cnode)374 static void AdjustAllReduceInputWithLoad(const CNodePtr &cnode) {
375   const size_t monad_index = 2;
376   const size_t tuple_inputs_size = 2;
377   const size_t load_inputs_size = 3;
378   auto cnode_load = BroadFirstSearchFirstOf({cnode}, [&](const CNodePtr &search_cnode) {
379     if (!IsPrimitiveCNode(search_cnode, prim::kPrimLoad)) {
380       return false;
381     }
382     if (search_cnode->size() != load_inputs_size) {
383       MS_LOG(EXCEPTION) << "Load CNode should have 3 inputs, but: " << search_cnode->DebugString();
384     }
385     return search_cnode->input(monad_index)->isa<CNode>();
386   });
387   if (cnode_load != nullptr) {
388     auto const_u_monad = NewValueNode(kUMonad);
389     const_u_monad->set_abstract(kUMonad->ToAbstract());
390     const auto &cnode_u = cnode_load->input(monad_index);
391     MS_LOG(DEBUG) << "Replace Load with CNode U to constant U for cnode: " << cnode_load->DebugString();
392     MS_EXCEPTION_IF_NULL(cnode->func_graph());
393     MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
394     auto manager = cnode->func_graph()->manager();
395     manager->SetEdge(cnode_load, monad_index, const_u_monad);
396     // Update the u_monad input of UpdateState from CNode U same as Load to constant U.
397     CNodePtr cnode_update_state = nullptr;
398     CNodePtr cnode_make_tuple = nullptr;
399     const auto &cnode_load_users = manager->node_users()[cnode_load];
400     for (auto &load_user : cnode_load_users) {
401       if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
402         const auto &cnode_make_tuple_users = manager->node_users()[load_user.first];
403         for (auto &make_tuple_user : cnode_make_tuple_users) {
404           if (IsPrimitiveCNode(make_tuple_user.first, prim::kPrimUpdateState)) {
405             const auto &cnode_user = make_tuple_user.first->cast<CNodePtr>();
406             if (cnode_user->input(1) == cnode_u) {
407               cnode_update_state = cnode_user;
408               cnode_make_tuple = load_user.first->cast<CNodePtr>();
409               break;
410             }
411           }
412         }
413         if (cnode_update_state != nullptr) {
414           break;
415         }
416       }
417       if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
418         const auto &cnode_user = load_user.first->cast<CNodePtr>();
419         if (cnode_user->input(1) == cnode_u) {
420           cnode_update_state = cnode_user;
421           break;
422         }
423       }
424     }
425     if (cnode_update_state != nullptr) {
426       if (cnode_make_tuple == nullptr || cnode_make_tuple->size() == tuple_inputs_size) {
427         // case 1 and case 3: Replace cnode_update_state to cnode_u;
428         MS_LOG(DEBUG) << "Replace UpdateState with CNode U: " << cnode_update_state->DebugString()
429                       << " ::TO:: " << cnode_u->DebugString();
430         manager->Replace(cnode_update_state, cnode_u);
431       } else if (cnode_make_tuple->size() > tuple_inputs_size) {
432         // case 2: remove cnode_load from cnode_make_tuple;
433         MS_LOG(DEBUG) << "Drop " << cnode_load->DebugString() << " from " << cnode_make_tuple->DebugString();
434         const auto &make_tuple_inputs = cnode_make_tuple->inputs();
435         AnfNodePtrList new_tuple_inputs(make_tuple_inputs.size() - 1);
436         std::copy_if(make_tuple_inputs.cbegin(), make_tuple_inputs.cend(), new_tuple_inputs.begin(),
437                      [cnode_load](const auto &inp) { return inp != cnode_load; });
438         auto new_cnode_make_tuple = cnode_make_tuple->func_graph()->NewCNode(new_tuple_inputs);
439         manager->Replace(cnode_make_tuple, new_cnode_make_tuple);
440       } else {
441         MS_LOG(INTERNAL_EXCEPTION) << "Cannot replace UpdateState with CNode U: " << cnode_update_state->DebugString()
442                                    << " as make_tuple CNode cannot match " << cnode_make_tuple->DebugString();
443       }
444     }
445   }
446 }
447 
CreateFusedCommunicationOp(const FuncGraphPtr & func_graph,const CommunicationOpInfo & communication_op_info,size_t start_index,size_t end_index) const448 AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
449                                                              const CommunicationOpInfo &communication_op_info,
450                                                              size_t start_index, size_t end_index) const {
451   MS_EXCEPTION_IF_NULL(func_graph);
452   auto prim = std::make_shared<Primitive>(op_name_);
453   MS_EXCEPTION_IF_NULL(prim);
454   std::vector<AnfNodePtr> fusion_inputs = {NewValueNode(prim)};
455   // get all inputs of current segment
456   if (end_index >= communication_op_info.communication_op_nodes.size()) {
457     MS_LOG(EXCEPTION) << "End index is out of communication_op_nodes size";
458   }
459   std::vector<AnfNodePtr> orig_nodes;
460   for (size_t idx = start_index; idx <= end_index; ++idx) {
461     auto cnode = communication_op_info.communication_op_nodes[idx];
462     MS_EXCEPTION_IF_NULL(cnode);
463     if (idx != start_index) {
464       AdjustAllReduceInputWithLoad(cnode);
465     }
466     auto inputs = cnode->inputs();
467     (void)fusion_inputs.insert(fusion_inputs.cend(), inputs.cbegin() + 1, inputs.cend());
468     (void)orig_nodes.emplace_back(cnode);
469   }
470   CheckInputs(fusion_inputs);
471   AnfNodePtr fused_node = NewCNode(fusion_inputs, func_graph, orig_nodes);
472   MS_EXCEPTION_IF_NULL(fused_node);
473   auto kernel_info = std::make_shared<device::KernelInfo>();
474   MS_EXCEPTION_IF_NULL(kernel_info);
475   fused_node->set_kernel_info(kernel_info);
476   auto final_node = communication_op_info.communication_op_nodes[end_index];
477   size_t node_num = end_index - start_index + 1;
478   int64_t rank_size = 1;
479   if (common::AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) &&
480       common::AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) {
481     rank_size = common::AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize);
482   }
483 
484   if (rank_size == 0) {
485     MS_LOG(EXCEPTION) << "Rank size should not be zero.";
486   }
487   size_t output_num = node_num * LongToSize(rank_size);
488   std::vector<TypeId> dtypes(output_num, common::AnfAlgo::GetOutputInferDataType(final_node, 0));
489   std::vector<ShapeVector> shapes;
490   int64_t fusion_total_size = 0;
491   for (int64_t i = 0; i < rank_size; ++i) {
492     for (size_t idx = start_index; idx <= end_index; ++idx) {
493       auto input_node = communication_op_info.communication_op_nodes[idx];
494       MS_EXCEPTION_IF_NULL(input_node);
495       auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
496       if (!shape.empty()) {
497         shape[0] /= rank_size;
498       }
499       shapes.push_back(shape);
500       size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
501       TypeId output_type = AnfAlgo::GetOutputDeviceDataType(input_node, 0);
502       size_t type_size = GetTypeByte(TypeIdToType(output_type));
503       if (type_size == 0) {
504         MS_LOG(EXCEPTION) << "Divisor 'type_size' should not be 0.";
505       }
506       tensor_size = (tensor_size / kAlignSize + 1) * kAlignSize / type_size;
507       fusion_total_size += static_cast<int64_t>(tensor_size);
508     }
509   }
510   common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get());
511   auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
512   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get());
513   const std::vector<std::string> kHcclFusionAttrs = {
514     kAttrFusion, kAttrGroup, kAttrGroupBack, kAttrSrTag,        kAttrDestRank,           kAttrSrcRank,
515     kAttrDType,  kAttrOp,    kAttrRankSize,  kAttrGroupRankIds, kAttrReuseCommunication, kAttrSegment};
516   for (const auto &attr : kHcclFusionAttrs) {
517     if (common::AnfAlgo::HasNodeAttr(attr, final_node)) {
518       common::AnfAlgo::CopyNodeAttr(attr, final_node, fused_node);
519     }
520   }
521   if (common::AnfAlgo::HasNodeAttr(kAttrShape, final_node)) {
522     std::vector<int64_t> fusion_total_shape{fusion_total_size};
523     common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(fusion_total_shape), fused_node);
524   }
525   bool is_recompute =
526     final_node->GetAttr(kAttrDuplicated) != nullptr && GetValue<bool>(final_node->GetAttr(kAttrDuplicated));
527   if (common::AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName && is_recompute) {
528     auto fused_cnode = fused_node->cast<CNodePtr>();
529     fused_cnode->AddAttr("duplicated", MakeValue(true));
530     auto fused_prim = GetCNodePrimitive(fused_cnode);
531     auto final_node_prim = GetCNodePrimitive(final_node);
532     fused_prim->set_instance_name(final_node_prim->instance_name());
533   }
534   if (common::AnfAlgo::HasNodeAttr(kAttrNotDelayFusion, final_node)) {
535     common::AnfAlgo::CopyNodeAttr(kAttrNotDelayFusion, final_node, fused_node);
536   }
537   return fused_node;
538 }
539 
DoFusion(const FuncGraphPtr & func_graph,const CommunicationOpInfo & communication_op_info,const std::vector<size_t> & segment_index) const540 bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info,
541                                      const std::vector<size_t> &segment_index) const {
542   MS_EXCEPTION_IF_NULL(func_graph);
543   auto manager = func_graph->manager();
544   MS_EXCEPTION_IF_NULL(manager);
545   bool changed = false;
546   size_t start_index = 0;
547   for (size_t segment_idx = 0; segment_idx < segment_index.size(); ++segment_idx) {
548     size_t end_index = segment_index.at(segment_idx);
549     if (end_index - start_index < 1) {
550       start_index = end_index + 1;
551       continue;
552     }
553     auto kernel_graph = func_graph->cast<KernelGraphPtr>();
554     MS_EXCEPTION_IF_NULL(kernel_graph);
555     auto graph_id = kernel_graph->graph_id();
556     AnfNodePtr new_communication_op =
557       CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
558     AnfAlgo::SetGraphId(graph_id, new_communication_op.get());
559     // replace old communication op with new communication op
560     for (auto idx = start_index; idx <= end_index; ++idx) {
561       std::vector<AnfNodePtr> tuple_getitem_input;
562       tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
563       tuple_getitem_input.push_back(new_communication_op);
564       auto offset = SizeToLong(idx - start_index);
565       auto index = NewValueNode(offset);
566       MS_EXCEPTION_IF_NULL(index);
567       auto imm = std::make_shared<Int64Imm>(idx - start_index);
568       MS_EXCEPTION_IF_NULL(imm);
569       auto abstract_scalar = std::make_shared<abstract::AbstractScalar>();
570       MS_EXCEPTION_IF_NULL(abstract_scalar);
571       index->set_abstract(abstract_scalar);
572       tuple_getitem_input.push_back(index);
573       AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input);
574       MS_EXCEPTION_IF_NULL(tuple_getitem);
575       auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx);
576       MS_EXCEPTION_IF_NULL(communication_op_node_item);
577       tuple_getitem->set_abstract(communication_op_node_item->abstract());
578       if (kernel_graph->IsInternalOutput(communication_op_node_item, 0)) {
579         kernel_graph->ReplaceInternalOutput(communication_op_node_item, new_communication_op, 0, LongToSize(offset));
580       }
581       if (common::GetEnv("MS_ENABLE_FRONTEND_SCHEDULING_OPTIMIZATION") == "1") {
582         auto &users = manager->node_users()[communication_op_node_item];
583         for (auto &node : users) {
584           auto cnode = node.first->cast<CNodePtr>();
585           MS_EXCEPTION_IF_NULL(cnode);
586           if (cnode->HasAttr("comp_comm_scheduling_depend")) {
587             MS_LOG(INFO) << "Start EdgeRemove: AllReduce to comp_comm_scheduling_depend";
588             if (cnode->size() <= 1 || !common::AnfAlgo::IsCommunicationOp(cnode->input(1))) {
589               MS_LOG(INTERNAL_EXCEPTION) << "Input 1 of Cnode doesn't exist or is not a communication node!";
590             }
591             std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), cnode->input(1)->cast<CNodePtr>()};
592             auto depend_node = cnode->func_graph()->NewCNode(depend_inputs);
593             depend_node->set_abstract(cnode->input(1)->cast<CNodePtr>()->abstract()->Clone());
594             depend_node->AddAttr("comp_comm_scheduling_depend", MakeValue(true));
595             if (!manager->Replace(cnode, depend_node)) {
596               MS_LOG(INTERNAL_EXCEPTION) << "Manager replace node failed";
597             }
598             MS_LOG(INFO) << "End EdgeRemove: AllReduce to comp_comm_scheduling_depend";
599           }
600         }
601       }
602       if (!manager->Replace(communication_op_node_item, tuple_getitem)) {
603         MS_LOG(INTERNAL_EXCEPTION) << "Manager replace node failed";
604       }
605     }
606     start_index = end_index + 1;
607     changed = true;
608   }
609   return changed;
610 }
611 
Run(const FuncGraphPtr & func_graph)612 bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
613   MS_EXCEPTION_IF_NULL(func_graph);
614   auto parallel_context = parallel::ParallelContext::GetInstance();
615   MS_EXCEPTION_IF_NULL(parallel_context);
616   auto threshold = parallel_context->dp_fusion_threshold_mb();
617   if (threshold == 0) {
618     return false;
619   }
620   const float input_grad_size_num = 0.0;
621   const float input_grad_time_num = 0.0;
622   // divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion
623   mindspore::HashMap<std::string, CommunicationOpInfo> candidate_groups;
624   // avoid fuse communication nodes with dependencies like comm_node1->depend->comm_node2
625   MarkCommunicationZone(func_graph, op_name_);
626   std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
627   for (auto &node : node_list) {
628     if (node != nullptr && node->isa<CNode>() && common::AnfAlgo::GetCNodeName(node) == op_name_) {
629       std::string group_name = GetFusionGroupKey(node);
630       if (group_name.empty()) {
631         continue;
632       }
633       std::string key = group_name + std::to_string(GetNodeCommZoneId(node->cast<CNodePtr>()));
634       if (candidate_groups.find(key) == candidate_groups.end()) {
635         CommunicationOpInfo communication_op_info;
636         candidate_groups[key] = communication_op_info;
637         communication_op_info.group_name = group_name;
638       }
639       candidate_groups[key].communication_op_nodes.push_back(node->cast<CNodePtr>());
640       candidate_groups[key].input_grad_size.push_back(input_grad_size_num);
641       candidate_groups[key].input_grad_time.push_back(input_grad_time_num);
642     }
643   }
644   RemoveCommunicationZone(func_graph);
645   // split candidate group to segments according to _group class member
646   bool changed = false;
647   for (auto &it : candidate_groups) {
648     if (it.second.communication_op_nodes.size() <= 1) {
649       continue;
650     }
651     auto first_node = it.second.communication_op_nodes[0];
652     TraceGuard guard(std::make_shared<TraceOpt>(first_node->debug_info()));
653     if (common::AnfAlgo::HasNodeAttr(kAttrIndex, first_node) &&
654         common::AnfAlgo::GetNodeAttr<int64_t>(first_node, kAttrIndex) > 0) {
655       std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(),
656                        [](const CNodePtr &a, const CNodePtr &b) {
657                          return common::AnfAlgo::GetNodeAttr<int64_t>(a, kAttrIndex) <
658                                 common::AnfAlgo::GetNodeAttr<int64_t>(b, kAttrIndex);
659                        });
660     }
661     std::vector<size_t> segment_index;
662     if (GetSplitSegments(it.second, &segment_index, it.second.group_name)) {
663       if (DoFusion(func_graph, it.second, segment_index)) {
664         changed = true;
665       }
666     }
667   }
668   return changed;
669 }
670 }  // namespace opt
671 }  // namespace mindspore
672