1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <memory>
18 #include <list>
19 #include <set>
20 #include <queue>
21 #include <algorithm>
22 #include "mindspore/core/ops/other_ops.h"
23 #include "mindspore/core/ops/array_ops.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
26 #include "frontend/parallel/graph_util/grad_accumulation_utils.h"
27 #include "frontend/parallel/parameter_manager.h"
28 #include "frontend/parallel/graph_util/generate_graph.h"
29 #include "ir/value.h"
30 #include "frontend/parallel/ops_info/ops_utils.h"
31 #include "include/common/utils/parallel_context.h"
32 #include "frontend/parallel/step_parallel.h"
33 #include "frontend/parallel/step_parallel_utils.h"
34 #include "utils/parallel_node_check.h"
35
36 namespace mindspore {
37 namespace parallel {
38 constexpr char GRAD_ACCU_NUM[] = "grad_accu_num";
39 constexpr char GRAD_ACCU_FORWARD_BEGIN[] = "grad_accu_forward_begin";
40 constexpr char GRAD_ACCU_FORWARD_END[] = "grad_accu_forward_end";
41 constexpr char GRAD_ACCU_BACKWARD_END[] = "grad_accu_backward_end";
42 constexpr char FIRST_PARAMETER_CNODE[] = "first_parameter_cnode";
TagMicroBatchStart(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & all_nodes)43 void TagMicroBatchStart(const FuncGraphManagerPtr &manager, const std::vector<AnfNodePtr> &all_nodes) {
44 auto node_users_map = manager->node_users();
45 for (const auto &node : all_nodes) {
46 if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
47 continue;
48 }
49 auto slice_cnode = node->cast<CNodePtr>();
50 auto slice_prim = GetCNodePrimitive(slice_cnode);
51 if (!slice_prim->HasAttr(GRAD_ACCU_NUM)) {
52 continue;
53 }
54 auto accu_step = GetValue<int64_t>(slice_prim->GetAttr(GRAD_ACCU_NUM));
55 ParallelContext::GetInstance()->set_grad_accumulation_step(accu_step);
56 auto value = GetValueNode(slice_cnode->input(2));
57 MS_EXCEPTION_IF_NULL(value);
58 auto tuple = GetValue<std::vector<int64_t>>(value);
59 auto input_tmp = GetNodeShape(slice_cnode->input(1));
60 auto input_shape = input_tmp.at(0);
61 int64_t micro = tuple.at(0) * accu_step / input_shape.at(0);
62 slice_cnode->AddPrimalAttr(MICRO, MakeValue(micro));
63 slice_cnode->AddPrimalAttr(GRAD_ACCU_FORWARD_BEGIN, MakeValue(micro));
64 MS_LOG(INFO) << "Find grad accumulation begin node.";
65 BroadCastMicroBatch(slice_cnode, &node_users_map, MakeValue(micro), 0);
66 }
67 }
68
TagMicroBatchEnd(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & all_nodes)69 void TagMicroBatchEnd(const FuncGraphManagerPtr &manager, const std::vector<AnfNodePtr> &all_nodes) {
70 for (const auto &node : all_nodes) {
71 if (!IsPrimitiveCNode(node)) {
72 continue;
73 }
74 auto end_cnode = node->cast<CNodePtr>();
75 auto end_prim = GetCNodePrimitive(end_cnode);
76 if (!end_prim->HasAttr(FORWARD_END)) {
77 continue;
78 }
79 if (ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !end_cnode->HasPrimalAttr(MICRO)) {
80 MS_LOG(EXCEPTION) << "Cannot find micro attribute for forward_end nodes";
81 }
82 for (size_t i = 0; i < end_cnode->size(); ++i) {
83 auto temp_node = GetRealKernelNode(end_cnode->input(i), -1, nullptr).first;
84 if (!temp_node->isa<CNode>()) {
85 continue;
86 }
87 auto temp_prim = GetCNodePrimitive(temp_node);
88 if (!temp_prim || temp_prim->HasAttr(FORWARD_END)) {
89 continue;
90 }
91 InsertVirtualPipelineEndNode(end_cnode, manager, i, GRAD_ACCU_FORWARD_END);
92 }
93 }
94 }
95
SearchPreNodeMicro(const CNodePtr & cnode)96 ValuePtr SearchPreNodeMicro(const CNodePtr &cnode) {
97 if (cnode->HasPrimalAttr(MICRO)) {
98 return cnode->GetPrimalAttr(MICRO);
99 }
100 for (size_t i = 1; i < cnode->size(); ++i) {
101 if (!cnode->input(i)->isa<CNode>()) {
102 continue;
103 }
104 return SearchPreNodeMicro(cnode->input(i)->cast<CNodePtr>());
105 }
106 return nullptr;
107 }
108
TagMicroBatchBpEndInCellShare(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)109 void TagMicroBatchBpEndInCellShare(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
110 auto node_users_map = manager->node_users();
111 AnfNodePtr ret = root->get_return();
112 MS_EXCEPTION_IF_NULL(ret);
113 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
114 for (const auto &node : all_nodes) {
115 if (!node->isa<CNode>()) {
116 continue;
117 }
118 auto cnode = node->cast<CNodePtr>();
119 if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimTupleGetItem)) {
120 continue;
121 }
122
123 auto tuple_getitem_cnode = cnode->input(0)->cast<CNodePtr>();
124 auto tuple_getitem_cnode_input = tuple_getitem_cnode->input(1)->cast<CNodePtr>();
125 if (!tuple_getitem_cnode_input || !IsValueNode<FuncGraph>(tuple_getitem_cnode_input->input(0))) {
126 continue;
127 }
128 auto reuse_graph = GetValueNode<FuncGraphPtr>(tuple_getitem_cnode_input->input(0));
129 if (!reuse_graph->has_flag("no_inline")) {
130 continue;
131 }
132 MS_LOG(INFO) << "Find bp call func node:" << node->DebugString();
133
134 auto micro = SearchPreNodeMicro(cnode);
135 if (!micro) {
136 MS_LOG(EXCEPTION) << "Cannot find micro info in cell share for node:" << node->DebugString();
137 }
138 const auto &users = node_users_map[node];
139 for (const auto &user : users) {
140 const auto &cuser = user.first->cast<CNodePtr>();
141 if (!cuser) {
142 continue;
143 }
144 if (IsPrimitiveCNode(cuser, prim::kPrimTupleGetItem) && IsValidNode(cuser, root->get_return(), node_users_map)) {
145 cuser->AddPrimalAttr(GRAD_ACCU_BACKWARD_END, micro);
146 break;
147 }
148 }
149 }
150 }
151
TagMicroBatchBpEndPrim(const FuncGraphPtr & root)152 void TagMicroBatchBpEndPrim(const FuncGraphPtr &root) {
153 FuncGraphPtr parallel_care_graph = nullptr;
154 for (auto &fg : root->manager()->func_graphs()) {
155 for (auto &node : fg->nodes()) {
156 if (IsPrimitiveCNode(node, prim::kPrimVirtualDataset)) {
157 parallel_care_graph = fg;
158 break;
159 }
160 }
161 }
162 if (!parallel_care_graph) {
163 MS_LOG(EXCEPTION) << "Cannot find parallel care graph with VirtualDataset";
164 }
165 bool is_found = false;
166 auto orders = parallel_care_graph->GetOrderedCnodes();
167 for (auto node = orders.cbegin(); node != orders.cend(); ++node) {
168 auto cnode = (*node)->cast<CNodePtr>();
169 MS_EXCEPTION_IF_NULL(cnode);
170 auto prim = GetCNodePrimitive(cnode);
171 if (!prim || !IsParallelConsiderCNode(cnode) ||
172 IsSomePrimitiveList(cnode, {prim::kPrimTupleGetItem->name(), prim::kPrimMakeTuple->name()})) {
173 continue;
174 }
175 for (size_t i = 1; i < cnode->size(); ++i) {
176 std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(i), parallel_care_graph);
177 if (param_node_pair.first) {
178 (void)prim->AddAttr(FIRST_PARAMETER_CNODE, MakeValue(0));
179 is_found = true;
180 break;
181 }
182 }
183 if (is_found) {
184 break;
185 }
186 }
187 }
188
TagMicroBatchBpEnd(const FuncGraphPtr & root)189 void TagMicroBatchBpEnd(const FuncGraphPtr &root) {
190 AnfNodePtr ret = root->get_return();
191 MS_EXCEPTION_IF_NULL(ret);
192 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
193 for (const auto &node : all_nodes) {
194 if (!IsPrimitiveCNode(node)) {
195 continue;
196 }
197 auto cnode = node->cast<CNodePtr>();
198 auto prim = GetCNodePrimitive(cnode);
199 if (!prim->HasAttr(FIRST_PARAMETER_CNODE)) {
200 continue;
201 }
202 auto micro = SearchPreNodeMicro(cnode->cast<CNodePtr>());
203 if (!micro) {
204 MS_LOG(EXCEPTION) << "Cannot find micro info for node:" << node->DebugString();
205 }
206 cnode->AddPrimalAttr(GRAD_ACCU_BACKWARD_END, micro);
207 }
208 }
209
ExtractMicroBatchBorderNodes(const FuncGraphPtr & root,std::unordered_map<int64_t,std::vector<CNodePtr>> * forward_start,std::unordered_map<int64_t,std::vector<CNodePtr>> * backward_end)210 void ExtractMicroBatchBorderNodes(const FuncGraphPtr &root,
211 std::unordered_map<int64_t, std::vector<CNodePtr>> *forward_start,
212 std::unordered_map<int64_t, std::vector<CNodePtr>> *backward_end) {
213 AnfNodePtr ret = root->get_return();
214 MS_EXCEPTION_IF_NULL(ret);
215 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
216 for (const auto &node : all_nodes) {
217 if (!IsPrimitiveCNode(node)) {
218 continue;
219 }
220 auto cnode = node->cast<CNodePtr>();
221 bool is_bp_node = cnode->HasPrimalAttr(kPrimalAttrForwardNodeName);
222 if (!is_bp_node && cnode->HasPrimalAttr(GRAD_ACCU_FORWARD_BEGIN)) {
223 auto accu_forward_begin_micro = GetValue<int64_t>(cnode->GetPrimalAttr(GRAD_ACCU_FORWARD_BEGIN));
224 (*forward_start)[accu_forward_begin_micro].push_back(cnode);
225 }
226 if ((is_bp_node || IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) &&
227 cnode->HasPrimalAttr(GRAD_ACCU_BACKWARD_END)) {
228 auto accu_backward_end_micro = GetValue<int64_t>(cnode->GetPrimalAttr(GRAD_ACCU_BACKWARD_END));
229 (*backward_end)[accu_backward_end_micro].push_back(cnode);
230 }
231 }
232 }
233
ReorderGradAccumulation(const FuncGraphPtr & root,const std::unordered_map<int64_t,std::vector<CNodePtr>> & forward_start,const std::unordered_map<int64_t,std::vector<CNodePtr>> & backward_end)234 void ReorderGradAccumulation(const FuncGraphPtr &root,
235 const std::unordered_map<int64_t, std::vector<CNodePtr>> &forward_start,
236 const std::unordered_map<int64_t, std::vector<CNodePtr>> &backward_end) {
237 if (forward_start.empty() || backward_end.empty()) {
238 MS_LOG(EXCEPTION) << "Cannot find grad_accumulation border node.";
239 }
240 auto manager = root->manager();
241 for (int64_t micro = 0; micro < ParallelContext::GetInstance()->grad_accumulation_step() - 1; ++micro) {
242 if (forward_start.find(micro + 1) == forward_start.end()) {
243 MS_LOG(EXCEPTION) << "Micro " << micro + 1 << " cannot find forward_start nodes.";
244 }
245 if (backward_end.find(micro) == backward_end.end()) {
246 MS_LOG(EXCEPTION) << "Micro " << micro << " cannot find backward_end nodes.";
247 }
248 // backward_end -> depend -> next_forward_start
249 std::vector<AnfNodePtr> backward_end_inputs{NewValueNode(prim::kPrimMakeTuple)};
250 std::copy(backward_end.at(micro).begin(), backward_end.at(micro).end(), std::back_inserter(backward_end_inputs));
251 auto backward_end_make_tuple_cnode = root->NewCNode(backward_end_inputs);
252 for (const auto &forward_start_node : forward_start.at(micro + 1)) {
253 std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), forward_start_node->input(1),
254 backward_end_make_tuple_cnode};
255 auto depend_node = root->NewCNode(depend_inputs);
256 depend_node->AddAttr("grad_accu_reorder2", MakeValue(micro));
257 manager->SetEdge(forward_start_node, 1, depend_node);
258 }
259 }
260 }
261 } // namespace parallel
262 } // namespace mindspore
263