• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/func_graph_cloner.h"
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "mindspore/core/ops/framework_ops.h"
20 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
21 #include "pipeline/jit/ps/pass.h"
22 
23 namespace mindspore {
24 namespace ad {
TryFreeArgsValue(const ValuePtrList & op_args,const ValuePtr & out)25 void PrimBpropOptGraphLevel2Info::TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out) {
26   // args_value_using_info_ contains out
27   if (args_value_using_info_.size() != op_args.size() + 1) {
28     MS_LOG(INTERNAL_EXCEPTION) << "Parameter size :" << args_value_using_info_.size()
29                                << " of bp_graph:" << opt_func_graph_->ToString()
30                                << " not match input arguments num:" << op_args.size();
31   }
32 
33   ValuePtrList new_args(op_args);
34   (void)new_args.emplace_back(out);
35   TryFreeOneValue(new_args, args_value_using_info_);
36 }
37 
TryFreeOneValue(const ValuePtrList & op_args,const std::vector<ParamUsingInfo> & param_info_vec)38 void PrimBpropOptGraphLevel2Info::TryFreeOneValue(const ValuePtrList &op_args,
39                                                   const std::vector<ParamUsingInfo> &param_info_vec) {
40   if (param_info_vec.size() != op_args.size()) {
41     MS_LOG(INTERNAL_EXCEPTION) << "Parameter size :" << param_info_vec.size()
42                                << " of bp_graph:" << opt_func_graph_->ToString()
43                                << " not match input arguments num:" << op_args.size();
44   }
45 
46   for (size_t i = 0; i < op_args.size(); ++i) {
47     if (!param_info_vec[i].using_flg_ && !param_info_vec[i].tuple_flg_ && op_args[i]->isa<tensor::Tensor>()) {
48       auto value = op_args[i]->cast<tensor::TensorPtr>();
49       value->set_device_address(nullptr);
50     } else if (param_info_vec[i].tuple_flg_ && op_args[i]->isa<ValueTuple>()) {
51       auto value = op_args[i]->cast<ValueTuplePtr>();
52       MS_EXCEPTION_IF_NULL(value);
53       TryFreeOneValue(value->value(), param_info_vec[i].sub_using_info_);
54     }
55   }
56 }
57 
AnalysisArgUsingInfo(const FuncGraphManagerPtr & manager)58 void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager) {
59   MS_EXCEPTION_IF_NULL(manager);
60   if (analysis_finish_flg_) {
61     return;
62   }
63   MS_EXCEPTION_IF_NULL(opt_func_graph_);
64   auto &params = opt_func_graph_->parameters();
65   const auto &node_users = manager->node_users();
66   args_value_using_info_.resize(params.size() - 1);
67   // analysis value using flg except dout
68   for (size_t i = 0; i < params.size() - 1; ++i) {
69     auto &param = params[i];
70     auto &arg_info = args_value_using_info_[i];
71     ArgInfoRefresh(param, &arg_info);
72     AnalysisNodeUsingInfo(node_users, param, &arg_info);
73   }
74   analysis_finish_flg_ = true;
75 }
76 
AnalysisNodeUsingInfo(const NodeUsersMap & node_users,const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info) const77 void PrimBpropOptGraphLevel2Info::AnalysisNodeUsingInfo(const NodeUsersMap &node_users,
78                                                         const std::shared_ptr<AnfNode> &param,
79                                                         ParamUsingInfo *arg_info) const {
80   MS_EXCEPTION_IF_NULL(arg_info);
81   auto iter = node_users.find(param);
82   if (iter == node_users.end()) {
83     arg_info->using_flg_ = false;
84     return;
85   }
86 
87   // tensor return directly
88   if (!arg_info->tuple_flg_) {
89     arg_info->using_flg_ = true;
90     return;
91   }
92 
93   // specific process for tuple parameter, may only partial items used
94   const auto &users_info = iter->second;
95   for (auto &user_info : users_info) {
96     auto user_node = user_info.first;
97     arg_info->using_flg_ = true;
98     MS_LOG(DEBUG) << "param:" << param->ToString() << " used by node:" << user_node->ToString();
99     if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
100       for (auto &sub_info : arg_info->sub_using_info_) {
101         sub_info.using_flg_ = true;
102       }
103     } else {
104       AalysisForTupleGetItem(node_users, param, arg_info, user_node);
105     }
106   }
107 }
AalysisForTupleGetItem(const NodeUsersMap & node_users,const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info,const AnfNodePtr & user_node) const108 void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &node_users,
109                                                          const std::shared_ptr<AnfNode> &param,
110                                                          ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const {
111   MS_EXCEPTION_IF_NULL(arg_info);
112   MS_EXCEPTION_IF_NULL(user_node);
113   auto cnode = user_node->cast<CNodePtr>();
114   MS_EXCEPTION_IF_NULL(cnode);
115   const size_t tuple_get_item_size = 3;
116   const size_t index = 2;
117   if (cnode->size() != tuple_get_item_size) {
118     MS_LOG(INTERNAL_EXCEPTION) << "TupleGetItem Node:" << user_node->ToString()
119                                << " of bp_graph:" << opt_func_graph_->ToString() << "input size is:" << cnode->size();
120   }
121   auto idx_node = cnode->input(index);
122   if (!idx_node->isa<ValueNode>()) {
123     MS_LOG(INTERNAL_EXCEPTION) << "Tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
124                                << " unexpected used by node:" << user_node->ToString()
125                                << " TupleGetItem idx node:" << idx_node->ToString();
126   }
127 
128   auto vnode = idx_node->cast<ValueNodePtr>();
129   auto value_ptr = vnode->value();
130   if (value_ptr == nullptr || !value_ptr->isa<Int64Imm>()) {
131     MS_LOG(INTERNAL_EXCEPTION) << "Tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
132                                << " unexpected used by node:" << user_node->ToString()
133                                << " TupleGetItem idx node:" << idx_node->ToString() << " idx Value :" << value_ptr;
134   }
135 
136   auto idx = LongToSize(value_ptr->cast<Int64ImmPtr>()->value());
137   arg_info->sub_using_info_[idx].using_flg_ = true;
138   ArgInfoRefresh(cnode, &(arg_info->sub_using_info_[idx]));
139 
140   if (arg_info->tuple_flg_) {
141     AnalysisNodeUsingInfo(node_users, cnode, &(arg_info->sub_using_info_[idx]));
142   }
143 }
144 
ArgInfoRefresh(const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info) const145 void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr<AnfNode> &param,
146                                                  ParamUsingInfo *arg_info) const {
147   MS_EXCEPTION_IF_NULL(arg_info);
148   MS_EXCEPTION_IF_NULL(param);
149   auto abs = param->abstract();
150   MS_EXCEPTION_IF_NULL(abs);
151   if (abs->isa<abstract::AbstractTensor>()) {
152     arg_info->tuple_flg_ = false;
153     MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTensor";
154   } else if (abs->isa<abstract::AbstractTuple>()) {
155     auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
156     MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTuple";
157     arg_info->tuple_flg_ = true;
158     arg_info->tuple_size_ = abs_tuple->size();
159     arg_info->sub_using_info_.resize(abs_tuple->size());
160   } else {
161     arg_info->tuple_flg_ = false;
162   }
163 }
164 
GetPrimBpropOptimizerInst()165 PrimBpropOptimizer &PrimBpropOptimizer::GetPrimBpropOptimizerInst() {
166   static PrimBpropOptimizer g_prim_bprop_opt = PrimBpropOptimizer();
167   return g_prim_bprop_opt;
168 }
169 
Clear()170 void PrimBpropOptimizer::Clear() {
171   prim_bprop_cache_.clear();
172   tuple_list_bprop_cache_.clear();
173 }
174 
175 // bprop_fg has the signature:
176 // (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
177 // c_node contains the prim(input 0) and the input parameters of that prim;
178 // op_args contains the arguments list of each input parameters, it maybe tensor or tuple
179 // out contains the out of c_node;
OptimizeBPropFuncGraph(const FuncGraphPtr & bprop_fg,const CNodePtr & c_node,const ValuePtrList & op_args,const ValuePtr & out)180 FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node,
181                                                         const ValuePtrList &op_args, const ValuePtr &out) {
182   MS_EXCEPTION_IF_NULL(bprop_fg);
183   MS_EXCEPTION_IF_NULL(c_node);
184   MS_EXCEPTION_IF_NULL(out);
185   auto &inputs = c_node->inputs();
186   if (inputs.size() < 1 || inputs.size() - 1 != op_args.size()) {
187     MS_LOG(INTERNAL_EXCEPTION) << "The parameters num " << (inputs.size() - 1) << " not match arguments num "
188                                << op_args.size() << ", CNode:" << c_node->ToString()
189                                << " grad:" << bprop_fg->ToString();
190   }
191 
192   if (!IsValueNode<Primitive>(inputs[0])) {
193     MS_LOG(INTERNAL_EXCEPTION) << "CNode:" << c_node->ToString()
194                                << " not a primitive node, input_0 is:" << inputs[0]->ToString();
195   }
196 
197   PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
198   MS_LOG(DEBUG) << "Hash of prim " << prim->ToString() << " is:" << prim->hash();
199 
200   //  kPrimHookBackward
201   bool hookback_flg =
202     IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
203   if (hookback_flg || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
204     return GenSpecOptBprop(bprop_fg, op_args, out, prim, hookback_flg);
205   }
206 
207   std::vector<bool> need_grad_flags;
208   if (c_node->HasAttr(kAttrNeedGradFlagOfInputs)) {
209     need_grad_flags = GetValue<std::vector<bool>>(c_node->GetAttr(kAttrNeedGradFlagOfInputs));
210   }
211 
212   return GetOptBpropFromCache(bprop_fg, op_args, out, prim, need_grad_flags);
213 }
214 
GetOptBpropFromCache(const FuncGraphPtr & bprop_fg,const ValuePtrList & op_args,const ValuePtr & out,const PrimitivePtr & prim,const std::vector<bool> & need_grad_flags)215 FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
216                                                       const ValuePtr &out, const PrimitivePtr &prim,
217                                                       const std::vector<bool> &need_grad_flags) {
218   MS_EXCEPTION_IF_NULL(bprop_fg);
219   abstract::AbstractBasePtrList abs_list;
220   ArgsToAbs(prim, op_args, &abs_list);
221 
222   PrimBpropOptGraphLevel2InfoPtr level_2_graph_info;
223   PrimBpropOptGraphInfoPtr level_1_graph_info;
224   ECacheQrtRes cache_res =
225     GetOptBpfgFromCache(prim, abs_list, need_grad_flags, &level_2_graph_info, &level_1_graph_info);
226 
227   MS_LOG(DEBUG) << "Cache match result " << cache_res << ", prim: " << prim->ToString();
228   if (cache_res == E_LEVEL_2) {
229     MS_LOG(DEBUG) << "Level 2 cache matched, prim: " << prim->ToString();
230     level_2_graph_info->TryFreeArgsValue(op_args, out);
231     return BasicClone(level_2_graph_info->opt_func_graph());
232   }
233 
234   // do step1 opt
235   if (cache_res == E_NOT_FOUND) {
236     bprop_fg->debug_info()->set_name(prim->ToString());
237     level_1_graph_info = PrimBpropOptStep1(bprop_fg);
238     prim_bprop_cache_[prim] = level_1_graph_info;
239   }
240   FuncGraphPtr level_1_graph = BasicClone(level_1_graph_info->opt_func_graph_);
241 
242   level_1_graph->set_attr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flags));
243   // do step2 opt
244   auto new_abs_list = AddOutToAbsList(out, abs_list);
245   level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list, need_grad_flags);
246   level_2_graph_info->TryFreeArgsValue(op_args, out);
247   auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
248   if (enable_grad_cache) {
249     level_1_graph_info->graph_level_2_cache_[abs_list][need_grad_flags] = level_2_graph_info;
250     return BasicClone(level_2_graph_info->opt_func_graph());
251   }
252   return level_2_graph_info->opt_func_graph();
253 }
254 
GenSpecOptBprop(const FuncGraphPtr & bprop_fg,const ValuePtrList & op_args,const ValuePtr & out,const PrimitivePtr & prim,bool hook_flg)255 FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
256                                                  const ValuePtr &out, const PrimitivePtr &prim, bool hook_flg) {
257   MS_EXCEPTION_IF_NULL(bprop_fg);
258   abstract::AbstractBasePtrList abs_list;
259   ArgsToAbs(prim, op_args, &abs_list);
260   if (!hook_flg) {
261     auto iter = tuple_list_bprop_cache_.find(std::pair(prim, abs_list));
262     if (iter != tuple_list_bprop_cache_.end()) {
263       return BasicClone(iter->second);
264     }
265   }
266 
267   // do step1 opt
268   bprop_fg->debug_info()->set_name(prim->ToString());
269   auto level_1_graph_info = PrimBpropOptStep1(bprop_fg);
270 
271   // do step2 opt
272   auto new_abs_list = AddOutToAbsList(out, abs_list);
273   auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
274   level_2_graph_info->TryFreeArgsValue(op_args, out);
275   auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
276   if (!hook_flg && enable_grad_cache) {
277     tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
278   }
279   return level_2_graph_info->opt_func_graph();
280 }
281 
PrimBpropOptStep1(const FuncGraphPtr & bprop_fg) const282 PrimBpropOptGraphInfoPtr PrimBpropOptimizer::PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) const {
283   opt::irpass::OptimizeIRPassLib irpass;
284   auto level_1_graph_info = std::make_shared<PrimBpropOptGraphInfo>();
285   auto prim_bprop_opt_res = std::make_shared<pipeline::Resource>();
286   auto prim_bprop_opt_manage = prim_bprop_opt_res->manager();
287   auto graph_for_cache = BasicClone(bprop_fg);
288   prim_bprop_opt_res->set_func_graph(graph_for_cache);
289   prim_bprop_opt_manage->AddFuncGraph(graph_for_cache);
290   auto opt_bprop_fg = PrimBpOptPassStep1(irpass, prim_bprop_opt_res);
291   level_1_graph_info->opt_func_graph_ = opt_bprop_fg;
292   return level_1_graph_info;
293 }
294 
BindAbsToParameters(const FuncGraphPtr & bprop_fg,const abstract::AbstractBasePtrList & abs_list_input) const295 void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
296                                              const abstract::AbstractBasePtrList &abs_list_input) const {
297   MS_EXCEPTION_IF_NULL(bprop_fg);
298   auto &params = bprop_fg->parameters();
299   if (abs_list_input.size() != params.size()) {
300     MS_LOG(INTERNAL_EXCEPTION) << "Parameter num:" << params.size() << " not match inputs num "
301                                << abs_list_input.size();
302   }
303 
304   for (size_t i = 0; i < abs_list_input.size(); i++) {
305     params[i]->set_abstract(abs_list_input[i]);
306   }
307 }
308 
PrimBpropOptStep2(const FuncGraphPtr & bprop_fg,const abstract::AbstractBasePtrList & abs_list_input,const std::vector<bool> & need_grad_flags) const309 PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
310   const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input,
311   const std::vector<bool> &need_grad_flags) const {
312   opt::irpass::OptimizeIRPassLib irpass;
313   BindAbsToParameters(bprop_fg, abs_list_input);
314   pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
315   auto manager = resource->manager();
316   resource->set_func_graph(bprop_fg);
317   for (const auto &abs : abs_list_input) {
318     if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
319       auto fg_abs = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
320       manager->AddFuncGraph(fg_abs->func_graph());
321     }
322   }
323   manager->AddFuncGraph(bprop_fg);
324   auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource, need_grad_flags);
325   auto level_2_graph_info = std::make_shared<PrimBpropOptGraphLevel2Info>(opt_bprop_fg);
326   level_2_graph_info->AnalysisArgUsingInfo(manager);
327   return level_2_graph_info;
328 }
329 
BpropGraphFinalOpt(const pipeline::ResourcePtr & res) const330 FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const {
331   MS_EXCEPTION_IF_NULL(res);
332   auto after_opt_bg = JitBpropGraphPass(res, false);
333   return after_opt_bg;
334 }
335 
GetOptBpfgFromCache(const PrimitivePtr & prim,const abstract::AbstractBasePtrList & abs_list,const std::vector<bool> & need_grad_flags,PrimBpropOptGraphLevel2InfoPtr * level_2_graph_info,PrimBpropOptGraphInfoPtr * level_1_graph_info)336 ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
337                                                      const abstract::AbstractBasePtrList &abs_list,
338                                                      const std::vector<bool> &need_grad_flags,
339                                                      PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
340                                                      PrimBpropOptGraphInfoPtr *level_1_graph_info) {
341   MS_EXCEPTION_IF_NULL(prim);
342   MS_EXCEPTION_IF_NULL(level_1_graph_info);
343   MS_EXCEPTION_IF_NULL(level_2_graph_info);
344   auto attrs_ = prim->attrs();
345   for (auto &item : attrs_) {
346     MS_LOG(DEBUG) << "prim:" << prim->ToString() << " attr: " << item.first << " value:" << item.second->ToString();
347   }
348 
349   auto iter = prim_bprop_cache_.find(prim);
350   if (iter == prim_bprop_cache_.end()) {
351     return E_NOT_FOUND;
352   }
353 
354   *level_1_graph_info = iter->second;
355   auto second_iter = (*level_1_graph_info)->graph_level_2_cache_.find(abs_list);
356   if (second_iter == (*level_1_graph_info)->graph_level_2_cache_.end()) {
357     return E_LEVEL_1;
358   }
359 
360   auto level_2_iter = (second_iter->second).find(need_grad_flags);
361   if (level_2_iter == second_iter->second.end()) {
362     return E_LEVEL_1;
363   }
364 
365   *level_2_graph_info = level_2_iter->second;
366   return E_LEVEL_2;
367 }
368 
ArgsToAbs(const PrimitivePtr & prim,const ValuePtrList & op_args,abstract::AbstractBasePtrList * abs_list) const369 void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args,
370                                    abstract::AbstractBasePtrList *abs_list) const {
371   MS_EXCEPTION_IF_NULL(prim);
372   MS_EXCEPTION_IF_NULL(abs_list);
373   auto const_input_index = prim->get_const_input_indexes();
374   bool have_const_input = !const_input_index.empty();
375   bool is_const_prim = prim->const_prim();
376   for (size_t i = 0; i < op_args.size(); ++i) {
377     bool is_const_input =
378       have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
379     auto &arg_value = op_args[i];
380     auto arg_abs = arg_value->ToAbstract();
381     if (!is_const_prim && !is_const_input) {
382       arg_abs = arg_abs->PartialBroaden();
383       MS_LOG(DEBUG) << "Broaden for " << prim->ToString();
384     }
385     (void)abs_list->emplace_back(arg_abs);
386   }
387 }
388 
AddOutToAbsList(const ValuePtr & out,const abstract::AbstractBasePtrList & abs_list) const389 abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr &out,
390                                                                   const abstract::AbstractBasePtrList &abs_list) const {
391   MS_EXCEPTION_IF_NULL(out);
392   if (!out->isa<tensor::Tensor>() && !out->isa<ValueSequence>() && !out->isa<None>()) {
393     MS_LOG(EXCEPTION) << "Out value not Tensor, Tuple or None, please check the input arguments.";
394   }
395   abstract::AbstractBasePtrList new_abs_list(abs_list);
396   auto out_abs = out->ToAbstract();
397   out_abs = out_abs->PartialBroaden();
398   (void)new_abs_list.emplace_back(out_abs);
399   (void)new_abs_list.emplace_back(out_abs);
400   return new_abs_list;
401 }
402 }  // namespace ad
403 }  // namespace mindspore
404