1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_PASS_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_PASS_H_ 19 20 #include <string> 21 #include <utility> 22 #include <memory> 23 #include "ir/anf.h" 24 #include "include/backend/kernel_graph.h" 25 #include "include/common/expander/core/node.h" 26 27 namespace mindspore { 28 namespace pynative { 29 namespace autograd { 30 class FuncBuilder; 31 } 32 33 namespace bprop_pass { 34 using NodePtr = expander::NodePtr; 35 using NodePtrList = expander::NodePtrList; 36 37 struct FuncPassForward { FuncPassForwardFuncPassForward38 explicit FuncPassForward(autograd::FuncBuilder *func_builder, std::string &&device_target) 39 : func_builder_(func_builder), device_target_(std::move(device_target)) {} 40 41 // Pass for expander outputs 42 NodePtrList PassForOpInput(const PrimitivePtr &prim, const NodePtrList &inputs); 43 NodePtr BatchNormGradToBNInferGrad(const NodePtrList &inputs, bool is_scale_or_bias_grad); 44 NodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(const NodePtrList &inputs, const expander::DAttr &attrs, 45 const NodePtr &out, const NodePtr &dout, 46 bool is_graph_mode); 47 48 private: 49 NodePtrList ConvertConstInputToAttr(const PrimitivePtr &prim, const NodePtrList &inputs); 50 // Plant op input which is tuple, and set kAttrDynInputSizes attr 51 NodePtrList ConvertMakeTupleInputToDynamicInput(const PrimitivePtr &prim, const NodePtrList &inputs); 52 53 autograd::FuncBuilder *func_builder_{nullptr}; 54 std::string device_target_; 55 }; 56 using FuncPassForwardPtr = std::shared_ptr<FuncPassForward>; 57 } // namespace bprop_pass 58 } // namespace pynative 59 } // namespace mindspore 60 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_PASS_H_ 61