1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h"
18 #include <memory>
19 #include <queue>
20 #include <string>
21 #include <functional>
22 #include <utility>
23 #include <vector>
24 #include <unordered_map>
25 #include "mindspore/core/ops/other_ops.h"
26 #include "utils/hash_set.h"
27 #include "ir/func_graph.h"
28 #include "frontend/parallel/costmodel_context.h"
29 #include "frontend/parallel/graph_util/node_info.h"
30 #include "frontend/parallel/status.h"
31 #include "frontend/parallel/parameter_manager.h"
32 #include "frontend/parallel/step_parallel.h"
33 #include "utils/log_adapter.h"
34
35 namespace mindspore {
36 namespace parallel {
SetMirrorFusion(const CNodePtr & mirror_cnode,int64_t fusion,const std::string & parameter_name)37 void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::string ¶meter_name) {
38 MS_EXCEPTION_IF_NULL(mirror_cnode);
39 MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
40 auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
41 (void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared<Int64Imm>(fusion)));
42 (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
43 }
44
AdjustRelatedFusionNode(const CNodePtr & ret,const std::unordered_map<std::string,CNodePtr> & comm_node_map)45 void AdjustRelatedFusionNode(const CNodePtr &ret, const std::unordered_map<std::string, CNodePtr> &comm_node_map) {
46 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
47 for (const auto &related_node : all_nodes) {
48 if (!IsPrimitiveCNode(related_node)) {
49 continue;
50 }
51 auto related_cnode = related_node->cast<CNodePtr>();
52 if (!related_cnode->HasAttr(kRelatedCommNodeId)) {
53 continue;
54 }
55 auto related_comm_node_id = GetValue<std::string>(related_cnode->GetAttr(kRelatedCommNodeId));
56 if (comm_node_map.find(related_comm_node_id) == comm_node_map.end()) {
57 continue;
58 }
59 auto comm_cnode = comm_node_map.at(related_comm_node_id);
60 if (!IsPrimitiveCNode(comm_cnode)) {
61 continue;
62 }
63 auto node_prim = GetValueNode<PrimitivePtr>(comm_cnode->input(0));
64 if (!node_prim->HasAttr(FUSION)) {
65 continue;
66 }
67 if (!related_cnode->HasPrimalAttr(kRelatedNodeId) || !related_cnode->HasPrimalAttr(kRelatedFusionKey)) {
68 continue;
69 }
70 auto related_fusion_key = GetValue<std::string>(related_cnode->GetPrimalAttr(kRelatedFusionKey));
71 auto fusion_id_pos = related_fusion_key.rfind("_");
72 if (fusion_id_pos != std::string::npos) {
73 auto sub_str = related_fusion_key.substr(0, fusion_id_pos);
74 auto auto_fusion_id = GetValue<int64_t>(node_prim->GetAttr(FUSION));
75 auto new_related_fusion_key = sub_str + "_" + std::to_string(auto_fusion_id);
76 MS_LOG(INFO) << "replace related fusion key to: " << new_related_fusion_key;
77 related_cnode->AddPrimalAttr(kRelatedFusionKey, MakeValue<std::string>(new_related_fusion_key));
78 }
79 }
80 }
81
SetFusionBySize(const CNodePtr & ret,int64_t threshold,const PrimitivePtr & primp) const82 Status AllCommFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp) const {
83 auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
84 auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
85 auto temp = threshold;
86 int64_t fusion = 1;
87 bool init = true;
88 std::string parameter_name;
89 std::string name;
90 std::unordered_map<std::string, CNodePtr> comm_node_map;
91 for (auto &node : todo) {
92 auto cnode = node->cast<CNodePtr>();
93 if (cnode->input(1)->Shape() == nullptr) {
94 continue;
95 }
96 auto input_shapes = GetNodeShape(cnode->input(1));
97 int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
98 FuncGraphPtr func_graph = cnode->func_graph();
99 if (IsPrimitiveEquals(primp, prim::kPrimMirror)) {
100 name = ALL_REDUCE;
101 std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
102 if (!param_node_pair.first) {
103 continue;
104 }
105 parameter_name = ParameterName(param_node_pair.first);
106 }
107
108 if (IsPrimitiveEquals(primp, prim::kPrimMicroStepAllGather) || IsPrimitiveEquals(primp, prim::kPrimAllGather)) {
109 name = ALL_GATHER;
110 if (!cnode->input(0) || !cnode->input(1)) {
111 continue;
112 }
113 PrimitivePtr primp1 = GetValueNode<PrimitivePtr>(cnode->input(0));
114 if (!primp1->HasAttr(RECOMPUTE) || GetValue<bool>(primp1->GetAttr(RECOMPUTE))) {
115 continue;
116 }
117 std::pair<AnfNodePtr, bool> param_node_pair = FindParameterWithAllgather(cnode->input(1), func_graph, name);
118 if (!param_node_pair.first) {
119 continue;
120 }
121 parameter_name = ParameterName(param_node_pair.first);
122 }
123
124 if (init || input_size < temp) {
125 temp -= input_size;
126 init = false;
127 } else {
128 temp = threshold;
129 fusion++;
130 }
131 SetMirrorFusion(cnode, fusion, parameter_name);
132 comm_node_map[cnode->UniqueId()] = cnode;
133 }
134 AdjustRelatedFusionNode(ret, comm_node_map);
135 MS_LOG(INFO) << name << " fusion by size succeed.";
136 return SUCCESS;
137 }
138
SetFusionBySizeReduceScatter(const CNodePtr & ret,int64_t threshold,const PrimitivePtr & primp) const139 Status AllCommFusion::SetFusionBySizeReduceScatter(const CNodePtr &ret, int64_t threshold,
140 const PrimitivePtr &primp) const {
141 auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
142 auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
143 auto temp = threshold;
144 int64_t fusion = 1;
145 bool init = true;
146 std::unordered_map<std::string, CNodePtr> comm_node_map;
147 for (auto &node : todo) {
148 auto cnode = node->cast<CNodePtr>();
149 if (cnode->input(1) == nullptr) {
150 continue;
151 }
152 FuncGraphPtr func_graph = cnode->func_graph();
153 std::pair<AnfNodePtr, bool> param_node_pair =
154 FindParameterWithAllgather(cnode->input(1), func_graph, REDUCE_SCATTER);
155 if (!param_node_pair.first) {
156 continue;
157 }
158 auto parameter_name = ParameterName(param_node_pair.first);
159 auto input_shapes = GetNodeShape(param_node_pair.first);
160 int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
161 if (init || input_size < temp) {
162 temp -= input_size;
163 init = false;
164 } else {
165 temp = threshold;
166 fusion++;
167 }
168 SetMirrorFusion(cnode, fusion, parameter_name);
169 comm_node_map[cnode->UniqueId()] = cnode;
170 }
171 AdjustRelatedFusionNode(ret, comm_node_map);
172 MS_LOG(INFO) << "Reduce_Scatter fusion by size succeed.";
173 return SUCCESS;
174 }
175
ProcessCommOpsFusion(const CNodePtr & ret,const std::string & comm_name)176 Status AllCommFusion::ProcessCommOpsFusion(const CNodePtr &ret, const std::string &comm_name) {
177 if (ret == nullptr) {
178 MS_LOG(ERROR) << "ret is nullptr.";
179 return FAILED;
180 }
181 ret_ = ret;
182 root_graph_ = ret_->func_graph();
183 MS_EXCEPTION_IF_NULL(root_graph_);
184 auto graph_set = ForwardGraph(root_graph_);
185 if (graph_set.size() > 1) {
186 MS_LOG(INFO) << comm_name << " fusion don't support multiple subgraphs now.";
187 return SUCCESS;
188 }
189 auto forward_graph = *(graph_set.begin());
190 MS_EXCEPTION_IF_NULL(forward_graph);
191 forward_ret_ = forward_graph->get_return();
192 MS_EXCEPTION_IF_NULL(forward_ret_);
193 if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
194 MS_LOG(ERROR) << comm_name << "Graph set_head_cnode failed.";
195 return FAILED;
196 }
197 int64_t threshold = 0;
198 if (comm_name == ALL_REDUCE) {
199 threshold = ParallelContext::GetInstance()->fusion_threshold_mb();
200 } else if (comm_name == ALL_GATHER) {
201 threshold = ParallelContext::GetInstance()->allgather_fusion_threshold_mb();
202 } else if (comm_name == REDUCE_SCATTER) {
203 threshold = ParallelContext::GetInstance()->reducescatter_fusion_threshold_mb();
204 } else {
205 MS_LOG(ERROR) << " Comm Ops must be ALL_REDUCE, ALL_GATHER or REDUCE_SCATTER, but got " << comm_name;
206 }
207 threshold *= DEFAULT_THRESHOLD_MB_TO_BYTE;
208 if (threshold <= 0) {
209 MS_LOG(ERROR) << "The threshold of" << comm_name << "fusion must be larger than 0, but got " << threshold << ".";
210 return FAILED;
211 }
212 if (comm_name == REDUCE_SCATTER) {
213 (void)SetFusionBySizeReduceScatter(ret, threshold, prim::kPrimVirtualAssignAdd);
214 }
215 if (comm_name == ALL_REDUCE) {
216 (void)SetFusionBySize(ret, threshold, prim::kPrimMirror);
217 }
218 if (comm_name == ALL_GATHER) {
219 (void)SetFusionBySize(ret, threshold, prim::kPrimMicroStepAllGather);
220 (void)SetFusionBySize(ret, threshold, prim::kPrimAllGather);
221 }
222
223 return SUCCESS;
224 }
225 } // namespace parallel
226 } // namespace mindspore
227