• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
19 
20 #include <vector>
21 #include <utility>
22 #include <memory>
23 #include <unordered_map>
24 #include <string>
25 #include "utils/hash_map.h"
26 #include "frontend/optimizer/irpass.h"
27 #include "ir/func_graph.h"
28 #include "pipeline/jit/ps/resource.h"
29 
30 namespace mindspore {
31 namespace ad {
32 struct PrimBpropOptGraphInfo;
33 
34 class PrimBpropOptGraphLevel2Info;
35 
36 struct PrimitiveTotalEqual;
37 
38 struct PrimitiveTupleListHasher;
39 
40 struct PrimitiveTupleListEqual;
41 
42 using PrimBpropOptGraphInfoPtr = std::shared_ptr<PrimBpropOptGraphInfo>;
43 
44 using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr<PrimBpropOptGraphLevel2Info>;
45 
46 using PrimBpropCache = std::unordered_map<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>;
47 
48 using TupleListKey = std::pair<PrimitivePtr, abstract::AbstractBasePtrList>;
49 
50 using PrimBpropGragFlagCache = std::unordered_map<std::vector<bool>, PrimBpropOptGraphLevel2InfoPtr>;
51 
52 using PrimBpropLevel2Cache =
53   std::unordered_map<abstract::AbstractBasePtrList, PrimBpropGragFlagCache, abstract::AbstractBasePtrListHasher,
54                      abstract::AbstractBasePtrListEqual>;
55 
56 using PrimTupleListCache =
57   std::unordered_map<TupleListKey, FuncGraphPtr, PrimitiveTupleListHasher, PrimitiveTupleListEqual>;
58 
59 struct PrimitiveTupleListHasher {
operatorPrimitiveTupleListHasher60   std::size_t operator()(const TupleListKey &key) const {
61     abstract::AbstractBasePtrListHasher hasher;
62     return hasher(key.second);
63   }
64 };
65 
66 struct PrimitiveTupleListEqual {
operatorPrimitiveTupleListEqual67   bool operator()(TupleListKey const &t1, TupleListKey const &t2) const {
68     MS_EXCEPTION_IF_NULL(t1.first);
69     MS_EXCEPTION_IF_NULL(t2.first);
70 
71     if (!(*t1.first == *t2.first)) {
72       return false;
73     }
74     abstract::AbstractBasePtrListEqual cmp;
75     return cmp(t1.second, t2.second);
76   }
77 };
78 
79 struct PrimitiveTotalEqual {
operatorPrimitiveTotalEqual80   bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
81     MS_EXCEPTION_IF_NULL(t1);
82     MS_EXCEPTION_IF_NULL(t2);
83     return *t1 == *t2;
84   }
85 };
86 
87 enum ECacheQrtRes { E_NOT_FOUND, E_LEVEL_1, E_LEVEL_2 };
88 
89 struct PrimBpropOptGraphInfo {
90   // the level1 opt func_graph without infer, no shape/type info provide
91   FuncGraphPtr opt_func_graph_;
92   // the opt func_graph after infer, func_graph level2 cache
93   PrimBpropLevel2Cache graph_level_2_cache_;
94 };
95 
96 struct ParamUsingInfo {
97   bool using_flg_{false};
98   bool tuple_flg_{false};
99   size_t tuple_size_;
100   std::vector<ParamUsingInfo> sub_using_info_;
101 };
102 
103 class PrimBpropOptGraphLevel2Info {
104  public:
PrimBpropOptGraphLevel2Info(const FuncGraphPtr & func_graph)105   explicit PrimBpropOptGraphLevel2Info(const FuncGraphPtr &func_graph) : opt_func_graph_(func_graph) {}
106   ~PrimBpropOptGraphLevel2Info() = default;
107 
opt_func_graph()108   const FuncGraphPtr &opt_func_graph() const { return opt_func_graph_; }
109 
110   void TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out);
111 
112   void AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager);
113 
114  private:
115   void ArgInfoRefresh(const std::shared_ptr<AnfNode> &param, ParamUsingInfo *arg_info) const;
116 
117   void AnalysisNodeUsingInfo(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> &param,
118                              ParamUsingInfo *arg_info) const;
119 
120   void TryFreeOneValue(const ValuePtrList &op_args, const std::vector<ParamUsingInfo> &param_info_vec);
121 
122   void AalysisForTupleGetItem(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> &param,
123                               ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const;
124 
125   // the level2 opt func_graph
126   FuncGraphPtr opt_func_graph_;
127   // to indicate arguments value using or not, if not using should free device memory
128   std::vector<ParamUsingInfo> args_value_using_info_;
129   bool analysis_finish_flg_{false};
130 };
131 
132 class PrimBpropOptimizer {
133  public:
134   ~PrimBpropOptimizer() = default;
135 
136   void Clear();
137 
138   static PrimBpropOptimizer &GetPrimBpropOptimizerInst();
139 
140   // bprop_fg has the signature:
141   // (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
142   // c_node contains the prim(input 0) and the input parameters of that prim;
143   // op_args contains the arguments list of each input parameters, it maybe tensor or tuple
144   // out contains the out of c_node;
145   FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
146                                       const ValuePtr &out);
147 
148   // do inline opt for final bprop graph
149   FuncGraphPtr BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const;
150 
151  private:
152   PrimBpropOptimizer() = default;
153 
154   ECacheQrtRes GetOptBpfgFromCache(const PrimitivePtr &prim, const abstract::AbstractBasePtrList &abs_list,
155                                    const std::vector<bool> &need_grad_flags,
156                                    PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
157                                    PrimBpropOptGraphInfoPtr *level_1_graph_info);
158 
159   // converter tensor args to abs value;
160   void ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args, abstract::AbstractBasePtrList *abs_list) const;
161 
162   // add out && dout to abs list
163   abstract::AbstractBasePtrList AddOutToAbsList(const ValuePtr &out,
164                                                 const abstract::AbstractBasePtrList &abs_list) const;
165 
166   // do opt without input info, no infer
167   PrimBpropOptGraphInfoPtr PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) const;
168 
169   // do opt with input info
170   PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(
171     const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input,
172     const std::vector<bool> &need_grad_flags = std::vector<bool>()) const;
173 
174   void BindAbsToParameters(const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) const;
175 
176   FuncGraphPtr GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
177                                     const PrimitivePtr &prim, const std::vector<bool> &need_grad_flags);
178 
179   FuncGraphPtr GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
180                                const PrimitivePtr &prim, bool hook_flg);
181 
182   // cache optimized bprop graph
183   PrimBpropCache prim_bprop_cache_;
184   PrimTupleListCache tuple_list_bprop_cache_;
185 };
186 }  // namespace ad
187 }  // namespace mindspore
188 
189 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
190