• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PI_JIT_NATIVE_BACKWARD_FUNCTION_H_
18 #define MINDSPORE_PI_JIT_NATIVE_BACKWARD_FUNCTION_H_
19 
20 #include <memory>
21 #include <string>
22 #include "pipeline/jit/pi/auto_grad/backward_function.h"
23 #include "pipeline/pynative/grad/function/func_builder.h"
24 #include "utils/ms_context.h"
25 
26 namespace mindspore {
27 namespace pijit {
28 namespace grad {
29 using FuncBuilderPtr = pynative::autograd::FuncBuilderPtr;
30 using BpropHandlePtr = const expander::bprop::BpropHandle *;
31 
32 class NativeBackwardFunc;
33 using NativeBackwardFuncPtr = std::shared_ptr<NativeBackwardFunc>;
34 
35 /// \brief NativeBackwardFunc is a class, which represent a function to calculate the gradient.
36 class NativeBackwardFunc : public BackwardFunc {
37  public:
38   /// \brief The constructor of NativeBackwardFunc.
39   ///
40   /// \param[in] name The name of this backward function.
41   ///
42   /// \return The instance of NativeBackwardFunc.
NativeBackwardFunc(const PrimitivePtr & prim,const FuncBuilderPtr & ir_builder,const BpropHandlePtr handle)43   explicit NativeBackwardFunc(const PrimitivePtr &prim, const FuncBuilderPtr &ir_builder, const BpropHandlePtr handle)
44       : BackwardFunc(prim->name()), prim_(prim), ir_builder_(ir_builder), handle_(handle) {}
45 
46   /// \brief Destructor.
47   virtual ~NativeBackwardFunc() = default;
48 
49   /// \brief Create a instance of native backward function.
50   ///
51   /// \param[in] prim The primitive of the forward execution.
52   ///
53   /// \return The instance of native backward function.
54   static NativeBackwardFuncPtr GetInstance(const PrimitivePtr &prim);
55 
56   /// \brief Start calculate the gradient of the backward function.
57   ///
58   /// \param[in] inputs The arguments of the forward execution.
59   /// \param[in] out The output of the forward execution.
60   /// \param[in] dout The dout of the output.
61   ///
62   /// \return The gradients of the inputs of forward execution.
63   ValuePtrList Run(const ValuePtrList &inputs, const ValuePtr &out, const ValuePtr &dout) override;
64 
65   /// \brief Postprocess gradients from func to align with next_edges.
66   ///
67   /// \param[in] gradient_value Gradients value is gradients result from func which need postprocess.
68   ///
69   /// \return Real gradients after postprocess, the size is same as next edges size.
70   ValuePtrList PostProcess(const ValuePtrList &gradient_value) override;
71 
72   /// \brief Get the primitive of the forward.
73   ///
74   /// \return The primitive of the forward.
GetPrim()75   const PrimitivePtr &GetPrim() const { return prim_; }
76 
77   /// \brief Create the value filled with one, shape like the input.
78   ///
79   /// \param[in] value The input value.
80   ///
81   /// \return The value filled with one.
Ones(const ValuePtr & value)82   ValuePtr Ones(const ValuePtr &value) const override { return ir_builder_->Ones(value); }
83 
84   /// \brief Create the value filled with zero, shape like the input.
85   ///
86   /// \param[in] value The input value.
87   ///
88   /// \return The value filled with zero.
Zeros(const ValuePtr & value)89   ValuePtr Zeros(const ValuePtr &value) const override { return ir_builder_->Zeros(value); }
90 
91   /// \brief Calculate the sum of inputs.
92   ///
93   /// \param[in] input The first input value.
94   /// \param[in] other The second input value.
95   ///
96   /// \return The sum of inputs.
Add(const ValuePtr & input,const ValuePtr & other)97   ValuePtr Add(const ValuePtr &input, const ValuePtr &other) const override { return ir_builder_->Add(input, other); }
98 
99   /// \brief Convert the inputs, output and dout of forward execution into the inputs of function builder.
100   ///
101   /// \param[in] inputs The arguments of the forward execution.
102   /// \param[in] out The output of the forward execution.
103   /// \param[in] dout The dout of the output.
104   ///
105   /// \return The inputs of the function builder.
106   expander::NodePtrList PreProcess(const ValuePtrList &inputs, const ValuePtr &out, const ValuePtr &dout) const;
107 
108  private:
109   /// \brief The primitive of forward execution.
110   PrimitivePtr prim_;
111   /// \brief The function builder of this backward function.
112   FuncBuilderPtr ir_builder_;
113   /// \brief The bprop handle of the primitive.
114   const BpropHandlePtr handle_;
115 };
116 }  // namespace grad
117 }  // namespace pijit
118 }  // namespace mindspore
119 #endif  // MINDSPORE_PI_JIT_NATIVE_BACKWARD_FUNCTION_H_
120