1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/comm_op_reuse_tag.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <algorithm>
22 #include "mindspore/core/ops/other_ops.h"
23 #include "ir/func_graph.h"
24 #include "frontend/parallel/ops_info/ops_utils.h"
25 #include "frontend/parallel/device_manager.h"
26 #include "include/common/utils/parallel_context.h"
27 #include "frontend/parallel/step_parallel_utils.h"
28 #include "include/common/utils/utils.h"
29 #include "include/common/utils/comm_manager.h"
30
31 namespace mindspore {
32 namespace opt {
33 namespace {
is_comm_ops(const AnfNodePtr & node)34 inline bool is_comm_ops(const AnfNodePtr &node) {
35 static const std::vector<PrimitivePtr> kCommunicationOpsPrim = {prim::kPrimAllReduce,
36 prim::kPrimReduce,
37 prim::kPrimAllGather,
38 prim::kPrimReduceScatter,
39 prim::kPrimCollectiveScatter,
40 prim::kPrimCollectiveGather,
41 prim::kPrimAlltoAll,
42 prim::kPrimAllSwap,
43 prim::kPrimAllToAllv,
44 prim::kPrimNeighborExchange,
45 prim::kPrimNeighborExchangeV2,
46 prim::kPrimNeighborExchangeV2Grad,
47 prim::kPrimBarrier,
48 prim::kPrimBatchISendIRecv,
49 prim::kPrimAlltoAllV};
50
51 for (const auto &prim : kCommunicationOpsPrim) {
52 if (IsPrimitiveCNode(node, prim)) {
53 return true;
54 }
55 }
56
57 return false;
58 }
59 } // namespace
60
AddCommOpReuseTag(const FuncGraphPtr & graph)61 void AddCommOpReuseTag(const FuncGraphPtr &graph) {
62 if (parallel::g_device_manager == nullptr) {
63 MS_LOG(INFO) << "parallel::g_device_manager is not initialized.";
64 return;
65 }
66 MS_EXCEPTION_IF_NULL(graph);
67
68 if (!parallel::IsAutoParallelCareGraph(graph)) {
69 return;
70 }
71 auto manager = graph->manager();
72 MS_EXCEPTION_IF_NULL(manager);
73 const auto &all_nodes = manager->all_nodes();
74 for (auto &node : all_nodes) {
75 MS_EXCEPTION_IF_NULL(node);
76 if (!is_comm_ops(node)) {
77 continue;
78 }
79 auto comm_prim = GetCNodePrimitive(node);
80 MS_EXCEPTION_IF_NULL(comm_prim);
81 if (comm_prim->HasAttr(parallel::FUSION) && GetValue<int64_t>(comm_prim->GetAttr(parallel::FUSION)) != 0) {
82 continue;
83 }
84 (void)comm_prim->AddAttr(parallel::COMM_REUSE, MakeValue(true));
85
86 std::string group_name = "";
87 if (comm_prim->HasAttr(parallel::GROUP)) {
88 group_name = GetValue<std::string>(comm_prim->GetAttr(parallel::GROUP));
89 }
90 std::vector<unsigned int> rank_list = {};
91 auto long_rank_list = parallel::g_device_manager->FindRankListByHashName(group_name);
92 (void)std::transform(long_rank_list.begin(), long_rank_list.end(), std::back_inserter(rank_list),
93 [](int64_t d) -> unsigned int { return IntToUint(LongToInt(d)); });
94 (void)comm_prim->AddAttr(kAttrRankList, MakeValue<std::vector<unsigned int>>(rank_list));
95 }
96 }
97 } // namespace opt
98 } // namespace mindspore
99