• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PI_JIT_FUNCTION_NODE_H_
18 #define MINDSPORE_PI_JIT_FUNCTION_NODE_H_
19 
20 #include <atomic>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include "pipeline/jit/pi/auto_grad/backward_function.h"
25 #include "pipeline/jit/pi/auto_grad/edge.h"
26 #include "pipeline/jit/pi/auto_grad/function_context.h"
27 #include "pipeline/jit/pi/auto_grad/native_backward_function.h"
28 #include "pipeline/pynative/pynative_utils.h"
29 #include "utils/tensor_construct_utils.h"
30 
31 namespace mindspore {
32 namespace pijit {
33 namespace grad {
34 namespace py = pybind11;
35 using Convert = pynative::PyNativeAlgo::DataConvert;
36 
37 /// \brief FunctionNode is a class, which represent a way to calculate the gradient.
38 class FunctionNode : public FunctionContext {
39  public:
40   /// \brief The constructor of FunctionNode.
41   ///
42   /// \param[in] tensor The tensor that is asked to calculate the gradient.
43   ///
44   /// \return The instance of FunctionNode.
FunctionNode(const py::object & tensor)45   explicit FunctionNode(const py::object &tensor) : FunctionContext(Convert::PyObjToValue(tensor)), tensor_(tensor) {}
46 
47   /// \brief The constructor of FunctionNode.
48   ///
49   /// \param[in] tensor The tensor that is asked to calculate the gradient.
50   /// \param[in] prim The calculation that the tensor as input.
51   /// \param[in] out The output of the calculation that the tensor as input.
52   ///
53   /// \return The instance of FunctionNode.
FunctionNode(const py::object & tensor,const py::object & prim,const py::object & out)54   explicit FunctionNode(const py::object &tensor, const py::object &prim, const py::object &out)
55       : FunctionContext(Convert::PyObjToValue(prim), Convert::PyObjToValue(out)),
56         tensor_(tensor),
57         backward_func_(NativeBackwardFunc::GetInstance(Convert::PyObjToValue(prim)->cast<PrimitivePtr>())) {}
58 
59   /// \brief Destructor.
60   virtual ~FunctionNode() = default;
61 
62   /// \brief Release all resource.
63   void CleanResource();
64 
65   /// \brief Determine whether the python object has attribute `requires_grad`.
66   ///
67   /// \param[in] obj The python object.
68   ///
69   /// \return The result of the python object's attribute `requires_grad`.
HasAttrReqGrad(const py::handle & obj)70   static bool HasAttrReqGrad(const py::handle &obj) { return py::hasattr(obj, "requires_grad"); }
71 
72   /// \brief Determine whether the python object has attribute `requires_grad`, and the value is True.
73   ///
74   /// \param[in] obj The python object.
75   ///
76   /// \return The result of the python object's attribute `requires_grad`.
77   static bool IsRequiresGradient(const py::handle &obj);
78 
79   /// \brief Determine whether the python object has attribute `grad_fn`.
80   ///
81   /// \param[in] obj The python object.
82   ///
83   /// \return The result whether the python object has attribute `grad_fn`.
84   static bool HasGradFunc(const py::handle &obj);
85 
86   /// \brief Create a new function node.
87   ///
88   /// \param[in] tensor The tensor mounted by function node.
89   /// \param[in] prim The forward execution function.
90   /// \param[in] out The output of the forward execution function.
91   /// \param[in] inputs The input of the forward execution function.
92   ///
93   /// \return The instance of function node.
94   static FunctionNodePtr CreateFunctionNode(const py::object &tensor, const py::object &prim, const py::object &out,
95                                             const py::list &inputs);
96 
97   /// \brief The static method to record the executed primitive during forward execution.
98   ///
99   /// \param[in] prim The executed primitive.
100   /// \param[in] out The output of the executed primitive.
101   /// \param[in] inputs The inputs of the executed primitive.
102   static void RecordPrimitive(const py::object &prim, const py::object &out, const py::list &inputs);
103 
104   /// \brief Get the tensor that is asked to calculate the gradient.
105   ///
106   /// \return The tensor that is asked to calculate the gradient.
GetTensor()107   const py::object &GetTensor() const { return tensor_; }
108 
109   /// \brief Set the inputs of the function node.
110   ///
111   /// \param[in] inputs The inputs.
112   void SetInputs(const py::list &inputs);
113 
114   /// \brief Get the bprop function graph.
115   ///
116   /// \return The bprop function graph.
GetBpropFunction()117   const FuncGraphPtr &GetBpropFunction() const { return grad_fn_; }
118 
119   /// \brief Generate the bprop function.
120   void GenerateBropFunction();
121 
122   /// \brief Start gradient calculation.
123   void ApplyNative();
124 
125   /// \brief Get the called functions in the previous/next step.
126   ///
127   /// \return The called functions in the previous/next step.
GetNextEdges()128   const EdgePtrList &GetNextEdges() const { return edges_; }
129 
130   /// \brief Set the called functions in the previous/next step.
131   ///
132   /// \param[in] edges The called functions.
SetNextEdges(const EdgePtrList & edges)133   void SetNextEdges(const EdgePtrList &edges) { edges_ = edges; }
134 
135   /// \brief Add a called function in the previous/next step.
136   ///
137   /// \param[in] node The called function.
138   /// \param[in] index The index of the input.
AddNextEdge(const FunctionNodePtr & node,size_t index)139   void AddNextEdge(const FunctionNodePtr &node, size_t index) {
140     edges_.push_back(std::make_shared<Edge>(node, index));
141     node->dependences_.insert(shared_from_base<FunctionNode>());
142   }
143 
144   /// \brief Synchronize gradient value ​​to python object.
145   void SyncGradToPyObject();
146 
147   /// \brief Generate the grad value of function.
148   ///
149   /// \param[in] grad The default gradient value of the function node.
150   ///
151   /// \note This function node must be the tensor who call backward from python.
152   void Apply(const py::object &grad);
153 
154   /// \brief Generate the description of the tree nodes.
155   std::string ToString() const;
156 
157  private:
158   /// \brief Generate the grad value of function.
159   ///
160   /// \param[in] dout The gradient of the output.
161   void ApplyInner(const ValuePtr &dout);
162 
163   /// \brief Calculate the gradient of the next layer function node.
164   ///
165   /// \param[in] grad_values The the gradient values.
166   void ApplyEdges(const ValuePtrList &grad_values);
167 
168   /// \brief Update data dependencies.
169   void UpdateDependence();
170 
171   /// \brief Notify the function node that the gradient data is ready.
172   ///
173   /// \param[in] node The function node been notified.
174   /// \param[in] dout The gradient data.
175   void Notify(const FunctionNodePtr &node, const ValuePtr &dout);
176 
177   /// \brief Accumulate the delta of the gradient.
178   ///
179   /// \param[in] dout The delta of the gradient.
180   /// \param[in] index The index of the gradient.
181   void AccumulateGradient(const ValuePtr &dout, size_t index);
182 
183   /// \brief Determine whether the current function node can start gradient calculation
IsReady()184   bool IsReady() const { return depend_cnt_.load() == dependences_.size(); }
185 
186   /// \brief Dump the function node and its children.
187   ///
188   /// \param[in] ss The string stream.
189   /// \param[in] prefix The prefix string for this node.
190   void Dump(std::stringstream &ss, const std::string &prefix) const;
191 
192   /// \brief The called function.
193   py::object tensor_;
194   /// \brief the function used to calculate the gradient.
195   BackwardFuncPtr backward_func_;
196   /// \brief The bprop function.
197   FuncGraphPtr grad_fn_;
198   /// \brief The accumulate function.
199   FuncGraphPtr acc_fn_;
200   /// \brief The called functions in the previous/next step.
201   EdgePtrList edges_;
202   /// \brief The mutex for accumulate the delta of the gradient.
203   std::mutex mutex_;
204   /// \brief Used to locate the position of the tensor in multiple outputs.
205   size_t index_{0};
206   /// \brief Mark whether the current node is used in the reverse calculation.
207   bool is_in_reverse_chain_{false};
208   /// \brief Dependency data of the current node in gradient calculation.
209   std::set<FunctionNodePtr> dependences_;
210   /// \brief The dependency status.
211   std::atomic<size_t> depend_cnt_{0};
212 };
213 
214 using FunctionNodePtr = std::shared_ptr<FunctionNode>;
215 }  // namespace grad
216 }  // namespace pijit
217 }  // namespace mindspore
218 #endif  // MINDSPORE_PI_JIT_FUNCTION_NODE_H_
219