• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_
19 
20 #include <memory>
21 #include <vector>
22 #include "ir/anf.h"
23 #include "ir/func_graph.h"
24 
25 namespace mindspore {
26 namespace ad {
27 class KPynativeCell {
28  public:
29   virtual ~KPynativeCell() = default;
30   virtual void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) = 0;
31   // Grad for cell which may have user passed front propagate FuncGraph.
32   // c_node: CNode with contains the construct function graph of cell  (index 0) and the formal input parameters of that
33   // cell. op_args: the arguments list of each input parameters.
34   // out: the op result.
35   // fprop_fg: user defined back propagate cnode which output is the bprop_fg.
36   //           Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
37   virtual bool KPynativeWithFProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out,
38                                   const FuncGraphPtr &fprop_fg) = 0;
39 };
40 
41 using KPynativeCellPtr = std::shared_ptr<KPynativeCell>;
42 
43 // bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after
44 //           just parsed. will have prototype:
45 //           (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
46 // c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
47 // op_args: the arguments list of each input parameters.
48 // out: the op result.
49 // return: the returned funcgraph should have the same prototype.
50 FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args,
51                                     const ValuePtr &out);
52 
53 // Start building back propagate funcgraph for this cell.
54 // cell_inputs: the input parameter list of this cell except the weights;
55 KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs,
56                                        const std::vector<ValuePtr> &input_param_values);
57 
58 // Return the back propagate funcgraph for this cell.
59 // weights: weights parameters used in this cell.
60 // grad_inputs: return sensitivity for input parameters;
61 // grad_weights: return sensitivity for weights;
62 // has_sens_arg: caller will pass sens args;
63 // return: the returned funcgraph will have prototype:
64 // if has_sens_arg is true
65 // (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...,
66 // sens_out)
67 // else:
68 // (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...)
69 // if build_formal_param is true
70 // each cnode in primal funcgraph is replaced by formal cnode
71 // else:
72 // each cnode in primal funcgraph is replaced by value node
73 FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs,
74                                  bool grad_weights, bool has_sens_arg = false, bool build_formal_param = false);
75 
76 // Grad for each operation.
77 // c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim.
78 // op_args: the arguments list of each input parameters.
79 // out: the op result.
80 bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
81                     const ValuePtr &out);
82 
83 // Grad for cell which may have user defined back propagate function.
84 // c_node: CNode with contains the construct function graph of cell  (index 0) and the formal input parameters of that
85 // cell. op_args: the arguments list of each input parameters.
86 // out: the op result.
87 // bprop_fg: user defined back propagate funcgraph, it should be passed after just parsed.
88 //           Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout)
89 bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args,
90                            const ValuePtr &out, const FuncGraphPtr &bprop_fg);
91 
92 // Clear all static resources that used in grad process
93 void ClearKPynativeCellStaticRes();
94 }  // namespace ad
95 }  // namespace mindspore
96 
97 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_
98