1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_PI_JIT_FUNCTION_NODE_H_ 18 #define MINDSPORE_PI_JIT_FUNCTION_NODE_H_ 19 20 #include <atomic> 21 #include <memory> 22 #include <set> 23 #include <string> 24 #include "pipeline/jit/pi/auto_grad/backward_function.h" 25 #include "pipeline/jit/pi/auto_grad/edge.h" 26 #include "pipeline/jit/pi/auto_grad/function_context.h" 27 #include "pipeline/jit/pi/auto_grad/native_backward_function.h" 28 #include "pipeline/pynative/pynative_utils.h" 29 #include "utils/tensor_construct_utils.h" 30 31 namespace mindspore { 32 namespace pijit { 33 namespace grad { 34 namespace py = pybind11; 35 using Convert = pynative::PyNativeAlgo::DataConvert; 36 37 /// \brief FunctionNode is a class, which represent a way to calculate the gradient. 38 class FunctionNode : public FunctionContext { 39 public: 40 /// \brief The constructor of FunctionNode. 41 /// 42 /// \param[in] tensor The tensor that is asked to calculate the gradient. 43 /// 44 /// \return The instance of FunctionNode. FunctionNode(const py::object & tensor)45 explicit FunctionNode(const py::object &tensor) : FunctionContext(Convert::PyObjToValue(tensor)), tensor_(tensor) {} 46 47 /// \brief The constructor of FunctionNode. 48 /// 49 /// \param[in] tensor The tensor that is asked to calculate the gradient. 50 /// \param[in] prim The calculation that the tensor as input. 51 /// \param[in] out The output of the calculation that the tensor as input. 52 /// 53 /// \return The instance of FunctionNode. FunctionNode(const py::object & tensor,const py::object & prim,const py::object & out)54 explicit FunctionNode(const py::object &tensor, const py::object &prim, const py::object &out) 55 : FunctionContext(Convert::PyObjToValue(prim), Convert::PyObjToValue(out)), 56 tensor_(tensor), 57 backward_func_(NativeBackwardFunc::GetInstance(Convert::PyObjToValue(prim)->cast<PrimitivePtr>())) {} 58 59 /// \brief Destructor. 60 virtual ~FunctionNode() = default; 61 62 /// \brief Release all resource. 63 void CleanResource(); 64 65 /// \brief Determine whether the python object has attribute `requires_grad`. 66 /// 67 /// \param[in] obj The python object. 68 /// 69 /// \return The result of the python object's attribute `requires_grad`. HasAttrReqGrad(const py::handle & obj)70 static bool HasAttrReqGrad(const py::handle &obj) { return py::hasattr(obj, "requires_grad"); } 71 72 /// \brief Determine whether the python object has attribute `requires_grad`, and the value is True. 73 /// 74 /// \param[in] obj The python object. 75 /// 76 /// \return The result of the python object's attribute `requires_grad`. 77 static bool IsRequiresGradient(const py::handle &obj); 78 79 /// \brief Determine whether the python object has attribute `grad_fn`. 80 /// 81 /// \param[in] obj The python object. 82 /// 83 /// \return The result whether the python object has attribute `grad_fn`. 84 static bool HasGradFunc(const py::handle &obj); 85 86 /// \brief Create a new function node. 87 /// 88 /// \param[in] tensor The tensor mounted by function node. 89 /// \param[in] prim The forward execution function. 90 /// \param[in] out The output of the forward execution function. 91 /// \param[in] inputs The input of the forward execution function. 92 /// 93 /// \return The instance of function node. 94 static FunctionNodePtr CreateFunctionNode(const py::object &tensor, const py::object &prim, const py::object &out, 95 const py::list &inputs); 96 97 /// \brief The static method to record the executed primitive during forward execution. 98 /// 99 /// \param[in] prim The executed primitive. 100 /// \param[in] out The output of the executed primitive. 101 /// \param[in] inputs The inputs of the executed primitive. 102 static void RecordPrimitive(const py::object &prim, const py::object &out, const py::list &inputs); 103 104 /// \brief Get the tensor that is asked to calculate the gradient. 105 /// 106 /// \return The tensor that is asked to calculate the gradient. GetTensor()107 const py::object &GetTensor() const { return tensor_; } 108 109 /// \brief Set the inputs of the function node. 110 /// 111 /// \param[in] inputs The inputs. 112 void SetInputs(const py::list &inputs); 113 114 /// \brief Get the bprop function graph. 115 /// 116 /// \return The bprop function graph. GetBpropFunction()117 const FuncGraphPtr &GetBpropFunction() const { return grad_fn_; } 118 119 /// \brief Generate the bprop function. 120 void GenerateBropFunction(); 121 122 /// \brief Start gradient calculation. 123 void ApplyNative(); 124 125 /// \brief Get the called functions in the previous/next step. 126 /// 127 /// \return The called functions in the previous/next step. GetNextEdges()128 const EdgePtrList &GetNextEdges() const { return edges_; } 129 130 /// \brief Set the called functions in the previous/next step. 131 /// 132 /// \param[in] edges The called functions. SetNextEdges(const EdgePtrList & edges)133 void SetNextEdges(const EdgePtrList &edges) { edges_ = edges; } 134 135 /// \brief Add a called function in the previous/next step. 136 /// 137 /// \param[in] node The called function. 138 /// \param[in] index The index of the input. AddNextEdge(const FunctionNodePtr & node,size_t index)139 void AddNextEdge(const FunctionNodePtr &node, size_t index) { 140 edges_.push_back(std::make_shared<Edge>(node, index)); 141 node->dependences_.insert(shared_from_base<FunctionNode>()); 142 } 143 144 /// \brief Synchronize gradient value to python object. 145 void SyncGradToPyObject(); 146 147 /// \brief Generate the grad value of function. 148 /// 149 /// \param[in] grad The default gradient value of the function node. 150 /// 151 /// \note This function node must be the tensor who call backward from python. 152 void Apply(const py::object &grad); 153 154 /// \brief Generate the description of the tree nodes. 155 std::string ToString() const; 156 157 private: 158 /// \brief Generate the grad value of function. 159 /// 160 /// \param[in] dout The gradient of the output. 161 void ApplyInner(const ValuePtr &dout); 162 163 /// \brief Calculate the gradient of the next layer function node. 164 /// 165 /// \param[in] grad_values The the gradient values. 166 void ApplyEdges(const ValuePtrList &grad_values); 167 168 /// \brief Update data dependencies. 169 void UpdateDependence(); 170 171 /// \brief Notify the function node that the gradient data is ready. 172 /// 173 /// \param[in] node The function node been notified. 174 /// \param[in] dout The gradient data. 175 void Notify(const FunctionNodePtr &node, const ValuePtr &dout); 176 177 /// \brief Accumulate the delta of the gradient. 178 /// 179 /// \param[in] dout The delta of the gradient. 180 /// \param[in] index The index of the gradient. 181 void AccumulateGradient(const ValuePtr &dout, size_t index); 182 183 /// \brief Determine whether the current function node can start gradient calculation IsReady()184 bool IsReady() const { return depend_cnt_.load() == dependences_.size(); } 185 186 /// \brief Dump the function node and its children. 187 /// 188 /// \param[in] ss The string stream. 189 /// \param[in] prefix The prefix string for this node. 190 void Dump(std::stringstream &ss, const std::string &prefix) const; 191 192 /// \brief The called function. 193 py::object tensor_; 194 /// \brief the function used to calculate the gradient. 195 BackwardFuncPtr backward_func_; 196 /// \brief The bprop function. 197 FuncGraphPtr grad_fn_; 198 /// \brief The accumulate function. 199 FuncGraphPtr acc_fn_; 200 /// \brief The called functions in the previous/next step. 201 EdgePtrList edges_; 202 /// \brief The mutex for accumulate the delta of the gradient. 203 std::mutex mutex_; 204 /// \brief Used to locate the position of the tensor in multiple outputs. 205 size_t index_{0}; 206 /// \brief Mark whether the current node is used in the reverse calculation. 207 bool is_in_reverse_chain_{false}; 208 /// \brief Dependency data of the current node in gradient calculation. 209 std::set<FunctionNodePtr> dependences_; 210 /// \brief The dependency status. 211 std::atomic<size_t> depend_cnt_{0}; 212 }; 213 214 using FunctionNodePtr = std::shared_ptr<FunctionNode>; 215 } // namespace grad 216 } // namespace pijit 217 } // namespace mindspore 218 #endif // MINDSPORE_PI_JIT_FUNCTION_NODE_H_ 219