• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/ad/grad.h"
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 #include "frontend/optimizer/ad/dfunctor.h"
22 #include "frontend/optimizer/irpass.h"
23 #include "ir/func_graph_cloner.h"
24 #include "utils/ms_context.h"
25 #include "utils/symbolic.h"
26 #include "include/common/utils/parallel_context.h"
27 
28 namespace mindspore {
29 namespace ad {
30 namespace {
PartialEliminateOptPass(const pipeline::ResourcePtr & resource,const FuncGraphPtr & func_graph)31 FuncGraphPtr PartialEliminateOptPass(const pipeline::ResourcePtr &resource, const FuncGraphPtr &func_graph) {
32   MS_EXCEPTION_IF_NULL(resource);
33 
34   opt::irpass::OptimizeIRPassLib irpass;
35   opt::OptPassConfig partial_eliminate_opt_ = opt::OptPassConfig(
36     {irpass.partial_eliminate_, irpass.switch_partial_eliminater_, irpass.switch_layer_partial_eliminater_});
37   opt::OptPassGroupMap map({{"partial_eliminate_", partial_eliminate_opt_}});
38 
39   auto after_lift_opt = opt::Optimizer::MakeOptimizer("partial_eliminate", resource, map);
40 
41   FuncGraphPtr opt_fg = nullptr;
42   ProfileExecute(MsProfile::GetProfile()->Step("partial_eliminate_before_grad"),
43                  [&after_lift_opt, func_graph, &opt_fg]() { opt_fg = after_lift_opt->step(func_graph, true); });
44   return opt_fg;
45 }
46 
PartialEliminateMulti(const pipeline::ResourceBasePtr & resource,const FuncGraphVector & func_graphs)47 FuncGraphVector PartialEliminateMulti(const pipeline::ResourceBasePtr &resource, const FuncGraphVector &func_graphs) {
48   auto new_res = std::dynamic_pointer_cast<pipeline::Resource>(resource);
49   if (new_res == nullptr) {
50     MS_LOG(INTERNAL_EXCEPTION) << "Parameter resources is not a pipeline::Resource";
51   }
52   FuncGraphVector opt_fgs;
53   for (const auto &func_graph : func_graphs) {
54     auto opt_fg = PartialEliminateOptPass(new_res, func_graph);
55 #ifdef ENABLE_DUMP_IR
56     auto context = MsContext::GetInstance();
57     MS_EXCEPTION_IF_NULL(context);
58     if (context->CanDump(kIntroductory)) {
59       DumpIR("after_opt_" + opt_fg->ToString() + ".ir", opt_fg);
60     }
61 #endif
62     opt_fgs.push_back(opt_fg);
63   }
64   return opt_fgs;
65 }
66 
LiftFv(const pipeline::ResourceBasePtr & resource,const FuncGraphPtr & func_graph)67 FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPtr &func_graph) {
68 #ifdef ENABLE_DUMP_IR
69   auto context = MsContext::GetInstance();
70   MS_EXCEPTION_IF_NULL(context);
71   bool enable_save_graphs = context->CanDump(kIntroductory);
72   if (enable_save_graphs) {
73     DumpIR("before_lift_" + func_graph->ToString() + ".ir", func_graph);
74   }
75 #endif
76   FuncGraphPtr new_fg = LiftingClone(func_graph);
77 #ifdef ENABLE_DUMP_IR
78   if (enable_save_graphs) {
79     DumpIR("after_lift_" + new_fg->ToString() + ".ir", new_fg);
80   }
81 #endif
82   auto new_res = std::dynamic_pointer_cast<pipeline::Resource>(resource);
83   if (new_res == nullptr) {
84     MS_LOG(INTERNAL_EXCEPTION) << "Parameter resources is not a pipeline::Resource";
85   }
86   auto opt_fg = PartialEliminateOptPass(new_res, new_fg);
87 #ifdef ENABLE_DUMP_IR
88   if (enable_save_graphs) {
89     DumpIR("after_opt_" + opt_fg->ToString() + ".ir", opt_fg);
90   }
91 #endif
92   return opt_fg;
93 }
94 
LiftFvMulti(const pipeline::ResourceBasePtr & resource,const FuncGraphVector & func_graphs)95 FuncGraphVector LiftFvMulti(const pipeline::ResourceBasePtr &resource, const FuncGraphVector &func_graphs) {
96 #ifdef ENABLE_DUMP_IR
97   auto context = MsContext::GetInstance();
98   MS_EXCEPTION_IF_NULL(context);
99   if (context->CanDump(kIntroductory)) {
100     for (const auto &func_graph : func_graphs) {
101       DumpIR("before_lift_" + func_graph->ToString() + ".ir", func_graph);
102     }
103   }
104 #endif
105   bool has_used_fg = std::any_of(func_graphs.cbegin(), func_graphs.cend(), [](const FuncGraphPtr &func_graph) {
106     return func_graph->func_graphs_used().size() != 0;
107   });
108   // All func_graphs being graded don't have used funcgraphs, no need to do lifting clone.
109   if (!has_used_fg) {
110     return func_graphs;
111   }
112   FuncGraphVector new_fgs = LiftingCloneMulti(func_graphs);
113 #ifdef ENABLE_DUMP_IR
114   if (context->CanDump(kIntroductory)) {
115     for (const auto &new_fg : new_fgs) {
116       DumpIR("after_lift_" + new_fg->ToString() + ".ir", new_fg);
117     }
118   }
119 #endif
120   return PartialEliminateMulti(resource, new_fgs);
121 }
122 
ForwardInputsEqual(const AnfNodeWeakPtrList & first_inputs,const AnfNodeWeakPtrList & second_inputs)123 bool ForwardInputsEqual(const AnfNodeWeakPtrList &first_inputs, const AnfNodeWeakPtrList &second_inputs) {
124   if (first_inputs.size() != second_inputs.size()) {
125     return false;
126   }
127   for (size_t i = 1; i < first_inputs.size(); ++i) {
128     if (HasAbstractMonad(first_inputs[i].lock()) && HasAbstractMonad(second_inputs[i].lock())) {
129       continue;
130     }
131     if (first_inputs[i].lock() != second_inputs[i].lock()) {
132       return false;
133     }
134   }
135   return true;
136 }
137 
GetJUser(const FuncGraphManagerPtr & manager,const AnfNodePtr & j_node)138 AnfNodePtr GetJUser(const FuncGraphManagerPtr &manager, const AnfNodePtr &j_node) {
139   auto iter = manager->node_users().find(j_node);
140   if (iter == manager->node_users().end()) {
141     return nullptr;
142   }
143   auto users = iter->second;
144   if (users.size() != 1) {
145     MS_LOG(EXCEPTION) << "The size of J users should be 1, but got " << users.size();
146   }
147   return users.begin()->first;
148 }
149 }  // namespace
150 
GradOneFuncGraph(const FuncGraphPtr & func_graph,const opt::OptimizerPtr & optimizer,bool is_top,BpropAutoMonadLevel level)151 FuncGraphPtr GradOneFuncGraph(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer, bool is_top,
152                               BpropAutoMonadLevel level) {
153   MS_EXCEPTION_IF_NULL(func_graph);
154   auto gradkv = func_graph->transforms().find("grad");
155   if (gradkv != func_graph->transforms().end()) {
156     return gradkv->second.func_graph();
157   }
158   const auto &resources = optimizer->resource();
159   auto manager_ptr = resources->manager();
160   MS_EXCEPTION_IF_NULL(manager_ptr);
161   manager_ptr->AddFuncGraph(func_graph);
162   auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
163     if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
164       if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
165         f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
166       }
167     }
168   };
169 
170   auto f = std::make_shared<DFunctor>(func_graph, resources, is_top);
171   auto user_defined = f->KUserDefined(func_graph);
172   if (user_defined != nullptr) {
173     multi_graph_sink(user_defined);
174     if (is_top) {
175       DFunctor::Clear();
176     }
177     return user_defined;
178   }
179   f->Init(is_top);
180   f->MapObject();
181   f->MapMorphism();
182   f->Finish();
183   auto res = f->k_graph();
184   res->set_attr(kAttrBpropAutoMonadLevel, MakeValue<int>(level));
185   auto tape = f->tape();
186   tape->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true);
187   if (is_top) {
188     DFunctor::Clear();
189   }
190 
191   multi_graph_sink(res);
192   (void)func_graph->transforms().emplace("grad", FuncGraphTransform(res));
193   return res;
194 }
195 
Grad(const FuncGraphPtr & func_graph,const opt::OptimizerPtr & optimizer,bool is_top,BpropAutoMonadLevel level)196 FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer, bool is_top,
197                   BpropAutoMonadLevel level) {
198   MS_EXCEPTION_IF_NULL(func_graph);
199   auto gradkv = func_graph->transforms().find("grad");
200   if (gradkv != func_graph->transforms().end()) {
201     return gradkv->second.func_graph();
202   }
203 
204   const auto &resources = optimizer->resource();
205   auto manager_ptr = resources->manager();
206   MS_EXCEPTION_IF_NULL(manager_ptr);
207   manager_ptr->AddFuncGraph(func_graph);
208 
209   FuncGraphPtr grad_fg = func_graph;
210   if (func_graph->func_graphs_used().size() != 0 && optimizer->is_first_order_j()) {
211     lift_fv_before_grad = true;
212     grad_fg = LiftFv(resources, func_graph);
213   } else {
214     lift_fv_before_grad = false;
215   }
216   return GradOneFuncGraph(grad_fg, optimizer, is_top, level);
217 }
218 
GradMultiFuncGraph(const FuncGraphVector & func_graphs,const opt::OptimizerPtr & optimizer,bool is_top)219 FuncGraphVector GradMultiFuncGraph(const FuncGraphVector &func_graphs, const opt::OptimizerPtr &optimizer,
220                                    bool is_top) {
221   auto parallel_context = parallel::ParallelContext::GetInstance();
222   MS_EXCEPTION_IF_NULL(parallel_context);
223   auto parallel_mode = parallel_context->parallel_mode();
224   const bool is_parallel_mode =
225     parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
226   BpropAutoMonadLevel bprop_auto_monad_level = is_parallel_mode ? kLevelTop : kLevelWhole;
227   FuncGraphVector grad_fgs;
228   if (func_graphs.size() == 1) {
229     auto grad_fg = Grad(func_graphs[0], optimizer, is_top, bprop_auto_monad_level);
230     grad_fgs.push_back(grad_fg);
231     return grad_fgs;
232   }
233   const auto &resources = optimizer->resource();
234   auto manager_ptr = resources->manager();
235   MS_EXCEPTION_IF_NULL(manager_ptr);
236   for (const auto &func_graph : func_graphs) {
237     manager_ptr->AddFuncGraph(func_graph);
238   }
239   FuncGraphVector before_grad_fgs;
240   if (optimizer->is_first_order_j()) {
241     lift_fv_before_grad = true;
242     before_grad_fgs = LiftFvMulti(resources, func_graphs);
243   } else {
244     before_grad_fgs = func_graphs;
245     lift_fv_before_grad = false;
246   }
247   for (const auto &func_graph : before_grad_fgs) {
248     auto grad_fg = GradOneFuncGraph(func_graph, optimizer, is_top, bprop_auto_monad_level);
249     grad_fgs.push_back(grad_fg);
250   }
251   return grad_fgs;
252 }
253 
Kprim(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)254 FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
255   auto fg = g_k_prims.KPrimitive(nullptr, value_node, resources);
256   if (fg == nullptr) {
257     return nullptr;
258   }
259   return BasicClone(fg);
260 }
261 
Kmeta(const PrimitivePtr & prim,const pipeline::ResourceBasePtr &)262 MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &) {
263   MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim);
264   return fg;
265 }
266 
CleanRes()267 void CleanRes() { DFunctor::Clear(); }
268 
MergeForward(const FuncGraphPtr & root,const opt::OptimizerPtr & opt)269 bool MergeForward(const FuncGraphPtr &root, const opt::OptimizerPtr &opt) {
270   auto manager = opt->manager();
271   MS_EXCEPTION_IF_NULL(manager);
272   std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> forward_fg_to_j_nodes;
273   auto all_nodes = TopoSort(root->get_return(), SuccDeeperSimple, AlwaysInclude);
274   for (const auto &node : all_nodes) {
275     if (!IsPrimitiveCNode(node, prim::kPrimJ)) {
276       continue;
277     }
278     auto cnode = node->cast<CNodePtr>();
279     auto merge_forward = cnode->user_data<bool>("merge_forward");
280     if (merge_forward == nullptr || !(*merge_forward)) {
281       continue;
282     }
283     auto forward_fg = GetValueNode<FuncGraphPtr>(cnode->input(1));
284     if (forward_fg == nullptr) {
285       continue;
286     }
287     (void)forward_fg_to_j_nodes[forward_fg].emplace_back(node);
288   }
289   bool change = false;
290   for (const auto &iter : forward_fg_to_j_nodes) {
291     auto &j_nodes = iter.second;
292     MS_LOG(DEBUG) << "J nodes size is " << j_nodes.size();
293     if (j_nodes.size() <= 1) {
294       continue;
295     }
296     auto first_j_user = GetJUser(manager, j_nodes[0]);
297     if (first_j_user == nullptr) {
298       continue;
299     }
300     const auto &first_forward_inputs = first_j_user->cast<CNodePtr>()->weak_inputs();
301     for (size_t i = 1; i < j_nodes.size(); ++i) {
302       auto j_user = GetJUser(manager, j_nodes[i]);
303       const auto &forward_inputs = j_user->cast<CNodePtr>()->weak_inputs();
304       if (!ForwardInputsEqual(first_forward_inputs, forward_inputs)) {
305         continue;
306       }
307       manager->Replace(j_user, first_j_user);
308       MS_LOG(DEBUG) << "Replace J user " << j_user->DebugString() << " with the first J user "
309                     << first_j_user->DebugString();
310       change = true;
311     }
312   }
313   return change;
314 }
315 }  // namespace ad
316 }  // namespace mindspore
317