• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/pass/full_micro_interleaved_order_control.h"
18 #include <memory>
19 #include <list>
20 #include <vector>
21 #include <string>
22 #include <algorithm>
23 #include <queue>
24 #include <unordered_map>
25 #include <utility>
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "include/common/utils/utils.h"
28 #include "frontend/parallel/step_parallel.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 namespace {
33 constexpr auto kGradientsFlag = "Gradients";
34 const size_t interleaved_size = 2;
35 const size_t node_size_two = 2;
36 const size_t node_size_three = 3;
37 using interleaved_node_pair_vector = std::vector<std::pair<size_t, std::vector<CNodePtr>>>;
IsBpropNode(const AnfNodePtr & node)38 bool IsBpropNode(const AnfNodePtr &node) {
39   MS_EXCEPTION_IF_NULL(node);
40   if (!node->isa<CNode>()) {
41     return false;
42   }
43   return node->fullname_with_scope().find(kGradientsFlag) == 0;
44 }
45 
CheckCommNodeEqual(const CNodePtr comm_node1,const CNodePtr comm_node2)46 bool CheckCommNodeEqual(const CNodePtr comm_node1, const CNodePtr comm_node2) {
47   auto prim1 = GetCNodePrimitive(comm_node1);
48   auto prim2 = GetCNodePrimitive(comm_node2);
49   if (prim1->type_name() != prim2->type_name()) {
50     MS_LOG(INFO) << "Type of two comm node is not euqal";
51     return false;
52   }
53   if (!prim1->HasAttr(parallel::GROUP) || !prim2->HasAttr(parallel::GROUP)) {
54     return false;
55   }
56   auto group1 = GetValue<std::string>(prim1->GetAttr(parallel::GROUP));
57   auto group2 = GetValue<std::string>(prim2->GetAttr(parallel::GROUP));
58   if (group1 != group2) {
59     MS_LOG(INFO) << "Group of two comm node is not euqal.";
60     return false;
61   }
62   auto shape1 = dyn_cast<abstract::Shape>(comm_node1->Shape());
63   auto shape2 = dyn_cast<abstract::Shape>(comm_node2->Shape());
64   if (shape1 == nullptr || shape2 == nullptr) {
65     return false;
66   }
67   if (shape1->shape() != shape2->shape()) {
68     return false;
69   }
70   return true;
71 }
72 
ExtractInterLeavedCommNode(const std::vector<CNodePtr> & origin_nodes_topological,bool is_forward,interleaved_node_pair_vector * micro_interleaved_fp_bp_node_list,int64_t pipeline_micro=-1)73 bool ExtractInterLeavedCommNode(const std::vector<CNodePtr> &origin_nodes_topological, bool is_forward,
74                                 interleaved_node_pair_vector *micro_interleaved_fp_bp_node_list,
75                                 int64_t pipeline_micro = -1) {
76   std::vector<std::pair<std::pair<size_t, size_t>, CNodePtr>> micro_interleaved_fp_bp_node_list0;
77   std::vector<std::pair<std::pair<size_t, size_t>, CNodePtr>> micro_interleaved_fp_bp_node_list1;
78   for (size_t i = 0; i < origin_nodes_topological.size(); ++i) {
79     auto cnode = origin_nodes_topological[i];
80     if (!common::AnfAlgo::IsCommunicationOp(cnode) || !cnode->HasAttr(parallel::MICRO_INTERLEAVED_FORWARD_COMM_ORDER) ||
81         !cnode->HasAttr(parallel::MICRO_INTERLEAVED_INDEX) || cnode->HasAttr(kAttrDuplicated)) {
82       continue;
83     }
84 
85     if (is_forward == IsBpropNode(cnode)) {
86       continue;
87     }
88 
89     if (pipeline_micro >= 0 && cnode->HasPrimalAttr(parallel::MICRO) &&
90         GetValue<int64_t>(cnode->GetPrimalAttr(parallel::MICRO)) != pipeline_micro) {
91       continue;
92     }
93     if (pipeline_micro >= 0 && !cnode->HasPrimalAttr(parallel::MICRO)) {
94       MS_LOG(INFO) << "communication cnode :" << cnode->DebugString() << " dose not contains micro info.";
95       continue;
96     }
97     size_t micro_interleaved_fp_bp_comm_order =
98       GetValue<size_t>(cnode->GetAttr(parallel::MICRO_INTERLEAVED_FORWARD_COMM_ORDER));
99     size_t micro_interleaved_index = GetValue<size_t>(cnode->GetAttr(parallel::MICRO_INTERLEAVED_INDEX));
100     if (micro_interleaved_index == 0) {
101       micro_interleaved_fp_bp_node_list0.push_back({{micro_interleaved_fp_bp_comm_order, i}, cnode});
102     } else if (micro_interleaved_index == 1) {
103       micro_interleaved_fp_bp_node_list1.push_back({{micro_interleaved_fp_bp_comm_order, i}, cnode});
104     } else {
105       MS_LOG(INFO) << "The micro interleaved num can only be 2.";
106       return false;
107     }
108   }
109   if (micro_interleaved_fp_bp_node_list0.size() != micro_interleaved_fp_bp_node_list1.size()) {
110     return false;
111   }
112   std::sort(micro_interleaved_fp_bp_node_list0.begin(), micro_interleaved_fp_bp_node_list0.end(),
113             [](auto pair1, auto pair2) { return pair1.first.first < pair2.first.first; });
114   std::sort(micro_interleaved_fp_bp_node_list1.begin(), micro_interleaved_fp_bp_node_list1.end(),
115             [](auto pair1, auto pair2) { return pair1.first.first < pair2.first.first; });
116   for (size_t i = 0; i < micro_interleaved_fp_bp_node_list0.size(); ++i) {
117     std::vector<CNodePtr> fp_bp_node_same_id;
118     if (micro_interleaved_fp_bp_node_list0[i].first.first != micro_interleaved_fp_bp_node_list1[i].first.first) {
119       return false;
120     }
121     fp_bp_node_same_id.push_back(micro_interleaved_fp_bp_node_list0[i].second);
122     fp_bp_node_same_id.push_back(micro_interleaved_fp_bp_node_list1[i].second);
123     (*micro_interleaved_fp_bp_node_list)
124       .push_back({micro_interleaved_fp_bp_node_list0[i].first.second, fp_bp_node_same_id});
125   }
126   std::sort((*micro_interleaved_fp_bp_node_list).begin(), (*micro_interleaved_fp_bp_node_list).end(),
127             [](auto pair1, auto pair2) { return pair1.first < pair2.first; });
128   return true;
129 }
130 
CreateGroupForMicroInterleaved(const CNodePtr & comm_cnode,size_t micro_interleaved_index)131 void CreateGroupForMicroInterleaved(const CNodePtr &comm_cnode, size_t micro_interleaved_index) {
132   auto comm_prim = GetCNodePrimitive(comm_cnode);
133   auto group_name = GetValue<std::string>(comm_prim->GetAttr(parallel::GROUP));
134   if (group_name.find("micro_interleaved") != std::string::npos) {
135     return;
136   }
137   auto rank_ids = parallel::g_device_manager->FindRankListByHashName(group_name);
138   auto dev_list = parallel::g_device_manager->CreateDeviceListByRankList(rank_ids);
139   auto new_group_name = group_name + "_micro_interleaved_" + std::to_string(micro_interleaved_index);
140   parallel::Group cur_device_list;
141   (void)parallel::g_device_manager->CreateGroup(new_group_name, dev_list, &cur_device_list);
142   auto new_comm_prim = comm_prim->Clone();
143   (void)new_comm_prim->SetAttrs(comm_prim->attrs());
144   (void)new_comm_prim->AddAttr(parallel::GROUP, MakeValue<std::string>(new_group_name));
145   comm_cnode->set_input(0, NewValueNode(new_comm_prim));
146 }
147 
InsertInterleavedNodesDepend(const FuncGraphManagerPtr & manager,const interleaved_node_pair_vector & micro_interleaved_node_list)148 void InsertInterleavedNodesDepend(const FuncGraphManagerPtr &manager,
149                                   const interleaved_node_pair_vector &micro_interleaved_node_list) {
150   for (size_t i = 0; i < micro_interleaved_node_list.size() - 1; ++i) {
151     auto comm_node_list = micro_interleaved_node_list[i].second;
152     auto next_comm_node_list = micro_interleaved_node_list[i + 1].second;
153     auto comm_node_a = comm_node_list[0];
154     auto comm_node_b = comm_node_list[1];
155     auto next_comm_node_a = next_comm_node_list[0];
156     auto next_comm_node_b = next_comm_node_list[1];
157     if (next_comm_node_a->size() < node_size_two || !IsPrimitiveCNode(next_comm_node_a->input(1)) ||
158         comm_node_b->size() < node_size_two || !IsPrimitiveCNode(comm_node_b->input(1))) {
159       continue;
160     }
161     auto next_comm_node_a_input_node = next_comm_node_a->input(1)->cast<CNodePtr>();
162     auto comm_node_b_input_node = comm_node_b->input(1)->cast<CNodePtr>();
163     auto comm_node_a_node_users = manager->node_users()[comm_node_a];
164     auto comm_node_b_node_users = manager->node_users()[comm_node_b];
165     if (comm_node_a_node_users.empty() || comm_node_b_node_users.empty()) {
166       continue;
167     }
168     auto comm_node_a_output_node = comm_node_a_node_users.front().first;
169     auto comm_node_b_output_node = comm_node_b_node_users.front().first;
170     // comm_node_b_input -> depend -> comm_node_a_output
171     std::vector<AnfNodePtr> depend1_inputs{NewValueNode(prim::kPrimDepend), comm_node_a, comm_node_b_input_node};
172     auto depend_node1 = comm_node_a_output_node->func_graph()->NewCNode(depend1_inputs);
173     depend_node1->set_abstract(comm_node_a->abstract()->Clone());
174     depend_node1->AddAttr("micro_interleaved_depend1", MakeValue(true));
175     MS_EXCEPTION_IF_NULL(depend_node1);
176     manager->SetEdge(comm_node_a_output_node, comm_node_a_node_users.front().second, depend_node1);
177     // next_comm_node_a_input -> depend -> comm_node_b_output
178     std::vector<AnfNodePtr> depend2_inputs{NewValueNode(prim::kPrimDepend), comm_node_b, next_comm_node_a_input_node};
179     auto depend_node2 = next_comm_node_a_input_node->func_graph()->NewCNode(depend2_inputs);
180     depend_node2->AddAttr("micro_interleaved_depend2", MakeValue(true));
181     depend_node2->set_abstract(comm_node_b->abstract()->Clone());
182     MS_EXCEPTION_IF_NULL(depend_node2);
183     manager->SetEdge(comm_node_b_output_node, comm_node_b_node_users.front().second, depend_node2);
184   }
185 }
186 
CreateExtraGroupForModelParallelCommNode(const std::vector<CNodePtr> & origin_nodes_topological,const interleaved_node_pair_vector & micro_interleaved_node_list)187 void CreateExtraGroupForModelParallelCommNode(const std::vector<CNodePtr> &origin_nodes_topological,
188                                               const interleaved_node_pair_vector &micro_interleaved_node_list) {
189   std::unordered_map<std::string, size_t> unique_id_interleaved_map;
190   for (const auto &pair : micro_interleaved_node_list) {
191     auto cnode_list = pair.second;
192     CreateGroupForMicroInterleaved(cnode_list[0], 0);
193     CreateGroupForMicroInterleaved(cnode_list[1], 1);
194     if (!IsBpropNode(cnode_list[0]) && cnode_list[0]->HasPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)) {
195       auto forward_comm_node_unique_id =
196         GetValue<std::string>(cnode_list[0]->GetPrimalAttr(kPrimalAttrForwardCommNodeUniqueId));
197       unique_id_interleaved_map[forward_comm_node_unique_id] = 0;
198     }
199     if (!IsBpropNode(cnode_list[1]) && cnode_list[1]->HasPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)) {
200       auto forward_comm_node_unique_id =
201         GetValue<std::string>(cnode_list[1]->GetPrimalAttr(kPrimalAttrForwardCommNodeUniqueId));
202       unique_id_interleaved_map[forward_comm_node_unique_id] = 1;
203     }
204   }
205 
206   if (unique_id_interleaved_map.empty()) {
207     return;
208   }
209 
210   for (const auto &cnode : origin_nodes_topological) {
211     if (!cnode->HasAttr(kAttrDuplicated)) {
212       continue;
213     }
214     if (!cnode->HasPrimalAttr(kPrimalAttrForwardCommNodeUniqueId)) {
215       continue;
216     }
217     auto duplicate_comm_node_unique_id =
218       GetValue<std::string>(cnode->GetPrimalAttr(kPrimalAttrForwardCommNodeUniqueId));
219     if (unique_id_interleaved_map.find(duplicate_comm_node_unique_id) == unique_id_interleaved_map.end()) {
220       continue;
221     }
222     CreateGroupForMicroInterleaved(cnode, unique_id_interleaved_map[duplicate_comm_node_unique_id]);
223   }
224 }
225 
MicroInterleavedOrderControl(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological,int pipeline_micro=-1)226 void MicroInterleavedOrderControl(const FuncGraphManagerPtr &manager,
227                                   const std::vector<CNodePtr> &origin_nodes_topological, int pipeline_micro = -1) {
228   // 1 order forward_node, and the node with same MICRO_INTERLEAVED_FORWARD_COMM_ORDER is the micro interleaved pair
229   // nodes.
230   interleaved_node_pair_vector micro_interleaved_forward_node_list;
231   if (!ExtractInterLeavedCommNode(origin_nodes_topological, true, &micro_interleaved_forward_node_list,
232                                   pipeline_micro)) {
233     MS_LOG(INFO) << "Cannot match micro interleaved conditions.";
234     return;
235   }
236   interleaved_node_pair_vector micro_interleaved_backward_node_list;
237   if (!ExtractInterLeavedCommNode(origin_nodes_topological, false, &micro_interleaved_backward_node_list,
238                                   pipeline_micro)) {
239     MS_LOG(INFO) << "Cannot match micro interleaved conditions.";
240     return;
241   }
242 
243   if (micro_interleaved_forward_node_list.empty() || micro_interleaved_backward_node_list.empty()) {
244     MS_LOG(INFO) << "Cannot find micro interleaved nodes.";
245     return;
246   }
247   for (auto &pair : micro_interleaved_forward_node_list) {
248     auto cnode_list = pair.second;
249     if (!CheckCommNodeEqual(cnode_list[0], cnode_list[1])) {
250       MS_LOG(INFO) << cnode_list[0]->DebugString() << " and " << cnode_list[1]->DebugString()
251                    << " not match for micro interleaved.";
252 
253       return;
254     }
255   }
256   for (auto &pair : micro_interleaved_backward_node_list) {
257     auto cnode_list = pair.second;
258     if (!CheckCommNodeEqual(cnode_list[0], cnode_list[1])) {
259       MS_LOG(INFO) << cnode_list[0]->DebugString() << " and " << cnode_list[1]->fullname_with_scope()
260                    << " not match for micro interleaved.";
261       return;
262     }
263   }
264   static const auto micro_interleaved_extra_comm_group = (common::GetEnv("interleaved_extra_group") == "1");
265   if (micro_interleaved_extra_comm_group) {
266     CreateExtraGroupForModelParallelCommNode(origin_nodes_topological, micro_interleaved_forward_node_list);
267     CreateExtraGroupForModelParallelCommNode(origin_nodes_topological, micro_interleaved_backward_node_list);
268   }
269   InsertInterleavedNodesDepend(manager, micro_interleaved_forward_node_list);
270   InsertInterleavedNodesDepend(manager, micro_interleaved_backward_node_list);
271 }
272 
MicroInterleavedOrderControlPipeline(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & origin_nodes_topological)273 void MicroInterleavedOrderControlPipeline(const FuncGraphManagerPtr &manager,
274                                           const std::vector<CNodePtr> &origin_nodes_topological) {
275   // 1 order forward_node, and the node with same MICRO_INTERLEAVED_FORWARD_COMM_ORDER is the micro interleaved pair
276   // nodes.
277   MS_EXCEPTION_IF_NULL(parallel::g_device_manager);
278   size_t pipeline_micro_size = parallel::ParallelContext::GetInstance()->pipeline_micro_size();
279   MS_LOG(INFO) << "The pipeline micro size in micro interleaved is: " << pipeline_micro_size;
280   for (size_t pipeline_micro_id = 0; pipeline_micro_id < pipeline_micro_size; ++pipeline_micro_id) {
281     MicroInterleavedOrderControl(manager, origin_nodes_topological, pipeline_micro_id);
282   }
283   return;
284 }
285 }  // namespace
286 
FullMicroInterleavedOrderControl(const FuncGraphPtr & graph)287 void FullMicroInterleavedOrderControl(const FuncGraphPtr &graph) {
288   if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
289       parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
290     return;
291   }
292   auto context = MsContext::GetInstance();
293   MS_EXCEPTION_IF_NULL(context);
294   const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
295   if (cell_reuse) {
296     return;
297   }
298   if (!parallel::ParallelContext::GetInstance()->enable_micro_interleaved()) {
299     return;
300   }
301   if (common::GetEnv("MS_ENABLE_FRONTEND_SCHEDULING_OPTIMIZATION") == "1") {
302     return;
303   }
304   MS_EXCEPTION_IF_NULL(graph);
305   auto manager = graph->manager();
306   MS_EXCEPTION_IF_NULL(manager);
307   std::list<CNodePtr> orders = graph->GetOrderedCnodes();
308   std::vector<CNodePtr> origin_nodes_topological(orders.cbegin(), orders.cend());
309   if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() == 1) {
310     MicroInterleavedOrderControl(manager, origin_nodes_topological);
311     return;
312   }
313   MicroInterleavedOrderControlPipeline(manager, origin_nodes_topological);
314 }
315 }  // namespace parallel
316 }  // namespace mindspore
317