1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_PI_JIT_BACKWARD_FUNCTION_H_ 18 #define MINDSPORE_PI_JIT_BACKWARD_FUNCTION_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include "ir/anf.h" 24 25 namespace mindspore { 26 namespace pijit { 27 namespace grad { 28 /// \brief BackwardFunc is a class, which represent a function to calculate the gradient. 29 class BackwardFunc { 30 public: 31 /// \brief The constructor of BackwardFunc. 32 /// 33 /// \param[in] name The name of this backward function. 34 /// 35 /// \return The instance of BackwardFunc. BackwardFunc(const std::string & name)36 explicit BackwardFunc(const std::string &name) : name_(name) {} 37 38 /// \brief Destructor. 39 virtual ~BackwardFunc() = default; 40 41 /// \brief Get the name of the backward function. 42 /// 43 /// \return The name of this backward function. GetName()44 const std::string &GetName() const { return name_; } 45 46 /// \brief Start calculate the gradient of the backward function. 47 /// 48 /// \param[in] inputs The arguments of the forward execution. 49 /// \param[in] out The output of the forward execution. 50 /// \param[in] dout The dout of the output. 51 /// 52 /// \return The gradients of the inputs of forward execution. 53 virtual ValuePtrList Run(const ValuePtrList &inputs, const ValuePtr &out, const ValuePtr &dout) = 0; 54 55 /// \brief Postprocess gradients from func to align with next_edges. 56 /// 57 /// \param[in] gradient_value Gradients value is gradients result from func which need postprocess. 58 /// 59 /// \return Real gradients after postprocess, the size is same as next edges size. PostProcess(const ValuePtrList & gradient_value)60 virtual ValuePtrList PostProcess(const ValuePtrList &gradient_value) { return gradient_value; } 61 62 /// \brief Get indexes of inputs required to calculate the gradient. 63 /// 64 /// \return The indexes of inputs required to calculate the gradient. GetGradientIndexes()65 const std::vector<size_t> &GetGradientIndexes() const { return gradient_indexes_; } 66 67 /// \brief Set the indexes of forward's inputs required to calculate the gradient. 68 /// 69 /// \param[in] indexes The indexes of inputs required to calculate the gradient. SetGradientIndexes(const std::vector<size_t> & indexes)70 void SetGradientIndexes(const std::vector<size_t> &indexes) { gradient_indexes_ = indexes; } 71 72 /// \brief Add a index of forward's input required to calculate the gradient. 73 /// 74 /// \param[in] index The index of forward's input required to calculate the gradient. AddGradientIndex(size_t index)75 void AddGradientIndex(size_t index) { gradient_indexes_.push_back(index); } 76 77 /// \brief Create the value filled with one, shape like the input. 78 /// 79 /// \param[in] value The input value. 80 /// 81 /// \return The value filled with one. 82 virtual ValuePtr Ones(const ValuePtr &value) const = 0; 83 84 /// \brief Create the value filled with zero, shape like the input. 85 /// 86 /// \param[in] value The input value. 87 /// 88 /// \return The value filled with zero. 89 virtual ValuePtr Zeros(const ValuePtr &value) const = 0; 90 91 /// \brief Calculate the sum of inputs. 92 /// 93 /// \param[in] input The first input value. 94 /// \param[in] other The second input value. 95 /// 96 /// \return The sum of inputs. 97 virtual ValuePtr Add(const ValuePtr &input, const ValuePtr &other) const = 0; 98 99 private: 100 /// \brief The name of this backward function. 101 std::string name_; 102 /// \brief The index set of the inputs required to calculate the gradient. 103 std::vector<size_t> gradient_indexes_; 104 }; 105 106 using BackwardFuncPtr = std::shared_ptr<BackwardFunc>; 107 } // namespace grad 108 } // namespace pijit 109 } // namespace mindspore 110 #endif // MINDSPORE_PI_JIT_BACKWARD_FUNCTION_H_ 111