1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <iterator>
18 #include <memory>
19 #include <list>
20 #include <set>
21 #include <algorithm>
22 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
23 #include "frontend/parallel/graph_util/generate_graph.h"
24 #include "base/core_ops.h"
25 #include "ir/value.h"
26 #include "frontend/parallel/ops_info/ops_utils.h"
27 #include "frontend/parallel/device_manager.h"
28 #include "frontend/parallel/context.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/graph_util/node_info.h"
31 #include "utils/parallel_node_check.h"
32
33 namespace mindspore {
34 namespace parallel {
35 const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {
36 prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, prim::kPrimSoftmaxCrossEntropyWithLogits,
37 prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimReshape};
38
IsInEndNodeBlackList(const CNodePtr & cnode)39 static bool IsInEndNodeBlackList(const CNodePtr &cnode) {
40 MS_EXCEPTION_IF_NULL(cnode);
41 if (!IsValueNode<Primitive>(cnode->input(0))) {
42 return true;
43 }
44 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
45 if (IsInParallelBlackList(prim)) {
46 return true;
47 }
48 for (auto &prim_node : END_NODE_BLACK_LIST) {
49 if (IsPrimitiveCNode(cnode, prim_node)) {
50 return true;
51 }
52 }
53 return false;
54 }
55
FindAccuGrad(const CNodePtr & cnode)56 AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
57 auto pre_node = cnode->input(1);
58 size_t depth = 0;
59 while (true) {
60 if (depth > MAX_RECURSIVE_DEPTH) {
61 return nullptr;
62 }
63 depth += 1;
64 if (pre_node->isa<Parameter>()) {
65 return pre_node;
66 } else {
67 if (pre_node->isa<CNode>()) {
68 auto pre_cnode = pre_node->cast<CNodePtr>();
69 pre_node = pre_cnode->input(1);
70 } else {
71 return nullptr;
72 }
73 }
74 }
75 return nullptr;
76 }
77
IsLastStage()78 bool IsLastStage() {
79 MS_EXCEPTION_IF_NULL(g_device_manager);
80 auto stage_num = g_device_manager->stage_num();
81 auto stage_id = g_device_manager->stage_id();
82 return ((stage_num - 1) == stage_id);
83 }
84
SetStridedSliceStrategy(const AnfNodePtr & node)85 void SetStridedSliceStrategy(const AnfNodePtr &node) {
86 MS_EXCEPTION_IF_NULL(node);
87 if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
88 return;
89 }
90 auto cnode = node->cast<CNodePtr>();
91 MS_EXCEPTION_IF_NULL(cnode);
92 std::vector<Shapes> shape_list = ExtractShape(cnode);
93 if (shape_list.empty()) {
94 MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " failed to extract shape";
95 }
96 std::vector<ValuePtr> elements;
97 for (size_t i = 0; i < shape_list[0].size(); i++) {
98 if (shape_list[0][i].empty()) {
99 MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
100 }
101 Dimensions input_strategy;
102 for (size_t j = 0; j < shape_list[0][i].size(); j++) {
103 input_strategy.push_back(1);
104 }
105 elements.push_back(MakeValue(input_strategy));
106 }
107 ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
108 cnode->AddPrimalAttr(STRATEGY, strategy);
109 }
110
InsertVirtualAssignAdd(const std::pair<AnfNodePtr,int> & node_user,const FuncGraphManagerPtr & manager,const AnfNodePtr & accu_parameter)111 void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
112 const AnfNodePtr &accu_parameter) {
113 auto cnode = node_user.first->cast<CNodePtr>();
114 if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
115 return;
116 }
117 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
118 bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
119 if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer) {
120 return;
121 }
122 auto prim = GetCNodePrimitive(cnode);
123 if (prim == nullptr) {
124 MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd.";
125 return;
126 }
127 OperatorAttrs attrs;
128 auto py_instance = CreatOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
129 auto value_node = NewValueNode(py_instance);
130 std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
131 auto graph = cnode->func_graph();
132 auto virtual_node = graph->NewCNode(virtual_node_input);
133 manager->SetEdge(cnode, node_user.second, virtual_node);
134 }
135
InsertVirtualAccuGrad(const AnfNodePtr & recv,const FuncGraphManagerPtr & manager,const AnfNodePtr & param)136 void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr ¶m) {
137 auto cnode = recv->cast<CNodePtr>();
138 MS_EXCEPTION_IF_NULL(cnode);
139 OperatorAttrs attrs;
140 auto py_instance = CreatOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
141 auto value_node = NewValueNode(py_instance);
142 std::vector<AnfNodePtr> virtual_node_input = {value_node, recv, param};
143 auto graph = cnode->func_graph();
144 MS_EXCEPTION_IF_NULL(graph);
145 auto virtual_node = graph->NewCNode(virtual_node_input);
146 (void)manager->Replace(recv, virtual_node);
147 }
148
FindGradAccuParameter(const std::vector<AnfNodePtr> & parameters,const std::string & name)149 AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> ¶meters, const std::string &name) {
150 for (auto ¶meter : parameters) {
151 auto param_ptr = parameter->cast<ParameterPtr>();
152 MS_EXCEPTION_IF_NULL(param_ptr);
153 if (param_ptr->name() == name) {
154 continue;
155 }
156 auto expect_name = "accu_grads." + name;
157 if (param_ptr->name() == expect_name) {
158 return parameter;
159 }
160 }
161 return nullptr;
162 }
163
HandleReceiveParam(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & all_nodes)164 void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
165 auto parameters = root->parameters();
166 auto node_users_map = root->manager()->node_users();
167 for (auto &node : all_nodes) {
168 if (!IsPrimitiveCNode(node, prim::kPrimReceive)) {
169 continue;
170 }
171 auto cnode = node->cast<CNodePtr>();
172 if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
173 continue;
174 }
175 auto parameter_ptr = cnode->input(1)->cast<ParameterPtr>();
176 MS_EXCEPTION_IF_NULL(parameter_ptr);
177 auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
178 if (!accu_parameter) {
179 continue;
180 }
181 auto node_users = node_users_map[node];
182 for (auto &temp_user : node_users) {
183 auto temp_node = temp_user.first;
184 // Micro virtual operator might be inserted after cast
185 if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
186 temp_node = node_users_map[temp_node].begin()->first;
187 }
188 if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
189 IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
190 auto node_set = node_users_map[temp_node];
191 for (auto &node_user : node_set) {
192 InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
193 }
194 } else {
195 InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
196 }
197 }
198 InsertVirtualAccuGrad(node, root->manager(), accu_parameter);
199 }
200 }
201
AddVirtualAssignAdd(const FuncGraphPtr & root)202 void AddVirtualAssignAdd(const FuncGraphPtr &root) {
203 auto parameters = root->parameters();
204 auto node_users_map = root->manager()->node_users();
205 for (auto ¶meter : parameters) {
206 auto parameter_ptr = parameter->cast<ParameterPtr>();
207 auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
208 if (!accu_parameter) {
209 continue;
210 }
211 auto node_users = node_users_map[parameter];
212 for (auto &temp_user : node_users) {
213 auto temp_node = temp_user.first;
214 // Micro virtual operator might be inserted after cast
215 if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
216 temp_node = node_users_map[temp_node].begin()->first;
217 }
218 if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
219 IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
220 auto node_set = node_users_map[temp_node];
221 for (auto &node_user : node_set) {
222 InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
223 }
224 } else {
225 InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
226 }
227 }
228 }
229 }
230
CompFunc(const AnfNodePtr & node1,const AnfNodePtr & node2)231 bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
232 MS_EXCEPTION_IF_NULL(node1);
233 MS_EXCEPTION_IF_NULL(node2);
234 auto cnode1 = node1->cast<CNodePtr>();
235 auto cnode2 = node2->cast<CNodePtr>();
236 MS_EXCEPTION_IF_NULL(cnode1);
237 MS_EXCEPTION_IF_NULL(cnode2);
238 auto micro1 = cnode1->GetPrimalAttr(MICRO);
239 auto micro2 = cnode2->GetPrimalAttr(MICRO);
240 MS_EXCEPTION_IF_NULL(micro1);
241 MS_EXCEPTION_IF_NULL(micro2);
242 auto micro1_value = GetValue<int64_t>(micro1);
243 auto micro2_value = GetValue<int64_t>(micro2);
244 if (micro1_value == micro2_value) {
245 auto prim1 = GetCNodePrimitive(cnode1);
246 auto prim2 = GetCNodePrimitive(cnode2);
247 MS_EXCEPTION_IF_NULL(prim1);
248 MS_EXCEPTION_IF_NULL(prim2);
249 auto rank_tag1 = prim1->GetAttr(SRC_RANK);
250 auto rank_tag2 = prim2->GetAttr(SRC_RANK);
251 if (rank_tag1 == nullptr) {
252 rank_tag1 = prim1->GetAttr(DEST_RANK);
253 }
254 if (rank_tag2 == nullptr) {
255 rank_tag2 = prim2->GetAttr(DEST_RANK);
256 }
257 if (!rank_tag1 || !rank_tag2) {
258 return false;
259 }
260 auto rank1_value = GetValue<int64_t>(rank_tag1);
261 auto rank2_value = GetValue<int64_t>(rank_tag2);
262 if (rank1_value == rank2_value) {
263 auto sr_tag1 = prim1->GetAttr(SR_TAG);
264 auto sr_tag2 = prim2->GetAttr(SR_TAG);
265 MS_EXCEPTION_IF_NULL(sr_tag1);
266 MS_EXCEPTION_IF_NULL(sr_tag2);
267 auto sr1_value = GetValue<int64_t>(sr_tag1);
268 auto sr2_value = GetValue<int64_t>(sr_tag2);
269 return sr1_value < sr2_value;
270 }
271 return rank1_value < rank2_value;
272 }
273 return micro1_value < micro2_value;
274 }
275
InsertDepend(const AnfNodePtr & prior_node,const AnfNodePtr & post_node,const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)276 void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
277 const FuncGraphPtr &root) {
278 MS_EXCEPTION_IF_NULL(prior_node);
279 MS_EXCEPTION_IF_NULL(post_node);
280 auto post_cnode = post_node->cast<CNodePtr>();
281 MS_EXCEPTION_IF_NULL(post_cnode);
282 std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node};
283 auto depend_node = root->NewCNode(depend_input);
284 manager->SetEdge(post_node, 1, depend_node);
285 }
286
ReorderForForward(const std::vector<AnfNodePtr> & forward_start,const std::vector<AnfNodePtr> & forward_end,const FuncGraphPtr & root)287 void ReorderForForward(const std::vector<AnfNodePtr> &forward_start, const std::vector<AnfNodePtr> &forward_end,
288 const FuncGraphPtr &root) {
289 MS_EXCEPTION_IF_NULL(g_device_manager);
290 MS_EXCEPTION_IF_NULL(root);
291 auto manager = root->manager();
292 MS_EXCEPTION_IF_NULL(manager);
293 auto stage_num = g_device_manager->stage_num();
294 auto stage_id = g_device_manager->stage_id();
295 for (size_t i = 1; i < LongToSize(stage_num - stage_id); ++i) {
296 auto prior_node = forward_end[i - 1];
297 auto post_node = forward_start[i];
298 InsertDepend(prior_node, post_node, manager, root);
299 }
300 }
301
ReorderForBackward(const PipelinePair & forward_start_pair,const PipelinePair & forward_end_pair,const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,const PipelinePair & forward_end_before_pair,const FuncGraphPtr & root)302 void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
303 const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
304 const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root) {
305 MS_EXCEPTION_IF_NULL(g_device_manager);
306 MS_EXCEPTION_IF_NULL(root);
307 auto manager = root->manager();
308 MS_EXCEPTION_IF_NULL(manager);
309 auto stage_num = g_device_manager->stage_num();
310 auto stage_id = g_device_manager->stage_id();
311 for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size()); ++i) {
312 auto prior_node1 = forward_end_before_pair.second[i];
313 auto post_node1 = backward_start_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id + 1)];
314 InsertDepend(prior_node1, post_node1, manager, root);
315 auto prior_node2 = backward_end_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)];
316 auto post_node2 = forward_start_pair.first[i];
317 InsertDepend(prior_node2, post_node2, manager, root);
318 }
319 for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size() + 1); ++i) {
320 if (!IsLastStage()) {
321 auto prior_node3 = backward_start_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)];
322 auto post_node3 = forward_end_pair.first[i - 1];
323 InsertDepend(prior_node3, post_node3, manager, root);
324 auto prior_node4 = forward_end_pair.second[i - 1];
325 auto post_node4 = backward_end_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id)];
326 InsertDepend(prior_node4, post_node4, manager, root);
327 }
328 }
329 for (size_t j = LongToSize(SizeToLong(backward_start_pair.first.size()) - stage_num + stage_id + 1);
330 j < backward_start_pair.first.size(); ++j) {
331 auto prior_node5 = backward_end_pair.second[j - 1];
332 auto post_node5 = backward_start_pair.first[j];
333 InsertDepend(prior_node5, post_node5, manager, root);
334 }
335 if (!IsLastStage()) {
336 auto prior_node6 = forward_end_before_pair.second[LongToSize(stage_num - 1 - stage_id)];
337 auto post_node6 = backward_start_pair.first[0];
338 InsertDepend(prior_node6, post_node6, manager, root);
339 }
340 }
341
ReorderForParams(const std::vector<AnfNodePtr> & backward_params,const std::vector<AnfNodePtr> & forward_params,const std::vector<AnfNodePtr> & allreduce_params,const PipelinePair & forward_params_pair,const PipelinePair & backward_params_pair,const std::vector<AnfNodePtr> & backward_end,const PipelinePair & forward_start_pair,const FuncGraphPtr & root)342 void ReorderForParams(const std::vector<AnfNodePtr> &backward_params, const std::vector<AnfNodePtr> &forward_params,
343 const std::vector<AnfNodePtr> &allreduce_params, const PipelinePair &forward_params_pair,
344 const PipelinePair &backward_params_pair, const std::vector<AnfNodePtr> &backward_end,
345 const PipelinePair &forward_start_pair, const FuncGraphPtr &root) {
346 auto manager = root->manager();
347 MS_EXCEPTION_IF_NULL(manager);
348 if (!forward_params.empty()) {
349 auto prior_node = forward_params_pair.second[0];
350 auto post_node = forward_start_pair.first[0];
351 InsertDepend(prior_node, post_node, manager, root);
352 }
353 if (!backward_params.empty()) {
354 if (!allreduce_params.empty()) {
355 for (auto &node : allreduce_params) {
356 auto post_node1 = backward_params_pair.first[0];
357 InsertDepend(node, post_node1, manager, root);
358 }
359 }
360 auto prior_node2 = backward_end.back();
361 auto post_node2 = backward_params[0];
362 InsertDepend(prior_node2, post_node2, manager, root);
363 }
364 }
365
GetMicroBatch(const AnfNodePtr & node)366 int64_t GetMicroBatch(const AnfNodePtr &node) {
367 MS_EXCEPTION_IF_NULL(node);
368 auto cnode = node->cast<CNodePtr>();
369 MS_EXCEPTION_IF_NULL(cnode);
370 auto micro_value = cnode->GetPrimalAttr(MICRO);
371 MS_EXCEPTION_IF_NULL(micro_value);
372 return GetValue<int64_t>(micro_value);
373 }
374
Deduplicate(const std::vector<AnfNodePtr> & node_vector,const FuncGraphPtr & root,int64_t micro_max)375 PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max) {
376 std::vector<AnfNodePtr> temp_vec;
377 std::vector<AnfNodePtr> out_vec_begin;
378 std::vector<AnfNodePtr> out_vec_end;
379 auto manager = root->manager();
380 for (int64_t i = 0; i <= micro_max; ++i) {
381 temp_vec.clear();
382 if (!root->has_flag(TRAINING)) {
383 temp_vec = node_vector;
384 } else {
385 for (auto &node : node_vector) {
386 auto node_micro = GetMicroBatch(node);
387 if (node_micro == i) {
388 temp_vec.push_back(node);
389 }
390 }
391 }
392 if (temp_vec.size() <= 1) {
393 MS_LOG(INFO) << "No Duplicate MicroBatch.";
394 continue;
395 }
396 std::sort(temp_vec.begin(), temp_vec.end(), CompFunc);
397 for (size_t j = 0; j < temp_vec.size() - 1; ++j) {
398 auto prior_node = temp_vec[j];
399 auto post_node = temp_vec[j + 1];
400 InsertDepend(prior_node, post_node, manager, root);
401 }
402 if (!temp_vec.empty()) {
403 out_vec_begin.push_back(temp_vec.front());
404 out_vec_end.push_back(temp_vec.back());
405 }
406 }
407 if (out_vec_begin.empty()) {
408 return std::make_pair(node_vector, node_vector);
409 }
410 return std::make_pair(out_vec_begin, out_vec_end);
411 }
412
BroadCastMicroBatch(const CNodePtr & node,NodeUsersMap * node_users_map,const ValuePtr & value,size_t max_depth)413 void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth) {
414 auto node_users = (*node_users_map)[node];
415 if (max_depth > MAX_RECURSIVE_DEPTH) {
416 MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
417 }
418 for (auto &node_pair : node_users) {
419 auto user_node = node_pair.first->cast<CNodePtr>();
420 if (user_node->HasPrimalAttr(MICRO)) {
421 continue;
422 }
423 user_node->AddPrimalAttr(MICRO, value);
424 BroadCastMicroBatch(user_node, node_users_map, value, max_depth + 1);
425 }
426 }
427
BroadCastNeedGrad(const AnfNodePtr & node,NodeUsersMap * node_user_map,const FuncGraphPtr & root)428 void BroadCastNeedGrad(const AnfNodePtr &node, NodeUsersMap *node_user_map, const FuncGraphPtr &root) {
429 auto node_users = (*node_user_map)[node];
430 for (auto &node_user : node_users) {
431 auto cnode = node_user.first->cast<CNodePtr>();
432 MS_EXCEPTION_IF_NULL(cnode);
433 if (cnode->HasPrimalAttr(NEED_GRAD)) {
434 continue;
435 }
436 if (cnode->func_graph() == root) {
437 continue;
438 }
439 cnode->AddPrimalAttr(NEED_GRAD, MakeValue(1));
440 BroadCastNeedGrad(cnode, node_user_map, root);
441 }
442 }
443
444 // Label node that need backpropagation
LabelNeedGrad(const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)445 void LabelNeedGrad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root) {
446 auto parameters = root->parameters();
447 auto node_user_map = manager->node_users();
448 for (auto ¶meter : parameters) {
449 if (!ParameterRequireGrad(parameter)) {
450 continue;
451 }
452 auto param_ptr = parameter->cast<ParameterPtr>();
453 MS_EXCEPTION_IF_NULL(param_ptr);
454 if (param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
455 continue;
456 }
457 BroadCastNeedGrad(parameter, &node_user_map, root);
458 }
459 }
460
GetPreNode(const AnfNodePtr & node)461 AnfNodePtr GetPreNode(const AnfNodePtr &node) {
462 auto cnode = node->cast<CNodePtr>();
463 MS_EXCEPTION_IF_NULL(cnode);
464 std::vector<AnfNodePtr> node_queue = {node};
465 while (!node_queue.empty()) {
466 auto cur_node = (*node_queue.begin())->cast<CNodePtr>();
467 if (!cur_node) {
468 (void)node_queue.erase(node_queue.begin());
469 continue;
470 }
471 (void)node_queue.erase(node_queue.begin());
472 if (!IsInEndNodeBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) {
473 MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
474 return cur_node;
475 }
476 (void)node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end());
477 }
478 MS_LOG(EXCEPTION) << "Get Pipeline End node failed.";
479 }
480
LastStageEndNode(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)481 void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager,
482 const FuncGraphPtr &root) {
483 if (!IsLastStage()) {
484 return;
485 }
486 LabelNeedGrad(manager, root);
487 for (auto &node : all_nodes) {
488 if (!node->isa<CNode>()) {
489 continue;
490 }
491 auto cnode = node->cast<CNodePtr>();
492 if (!cnode->HasPrimalAttr(MICRO)) {
493 continue;
494 }
495 auto prim = GetCNodePrimitive(node);
496 if (prim && prim->HasAttr(PIPELINE_END)) {
497 for (auto &temp_node : cnode->inputs()) {
498 if (!temp_node->isa<CNode>()) {
499 continue;
500 }
501 auto temp_prim = GetCNodePrimitive(temp_node);
502 if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
503 continue;
504 }
505 auto end_node = GetPreNode(temp_node);
506 MS_EXCEPTION_IF_NULL(end_node);
507 auto end_cnode = end_node->cast<CNodePtr>();
508 MS_EXCEPTION_IF_NULL(end_cnode);
509 auto end_prim = GetCNodePrimitive(end_node);
510 OperatorAttrs attrs_;
511 auto op = CreatOpInstance(attrs_, end_prim->name(), "");
512 auto value_node = NewValueNode(op);
513 auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
514 (void)new_prim->SetAttrs(end_prim->attrs());
515 manager->SetEdge(end_node, 0, value_node);
516 end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO));
517 }
518 }
519 }
520 }
521
Micro(const CNodePtr & cnode,NodeUsersMap * node_users_map,size_t max_depth)522 ValuePtr Micro(const CNodePtr &cnode, NodeUsersMap *node_users_map, size_t max_depth) {
523 if (max_depth > MAX_RECURSIVE_DEPTH) {
524 MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
525 }
526 if (cnode->HasPrimalAttr(MICRO)) {
527 return cnode->GetPrimalAttr(MICRO);
528 }
529 auto node_users = (*node_users_map)[cnode];
530 for (auto &node_pair : node_users) {
531 auto user_node = node_pair.first->cast<CNodePtr>();
532 auto micro = Micro(user_node, node_users_map, max_depth + 1);
533 if (micro) {
534 return micro;
535 }
536 }
537 return nullptr;
538 }
539
ParameterStartNode(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)540 void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
541 auto node_users_map = manager->node_users();
542 for (auto &node : all_nodes) {
543 if (!node->isa<CNode>()) {
544 continue;
545 }
546 auto cnode = node->cast<CNodePtr>();
547 auto prim = GetCNodePrimitive(node);
548 if (prim && prim->HasAttr(PARAMETER_START)) {
549 auto micro = Micro(cnode, &node_users_map, 0);
550 cnode->AddPrimalAttr(MICRO, micro);
551 cnode->AddPrimalAttr(PARAMETER_START, micro);
552 }
553 }
554 }
555
HandleMicroBatch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)556 void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
557 auto node_users_map = manager->node_users();
558 for (auto &node : all_nodes) {
559 if (!node->isa<CNode>()) {
560 continue;
561 }
562 auto cnode = node->cast<CNodePtr>();
563 if (!cnode->HasPrimalAttr(MICRO)) {
564 continue;
565 }
566 auto micro = cnode->GetPrimalAttr(MICRO);
567 MS_EXCEPTION_IF_NULL(micro);
568 BroadCastMicroBatch(cnode, &node_users_map, micro, 0);
569 }
570 }
571
GetActualOp(const AnfNodePtr & node)572 AnfNodePtr GetActualOp(const AnfNodePtr &node) {
573 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
574 auto cnode = node->cast<CNodePtr>();
575 return cnode->input(1);
576 }
577 return node;
578 }
579
GetBorderNode(std::vector<AnfNodePtr> * forward_start,std::vector<AnfNodePtr> * forward_end,std::vector<AnfNodePtr> * backward_start,std::vector<AnfNodePtr> * backward_end,std::vector<AnfNodePtr> * forward_params,std::vector<AnfNodePtr> * backward_params,std::vector<AnfNodePtr> * allreduce_params,const FuncGraphPtr & root)580 void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
581 std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
582 std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
583 std::vector<AnfNodePtr> *allreduce_params, const FuncGraphPtr &root) {
584 std::list<ValuePtr> name_list = {};
585 auto stage_id = g_device_manager->stage_id();
586 for (auto &node : root->nodes()) {
587 if (!node->isa<CNode>()) {
588 continue;
589 }
590 if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimZerosLike)) {
591 continue;
592 }
593 auto prim = GetCNodePrimitive(node);
594 auto cnode = node->cast<CNodePtr>();
595 if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
596 auto forward_node_name = cnode->GetPrimalAttr(kPrimalAttrForwardNodeName);
597 if (std::find(name_list.begin(), name_list.end(), forward_node_name) != name_list.end()) {
598 continue;
599 }
600 name_list.push_back(forward_node_name);
601 if (cnode->HasPrimalAttr(PIPELINE_END)) {
602 backward_start->push_back(node);
603 }
604 if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
605 backward_end->push_back(node);
606 }
607 if (cnode->HasPrimalAttr(PARAMETER_START)) {
608 backward_end->push_back(node);
609 }
610 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
611 backward_params->push_back(node);
612 }
613 if (prim->HasAttr(PARAMETER_MICRO)) {
614 allreduce_params->push_back(node);
615 }
616 } else {
617 if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
618 if (stage_id != 0 && IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
619 continue;
620 }
621 forward_start->push_back(node);
622 }
623 if (cnode->HasPrimalAttr(PIPELINE_END)) {
624 forward_end->push_back(node);
625 }
626 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
627 forward_params->push_back(node);
628 }
629 }
630 }
631 std::sort((*backward_start).begin(), (*backward_start).end(), CompFunc);
632 std::sort((*backward_end).begin(), (*backward_end).end(), CompFunc);
633 std::sort((*forward_start).begin(), (*forward_start).end(), CompFunc);
634 std::sort((*forward_end).begin(), (*forward_end).end(), CompFunc);
635 std::sort((*backward_params).begin(), (*backward_params).end(), CompFunc);
636 std::sort((*forward_params).begin(), (*forward_params).end(), CompFunc);
637 }
638
CheckBorderNode(const PipelinePair & forward_start_pair,const PipelinePair & forward_end_pair,const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,size_t micro_size)639 void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
640 const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
641 size_t micro_size) {
642 micro_size = micro_size + 1;
643 if (forward_start_pair.first.size() != micro_size) {
644 MS_LOG(EXCEPTION) << "forward_node's size:" << forward_start_pair.first.size()
645 << "is not equal to micro size:" << micro_size;
646 }
647 if (forward_end_pair.first.size() != micro_size) {
648 MS_LOG(EXCEPTION) << "forward_node's size:" << forward_end_pair.first.size()
649 << "is not equal to micro size:" << micro_size;
650 }
651 if (backward_start_pair.first.size() != micro_size) {
652 MS_LOG(EXCEPTION) << "backward_node's size:" << backward_start_pair.first.size()
653 << "is not equal to micro size:" << micro_size;
654 }
655 if (backward_end_pair.first.size() != micro_size) {
656 MS_LOG(EXCEPTION) << "backward_node's size:" << backward_end_pair.first.size()
657 << "is not equal to micro size:" << micro_size;
658 }
659 }
660
Reorder(const FuncGraphPtr & root)661 void Reorder(const FuncGraphPtr &root) {
662 std::vector<AnfNodePtr> forward_start;
663 std::vector<AnfNodePtr> forward_end;
664 std::vector<AnfNodePtr> forward_params;
665 std::vector<AnfNodePtr> backward_start;
666 std::vector<AnfNodePtr> backward_end;
667 std::vector<AnfNodePtr> backward_params;
668 std::vector<AnfNodePtr> allreduce_params;
669 GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
670 &allreduce_params, root);
671 int64_t micro_max = 0;
672 if (root->has_flag(TRAINING)) {
673 auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
674 auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
675 MS_EXCEPTION_IF_NULL(micro_size);
676 micro_max = GetValue<int64_t>(micro_size);
677 }
678 auto backward_start_pair = Deduplicate(backward_start, root, micro_max);
679 auto backward_end_pair = Deduplicate(backward_end, root, micro_max);
680 auto forward_start_pair = Deduplicate(forward_start, root, micro_max);
681 auto forward_end_pair = Deduplicate(forward_end, root, micro_max);
682 auto forward_params_pair = Deduplicate(forward_params, root, micro_max);
683 auto backward_params_pair = Deduplicate(backward_params, root, micro_max);
684 CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, LongToSize(micro_max));
685 PipelinePair forward_end_before_pair;
686 if (!IsLastStage()) {
687 for (auto &node : forward_end_pair.first) {
688 auto cnode = node->cast<CNodePtr>();
689 auto temp_node = GetActualOp(cnode->input(1));
690 MS_EXCEPTION_IF_NULL(temp_node);
691 forward_end_before_pair.first.push_back(temp_node);
692 }
693 for (auto &node : forward_end_pair.second) {
694 auto cnode = node->cast<CNodePtr>();
695 auto temp_node = GetActualOp(cnode->input(1));
696 MS_EXCEPTION_IF_NULL(temp_node);
697 forward_end_before_pair.second.push_back(temp_node);
698 }
699 } else {
700 forward_end_before_pair = forward_end_pair;
701 }
702 ReorderForForward(forward_start_pair.first, forward_end_pair.second, root);
703 ReorderForBackward(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair,
704 forward_end_before_pair, root);
705 ReorderForParams(backward_params, forward_params, allreduce_params, forward_params_pair, backward_params_pair,
706 backward_end, forward_start_pair, root);
707 }
708
ReorderForPredict(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)709 void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
710 std::vector<AnfNodePtr> forward_end;
711 std::vector<AnfNodePtr> forward_start;
712 std::vector<AnfNodePtr> forward_params;
713 for (auto &node : root->nodes()) {
714 if (!node->isa<CNode>()) {
715 continue;
716 }
717 auto cnode = node->cast<CNodePtr>();
718 if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
719 forward_start.push_back(node);
720 }
721 if (cnode->HasPrimalAttr(PIPELINE_END)) {
722 forward_end.push_back(node);
723 }
724 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
725 forward_params.push_back(node);
726 }
727 }
728 std::sort(forward_start.begin(), forward_start.end(), CompFunc);
729 std::sort(forward_end.begin(), forward_end.end(), CompFunc);
730 std::sort(forward_params.begin(), forward_params.end(), CompFunc);
731 auto forward_start_pair = Deduplicate(forward_start, root, 0);
732 auto forward_end_pair = Deduplicate(forward_end, root, 0);
733 auto forward_params_pair = Deduplicate(forward_params, root, 0);
734 if (!forward_end.empty() && !forward_params.empty()) {
735 InsertDepend(forward_params_pair.second[0], forward_end_pair.first[0], manager, root);
736 }
737 if (!forward_start.empty() && !forward_params.empty()) {
738 InsertDepend(forward_params_pair.second[0], forward_start_pair.first[0], manager, root);
739 }
740 }
741 } // namespace parallel
742 } // namespace mindspore
743