• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 #include <list>
19 #include <set>
20 #include <queue>
21 #include <algorithm>
22 #include "mindspore/core/ops/other_ops.h"
23 #include "mindspore/core/ops/array_ops.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
26 #include "frontend/parallel/graph_util/grad_accumulation_utils.h"
27 #include "frontend/parallel/parameter_manager.h"
28 #include "frontend/parallel/graph_util/generate_graph.h"
29 #include "ir/value.h"
30 #include "frontend/parallel/ops_info/ops_utils.h"
31 #include "include/common/utils/parallel_context.h"
32 #include "frontend/parallel/step_parallel.h"
33 #include "frontend/parallel/step_parallel_utils.h"
34 #include "utils/parallel_node_check.h"
35 
36 namespace mindspore {
37 namespace parallel {
38 constexpr char GRAD_ACCU_NUM[] = "grad_accu_num";
39 constexpr char GRAD_ACCU_FORWARD_BEGIN[] = "grad_accu_forward_begin";
40 constexpr char GRAD_ACCU_FORWARD_END[] = "grad_accu_forward_end";
41 constexpr char GRAD_ACCU_BACKWARD_END[] = "grad_accu_backward_end";
42 constexpr char FIRST_PARAMETER_CNODE[] = "first_parameter_cnode";
TagMicroBatchStart(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & all_nodes)43 void TagMicroBatchStart(const FuncGraphManagerPtr &manager, const std::vector<AnfNodePtr> &all_nodes) {
44   auto node_users_map = manager->node_users();
45   for (const auto &node : all_nodes) {
46     if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
47       continue;
48     }
49     auto slice_cnode = node->cast<CNodePtr>();
50     auto slice_prim = GetCNodePrimitive(slice_cnode);
51     if (!slice_prim->HasAttr(GRAD_ACCU_NUM)) {
52       continue;
53     }
54     auto accu_step = GetValue<int64_t>(slice_prim->GetAttr(GRAD_ACCU_NUM));
55     ParallelContext::GetInstance()->set_grad_accumulation_step(accu_step);
56     auto value = GetValueNode(slice_cnode->input(2));
57     MS_EXCEPTION_IF_NULL(value);
58     auto tuple = GetValue<std::vector<int64_t>>(value);
59     auto input_tmp = GetNodeShape(slice_cnode->input(1));
60     auto input_shape = input_tmp.at(0);
61     int64_t micro = tuple.at(0) * accu_step / input_shape.at(0);
62     slice_cnode->AddPrimalAttr(MICRO, MakeValue(micro));
63     slice_cnode->AddPrimalAttr(GRAD_ACCU_FORWARD_BEGIN, MakeValue(micro));
64     MS_LOG(INFO) << "Find grad accumulation begin node.";
65     BroadCastMicroBatch(slice_cnode, &node_users_map, MakeValue(micro), 0);
66   }
67 }
68 
TagMicroBatchEnd(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & all_nodes)69 void TagMicroBatchEnd(const FuncGraphManagerPtr &manager, const std::vector<AnfNodePtr> &all_nodes) {
70   for (const auto &node : all_nodes) {
71     if (!IsPrimitiveCNode(node)) {
72       continue;
73     }
74     auto end_cnode = node->cast<CNodePtr>();
75     auto end_prim = GetCNodePrimitive(end_cnode);
76     if (!end_prim->HasAttr(FORWARD_END)) {
77       continue;
78     }
79     if (ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !end_cnode->HasPrimalAttr(MICRO)) {
80       MS_LOG(EXCEPTION) << "Cannot find micro attribute for forward_end nodes";
81     }
82     for (size_t i = 0; i < end_cnode->size(); ++i) {
83       auto temp_node = GetRealKernelNode(end_cnode->input(i), -1, nullptr).first;
84       if (!temp_node->isa<CNode>()) {
85         continue;
86       }
87       auto temp_prim = GetCNodePrimitive(temp_node);
88       if (!temp_prim || temp_prim->HasAttr(FORWARD_END)) {
89         continue;
90       }
91       InsertVirtualPipelineEndNode(end_cnode, manager, i, GRAD_ACCU_FORWARD_END);
92     }
93   }
94 }
95 
SearchPreNodeMicro(const CNodePtr & cnode)96 ValuePtr SearchPreNodeMicro(const CNodePtr &cnode) {
97   if (cnode->HasPrimalAttr(MICRO)) {
98     return cnode->GetPrimalAttr(MICRO);
99   }
100   for (size_t i = 1; i < cnode->size(); ++i) {
101     if (!cnode->input(i)->isa<CNode>()) {
102       continue;
103     }
104     return SearchPreNodeMicro(cnode->input(i)->cast<CNodePtr>());
105   }
106   return nullptr;
107 }
108 
TagMicroBatchBpEndInCellShare(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)109 void TagMicroBatchBpEndInCellShare(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
110   auto node_users_map = manager->node_users();
111   AnfNodePtr ret = root->get_return();
112   MS_EXCEPTION_IF_NULL(ret);
113   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
114   for (const auto &node : all_nodes) {
115     if (!node->isa<CNode>()) {
116       continue;
117     }
118     auto cnode = node->cast<CNodePtr>();
119     if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimTupleGetItem)) {
120       continue;
121     }
122 
123     auto tuple_getitem_cnode = cnode->input(0)->cast<CNodePtr>();
124     auto tuple_getitem_cnode_input = tuple_getitem_cnode->input(1)->cast<CNodePtr>();
125     if (!tuple_getitem_cnode_input || !IsValueNode<FuncGraph>(tuple_getitem_cnode_input->input(0))) {
126       continue;
127     }
128     auto reuse_graph = GetValueNode<FuncGraphPtr>(tuple_getitem_cnode_input->input(0));
129     if (!reuse_graph->has_flag("no_inline")) {
130       continue;
131     }
132     MS_LOG(INFO) << "Find bp call func node:" << node->DebugString();
133 
134     auto micro = SearchPreNodeMicro(cnode);
135     if (!micro) {
136       MS_LOG(EXCEPTION) << "Cannot find micro info in cell share for node:" << node->DebugString();
137     }
138     const auto &users = node_users_map[node];
139     for (const auto &user : users) {
140       const auto &cuser = user.first->cast<CNodePtr>();
141       if (!cuser) {
142         continue;
143       }
144       if (IsPrimitiveCNode(cuser, prim::kPrimTupleGetItem) && IsValidNode(cuser, root->get_return(), node_users_map)) {
145         cuser->AddPrimalAttr(GRAD_ACCU_BACKWARD_END, micro);
146         break;
147       }
148     }
149   }
150 }
151 
TagMicroBatchBpEndPrim(const FuncGraphPtr & root)152 void TagMicroBatchBpEndPrim(const FuncGraphPtr &root) {
153   FuncGraphPtr parallel_care_graph = nullptr;
154   for (auto &fg : root->manager()->func_graphs()) {
155     for (auto &node : fg->nodes()) {
156       if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
157         parallel_care_graph = fg;
158         break;
159       }
160     }
161   }
162   if (!parallel_care_graph) {
163     MS_LOG(EXCEPTION) << "Cannot find parallel care graph with VirtualDataset";
164   }
165   bool is_found = false;
166   auto orders = parallel_care_graph->GetOrderedCnodes();
167   for (auto node = orders.cbegin(); node != orders.cend(); ++node) {
168     auto cnode = (*node)->cast<CNodePtr>();
169     MS_EXCEPTION_IF_NULL(cnode);
170     auto prim = GetCNodePrimitive(cnode);
171     if (!prim || !IsParallelConsiderCNode(cnode) ||
172         IsSomePrimitiveList(cnode, {prim::kPrimTupleGetItem->name(), prim::kPrimMakeTuple->name()})) {
173       continue;
174     }
175     for (size_t i = 1; i < cnode->size(); ++i) {
176       std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(i), parallel_care_graph);
177       if (param_node_pair.first) {
178         (void)prim->AddAttr(FIRST_PARAMETER_CNODE, MakeValue(0));
179         is_found = true;
180         break;
181       }
182     }
183     if (is_found) {
184       break;
185     }
186   }
187 }
188 
TagMicroBatchBpEnd(const FuncGraphPtr & root)189 void TagMicroBatchBpEnd(const FuncGraphPtr &root) {
190   AnfNodePtr ret = root->get_return();
191   MS_EXCEPTION_IF_NULL(ret);
192   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
193   for (const auto &node : all_nodes) {
194     if (!IsPrimitiveCNode(node)) {
195       continue;
196     }
197     auto cnode = node->cast<CNodePtr>();
198     auto prim = GetCNodePrimitive(cnode);
199     if (!prim->HasAttr(FIRST_PARAMETER_CNODE)) {
200       continue;
201     }
202     auto micro = SearchPreNodeMicro(cnode->cast<CNodePtr>());
203     if (!micro) {
204       MS_LOG(EXCEPTION) << "Cannot find micro info for node:" << node->DebugString();
205     }
206     cnode->AddPrimalAttr(GRAD_ACCU_BACKWARD_END, micro);
207   }
208 }
209 
ExtractMicroBatchBorderNodes(const FuncGraphPtr & root,std::unordered_map<int64_t,std::vector<CNodePtr>> * forward_start,std::unordered_map<int64_t,std::vector<CNodePtr>> * backward_end)210 void ExtractMicroBatchBorderNodes(const FuncGraphPtr &root,
211                                   std::unordered_map<int64_t, std::vector<CNodePtr>> *forward_start,
212                                   std::unordered_map<int64_t, std::vector<CNodePtr>> *backward_end) {
213   AnfNodePtr ret = root->get_return();
214   MS_EXCEPTION_IF_NULL(ret);
215   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
216   for (const auto &node : all_nodes) {
217     if (!IsPrimitiveCNode(node)) {
218       continue;
219     }
220     auto cnode = node->cast<CNodePtr>();
221     bool is_bp_node = cnode->HasPrimalAttr(kPrimalAttrForwardNodeName);
222     if (!is_bp_node && cnode->HasPrimalAttr(GRAD_ACCU_FORWARD_BEGIN)) {
223       auto accu_forward_begin_micro = GetValue<int64_t>(cnode->GetPrimalAttr(GRAD_ACCU_FORWARD_BEGIN));
224       (*forward_start)[accu_forward_begin_micro].push_back(cnode);
225     }
226     if ((is_bp_node || IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) &&
227         cnode->HasPrimalAttr(GRAD_ACCU_BACKWARD_END)) {
228       auto accu_backward_end_micro = GetValue<int64_t>(cnode->GetPrimalAttr(GRAD_ACCU_BACKWARD_END));
229       (*backward_end)[accu_backward_end_micro].push_back(cnode);
230     }
231   }
232 }
233 
ReorderGradAccumulation(const FuncGraphPtr & root,const std::unordered_map<int64_t,std::vector<CNodePtr>> & forward_start,const std::unordered_map<int64_t,std::vector<CNodePtr>> & backward_end)234 void ReorderGradAccumulation(const FuncGraphPtr &root,
235                              const std::unordered_map<int64_t, std::vector<CNodePtr>> &forward_start,
236                              const std::unordered_map<int64_t, std::vector<CNodePtr>> &backward_end) {
237   if (forward_start.empty() || backward_end.empty()) {
238     MS_LOG(EXCEPTION) << "Cannot find grad_accumulation border node.";
239   }
240   auto manager = root->manager();
241   for (int64_t micro = 0; micro < ParallelContext::GetInstance()->grad_accumulation_step() - 1; ++micro) {
242     if (forward_start.find(micro + 1) == forward_start.end()) {
243       MS_LOG(EXCEPTION) << "Micro " << micro + 1 << " cannot find forward_start nodes.";
244     }
245     if (backward_end.find(micro) == backward_end.end()) {
246       MS_LOG(EXCEPTION) << "Micro " << micro << " cannot find backward_end nodes.";
247     }
248     // backward_end -> depend -> next_forward_start
249     std::vector<AnfNodePtr> backward_end_inputs{NewValueNode(prim::kPrimMakeTuple)};
250     std::copy(backward_end.at(micro).begin(), backward_end.at(micro).end(), std::back_inserter(backward_end_inputs));
251     auto backward_end_make_tuple_cnode = root->NewCNode(backward_end_inputs);
252     for (const auto &forward_start_node : forward_start.at(micro + 1)) {
253       std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), forward_start_node->input(1),
254                                             backward_end_make_tuple_cnode};
255       auto depend_node = root->NewCNode(depend_inputs);
256       depend_node->AddAttr("grad_accu_reorder2", MakeValue(micro));
257       manager->SetEdge(forward_start_node, 1, depend_node);
258     }
259   }
260 }
261 }  // namespace parallel
262 }  // namespace mindspore
263