• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PI_JIT_BACKWARD_FUNCTION_H_
18 #define MINDSPORE_PI_JIT_BACKWARD_FUNCTION_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include "ir/anf.h"
24 
25 namespace mindspore {
26 namespace pijit {
27 namespace grad {
28 /// \brief BackwardFunc is a class, which represent a function to calculate the gradient.
29 class BackwardFunc {
30  public:
31   /// \brief The constructor of BackwardFunc.
32   ///
33   /// \param[in] name The name of this backward function.
34   ///
35   /// \return The instance of BackwardFunc.
BackwardFunc(const std::string & name)36   explicit BackwardFunc(const std::string &name) : name_(name) {}
37 
38   /// \brief Destructor.
39   virtual ~BackwardFunc() = default;
40 
41   /// \brief Get the name of the backward function.
42   ///
43   /// \return The name of this backward function.
GetName()44   const std::string &GetName() const { return name_; }
45 
46   /// \brief Start calculate the gradient of the backward function.
47   ///
48   /// \param[in] inputs The arguments of the forward execution.
49   /// \param[in] out The output of the forward execution.
50   /// \param[in] dout The dout of the output.
51   ///
52   /// \return The gradients of the inputs of forward execution.
53   virtual ValuePtrList Run(const ValuePtrList &inputs, const ValuePtr &out, const ValuePtr &dout) = 0;
54 
55   /// \brief Postprocess gradients from func to align with next_edges.
56   ///
57   /// \param[in] gradient_value Gradients value is gradients result from func which need postprocess.
58   ///
59   /// \return Real gradients after postprocess, the size is same as next edges size.
PostProcess(const ValuePtrList & gradient_value)60   virtual ValuePtrList PostProcess(const ValuePtrList &gradient_value) { return gradient_value; }
61 
62   /// \brief Get indexes of inputs required to calculate the gradient.
63   ///
64   /// \return The indexes of inputs required to calculate the gradient.
GetGradientIndexes()65   const std::vector<size_t> &GetGradientIndexes() const { return gradient_indexes_; }
66 
67   /// \brief Set the indexes of forward's inputs required to calculate the gradient.
68   ///
69   /// \param[in] indexes The indexes of inputs required to calculate the gradient.
SetGradientIndexes(const std::vector<size_t> & indexes)70   void SetGradientIndexes(const std::vector<size_t> &indexes) { gradient_indexes_ = indexes; }
71 
72   /// \brief Add a index of forward's input required to calculate the gradient.
73   ///
74   /// \param[in] index The index of forward's input required to calculate the gradient.
AddGradientIndex(size_t index)75   void AddGradientIndex(size_t index) { gradient_indexes_.push_back(index); }
76 
77   /// \brief Create the value filled with one, shape like the input.
78   ///
79   /// \param[in] value The input value.
80   ///
81   /// \return The value filled with one.
82   virtual ValuePtr Ones(const ValuePtr &value) const = 0;
83 
84   /// \brief Create the value filled with zero, shape like the input.
85   ///
86   /// \param[in] value The input value.
87   ///
88   /// \return The value filled with zero.
89   virtual ValuePtr Zeros(const ValuePtr &value) const = 0;
90 
91   /// \brief Calculate the sum of inputs.
92   ///
93   /// \param[in] input The first input value.
94   /// \param[in] other The second input value.
95   ///
96   /// \return The sum of inputs.
97   virtual ValuePtr Add(const ValuePtr &input, const ValuePtr &other) const = 0;
98 
99  private:
100   /// \brief The name of this backward function.
101   std::string name_;
102   /// \brief The index set of the inputs required to calculate the gradient.
103   std::vector<size_t> gradient_indexes_;
104 };
105 
106 using BackwardFuncPtr = std::shared_ptr<BackwardFunc>;
107 }  // namespace grad
108 }  // namespace pijit
109 }  // namespace mindspore
110 #endif  // MINDSPORE_PI_JIT_BACKWARD_FUNCTION_H_
111