1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_BPROP_PASS_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_BPROP_PASS_H_ 19 20 #include <string> 21 #include <utility> 22 #include <memory> 23 #include "ir/anf.h" 24 #include "include/backend/kernel_graph.h" 25 26 namespace mindspore { 27 namespace pynative { 28 namespace autograd { 29 class IrBprop; 30 } 31 32 namespace bprop_pass { 33 constexpr auto kIsKNode = "is_knode"; 34 35 struct IrPassForward { IrPassForwardIrPassForward36 explicit IrPassForward(autograd::IrBprop *ir_bprop, std::string &&device_target, bool grad_by_value) 37 : ir_bprop_(ir_bprop), device_target_(std::move(device_target)), grad_by_value_(grad_by_value) {} 38 39 // Pass for expander outputs 40 CNodePtr PassForDin(const CNodePtr &cnode, const std::string &op_name, bool is_dynamic_shape); 41 // Plant op input which is tuple, and set kAttrDynInputSizes attr 42 void ConvertMakeTupleInputToDynamicInput(const AnfNodePtr &node, SeenNum seen, bool run_by_single_op); 43 AnfNodePtr PassBackwardHook(const ValuePtr &value, const AnfNodePtr &grad_node); 44 // Reverse operation for pass in high grad 45 void ReversePassFuncGraph(const FuncGraphPtr &func_graph); 46 void ReversePassCNode(const CNodePtr &cnode, ValuePtrList *inputs_value, AnfNodePtrList *cnode_inputs); need_reverse_graphIrPassForward47 static inline bool need_reverse_graph() { return need_reverse_graph_; } 48 49 private: 50 CNodePtr ConvertConstInputToAttr(const CNodePtr &cnode, bool is_dynamic_shape); 51 AnfNodePtr BatchNormGradToBNInferGrad(const AnfNodePtr &node, const std::string &op_name); 52 void ReverseConstantToAttrNode(const CNodePtr &cnode, ValuePtrList *inputs_value, AnfNodePtrList *cnode_inputs); 53 void ReverseMakeTupleNode(const CNodePtr &cnode, ValuePtrList *inputs_value, AnfNodePtrList *cnode_inputs); 54 void ReverseBNInfer(const CNodePtr &cnode); 55 void ReverseCNodeInputs(const CNodePtr &cnode, AnfNodePtrList *cnode_inputs, ValuePtrList *inputs_value); 56 57 autograd::IrBprop *ir_bprop_{nullptr}; 58 std::string device_target_; 59 bool grad_by_value_{false}; 60 static bool need_reverse_graph_; 61 }; 62 using PyNativePassForwardPtr = std::shared_ptr<IrPassForward>; 63 64 void ClearCache(); 65 } // namespace bprop_pass 66 } // namespace pynative 67 } // namespace mindspore 68 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_IR_BPROP_PASS_H_ 69