1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_IR_BPROP_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_IR_BPROP_H_ 19 20 #include <utility> 21 #include <tuple> 22 #include <vector> 23 #include <string> 24 #include <memory> 25 #include "ir/anf.h" 26 #include "pipeline/pynative/base.h" 27 #include "pipeline/pynative/grad/variable.h" 28 #include "pipeline/pynative/grad/ir/ir_pass.h" 29 #include "frontend/expander/bprop/bprop.h" 30 31 namespace mindspore::pynative::autograd { 32 void ClearAutoGradCache(); 33 using KernelGraph = session::KernelGraph; 34 struct AdParam { AdParamAdParam35 AdParam() : tape_(std::make_shared<KernelGraph>()), fg_(std::make_shared<FuncGraph>()) {} 36 // Bprop funcgraph 37 KernelGraphPtr tape_; 38 FuncGraphPtr fg_; 39 IrVariablePtr last_variable_{nullptr}; 40 // Just for ad graph 41 AnfNodePtr last_node_{nullptr}; 42 ValuePtr sens_value_; 43 // Bprop dins of each variable or middle out 44 OrderedMap<AnfNodePtr, IrVariablePtr> anfnode_to_variable_adjoint_; 45 OrderedSet<IrVariablePtr> variable_adjoint_set_; 46 // Record cnode's input map for tape_ 47 expander::bprop::UserMap users_; 48 expander::bprop::UserType reverse_users_; 49 AnfNodePtrList weights_used_in_graph_; 50 std::vector<std::tuple<AnfNodePtr, CNodePtr, size_t>> lazy_user_data_; 51 }; 52 using AdParamPtr = std::shared_ptr<AdParam>; 53 54 class IrBprop { 55 public: 56 IrBprop(AdParamPtr ad_param, std::string device_target, bool grad_by_value, bool is_run_recompute = false) ad_param_(std::move (ad_param))57 : ad_param_(std::move(ad_param)), grad_by_value_(grad_by_value), is_run_recompute_(is_run_recompute) { 58 pass_forward_ = std::make_shared<bprop_pass::IrPassForward>(this, std::move(device_target), grad_by_value_); 59 } 60 61 // Get graph bporp graph by ad::grad or by expander 62 std::pair<bool, FuncGraphPtr> GetBpropGraph(const GradParamPtr &grad_param); 63 64 // Build custom 65 void BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs); 66 67 // Create bprop_cut cnode in bprop graph 68 void BuildBPropCutCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs, 69 bool is_need_recompute = false); 70 // Get parameter from a value 71 AnfNodePtr MapParameter(const ValuePtr &value, const abstract::AbstractBasePtr &abs); 72 73 // Create variable for parameter 74 ParameterPtr AddParameterNode(const tensor::BaseTensorPtr &tensor, const abstract::AbstractBasePtr &abs); 75 76 // Create a new parameter 77 ParameterPtr CreateTapeParameter(const tensor::BaseTensorPtr &tensor, const abstract::AbstractBasePtr &abs); 78 79 // Update cnode dout 80 void UpdateNextEdges(const VariablePtr &variable, const std::vector<CNodePtr> &dins, const ValuePtrList &inputs_value, 81 const abstract::AbstractBasePtrList &abs, const string &op_name = ""); 82 83 // Used for ture dout repalce 84 void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index); 85 86 // Used for high grad 87 void AddReverseUser(const AnfNodePtr &node, const CNodePtr &user, size_t index); 88 89 // Create link for op grad graph and generate a bprop graph 90 void BackPropagate(); 91 92 // Get lase node variable 93 AbstractBasePtr BuildForwardLastNode(); 94 95 // Replace for true dout 96 void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, expander::bprop::UserType *user, 97 bool need_update = false); 98 ad_param()99 AdParamPtr ad_param() const { return ad_param_; } bprop_graph_run_by_single_op()100 inline bool bprop_graph_run_by_single_op() { return bprop_graph_run_by_single_op_; } set_bprop_graph_run_by_single_op(bool bprop_graph_run_by_single_op)101 inline void set_bprop_graph_run_by_single_op(bool bprop_graph_run_by_single_op) { 102 bprop_graph_run_by_single_op_ = bprop_graph_run_by_single_op_ || bprop_graph_run_by_single_op; 103 } 104 105 private: 106 // Get bprop graph by ad::grad 107 FuncGraphPtr GetBpropGraphFromFprop(const GradParamPtr &grad_param); 108 109 // Get Bprop by expander 110 FuncGraphPtr GetBpropGraphFromExpander(const GradParamPtr &grad_param); 111 112 // Use topo grad for every cnode 113 void GradGraphByExpander(const GradParamPtr &grad_param); 114 115 // Create variable for param 116 void CreateParameterAdjoint(const GradParamPtr &grad_param) const; 117 118 // Use pass for cnode inputs 119 void PrepareGradCNodeInputs(const PrimitivePtr &prim, const CNodePtr &cnode, ValuePtrList *inputs_value, 120 AnfNodePtrList *cnode_inputs); 121 122 // Get knode and value for cnode inputs 123 ValuePtrList GetInputArgs(const CNodePtr &cnode, AnfNodePtrList *cnode_inputs) const; 124 125 // Do grad for a cnode 126 void GradCNode(const PrimitivePtr &prim, const CNodePtr &cnode, const GradParamPtr &grad_param, 127 const ValuePtrList &inputs_value, AnfNodePtrList *cnode_inputs); 128 129 // Build knode for MakeTuple 130 AnfNodePtr BuildKNodeForMakeTuple(const AnfNodePtr &input_node); 131 132 // Build knode for TupleGetItem 133 AnfNodePtr BuildKNodeForTupleGetItem(const AnfNodePtr &input_node); 134 135 // Get knode for cnode inputs 136 AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input); 137 138 // Get a compute cnode 139 AnfNodePtr GetKnode(const PrimitivePtr &prim, const CNodePtr &cnode, const AnfNodePtrList &cnode_inputs, 140 bool jit_by_value); 141 142 // Set dout for every input arg 143 void UpdateNextEdge(const IrFunctionNodePtr &fn, const AnfNodePtr &din, const ValuePtr &input_arg, 144 const AbstractBasePtr &abs); 145 146 // Used for dict inputs 147 void UpdateNextEdgeForDict(const IrFunctionNodePtr &fn, const AnfNodePtr &din, const ValuePtr &input_arg, 148 const AbstractBasePtr &abs); 149 150 // Set din for corresponding input 151 AnfNodePtr TraceInput(const IrFunctionNodePtr &fn, const ValuePtr &out_value, 152 const abstract::AbstractBasePtr &out_abs, const tensor::BaseTensorPtr &input_tensor, 153 const AnfNodePtr &din); 154 155 // Used for dict input 156 AnfNodePtr TraceInputForDict(const IrFunctionNodePtr &fn, const ValuePtr &out_value, 157 const abstract::AbstractBasePtr &out_abs, const tensor::BaseTensorPtr &input_tensor, 158 const AnfNodePtr &din); 159 160 // Get last node variable 161 OrderedSet<IrVariablePtr>::reverse_iterator GetLastNodeReverseIter(); 162 163 // Used for tuplegetiem elimate 164 void AddTupleGetItemUser(const AnfNodePtr &node, const CNodePtr &user, size_t index); 165 166 // For lazy user 167 void UpdateLazyUser(); 168 169 // Input node is user cnode one of input, index is user input index 170 // User->input(index) is input node 171 void LazyAddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index); 172 173 AdParamPtr ad_param_{nullptr}; 174 bool grad_by_value_{false}; 175 bool is_run_recompute_{false}; 176 // Flag for ms_funtcion and high order 177 bool bprop_graph_run_by_single_op_{false}; 178 bprop_pass::PyNativePassForwardPtr pass_forward_; 179 }; 180 using IrBpropPtr = std::unique_ptr<IrBprop>; 181 } // namespace mindspore::pynative::autograd 182 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_IR_BPROP_H_ 183