1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_ 19 20 #include <memory> 21 #include <vector> 22 #include "ir/anf.h" 23 #include "ir/func_graph.h" 24 25 namespace mindspore { 26 namespace ad { 27 class KPynativeCell { 28 public: 29 virtual ~KPynativeCell() = default; 30 virtual void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) = 0; 31 // Grad for cell which may have user passed front propagate FuncGraph. 32 // c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that 33 // cell. op_args: the arguments list of each input parameters. 34 // out: the op result. 35 // fprop_fg: user defined back propagate cnode which output is the bprop_fg. 36 // Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout) 37 virtual bool KPynativeWithFProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out, 38 const FuncGraphPtr &fprop_fg) = 0; 39 }; 40 41 using KPynativeCellPtr = std::shared_ptr<KPynativeCell>; 42 43 // bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after 44 // just parsed. will have prototype: 45 // (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout) 46 // c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim. 47 // op_args: the arguments list of each input parameters. 48 // out: the op result. 49 // return: the returned funcgraph should have the same prototype. 50 FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args, 51 const ValuePtr &out); 52 53 // Start building back propagate funcgraph for this cell. 54 // cell_inputs: the input parameter list of this cell except the weights; 55 KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs, 56 const std::vector<ValuePtr> &input_param_values); 57 58 // Return the back propagate funcgraph for this cell. 59 // weights: weights parameters used in this cell. 60 // grad_inputs: return sensitivity for input parameters; 61 // grad_weights: return sensitivity for weights; 62 // has_sens_arg: caller will pass sens args; 63 // return: the returned funcgraph will have prototype: 64 // if has_sens_arg is true 65 // (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ..., 66 // sens_out) 67 // else: 68 // (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...) 69 // if build_formal_param is true 70 // each cnode in primal funcgraph is replaced by formal cnode 71 // else: 72 // each cnode in primal funcgraph is replaced by value node 73 FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs, 74 bool grad_weights, bool has_sens_arg = false, bool build_formal_param = false); 75 76 // Grad for each operation. 77 // c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim. 78 // op_args: the arguments list of each input parameters. 79 // out: the op result. 80 bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args, 81 const ValuePtr &out); 82 83 // Grad for cell which may have user defined back propagate function. 84 // c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that 85 // cell. op_args: the arguments list of each input parameters. 86 // out: the op result. 87 // bprop_fg: user defined back propagate funcgraph, it should be passed after just parsed. 88 // Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout) 89 bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args, 90 const ValuePtr &out, const FuncGraphPtr &bprop_fg); 91 92 // Clear all static resources that used in grad process 93 void ClearKPynativeCellStaticRes(); 94 } // namespace ad 95 } // namespace mindspore 96 97 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_ 98