1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <memory>
18 #include "ir/func_graph_cloner.h"
19 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
20 #include "pipeline/jit/pass.h"
21
22 namespace mindspore {
23 namespace ad {
TryFreeArgsValue(const ValuePtrList & op_args,const ValuePtr & out)24 void PrimBpropOptGraphLevel2Info::TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out) {
25 // args_value_using_info_ contains out
26 if (args_value_using_info_.size() != op_args.size() + 1) {
27 MS_LOG(EXCEPTION) << "param size :" << args_value_using_info_.size()
28 << " of bp_graph:" << opt_func_graph_->ToString()
29 << " not match input arguments num:" << op_args.size();
30 }
31
32 ValuePtrList new_args(op_args);
33 (void)new_args.emplace_back(out);
34 TryFreeOneValue(new_args, args_value_using_info_);
35 }
36
TryFreeOneValue(const ValuePtrList & op_args,const std::vector<ParamUsingInfo> & param_info_vec)37 void PrimBpropOptGraphLevel2Info::TryFreeOneValue(const ValuePtrList &op_args,
38 const std::vector<ParamUsingInfo> ¶m_info_vec) {
39 if (param_info_vec.size() != op_args.size()) {
40 MS_LOG(EXCEPTION) << "param size :" << param_info_vec.size() << " of bp_graph:" << opt_func_graph_->ToString()
41 << " not match input arguments num:" << op_args.size();
42 }
43
44 for (size_t i = 0; i < op_args.size(); ++i) {
45 if (!param_info_vec[i].using_flg_ && !param_info_vec[i].tuple_flg_ && op_args[i]->isa<tensor::Tensor>()) {
46 auto value = op_args[i]->cast<tensor::TensorPtr>();
47 value->set_device_address(nullptr);
48 } else if (param_info_vec[i].tuple_flg_ && op_args[i]->isa<ValueTuple>()) {
49 auto value = op_args[i]->cast<ValueTuplePtr>();
50 MS_EXCEPTION_IF_NULL(value);
51 TryFreeOneValue(value->value(), param_info_vec[i].sub_using_info_);
52 }
53 }
54 }
55
AnalysisArgUsingInfo(const FuncGraphManagerPtr & manager)56 void PrimBpropOptGraphLevel2Info::AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager) {
57 MS_EXCEPTION_IF_NULL(manager);
58 if (analysis_finish_flg_) {
59 return;
60 }
61 MS_EXCEPTION_IF_NULL(opt_func_graph_);
62 auto ¶ms = opt_func_graph_->parameters();
63 const auto &node_users = manager->node_users();
64 args_value_using_info_.resize(params.size() - 1);
65 // analysis value using flg except dout
66 for (size_t i = 0; i < params.size() - 1; ++i) {
67 auto ¶m = params[i];
68 auto &arg_info = args_value_using_info_[i];
69 ArgInfoRefresh(param, &arg_info);
70 AnalysisNodeUsingInfo(node_users, param, &arg_info);
71 }
72 analysis_finish_flg_ = true;
73 }
74
AnalysisNodeUsingInfo(const NodeUsersMap & node_users,const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info) const75 void PrimBpropOptGraphLevel2Info::AnalysisNodeUsingInfo(const NodeUsersMap &node_users,
76 const std::shared_ptr<AnfNode> ¶m,
77 ParamUsingInfo *arg_info) const {
78 MS_EXCEPTION_IF_NULL(arg_info);
79 auto iter = node_users.find(param);
80 if (iter == node_users.end()) {
81 arg_info->using_flg_ = false;
82 return;
83 }
84
85 // tensor return directly
86 if (!arg_info->tuple_flg_) {
87 arg_info->using_flg_ = true;
88 return;
89 }
90
91 // specific process for tuple parameter, may only partial items used
92 const auto &users_info = iter->second;
93 for (auto &user_info : users_info) {
94 auto user_node = user_info.first;
95 arg_info->using_flg_ = true;
96 MS_LOG(DEBUG) << "param:" << param->ToString() << " used by node:" << user_node->ToString();
97 if (!IsPrimitiveCNode(user_node, prim::kPrimTupleGetItem)) {
98 for (auto &sub_info : arg_info->sub_using_info_) {
99 sub_info.using_flg_ = true;
100 }
101 } else {
102 AalysisForTupleGetItem(node_users, param, arg_info, user_node);
103 }
104 }
105 }
AalysisForTupleGetItem(const NodeUsersMap & node_users,const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info,const AnfNodePtr & user_node) const106 void PrimBpropOptGraphLevel2Info::AalysisForTupleGetItem(const NodeUsersMap &node_users,
107 const std::shared_ptr<AnfNode> ¶m,
108 ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const {
109 MS_EXCEPTION_IF_NULL(arg_info);
110 MS_EXCEPTION_IF_NULL(user_node);
111 auto cnode = user_node->cast<CNodePtr>();
112 MS_EXCEPTION_IF_NULL(cnode);
113 const size_t tuple_get_item_size = 3;
114 const size_t index = 2;
115 if (cnode->size() != tuple_get_item_size) {
116 MS_LOG(EXCEPTION) << "TupleGetItem Node:" << user_node->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
117 << "input size is:" << cnode->size();
118 }
119 auto idx_node = cnode->input(index);
120 if (!idx_node->isa<ValueNode>()) {
121 MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
122 << " unexpected used by node:" << user_node->ToString()
123 << " TupleGetItem idx node:" << idx_node->ToString();
124 }
125
126 auto vnode = idx_node->cast<ValueNodePtr>();
127 auto value_ptr = vnode->value();
128 if (value_ptr == nullptr || !value_ptr->isa<Int64Imm>()) {
129 MS_LOG(EXCEPTION) << "tuple :" << param->ToString() << " of bp_graph:" << opt_func_graph_->ToString()
130 << " unexpected used by node:" << user_node->ToString()
131 << " TupleGetItem idx node:" << idx_node->ToString() << " idx Value :" << value_ptr;
132 }
133
134 auto idx = LongToSize(value_ptr->cast<Int64ImmPtr>()->value());
135 arg_info->sub_using_info_[idx].using_flg_ = true;
136 ArgInfoRefresh(cnode, &(arg_info->sub_using_info_[idx]));
137
138 if (arg_info->tuple_flg_) {
139 AnalysisNodeUsingInfo(node_users, cnode, &(arg_info->sub_using_info_[idx]));
140 }
141 }
142
ArgInfoRefresh(const std::shared_ptr<AnfNode> & param,ParamUsingInfo * arg_info) const143 void PrimBpropOptGraphLevel2Info::ArgInfoRefresh(const std::shared_ptr<AnfNode> ¶m,
144 ParamUsingInfo *arg_info) const {
145 MS_EXCEPTION_IF_NULL(arg_info);
146 MS_EXCEPTION_IF_NULL(param);
147 auto abs = param->abstract();
148 MS_EXCEPTION_IF_NULL(abs);
149 if (abs->isa<abstract::AbstractTensor>()) {
150 arg_info->tuple_flg_ = false;
151 MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTensor";
152 } else if (abs->isa<abstract::AbstractTuple>()) {
153 auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
154 MS_LOG(DEBUG) << "param abstract:" << param->ToString() << " is a AbstractTuple";
155 arg_info->tuple_flg_ = true;
156 arg_info->tuple_size_ = abs_tuple->size();
157 arg_info->sub_using_info_.resize(abs_tuple->size());
158 } else {
159 arg_info->tuple_flg_ = false;
160 }
161 }
162
GetPrimBpropOptimizerInst()163 PrimBpropOptimizer &PrimBpropOptimizer::GetPrimBpropOptimizerInst() {
164 static PrimBpropOptimizer g_prim_bprop_opt = PrimBpropOptimizer();
165 return g_prim_bprop_opt;
166 }
167
Clear()168 void PrimBpropOptimizer::Clear() {
169 prim_bprop_cache_.clear();
170 tuple_list_bprop_cache_.clear();
171 }
172
173 // bprop_fg has the signature:
174 // (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
175 // c_node contains the prim(input 0) and the input parameters of that prim;
176 // op_args contains the arguments list of each input parameters, it maybe tensor or tuple
177 // out contains the out of c_node;
OptimizeBPropFuncGraph(const FuncGraphPtr & bprop_fg,const CNodePtr & c_node,const ValuePtrList & op_args,const ValuePtr & out)178 FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node,
179 const ValuePtrList &op_args, const ValuePtr &out) {
180 MS_EXCEPTION_IF_NULL(bprop_fg);
181 MS_EXCEPTION_IF_NULL(c_node);
182 MS_EXCEPTION_IF_NULL(out);
183 auto &inputs = c_node->inputs();
184 if (inputs.size() < 1 || inputs.size() - 1 != op_args.size()) {
185 MS_LOG(EXCEPTION) << "The parameters num " << inputs.size() - 1 << " not match arguments num " << op_args.size()
186 << ", CNode:" << c_node->ToString() << " grap:" << bprop_fg->ToString();
187 }
188
189 if (!IsValueNode<Primitive>(inputs[0])) {
190 MS_LOG(EXCEPTION) << "CNode:" << c_node->ToString()
191 << " not a primitive node, input_0 is:" << inputs[0]->ToString();
192 }
193
194 PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
195 MS_LOG(DEBUG) << "Hash of prim " << prim->ToString() << " is:" << prim->hash();
196
197 // kPrimHookBackward
198 bool hookback_flg = IsPrimitiveEquals(prim, prim::kPrimHookBackward);
199 if (hookback_flg || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
200 return GenSpecOptBprop(bprop_fg, op_args, out, prim, hookback_flg);
201 }
202
203 return GetOptBpropFromCache(bprop_fg, op_args, out, prim);
204 }
205
GetOptBpropFromCache(const FuncGraphPtr & bprop_fg,const ValuePtrList & op_args,const ValuePtr & out,const PrimitivePtr & prim)206 FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
207 const ValuePtr &out, const PrimitivePtr &prim) {
208 MS_EXCEPTION_IF_NULL(bprop_fg);
209 abstract::AbstractBasePtrList abs_list;
210 ArgsToAbs(prim, op_args, &abs_list);
211
212 PrimBpropOptGraphLevel2InfoPtr level_2_graph_info;
213 PrimBpropOptGraphInfoPtr level_1_graph_info;
214 ECacheQrtRes cache_res = GetOptBpfgFromCache(prim, abs_list, &level_2_graph_info, &level_1_graph_info);
215
216 MS_LOG(DEBUG) << "Cache match result " << cache_res << ", prim: " << prim->ToString();
217 if (cache_res == E_LEVEL_2) {
218 MS_LOG(DEBUG) << "Level 2 cache matched, prim: " << prim->ToString();
219 level_2_graph_info->TryFreeArgsValue(op_args, out);
220 return BasicClone(level_2_graph_info->opt_func_graph());
221 }
222
223 // do step1 opt
224 if (cache_res == E_NOT_FOUND) {
225 bprop_fg->debug_info()->set_name(prim->ToString());
226 level_1_graph_info = PrimBpropOptStep1(bprop_fg);
227 prim_bprop_cache_[prim] = level_1_graph_info;
228 }
229 FuncGraphPtr level_1_graph = BasicClone(level_1_graph_info->opt_func_graph_);
230
231 // do step2 opt
232 auto new_abs_list = AddOutToAbsList(out, abs_list);
233 level_2_graph_info = PrimBpropOptStep2(level_1_graph, new_abs_list);
234 level_2_graph_info->TryFreeArgsValue(op_args, out);
235 auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
236 if (enable_grad_cache) {
237 level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
238 return BasicClone(level_2_graph_info->opt_func_graph());
239 }
240 return level_2_graph_info->opt_func_graph();
241 }
242
GenSpecOptBprop(const FuncGraphPtr & bprop_fg,const ValuePtrList & op_args,const ValuePtr & out,const PrimitivePtr & prim,bool hook_flg)243 FuncGraphPtr PrimBpropOptimizer::GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
244 const ValuePtr &out, const PrimitivePtr &prim, bool hook_flg) {
245 MS_EXCEPTION_IF_NULL(bprop_fg);
246 abstract::AbstractBasePtrList abs_list;
247 ArgsToAbs(prim, op_args, &abs_list);
248 if (!hook_flg) {
249 auto iter = tuple_list_bprop_cache_.find(std::pair(prim, abs_list));
250 if (iter != tuple_list_bprop_cache_.end()) {
251 return BasicClone(iter->second);
252 }
253 }
254
255 // do step1 opt
256 bprop_fg->debug_info()->set_name(prim->ToString());
257 auto level_1_graph_info = PrimBpropOptStep1(bprop_fg);
258
259 // do step2 opt
260 auto new_abs_list = AddOutToAbsList(out, abs_list);
261 auto level_2_graph_info = PrimBpropOptStep2(level_1_graph_info->opt_func_graph_, new_abs_list);
262 level_2_graph_info->TryFreeArgsValue(op_args, out);
263 auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
264 if (!hook_flg && enable_grad_cache) {
265 tuple_list_bprop_cache_[std::pair(prim, abs_list)] = BasicClone(level_2_graph_info->opt_func_graph());
266 }
267 return level_2_graph_info->opt_func_graph();
268 }
269
PrimBpropOptStep1(const FuncGraphPtr & bprop_fg)270 PrimBpropOptGraphInfoPtr PrimBpropOptimizer::PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) {
271 opt::irpass::OptimizeIRPassLib irpass;
272 auto level_1_graph_info = std::make_shared<PrimBpropOptGraphInfo>();
273 auto prim_bprop_opt_res = std::make_shared<pipeline::Resource>();
274 auto prim_bprop_opt_manage = prim_bprop_opt_res->manager();
275 auto graph_for_cache = BasicClone(bprop_fg);
276 prim_bprop_opt_res->set_func_graph(graph_for_cache);
277 prim_bprop_opt_manage->AddFuncGraph(graph_for_cache);
278 auto opt_bprop_fg = PrimBpOptPassStep1(irpass, prim_bprop_opt_res);
279 level_1_graph_info->opt_func_graph_ = opt_bprop_fg;
280 return level_1_graph_info;
281 }
282
BindAbsToParameters(const FuncGraphPtr & bprop_fg,const abstract::AbstractBasePtrList & abs_list_input)283 void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
284 const abstract::AbstractBasePtrList &abs_list_input) {
285 MS_EXCEPTION_IF_NULL(bprop_fg);
286 auto ¶ms = bprop_fg->parameters();
287 if (abs_list_input.size() != params.size()) {
288 MS_LOG(EXCEPTION) << "Param num:" << params.size() << " not match inputs num " << abs_list_input.size();
289 }
290
291 for (size_t i = 0; i < abs_list_input.size(); i++) {
292 params[i]->set_abstract(abs_list_input[i]);
293 }
294 }
295
PrimBpropOptStep2(const FuncGraphPtr & bprop_fg,const abstract::AbstractBasePtrList & abs_list_input)296 PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
297 const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) {
298 opt::irpass::OptimizeIRPassLib irpass;
299 BindAbsToParameters(bprop_fg, abs_list_input);
300 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
301 auto manager = resource->manager();
302 resource->set_func_graph(bprop_fg);
303 manager->AddFuncGraph(bprop_fg);
304 auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource);
305 auto level_2_graph_info = std::make_shared<PrimBpropOptGraphLevel2Info>(opt_bprop_fg);
306 level_2_graph_info->AnalysisArgUsingInfo(manager);
307 return level_2_graph_info;
308 }
309
BpropGraphFinalOpt(const pipeline::ResourcePtr & res) const310 FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const {
311 MS_EXCEPTION_IF_NULL(res);
312 auto after_opt_bg = BpropGraphFinalOptPass(res);
313 return after_opt_bg;
314 }
315
GetOptBpfgFromCache(const PrimitivePtr & prim,const abstract::AbstractBasePtrList & abs_list,PrimBpropOptGraphLevel2InfoPtr * level_2_graph_info,PrimBpropOptGraphInfoPtr * level_1_graph_info)316 ECacheQrtRes PrimBpropOptimizer::GetOptBpfgFromCache(const PrimitivePtr &prim,
317 const abstract::AbstractBasePtrList &abs_list,
318 PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
319 PrimBpropOptGraphInfoPtr *level_1_graph_info) {
320 MS_EXCEPTION_IF_NULL(prim);
321 MS_EXCEPTION_IF_NULL(level_1_graph_info);
322 MS_EXCEPTION_IF_NULL(level_2_graph_info);
323 auto attrs_ = prim->attrs();
324 for (auto &item : attrs_) {
325 MS_LOG(DEBUG) << "prim:" << prim->ToString() << " attr: " << item.first << " value:" << item.second->ToString();
326 }
327
328 auto iter = prim_bprop_cache_.find(prim);
329 if (iter == prim_bprop_cache_.end()) {
330 return E_NOT_FOUND;
331 }
332
333 *level_1_graph_info = iter->second;
334 auto second_iter = (*level_1_graph_info)->graph_level_2_cache_.find(abs_list);
335 if (second_iter == (*level_1_graph_info)->graph_level_2_cache_.end()) {
336 return E_LEVEL_1;
337 }
338 *level_2_graph_info = second_iter->second;
339 return E_LEVEL_2;
340 }
341
ArgsToAbs(const PrimitivePtr & prim,const ValuePtrList & op_args,abstract::AbstractBasePtrList * abs_list)342 void PrimBpropOptimizer::ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args,
343 abstract::AbstractBasePtrList *abs_list) {
344 MS_EXCEPTION_IF_NULL(prim);
345 MS_EXCEPTION_IF_NULL(abs_list);
346 auto const_input_index = prim->get_const_input_indexes();
347 bool have_const_input = !const_input_index.empty();
348 bool is_const_prim = prim->is_const_prim();
349 for (size_t i = 0; i < op_args.size(); ++i) {
350 bool is_const_input =
351 have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
352 auto &arg_value = op_args[i];
353 auto arg_abs = arg_value->ToAbstract();
354 if (!is_const_prim && !is_const_input) {
355 arg_abs = arg_abs->PartialBroaden();
356 MS_LOG(DEBUG) << "Broaden for " << prim->ToString();
357 }
358 (void)abs_list->emplace_back(arg_abs);
359 }
360 }
361
AddOutToAbsList(const ValuePtr & out,const abstract::AbstractBasePtrList & abs_list)362 abstract::AbstractBasePtrList PrimBpropOptimizer::AddOutToAbsList(const ValuePtr &out,
363 const abstract::AbstractBasePtrList &abs_list) {
364 MS_EXCEPTION_IF_NULL(out);
365 if (!out->isa<tensor::Tensor>() && !out->isa<ValueTuple>()) {
366 MS_LOG(EXCEPTION) << "Out value not Tensor or Tuple, please check the input arguments.";
367 }
368 abstract::AbstractBasePtrList new_abs_list(abs_list);
369 auto out_abs = out->ToAbstract();
370 out_abs = out_abs->PartialBroaden();
371 (void)new_abs_list.emplace_back(out_abs);
372 (void)new_abs_list.emplace_back(out_abs);
373 return new_abs_list;
374 }
375 } // namespace ad
376 } // namespace mindspore
377