1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_GRAD_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_GRAD_H_ 19 20 #include <memory> 21 #include <utility> 22 #include <map> 23 #include <vector> 24 #include <string> 25 #include <tuple> 26 #include "ir/anf.h" 27 #include "ir/func_graph.h" 28 #include "pipeline/pynative/base.h" 29 #include "pipeline/pynative/grad/variable.h" 30 #include "pipeline/pynative/grad/ir/ir_bprop.h" 31 #include "pipeline/pynative/grad/auto_grad.h" 32 #include "pipeline/pynative/grad/function/func_builder.h" 33 34 namespace mindspore::pynative::autograd { 35 class FuncBackwardNode : public BackwardNode { 36 public: FuncBackwardNode(const string & name,expander::bprop::BpropBuilderFunc func,mindspore::HashMap<std::string,ValuePtr> attrs,ValuePtrList op_inputs,AbstractBasePtrList input_abstract,ValuePtr op_output,size_t output_size,AbstractBasePtr out_abstract,std::vector<InputType> grad_type)37 FuncBackwardNode(const string &name, expander::bprop::BpropBuilderFunc func, 38 mindspore::HashMap<std::string, ValuePtr> attrs, ValuePtrList op_inputs, 39 AbstractBasePtrList input_abstract, ValuePtr op_output, size_t output_size, 40 AbstractBasePtr out_abstract, std::vector<InputType> grad_type) 41 : BackwardNode(name, output_size), 42 attrs_(std::move(attrs)), 43 op_inputs_(std::move(op_inputs)), 44 input_abstract_(std::move(input_abstract)), 45 grad_type_(std::move(grad_type)), 46 out_abstract_(std::move(out_abstract)), 47 func_(std::move(func)) { 48 op_output_ = std::move(op_output); 49 } 50 ~FuncBackwardNode() override = default; 51 ValuePtrList CallBackward(const ValuePtrList &grads) override; 52 NodePtrList PreProcess(const ValuePtrList &dout, FuncBuilder *emitter); grad_func()53 const expander::bprop::BpropBuilderFunc &grad_func() { return func_; } set_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)54 void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) { attrs_ = attrs; } 55 void Release() override; 56 57 private: 58 mindspore::HashMap<std::string, ValuePtr> attrs_; 59 ValuePtrList op_inputs_; 60 abstract::AbstractBasePtrList input_abstract_; 61 std::vector<InputType> grad_type_; 62 abstract::AbstractBasePtr out_abstract_; 63 expander::bprop::BpropBuilderFunc func_; 64 }; 65 66 class HookBackwardNode : public BackwardNode { 67 public: HookBackwardNode(const string & name,PrimitivePyPtr prim,VectorRef && args,size_t output_size,abstract::AbstractBasePtr out_abstract)68 HookBackwardNode(const string &name, PrimitivePyPtr prim, VectorRef &&args, size_t output_size, 69 abstract::AbstractBasePtr out_abstract) 70 : BackwardNode(name, output_size), prim_(std::move(prim)), args_(args), out_abstract_(std::move(out_abstract)) {} 71 ValuePtrList CallBackward(const ValuePtrList &grads) override; 72 void Release() override; 73 74 private: 75 PrimitivePyPtr prim_; 76 VectorRef args_; 77 abstract::AbstractBasePtr out_abstract_; 78 }; 79 80 class GraphBackwardNode : public BackwardNode { 81 public: GraphBackwardNode(const string & name,FuncGraphPtr func_graph,const VectorRef & args,const ValuePtr & op_output,size_t output_size,std::string cache_key,bool is_control_flow,bool is_jit_graph,bool is_dynamic_shape_process,bool jit_out_has_dict)82 explicit GraphBackwardNode(const string &name, FuncGraphPtr func_graph, const VectorRef &args, 83 const ValuePtr &op_output, size_t output_size, std::string cache_key, bool is_control_flow, 84 bool is_jit_graph, bool is_dynamic_shape_process, bool jit_out_has_dict) 85 : BackwardNode(name, output_size), 86 func_graph_(std::move(func_graph)), 87 args_(args), 88 cache_key_(std::move(cache_key)), 89 graph_call_condition_(is_control_flow, is_jit_graph, is_dynamic_shape_process, jit_out_has_dict, true) { 90 op_output_ = op_output; 91 } 92 ValuePtrList CallBackward(const ValuePtrList &grads) override; 93 94 private: 95 FuncGraphPtr func_graph_; 96 VectorRef args_; 97 std::string cache_key_{false}; 98 GraphCallCondition graph_call_condition_; 99 }; 100 101 class GraphRoot : public BackwardNode { 102 public: GraphRoot(const string & name)103 explicit GraphRoot(const string &name) : BackwardNode(name) {} 104 ~GraphRoot() override = default; CallBackward(const ValuePtrList & grads)105 ValuePtrList CallBackward(const ValuePtrList &grads) override { return grads; } 106 ValuePtrList BuildFlattenSensGradient(const ValuePtrList &sens_gradient) const; 107 }; 108 109 class FakeBackwardNode : public BackwardNode { 110 public: BackwardNode(name,output_size)111 explicit FakeBackwardNode(const string &name, size_t output_size = 1) : BackwardNode(name, output_size) {} 112 ~FakeBackwardNode() override = default; CallBackward(const ValuePtrList & grads)113 ValuePtrList CallBackward(const ValuePtrList &grads) override { 114 MS_LOG(EXCEPTION) << "Illegal primitive " << name() << "'s bprop not defined"; 115 } 116 }; 117 118 class FuncGrad : public AutoGrad { 119 public: 120 FuncGrad(const ValuePtrList &input_param_values, size_t op_num_in_bprop_graph, bool grad_by_value, 121 bool is_run_recompute); 122 ~FuncGrad() override = default; 123 124 bool KPynativeOp(const GradParamPtr &grad_param) override; 125 // Update top cell output, record last_node 126 void UpdateOutputNodeOfTopCell(const ValuePtr &sens_out) override; 127 // Reverse connect jit or higher order sub bprop funcgraph 128 bool KPynativeWithFProp(const GradParamPtr &grad_param) override; 129 130 ValuePtr Finish(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position, 131 const GradAttr &grad_attr, const ValuePtr &sens = nullptr); 132 133 private: 134 void BackPropagate(); 135 void BuildForwardLastNode(const ValuePtr &sens_gradient); 136 OrderedSet<FuncVariablePtr>::reverse_iterator GetLastNodeReverseIter(); 137 void ConstructParameterNodes(const ValuePtrList &inputs); 138 139 BackwardNodePtr BuildFuncBackwardNode(const PrimitivePtr &prim, const expander::bprop::BpropBuilderFunc &func, 140 const ValuePtrList &flatten_inputs, const OpGradInfoPtr &op_grad_info); 141 BackwardNodePtr BuildCustomBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs, 142 const OpGradInfoPtr &op_grad_info); 143 BackwardNodePtr BuildHookBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs, 144 const OpGradInfoPtr &op_grad_info); 145 BackwardNodePtr BuildFakeBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs, 146 const OpGradInfoPtr &op_grad_info); 147 BackwardNodePtr BuildGraphBackwardNode(const GradParamPtr &grad_param); 148 ValuePtr GetGrads(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position, 149 const GradAttr &grad_attr); 150 ValuePtr GetInputGrads(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position); 151 ValuePtr GetWeightGrads(bool grad_weights, const tensor::BaseTensorPtrList &weights, bool weight_param_is_tuple); 152 ValuePtr GetWeightGrad(const tensor::BaseTensorPtr &weight); 153 void ClearGrads(const tensor::BaseTensorPtrList &weights); 154 ValuePtrList OnsLike(const ValuePtr &value); 155 void CheckSensShapeAndType(const ValuePtr &sens_gradient); 156 void PruningGradGraph(const tensor::BaseTensorPtrList &weights, const GradAttr &grad_attr, 157 const std::vector<size_t> &grad_position); 158 void PruningInput(const GradAttr &grad_attr, const std::vector<size_t> &grad_position); 159 void PruningWeights(const tensor::BaseTensorPtrList &weights, const GradAttr &grad_attr); 160 161 bool is_run_recompute_{false}; 162 std::shared_ptr<FuncBuilder> func_impl_; 163 OrderedSet<FuncVariablePtr> variable_set_; 164 std::vector<std::pair<ValuePtr, FuncVariablePtr>> cell_inputs_; 165 std::vector<tensor::BaseTensorPtr> weights_used_in_graph_; 166 ValuePtr sens_value_{nullptr}; 167 FuncVariablePtr last_variable_{nullptr}; 168 ValuePtrList root_gradients_; 169 }; 170 } // namespace mindspore::pynative::autograd 171 172 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_GRAD_H_ 173