• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 #include "ir/func_graph_cloner.h"
19 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
20 #include "pipeline/jit/pass.h"
21 
22 namespace mindspore {
23 namespace ad {
TryFreeArgsValue(const ValuePtrList & op_args,const ValuePtr & out)24 void PrimBpropOptGraphLevel2Info::TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out) {
25   // args_value_using_info_ contains out
26   if (args_value_using_info_.size() != op_args.size() + 1) {
27     MS_LOG(EXCEPTION) << "param size :" << args_value_using_info_.size()
28                       << " of bp_graph:" << opt_func_graph_->ToString()
29                       << " not match input arguments num:" << op_args.size();
30   }
31 
32   ValuePtrList new_args(op_args);
33   (void)new_args.emplace_back(out);
34   TryFreeOneValue(new_args, args_value_using_info_);
35 }
36 
TryFreeOneValue(const ValuePtrList & op_args,const std::vector<ParamUsingInfo> & param_info_vec)37 void PrimBpropOptGraphLevel2Info::TryFreeOneValue(const ValuePtrList &op_args,
38                                                   const std::vector<ParamUsingInfo> &param_info_vec) {
39   if (param_info_vec.size() != op_args.size()) {
40     MS_LOG(EXCEPTION) << "param size :" << param_info_vec.size() << " of bp_graph:" << opt_func_graph_->ToString()
41                       << " not match input arguments num:" << op_args.size();
42   }
43 
44   for (size_t i = 0; i < op_args.size(); ++i) {
45     if (!param_info_vec[i].using_flg_ && !param_info_vec[i].tuple_flg_ && op_args[i]->isa<tensor::Tensor>()) {
46       auto value = op_args[i]->cast<tensor::TensorPtr>();
47       value->set_device_address(nullptr);
48     } else if (param_info_vec[i].tuple_flg_ && op_args[i]->isa<ValueTuple>()) {
49       auto value = op_args[i]->cast<ValueTuplePtr>();
50       MS_EXCEPTION_IF_NULL(value);
51       TryFreeOneValue(value->value(), param_info_vec[i].sub_using_info_);
52     }
53   }
54 }
55 
AnalysisArgUsingInfo(const FuncGraphManagerPtr & manager)56 void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager) {
57   MS_EXCEPTION_IF_NULL(manager);
58   if (analysis_finish_flg_) {
59     return;
60   }
61   MS_EXCEPTION_IF_NULL(opt_func_graph_);
62   auto &params = opt_func_graph_->parameters();
63   const auto &node_users = manager->node_users();
64   args_value_using_info_.resize(params.size() - 1);
65   // analysis value using flg except dout
66   for (size_t i = 0; i < params.size() - 1; ++i) {
67     auto &param = params[i];
68     auto &arg_info = args_value_using_info_[i];
69     ArgInfoRefresh(param, &arg_info);
70     AnalysisNodeUsingInfo(node_users, param, &arg_info);
71   }
72   analysis_finish_flg_ = true;
73 }
74 
AnalysisNodeUsingInfo(const NodeUsersMap & node_users,const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info) const75 void PrimBpropOptGraphLevel2Info::AnalysisNodeUsingInfo(const NodeUsersMap &node_users,
76                                                         const std::shared_ptr<AnfNode> &param,
77                                                         ParamUsingInfo *arg_info) const {
78   MS_EXCEPTION_IF_NULL(arg_info);
79   auto iter = node_users.find(param);
80   if (iter == node_users.end()) {
81     arg_info->using_flg_ = false;
82     return;
83   }
84 
85   // tensor return directly
86   if (!arg_info->tuple_flg_) {
87     arg_info->using_flg_ = true;
88     return;
89   }
90 
91   // specific process for tuple parameter, may only partial items used
92   const auto &users_info = iter->second;
93   for (auto &user_info : users_info) {
94     auto user_node = user_info.first;
95     arg_info->using_flg_ = true;
96     MS_LOG(DEBUG) << "param:" << param->ToString() << " used by node:" << user_node->ToString();
97     if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
98       for (auto &sub_info : arg_info->sub_using_info_) {
99         sub_info.using_flg_ = true;
100       }
101     } else {
102       AalysisForTupleGetItem(node_users, param, arg_info, user_node);
103     }
104   }
105 }
AalysisForTupleGetItem(const NodeUsersMap & node_users,const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info,const AnfNodePtr & user_node) const106 void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &node_users,
107                                                          const std::shared_ptr<AnfNode> &param,
108                                                          ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const {
109   MS_EXCEPTION_IF_NULL(arg_info);
110   MS_EXCEPTION_IF_NULL(user_node);
111   auto cnode = user_node->cast<CNodePtr>();
112   MS_EXCEPTION_IF_NULL(cnode);
113   const size_t tuple_get_item_size = 3;
114   const size_t index = 2;
115   if (cnode->size() != tuple_get_item_size) {
116     MS_LOG(EXCEPTION) << "TupleGetItem Node:" << user_node->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
117                       << "input size is:" << cnode->size();
118   }
119   auto idx_node = cnode->input(index);
120   if (!idx_node->isa<ValueNode>()) {
121     MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
122                       << " unexpected used by node:" << user_node->ToString()
123                       << " TupleGetItem idx node:" << idx_node->ToString();
124   }
125 
126   auto vnode = idx_node->cast<ValueNodePtr>();
127   auto value_ptr = vnode->value();
128   if (value_ptr == nullptr || !value_ptr->isa<Int64Imm>()) {
129     MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
130                       << " unexpected used by node:" << user_node->ToString()
131                       << " TupleGetItem idx node:" << idx_node->ToString() << " idx Value :" << value_ptr;
132   }
133 
134   auto idx = LongToSize(value_ptr->cast<Int64ImmPtr>()->value());
135   arg_info->sub_using_info_[idx].using_flg_ = true;
136   ArgInfoRefresh(cnode, &(arg_info->sub_using_info_[idx]));
137 
138   if (arg_info->tuple_flg_) {
139     AnalysisNodeUsingInfo(node_users, cnode, &(arg_info->sub_using_info_[idx]));
140   }
141 }
142 
ArgInfoRefresh(const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info) const143 void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr<AnfNode> &param,
144                                                  ParamUsingInfo *arg_info) const {
145   MS_EXCEPTION_IF_NULL(arg_info);
146   MS_EXCEPTION_IF_NULL(param);
147   auto abs = param->abstract();
148   MS_EXCEPTION_IF_NULL(abs);
149   if (abs->isa<abstract::AbstractTensor>()) {
150     arg_info->tuple_flg_ = false;
151     MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTensor";
152   } else if (abs->isa<abstract::AbstractTuple>()) {
153     auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
154     MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTuple";
155     arg_info->tuple_flg_ = true;
156     arg_info->tuple_size_ = abs_tuple->size();
157     arg_info->sub_using_info_.resize(abs_tuple->size());
158   } else {
159     arg_info->tuple_flg_ = false;
160   }
161 }
162 
GetPrimBpropOptimizerInst()163 PrimBpropOptimizer &PrimBpropOptimizer::GetPrimBpropOptimizerInst() {
164   static PrimBpropOptimizer g_prim_bprop_opt = PrimBpropOptimizer();
165   return g_prim_bprop_opt;
166 }
167 
Clear()168 void PrimBpropOptimizer::Clear() {
169   prim_bprop_cache_.clear();
170   tuple_list_bprop_cache_.clear();
171 }
172 
173 // bprop_fg has the signature:
174 // (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
175 // c_node contains the prim(input 0) and the input parameters of that prim;
176 // op_args contains the arguments list of each input parameters, it maybe tensor or tuple
177 // out contains the out of c_node;
OptimizeBPropFuncGraph(const FuncGraphPtr & bprop_fg,const CNodePtr & c_node,const ValuePtrList & op_args,const ValuePtr & out)178 FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node,
179                                                         const ValuePtrList &op_args, const ValuePtr &out) {
180   MS_EXCEPTION_IF_NULL(bprop_fg);
181   MS_EXCEPTION_IF_NULL(c_node);
182   MS_EXCEPTION_IF_NULL(out);
183   auto &inputs = c_node->inputs();
184   if (inputs.size() < 1 || inputs.size() - 1 != op_args.size()) {
185     MS_LOG(EXCEPTION) << "The parameters num " << inputs.size() - 1 << " not match arguments num " << op_args.size()
186                       << ", CNode:" << c_node->ToString() << " grap:" << bprop_fg->ToString();
187   }
188 
189   if (!IsValueNode<Primitive>(inputs[0])) {
190     MS_LOG(EXCEPTION) << "CNode:" << c_node->ToString()
191                       << " not a primitive node, input_0 is:" << inputs[0]->ToString();
192   }
193 
194   PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
195   MS_LOG(DEBUG) << "Hash of prim " << prim->ToString() << " is:" << prim->hash();
196 
197   //  kPrimHookBackward
198   bool hookback_flg = IsPrimitiveEquals(prim, prim::kPrimHookBackward);
199   if (hookback_flg || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
200     return GenSpecOptBprop(bprop_fg, op_args, out, prim, hookback_flg);
201   }
202 
203   return GetOptBpropFromCache(bprop_fg, op_args, out, prim);
204 }
205 
GetOptBpropFromCache(const FuncGraphPtr & bprop_fg,const ValuePtrList & op_args,const ValuePtr & out,const PrimitivePtr & prim)206 FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
207                                                       const ValuePtr &out, const PrimitivePtr &prim) {
208   MS_EXCEPTION_IF_NULL(bprop_fg);
209   abstract::AbstractBasePtrList abs_list;
210   ArgsToAbs(prim, op_args, &abs_list);
211 
212   PrimBpropOptGraphLevel2InfoPtr level_2_graph_info;
213   PrimBpropOptGraphInfoPtr level_1_graph_info;
214   ECacheQrtRes cache_res = GetOptBpfgFromCache(prim, abs_list, &level_2_graph_info, &level_1_graph_info);
215 
216   MS_LOG(DEBUG) << "Cache match result " << cache_res << ", prim: " << prim->ToString();
217   if (cache_res == E_LEVEL_2) {
218     MS_LOG(DEBUG) << "Level 2 cache matched, prim: " << prim->ToString();
219     level_2_graph_info->TryFreeArgsValue(op_args, out);
220     return BasicClone(level_2_graph_info->opt_func_graph());
221   }
222 
223   // do step1 opt
224   if (cache_res == E_NOT_FOUND) {
225     bprop_fg->debug_info()->set_name(prim->ToString());
226     level_1_graph_info = PrimBpropOptStep1(bprop_fg);
227     prim_bprop_cache_[prim] = level_1_graph_info;
228   }
229   FuncGraphPtr level_1_graph = BasicClone(level_1_graph_info->opt_func_graph_);
230 
231   // do step2 opt
232   auto new_abs_list = AddOutToAbsList(out, abs_list);
233   level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list);
234   level_2_graph_info->TryFreeArgsValue(op_args, out);
235   auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
236   if (enable_grad_cache) {
237     level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
238     return BasicClone(level_2_graph_info->opt_func_graph());
239   }
240   return level_2_graph_info->opt_func_graph();
241 }
242 
GenSpecOptBprop(const FuncGraphPtr & bprop_fg,const ValuePtrList & op_args,const ValuePtr & out,const PrimitivePtr & prim,bool hook_flg)243 FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
244                                                  const ValuePtr &out, const PrimitivePtr &prim, bool hook_flg) {
245   MS_EXCEPTION_IF_NULL(bprop_fg);
246   abstract::AbstractBasePtrList abs_list;
247   ArgsToAbs(prim, op_args, &abs_list);
248   if (!hook_flg) {
249     auto iter = tuple_list_bprop_cache_.find(std::pair(prim, abs_list));
250     if (iter != tuple_list_bprop_cache_.end()) {
251       return BasicClone(iter->second);
252     }
253   }
254 
255   // do step1 opt
256   bprop_fg->debug_info()->set_name(prim->ToString());
257   auto level_1_graph_info = PrimBpropOptStep1(bprop_fg);
258 
259   // do step2 opt
260   auto new_abs_list = AddOutToAbsList(out, abs_list);
261   auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
262   level_2_graph_info->TryFreeArgsValue(op_args, out);
263   auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
264   if (!hook_flg && enable_grad_cache) {
265     tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
266   }
267   return level_2_graph_info->opt_func_graph();
268 }
269 
PrimBpropOptStep1(const FuncGraphPtr & bprop_fg)270 PrimBpropOptGraphInfoPtr PrimBpropOptimizer::PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) {
271   opt::irpass::OptimizeIRPassLib irpass;
272   auto level_1_graph_info = std::make_shared<PrimBpropOptGraphInfo>();
273   auto prim_bprop_opt_res = std::make_shared<pipeline::Resource>();
274   auto prim_bprop_opt_manage = prim_bprop_opt_res->manager();
275   auto graph_for_cache = BasicClone(bprop_fg);
276   prim_bprop_opt_res->set_func_graph(graph_for_cache);
277   prim_bprop_opt_manage->AddFuncGraph(graph_for_cache);
278   auto opt_bprop_fg = PrimBpOptPassStep1(irpass, prim_bprop_opt_res);
279   level_1_graph_info->opt_func_graph_ = opt_bprop_fg;
280   return level_1_graph_info;
281 }
282 
BindAbsToParameters(const FuncGraphPtr & bprop_fg,const abstract::AbstractBasePtrList & abs_list_input)283 void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
284                                              const abstract::AbstractBasePtrList &abs_list_input) {
285   MS_EXCEPTION_IF_NULL(bprop_fg);
286   auto &params = bprop_fg->parameters();
287   if (abs_list_input.size() != params.size()) {
288     MS_LOG(EXCEPTION) << "Param num:" << params.size() << " not match inputs num " << abs_list_input.size();
289   }
290 
291   for (size_t i = 0; i < abs_list_input.size(); i++) {
292     params[i]->set_abstract(abs_list_input[i]);
293   }
294 }
295 
PrimBpropOptStep2(const FuncGraphPtr & bprop_fg,const abstract::AbstractBasePtrList & abs_list_input)296 PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
297   const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) {
298   opt::irpass::OptimizeIRPassLib irpass;
299   BindAbsToParameters(bprop_fg, abs_list_input);
300   pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
301   auto manager = resource->manager();
302   resource->set_func_graph(bprop_fg);
303   manager->AddFuncGraph(bprop_fg);
304   auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource);
305   auto level_2_graph_info = std::make_shared<PrimBpropOptGraphLevel2Info>(opt_bprop_fg);
306   level_2_graph_info->AnalysisArgUsingInfo(manager);
307   return level_2_graph_info;
308 }
309 
BpropGraphFinalOpt(const pipeline::ResourcePtr & res) const310 FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const {
311   MS_EXCEPTION_IF_NULL(res);
312   auto after_opt_bg = BpropGraphFinalOptPass(res);
313   return after_opt_bg;
314 }
315 
GetOptBpfgFromCache(const PrimitivePtr & prim,const abstract::AbstractBasePtrList & abs_list,PrimBpropOptGraphLevel2InfoPtr * level_2_graph_info,PrimBpropOptGraphInfoPtr * level_1_graph_info)316 ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
317                                                      const abstract::AbstractBasePtrList &abs_list,
318                                                      PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
319                                                      PrimBpropOptGraphInfoPtr *level_1_graph_info) {
320   MS_EXCEPTION_IF_NULL(prim);
321   MS_EXCEPTION_IF_NULL(level_1_graph_info);
322   MS_EXCEPTION_IF_NULL(level_2_graph_info);
323   auto attrs_ = prim->attrs();
324   for (auto &item : attrs_) {
325     MS_LOG(DEBUG) << "prim:" << prim->ToString() << " attr: " << item.first << " value:" << item.second->ToString();
326   }
327 
328   auto iter = prim_bprop_cache_.find(prim);
329   if (iter == prim_bprop_cache_.end()) {
330     return E_NOT_FOUND;
331   }
332 
333   *level_1_graph_info = iter->second;
334   auto second_iter = (*level_1_graph_info)->graph_level_2_cache_.find(abs_list);
335   if (second_iter == (*level_1_graph_info)->graph_level_2_cache_.end()) {
336     return E_LEVEL_1;
337   }
338   *level_2_graph_info = second_iter->second;
339   return E_LEVEL_2;
340 }
341 
ArgsToAbs(const PrimitivePtr & prim,const ValuePtrList & op_args,abstract::AbstractBasePtrList * abs_list)342 void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args,
343                                    abstract::AbstractBasePtrList *abs_list) {
344   MS_EXCEPTION_IF_NULL(prim);
345   MS_EXCEPTION_IF_NULL(abs_list);
346   auto const_input_index = prim->get_const_input_indexes();
347   bool have_const_input = !const_input_index.empty();
348   bool is_const_prim = prim->is_const_prim();
349   for (size_t i = 0; i < op_args.size(); ++i) {
350     bool is_const_input =
351       have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
352     auto &arg_value = op_args[i];
353     auto arg_abs = arg_value->ToAbstract();
354     if (!is_const_prim && !is_const_input) {
355       arg_abs = arg_abs->PartialBroaden();
356       MS_LOG(DEBUG) << "Broaden for " << prim->ToString();
357     }
358     (void)abs_list->emplace_back(arg_abs);
359   }
360 }
361 
AddOutToAbsList(const ValuePtr & out,const abstract::AbstractBasePtrList & abs_list)362 abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr &out,
363                                                                   const abstract::AbstractBasePtrList &abs_list) {
364   MS_EXCEPTION_IF_NULL(out);
365   if (!out->isa<tensor::Tensor>() && !out->isa<ValueTuple>()) {
366     MS_LOG(EXCEPTION) << "Out value not Tensor or Tuple, please check the input arguments.";
367   }
368   abstract::AbstractBasePtrList new_abs_list(abs_list);
369   auto out_abs = out->ToAbstract();
370   out_abs = out_abs->PartialBroaden();
371   (void)new_abs_list.emplace_back(out_abs);
372   (void)new_abs_list.emplace_back(out_abs);
373   return new_abs_list;
374 }
375 }  // namespace ad
376 }  // namespace mindspore
377