1 /** 2 * Copyright 2019-2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_PRIMITIVE_H_ 18 #define MINDSPORE_CORE_IR_PRIMITIVE_H_ 19 20 #include <unordered_map> 21 #include <vector> 22 #include <memory> 23 #include <string> 24 #include <tuple> 25 26 #include "ir/dtype/type.h" 27 #include "abstract/abstract_value.h" 28 #include "base/base_ref.h" 29 30 namespace mindspore { 31 // Supported meta type 32 enum PrimType { 33 kPrimTypeUnknown = 0, 34 kPrimTypeBegin = kTypeUnknown, 35 kPrimTypeBuiltIn, // Built-in primitive operator 36 kPrimTypePyInfer, // Primitive operator defined by custom 37 kPrimTypeUserCustom, 38 kPrimTypePyCheck // Primitive operator with input args checking method 39 }; 40 41 class MS_CORE_API Primitive : public Named { 42 public: 43 explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn); 44 Primitive(const std::string &name, const std::unordered_map<std::string, ValuePtr> &attrs); 45 Primitive(const Primitive &prim); 46 MS_DECLARE_PARENT(Primitive, Named); 47 abstract::AbstractBasePtr ToAbstract() override; 48 abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); ToString()49 std::string ToString() const override { return name(); } BeginRecordAddAttr()50 void BeginRecordAddAttr() { 51 evaluate_added_attrs_.clear(); 52 record_evaluate_add_attr_ = true; 53 } EndRecordAddAttr()54 void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } AddAttr(const std::string & name,const ValuePtr & attr)55 Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { 56 attrs_[name] = attr; 57 if (record_evaluate_add_attr_) { 58 evaluate_added_attrs_[name] = attr; 59 } 60 return *this; 61 } 62 DelAttr(const std::string & name)63 Primitive &DelAttr(const std::string &name) { 64 attrs_.erase(name); 65 return *this; 66 } 67 SetAttrs(const std::unordered_map<std::string,ValuePtr> & attrs)68 Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) { 69 for (auto &attr : attrs) { 70 attrs_[attr.first] = attr.second; 71 } 72 return *this; 73 } 74 set_attr(const std::string & attrName,const ValuePtr & attr)75 void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } EraseAttr(const std::string & attrName)76 void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } RunComputeFunction(const VectorRef & args)77 virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; } 78 GetAttr(const std::string & attrName)79 ValuePtr GetAttr(const std::string &attrName) const { 80 auto iter = attrs_.find(attrName); 81 return iter == attrs_.cend() ? nullptr : iter->second; 82 } 83 attrs()84 const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } evaluate_added_attrs()85 const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; } set_evaluate_added_attrs(const std::unordered_map<std::string,ValuePtr> & attrs)86 void set_evaluate_added_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) { 87 for (auto &attr : attrs) { 88 MS_LOG(DEBUG) << " set evalu attrl " << name() << attr.first; 89 attrs_[attr.first] = attr.second; 90 } 91 } 92 93 // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. HasAttr()94 bool HasAttr() const { return !attrs_.empty(); } HasAttr(const std::string & attrName)95 bool HasAttr(const std::string &attrName) const { 96 auto iter = attrs_.find(attrName); 97 return !(iter == attrs_.cend()); 98 } set_prim_type(const PrimType t)99 void set_prim_type(const PrimType t) { prim_type_ = t; } Clone()100 virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); } set_instance_name(const std::string & s)101 void set_instance_name(const std::string &s) { instance_name_ = s; } HasPyEvaluator()102 bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; } IsCustomPrim()103 bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } 104 prim_type()105 PrimType prim_type() const { return prim_type_; } instance_name()106 std::string instance_name() const { return instance_name_; } 107 std::string GetAttrsText() const; 108 bool operator==(const Value &other) const override; 109 bool operator==(const Primitive &other) const; 110 ~Primitive() override = default; 111 set_has_signature(bool has_signature)112 void set_has_signature(bool has_signature) { has_signature_ = has_signature; } has_signature()113 bool has_signature() const { return has_signature_; } is_base()114 bool is_base() const { return is_base_; } RunHookFunction(const VectorRef & args)115 virtual BaseRef RunHookFunction(const VectorRef &args) const { 116 MS_LOG(EXCEPTION) << "call a empty function!"; 117 BaseRef result; 118 return result; 119 } CopyHookFunction(const PrimitivePtr & primitive)120 virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } set_const_prim(bool is_const_prim)121 void set_const_prim(bool is_const_prim) { is_const_prim_ = is_const_prim; } is_const_prim()122 bool is_const_prim() const { return is_const_prim_; } set_const_input_indexes(const std::vector<size_t> & const_input_indexes)123 void set_const_input_indexes(const std::vector<size_t> &const_input_indexes) { 124 const_input_indexes_ = const_input_indexes; 125 } get_const_input_indexes()126 const std::vector<size_t> &get_const_input_indexes() { return const_input_indexes_; } id()127 uint64_t id() const { return id_; } 128 129 protected: 130 std::unordered_map<std::string, ValuePtr> attrs_; 131 std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_; 132 133 private: 134 std::string instance_name_; 135 bool is_base_; 136 bool has_signature_; 137 PrimType prim_type_; 138 bool record_evaluate_add_attr_; 139 bool is_const_prim_; 140 std::vector<size_t> const_input_indexes_; 141 uint64_t id_{0}; 142 }; 143 144 inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { 145 os << *p; 146 return os; 147 } 148 149 struct MS_CORE_API PrimitiveEqual { operatorPrimitiveEqual150 bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { 151 MS_EXCEPTION_IF_NULL(t1); 152 MS_EXCEPTION_IF_NULL(t2); 153 return t1 == t2 || t1->name() == t2->name(); 154 } 155 }; 156 157 struct MS_CORE_API PrimitiveHasher { operatorPrimitiveHasher158 std::size_t operator()(PrimitivePtr const &prim) const { 159 MS_EXCEPTION_IF_NULL(prim); 160 return prim->Hash(); 161 } 162 }; 163 164 struct MS_CORE_API PrimitiveTotalEqual { operatorPrimitiveTotalEqual165 bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { 166 MS_EXCEPTION_IF_NULL(t1); 167 MS_EXCEPTION_IF_NULL(t2); 168 return *t1 == *t2; 169 } 170 }; 171 } // namespace mindspore 172 #endif // MINDSPORE_CORE_IR_PRIMITIVE_H_ 173