• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_IR_BPROP_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_IR_BPROP_H_
19 
20 #include <utility>
21 #include <tuple>
22 #include <vector>
23 #include <string>
24 #include <memory>
25 #include "ir/anf.h"
26 #include "pipeline/pynative/base.h"
27 #include "pipeline/pynative/grad/variable.h"
28 #include "pipeline/pynative/grad/ir/ir_pass.h"
29 #include "frontend/expander/bprop/bprop.h"
30 
31 namespace mindspore::pynative::autograd {
32 void ClearAutoGradCache();
33 using KernelGraph = session::KernelGraph;
34 struct AdParam {
AdParamAdParam35   AdParam() : tape_(std::make_shared<KernelGraph>()), fg_(std::make_shared<FuncGraph>()) {}
36   // Bprop funcgraph
37   KernelGraphPtr tape_;
38   FuncGraphPtr fg_;
39   IrVariablePtr last_variable_{nullptr};
40   // Just for ad graph
41   AnfNodePtr last_node_{nullptr};
42   ValuePtr sens_value_;
43   // Bprop dins of each variable or middle out
44   OrderedMap<AnfNodePtr, IrVariablePtr> anfnode_to_variable_adjoint_;
45   OrderedSet<IrVariablePtr> variable_adjoint_set_;
46   // Record cnode's input map for tape_
47   expander::bprop::UserMap users_;
48   expander::bprop::UserType reverse_users_;
49   AnfNodePtrList weights_used_in_graph_;
50   std::vector<std::tuple<AnfNodePtr, CNodePtr, size_t>> lazy_user_data_;
51 };
52 using AdParamPtr = std::shared_ptr<AdParam>;
53 
54 class IrBprop {
55  public:
56   IrBprop(AdParamPtr ad_param, std::string device_target, bool grad_by_value, bool is_run_recompute = false)
ad_param_(std::move (ad_param))57       : ad_param_(std::move(ad_param)), grad_by_value_(grad_by_value), is_run_recompute_(is_run_recompute) {
58     pass_forward_ = std::make_shared<bprop_pass::IrPassForward>(this, std::move(device_target), grad_by_value_);
59   }
60 
61   // Get graph bporp graph by ad::grad or by expander
62   std::pair<bool, FuncGraphPtr> GetBpropGraph(const GradParamPtr &grad_param);
63 
64   // Build custom
65   void BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs);
66 
67   // Create bprop_cut cnode in bprop graph
68   void BuildBPropCutCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs,
69                           bool is_need_recompute = false);
70   // Get parameter from a value
71   AnfNodePtr MapParameter(const ValuePtr &value, const abstract::AbstractBasePtr &abs);
72 
73   // Create variable for parameter
74   ParameterPtr AddParameterNode(const tensor::BaseTensorPtr &tensor, const abstract::AbstractBasePtr &abs);
75 
76   // Create a new parameter
77   ParameterPtr CreateTapeParameter(const tensor::BaseTensorPtr &tensor, const abstract::AbstractBasePtr &abs);
78 
79   // Update cnode dout
80   void UpdateNextEdges(const VariablePtr &variable, const std::vector<CNodePtr> &dins, const ValuePtrList &inputs_value,
81                        const abstract::AbstractBasePtrList &abs, const string &op_name = "");
82 
83   // Used for ture dout repalce
84   void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
85 
86   // Used for high grad
87   void AddReverseUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
88 
89   // Create link for op grad graph and generate a bprop graph
90   void BackPropagate();
91 
92   // Get lase node variable
93   AbstractBasePtr BuildForwardLastNode();
94 
95   // Replace for true dout
96   void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, expander::bprop::UserType *user,
97                bool need_update = false);
98 
ad_param()99   AdParamPtr ad_param() const { return ad_param_; }
bprop_graph_run_by_single_op()100   inline bool bprop_graph_run_by_single_op() { return bprop_graph_run_by_single_op_; }
set_bprop_graph_run_by_single_op(bool bprop_graph_run_by_single_op)101   inline void set_bprop_graph_run_by_single_op(bool bprop_graph_run_by_single_op) {
102     bprop_graph_run_by_single_op_ = bprop_graph_run_by_single_op_ || bprop_graph_run_by_single_op;
103   }
104 
105  private:
106   // Get bprop graph by ad::grad
107   FuncGraphPtr GetBpropGraphFromFprop(const GradParamPtr &grad_param);
108 
109   // Get Bprop by expander
110   FuncGraphPtr GetBpropGraphFromExpander(const GradParamPtr &grad_param);
111 
112   // Use topo grad for every cnode
113   void GradGraphByExpander(const GradParamPtr &grad_param);
114 
115   // Create variable for param
116   void CreateParameterAdjoint(const GradParamPtr &grad_param) const;
117 
118   // Use pass for cnode inputs
119   void PrepareGradCNodeInputs(const PrimitivePtr &prim, const CNodePtr &cnode, ValuePtrList *inputs_value,
120                               AnfNodePtrList *cnode_inputs);
121 
122   // Get knode and value for cnode inputs
123   ValuePtrList GetInputArgs(const CNodePtr &cnode, AnfNodePtrList *cnode_inputs) const;
124 
125   // Do grad for a cnode
126   void GradCNode(const PrimitivePtr &prim, const CNodePtr &cnode, const GradParamPtr &grad_param,
127                  const ValuePtrList &inputs_value, AnfNodePtrList *cnode_inputs);
128 
129   // Build knode for MakeTuple
130   AnfNodePtr BuildKNodeForMakeTuple(const AnfNodePtr &input_node);
131 
132   // Build knode for TupleGetItem
133   AnfNodePtr BuildKNodeForTupleGetItem(const AnfNodePtr &input_node);
134 
135   // Get knode for cnode inputs
136   AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input);
137 
138   // Get a compute cnode
139   AnfNodePtr GetKnode(const PrimitivePtr &prim, const CNodePtr &cnode, const AnfNodePtrList &cnode_inputs,
140                       bool jit_by_value);
141 
142   // Set dout for every input arg
143   void UpdateNextEdge(const IrFunctionNodePtr &fn, const AnfNodePtr &din, const ValuePtr &input_arg,
144                       const AbstractBasePtr &abs);
145 
146   // Used for dict inputs
147   void UpdateNextEdgeForDict(const IrFunctionNodePtr &fn, const AnfNodePtr &din, const ValuePtr &input_arg,
148                              const AbstractBasePtr &abs);
149 
150   // Set din for corresponding input
151   AnfNodePtr TraceInput(const IrFunctionNodePtr &fn, const ValuePtr &out_value,
152                         const abstract::AbstractBasePtr &out_abs, const tensor::BaseTensorPtr &input_tensor,
153                         const AnfNodePtr &din);
154 
155   // Used for dict input
156   AnfNodePtr TraceInputForDict(const IrFunctionNodePtr &fn, const ValuePtr &out_value,
157                                const abstract::AbstractBasePtr &out_abs, const tensor::BaseTensorPtr &input_tensor,
158                                const AnfNodePtr &din);
159 
160   // Get last node variable
161   OrderedSet<IrVariablePtr>::reverse_iterator GetLastNodeReverseIter();
162 
163   // Used for tuplegetiem elimate
164   void AddTupleGetItemUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
165 
166   // For lazy user
167   void UpdateLazyUser();
168 
169   // Input node is user cnode one of input, index is user input index
170   // User->input(index) is input node
171   void LazyAddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
172 
173   AdParamPtr ad_param_{nullptr};
174   bool grad_by_value_{false};
175   bool is_run_recompute_{false};
176   // Flag for ms_funtcion and high order
177   bool bprop_graph_run_by_single_op_{false};
178   bprop_pass::PyNativePassForwardPtr pass_forward_;
179 };
180 using IrBpropPtr = std::unique_ptr<IrBprop>;
181 }  // namespace mindspore::pynative::autograd
182 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_IR_BPROP_H_
183