1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_PI_JIT_NATIVE_BACKWARD_FUNCTION_H_ 18 #define MINDSPORE_PI_JIT_NATIVE_BACKWARD_FUNCTION_H_ 19 20 #include <memory> 21 #include <string> 22 #include "pipeline/jit/pi/auto_grad/backward_function.h" 23 #include "pipeline/pynative/grad/function/func_builder.h" 24 #include "utils/ms_context.h" 25 26 namespace mindspore { 27 namespace pijit { 28 namespace grad { 29 using FuncBuilderPtr = pynative::autograd::FuncBuilderPtr; 30 using BpropHandlePtr = const expander::bprop::BpropHandle *; 31 32 class NativeBackwardFunc; 33 using NativeBackwardFuncPtr = std::shared_ptr<NativeBackwardFunc>; 34 35 /// \brief NativeBackwardFunc is a class, which represent a function to calculate the gradient. 36 class NativeBackwardFunc : public BackwardFunc { 37 public: 38 /// \brief The constructor of NativeBackwardFunc. 39 /// 40 /// \param[in] name The name of this backward function. 41 /// 42 /// \return The instance of NativeBackwardFunc. NativeBackwardFunc(const PrimitivePtr & prim,const FuncBuilderPtr & ir_builder,const BpropHandlePtr handle)43 explicit NativeBackwardFunc(const PrimitivePtr &prim, const FuncBuilderPtr &ir_builder, const BpropHandlePtr handle) 44 : BackwardFunc(prim->name()), prim_(prim), ir_builder_(ir_builder), handle_(handle) {} 45 46 /// \brief Destructor. 47 virtual ~NativeBackwardFunc() = default; 48 49 /// \brief Create a instance of native backward function. 50 /// 51 /// \param[in] prim The primitive of the forward execution. 52 /// 53 /// \return The instance of native backward function. 54 static NativeBackwardFuncPtr GetInstance(const PrimitivePtr &prim); 55 56 /// \brief Start calculate the gradient of the backward function. 57 /// 58 /// \param[in] inputs The arguments of the forward execution. 59 /// \param[in] out The output of the forward execution. 60 /// \param[in] dout The dout of the output. 61 /// 62 /// \return The gradients of the inputs of forward execution. 63 ValuePtrList Run(const ValuePtrList &inputs, const ValuePtr &out, const ValuePtr &dout) override; 64 65 /// \brief Postprocess gradients from func to align with next_edges. 66 /// 67 /// \param[in] gradient_value Gradients value is gradients result from func which need postprocess. 68 /// 69 /// \return Real gradients after postprocess, the size is same as next edges size. 70 ValuePtrList PostProcess(const ValuePtrList &gradient_value) override; 71 72 /// \brief Get the primitive of the forward. 73 /// 74 /// \return The primitive of the forward. GetPrim()75 const PrimitivePtr &GetPrim() const { return prim_; } 76 77 /// \brief Create the value filled with one, shape like the input. 78 /// 79 /// \param[in] value The input value. 80 /// 81 /// \return The value filled with one. Ones(const ValuePtr & value)82 ValuePtr Ones(const ValuePtr &value) const override { return ir_builder_->Ones(value); } 83 84 /// \brief Create the value filled with zero, shape like the input. 85 /// 86 /// \param[in] value The input value. 87 /// 88 /// \return The value filled with zero. Zeros(const ValuePtr & value)89 ValuePtr Zeros(const ValuePtr &value) const override { return ir_builder_->Zeros(value); } 90 91 /// \brief Calculate the sum of inputs. 92 /// 93 /// \param[in] input The first input value. 94 /// \param[in] other The second input value. 95 /// 96 /// \return The sum of inputs. Add(const ValuePtr & input,const ValuePtr & other)97 ValuePtr Add(const ValuePtr &input, const ValuePtr &other) const override { return ir_builder_->Add(input, other); } 98 99 /// \brief Convert the inputs, output and dout of forward execution into the inputs of function builder. 100 /// 101 /// \param[in] inputs The arguments of the forward execution. 102 /// \param[in] out The output of the forward execution. 103 /// \param[in] dout The dout of the output. 104 /// 105 /// \return The inputs of the function builder. 106 expander::NodePtrList PreProcess(const ValuePtrList &inputs, const ValuePtr &out, const ValuePtr &dout) const; 107 108 private: 109 /// \brief The primitive of forward execution. 110 PrimitivePtr prim_; 111 /// \brief The function builder of this backward function. 112 FuncBuilderPtr ir_builder_; 113 /// \brief The bprop handle of the primitive. 114 const BpropHandlePtr handle_; 115 }; 116 } // namespace grad 117 } // namespace pijit 118 } // namespace mindspore 119 #endif // MINDSPORE_PI_JIT_NATIVE_BACKWARD_FUNCTION_H_ 120