/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_KPYNATIVE_H_ #include #include #include "ir/anf.h" #include "ir/func_graph.h" namespace mindspore { namespace ad { class KPynativeCell { public: virtual ~KPynativeCell() = default; virtual void UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node) = 0; // Grad for cell which may have user passed front propagate FuncGraph. // c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that // cell. op_args: the arguments list of each input parameters. // out: the op result. // fprop_fg: user defined back propagate cnode which output is the bprop_fg. // Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout) virtual bool KPynativeWithFProp(const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out, const FuncGraphPtr &fprop_fg) = 0; }; using KPynativeCellPtr = std::shared_ptr; // bprop_fg: user defined back propagate funcgraph or back propagate funcgraph of primitive, it will be passed after // just parsed. will have prototype: // (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout) // c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim. // op_args: the arguments list of each input parameters. // out: the op result. // return: the returned funcgraph should have the same prototype. FuncGraphPtr OptimizeBPropFuncGraph(const FuncGraphPtr &bprop_fg, const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out); // Start building back propagate funcgraph for this cell. // cell_inputs: the input parameter list of this cell except the weights; KPynativeCellPtr GradPynativeCellBegin(const AnfNodePtrList &cell_inputs, const std::vector &input_param_values); // Return the back propagate funcgraph for this cell. // weights: weights parameters used in this cell. // grad_inputs: return sensitivity for input parameters; // grad_weights: return sensitivity for weights; // has_sens_arg: caller will pass sens args; // return: the returned funcgraph will have prototype: // if has_sens_arg is true // (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ..., // sens_out) // else: // (sens_input1, sens_input2, ..., sens_weight0, sens_weight1, ) bprop_fg(input1, input2, ..., weight0, weight1, ...) // if build_formal_param is true // each cnode in primal funcgraph is replaced by formal cnode // else: // each cnode in primal funcgraph is replaced by value node FuncGraphPtr GradPynativeCellEnd(const KPynativeCellPtr &k_cell, const AnfNodePtrList &weights, bool grad_inputs, bool grad_weights, bool has_sens_arg = false, bool build_formal_param = false); // Grad for each operation. // c_node: CNode with contains the prim (index 0) and the formal input parameters of that prim. // op_args: the arguments list of each input parameters. // out: the op result. bool GradPynativeOp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out); // Grad for cell which may have user defined back propagate function. // c_node: CNode with contains the construct function graph of cell (index 0) and the formal input parameters of that // cell. op_args: the arguments list of each input parameters. // out: the op result. // bprop_fg: user defined back propagate funcgraph, it should be passed after just parsed. // Should have prototype: (sens_input1, sens_input2, ...) bprop_fg(input1, input2, ..., out, dout) bool GradPynativeWithBProp(const KPynativeCellPtr &k_cell, const CNodePtr &c_node, const ValuePtrList &op_args, const ValuePtr &out, const FuncGraphPtr &bprop_fg); // Clear all static resources that used in grad process void ClearKPynativeCellStaticRes(); } // namespace ad } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_GRAD_H_