1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_NODE_H_ 18 #define MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_NODE_H_ 19 #include <vector> 20 #include <string> 21 #include <memory> 22 #include <utility> 23 #include "ir/anf.h" 24 #include "include/common/visible.h" 25 #include "include/common/utils/utils.h" 26 27 namespace mindspore { 28 namespace expander { 29 class Emitter; 30 using DAttr = std::vector<std::pair<std::string, ValuePtr>>; 31 32 class COMMON_EXPORT Node : public std::enable_shared_from_this<Node> { 33 public: 34 explicit Node(Emitter *emitter); Node(Emitter * emitter,const ValuePtr & value)35 Node(Emitter *emitter, const ValuePtr &value) : emitter_(emitter), value_(value) {} 36 virtual ~Node() = default; 37 get()38 virtual const AnfNodePtr &get() const { MS_EXCEPTION(NotImplementedError) << "Base Node not implement get() method"; } 39 40 virtual InputType input_type(); 41 virtual AbstractBasePtr abstract(); 42 SetValue(const ValuePtr & val)43 void SetValue(const ValuePtr &val) { value_ = val; } Value()44 ValuePtr Value() { return value_; } 45 virtual ValuePtr BuildValue(); 46 virtual bool HasAbstractValue(); 47 virtual BaseShapePtr GetShape(); 48 virtual TypePtr GetType(); 49 50 std::vector<int64_t> shape(); 51 std::vector<std::vector<int64_t>> shapes(); 52 TypePtr dtype(); 53 std::vector<TypePtr> dtypes(); emitter()54 Emitter *emitter() { return emitter_; } 55 virtual std::string ToString() const; set_debug_info(const std::string & debug_info)56 virtual void set_debug_info(const std::string &debug_info) {} debug_info()57 virtual std::string debug_info() const { return ""; } is_used_value()58 virtual bool is_used_value() const { 59 MS_EXCEPTION(NotImplementedError) << "Base Node not implement is_used_value() method"; 60 } need_compute_grad_out()61 virtual bool need_compute_grad_out() const { return true; } 62 63 protected: 64 // hold the emitter who created this node. 65 Emitter *emitter_{nullptr}; 66 // cache the output shape after first query 67 BaseShapePtr shape_{nullptr}; 68 // cache the output dtype after first query 69 TypePtr type_{nullptr}; 70 // cache the value of node 71 ValuePtr value_{nullptr}; 72 }; 73 using NodePtr = std::shared_ptr<Node>; 74 using NodePtrList = std::vector<NodePtr>; 75 76 class COMMON_EXPORT IrNode : public Node { 77 public: IrNode(const AnfNodePtr anfnode,Emitter * emitter)78 IrNode(const AnfNodePtr anfnode, Emitter *emitter) : Node(emitter), anf_node_(anfnode) {} get()79 const AnfNodePtr &get() const override { return anf_node_; } 80 InputType input_type() override; 81 AbstractBasePtr abstract() override; 82 83 ValuePtr BuildValue() override; 84 bool HasAbstractValue() override; 85 BaseShapePtr GetShape() override; 86 TypePtr GetType() override; 87 88 std::string ToString() const override; 89 void set_debug_info(const std::string &debug_info) override; 90 std::string debug_info() const override; is_used_value()91 bool is_used_value() const override { return is_used_value_; } 92 93 private: 94 // the wrapped anfnode. 95 AnfNodePtr anf_node_{nullptr}; 96 // whether use value 97 bool is_used_value_{false}; 98 }; 99 using IrNodePtr = std::shared_ptr<IrNode>; 100 101 class COMMON_EXPORT FuncNode : public Node { 102 public: FuncNode(const ValuePtr & value,const abstract::AbstractBasePtr & abs,InputType input_type,Emitter * emitter)103 FuncNode(const ValuePtr &value, const abstract::AbstractBasePtr &abs, InputType input_type, Emitter *emitter) 104 : Node(emitter, value), abstract_(abs), input_type_(input_type) {} 105 ValuePtr BuildValue() override; 106 InputType input_type() override; set_node_type(InputType input_type)107 void set_node_type(InputType input_type) { input_type_ = input_type; } 108 AbstractBasePtr abstract() override; set_abstract(const AbstractBasePtr & abs)109 void set_abstract(const AbstractBasePtr &abs) { abstract_ = abs; } 110 BaseShapePtr GetShape() override; 111 TypePtr GetType() override; ToString()112 std::string ToString() const override { return value_->ToString(); } set_debug_info(const std::string & debug_info)113 void set_debug_info(const std::string &debug_info) override {} debug_info()114 std::string debug_info() const override { return ""; } need_compute_grad_out()115 bool need_compute_grad_out() const override { return need_compute_grad_out_; } set_need_compute_grad_out(bool need_compute_grad_out)116 void set_need_compute_grad_out(bool need_compute_grad_out) { need_compute_grad_out_ = need_compute_grad_out; } 117 118 private: 119 AbstractBasePtr abstract_; 120 InputType input_type_; 121 bool need_compute_grad_out_{true}; 122 }; 123 using FuncNodePtr = std::shared_ptr<FuncNode>; 124 } // namespace expander 125 } // namespace mindspore 126 #endif // MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_NODE_H_ 127