1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_VARIABLE_H_
18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_VARIABLE_H_
19
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include "ir/anf.h"
25 #include "include/backend/kernel_graph.h"
26 #include "pipeline/pynative/grad/function/func_builder.h"
27
28 namespace mindspore::pynative::autograd {
29 using TensorPtrList = tensor::TensorPtrList;
30
31 struct GradAttr {
GradAttrGradAttr32 GradAttr(bool get_all, bool get_by_list, bool sens_param, bool get_by_position, bool weight_param_is_tuple)
33 : grad_all_inputs(get_all),
34 grad_weights(get_by_list),
35 has_sens(sens_param),
36 get_by_position(get_by_position),
37 weight_param_is_tuple(weight_param_is_tuple) {}
38
39 bool grad_all_inputs;
40 bool grad_weights;
41 bool has_sens;
42 bool get_by_position;
43 bool weight_param_is_tuple;
44 };
45
46 class Variable;
47 struct Edge {
48 /// \brief Constructor.
49 ///
50 /// \param[in] variable The variable represents object need gradient.
51 /// \param[in] input_index The input index is variable output index.
EdgeEdge52 explicit Edge(std::shared_ptr<Variable> variable, size_t input_index)
53 : variable(std::move(variable)), input_index(input_index) {}
54 std::shared_ptr<Variable> variable;
55 size_t input_index;
56 };
57
58 class BackwardNode {
59 public:
60 /// \brief Constructor.
61 ///
62 /// \param[in] name The name represents op name.
63 /// \param[in] output_size The output_size is output size for op.
name_(std::move (name))64 explicit BackwardNode(string name, size_t output_size = 1) : name_(std::move(name)), output_size_(output_size) {}
65
66 /// \brief Destructor.
67 virtual ~BackwardNode() = default;
68
69 /// \brief CallBackward function is used to calculate gradient of this node.
70 ///
71 /// \param[in] grads Grads is this node output's gradients.
CallBackward(const ValuePtrList & grads)72 virtual ValuePtrList CallBackward(const ValuePtrList &grads) { return {}; }
73
74 /// \brief Collect next edges of this node. The inputs should be flatten.
75 /// \param[in] inputs Inputs is op input.
76 virtual void UpdateNextEdges(const std::vector<ValuePtr> &inputs);
77
78 /// \brief Postprocess gradients from func to align with next_edges.
79 /// \param[in] gradient_value Gradients value is gradients result from func
80 /// which need postprocess.
81 /// \return Real gradients after postprocess, the size is same as next edges size.
82 virtual ValuePtrList PostProcess(const ValuePtrList &gradient_value);
83
84 // Update nullptr grad.
85 ValuePtrList LazeUpdateZeroGradient(const ValuePtrList &dout, FuncBuilder *func_builder, const ValuePtr &output);
86
87 /// \brief The PostProcess function is used to represent this node's inputs, which can
88 /// backpropagation gradients.
89 ///
90 /// \return next edges
next_edges()91 const std::vector<Edge> &next_edges() const { return next_edges_; }
92
93 /// \brief The gradient_index function is used to represent index of inputs,
94 /// which need calculate gradient.
95 ///
96 /// \return gradient index
gradient_index()97 const std::vector<size_t> &gradient_index() const { return gradient_index_; }
98
99 /// \brief name of this Node.
100 ///
101 /// \return name
name()102 const std::string &name() { return name_; }
103
104 /// \brief Set op output value
105 ///
106 /// \return op output value
set_op_output(const ValuePtr & op_output)107 void set_op_output(const ValuePtr &op_output) { op_output_ = op_output; }
108
109 /// \brief Get op output value
110 ///
111 /// \return op output value
op_output()112 const ValuePtr &op_output() { return op_output_; }
113
114 /// \brief The size of node output.
115 ///
116 /// \return output size
output_size()117 size_t output_size() const { return output_size_; }
118
119 /// \brief Release resource
120 ///
121 /// \return void
Release()122 virtual void Release() {}
123
124 protected:
125 std::vector<Edge> next_edges_;
126 std::vector<size_t> gradient_index_;
127 std::string name_;
128 ValuePtr op_output_{nullptr};
129 size_t output_size_;
130 };
131 using BackwardNodePtr = std::shared_ptr<BackwardNode>;
132
133 class IrFunctionNode {
134 public:
IrFunctionNode(KernelGraphPtr tape,const AnfNodePtr & dout)135 IrFunctionNode(KernelGraphPtr tape, const AnfNodePtr &dout)
136 : tape_(std::move(tape)), accumulate_dout_(dout), fake_dout_(dout) {}
137 void AddNextEdge(const std::shared_ptr<Variable> &next_variable, const AnfNodePtr &din);
138 void UpdateAccumulativeDout(const AnfNodePtr &new_dout);
next_edges()139 [[nodiscard]] const std::vector<std::pair<std::shared_ptr<Variable>, AnfNodePtr>> &next_edges() const {
140 return next_edges_;
141 }
tape()142 const KernelGraphPtr &tape() { return tape_; }
accumulate_dout()143 [[nodiscard]] const AnfNodePtr &accumulate_dout() const { return accumulate_dout_; }
set_accumulate_dout(const AnfNodePtr & accumulate_dout)144 void set_accumulate_dout(const AnfNodePtr &accumulate_dout) { accumulate_dout_ = accumulate_dout; }
145 void ReplaceEdges();
fake_dout()146 [[nodiscard]] const AnfNodePtr &fake_dout() const { return fake_dout_; }
147
148 private:
149 AnfNodePtr HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node);
150 // Bprop func graph
151 const KernelGraphPtr tape_;
152 // Input of dout for this bprop function
153 AnfNodePtr accumulate_dout_;
154 // First we generate a fake dout
155 const AnfNodePtr fake_dout_;
156 // The pair.first is a variable, pair.second is dout of variable
157 std::vector<std::pair<std::shared_ptr<Variable>, AnfNodePtr>> next_edges_;
158 // Replace next_edges where din == dout in brprop function
159 std::vector<int> need_replace_edges_;
160 };
161 using IrFunctionNodePtr = std::shared_ptr<IrFunctionNode>;
162
163 // Variable represent a tensor need grad
164 class Variable {
165 public:
166 /// \brief Constructor.
167 ///
168 Variable() = default;
169
170 /// \brief Destructor.
171 ///
172 virtual ~Variable() = default;
173
174 /// \param fn, Backward function.
175 /// \param is_leaf, The variable is leaf or not.
Variable(BackwardNodePtr && fn,bool is_leaf)176 Variable(BackwardNodePtr &&fn, bool is_leaf) : is_leaf_(is_leaf), func_node_(std::move(fn)) {}
177
178 /// \brief Constructor.
179 ///
180 /// \param fn, IrFunctionNodePtr function.
181 /// \param is_leaf, The variable is leaf or not.
Variable(IrFunctionNodePtr && fn,ValuePtr && out_value,bool is_leaf)182 Variable(IrFunctionNodePtr &&fn, ValuePtr &&out_value, bool is_leaf)
183 : is_leaf_(is_leaf), out_value_(std::move(out_value)), ir_function_node_(std::move(fn)) {}
184
185 /// \brief Backward function.
186 ///
187 /// \return fn
func_node()188 BackwardNodePtr func_node() const { return func_node_; }
189
190 /// \brief IrFunctionNode function.
191 ///
192 /// \return fn for ir
ir_function_node()193 IrFunctionNodePtr ir_function_node() const { return ir_function_node_; }
194
195 /// \brief Name of fake op.
196 ///
197 /// \return fake_prim_name
fake_prim_name()198 const string &fake_prim_name() const { return fake_prim_name_; }
199
200 /// \brief Set name of fake op.
201 ///
202 /// \param fake_prim_name
set_fake_prim_name(const string & fake_prim_name)203 void set_fake_prim_name(const string &fake_prim_name) { fake_prim_name_ = fake_prim_name; }
204
205 /// \brief Flag to judge whether the op is fake op.
206 ///
is_fake_bprop()207 bool is_fake_bprop() const { return is_fake_bprop_; }
208
209 /// \brief Set fake bprop.
210 ///
211 /// \param is_fake_bprop
set_is_fake_bprop(bool is_fake_bprop)212 void set_is_fake_bprop(bool is_fake_bprop) { is_fake_bprop_ = is_fake_bprop; }
213
214 /// \brief Flag to judge whether the variable is need propagate.
215 ///
216 /// \return True if the variable need propagate, false if not.
is_need_propagate()217 bool is_need_propagate() const { return is_need_propagate_; }
218
219 /// \brief Set need propagate.
220 ///
set_is_need_propagate(bool is_need_grad)221 void set_is_need_propagate(bool is_need_grad) { is_need_propagate_ = is_need_grad; }
222
223 /// \brief Flag to judge whether the variable is need grad.
224 ///
225 /// \return is need grad
is_need_grad()226 bool is_need_grad() const { return is_need_grad_; }
227
228 /// \brief Set need grad.
229 ///
230 /// \param is_need_grad
set_is_need_grad(bool is_need_grad)231 void set_is_need_grad(bool is_need_grad) { is_need_grad_ = is_need_grad; }
232
233 /// \brief Judge whether the variable is left node.
234 ///
235 /// \return True if variable is leaf, false if not.
is_leaf()236 bool is_leaf() const { return is_leaf_; }
237
238 /// \brief Get forward output value.
239 ///
240 /// \return valueptr.
241
out_value()242 ValuePtr out_value() const { return out_value_; }
243
244 /// \brief Debug info.
245 ///
246 /// \return debug info.
ToString()247 virtual std::string ToString() const { return {}; }
248 /// \brief Release input and output tensors
249 ///
250 /// \return void
Release()251 void Release() {
252 MS_EXCEPTION_IF_NULL(func_node_);
253 func_node_->Release();
254 }
255
256 private:
257 // If node has not bprop, we record its prim name
258 std::string fake_prim_name_;
259 // Record this node is a fake bprop
260 bool is_fake_bprop_{false};
261 // Flag to judge need to propagrate
262 bool is_need_propagate_{false};
263 // Flag to judge variable whether need grad
264 bool is_need_grad_{true};
265 // Flag the variable is a leaf in bprop.
266 bool is_leaf_{false};
267 ValuePtr out_value_{nullptr};
268 // Abstract bprop function
269 BackwardNodePtr func_node_{nullptr};
270 // Abstract bprop function
271 IrFunctionNodePtr ir_function_node_{nullptr};
272 };
273 using VariablePtr = std::shared_ptr<Variable>;
274
275 // FuncVariable represent a parameter or output of op
276 class FuncVariable : public Variable {
277 public:
278 /// \brief Constructor.
279 ///
280 FuncVariable() = default;
281 ~FuncVariable() override = default;
282 /// \brief Constructor.
283 ///
284 /// \param fn, Backward function.
285 /// \param is_leaf, The variable is leaf or not.
Variable(std::move (fn),is_leaf)286 explicit FuncVariable(BackwardNodePtr fn, bool is_leaf = false) : Variable(std::move(fn), is_leaf) {}
287
288 /// \brief Gradients of the variable if variable is left node, nullptr if not left node.
289 ///
grad()290 const tensor::BaseTensorPtr &grad() const { return grad_; }
291
292 /// \brief Set gradients of the leaf variable.
293 ///
294 /// \param grad
set_grad(const tensor::BaseTensorPtr & grad)295 void set_grad(const tensor::BaseTensorPtr &grad) { grad_ = grad; }
296
297 std::string ToString() const override;
298
299 private:
300 // Grad for this variable, only leaf node has grad.
301 tensor::BaseTensorPtr grad_;
302 };
303 using FuncVariablePtr = std::shared_ptr<FuncVariable>;
304
305 // IrVariable represent a parameter or output of a middle cnode
306 class IrVariable : public Variable {
307 public:
308 IrVariable() = default;
309 ~IrVariable() override = default;
310
311 IrVariable(IrFunctionNodePtr fn, ValuePtr out_value, bool is_leaf = false)
Variable(std::move (fn),std::move (out_value),is_leaf)312 : Variable(std::move(fn), std::move(out_value), is_leaf) {}
313
k_node()314 AnfNodePtr k_node() const { return k_node_; }
set_k_node(const AnfNodePtr & k_node)315 void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
316 AnfNodePtr RealDout();
317 std::string ToString() const override;
318
319 private:
320 AnfNodePtr k_node_{nullptr};
321 };
322 using IrVariablePtr = std::shared_ptr<IrVariable>;
323
324 template <typename T>
isa(const BackwardNodePtr & base_ptr)325 bool isa(const BackwardNodePtr &base_ptr) {
326 const auto &object = (*base_ptr);
327 return typeid(object) == typeid(T);
328 }
329 } // namespace mindspore::pynative::autograd
330
331 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_VARIABLE_H_
332