• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_GRAD_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_GRAD_H_
19 
20 #include <memory>
21 #include <utility>
22 #include <map>
23 #include <vector>
24 #include <string>
25 #include <tuple>
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "pipeline/pynative/base.h"
29 #include "pipeline/pynative/grad/variable.h"
30 #include "pipeline/pynative/grad/ir/ir_bprop.h"
31 #include "pipeline/pynative/grad/auto_grad.h"
32 #include "pipeline/pynative/grad/function/func_builder.h"
33 
34 namespace mindspore::pynative::autograd {
35 class FuncBackwardNode : public BackwardNode {
36  public:
FuncBackwardNode(const string & name,expander::bprop::BpropBuilderFunc func,mindspore::HashMap<std::string,ValuePtr> attrs,ValuePtrList op_inputs,AbstractBasePtrList input_abstract,ValuePtr op_output,size_t output_size,AbstractBasePtr out_abstract,std::vector<InputType> grad_type)37   FuncBackwardNode(const string &name, expander::bprop::BpropBuilderFunc func,
38                    mindspore::HashMap<std::string, ValuePtr> attrs, ValuePtrList op_inputs,
39                    AbstractBasePtrList input_abstract, ValuePtr op_output, size_t output_size,
40                    AbstractBasePtr out_abstract, std::vector<InputType> grad_type)
41       : BackwardNode(name, output_size),
42         attrs_(std::move(attrs)),
43         op_inputs_(std::move(op_inputs)),
44         input_abstract_(std::move(input_abstract)),
45         grad_type_(std::move(grad_type)),
46         out_abstract_(std::move(out_abstract)),
47         func_(std::move(func)) {
48     op_output_ = std::move(op_output);
49   }
50   ~FuncBackwardNode() override = default;
51   ValuePtrList CallBackward(const ValuePtrList &grads) override;
52   NodePtrList PreProcess(const ValuePtrList &dout, FuncBuilder *emitter);
grad_func()53   const expander::bprop::BpropBuilderFunc &grad_func() { return func_; }
set_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)54   void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) { attrs_ = attrs; }
55   void Release() override;
56 
57  private:
58   mindspore::HashMap<std::string, ValuePtr> attrs_;
59   ValuePtrList op_inputs_;
60   abstract::AbstractBasePtrList input_abstract_;
61   std::vector<InputType> grad_type_;
62   abstract::AbstractBasePtr out_abstract_;
63   expander::bprop::BpropBuilderFunc func_;
64 };
65 
66 class HookBackwardNode : public BackwardNode {
67  public:
HookBackwardNode(const string & name,PrimitivePyPtr prim,VectorRef && args,size_t output_size,abstract::AbstractBasePtr out_abstract)68   HookBackwardNode(const string &name, PrimitivePyPtr prim, VectorRef &&args, size_t output_size,
69                    abstract::AbstractBasePtr out_abstract)
70       : BackwardNode(name, output_size), prim_(std::move(prim)), args_(args), out_abstract_(std::move(out_abstract)) {}
71   ValuePtrList CallBackward(const ValuePtrList &grads) override;
72   void Release() override;
73 
74  private:
75   PrimitivePyPtr prim_;
76   VectorRef args_;
77   abstract::AbstractBasePtr out_abstract_;
78 };
79 
80 class GraphBackwardNode : public BackwardNode {
81  public:
GraphBackwardNode(const string & name,FuncGraphPtr func_graph,const VectorRef & args,const ValuePtr & op_output,size_t output_size,std::string cache_key,bool is_control_flow,bool is_jit_graph,bool is_dynamic_shape_process,bool jit_out_has_dict)82   explicit GraphBackwardNode(const string &name, FuncGraphPtr func_graph, const VectorRef &args,
83                              const ValuePtr &op_output, size_t output_size, std::string cache_key, bool is_control_flow,
84                              bool is_jit_graph, bool is_dynamic_shape_process, bool jit_out_has_dict)
85       : BackwardNode(name, output_size),
86         func_graph_(std::move(func_graph)),
87         args_(args),
88         cache_key_(std::move(cache_key)),
89         graph_call_condition_(is_control_flow, is_jit_graph, is_dynamic_shape_process, jit_out_has_dict, true) {
90     op_output_ = op_output;
91   }
92   ValuePtrList CallBackward(const ValuePtrList &grads) override;
93 
94  private:
95   FuncGraphPtr func_graph_;
96   VectorRef args_;
97   std::string cache_key_{false};
98   GraphCallCondition graph_call_condition_;
99 };
100 
101 class GraphRoot : public BackwardNode {
102  public:
GraphRoot(const string & name)103   explicit GraphRoot(const string &name) : BackwardNode(name) {}
104   ~GraphRoot() override = default;
CallBackward(const ValuePtrList & grads)105   ValuePtrList CallBackward(const ValuePtrList &grads) override { return grads; }
106   ValuePtrList BuildFlattenSensGradient(const ValuePtrList &sens_gradient) const;
107 };
108 
109 class FakeBackwardNode : public BackwardNode {
110  public:
BackwardNode(name,output_size)111   explicit FakeBackwardNode(const string &name, size_t output_size = 1) : BackwardNode(name, output_size) {}
112   ~FakeBackwardNode() override = default;
CallBackward(const ValuePtrList & grads)113   ValuePtrList CallBackward(const ValuePtrList &grads) override {
114     MS_LOG(EXCEPTION) << "Illegal primitive " << name() << "'s bprop not defined";
115   }
116 };
117 
118 class FuncGrad : public AutoGrad {
119  public:
120   FuncGrad(const ValuePtrList &input_param_values, size_t op_num_in_bprop_graph, bool grad_by_value,
121            bool is_run_recompute);
122   ~FuncGrad() override = default;
123 
124   bool KPynativeOp(const GradParamPtr &grad_param) override;
125   // Update top cell output, record last_node
126   void UpdateOutputNodeOfTopCell(const ValuePtr &sens_out) override;
127   // Reverse connect jit or higher order sub bprop funcgraph
128   bool KPynativeWithFProp(const GradParamPtr &grad_param) override;
129 
130   ValuePtr Finish(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
131                   const GradAttr &grad_attr, const ValuePtr &sens = nullptr);
132 
133  private:
134   void BackPropagate();
135   void BuildForwardLastNode(const ValuePtr &sens_gradient);
136   OrderedSet<FuncVariablePtr>::reverse_iterator GetLastNodeReverseIter();
137   void ConstructParameterNodes(const ValuePtrList &inputs);
138 
139   BackwardNodePtr BuildFuncBackwardNode(const PrimitivePtr &prim, const expander::bprop::BpropBuilderFunc &func,
140                                         const ValuePtrList &flatten_inputs, const OpGradInfoPtr &op_grad_info);
141   BackwardNodePtr BuildCustomBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs,
142                                           const OpGradInfoPtr &op_grad_info);
143   BackwardNodePtr BuildHookBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs,
144                                         const OpGradInfoPtr &op_grad_info);
145   BackwardNodePtr BuildFakeBackwardNode(const PrimitivePtr &prim, const ValuePtrList &flatten_inputs,
146                                         const OpGradInfoPtr &op_grad_info);
147   BackwardNodePtr BuildGraphBackwardNode(const GradParamPtr &grad_param);
148   ValuePtr GetGrads(const tensor::BaseTensorPtrList &weights, const std::vector<size_t> &grad_position,
149                     const GradAttr &grad_attr);
150   ValuePtr GetInputGrads(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position);
151   ValuePtr GetWeightGrads(bool grad_weights, const tensor::BaseTensorPtrList &weights, bool weight_param_is_tuple);
152   ValuePtr GetWeightGrad(const tensor::BaseTensorPtr &weight);
153   void ClearGrads(const tensor::BaseTensorPtrList &weights);
154   ValuePtrList OnsLike(const ValuePtr &value);
155   void CheckSensShapeAndType(const ValuePtr &sens_gradient);
156   void PruningGradGraph(const tensor::BaseTensorPtrList &weights, const GradAttr &grad_attr,
157                         const std::vector<size_t> &grad_position);
158   void PruningInput(const GradAttr &grad_attr, const std::vector<size_t> &grad_position);
159   void PruningWeights(const tensor::BaseTensorPtrList &weights, const GradAttr &grad_attr);
160 
161   bool is_run_recompute_{false};
162   std::shared_ptr<FuncBuilder> func_impl_;
163   OrderedSet<FuncVariablePtr> variable_set_;
164   std::vector<std::pair<ValuePtr, FuncVariablePtr>> cell_inputs_;
165   std::vector<tensor::BaseTensorPtr> weights_used_in_graph_;
166   ValuePtr sens_value_{nullptr};
167   FuncVariablePtr last_variable_{nullptr};
168   ValuePtrList root_gradients_;
169 };
170 }  // namespace mindspore::pynative::autograd
171 
172 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_GRAD_H_
173