• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_VARIABLE_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_VARIABLE_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include "ir/anf.h"
25 #include "include/backend/kernel_graph.h"
26 #include "pipeline/pynative/grad/function/func_builder.h"
27 
28 namespace mindspore::pynative::autograd {
29 using TensorPtrList = tensor::TensorPtrList;
30 
31 struct GradAttr {
GradAttrGradAttr32   GradAttr(bool get_all, bool get_by_list, bool sens_param, bool get_by_position, bool weight_param_is_tuple)
33       : grad_all_inputs(get_all),
34         grad_weights(get_by_list),
35         has_sens(sens_param),
36         get_by_position(get_by_position),
37         weight_param_is_tuple(weight_param_is_tuple) {}
38 
39   bool grad_all_inputs;
40   bool grad_weights;
41   bool has_sens;
42   bool get_by_position;
43   bool weight_param_is_tuple;
44 };
45 
46 class Variable;
47 struct Edge {
48   /// \brief Constructor.
49   ///
50   /// \param[in] variable The variable represents object need gradient.
51   /// \param[in] input_index The input index is variable output index.
EdgeEdge52   explicit Edge(std::shared_ptr<Variable> variable, size_t input_index)
53       : variable(std::move(variable)), input_index(input_index) {}
54   std::shared_ptr<Variable> variable;
55   size_t input_index;
56 };
57 
58 class BackwardNode {
59  public:
60   /// \brief Constructor.
61   ///
62   /// \param[in] name The name represents op name.
63   /// \param[in] output_size The output_size is output size for op.
name_(std::move (name))64   explicit BackwardNode(string name, size_t output_size = 1) : name_(std::move(name)), output_size_(output_size) {}
65 
66   /// \brief Destructor.
67   virtual ~BackwardNode() = default;
68 
69   /// \brief CallBackward function is used to calculate gradient of this node.
70   ///
71   /// \param[in] grads Grads is this node output's gradients.
CallBackward(const ValuePtrList & grads)72   virtual ValuePtrList CallBackward(const ValuePtrList &grads) { return {}; }
73 
74   /// \brief Collect next edges of this node. The inputs should be flatten.
75   /// \param[in] inputs Inputs is op input.
76   virtual void UpdateNextEdges(const std::vector<ValuePtr> &inputs);
77 
78   /// \brief Postprocess gradients from func to align with next_edges.
79   /// \param[in] gradient_value Gradients value is gradients result from func
80   /// which need postprocess.
81   /// \return Real gradients after postprocess, the size is same as next edges size.
82   virtual ValuePtrList PostProcess(const ValuePtrList &gradient_value);
83 
84   // Update nullptr grad.
85   ValuePtrList LazeUpdateZeroGradient(const ValuePtrList &dout, FuncBuilder *func_builder, const ValuePtr &output);
86 
87   /// \brief The PostProcess function is used to represent this node's inputs, which can
88   /// backpropagation gradients.
89   ///
90   /// \return next edges
next_edges()91   const std::vector<Edge> &next_edges() const { return next_edges_; }
92 
93   /// \brief The gradient_index function is used to represent index of inputs,
94   /// which need calculate gradient.
95   ///
96   /// \return gradient index
gradient_index()97   const std::vector<size_t> &gradient_index() const { return gradient_index_; }
98 
99   /// \brief name of this Node.
100   ///
101   /// \return name
name()102   const std::string &name() { return name_; }
103 
104   /// \brief Set op output value
105   ///
106   /// \return op output value
set_op_output(const ValuePtr & op_output)107   void set_op_output(const ValuePtr &op_output) { op_output_ = op_output; }
108 
109   /// \brief Get op output value
110   ///
111   /// \return op output value
op_output()112   const ValuePtr &op_output() { return op_output_; }
113 
114   /// \brief The size of node output.
115   ///
116   /// \return output size
output_size()117   size_t output_size() const { return output_size_; }
118 
119   /// \brief Release resource
120   ///
121   /// \return void
Release()122   virtual void Release() {}
123 
124  protected:
125   std::vector<Edge> next_edges_;
126   std::vector<size_t> gradient_index_;
127   std::string name_;
128   ValuePtr op_output_{nullptr};
129   size_t output_size_;
130 };
131 using BackwardNodePtr = std::shared_ptr<BackwardNode>;
132 
133 class IrFunctionNode {
134  public:
IrFunctionNode(KernelGraphPtr tape,const AnfNodePtr & dout)135   IrFunctionNode(KernelGraphPtr tape, const AnfNodePtr &dout)
136       : tape_(std::move(tape)), accumulate_dout_(dout), fake_dout_(dout) {}
137   void AddNextEdge(const std::shared_ptr<Variable> &next_variable, const AnfNodePtr &din);
138   void UpdateAccumulativeDout(const AnfNodePtr &new_dout);
next_edges()139   [[nodiscard]] const std::vector<std::pair<std::shared_ptr<Variable>, AnfNodePtr>> &next_edges() const {
140     return next_edges_;
141   }
tape()142   const KernelGraphPtr &tape() { return tape_; }
accumulate_dout()143   [[nodiscard]] const AnfNodePtr &accumulate_dout() const { return accumulate_dout_; }
set_accumulate_dout(const AnfNodePtr & accumulate_dout)144   void set_accumulate_dout(const AnfNodePtr &accumulate_dout) { accumulate_dout_ = accumulate_dout; }
145   void ReplaceEdges();
fake_dout()146   [[nodiscard]] const AnfNodePtr &fake_dout() const { return fake_dout_; }
147 
148  private:
149   AnfNodePtr HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node);
150   // Bprop func graph
151   const KernelGraphPtr tape_;
152   // Input of dout for this bprop function
153   AnfNodePtr accumulate_dout_;
154   // First we generate a fake dout
155   const AnfNodePtr fake_dout_;
156   // The pair.first is a variable, pair.second is dout of variable
157   std::vector<std::pair<std::shared_ptr<Variable>, AnfNodePtr>> next_edges_;
158   // Replace next_edges where din == dout in brprop function
159   std::vector<int> need_replace_edges_;
160 };
161 using IrFunctionNodePtr = std::shared_ptr<IrFunctionNode>;
162 
163 // Variable represent a tensor need grad
164 class Variable {
165  public:
166   /// \brief Constructor.
167   ///
168   Variable() = default;
169 
170   /// \brief Destructor.
171   ///
172   virtual ~Variable() = default;
173 
174   /// \param fn, Backward function.
175   /// \param is_leaf, The variable is leaf or not.
Variable(BackwardNodePtr && fn,bool is_leaf)176   Variable(BackwardNodePtr &&fn, bool is_leaf) : is_leaf_(is_leaf), func_node_(std::move(fn)) {}
177 
178   /// \brief Constructor.
179   ///
180   /// \param fn, IrFunctionNodePtr function.
181   /// \param is_leaf, The variable is leaf or not.
Variable(IrFunctionNodePtr && fn,ValuePtr && out_value,bool is_leaf)182   Variable(IrFunctionNodePtr &&fn, ValuePtr &&out_value, bool is_leaf)
183       : is_leaf_(is_leaf), out_value_(std::move(out_value)), ir_function_node_(std::move(fn)) {}
184 
185   /// \brief Backward function.
186   ///
187   /// \return fn
func_node()188   BackwardNodePtr func_node() const { return func_node_; }
189 
190   /// \brief IrFunctionNode function.
191   ///
192   /// \return fn for ir
ir_function_node()193   IrFunctionNodePtr ir_function_node() const { return ir_function_node_; }
194 
195   /// \brief Name of fake op.
196   ///
197   /// \return fake_prim_name
fake_prim_name()198   const string &fake_prim_name() const { return fake_prim_name_; }
199 
200   /// \brief Set name of fake op.
201   ///
202   /// \param fake_prim_name
set_fake_prim_name(const string & fake_prim_name)203   void set_fake_prim_name(const string &fake_prim_name) { fake_prim_name_ = fake_prim_name; }
204 
205   /// \brief Flag to judge whether the op is fake op.
206   ///
is_fake_bprop()207   bool is_fake_bprop() const { return is_fake_bprop_; }
208 
209   /// \brief Set fake bprop.
210   ///
211   /// \param is_fake_bprop
set_is_fake_bprop(bool is_fake_bprop)212   void set_is_fake_bprop(bool is_fake_bprop) { is_fake_bprop_ = is_fake_bprop; }
213 
214   /// \brief Flag to judge whether the variable is need propagate.
215   ///
216   /// \return True if the variable need propagate, false if not.
is_need_propagate()217   bool is_need_propagate() const { return is_need_propagate_; }
218 
219   /// \brief Set need propagate.
220   ///
set_is_need_propagate(bool is_need_grad)221   void set_is_need_propagate(bool is_need_grad) { is_need_propagate_ = is_need_grad; }
222 
223   /// \brief Flag to judge whether the variable is need grad.
224   ///
225   /// \return is need grad
is_need_grad()226   bool is_need_grad() const { return is_need_grad_; }
227 
228   /// \brief Set need grad.
229   ///
230   /// \param is_need_grad
set_is_need_grad(bool is_need_grad)231   void set_is_need_grad(bool is_need_grad) { is_need_grad_ = is_need_grad; }
232 
233   /// \brief Judge whether the variable is left node.
234   ///
235   /// \return True if variable is leaf, false if not.
is_leaf()236   bool is_leaf() const { return is_leaf_; }
237 
238   /// \brief Get forward output value.
239   ///
240   /// \return valueptr.
241 
out_value()242   ValuePtr out_value() const { return out_value_; }
243 
244   /// \brief Debug info.
245   ///
246   /// \return debug info.
ToString()247   virtual std::string ToString() const { return {}; }
248   /// \brief Release input and output tensors
249   ///
250   /// \return void
Release()251   void Release() {
252     MS_EXCEPTION_IF_NULL(func_node_);
253     func_node_->Release();
254   }
255 
256  private:
257   // If node has not bprop, we record its prim name
258   std::string fake_prim_name_;
259   // Record this node is a fake bprop
260   bool is_fake_bprop_{false};
261   // Flag to judge need to propagrate
262   bool is_need_propagate_{false};
263   // Flag to judge variable whether need grad
264   bool is_need_grad_{true};
265   // Flag the variable is a leaf in bprop.
266   bool is_leaf_{false};
267   ValuePtr out_value_{nullptr};
268   // Abstract bprop function
269   BackwardNodePtr func_node_{nullptr};
270   // Abstract bprop function
271   IrFunctionNodePtr ir_function_node_{nullptr};
272 };
273 using VariablePtr = std::shared_ptr<Variable>;
274 
275 // FuncVariable represent a parameter or output of op
276 class FuncVariable : public Variable {
277  public:
278   /// \brief Constructor.
279   ///
280   FuncVariable() = default;
281   ~FuncVariable() override = default;
282   /// \brief Constructor.
283   ///
284   /// \param fn, Backward function.
285   /// \param is_leaf, The variable is leaf or not.
Variable(std::move (fn),is_leaf)286   explicit FuncVariable(BackwardNodePtr fn, bool is_leaf = false) : Variable(std::move(fn), is_leaf) {}
287 
288   /// \brief Gradients of the variable if variable is left node, nullptr if not left node.
289   ///
grad()290   const tensor::BaseTensorPtr &grad() const { return grad_; }
291 
292   /// \brief Set gradients of the leaf variable.
293   ///
294   /// \param grad
set_grad(const tensor::BaseTensorPtr & grad)295   void set_grad(const tensor::BaseTensorPtr &grad) { grad_ = grad; }
296 
297   std::string ToString() const override;
298 
299  private:
300   // Grad for this variable, only leaf node has grad.
301   tensor::BaseTensorPtr grad_;
302 };
303 using FuncVariablePtr = std::shared_ptr<FuncVariable>;
304 
305 // IrVariable represent a parameter or output of a middle cnode
306 class IrVariable : public Variable {
307  public:
308   IrVariable() = default;
309   ~IrVariable() override = default;
310 
311   IrVariable(IrFunctionNodePtr fn, ValuePtr out_value, bool is_leaf = false)
Variable(std::move (fn),std::move (out_value),is_leaf)312       : Variable(std::move(fn), std::move(out_value), is_leaf) {}
313 
k_node()314   AnfNodePtr k_node() const { return k_node_; }
set_k_node(const AnfNodePtr & k_node)315   void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
316   AnfNodePtr RealDout();
317   std::string ToString() const override;
318 
319  private:
320   AnfNodePtr k_node_{nullptr};
321 };
322 using IrVariablePtr = std::shared_ptr<IrVariable>;
323 
324 template <typename T>
isa(const BackwardNodePtr & base_ptr)325 bool isa(const BackwardNodePtr &base_ptr) {
326   const auto &object = (*base_ptr);
327   return typeid(object) == typeid(T);
328 }
329 }  // namespace mindspore::pynative::autograd
330 
331 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_VARIABLE_H_
332