1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pass/full_micro_interleaved_order_control.h"
18 #include <memory>
19 #include <list>
20 #include <vector>
21 #include <string>
22 #include <algorithm>
23 #include <queue>
24 #include <unordered_map>
25 #include <utility>
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "include/common/utils/utils.h"
28 #include "frontend/parallel/step_parallel.h"
29
30 namespace mindspore {
31 namespace parallel {
32 namespace {
33 constexpr auto kGradientsFlag = "Gradients";
34 const size_t interleaved_size = 2;
35 const size_t node_size_two = 2;
36 const size_t node_size_three = 3;
37 using interleaved_node_pair_vector = std::vector<std::pair<size_t, std::vector<CNodePtr>>>;
IsBpropNode(const AnfNodePtr & node)38 bool IsBpropNode(const AnfNodePtr &node) {
39 MS_EXCEPTION_IF_NULL(node);
40 if (!node->isa<CNode>()) {
41 return false;
42 }
43 return node->fullname_with_scope().find(kGradientsFlag) == 0;
44 }
45
CheckCommNodeEqual(const CNodePtr comm_node1,const CNodePtr comm_node2)46 bool CheckCommNodeEqual(const CNodePtr comm_node1, const CNodePtr comm_node2) {
47 auto prim1 = GetCNodePrimitive(comm_node1);
48 auto prim2 = GetCNodePrimitive(comm_node2);
49 if (prim1->type_name() != prim2->type_name()) {
50 MS_LOG(INFO) << "Type of two comm node is not euqal";
51 return false;
52 }
53 if (!prim1->HasAttr(parallel::GROUP) || !prim2->HasAttr(parallel::GROUP)) {
54 return false;
55 }
56 auto group1 = GetValue<std::string>(prim1->GetAttr(parallel::GROUP));
57 auto group2 = GetValue<std::string>(prim2->GetAttr(parallel::GROUP));
58 if (group1 != group2) {
59 MS_LOG(INFO) << "Group of two comm node is not euqal.";
60 return false;
61 }
62 auto shape1 = dyn_cast<abstract::Shape>(comm_node1->Shape());
63 auto shape2 = dyn_cast<abstract::Shape>(comm_node2->Shape());
64 if (shape1 == nullptr || shape2 == nullptr) {
65 return false;
66 }
67 if (shape1->shape() != shape2->shape()) {
68 return false;
69 }
70 return true;
71 }
72
ExtractInterLeavedCommNode(const std::vector<CNodePtr> & origin_nodes_topological,bool is_forward,interleaved_node_pair_vector * micro_interleaved_fp_bp_node_list,int64_t pipeline_micro=-1)73 bool ExtractInterLeavedCommNode(const std::vector<CNodePtr> &origin_nodes_topological, bool is_forward,
74 interleaved_node_pair_vector *micro_interleaved_fp_bp_node_list,
75 int64_t pipeline_micro = -1) {
76 std::vector<std::pair<std::pair<size_t, size_t>, CNodePtr>> micro_interleaved_fp_bp_node_list0;
77 std::vector<std::pair<std::pair<size_t, size_t>, CNodePtr>> micro_interleaved_fp_bp_node_list1;
78 for (size_t i = 0; i < origin_nodes_topological.size(); ++i) {
79 auto cnode = origin_nodes_topological[i];
80 if (!common::AnfAlgo::IsCommunicationOp(cnode) || !cnode->HasAttr(parallel::MICRO_INTERLEAVED_FORWARD_COMM_ORDER) ||
81 !cnode->HasAttr(parallel::MICRO_INTERLEAVED_INDEX) || cnode->HasAttr(kAttrDuplicated)) {
82 continue;
83 }
84
85 if (is_forward == IsBpropNode(cnode)) {
86 continue;
87 }
88
89 if (pipeline_micro >= 0 && cnode->HasPrimalAttr(parallel::MICRO) &&
90 GetValue<int64_t>(cnode->GetPrimalAttr(parallel::MICRO)) != pipeline_micro) {
91 continue;
92 }
93 if (pipeline_micro >= 0 && !cnode->HasPrimalAttr(parallel::MICRO)) {
94 MS_LOG(INFO) << "communication cnode :" << cnode->DebugString() << " dose not contains micro info.";
95 continue;
96 }
97 size_t micro_interleaved_fp_bp_comm_order =
98 GetValue<size_t>(cnode->GetAttr(parallel::MICRO_INTERLEAVED_FORWARD_COMM_ORDER));
99 size_t micro_interleaved_index = GetValue<size_t>(cnode->GetAttr(parallel::MICRO_INTERLEAVED_INDEX));
100 if (micro_interleaved_index == 0) {
101 micro_interleaved_fp_bp_node_list0.push_back({{micro_interleaved_fp_bp_comm_order, i}, cnode});
102 } else if (micro_interleaved_index == 1) {
103 micro_interleaved_fp_bp_node_list1.push_back({{micro_interleaved_fp_bp_comm_order, i}, cnode});
104 } else {
105 MS_LOG(INFO) << "The micro interleaved num can only be 2.";
106 return false;
107 }
108 }
109 if (micro_interleaved_fp_bp_node_list0.size() != micro_interleaved_fp_bp_node_list1.size()) {
110 return false;
111 }
112 std::sort(micro_interleaved_fp_bp_node_list0.begin(), micro_interleaved_fp_bp_node_list0.end(),
113 [](auto pair1, auto pair2) { return pair1.first.first < pair2.first.first; });
114 std::sort(micro_interleaved_fp_bp_node_list1.begin(), micro_interleaved_fp_bp_node_list1.end(),
115 [](auto pair1, auto pair2) { return pair1.first.first < pair2.first.first; });
116 for (size_t i = 0; i < micro_interleaved_fp_bp_node_list0.size(); ++i) {
117 std::vector<CNodePtr> fp_bp_node_same_id;
118 if (micro_interleaved_fp_bp_node_list0[i].first.first != micro_interleaved_fp_bp_node_list1[i].first.first) {
119 return false;
120 }
121 fp_bp_node_same_id.push_back(micro_interleaved_fp_bp_node_list0[i].second);
122 fp_bp_node_same_id.push_back(micro_interleaved_fp_bp_node_list1[i].second);
123 (*micro_interleaved_fp_bp_node_list)
124 .push_back({micro_interleaved_fp_bp_node_list0[i].first.second, fp_bp_node_same_id});
125 }
126 std::sort((*micro_interleaved_fp_bp_node_list).begin(), (*micro_interleaved_fp_bp_node_list).end(),
127 [](auto pair1, auto pair2) { return pair1.first < pair2.first; });
128 return true;
129 }
130
CreateGroupForMicroInterleaved(const CNodePtr & comm_cnode,size_t micro_interleaved_index)131 void CreateGroupForMicroInterleaved(const CNodePtr &comm_cnode, size_t micro_interleaved_index) {
132 auto comm_prim = GetCNodePrimitive(comm_cnode);
133 auto group_name = GetValue<std::string>(comm_prim->GetAttr(parallel::GROUP));
134 if (group_name.find("micro_interleaved") != std::string::npos) {
135 return;
136 }
137 auto rank_ids = parallel::g_device_manager->FindRankListByHashName(group_name);
138 auto dev_list = parallel::g_device_manager->CreateDeviceListByRankList(rank_ids);
139 auto new_group_name = group_name + "_micro_interleaved_" + std::to_string(micro_interleaved_index);
140 parallel::Group cur_device_list;
141 (void)parallel::g_device_manager->CreateGroup(new_group_name, dev_list, &cur_device_list);
142 auto new_comm_prim = comm_prim->Clone();
143 (void)new_comm_prim->SetAttrs(comm_prim->attrs());
144 (void)new_comm_prim->AddAttr(parallel::GROUP, MakeValue<std::string>(new_group_name));
145 comm_cnode->set_input(0, NewValueNode(new_comm_prim));
146 }
147
InsertInterleavedNodesDepend(const FuncGraphManagerPtr & manager,const interleaved_node_pair_vector & micro_interleaved_node_list)148 void InsertInterleavedNodesDepend(const FuncGraphManagerPtr &manager,
149 const interleaved_node_pair_vector µ_interleaved_node_list) {
150 for (size_t i = 0; i < micro_interleaved_node_list.size() - 1; ++i) {
151 auto comm_node_list = micro_interleaved_node_list[i].second;
152 auto next_comm_node_list = micro_interleaved_node_list[i + 1].second;
153 auto comm_node_a = comm_node_list[0];
154 auto comm_node_b = comm_node_list[1];
155 auto next_comm_node_a = next_comm_node_list[0];
156 auto next_comm_node_b = next_comm_node_list[1];
157 if (next_comm_node_a->size() < node_size_two || !IsPrimitiveCNode(next_comm_node_a->input(1)) ||
158 comm_node_b->size() < node_size_two || !IsPrimitiveCNode(comm_node_b->input(1))) {
159 continue;
160 }
161 auto next_comm_node_a_input_node = next_comm_node_a->input(1)->cast<CNodePtr>();
162 auto comm_node_b_input_node = comm_node_b->input(1)->cast<CNodePtr>();
163 auto comm_node_a_node_users = manager->node_users()[comm_node_a];
164 auto comm_node_b_node_users = manager->node_users()[comm_node_b];
165 if (comm_node_a_node_users.empty() || comm_node_b_node_users.empty()) {
166 continue;
167 }
168 auto comm_node_a_output_node = comm_node_a_node_users.front().first;
169 auto comm_node_b_output_node = comm_node_b_node_users.front().first;
170 // comm_node_b_input -> depend -> comm_node_a_output
171 std::vector<AnfNodePtr> depend1_inputs{NewValueNode(prim::kPrimDepend), comm_node_a, comm_node_b_input_node};
172 auto depend_node1 = comm_node_a_output_node->func_graph()->NewCNode(depend1_inputs);
173 depend_node1->set_abstract(comm_node_a->abstract()->Clone());
174 depend_node1->AddAttr("micro_interleaved_depend1", MakeValue(true));
175 MS_EXCEPTION_IF_NULL(depend_node1);
176 manager->SetEdge(comm_node_a_output_node, comm_node_a_node_users.front().second, depend_node1);
177 // next_comm_node_a_input -> depend -> comm_node_b_output
178 std::vector<AnfNodePtr> depend2_inputs{NewValueNode(prim::kPrimDepend), comm_node_b, next_comm_node_a_input_node};
179 auto depend_node2 = next_comm_node_a_input_node->func_graph()->NewCNode(depend2_inputs);
180 depend_node2->AddAttr("micro_interleaved_depend2", MakeValue(true));
181 depend_node2->set_abstract(comm_node_b->abstract()->Clone());
182 MS_EXCEPTION_IF_NULL(depend_node2);
183 manager->SetEdge(comm_node_b_output_node, comm_node_b_node_users.front().second, depend_node2);
184 }
185 }
186
CreateExtraGroupForModelParallelCommNode(const std::vector<CNodePtr> & origin_nodes_topological,const interleaved_node_pair_vector & micro_interleaved_node_list)187 void CreateExtraGroupForModelParallelCommNode(const std::vector<CNodePtr> &origin_nodes_topological,
188 const interleaved_node_pair_vector µ_interleaved_node_list) {
189 std::unordered_map<std::string, size_t> unique_id_interleaved_map;
190 for (const auto &pair : micro_interleaved_node_list) {
191 auto cnode_list = pair.second;
192 CreateGroupForMicroInterleaved(cnode_list[0], 0);
193 CreateGroupForMicroInterleaved(cnode_list[1], 1);
194 if (!IsBpropNode(cnode_list[0]) && cnode_list[0]->HasPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)) {
195 auto forward_comm_node_unique_id =
196 GetValue<std::string>(cnode_list[0]->GetPrimalAttr(kPrimalAttrForwardCommNodeUniqueId));
197 unique_id_interleaved_map[forward_comm_node_unique_id] = 0;
198 }
199 if (!IsBpropNode(cnode_list[1]) && cnode_list[1]->HasPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)) {
200 auto forward_comm_node_unique_id =
201 GetValue<std::string>(cnode_list[1]->GetPrimalAttr(kPrimalAttrForwardCommNodeUniqueId));
202 unique_id_interleaved_map[forward_comm_node_unique_id] = 1;
203 }
204 }
205
206 if (unique_id_interleaved_map.empty()) {
207 return;
208 }
209
210 for (const auto &cnode : origin_nodes_topological) {
211 if (!cnode->HasAttr(kAttrDuplicated)) {
212 continue;
213 }
214 if (!cnode->HasPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)) {
215 continue;
216 }
217 auto duplicate_comm_node_unique_id =
218 GetValue<std::string>(cnode->GetPrimalAttr(kPrimalAttrForwardCommNodeUniqueId));
219 if (unique_id_interleaved_map.find(duplicate_comm_node_unique_id) == unique_id_interleaved_map.end()) {
220 continue;
221 }
222 CreateGroupForMicroInterleaved(cnode, unique_id_interleaved_map[duplicate_comm_node_unique_id]);
223 }
224 }
225
MicroInterleavedOrderControl(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,int pipeline_micro=-1)226 void MicroInterleavedOrderControl(const FuncGraphManagerPtr &manager,
227 const std::vector<CNodePtr> &origin_nodes_topological, int pipeline_micro = -1) {
228 // 1 order forward_node, and the node with same MICRO_INTERLEAVED_FORWARD_COMM_ORDER is the micro interleaved pair
229 // nodes.
230 interleaved_node_pair_vector micro_interleaved_forward_node_list;
231 if (!ExtractInterLeavedCommNode(origin_nodes_topological, true, µ_interleaved_forward_node_list,
232 pipeline_micro)) {
233 MS_LOG(INFO) << "Cannot match micro interleaved conditions.";
234 return;
235 }
236 interleaved_node_pair_vector micro_interleaved_backward_node_list;
237 if (!ExtractInterLeavedCommNode(origin_nodes_topological, false, µ_interleaved_backward_node_list,
238 pipeline_micro)) {
239 MS_LOG(INFO) << "Cannot match micro interleaved conditions.";
240 return;
241 }
242
243 if (micro_interleaved_forward_node_list.empty() || micro_interleaved_backward_node_list.empty()) {
244 MS_LOG(INFO) << "Cannot find micro interleaved nodes.";
245 return;
246 }
247 for (auto &pair : micro_interleaved_forward_node_list) {
248 auto cnode_list = pair.second;
249 if (!CheckCommNodeEqual(cnode_list[0], cnode_list[1])) {
250 MS_LOG(INFO) << cnode_list[0]->DebugString() << " and " << cnode_list[1]->DebugString()
251 << " not match for micro interleaved.";
252
253 return;
254 }
255 }
256 for (auto &pair : micro_interleaved_backward_node_list) {
257 auto cnode_list = pair.second;
258 if (!CheckCommNodeEqual(cnode_list[0], cnode_list[1])) {
259 MS_LOG(INFO) << cnode_list[0]->DebugString() << " and " << cnode_list[1]->fullname_with_scope()
260 << " not match for micro interleaved.";
261 return;
262 }
263 }
264 static const auto micro_interleaved_extra_comm_group = (common::GetEnv("interleaved_extra_group") == "1");
265 if (micro_interleaved_extra_comm_group) {
266 CreateExtraGroupForModelParallelCommNode(origin_nodes_topological, micro_interleaved_forward_node_list);
267 CreateExtraGroupForModelParallelCommNode(origin_nodes_topological, micro_interleaved_backward_node_list);
268 }
269 InsertInterleavedNodesDepend(manager, micro_interleaved_forward_node_list);
270 InsertInterleavedNodesDepend(manager, micro_interleaved_backward_node_list);
271 }
272
MicroInterleavedOrderControlPipeline(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological)273 void MicroInterleavedOrderControlPipeline(const FuncGraphManagerPtr &manager,
274 const std::vector<CNodePtr> &origin_nodes_topological) {
275 // 1 order forward_node, and the node with same MICRO_INTERLEAVED_FORWARD_COMM_ORDER is the micro interleaved pair
276 // nodes.
277 MS_EXCEPTION_IF_NULL(parallel::g_device_manager);
278 size_t pipeline_micro_size = parallel::ParallelContext::GetInstance()->pipeline_micro_size();
279 MS_LOG(INFO) << "The pipeline micro size in micro interleaved is: " << pipeline_micro_size;
280 for (size_t pipeline_micro_id = 0; pipeline_micro_id < pipeline_micro_size; ++pipeline_micro_id) {
281 MicroInterleavedOrderControl(manager, origin_nodes_topological, pipeline_micro_id);
282 }
283 return;
284 }
285 } // namespace
286
FullMicroInterleavedOrderControl(const FuncGraphPtr & graph)287 void FullMicroInterleavedOrderControl(const FuncGraphPtr &graph) {
288 if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
289 parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
290 return;
291 }
292 auto context = MsContext::GetInstance();
293 MS_EXCEPTION_IF_NULL(context);
294 const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
295 if (cell_reuse) {
296 return;
297 }
298 if (!parallel::ParallelContext::GetInstance()->enable_micro_interleaved()) {
299 return;
300 }
301 if (common::GetEnv("MS_ENABLE_FRONTEND_SCHEDULING_OPTIMIZATION") == "1") {
302 return;
303 }
304 MS_EXCEPTION_IF_NULL(graph);
305 auto manager = graph->manager();
306 MS_EXCEPTION_IF_NULL(manager);
307 std::list<CNodePtr> orders = graph->GetOrderedCnodes();
308 std::vector<CNodePtr> origin_nodes_topological(orders.cbegin(), orders.cend());
309 if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() == 1) {
310 MicroInterleavedOrderControl(manager, origin_nodes_topological);
311 return;
312 }
313 MicroInterleavedOrderControlPipeline(manager, origin_nodes_topological);
314 }
315 } // namespace parallel
316 } // namespace mindspore
317