1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_BUILDER_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_BUILDER_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include "utils/hash_map.h" 25 #include "frontend/expander/bprop/bprop_irbuilder.h" 26 #include "pipeline/pynative/grad/function/func_pass.h" 27 28 namespace mindspore::pynative::autograd { 29 using NodePtr = expander::NodePtr; 30 using NodePtrList = expander::NodePtrList; 31 using BpropBuilder = expander::bprop::BpropBuilder; 32 33 class FuncBuilder : public BpropBuilder { 34 public: 35 FuncBuilder(const std::string &name, std::string device_target, const expander::ExpanderInferPtr &infer = nullptr) BpropBuilder(name,infer)36 : BpropBuilder(name, infer), device_target_(device_target) { 37 pass_forward_ = std::make_shared<bprop_pass::FuncPassForward>(this, std::move(device_target)); 38 } 39 ~FuncBuilder() override = default; 40 NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) override; 41 NodePtr EmitValue(const ValuePtr &value) override; 42 // Override Stack to flatten tuple input. 43 NodePtr Stack(const NodePtr &x, const ValuePtr &axis) override; 44 NodePtr Stack(const NodePtrList &x, int64_t axis) override; 45 NodePtr BatchNormGrad(const NodePtrList &inputs, bool is_scale_or_bias_grad) override; 46 NodePtr SparseSoftmaxCrossEntropyWithLogits(const NodePtrList &inputs, const expander::DAttr &attrs, 47 const NodePtr &out, const NodePtr &dout, bool is_graph_mode) override; 48 NodePtr Depend(const NodePtr &value, const NodePtr &expr) override; 49 NodePtr TupleGetItem(const NodePtr &input, size_t i) override; 50 NodePtr TupleGetItem(const NodePtr &input, const NodePtr &index) override; 51 NodePtr MakeTuple(const NodePtrList &inputs) override; 52 NodePtr MakeList(const NodePtrList &inputs) override; 53 NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) override; 54 NodePtr ScalarEq(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) override; 55 NodePtr OutZeros(const NodePtr &node) override; 56 ValuePtr Ones(const ValuePtr &value); 57 ValuePtr Zeros(const ValuePtr &value); 58 ValuePtr Add(const ValuePtr &input, const ValuePtr &other); 59 void SetInputs(std::string instance_name, const std::vector<NodePtr> *inputs, 60 mindspore::HashMap<std::string, ValuePtr> *attrs_ptr); 61 ValuePtr FillZeros(const ValuePtr &value, const abstract::AbstractBasePtr &abs); 62 63 private: 64 NodePtrList FlattenNode(const NodePtr &input); 65 std::string device_target_; 66 bprop_pass::FuncPassForwardPtr pass_forward_; 67 }; 68 using FuncBuilderPtr = std::shared_ptr<FuncBuilder>; 69 } // namespace mindspore::pynative::autograd 70 71 #endif // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_GRAD_FUNCTION_FUNC_BUILDER_H_ 72