1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pass/overlap_grad_comm.h"
18 #include <memory>
19 #include <vector>
20 #include <list>
21 #include <set>
22 #include <unordered_map>
23 #include <algorithm>
24 #include <string>
25 #include <queue>
26 #include "mindspore/core/ops/math_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "mindspore/core/ops/other_ops.h"
29 #include "mindspore/core/ops/sequence_ops.h"
30 #include "frontend/parallel/ops_info/ops_utils.h"
31 #include "frontend/parallel/device_manager.h"
32 #include "frontend/parallel/pass/pass_utils.h"
33 #include "include/common/utils/parallel_context.h"
34 #include "frontend/parallel/step_parallel_utils.h"
35 #include "include/common/utils/utils.h"
36
37 namespace mindspore {
38 namespace parallel {
39 namespace {
40 const size_t loop_count = 1000;
41
GetDwRelyNodes(const CNodePtr & dw_matmul)42 std::vector<CNodePtr> GetDwRelyNodes(const CNodePtr &dw_matmul) {
43 // second input is the recompute node
44 std::vector<CNodePtr> rely_nodes;
45 std::queue<CNodePtr> cnode_queue;
46 std::set<AnfNodePtr> visited;
47 if (dw_matmul->input(kIndex2)->isa<CNode>()) {
48 cnode_queue.push(dw_matmul->input(kIndex2)->cast<CNodePtr>());
49 }
50 if (dw_matmul->input(kIndex1)->isa<CNode>()) {
51 cnode_queue.push(dw_matmul->input(kIndex1)->cast<CNodePtr>());
52 }
53 while (!cnode_queue.empty()) {
54 auto queue_front = cnode_queue.front();
55 cnode_queue.pop();
56 for (size_t i = 1; i < queue_front->size(); ++i) {
57 if (std::find(visited.begin(), visited.end(), queue_front->input(i)) != visited.end()) {
58 continue;
59 }
60 (void)visited.insert(queue_front->input(i));
61 if (!IsPrimitiveCNode(queue_front->input(i))) {
62 continue;
63 }
64 auto input_cnode = queue_front->input(i)->cast<CNodePtr>();
65 cnode_queue.push(input_cnode);
66 if (input_cnode->HasAttr(kAttrDuplicated)) {
67 continue;
68 }
69 rely_nodes.push_back(input_cnode);
70 if (rely_nodes.size() > loop_count) {
71 break;
72 }
73 }
74 }
75 return rely_nodes;
76 }
77
InsertDwMatmulDepend(const FuncGraphPtr & backward_graph,const std::vector<CNodePtr> & dw_matmul_list)78 void InsertDwMatmulDepend(const FuncGraphPtr &backward_graph, const std::vector<CNodePtr> &dw_matmul_list) {
79 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_COMM_OVERLAP)) {
80 return;
81 }
82 auto manager = backward_graph->manager();
83 for (size_t i = 0; i + 1 < dw_matmul_list.size(); ++i) {
84 auto cur_dw_matmul = dw_matmul_list[i];
85 auto next_dw_matmul = dw_matmul_list[i + 1];
86 std::vector<AnfNodePtr> depend4_inputs{NewValueNode(prim::kPrimDepend), next_dw_matmul->input(kIndex1),
87 cur_dw_matmul};
88 auto depend_node4 = backward_graph->NewCNode(depend4_inputs);
89 depend_node4->set_abstract(next_dw_matmul->input(kIndex1)->abstract()->Clone());
90 depend_node4->AddAttr("grad_comm_depend4", MakeValue(true));
91 manager->SetEdge(next_dw_matmul, kIndex1, depend_node4);
92 }
93 }
94
InsertDependForDxAndGradComm(const FuncGraphPtr & backward_graph,const std::vector<CNodePtr> & dx_matmul_list,const std::unordered_map<CNodePtr,CNodePtr> & backward_matmul_dx_dw_map,const std::unordered_map<CNodePtr,std::vector<CNodePtr>> & dx_grad_comm_map)95 void InsertDependForDxAndGradComm(const FuncGraphPtr &backward_graph, const std::vector<CNodePtr> &dx_matmul_list,
96 const std::unordered_map<CNodePtr, CNodePtr> &backward_matmul_dx_dw_map,
97 const std::unordered_map<CNodePtr, std::vector<CNodePtr>> &dx_grad_comm_map) {
98 auto manager = backward_graph->manager();
99 std::vector<CNodePtr> matched_dx_list;
100 // there are two comm node when opt sharding not fully
101 std::vector<std::vector<CNodePtr>> grad_comm_list;
102 std::vector<CNodePtr> dw_matmul_list;
103 for (size_t i = 0; i < dx_matmul_list.size(); ++i) {
104 auto cur_dx_matmul = dx_matmul_list[i];
105 auto dw_matmul = backward_matmul_dx_dw_map.at(cur_dx_matmul);
106 auto grad_comm = dx_grad_comm_map.at(cur_dx_matmul);
107 // Check dw_matmul inputs contains cur_dx_matmul or not, if contains, dx_matmul ++
108 auto dw_rely_nodes = GetDwRelyNodes(dw_matmul);
109 CNodePtr dx_matmul = nullptr;
110 for (size_t j = i; j < dx_matmul_list.size(); ++j) {
111 if (std::find(matched_dx_list.begin(), matched_dx_list.end(), dx_matmul_list[j]) != matched_dx_list.end()) {
112 continue;
113 }
114 if (std::find(dw_rely_nodes.begin(), dw_rely_nodes.end(), dx_matmul_list[j]) != dw_rely_nodes.end()) {
115 continue;
116 }
117 dx_matmul = dx_matmul_list[j];
118 matched_dx_list.push_back(dx_matmul_list[j]);
119 break;
120 }
121 if (!dx_matmul) {
122 continue;
123 }
124 if (grad_comm.size() > SIZE_TWO) {
125 continue;
126 }
127 if (grad_comm.size() == SIZE_TWO) {
128 if (IsPrimitiveCNode(grad_comm.front(), prim::kPrimReduceScatter)) {
129 auto tmp = grad_comm.front();
130 grad_comm[kIndex0] = grad_comm[kIndex1];
131 grad_comm[kIndex1] = tmp;
132 }
133 }
134 // insert depend
135 MS_LOG(INFO) << "insert depend for comm node:" << grad_comm.front()->fullname_with_scope()
136 << ", unique id:" << AnfNodeInfo(grad_comm.front())
137 << ", dx_matmul: " << dx_matmul->fullname_with_scope() << ", unique id:" << AnfNodeInfo(dx_matmul)
138 << ", dw_matmul: " << dw_matmul->fullname_with_scope() << ", unique id:" << AnfNodeInfo(dw_matmul);
139 // grad comm -> dx_matmul
140 auto grad_comm_input = grad_comm.front()->input(kIndex1);
141 auto dx_matmul_input = dx_matmul->input(kIndex1);
142 std::vector<AnfNodePtr> depend1_inputs{NewValueNode(prim::kPrimDepend), dx_matmul_input, grad_comm_input};
143 auto depend_node1 = backward_graph->NewCNode(depend1_inputs);
144 depend_node1->set_abstract(dx_matmul_input->abstract()->Clone());
145 depend_node1->AddAttr("grad_comm_depend1", MakeValue(true));
146 manager->SetEdge(dx_matmul, kIndex1, depend_node1);
147 // dx_matmul -> grad comm output
148 auto comm_output_users = manager->node_users()[grad_comm.back()];
149 for (const auto &comm_output_pair : comm_output_users) {
150 if (!IsPrimitiveCNode(comm_output_pair.first)) {
151 continue;
152 }
153 if (IsPrimitiveCNode(comm_output_pair.first, prim::kPrimDepend) && comm_output_pair.second == kIndex2) {
154 continue;
155 }
156 std::vector<AnfNodePtr> depend2_inputs{NewValueNode(prim::kPrimDepend), grad_comm.back(), dx_matmul};
157 auto depend_node2 = backward_graph->NewCNode(depend2_inputs);
158 depend_node2->set_abstract(grad_comm.back()->abstract()->Clone());
159 depend_node2->AddAttr("grad_comm_depend2", MakeValue(true));
160 manager->SetEdge(comm_output_pair.first, comm_output_pair.second, depend_node2);
161 }
162 grad_comm_list.push_back(grad_comm);
163 dw_matmul_list.push_back(dw_matmul);
164 }
165 for (size_t i = 0; i + 1 < grad_comm_list.size(); ++i) {
166 auto grad_comm_node = grad_comm_list[i].back();
167 auto next_grad_comm_node = grad_comm_list[i + 1].front();
168 auto grad_comm_node_users = manager->node_users()[grad_comm_node];
169 if (grad_comm_node_users.empty()) {
170 continue;
171 }
172 auto grad_comm_node_user = grad_comm_node_users.front().first;
173 std::vector<AnfNodePtr> depend3_inputs{NewValueNode(prim::kPrimDepend), next_grad_comm_node->input(kIndex1),
174 grad_comm_node_user};
175 auto depend_node3 = backward_graph->NewCNode(depend3_inputs);
176 depend_node3->set_abstract(next_grad_comm_node->input(kIndex1)->abstract()->Clone());
177 depend_node3->AddAttr("grad_comm_depend3", MakeValue(true));
178 manager->SetEdge(next_grad_comm_node, kIndex1, depend_node3);
179 }
180 InsertDwMatmulDepend(backward_graph, dw_matmul_list);
181 }
182
ExtractDxGradCommMap(const std::vector<CNodePtr> & origin_nodes_topological,const std::unordered_map<CNodePtr,CNodePtr> & backward_matmul_dx_dw_map)183 std::unordered_map<CNodePtr, std::vector<CNodePtr>> ExtractDxGradCommMap(
184 const std::vector<CNodePtr> &origin_nodes_topological,
185 const std::unordered_map<CNodePtr, CNodePtr> &backward_matmul_dx_dw_map) {
186 std::unordered_map<CNodePtr, std::vector<CNodePtr>> backward_matmul_dx_grad_comm_map;
187 std::vector<CNodePtr> backward_matmul_dx_grad_comm_vector;
188 for (const auto &node : origin_nodes_topological) {
189 if (!node->HasPrimalAttr(kPrimalAttrMirrorUserId) || !IsSomePrimitiveList(node, {ALL_REDUCE, REDUCE_SCATTER})) {
190 continue;
191 }
192 auto user_id = GetValue<std::string>(node->GetPrimalAttr(kPrimalAttrMirrorUserId));
193 CNodePtr matched_dx_node = nullptr;
194 int64_t pre_micro = -1;
195 for (const auto &key_node : backward_matmul_dx_dw_map) {
196 if (!key_node.first->HasPrimalAttr(kPrimalAttrMirrorUserId)) {
197 continue;
198 }
199 auto key_node_user_id = GetValue<std::string>(key_node.first->GetPrimalAttr(kPrimalAttrMirrorUserId));
200 if (key_node_user_id != user_id) {
201 continue;
202 }
203 if (!key_node.first->HasPrimalAttr(MICRO)) {
204 matched_dx_node = key_node.first;
205 break;
206 }
207 auto micro = GetValue<int64_t>(key_node.first->GetPrimalAttr(MICRO));
208 if (micro > pre_micro) {
209 pre_micro = micro;
210 matched_dx_node = key_node.first;
211 }
212 }
213 if (!matched_dx_node) {
214 MS_LOG(INFO) << "cannot match comm node:" << node->fullname_with_scope() << ", id:" << AnfNodeInfo(node);
215 continue;
216 }
217 backward_matmul_dx_grad_comm_map[matched_dx_node].push_back(node);
218 backward_matmul_dx_grad_comm_vector.push_back(matched_dx_node);
219 }
220 if (!parallel::ParallelContext::GetInstance()->enable_fine_grained_micro_interleaved()) {
221 return backward_matmul_dx_grad_comm_map;
222 }
223 MS_LOG(INFO) << "Enabled fine grained micro interleaved.";
224 // One grad comm match multi dx node
225 for (const auto &dx_cnode : backward_matmul_dx_grad_comm_vector) {
226 auto dw_cnode = backward_matmul_dx_dw_map.at(dx_cnode);
227 for (const auto &dx_dw : backward_matmul_dx_dw_map) {
228 if (dx_dw.second != dw_cnode) {
229 continue;
230 }
231 if (dx_dw.first == dx_cnode) {
232 continue;
233 }
234 backward_matmul_dx_grad_comm_map[dx_dw.first] = backward_matmul_dx_grad_comm_map[dx_cnode];
235 }
236 }
237 return backward_matmul_dx_grad_comm_map;
238 }
239
OverlapDxAndGradComm(const FuncGraphPtr & backward_graph)240 void OverlapDxAndGradComm(const FuncGraphPtr &backward_graph) {
241 std::list<CNodePtr> backward_orders = backward_graph->GetOrderedCnodes();
242 std::vector<CNodePtr> backward_origin_nodes_topological(backward_orders.cbegin(), backward_orders.cend());
243 std::unordered_map<CNodePtr, CNodePtr> backward_matmul_dx_dw_map;
244 ExtractBackwardMatMul(backward_origin_nodes_topological, &backward_matmul_dx_dw_map);
245 ExtendDxDwMap(backward_origin_nodes_topological, &backward_matmul_dx_dw_map);
246 auto dx_grad_comm_map = ExtractDxGradCommMap(backward_origin_nodes_topological, backward_matmul_dx_dw_map);
247 std::vector<CNodePtr> dx_matmul_list;
248 for (const auto &dx_matmul : backward_origin_nodes_topological) {
249 if (!IsPrimitiveCNode(dx_matmul, prim::kPrimMatMul) && !IsPrimitiveCNode(dx_matmul, prim::kPrimMatMulV2)) {
250 continue;
251 }
252 if (dx_grad_comm_map.count(dx_matmul) == 0) {
253 continue;
254 }
255 if (dx_matmul->HasAttr(INTERLEAVED_OVERLAP_MATMUL)) {
256 continue;
257 }
258 dx_matmul_list.push_back(dx_matmul);
259 }
260 InsertDependForDxAndGradComm(backward_graph, dx_matmul_list, backward_matmul_dx_dw_map, dx_grad_comm_map);
261 }
262
263 } // namespace
264
OverlapGradComm(const FuncGraphPtr & graph)265 void OverlapGradComm(const FuncGraphPtr &graph) {
266 if (parallel::g_device_manager == nullptr) {
267 MS_LOG(INFO) << "parallel::g_device_manager is not initialized.";
268 return;
269 }
270 auto ms_context = MsContext::GetInstance();
271 MS_EXCEPTION_IF_NULL(ms_context);
272 auto soc_version = ms_context->ascend_soc_version();
273 if (soc_version != "ascend910" && soc_version != "ascend910b" && soc_version != "ascend910c") {
274 return;
275 }
276 auto is_enable = ms_context->get_param<bool>(MS_CTX_ENABLE_GRAD_COMM_OPT);
277 if (!is_enable) {
278 return;
279 }
280 const auto cell_reuse = ms_context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
281 if (cell_reuse) {
282 MS_LOG(WARNING) << "Currently, grad communication overlap dose not support in lazy_line mode.";
283 return;
284 }
285 auto manager = graph->manager();
286 FuncGraphPtr backward_graph = graph;
287 for (const auto &each_graph : manager->func_graphs()) {
288 if (IsCellReuseForwardGraph(each_graph)) {
289 auto forward_graph = each_graph;
290 // need to using the inlined backward_graph
291 backward_graph = GetCellReuseBackwardGraph(forward_graph);
292 if (backward_graph == nullptr) {
293 MS_LOG(WARNING)
294 << "Failed to find backward cell reuse graph, skip pass 'overlap_gradmatmul_and_gradallreduce'.";
295 return;
296 }
297 break;
298 }
299 }
300 OverlapDxAndGradComm(backward_graph);
301 }
302 } // namespace parallel
303 } // namespace mindspore
304