1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_BPROP_FUNC_CONTEXT_H_ 17 #define MINDSPORE_PI_JIT_BPROP_FUNC_CONTEXT_H_ 18 19 #include <memory> 20 #include <vector> 21 #include "ir/value.h" 22 23 namespace mindspore { 24 namespace pijit { 25 namespace grad { 26 /// \brief FunctionContext is a class, which include the inputs and output of FunctionNode. 27 class FunctionContext : public std::enable_shared_from_this<FunctionContext> { 28 public: 29 /// \brief The constructor of FunctionContext. 30 /// 31 /// \param[in] output The output of FunctionContext. 32 /// 33 /// \return The instance of FunctionContext. FunctionContext(const ValuePtr & output)34 explicit FunctionContext(const ValuePtr &output) : FunctionContext(kNone, output) {} 35 36 /// \brief The constructor of FunctionContext. 37 /// 38 /// \param[in] fn The function of FunctionContext. 39 /// \param[in] output The output of FunctionContext. 40 /// 41 /// \return The instance of FunctionContext. FunctionContext(const ValuePtr & fn,const ValuePtr & output)42 explicit FunctionContext(const ValuePtr &fn, const ValuePtr &output) 43 : FunctionContext(fn, output, 44 ValuePtrList(output->isa<ValueTuple>() ? output->cast<ValueTuplePtr>()->size() : 1, kNone)) {} 45 46 /// \brief The constructor of FunctionContext. 47 /// 48 /// \param[in] fn The function of FunctionContext. 49 /// \param[in] output The output of FunctionContext. 50 /// \param[in] dout The gradient value list of FunctionContext. 51 /// 52 /// \return The instance of FunctionContext. FunctionContext(const ValuePtr & fn,const ValuePtr & output,const ValuePtrList & dout)53 explicit FunctionContext(const ValuePtr &fn, const ValuePtr &output, const ValuePtrList &dout) 54 : fn_(fn), inputs_({}), output_(output), dout_(dout) {} 55 56 /// \brief Destructor. 57 virtual ~FunctionContext() = default; 58 59 /// \brief Get the function of the function node. 60 /// 61 /// \return The function of the function node. GetFunction()62 const ValuePtr &GetFunction() const { return fn_; } 63 64 /// \brief Set the function of the function node. 65 /// 66 /// \param[in] fn The function. SetFunction(const ValuePtr & fn)67 void SetFunction(const ValuePtr &fn) { fn_ = fn; } 68 69 /// \brief Get the inputs of the function node. 70 /// 71 /// \return The inputs of the function node. GetInputs()72 const ValuePtrList &GetInputs() const { return inputs_; } 73 74 /// \brief Remove all inputs of the function node. RemoveInputs()75 void RemoveInputs() { inputs_.clear(); } 76 77 /// \brief Set the inputs of the function node. 78 /// 79 /// \param[in] inputs The inputs. SetInputs(const ValuePtrList & inputs)80 void SetInputs(const ValuePtrList &inputs) { inputs_ = inputs; } 81 82 /// \brief Add a input at the end of the input list. 83 /// 84 /// \param[in] input The input. AddInput(const ValuePtr & input)85 void AddInput(const ValuePtr &input) { inputs_.push_back(input); } 86 87 /// \brief Get the output of the function node. 88 /// 89 /// \return The output of the function node. GetOutput()90 const ValuePtr &GetOutput() const { return output_; } 91 92 /// \brief Set the output of the function node. 93 /// 94 /// \param[in] output The output. SetOutput(const ValuePtr & output)95 void SetOutput(const ValuePtr &output) { output_ = output; } 96 97 /// \brief Get the grad value list of the function node. 98 /// 99 /// \return The grad value list of the function node. GetGrad()100 const ValuePtrList &GetGrad() const { return dout_; } 101 102 /// \brief Set the grad value list of the function node. 103 /// 104 /// \param[in] grads The grad value list. SetGrad(const ValuePtrList & grads)105 void SetGrad(const ValuePtrList &grads) { dout_ = grads; } 106 107 /// \brief Set the grad value list of the function node. 108 /// 109 /// \param[in] grad The grad value list. 110 /// \param[in] index The index of the tensor in input. SetGrad(const ValuePtr & grad,size_t index)111 void SetGrad(const ValuePtr &grad, size_t index) { dout_[index] = grad; } 112 113 /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived> 114 /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr 115 /// \return A shared_ptr casted to the derived class 116 template <typename Derived> shared_from_base()117 std::shared_ptr<Derived> shared_from_base() { 118 return std::static_pointer_cast<Derived>(shared_from_this()); 119 } 120 121 private: 122 /// \brief The function of the function node. 123 ValuePtr fn_; 124 /// \brief The input list of the function node. 125 ValuePtrList inputs_; 126 /// \brief The output of the function node. 127 ValuePtr output_; 128 /// \brief The delta out of the function node. 129 ValuePtrList dout_; 130 }; 131 132 using FuncCtxPtr = std::shared_ptr<FunctionContext>; 133 } // namespace grad 134 } // namespace pijit 135 } // namespace mindspore 136 #endif // MINDSPORE_PI_JIT_BPROP_FUNC_CONTEXT_H_ 137