• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
19 
20 #include <vector>
21 #include <utility>
22 #include <unordered_map>
23 #include <memory>
24 
25 #include "frontend/optimizer/irpass.h"
26 #include "ir/func_graph.h"
27 #include "pipeline/jit/resource.h"
28 
29 namespace mindspore {
30 namespace ad {
31 struct PrimBpropOptGraphInfo;
32 
33 class PrimBpropOptGraphLevel2Info;
34 
35 struct PrimitiveTotalEqual;
36 
37 struct PrimitiveTupleListHasher;
38 
39 struct PrimitiveTupleListEqual;
40 
41 using PrimBpropOptGraphInfoPtr = std::shared_ptr<PrimBpropOptGraphInfo>;
42 
43 using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr<PrimBpropOptGraphLevel2Info>;
44 
45 using PrimBpropCache = std::unordered_map<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>;
46 
47 using TupleListKey = std::pair<PrimitivePtr, abstract::AbstractBasePtrList>;
48 
49 using PrimBpropLevel2Cache =
50   std::unordered_map<abstract::AbstractBasePtrList, PrimBpropOptGraphLevel2InfoPtr, abstract::AbstractBasePtrListHasher,
51                      abstract::AbstractBasePtrListEqual>;
52 
53 using PrimTupleListCache =
54   std::unordered_map<TupleListKey, FuncGraphPtr, PrimitiveTupleListHasher, PrimitiveTupleListEqual>;
55 
56 struct PrimitiveTupleListHasher {
operatorPrimitiveTupleListHasher57   bool operator()(const TupleListKey &key) const {
58     abstract::AbstractBasePtrListHasher hasher;
59     return hasher(key.second);
60   }
61 };
62 
63 struct PrimitiveTupleListEqual {
operatorPrimitiveTupleListEqual64   bool operator()(TupleListKey const &t1, TupleListKey const &t2) const {
65     MS_EXCEPTION_IF_NULL(t1.first);
66     MS_EXCEPTION_IF_NULL(t2.first);
67 
68     if (!(*t1.first == *t2.first)) {
69       return false;
70     }
71     abstract::AbstractBasePtrListEqual cmp;
72     return cmp(t1.second, t2.second);
73   }
74 };
75 
76 struct PrimitiveTotalEqual {
operatorPrimitiveTotalEqual77   bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
78     MS_EXCEPTION_IF_NULL(t1);
79     MS_EXCEPTION_IF_NULL(t2);
80     return *t1 == *t2;
81   }
82 };
83 
84 enum ECacheQrtRes { E_NOT_FOUND, E_LEVEL_1, E_LEVEL_2 };
85 
86 struct PrimBpropOptGraphInfo {
87   // the level1 opt func_graph without infer, no shape/type info provide
88   FuncGraphPtr opt_func_graph_;
89   // the opt func_graph after infer, func_graph level2 cache
90   PrimBpropLevel2Cache graph_level_2_cache_;
91 };
92 
93 struct ParamUsingInfo {
94   bool using_flg_{false};
95   bool tuple_flg_{false};
96   size_t tuple_size_;
97   std::vector<ParamUsingInfo> sub_using_info_;
98 };
99 
100 class PrimBpropOptGraphLevel2Info {
101  public:
PrimBpropOptGraphLevel2Info(const FuncGraphPtr & func_graph)102   explicit PrimBpropOptGraphLevel2Info(const FuncGraphPtr &func_graph) : opt_func_graph_(func_graph) {}
103   ~PrimBpropOptGraphLevel2Info() = default;
104 
opt_func_graph()105   const FuncGraphPtr &opt_func_graph() const { return opt_func_graph_; }
106 
107   void TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out);
108 
109   void AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager);
110 
111  private:
112   void ArgInfoRefresh(const std::shared_ptr<AnfNode> &param, ParamUsingInfo *arg_info) const;
113 
114   void AnalysisNodeUsingInfo(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> &param,
115                              ParamUsingInfo *arg_info) const;
116 
117   void TryFreeOneValue(const ValuePtrList &op_args, const std::vector<ParamUsingInfo> &param_info_vec);
118 
119   void AalysisForTupleGetItem(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> &param,
120                               ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const;
121 
122  private:
123   // the level2 opt func_graph
124   FuncGraphPtr opt_func_graph_;
125   // to indicate arguments value using or not, if not using should free device memory
126   std::vector<ParamUsingInfo> args_value_using_info_;
127   bool analysis_finish_flg_{false};
128 };
129 
130 class PrimBpropOptimizer {
131  public:
132   ~PrimBpropOptimizer() = default;
133 
134   void Clear();
135 
136   static PrimBpropOptimizer &GetPrimBpropOptimizerInst();
137 
138   // bprop_fg has the signature:
139   // (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out)
140   // c_node contains the prim(input 0) and the input parameters of that prim;
141   // op_args contains the arguments list of each input parameters, it maybe tensor or tuple
142   // out contains the out of c_node;
143   FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
144                                       const ValuePtr &out);
145 
146   // do inline opt for final bprop graph
147   FuncGraphPtr BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const;
148 
149  private:
150   PrimBpropOptimizer() = default;
151 
152   ECacheQrtRes GetOptBpfgFromCache(const PrimitivePtr &prim, const abstract::AbstractBasePtrList &abs_list,
153                                    PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info,
154                                    PrimBpropOptGraphInfoPtr *level_1_graph_info);
155 
156   // converter tensor args to abs value;
157   void ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args, abstract::AbstractBasePtrList *abs_list);
158 
159   // add out && dout to abs list
160   abstract::AbstractBasePtrList AddOutToAbsList(const ValuePtr &out, const abstract::AbstractBasePtrList &abs_list);
161 
162   // do opt without input info, no infer
163   PrimBpropOptGraphInfoPtr PrimBpropOptStep1(const FuncGraphPtr &bprop_fg);
164 
165   // do opt with input info
166   PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(const FuncGraphPtr &bprop_fg,
167                                                    const abstract::AbstractBasePtrList &abs_list_input);
168 
169   void BindAbsToParameters(const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input);
170 
171   FuncGraphPtr GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
172                                     const PrimitivePtr &prim);
173 
174   FuncGraphPtr GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
175                                const PrimitivePtr &prim, bool hook_flg);
176 
177  private:
178   // cache optimized bprop graph
179   PrimBpropCache prim_bprop_cache_;
180   PrimTupleListCache tuple_list_bprop_cache_;
181 };
182 }  // namespace ad
183 }  // namespace mindspore
184 
185 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H
186