1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H 19 20 #include <vector> 21 #include <utility> 22 #include <unordered_map> 23 #include <memory> 24 25 #include "frontend/optimizer/irpass.h" 26 #include "ir/func_graph.h" 27 #include "pipeline/jit/resource.h" 28 29 namespace mindspore { 30 namespace ad { 31 struct PrimBpropOptGraphInfo; 32 33 class PrimBpropOptGraphLevel2Info; 34 35 struct PrimitiveTotalEqual; 36 37 struct PrimitiveTupleListHasher; 38 39 struct PrimitiveTupleListEqual; 40 41 using PrimBpropOptGraphInfoPtr = std::shared_ptr<PrimBpropOptGraphInfo>; 42 43 using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr<PrimBpropOptGraphLevel2Info>; 44 45 using PrimBpropCache = std::unordered_map<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>; 46 47 using TupleListKey = std::pair<PrimitivePtr, abstract::AbstractBasePtrList>; 48 49 using PrimBpropLevel2Cache = 50 std::unordered_map<abstract::AbstractBasePtrList, PrimBpropOptGraphLevel2InfoPtr, abstract::AbstractBasePtrListHasher, 51 abstract::AbstractBasePtrListEqual>; 52 53 using PrimTupleListCache = 54 std::unordered_map<TupleListKey, FuncGraphPtr, PrimitiveTupleListHasher, PrimitiveTupleListEqual>; 55 56 struct PrimitiveTupleListHasher { operatorPrimitiveTupleListHasher57 bool operator()(const TupleListKey &key) const { 58 abstract::AbstractBasePtrListHasher hasher; 59 return hasher(key.second); 60 } 61 }; 62 63 struct PrimitiveTupleListEqual { operatorPrimitiveTupleListEqual64 bool operator()(TupleListKey const &t1, TupleListKey const &t2) const { 65 MS_EXCEPTION_IF_NULL(t1.first); 66 MS_EXCEPTION_IF_NULL(t2.first); 67 68 if (!(*t1.first == *t2.first)) { 69 return false; 70 } 71 abstract::AbstractBasePtrListEqual cmp; 72 return cmp(t1.second, t2.second); 73 } 74 }; 75 76 struct PrimitiveTotalEqual { operatorPrimitiveTotalEqual77 bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { 78 MS_EXCEPTION_IF_NULL(t1); 79 MS_EXCEPTION_IF_NULL(t2); 80 return *t1 == *t2; 81 } 82 }; 83 84 enum ECacheQrtRes { E_NOT_FOUND, E_LEVEL_1, E_LEVEL_2 }; 85 86 struct PrimBpropOptGraphInfo { 87 // the level1 opt func_graph without infer, no shape/type info provide 88 FuncGraphPtr opt_func_graph_; 89 // the opt func_graph after infer, func_graph level2 cache 90 PrimBpropLevel2Cache graph_level_2_cache_; 91 }; 92 93 struct ParamUsingInfo { 94 bool using_flg_{false}; 95 bool tuple_flg_{false}; 96 size_t tuple_size_; 97 std::vector<ParamUsingInfo> sub_using_info_; 98 }; 99 100 class PrimBpropOptGraphLevel2Info { 101 public: PrimBpropOptGraphLevel2Info(const FuncGraphPtr & func_graph)102 explicit PrimBpropOptGraphLevel2Info(const FuncGraphPtr &func_graph) : opt_func_graph_(func_graph) {} 103 ~PrimBpropOptGraphLevel2Info() = default; 104 opt_func_graph()105 const FuncGraphPtr &opt_func_graph() const { return opt_func_graph_; } 106 107 void TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out); 108 109 void AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager); 110 111 private: 112 void ArgInfoRefresh(const std::shared_ptr<AnfNode> ¶m, ParamUsingInfo *arg_info) const; 113 114 void AnalysisNodeUsingInfo(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> ¶m, 115 ParamUsingInfo *arg_info) const; 116 117 void TryFreeOneValue(const ValuePtrList &op_args, const std::vector<ParamUsingInfo> ¶m_info_vec); 118 119 void AalysisForTupleGetItem(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> ¶m, 120 ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const; 121 122 private: 123 // the level2 opt func_graph 124 FuncGraphPtr opt_func_graph_; 125 // to indicate arguments value using or not, if not using should free device memory 126 std::vector<ParamUsingInfo> args_value_using_info_; 127 bool analysis_finish_flg_{false}; 128 }; 129 130 class PrimBpropOptimizer { 131 public: 132 ~PrimBpropOptimizer() = default; 133 134 void Clear(); 135 136 static PrimBpropOptimizer &GetPrimBpropOptimizerInst(); 137 138 // bprop_fg has the signature: 139 // (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out) 140 // c_node contains the prim(input 0) and the input parameters of that prim; 141 // op_args contains the arguments list of each input parameters, it maybe tensor or tuple 142 // out contains the out of c_node; 143 FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args, 144 const ValuePtr &out); 145 146 // do inline opt for final bprop graph 147 FuncGraphPtr BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const; 148 149 private: 150 PrimBpropOptimizer() = default; 151 152 ECacheQrtRes GetOptBpfgFromCache(const PrimitivePtr &prim, const abstract::AbstractBasePtrList &abs_list, 153 PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info, 154 PrimBpropOptGraphInfoPtr *level_1_graph_info); 155 156 // converter tensor args to abs value; 157 void ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args, abstract::AbstractBasePtrList *abs_list); 158 159 // add out && dout to abs list 160 abstract::AbstractBasePtrList AddOutToAbsList(const ValuePtr &out, const abstract::AbstractBasePtrList &abs_list); 161 162 // do opt without input info, no infer 163 PrimBpropOptGraphInfoPtr PrimBpropOptStep1(const FuncGraphPtr &bprop_fg); 164 165 // do opt with input info 166 PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(const FuncGraphPtr &bprop_fg, 167 const abstract::AbstractBasePtrList &abs_list_input); 168 169 void BindAbsToParameters(const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input); 170 171 FuncGraphPtr GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out, 172 const PrimitivePtr &prim); 173 174 FuncGraphPtr GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out, 175 const PrimitivePtr &prim, bool hook_flg); 176 177 private: 178 // cache optimized bprop graph 179 PrimBpropCache prim_bprop_cache_; 180 PrimTupleListCache tuple_list_bprop_cache_; 181 }; 182 } // namespace ad 183 } // namespace mindspore 184 185 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H 186