1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H 19 20 #include <vector> 21 #include <utility> 22 #include <memory> 23 #include <unordered_map> 24 #include <string> 25 #include "utils/hash_map.h" 26 #include "frontend/optimizer/irpass.h" 27 #include "ir/func_graph.h" 28 #include "pipeline/jit/ps/resource.h" 29 30 namespace mindspore { 31 namespace ad { 32 struct PrimBpropOptGraphInfo; 33 34 class PrimBpropOptGraphLevel2Info; 35 36 struct PrimitiveTotalEqual; 37 38 struct PrimitiveTupleListHasher; 39 40 struct PrimitiveTupleListEqual; 41 42 using PrimBpropOptGraphInfoPtr = std::shared_ptr<PrimBpropOptGraphInfo>; 43 44 using PrimBpropOptGraphLevel2InfoPtr = std::shared_ptr<PrimBpropOptGraphLevel2Info>; 45 46 using PrimBpropCache = std::unordered_map<PrimitivePtr, PrimBpropOptGraphInfoPtr, PrimitiveHasher, PrimitiveTotalEqual>; 47 48 using TupleListKey = std::pair<PrimitivePtr, abstract::AbstractBasePtrList>; 49 50 using PrimBpropGragFlagCache = std::unordered_map<std::vector<bool>, PrimBpropOptGraphLevel2InfoPtr>; 51 52 using PrimBpropLevel2Cache = 53 std::unordered_map<abstract::AbstractBasePtrList, PrimBpropGragFlagCache, abstract::AbstractBasePtrListHasher, 54 abstract::AbstractBasePtrListEqual>; 55 56 using PrimTupleListCache = 57 std::unordered_map<TupleListKey, FuncGraphPtr, PrimitiveTupleListHasher, PrimitiveTupleListEqual>; 58 59 struct PrimitiveTupleListHasher { operatorPrimitiveTupleListHasher60 std::size_t operator()(const TupleListKey &key) const { 61 abstract::AbstractBasePtrListHasher hasher; 62 return hasher(key.second); 63 } 64 }; 65 66 struct PrimitiveTupleListEqual { operatorPrimitiveTupleListEqual67 bool operator()(TupleListKey const &t1, TupleListKey const &t2) const { 68 MS_EXCEPTION_IF_NULL(t1.first); 69 MS_EXCEPTION_IF_NULL(t2.first); 70 71 if (!(*t1.first == *t2.first)) { 72 return false; 73 } 74 abstract::AbstractBasePtrListEqual cmp; 75 return cmp(t1.second, t2.second); 76 } 77 }; 78 79 struct PrimitiveTotalEqual { operatorPrimitiveTotalEqual80 bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { 81 MS_EXCEPTION_IF_NULL(t1); 82 MS_EXCEPTION_IF_NULL(t2); 83 return *t1 == *t2; 84 } 85 }; 86 87 enum ECacheQrtRes { E_NOT_FOUND, E_LEVEL_1, E_LEVEL_2 }; 88 89 struct PrimBpropOptGraphInfo { 90 // the level1 opt func_graph without infer, no shape/type info provide 91 FuncGraphPtr opt_func_graph_; 92 // the opt func_graph after infer, func_graph level2 cache 93 PrimBpropLevel2Cache graph_level_2_cache_; 94 }; 95 96 struct ParamUsingInfo { 97 bool using_flg_{false}; 98 bool tuple_flg_{false}; 99 size_t tuple_size_; 100 std::vector<ParamUsingInfo> sub_using_info_; 101 }; 102 103 class PrimBpropOptGraphLevel2Info { 104 public: PrimBpropOptGraphLevel2Info(const FuncGraphPtr & func_graph)105 explicit PrimBpropOptGraphLevel2Info(const FuncGraphPtr &func_graph) : opt_func_graph_(func_graph) {} 106 ~PrimBpropOptGraphLevel2Info() = default; 107 opt_func_graph()108 const FuncGraphPtr &opt_func_graph() const { return opt_func_graph_; } 109 110 void TryFreeArgsValue(const ValuePtrList &op_args, const ValuePtr &out); 111 112 void AnalysisArgUsingInfo(const FuncGraphManagerPtr &manager); 113 114 private: 115 void ArgInfoRefresh(const std::shared_ptr<AnfNode> ¶m, ParamUsingInfo *arg_info) const; 116 117 void AnalysisNodeUsingInfo(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> ¶m, 118 ParamUsingInfo *arg_info) const; 119 120 void TryFreeOneValue(const ValuePtrList &op_args, const std::vector<ParamUsingInfo> ¶m_info_vec); 121 122 void AalysisForTupleGetItem(const NodeUsersMap &node_users, const std::shared_ptr<AnfNode> ¶m, 123 ParamUsingInfo *arg_info, const AnfNodePtr &user_node) const; 124 125 // the level2 opt func_graph 126 FuncGraphPtr opt_func_graph_; 127 // to indicate arguments value using or not, if not using should free device memory 128 std::vector<ParamUsingInfo> args_value_using_info_; 129 bool analysis_finish_flg_{false}; 130 }; 131 132 class PrimBpropOptimizer { 133 public: 134 ~PrimBpropOptimizer() = default; 135 136 void Clear(); 137 138 static PrimBpropOptimizer &GetPrimBpropOptimizerInst(); 139 140 // bprop_fg has the signature: 141 // (sens_input1, sens_input2,...)bprop_fg(input1, input2, ..., out, d_out) 142 // c_node contains the prim(input 0) and the input parameters of that prim; 143 // op_args contains the arguments list of each input parameters, it maybe tensor or tuple 144 // out contains the out of c_node; 145 FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args, 146 const ValuePtr &out); 147 148 // do inline opt for final bprop graph 149 FuncGraphPtr BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const; 150 151 private: 152 PrimBpropOptimizer() = default; 153 154 ECacheQrtRes GetOptBpfgFromCache(const PrimitivePtr &prim, const abstract::AbstractBasePtrList &abs_list, 155 const std::vector<bool> &need_grad_flags, 156 PrimBpropOptGraphLevel2InfoPtr *level_2_graph_info, 157 PrimBpropOptGraphInfoPtr *level_1_graph_info); 158 159 // converter tensor args to abs value; 160 void ArgsToAbs(const PrimitivePtr &prim, const ValuePtrList &op_args, abstract::AbstractBasePtrList *abs_list) const; 161 162 // add out && dout to abs list 163 abstract::AbstractBasePtrList AddOutToAbsList(const ValuePtr &out, 164 const abstract::AbstractBasePtrList &abs_list) const; 165 166 // do opt without input info, no infer 167 PrimBpropOptGraphInfoPtr PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) const; 168 169 // do opt with input info 170 PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2( 171 const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input, 172 const std::vector<bool> &need_grad_flags = std::vector<bool>()) const; 173 174 void BindAbsToParameters(const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) const; 175 176 FuncGraphPtr GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out, 177 const PrimitivePtr &prim, const std::vector<bool> &need_grad_flags); 178 179 FuncGraphPtr GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out, 180 const PrimitivePtr &prim, bool hook_flg); 181 182 // cache optimized bprop graph 183 PrimBpropCache prim_bprop_cache_; 184 PrimTupleListCache tuple_list_bprop_cache_; 185 }; 186 } // namespace ad 187 } // namespace mindspore 188 189 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_PRIM_BPROP_OPTIMIZER_H 190