1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/pass/communication_op_fusion.h"
17
18 #include <memory>
19 #include <set>
20 #include <vector>
21 #include <queue>
22
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/backend/kernel_info.h"
25 #include "include/backend/optimizer/helper.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/common/utils/parallel_context.h"
28 #include "ir/graph_utils.h"
29 #include "kernel/kernel_build_info.h"
30 #include "ops/framework_ops.h"
31 #include "ops/sequence_ops.h"
32 #include "utils/hash_map.h"
33 #include "ir/manager.h"
34
35 namespace mindspore {
36 namespace opt {
37 namespace {
38 constexpr auto kAttrDefaultGroup = "default_group";
39 constexpr auto kAttrDefaultOp = "default_op";
40 constexpr auto kAttrCommZone = "comm_fusion_zone";
41 constexpr size_t kAlignSize = 2 << 9;
42 constexpr int64_t kDefaultThresholdMb2Byte = 262144;
43
GenerateKernelBuildInfo(const CommunicationOpInfo & communication_op_info,size_t start_index,size_t end_index)44 kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index,
45 size_t end_index) {
46 if (end_index >= communication_op_info.communication_op_nodes.size()) {
47 MS_LOG(EXCEPTION) << "end index out of communication_op_nodes size";
48 }
49 std::vector<std::string> inputs_device_format;
50 std::vector<std::string> outputs_device_format;
51 std::vector<TypeId> inputs_device_type;
52 std::vector<TypeId> outputs_device_type;
53 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
54 for (size_t idx = start_index; idx <= end_index; ++idx) {
55 auto cnode = communication_op_info.communication_op_nodes[idx];
56 int64_t rank_size = 1;
57 if (common::AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) &&
58 common::AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) {
59 rank_size = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
60 }
61 if (rank_size == 0) {
62 MS_LOG(EXCEPTION) << "Rank size should not be zero.";
63 }
64 MS_EXCEPTION_IF_NULL(cnode);
65 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
66 for (size_t input_index = 0; input_index < input_num; ++input_index) {
67 inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
68 inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
69 }
70 for (int64_t rank_index = 0; rank_index < rank_size; ++rank_index) {
71 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
72 for (size_t output_index = 0; output_index < output_num; ++output_index) {
73 outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
74 outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
75 }
76 }
77 builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
78 builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
79 builder.SetKernelType(AnfAlgo::GetKernelType(cnode));
80 }
81 builder.SetInputsFormat(inputs_device_format);
82 builder.SetOutputsFormat(outputs_device_format);
83 builder.SetInputsDeviceType(inputs_device_type);
84 builder.SetOutputsDeviceType(outputs_device_type);
85 return builder.Build();
86 }
87
GetFusionGroupKey(const AnfNodePtr & node)88 std::string GetFusionGroupKey(const AnfNodePtr &node) {
89 MS_EXCEPTION_IF_NULL(node);
90 auto primitive = common::AnfAlgo::GetCNodePrimitive(node);
91 MS_EXCEPTION_IF_NULL(primitive);
92 ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
93 if (attr_fusion == nullptr) {
94 return "";
95 }
96 auto fusion = GetValue<int64_t>(attr_fusion);
97 if (fusion == 0) {
98 return "";
99 }
100 auto parallel_context = parallel::ParallelContext::GetInstance();
101 if (parallel_context->enable_fold_pipeline()) {
102 auto cnode = node->cast<CNodePtr>();
103 MS_EXCEPTION_IF_NULL(cnode);
104 auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
105 auto prim = GetCNodePrimitive(node);
106 MS_EXCEPTION_IF_NULL(prim);
107 if (cnode_name == kAllReduceOpName) {
108 if (prim->HasAttr(kAttrSegment)) {
109 auto segment_info = GetValue<int64_t>(prim->GetAttr(kAttrSegment));
110 MS_LOG(INFO) << "Cnode : " << cnode->fullname_with_scope() << ", instance_name: " << prim->instance_name()
111 << ", segment: " << segment_info;
112 fusion = segment_info + 2;
113 (void)prim->AddAttr(kAttrFusion, MakeValue(std::make_shared<Int64Imm>(fusion)));
114 MS_LOG(INFO) << "Now cnode : " << cnode->fullname_with_scope()
115 << ", fusion: " << GetValue<int64_t>(prim->GetAttr(kAttrFusion));
116 }
117 }
118 if (cnode_name == kAllGatherOpName) {
119 if (prim->HasAttr(kAttrSegment)) {
120 auto segment_info = GetValue<int64_t>(prim->GetAttr(kAttrSegment));
121 MS_LOG(INFO) << "Cnode : " << cnode->fullname_with_scope() << ", instance_name: " << prim->instance_name()
122 << ", segment: " << segment_info;
123 if (segment_info != 0) {
124 int64_t fusion_interval = 100;
125 fusion = segment_info + fusion_interval;
126 (void)prim->AddAttr(kAttrFusion, MakeValue(std::make_shared<Int64Imm>(fusion)));
127 }
128 MS_LOG(INFO) << "Cnode : " << cnode->fullname_with_scope()
129 << ", fusion: " << GetValue<int64_t>(prim->GetAttr(kAttrFusion));
130 }
131 }
132 }
133
134 std::string group = kAttrDefaultGroup;
135 ValuePtr attr_group = primitive->GetAttr(kAttrGroup);
136 if (attr_group != nullptr) {
137 group = GetValue<std::string>(attr_group);
138 }
139 std::string op = kAttrDefaultOp;
140 ValuePtr attr_op = primitive->GetAttr(kAttrOp);
141 if (attr_op != nullptr) {
142 op = GetValue<std::string>(attr_op);
143 }
144 auto dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
145 return group + op + std::to_string(fusion) + TypeIdLabel(dtype);
146 }
147
CheckInputs(const std::vector<AnfNodePtr> & fusion_inputs)148 void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
149 std::set<AnfNodePtr> inputs_set(fusion_inputs.begin(), fusion_inputs.end());
150 if (inputs_set.size() < fusion_inputs.size()) {
151 MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input";
152 }
153 }
154
CheckSegments(size_t communication_op_node_size,const std::vector<size_t> * segment_index)155 bool CheckSegments(size_t communication_op_node_size, const std::vector<size_t> *segment_index) {
156 MS_EXCEPTION_IF_NULL(segment_index);
157 auto segments = segment_index->size();
158 if (segment_index->at(segments - 1) != communication_op_node_size - 1) {
159 MS_LOG(EXCEPTION) << "the last segment index is invalid.";
160 }
161 for (size_t i = 0; i < segments - 1; ++i) {
162 if (segment_index->at(i) > segment_index->at(i + 1)) {
163 MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ "
164 << (i + 1) << "]=" << segment_index->at(i + 1);
165 }
166 }
167 return true;
168 }
169
GetNodeCommZoneId(const CNodePtr & cnode)170 uint32_t GetNodeCommZoneId(const CNodePtr &cnode) {
171 MS_EXCEPTION_IF_NULL(cnode);
172 if (cnode->HasAttr(kAttrCommZone)) {
173 return GetValue<uint32_t>(cnode->GetAttr(kAttrCommZone));
174 }
175 return 0;
176 }
177
MarkCommunicationZone(const FuncGraphPtr & func_graph,const string & comm_op_name)178 void MarkCommunicationZone(const FuncGraphPtr &func_graph, const string &comm_op_name) {
179 MS_EXCEPTION_IF_NULL(func_graph);
180 std::queue<AnfNodePtr> to_visit;
181 to_visit.emplace(func_graph->get_return());
182 auto seen = NewSeenGeneration();
183 while (!to_visit.empty()) {
184 auto node = to_visit.front();
185 to_visit.pop();
186 MS_EXCEPTION_IF_NULL(node);
187 if (!node->isa<CNode>()) {
188 continue;
189 }
190 auto cnode = node->cast<CNodePtr>();
191 MS_EXCEPTION_IF_NULL(cnode);
192 auto zone_id = GetNodeCommZoneId(cnode);
193 for (auto &input : cnode->inputs()) {
194 MS_EXCEPTION_IF_NULL(input);
195 if (!input->isa<CNode>()) {
196 continue;
197 }
198 auto input_cnode = input->cast<CNodePtr>();
199 MS_EXCEPTION_IF_NULL(input_cnode);
200 auto input_zone_id = GetNodeCommZoneId(input_cnode);
201 auto update_zone_id = zone_id;
202 if (common::AnfAlgo::GetCNodeName(input_cnode) == comm_op_name && common::AnfAlgo::IsFusion(input_cnode)) {
203 update_zone_id += 1;
204 }
205 if (input_zone_id >= update_zone_id && input->seen_ == seen) {
206 continue;
207 }
208 input_cnode->AddAttr(kAttrCommZone, MakeValue(update_zone_id));
209 to_visit.emplace(input);
210 input->seen_ = seen;
211 }
212 }
213 }
214
RemoveCommunicationZone(const FuncGraphPtr & func_graph)215 void RemoveCommunicationZone(const FuncGraphPtr &func_graph) {
216 MS_EXCEPTION_IF_NULL(func_graph);
217 auto seen = NewSeenGeneration();
218 std::queue<AnfNodePtr> to_visit;
219 to_visit.emplace(func_graph->get_return());
220 while (!to_visit.empty()) {
221 auto node = to_visit.front();
222 to_visit.pop();
223 MS_EXCEPTION_IF_NULL(node);
224 if (!node->isa<CNode>()) {
225 continue;
226 }
227 auto cnode = node->cast<CNodePtr>();
228 MS_EXCEPTION_IF_NULL(cnode);
229 cnode->EraseAttr(kAttrCommZone);
230 for (auto &input : cnode->inputs()) {
231 MS_EXCEPTION_IF_NULL(input);
232 if (!input->isa<CNode>()) {
233 continue;
234 }
235 if (input->seen_ == seen) {
236 continue;
237 }
238 to_visit.emplace(input);
239 input->seen_ = seen;
240 }
241 }
242 }
243 } // namespace
244
GetSplitSegments(const CommunicationOpInfo & communication_op_info,std::vector<size_t> * segment_index,const std::string & group) const245 bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info,
246 std::vector<size_t> *segment_index, const std::string &group) const {
247 MS_EXCEPTION_IF_NULL(segment_index);
248 size_t communication_op_node_size = communication_op_info.communication_op_nodes.size();
249 MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size;
250
251 if (op_name_ == kSendOpName || op_name_ == kReceiveOpName) {
252 if (communication_op_node_size == 0) {
253 return false;
254 }
255 (void)segment_index->emplace_back(communication_op_node_size - 1);
256 return true;
257 }
258
259 auto parallel_context = parallel::ParallelContext::GetInstance();
260 MS_EXCEPTION_IF_NULL(parallel_context);
261 std::vector<uint32_t> split_indices;
262 if (!parallel_context->enable_parallel_optimizer()) {
263 split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
264 }
265
266 if (!split_indices.empty()) {
267 uint32_t last_index = 0;
268 for (size_t i = 0; i < split_indices.size(); ++i) {
269 uint32_t index = split_indices[i];
270 if (index <= last_index && i != 0) {
271 MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index;
272 }
273 if (index >= communication_op_node_size) {
274 MS_LOG(WARNING) << op_name_ << "'s split index " << index
275 << " is Greater than or equal to total gradient's number " << communication_op_node_size;
276 continue;
277 }
278 segment_index->push_back(index);
279 last_index = index;
280 }
281 if (last_index != communication_op_node_size - 1) {
282 segment_index->push_back(communication_op_node_size - 1);
283 }
284 } else {
285 for (size_t i = 0; i < groups_ - 1; ++i) {
286 segment_index->push_back((i + 1) * (communication_op_node_size / groups_) - 1);
287 }
288 segment_index->push_back(communication_op_node_size - 1);
289 }
290 auto parallel_mode = parallel_context->parallel_mode();
291 if (parallel_mode == parallel::kDataParallel && op_name_ == kAllReduceOpName) {
292 auto threshold = parallel_context->dp_fusion_threshold_mb();
293 GetAllReduceSplitSegment(communication_op_info.communication_op_nodes, threshold, segment_index);
294 MS_LOG(INFO) << "The split threshold for AllReduce is " << threshold << ", the segment num is "
295 << segment_index->size();
296 }
297 return CheckSegments(communication_op_node_size, segment_index);
298 }
299
GetAllReduceSplitSegment(const std::vector<CNodePtr> & nodes,int64_t threshold,std::vector<size_t> * segment_index) const300 void CommunicationOpFusion::GetAllReduceSplitSegment(const std::vector<CNodePtr> &nodes, int64_t threshold,
301 std::vector<size_t> *segment_index) const {
302 MS_EXCEPTION_IF_NULL(segment_index);
303 if (threshold < 0) {
304 MS_LOG(INFO) << "Split threshold is " << threshold << ". AllReduce nodes will take default fusion strategy.";
305 return;
306 }
307 threshold *= kDefaultThresholdMb2Byte;
308 std::vector<size_t> real_segment_index;
309 size_t start_index = 0;
310 for (auto index : *segment_index) {
311 if (index >= nodes.size()) {
312 MS_LOG(WARNING) << "split index is greater than or equal to total gradient's number " << nodes.size();
313 continue;
314 }
315 size_t accumulate = 0;
316 for (size_t j = start_index; j <= index; ++j) {
317 auto tensor_size = AnfAlgo::GetOutputTensorMemSize(nodes[j], 0);
318 if (accumulate + tensor_size > LongToSize(threshold)) {
319 real_segment_index.push_back(j);
320 accumulate = 0;
321 } else {
322 accumulate += tensor_size;
323 }
324 }
325 if (accumulate != 0) {
326 real_segment_index.push_back(index);
327 }
328 start_index = index + 1;
329 }
330 *segment_index = std::move(real_segment_index);
331 }
332
333 // Hard coded Load(%paraxxx, cnode()) to Load(%paraxxx, U) to prevent
334 // cycle after AllReduce fused. It's a workaround.
335 // case 1:
336 // cnode_load = Load(%para2, cnode_u)
337 // %100 = UpdateState(cnode_u, cnode_load)
338 // ...
339 // %109 = AssignAdd(%para485, Tensor(34), %100)
340 // %110 = UpdateState(%100, xxx)
341 // will convert to:
342 // cnode_load = Load(%para2, U)
343 // ...
344 // %109 = AssignAdd(%para485, Tensor(34), cnode_u)
345 // %110 = UpdateState(cnode_u, xxx)
346 //
347 // case 2:
348 // cnode_load = Load(%para2, cnode_u)
349 // %99 = make_tuple(yyy, ..., cnode_load, ...)
350 // %100 = UpdateState(cnode_u, %99)
351 // ...
352 // %109 = AssignAdd(%para485, Tensor(34), %100)
353 // %110 = UpdateState(%100, xxx)
354 // will convert to:
355 // cnode_load = Load(%para2, U)
356 // %99 = make_tuple(yyy, ...)
357 // %100 = UpdateState(cnode_u, %99)
358 // ...
359 // %109 = AssignAdd(%para485, Tensor(34), %100)
360 // %110 = UpdateState(%100, xxx)
361 //
362 // case 3:
363 // cnode_load = Load(%para2, cnode_u)
364 // %99 = make_tuple(cnode_load)
365 // %100 = UpdateState(cnode_u, %99)
366 // ...
367 // %109 = AssignAdd(%para485, Tensor(34), %100)
368 // %110 = UpdateState(%100, xxx)
369 // will convert to:
370 // cnode_load = Load(%para2, U)
371 // ...
372 // %109 = AssignAdd(%para485, Tensor(34), cnode_u)
373 // %110 = UpdateState(cnode_u, xxx)
AdjustAllReduceInputWithLoad(const CNodePtr & cnode)374 static void AdjustAllReduceInputWithLoad(const CNodePtr &cnode) {
375 const size_t monad_index = 2;
376 const size_t tuple_inputs_size = 2;
377 const size_t load_inputs_size = 3;
378 auto cnode_load = BroadFirstSearchFirstOf({cnode}, [&](const CNodePtr &search_cnode) {
379 if (!IsPrimitiveCNode(search_cnode, prim::kPrimLoad)) {
380 return false;
381 }
382 if (search_cnode->size() != load_inputs_size) {
383 MS_LOG(EXCEPTION) << "Load CNode should have 3 inputs, but: " << search_cnode->DebugString();
384 }
385 return search_cnode->input(monad_index)->isa<CNode>();
386 });
387 if (cnode_load != nullptr) {
388 auto const_u_monad = NewValueNode(kUMonad);
389 const_u_monad->set_abstract(kUMonad->ToAbstract());
390 const auto &cnode_u = cnode_load->input(monad_index);
391 MS_LOG(DEBUG) << "Replace Load with CNode U to constant U for cnode: " << cnode_load->DebugString();
392 MS_EXCEPTION_IF_NULL(cnode->func_graph());
393 MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
394 auto manager = cnode->func_graph()->manager();
395 manager->SetEdge(cnode_load, monad_index, const_u_monad);
396 // Update the u_monad input of UpdateState from CNode U same as Load to constant U.
397 CNodePtr cnode_update_state = nullptr;
398 CNodePtr cnode_make_tuple = nullptr;
399 const auto &cnode_load_users = manager->node_users()[cnode_load];
400 for (auto &load_user : cnode_load_users) {
401 if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
402 const auto &cnode_make_tuple_users = manager->node_users()[load_user.first];
403 for (auto &make_tuple_user : cnode_make_tuple_users) {
404 if (IsPrimitiveCNode(make_tuple_user.first, prim::kPrimUpdateState)) {
405 const auto &cnode_user = make_tuple_user.first->cast<CNodePtr>();
406 if (cnode_user->input(1) == cnode_u) {
407 cnode_update_state = cnode_user;
408 cnode_make_tuple = load_user.first->cast<CNodePtr>();
409 break;
410 }
411 }
412 }
413 if (cnode_update_state != nullptr) {
414 break;
415 }
416 }
417 if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
418 const auto &cnode_user = load_user.first->cast<CNodePtr>();
419 if (cnode_user->input(1) == cnode_u) {
420 cnode_update_state = cnode_user;
421 break;
422 }
423 }
424 }
425 if (cnode_update_state != nullptr) {
426 if (cnode_make_tuple == nullptr || cnode_make_tuple->size() == tuple_inputs_size) {
427 // case 1 and case 3: Replace cnode_update_state to cnode_u;
428 MS_LOG(DEBUG) << "Replace UpdateState with CNode U: " << cnode_update_state->DebugString()
429 << " ::TO:: " << cnode_u->DebugString();
430 manager->Replace(cnode_update_state, cnode_u);
431 } else if (cnode_make_tuple->size() > tuple_inputs_size) {
432 // case 2: remove cnode_load from cnode_make_tuple;
433 MS_LOG(DEBUG) << "Drop " << cnode_load->DebugString() << " from " << cnode_make_tuple->DebugString();
434 const auto &make_tuple_inputs = cnode_make_tuple->inputs();
435 AnfNodePtrList new_tuple_inputs(make_tuple_inputs.size() - 1);
436 std::copy_if(make_tuple_inputs.cbegin(), make_tuple_inputs.cend(), new_tuple_inputs.begin(),
437 [cnode_load](const auto &inp) { return inp != cnode_load; });
438 auto new_cnode_make_tuple = cnode_make_tuple->func_graph()->NewCNode(new_tuple_inputs);
439 manager->Replace(cnode_make_tuple, new_cnode_make_tuple);
440 } else {
441 MS_LOG(INTERNAL_EXCEPTION) << "Cannot replace UpdateState with CNode U: " << cnode_update_state->DebugString()
442 << " as make_tuple CNode cannot match " << cnode_make_tuple->DebugString();
443 }
444 }
445 }
446 }
447
CreateFusedCommunicationOp(const FuncGraphPtr & func_graph,const CommunicationOpInfo & communication_op_info,size_t start_index,size_t end_index) const448 AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
449 const CommunicationOpInfo &communication_op_info,
450 size_t start_index, size_t end_index) const {
451 MS_EXCEPTION_IF_NULL(func_graph);
452 auto prim = std::make_shared<Primitive>(op_name_);
453 MS_EXCEPTION_IF_NULL(prim);
454 std::vector<AnfNodePtr> fusion_inputs = {NewValueNode(prim)};
455 // get all inputs of current segment
456 if (end_index >= communication_op_info.communication_op_nodes.size()) {
457 MS_LOG(EXCEPTION) << "End index is out of communication_op_nodes size";
458 }
459 std::vector<AnfNodePtr> orig_nodes;
460 for (size_t idx = start_index; idx <= end_index; ++idx) {
461 auto cnode = communication_op_info.communication_op_nodes[idx];
462 MS_EXCEPTION_IF_NULL(cnode);
463 if (idx != start_index) {
464 AdjustAllReduceInputWithLoad(cnode);
465 }
466 auto inputs = cnode->inputs();
467 (void)fusion_inputs.insert(fusion_inputs.cend(), inputs.cbegin() + 1, inputs.cend());
468 (void)orig_nodes.emplace_back(cnode);
469 }
470 CheckInputs(fusion_inputs);
471 AnfNodePtr fused_node = NewCNode(fusion_inputs, func_graph, orig_nodes);
472 MS_EXCEPTION_IF_NULL(fused_node);
473 auto kernel_info = std::make_shared<device::KernelInfo>();
474 MS_EXCEPTION_IF_NULL(kernel_info);
475 fused_node->set_kernel_info(kernel_info);
476 auto final_node = communication_op_info.communication_op_nodes[end_index];
477 size_t node_num = end_index - start_index + 1;
478 int64_t rank_size = 1;
479 if (common::AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) &&
480 common::AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) {
481 rank_size = common::AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize);
482 }
483
484 if (rank_size == 0) {
485 MS_LOG(EXCEPTION) << "Rank size should not be zero.";
486 }
487 size_t output_num = node_num * LongToSize(rank_size);
488 std::vector<TypeId> dtypes(output_num, common::AnfAlgo::GetOutputInferDataType(final_node, 0));
489 std::vector<ShapeVector> shapes;
490 int64_t fusion_total_size = 0;
491 for (int64_t i = 0; i < rank_size; ++i) {
492 for (size_t idx = start_index; idx <= end_index; ++idx) {
493 auto input_node = communication_op_info.communication_op_nodes[idx];
494 MS_EXCEPTION_IF_NULL(input_node);
495 auto shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
496 if (!shape.empty()) {
497 shape[0] /= rank_size;
498 }
499 shapes.push_back(shape);
500 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
501 TypeId output_type = AnfAlgo::GetOutputDeviceDataType(input_node, 0);
502 size_t type_size = GetTypeByte(TypeIdToType(output_type));
503 if (type_size == 0) {
504 MS_LOG(EXCEPTION) << "Divisor 'type_size' should not be 0.";
505 }
506 tensor_size = (tensor_size / kAlignSize + 1) * kAlignSize / type_size;
507 fusion_total_size += static_cast<int64_t>(tensor_size);
508 }
509 }
510 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get());
511 auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
512 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get());
513 const std::vector<std::string> kHcclFusionAttrs = {
514 kAttrFusion, kAttrGroup, kAttrGroupBack, kAttrSrTag, kAttrDestRank, kAttrSrcRank,
515 kAttrDType, kAttrOp, kAttrRankSize, kAttrGroupRankIds, kAttrReuseCommunication, kAttrSegment};
516 for (const auto &attr : kHcclFusionAttrs) {
517 if (common::AnfAlgo::HasNodeAttr(attr, final_node)) {
518 common::AnfAlgo::CopyNodeAttr(attr, final_node, fused_node);
519 }
520 }
521 if (common::AnfAlgo::HasNodeAttr(kAttrShape, final_node)) {
522 std::vector<int64_t> fusion_total_shape{fusion_total_size};
523 common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(fusion_total_shape), fused_node);
524 }
525 bool is_recompute =
526 final_node->GetAttr(kAttrDuplicated) != nullptr && GetValue<bool>(final_node->GetAttr(kAttrDuplicated));
527 if (common::AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName && is_recompute) {
528 auto fused_cnode = fused_node->cast<CNodePtr>();
529 fused_cnode->AddAttr("duplicated", MakeValue(true));
530 auto fused_prim = GetCNodePrimitive(fused_cnode);
531 auto final_node_prim = GetCNodePrimitive(final_node);
532 fused_prim->set_instance_name(final_node_prim->instance_name());
533 }
534 if (common::AnfAlgo::HasNodeAttr(kAttrNotDelayFusion, final_node)) {
535 common::AnfAlgo::CopyNodeAttr(kAttrNotDelayFusion, final_node, fused_node);
536 }
537 return fused_node;
538 }
539
DoFusion(const FuncGraphPtr & func_graph,const CommunicationOpInfo & communication_op_info,const std::vector<size_t> & segment_index) const540 bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info,
541 const std::vector<size_t> &segment_index) const {
542 MS_EXCEPTION_IF_NULL(func_graph);
543 auto manager = func_graph->manager();
544 MS_EXCEPTION_IF_NULL(manager);
545 bool changed = false;
546 size_t start_index = 0;
547 for (size_t segment_idx = 0; segment_idx < segment_index.size(); ++segment_idx) {
548 size_t end_index = segment_index.at(segment_idx);
549 if (end_index - start_index < 1) {
550 start_index = end_index + 1;
551 continue;
552 }
553 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
554 MS_EXCEPTION_IF_NULL(kernel_graph);
555 auto graph_id = kernel_graph->graph_id();
556 AnfNodePtr new_communication_op =
557 CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
558 AnfAlgo::SetGraphId(graph_id, new_communication_op.get());
559 // replace old communication op with new communication op
560 for (auto idx = start_index; idx <= end_index; ++idx) {
561 std::vector<AnfNodePtr> tuple_getitem_input;
562 tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
563 tuple_getitem_input.push_back(new_communication_op);
564 auto offset = SizeToLong(idx - start_index);
565 auto index = NewValueNode(offset);
566 MS_EXCEPTION_IF_NULL(index);
567 auto imm = std::make_shared<Int64Imm>(idx - start_index);
568 MS_EXCEPTION_IF_NULL(imm);
569 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>();
570 MS_EXCEPTION_IF_NULL(abstract_scalar);
571 index->set_abstract(abstract_scalar);
572 tuple_getitem_input.push_back(index);
573 AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input);
574 MS_EXCEPTION_IF_NULL(tuple_getitem);
575 auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx);
576 MS_EXCEPTION_IF_NULL(communication_op_node_item);
577 tuple_getitem->set_abstract(communication_op_node_item->abstract());
578 if (kernel_graph->IsInternalOutput(communication_op_node_item, 0)) {
579 kernel_graph->ReplaceInternalOutput(communication_op_node_item, new_communication_op, 0, LongToSize(offset));
580 }
581 if (common::GetEnv("MS_ENABLE_FRONTEND_SCHEDULING_OPTIMIZATION") == "1") {
582 auto &users = manager->node_users()[communication_op_node_item];
583 for (auto &node : users) {
584 auto cnode = node.first->cast<CNodePtr>();
585 MS_EXCEPTION_IF_NULL(cnode);
586 if (cnode->HasAttr("comp_comm_scheduling_depend")) {
587 MS_LOG(INFO) << "Start EdgeRemove: AllReduce to comp_comm_scheduling_depend";
588 if (cnode->size() <= 1 || !common::AnfAlgo::IsCommunicationOp(cnode->input(1))) {
589 MS_LOG(INTERNAL_EXCEPTION) << "Input 1 of Cnode doesn't exist or is not a communication node!";
590 }
591 std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), cnode->input(1)->cast<CNodePtr>()};
592 auto depend_node = cnode->func_graph()->NewCNode(depend_inputs);
593 depend_node->set_abstract(cnode->input(1)->cast<CNodePtr>()->abstract()->Clone());
594 depend_node->AddAttr("comp_comm_scheduling_depend", MakeValue(true));
595 if (!manager->Replace(cnode, depend_node)) {
596 MS_LOG(INTERNAL_EXCEPTION) << "Manager replace node failed";
597 }
598 MS_LOG(INFO) << "End EdgeRemove: AllReduce to comp_comm_scheduling_depend";
599 }
600 }
601 }
602 if (!manager->Replace(communication_op_node_item, tuple_getitem)) {
603 MS_LOG(INTERNAL_EXCEPTION) << "Manager replace node failed";
604 }
605 }
606 start_index = end_index + 1;
607 changed = true;
608 }
609 return changed;
610 }
611
Run(const FuncGraphPtr & func_graph)612 bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
613 MS_EXCEPTION_IF_NULL(func_graph);
614 auto parallel_context = parallel::ParallelContext::GetInstance();
615 MS_EXCEPTION_IF_NULL(parallel_context);
616 auto threshold = parallel_context->dp_fusion_threshold_mb();
617 if (threshold == 0) {
618 return false;
619 }
620 const float input_grad_size_num = 0.0;
621 const float input_grad_time_num = 0.0;
622 // divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion
623 mindspore::HashMap<std::string, CommunicationOpInfo> candidate_groups;
624 // avoid fuse communication nodes with dependencies like comm_node1->depend->comm_node2
625 MarkCommunicationZone(func_graph, op_name_);
626 std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
627 for (auto &node : node_list) {
628 if (node != nullptr && node->isa<CNode>() && common::AnfAlgo::GetCNodeName(node) == op_name_) {
629 std::string group_name = GetFusionGroupKey(node);
630 if (group_name.empty()) {
631 continue;
632 }
633 std::string key = group_name + std::to_string(GetNodeCommZoneId(node->cast<CNodePtr>()));
634 if (candidate_groups.find(key) == candidate_groups.end()) {
635 CommunicationOpInfo communication_op_info;
636 candidate_groups[key] = communication_op_info;
637 communication_op_info.group_name = group_name;
638 }
639 candidate_groups[key].communication_op_nodes.push_back(node->cast<CNodePtr>());
640 candidate_groups[key].input_grad_size.push_back(input_grad_size_num);
641 candidate_groups[key].input_grad_time.push_back(input_grad_time_num);
642 }
643 }
644 RemoveCommunicationZone(func_graph);
645 // split candidate group to segments according to _group class member
646 bool changed = false;
647 for (auto &it : candidate_groups) {
648 if (it.second.communication_op_nodes.size() <= 1) {
649 continue;
650 }
651 auto first_node = it.second.communication_op_nodes[0];
652 TraceGuard guard(std::make_shared<TraceOpt>(first_node->debug_info()));
653 if (common::AnfAlgo::HasNodeAttr(kAttrIndex, first_node) &&
654 common::AnfAlgo::GetNodeAttr<int64_t>(first_node, kAttrIndex) > 0) {
655 std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(),
656 [](const CNodePtr &a, const CNodePtr &b) {
657 return common::AnfAlgo::GetNodeAttr<int64_t>(a, kAttrIndex) <
658 common::AnfAlgo::GetNodeAttr<int64_t>(b, kAttrIndex);
659 });
660 }
661 std::vector<size_t> segment_index;
662 if (GetSplitSegments(it.second, &segment_index, it.second.group_name)) {
663 if (DoFusion(func_graph, it.second, segment_index)) {
664 changed = true;
665 }
666 }
667 }
668 return changed;
669 }
670 } // namespace opt
671 } // namespace mindspore
672