1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/ad/grad.h"
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 #include "frontend/optimizer/ad/dfunctor.h"
22 #include "frontend/optimizer/irpass.h"
23 #include "ir/func_graph_cloner.h"
24 #include "utils/ms_context.h"
25 #include "utils/symbolic.h"
26 #include "include/common/utils/parallel_context.h"
27
28 namespace mindspore {
29 namespace ad {
30 namespace {
PartialEliminateOptPass(const pipeline::ResourcePtr & resource,const FuncGraphPtr & func_graph)31 FuncGraphPtr PartialEliminateOptPass(const pipeline::ResourcePtr &resource, const FuncGraphPtr &func_graph) {
32 MS_EXCEPTION_IF_NULL(resource);
33
34 opt::irpass::OptimizeIRPassLib irpass;
35 opt::OptPassConfig partial_eliminate_opt_ = opt::OptPassConfig(
36 {irpass.partial_eliminate_, irpass.switch_partial_eliminater_, irpass.switch_layer_partial_eliminater_});
37 opt::OptPassGroupMap map({{"partial_eliminate_", partial_eliminate_opt_}});
38
39 auto after_lift_opt = opt::Optimizer::MakeOptimizer("partial_eliminate", resource, map);
40
41 FuncGraphPtr opt_fg = nullptr;
42 ProfileExecute(MsProfile::GetProfile()->Step("partial_eliminate_before_grad"),
43 [&after_lift_opt, func_graph, &opt_fg]() { opt_fg = after_lift_opt->step(func_graph, true); });
44 return opt_fg;
45 }
46
PartialEliminateMulti(const pipeline::ResourceBasePtr & resource,const FuncGraphVector & func_graphs)47 FuncGraphVector PartialEliminateMulti(const pipeline::ResourceBasePtr &resource, const FuncGraphVector &func_graphs) {
48 auto new_res = std::dynamic_pointer_cast<pipeline::Resource>(resource);
49 if (new_res == nullptr) {
50 MS_LOG(INTERNAL_EXCEPTION) << "Parameter resources is not a pipeline::Resource";
51 }
52 FuncGraphVector opt_fgs;
53 for (const auto &func_graph : func_graphs) {
54 auto opt_fg = PartialEliminateOptPass(new_res, func_graph);
55 #ifdef ENABLE_DUMP_IR
56 auto context = MsContext::GetInstance();
57 MS_EXCEPTION_IF_NULL(context);
58 if (context->CanDump(kIntroductory)) {
59 DumpIR("after_opt_" + opt_fg->ToString() + ".ir", opt_fg);
60 }
61 #endif
62 opt_fgs.push_back(opt_fg);
63 }
64 return opt_fgs;
65 }
66
LiftFv(const pipeline::ResourceBasePtr & resource,const FuncGraphPtr & func_graph)67 FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPtr &func_graph) {
68 #ifdef ENABLE_DUMP_IR
69 auto context = MsContext::GetInstance();
70 MS_EXCEPTION_IF_NULL(context);
71 bool enable_save_graphs = context->CanDump(kIntroductory);
72 if (enable_save_graphs) {
73 DumpIR("before_lift_" + func_graph->ToString() + ".ir", func_graph);
74 }
75 #endif
76 FuncGraphPtr new_fg = LiftingClone(func_graph);
77 #ifdef ENABLE_DUMP_IR
78 if (enable_save_graphs) {
79 DumpIR("after_lift_" + new_fg->ToString() + ".ir", new_fg);
80 }
81 #endif
82 auto new_res = std::dynamic_pointer_cast<pipeline::Resource>(resource);
83 if (new_res == nullptr) {
84 MS_LOG(INTERNAL_EXCEPTION) << "Parameter resources is not a pipeline::Resource";
85 }
86 auto opt_fg = PartialEliminateOptPass(new_res, new_fg);
87 #ifdef ENABLE_DUMP_IR
88 if (enable_save_graphs) {
89 DumpIR("after_opt_" + opt_fg->ToString() + ".ir", opt_fg);
90 }
91 #endif
92 return opt_fg;
93 }
94
LiftFvMulti(const pipeline::ResourceBasePtr & resource,const FuncGraphVector & func_graphs)95 FuncGraphVector LiftFvMulti(const pipeline::ResourceBasePtr &resource, const FuncGraphVector &func_graphs) {
96 #ifdef ENABLE_DUMP_IR
97 auto context = MsContext::GetInstance();
98 MS_EXCEPTION_IF_NULL(context);
99 if (context->CanDump(kIntroductory)) {
100 for (const auto &func_graph : func_graphs) {
101 DumpIR("before_lift_" + func_graph->ToString() + ".ir", func_graph);
102 }
103 }
104 #endif
105 bool has_used_fg = std::any_of(func_graphs.cbegin(), func_graphs.cend(), [](const FuncGraphPtr &func_graph) {
106 return func_graph->func_graphs_used().size() != 0;
107 });
108 // All func_graphs being graded don't have used funcgraphs, no need to do lifting clone.
109 if (!has_used_fg) {
110 return func_graphs;
111 }
112 FuncGraphVector new_fgs = LiftingCloneMulti(func_graphs);
113 #ifdef ENABLE_DUMP_IR
114 if (context->CanDump(kIntroductory)) {
115 for (const auto &new_fg : new_fgs) {
116 DumpIR("after_lift_" + new_fg->ToString() + ".ir", new_fg);
117 }
118 }
119 #endif
120 return PartialEliminateMulti(resource, new_fgs);
121 }
122
ForwardInputsEqual(const AnfNodeWeakPtrList & first_inputs,const AnfNodeWeakPtrList & second_inputs)123 bool ForwardInputsEqual(const AnfNodeWeakPtrList &first_inputs, const AnfNodeWeakPtrList &second_inputs) {
124 if (first_inputs.size() != second_inputs.size()) {
125 return false;
126 }
127 for (size_t i = 1; i < first_inputs.size(); ++i) {
128 if (HasAbstractMonad(first_inputs[i].lock()) && HasAbstractMonad(second_inputs[i].lock())) {
129 continue;
130 }
131 if (first_inputs[i].lock() != second_inputs[i].lock()) {
132 return false;
133 }
134 }
135 return true;
136 }
137
GetJUser(const FuncGraphManagerPtr & manager,const AnfNodePtr & j_node)138 AnfNodePtr GetJUser(const FuncGraphManagerPtr &manager, const AnfNodePtr &j_node) {
139 auto iter = manager->node_users().find(j_node);
140 if (iter == manager->node_users().end()) {
141 return nullptr;
142 }
143 auto users = iter->second;
144 if (users.size() != 1) {
145 MS_LOG(EXCEPTION) << "The size of J users should be 1, but got " << users.size();
146 }
147 return users.begin()->first;
148 }
149 } // namespace
150
GradOneFuncGraph(const FuncGraphPtr & func_graph,const opt::OptimizerPtr & optimizer,bool is_top,BpropAutoMonadLevel level)151 FuncGraphPtr GradOneFuncGraph(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer, bool is_top,
152 BpropAutoMonadLevel level) {
153 MS_EXCEPTION_IF_NULL(func_graph);
154 auto gradkv = func_graph->transforms().find("grad");
155 if (gradkv != func_graph->transforms().end()) {
156 return gradkv->second.func_graph();
157 }
158 const auto &resources = optimizer->resource();
159 auto manager_ptr = resources->manager();
160 MS_EXCEPTION_IF_NULL(manager_ptr);
161 manager_ptr->AddFuncGraph(func_graph);
162 auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
163 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
164 if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
165 f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
166 }
167 }
168 };
169
170 auto f = std::make_shared<DFunctor>(func_graph, resources, is_top);
171 auto user_defined = f->KUserDefined(func_graph);
172 if (user_defined != nullptr) {
173 multi_graph_sink(user_defined);
174 if (is_top) {
175 DFunctor::Clear();
176 }
177 return user_defined;
178 }
179 f->Init(is_top);
180 f->MapObject();
181 f->MapMorphism();
182 f->Finish();
183 auto res = f->k_graph();
184 res->set_attr(kAttrBpropAutoMonadLevel, MakeValue<int>(level));
185 auto tape = f->tape();
186 tape->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true);
187 if (is_top) {
188 DFunctor::Clear();
189 }
190
191 multi_graph_sink(res);
192 (void)func_graph->transforms().emplace("grad", FuncGraphTransform(res));
193 return res;
194 }
195
Grad(const FuncGraphPtr & func_graph,const opt::OptimizerPtr & optimizer,bool is_top,BpropAutoMonadLevel level)196 FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer, bool is_top,
197 BpropAutoMonadLevel level) {
198 MS_EXCEPTION_IF_NULL(func_graph);
199 auto gradkv = func_graph->transforms().find("grad");
200 if (gradkv != func_graph->transforms().end()) {
201 return gradkv->second.func_graph();
202 }
203
204 const auto &resources = optimizer->resource();
205 auto manager_ptr = resources->manager();
206 MS_EXCEPTION_IF_NULL(manager_ptr);
207 manager_ptr->AddFuncGraph(func_graph);
208
209 FuncGraphPtr grad_fg = func_graph;
210 if (func_graph->func_graphs_used().size() != 0 && optimizer->is_first_order_j()) {
211 lift_fv_before_grad = true;
212 grad_fg = LiftFv(resources, func_graph);
213 } else {
214 lift_fv_before_grad = false;
215 }
216 return GradOneFuncGraph(grad_fg, optimizer, is_top, level);
217 }
218
GradMultiFuncGraph(const FuncGraphVector & func_graphs,const opt::OptimizerPtr & optimizer,bool is_top)219 FuncGraphVector GradMultiFuncGraph(const FuncGraphVector &func_graphs, const opt::OptimizerPtr &optimizer,
220 bool is_top) {
221 auto parallel_context = parallel::ParallelContext::GetInstance();
222 MS_EXCEPTION_IF_NULL(parallel_context);
223 auto parallel_mode = parallel_context->parallel_mode();
224 const bool is_parallel_mode =
225 parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
226 BpropAutoMonadLevel bprop_auto_monad_level = is_parallel_mode ? kLevelTop : kLevelWhole;
227 FuncGraphVector grad_fgs;
228 if (func_graphs.size() == 1) {
229 auto grad_fg = Grad(func_graphs[0], optimizer, is_top, bprop_auto_monad_level);
230 grad_fgs.push_back(grad_fg);
231 return grad_fgs;
232 }
233 const auto &resources = optimizer->resource();
234 auto manager_ptr = resources->manager();
235 MS_EXCEPTION_IF_NULL(manager_ptr);
236 for (const auto &func_graph : func_graphs) {
237 manager_ptr->AddFuncGraph(func_graph);
238 }
239 FuncGraphVector before_grad_fgs;
240 if (optimizer->is_first_order_j()) {
241 lift_fv_before_grad = true;
242 before_grad_fgs = LiftFvMulti(resources, func_graphs);
243 } else {
244 before_grad_fgs = func_graphs;
245 lift_fv_before_grad = false;
246 }
247 for (const auto &func_graph : before_grad_fgs) {
248 auto grad_fg = GradOneFuncGraph(func_graph, optimizer, is_top, bprop_auto_monad_level);
249 grad_fgs.push_back(grad_fg);
250 }
251 return grad_fgs;
252 }
253
Kprim(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)254 FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
255 auto fg = g_k_prims.KPrimitive(nullptr, value_node, resources);
256 if (fg == nullptr) {
257 return nullptr;
258 }
259 return BasicClone(fg);
260 }
261
Kmeta(const PrimitivePtr & prim,const pipeline::ResourceBasePtr &)262 MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &) {
263 MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim);
264 return fg;
265 }
266
CleanRes()267 void CleanRes() { DFunctor::Clear(); }
268
MergeForward(const FuncGraphPtr & root,const opt::OptimizerPtr & opt)269 bool MergeForward(const FuncGraphPtr &root, const opt::OptimizerPtr &opt) {
270 auto manager = opt->manager();
271 MS_EXCEPTION_IF_NULL(manager);
272 std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> forward_fg_to_j_nodes;
273 auto all_nodes = TopoSort(root->get_return(), SuccDeeperSimple, AlwaysInclude);
274 for (const auto &node : all_nodes) {
275 if (!IsPrimitiveCNode(node, prim::kPrimJ)) {
276 continue;
277 }
278 auto cnode = node->cast<CNodePtr>();
279 auto merge_forward = cnode->user_data<bool>("merge_forward");
280 if (merge_forward == nullptr || !(*merge_forward)) {
281 continue;
282 }
283 auto forward_fg = GetValueNode<FuncGraphPtr>(cnode->input(1));
284 if (forward_fg == nullptr) {
285 continue;
286 }
287 (void)forward_fg_to_j_nodes[forward_fg].emplace_back(node);
288 }
289 bool change = false;
290 for (const auto &iter : forward_fg_to_j_nodes) {
291 auto &j_nodes = iter.second;
292 MS_LOG(DEBUG) << "J nodes size is " << j_nodes.size();
293 if (j_nodes.size() <= 1) {
294 continue;
295 }
296 auto first_j_user = GetJUser(manager, j_nodes[0]);
297 if (first_j_user == nullptr) {
298 continue;
299 }
300 const auto &first_forward_inputs = first_j_user->cast<CNodePtr>()->weak_inputs();
301 for (size_t i = 1; i < j_nodes.size(); ++i) {
302 auto j_user = GetJUser(manager, j_nodes[i]);
303 const auto &forward_inputs = j_user->cast<CNodePtr>()->weak_inputs();
304 if (!ForwardInputsEqual(first_forward_inputs, forward_inputs)) {
305 continue;
306 }
307 manager->Replace(j_user, first_j_user);
308 MS_LOG(DEBUG) << "Replace J user " << j_user->DebugString() << " with the first J user "
309 << first_j_user->DebugString();
310 change = true;
311 }
312 }
313 return change;
314 }
315 } // namespace ad
316 } // namespace mindspore
317