1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/func_graph_cloner.h"
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "mindspore/core/ops/framework_ops.h"
20 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
21 #include "pipeline/jit/ps/pass.h"
22
23 namespace mindspore {
24 namespace ad {
TryFreeArgsValue(const ValuePtrList & op_args,const ValuePtr & out)25 void PrimBpropOptGraphLevel2Info::TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out) {
26 // args_value_using_info_ contains out
27 if (args_value_using_info_.size() != op_args.size() + 1) {
28 MS_LOG(INTERNAL_EXCEPTION) << "Parameter size :" << args_value_using_info_.size()
29 << " of bp_graph:" << opt_func_graph_->ToString()
30 << " not match input arguments num:" << op_args.size();
31 }
32
33 ValuePtrList new_args(op_args);
34 (void)new_args.emplace_back(out);
35 TryFreeOneValue(new_args, args_value_using_info_);
36 }
37
TryFreeOneValue(const ValuePtrList & op_args,const std::vector<ParamUsingInfo> & param_info_vec)38 void PrimBpropOptGraphLevel2Info::TryFreeOneValue(const ValuePtrList &op_args,
39 const std::vector<ParamUsingInfo> ¶m_info_vec) {
40 if (param_info_vec.size() != op_args.size()) {
41 MS_LOG(INTERNAL_EXCEPTION) << "Parameter size :" << param_info_vec.size()
42 << " of bp_graph:" << opt_func_graph_->ToString()
43 << " not match input arguments num:" << op_args.size();
44 }
45
46 for (size_t i = 0; i < op_args.size(); ++i) {
47 if (!param_info_vec[i].using_flg_ && !param_info_vec[i].tuple_flg_ && op_args[i]->isa<tensor::Tensor>()) {
48 auto value = op_args[i]->cast<tensor::TensorPtr>();
49 value->set_device_address(nullptr);
50 } else if (param_info_vec[i].tuple_flg_ && op_args[i]->isa<ValueTuple>()) {
51 auto value = op_args[i]->cast<ValueTuplePtr>();
52 MS_EXCEPTION_IF_NULL(value);
53 TryFreeOneValue(value->value(), param_info_vec[i].sub_using_info_);
54 }
55 }
56 }
57
AnalysisArgUsingInfo(const FuncGraphManagerPtr & manager)58 void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager) {
59 MS_EXCEPTION_IF_NULL(manager);
60 if (analysis_finish_flg_) {
61 return;
62 }
63 MS_EXCEPTION_IF_NULL(opt_func_graph_);
64 auto ¶ms = opt_func_graph_->parameters();
65 const auto &node_users = manager->node_users();
66 args_value_using_info_.resize(params.size() - 1);
67 // analysis value using flg except dout
68 for (size_t i = 0; i < params.size() - 1; ++i) {
69 auto ¶m = params[i];
70 auto &arg_info = args_value_using_info_[i];
71 ArgInfoRefresh(param, &arg_info);
72 AnalysisNodeUsingInfo(node_users, param, &arg_info);
73 }
74 analysis_finish_flg_ = true;
75 }
76
AnalysisNodeUsingInfo(const NodeUsersMap & node_users,const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info) const77 void PrimBpropOptGraphLevel2Info::AnalysisNodeUsingInfo(const NodeUsersMap &node_users,
78 const std::shared_ptr<AnfNode> ¶m,
79 ParamUsingInfo *arg_info) const {
80 MS_EXCEPTION_IF_NULL(arg_info);
81 auto iter = node_users.find(param);
82 if (iter == node_users.end()) {
83 arg_info->using_flg_ = false;
84 return;
85 }
86
87 // tensor return directly
88 if (!arg_info->tuple_flg_) {
89 arg_info->using_flg_ = true;
90 return;
91 }
92
93 // specific process for tuple parameter, may only partial items used
94 const auto &users_info = iter->second;
95 for (auto &user_info : users_info) {
96 auto user_node = user_info.first;
97 arg_info->using_flg_ = true;
98 MS_LOG(DEBUG) << "param:" << param->ToString() << " used by node:" << user_node->ToString();
99 if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
100 for (auto &sub_info : arg_info->sub_using_info_) {
101 sub_info.using_flg_ = true;
102 }
103 } else {
104 AalysisForTupleGetItem(node_users, param, arg_info, user_node);
105 }
106 }
107 }
AalysisForTupleGetItem(const NodeUsersMap & node_users,const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info,const AnfNodePtr & user_node) const108 void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &node_users,
109 const std::shared_ptr<AnfNode> ¶m,
110 ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const {
111 MS_EXCEPTION_IF_NULL(arg_info);
112 MS_EXCEPTION_IF_NULL(user_node);
113 auto cnode = user_node->cast<CNodePtr>();
114 MS_EXCEPTION_IF_NULL(cnode);
115 const size_t tuple_get_item_size = 3;
116 const size_t index = 2;
117 if (cnode->size() != tuple_get_item_size) {
118 MS_LOG(INTERNAL_EXCEPTION) << "TupleGetItem Node:" << user_node->ToString()
119 << " of bp_graph:" << opt_func_graph_->ToString() << "input size is:" << cnode->size();
120 }
121 auto idx_node = cnode->input(index);
122 if (!idx_node->isa<ValueNode>()) {
123 MS_LOG(INTERNAL_EXCEPTION) << "Tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
124 << " unexpected used by node:" << user_node->ToString()
125 << " TupleGetItem idx node:" << idx_node->ToString();
126 }
127
128 auto vnode = idx_node->cast<ValueNodePtr>();
129 auto value_ptr = vnode->value();
130 if (value_ptr == nullptr || !value_ptr->isa<Int64Imm>()) {
131 MS_LOG(INTERNAL_EXCEPTION) << "Tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
132 << " unexpected used by node:" << user_node->ToString()
133 << " TupleGetItem idx node:" << idx_node->ToString() << " idx Value :" << value_ptr;
134 }
135
136 auto idx = LongToSize(value_ptr->cast<Int64ImmPtr>()->value());
137 arg_info->sub_using_info_[idx].using_flg_ = true;
138 ArgInfoRefresh(cnode, &(arg_info->sub_using_info_[idx]));
139
140 if (arg_info->tuple_flg_) {
141 AnalysisNodeUsingInfo(node_users, cnode, &(arg_info->sub_using_info_[idx]));
142 }
143 }
144
ArgInfoRefresh(const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info) const145 void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr<AnfNode> ¶m,
146 ParamUsingInfo *arg_info) const {
147 MS_EXCEPTION_IF_NULL(arg_info);
148 MS_EXCEPTION_IF_NULL(param);
149 auto abs = param->abstract();
150 MS_EXCEPTION_IF_NULL(abs);
151 if (abs->isa<abstract::AbstractTensor>()) {
152 arg_info->tuple_flg_ = false;
153 MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTensor";
154 } else if (abs->isa<abstract::AbstractTuple>()) {
155 auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
156 MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTuple";
157 arg_info->tuple_flg_ = true;
158 arg_info->tuple_size_ = abs_tuple->size();
159 arg_info->sub_using_info_.resize(abs_tuple->size());
160 } else {
161 arg_info->tuple_flg_ = false;
162 }
163 }
164
GetPrimBpropOptimizerInst()165 PrimBpropOptimizer &PrimBpropOptimizer::GetPrimBpropOptimizerInst() {
166 static PrimBpropOptimizer g_prim_bprop_opt = PrimBpropOptimizer();
167 return g_prim_bprop_opt;
168 }
169
Clear()170 void PrimBpropOptimizer::Clear() {
171 prim_bprop_cache_.clear();
172 tuple_list_bprop_cache_.clear();
173 }
174
175 // bprop_fg has the signature:
176 // (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
177 // c_node contains the prim(input 0) and the input parameters of that prim;
178 // op_args contains the arguments list of each input parameters, it maybe tensor or tuple
179 // out contains the out of c_node;
OptimizeBPropFuncGraph(const FuncGraphPtr & bprop_fg,const CNodePtr & c_node,const ValuePtrList & op_args,const ValuePtr & out)180 FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node,
181 const ValuePtrList &op_args, const ValuePtr &out) {
182 MS_EXCEPTION_IF_NULL(bprop_fg);
183 MS_EXCEPTION_IF_NULL(c_node);
184 MS_EXCEPTION_IF_NULL(out);
185 auto &inputs = c_node->inputs();
186 if (inputs.size() < 1 || inputs.size() - 1 != op_args.size()) {
187 MS_LOG(INTERNAL_EXCEPTION) << "The parameters num " << (inputs.size() - 1) << " not match arguments num "
188 << op_args.size() << ", CNode:" << c_node->ToString()
189 << " grad:" << bprop_fg->ToString();
190 }
191
192 if (!IsValueNode<Primitive>(inputs[0])) {
193 MS_LOG(INTERNAL_EXCEPTION) << "CNode:" << c_node->ToString()
194 << " not a primitive node, input_0 is:" << inputs[0]->ToString();
195 }
196
197 PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
198 MS_LOG(DEBUG) << "Hash of prim " << prim->ToString() << " is:" << prim->hash();
199
200 // kPrimHookBackward
201 bool hookback_flg =
202 IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
203 if (hookback_flg || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
204 return GenSpecOptBprop(bprop_fg, op_args, out, prim, hookback_flg);
205 }
206
207 std::vector<bool> need_grad_flags;
208 if (c_node->HasAttr(kAttrNeedGradFlagOfInputs)) {
209 need_grad_flags = GetValue<std::vector<bool>>(c_node->GetAttr(kAttrNeedGradFlagOfInputs));
210 }
211
212 return GetOptBpropFromCache(bprop_fg, op_args, out, prim, need_grad_flags);
213 }
214
GetOptBpropFromCache(const FuncGraphPtr & bprop_fg,const ValuePtrList & op_args,const ValuePtr & out,const PrimitivePtr & prim,const std::vector<bool> & need_grad_flags)215 FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
216 const ValuePtr &out, const PrimitivePtr &prim,
217 const std::vector<bool> &need_grad_flags) {
218 MS_EXCEPTION_IF_NULL(bprop_fg);
219 abstract::AbstractBasePtrList abs_list;
220 ArgsToAbs(prim, op_args, &abs_list);
221
222 PrimBpropOptGraphLevel2InfoPtr level_2_graph_info;
223 PrimBpropOptGraphInfoPtr level_1_graph_info;
224 ECacheQrtRes cache_res =
225 GetOptBpfgFromCache(prim, abs_list, need_grad_flags, &level_2_graph_info, &level_1_graph_info);
226
227 MS_LOG(DEBUG) << "Cache match result " << cache_res << ", prim: " << prim->ToString();
228 if (cache_res == E_LEVEL_2) {
229 MS_LOG(DEBUG) << "Level 2 cache matched, prim: " << prim->ToString();
230 level_2_graph_info->TryFreeArgsValue(op_args, out);
231 return BasicClone(level_2_graph_info->opt_func_graph());
232 }
233
234 // do step1 opt
235 if (cache_res == E_NOT_FOUND) {
236 bprop_fg->debug_info()->set_name(prim->ToString());
237 level_1_graph_info = PrimBpropOptStep1(bprop_fg);
238 prim_bprop_cache_[prim] = level_1_graph_info;
239 }
240 FuncGraphPtr level_1_graph = BasicClone(level_1_graph_info->opt_func_graph_);
241
242 level_1_graph->set_attr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flags));
243 // do step2 opt
244 auto new_abs_list = AddOutToAbsList(out, abs_list);
245 level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list, need_grad_flags);
246 level_2_graph_info->TryFreeArgsValue(op_args, out);
247 auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
248 if (enable_grad_cache) {
249 level_1_graph_info->graph_level_2_cache_[abs_list][need_grad_flags] = level_2_graph_info;
250 return BasicClone(level_2_graph_info->opt_func_graph());
251 }
252 return level_2_graph_info->opt_func_graph();
253 }
254
GenSpecOptBprop(const FuncGraphPtr & bprop_fg,const ValuePtrList & op_args,const ValuePtr & out,const PrimitivePtr & prim,bool hook_flg)255 FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
256 const ValuePtr &out, const PrimitivePtr &prim, bool hook_flg) {
257 MS_EXCEPTION_IF_NULL(bprop_fg);
258 abstract::AbstractBasePtrList abs_list;
259 ArgsToAbs(prim, op_args, &abs_list);
260 if (!hook_flg) {
261 auto iter = tuple_list_bprop_cache_.find(std::pair(prim, abs_list));
262 if (iter != tuple_list_bprop_cache_.end()) {
263 return BasicClone(iter->second);
264 }
265 }
266
267 // do step1 opt
268 bprop_fg->debug_info()->set_name(prim->ToString());
269 auto level_1_graph_info = PrimBpropOptStep1(bprop_fg);
270
271 // do step2 opt
272 auto new_abs_list = AddOutToAbsList(out, abs_list);
273 auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
274 level_2_graph_info->TryFreeArgsValue(op_args, out);
275 auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
276 if (!hook_flg && enable_grad_cache) {
277 tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
278 }
279 return level_2_graph_info->opt_func_graph();
280 }
281
PrimBpropOptStep1(const FuncGraphPtr & bprop_fg) const282 PrimBpropOptGraphInfoPtr PrimBpropOptimizer::PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) const {
283 opt::irpass::OptimizeIRPassLib irpass;
284 auto level_1_graph_info = std::make_shared<PrimBpropOptGraphInfo>();
285 auto prim_bprop_opt_res = std::make_shared<pipeline::Resource>();
286 auto prim_bprop_opt_manage = prim_bprop_opt_res->manager();
287 auto graph_for_cache = BasicClone(bprop_fg);
288 prim_bprop_opt_res->set_func_graph(graph_for_cache);
289 prim_bprop_opt_manage->AddFuncGraph(graph_for_cache);
290 auto opt_bprop_fg = PrimBpOptPassStep1(irpass, prim_bprop_opt_res);
291 level_1_graph_info->opt_func_graph_ = opt_bprop_fg;
292 return level_1_graph_info;
293 }
294
BindAbsToParameters(const FuncGraphPtr & bprop_fg,const abstract::AbstractBasePtrList & abs_list_input) const295 void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
296 const abstract::AbstractBasePtrList &abs_list_input) const {
297 MS_EXCEPTION_IF_NULL(bprop_fg);
298 auto ¶ms = bprop_fg->parameters();
299 if (abs_list_input.size() != params.size()) {
300 MS_LOG(INTERNAL_EXCEPTION) << "Parameter num:" << params.size() << " not match inputs num "
301 << abs_list_input.size();
302 }
303
304 for (size_t i = 0; i < abs_list_input.size(); i++) {
305 params[i]->set_abstract(abs_list_input[i]);
306 }
307 }
308
PrimBpropOptStep2(const FuncGraphPtr & bprop_fg,const abstract::AbstractBasePtrList & abs_list_input,const std::vector<bool> & need_grad_flags) const309 PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
310 const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input,
311 const std::vector<bool> &need_grad_flags) const {
312 opt::irpass::OptimizeIRPassLib irpass;
313 BindAbsToParameters(bprop_fg, abs_list_input);
314 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
315 auto manager = resource->manager();
316 resource->set_func_graph(bprop_fg);
317 for (const auto &abs : abs_list_input) {
318 if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
319 auto fg_abs = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
320 manager->AddFuncGraph(fg_abs->func_graph());
321 }
322 }
323 manager->AddFuncGraph(bprop_fg);
324 auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource, need_grad_flags);
325 auto level_2_graph_info = std::make_shared<PrimBpropOptGraphLevel2Info>(opt_bprop_fg);
326 level_2_graph_info->AnalysisArgUsingInfo(manager);
327 return level_2_graph_info;
328 }
329
BpropGraphFinalOpt(const pipeline::ResourcePtr & res) const330 FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const {
331 MS_EXCEPTION_IF_NULL(res);
332 auto after_opt_bg = JitBpropGraphPass(res, false);
333 return after_opt_bg;
334 }
335
GetOptBpfgFromCache(const PrimitivePtr & prim,const abstract::AbstractBasePtrList & abs_list,const std::vector<bool> & need_grad_flags,PrimBpropOptGraphLevel2InfoPtr * level_2_graph_info,PrimBpropOptGraphInfoPtr * level_1_graph_info)336 ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
337 const abstract::AbstractBasePtrList &abs_list,
338 const std::vector<bool> &need_grad_flags,
339 PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
340 PrimBpropOptGraphInfoPtr *level_1_graph_info) {
341 MS_EXCEPTION_IF_NULL(prim);
342 MS_EXCEPTION_IF_NULL(level_1_graph_info);
343 MS_EXCEPTION_IF_NULL(level_2_graph_info);
344 auto attrs_ = prim->attrs();
345 for (auto &item : attrs_) {
346 MS_LOG(DEBUG) << "prim:" << prim->ToString() << " attr: " << item.first << " value:" << item.second->ToString();
347 }
348
349 auto iter = prim_bprop_cache_.find(prim);
350 if (iter == prim_bprop_cache_.end()) {
351 return E_NOT_FOUND;
352 }
353
354 *level_1_graph_info = iter->second;
355 auto second_iter = (*level_1_graph_info)->graph_level_2_cache_.find(abs_list);
356 if (second_iter == (*level_1_graph_info)->graph_level_2_cache_.end()) {
357 return E_LEVEL_1;
358 }
359
360 auto level_2_iter = (second_iter->second).find(need_grad_flags);
361 if (level_2_iter == second_iter->second.end()) {
362 return E_LEVEL_1;
363 }
364
365 *level_2_graph_info = level_2_iter->second;
366 return E_LEVEL_2;
367 }
368
ArgsToAbs(const PrimitivePtr & prim,const ValuePtrList & op_args,abstract::AbstractBasePtrList * abs_list) const369 void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args,
370 abstract::AbstractBasePtrList *abs_list) const {
371 MS_EXCEPTION_IF_NULL(prim);
372 MS_EXCEPTION_IF_NULL(abs_list);
373 auto const_input_index = prim->get_const_input_indexes();
374 bool have_const_input = !const_input_index.empty();
375 bool is_const_prim = prim->const_prim();
376 for (size_t i = 0; i < op_args.size(); ++i) {
377 bool is_const_input =
378 have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
379 auto &arg_value = op_args[i];
380 auto arg_abs = arg_value->ToAbstract();
381 if (!is_const_prim && !is_const_input) {
382 arg_abs = arg_abs->PartialBroaden();
383 MS_LOG(DEBUG) << "Broaden for " << prim->ToString();
384 }
385 (void)abs_list->emplace_back(arg_abs);
386 }
387 }
388
AddOutToAbsList(const ValuePtr & out,const abstract::AbstractBasePtrList & abs_list) const389 abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr &out,
390 const abstract::AbstractBasePtrList &abs_list) const {
391 MS_EXCEPTION_IF_NULL(out);
392 if (!out->isa<tensor::Tensor>() && !out->isa<ValueSequence>() && !out->isa<None>()) {
393 MS_LOG(EXCEPTION) << "Out value not Tensor, Tuple or None, please check the input arguments.";
394 }
395 abstract::AbstractBasePtrList new_abs_list(abs_list);
396 auto out_abs = out->ToAbstract();
397 out_abs = out_abs->PartialBroaden();
398 (void)new_abs_list.emplace_back(out_abs);
399 (void)new_abs_list.emplace_back(out_abs);
400 return new_abs_list;
401 }
402 } // namespace ad
403 } // namespace mindspore
404