• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_BPROP_FUNC_CONTEXT_H_
17 #define MINDSPORE_PI_JIT_BPROP_FUNC_CONTEXT_H_
18 
19 #include <memory>
20 #include <vector>
21 #include "ir/value.h"
22 
23 namespace mindspore {
24 namespace pijit {
25 namespace grad {
26 /// \brief FunctionContext is a class, which include the inputs and output of FunctionNode.
27 class FunctionContext : public std::enable_shared_from_this<FunctionContext> {
28  public:
29   /// \brief The constructor of FunctionContext.
30   ///
31   /// \param[in] output The output of FunctionContext.
32   ///
33   /// \return The instance of FunctionContext.
FunctionContext(const ValuePtr & output)34   explicit FunctionContext(const ValuePtr &output) : FunctionContext(kNone, output) {}
35 
36   /// \brief The constructor of FunctionContext.
37   ///
38   /// \param[in] fn The function of FunctionContext.
39   /// \param[in] output The output of FunctionContext.
40   ///
41   /// \return The instance of FunctionContext.
FunctionContext(const ValuePtr & fn,const ValuePtr & output)42   explicit FunctionContext(const ValuePtr &fn, const ValuePtr &output)
43       : FunctionContext(fn, output,
44                         ValuePtrList(output->isa<ValueTuple>() ? output->cast<ValueTuplePtr>()->size() : 1, kNone)) {}
45 
46   /// \brief The constructor of FunctionContext.
47   ///
48   /// \param[in] fn The function of FunctionContext.
49   /// \param[in] output The output of FunctionContext.
50   /// \param[in] dout The gradient value list of FunctionContext.
51   ///
52   /// \return The instance of FunctionContext.
FunctionContext(const ValuePtr & fn,const ValuePtr & output,const ValuePtrList & dout)53   explicit FunctionContext(const ValuePtr &fn, const ValuePtr &output, const ValuePtrList &dout)
54       : fn_(fn), inputs_({}), output_(output), dout_(dout) {}
55 
56   /// \brief Destructor.
57   virtual ~FunctionContext() = default;
58 
59   /// \brief Get the function of the function node.
60   ///
61   /// \return The function of the function node.
GetFunction()62   const ValuePtr &GetFunction() const { return fn_; }
63 
64   /// \brief Set the function of the function node.
65   ///
66   /// \param[in] fn The function.
SetFunction(const ValuePtr & fn)67   void SetFunction(const ValuePtr &fn) { fn_ = fn; }
68 
69   /// \brief Get the inputs of the function node.
70   ///
71   /// \return The inputs of the function node.
GetInputs()72   const ValuePtrList &GetInputs() const { return inputs_; }
73 
74   /// \brief Remove all inputs of the function node.
RemoveInputs()75   void RemoveInputs() { inputs_.clear(); }
76 
77   /// \brief Set the inputs of the function node.
78   ///
79   /// \param[in] inputs The inputs.
SetInputs(const ValuePtrList & inputs)80   void SetInputs(const ValuePtrList &inputs) { inputs_ = inputs; }
81 
82   /// \brief Add a input at the end of the input list.
83   ///
84   /// \param[in] input The input.
AddInput(const ValuePtr & input)85   void AddInput(const ValuePtr &input) { inputs_.push_back(input); }
86 
87   /// \brief Get the output of the function node.
88   ///
89   /// \return The output of the function node.
GetOutput()90   const ValuePtr &GetOutput() const { return output_; }
91 
92   /// \brief Set the output of the function node.
93   ///
94   /// \param[in] output The output.
SetOutput(const ValuePtr & output)95   void SetOutput(const ValuePtr &output) { output_ = output; }
96 
97   /// \brief Get the grad value list of the function node.
98   ///
99   /// \return The grad value list of the function node.
GetGrad()100   const ValuePtrList &GetGrad() const { return dout_; }
101 
102   /// \brief Set the grad value list of the function node.
103   ///
104   /// \param[in] grads The grad value list.
SetGrad(const ValuePtrList & grads)105   void SetGrad(const ValuePtrList &grads) { dout_ = grads; }
106 
107   /// \brief Set the grad value list of the function node.
108   ///
109   /// \param[in] grad The grad value list.
110   /// \param[in] index The index of the tensor in input.
SetGrad(const ValuePtr & grad,size_t index)111   void SetGrad(const ValuePtr &grad, size_t index) { dout_[index] = grad; }
112 
113   /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
114   ///     Similar to shared_from_this, except this one will give you the derived class as shared_ptr
115   /// \return A shared_ptr casted to the derived class
116   template <typename Derived>
shared_from_base()117   std::shared_ptr<Derived> shared_from_base() {
118     return std::static_pointer_cast<Derived>(shared_from_this());
119   }
120 
121  private:
122   /// \brief The function of the function node.
123   ValuePtr fn_;
124   /// \brief The input list of the function node.
125   ValuePtrList inputs_;
126   /// \brief The output of the function node.
127   ValuePtr output_;
128   /// \brief The delta out of the function node.
129   ValuePtrList dout_;
130 };
131 
132 using FuncCtxPtr = std::shared_ptr<FunctionContext>;
133 }  // namespace grad
134 }  // namespace pijit
135 }  // namespace mindspore
136 #endif  // MINDSPORE_PI_JIT_BPROP_FUNC_CONTEXT_H_
137