• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_PRIMITIVE_H_
18 #define MINDSPORE_CORE_IR_PRIMITIVE_H_
19 
20 #include <unordered_map>
21 #include <vector>
22 #include <memory>
23 #include <string>
24 #include <tuple>
25 
26 #include "ir/dtype/type.h"
27 #include "abstract/abstract_value.h"
28 #include "base/base_ref.h"
29 
30 namespace mindspore {
31 // Supported meta type
32 enum PrimType {
33   kPrimTypeUnknown = 0,
34   kPrimTypeBegin = kTypeUnknown,
35   kPrimTypeBuiltIn,  // Built-in primitive operator
36   kPrimTypePyInfer,  // Primitive operator defined by custom
37   kPrimTypeUserCustom,
38   kPrimTypePyCheck  // Primitive operator with input args checking method
39 };
40 
41 class MS_CORE_API Primitive : public Named {
42  public:
43   explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn);
44   Primitive(const std::string &name, const std::unordered_map<std::string, ValuePtr> &attrs);
45   Primitive(const Primitive &prim);
46   MS_DECLARE_PARENT(Primitive, Named);
47   abstract::AbstractBasePtr ToAbstract() override;
48   abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
ToString()49   std::string ToString() const override { return name(); }
BeginRecordAddAttr()50   void BeginRecordAddAttr() {
51     evaluate_added_attrs_.clear();
52     record_evaluate_add_attr_ = true;
53   }
EndRecordAddAttr()54   void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
AddAttr(const std::string & name,const ValuePtr & attr)55   Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
56     attrs_[name] = attr;
57     if (record_evaluate_add_attr_) {
58       evaluate_added_attrs_[name] = attr;
59     }
60     return *this;
61   }
62 
DelAttr(const std::string & name)63   Primitive &DelAttr(const std::string &name) {
64     attrs_.erase(name);
65     return *this;
66   }
67 
SetAttrs(const std::unordered_map<std::string,ValuePtr> & attrs)68   Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
69     for (auto &attr : attrs) {
70       attrs_[attr.first] = attr.second;
71     }
72     return *this;
73   }
74 
set_attr(const std::string & attrName,const ValuePtr & attr)75   void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
EraseAttr(const std::string & attrName)76   void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
RunComputeFunction(const VectorRef & args)77   virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; }
78 
GetAttr(const std::string & attrName)79   ValuePtr GetAttr(const std::string &attrName) const {
80     auto iter = attrs_.find(attrName);
81     return iter == attrs_.cend() ? nullptr : iter->second;
82   }
83 
attrs()84   const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
evaluate_added_attrs()85   const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; }
set_evaluate_added_attrs(const std::unordered_map<std::string,ValuePtr> & attrs)86   void set_evaluate_added_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
87     for (auto &attr : attrs) {
88       MS_LOG(DEBUG) << " set evalu attrl " << name() << attr.first;
89       attrs_[attr.first] = attr.second;
90     }
91   }
92 
93   // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
HasAttr()94   bool HasAttr() const { return !attrs_.empty(); }
HasAttr(const std::string & attrName)95   bool HasAttr(const std::string &attrName) const {
96     auto iter = attrs_.find(attrName);
97     return !(iter == attrs_.cend());
98   }
set_prim_type(const PrimType t)99   void set_prim_type(const PrimType t) { prim_type_ = t; }
Clone()100   virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
set_instance_name(const std::string & s)101   void set_instance_name(const std::string &s) { instance_name_ = s; }
HasPyEvaluator()102   bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; }
IsCustomPrim()103   bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
104 
prim_type()105   PrimType prim_type() const { return prim_type_; }
instance_name()106   std::string instance_name() const { return instance_name_; }
107   std::string GetAttrsText() const;
108   bool operator==(const Value &other) const override;
109   bool operator==(const Primitive &other) const;
110   ~Primitive() override = default;
111 
set_has_signature(bool has_signature)112   void set_has_signature(bool has_signature) { has_signature_ = has_signature; }
has_signature()113   bool has_signature() const { return has_signature_; }
is_base()114   bool is_base() const { return is_base_; }
RunHookFunction(const VectorRef & args)115   virtual BaseRef RunHookFunction(const VectorRef &args) const {
116     MS_LOG(EXCEPTION) << "call a empty function!";
117     BaseRef result;
118     return result;
119   }
CopyHookFunction(const PrimitivePtr & primitive)120   virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; }
set_const_prim(bool is_const_prim)121   void set_const_prim(bool is_const_prim) { is_const_prim_ = is_const_prim; }
is_const_prim()122   bool is_const_prim() const { return is_const_prim_; }
set_const_input_indexes(const std::vector<size_t> & const_input_indexes)123   void set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
124     const_input_indexes_ = const_input_indexes;
125   }
get_const_input_indexes()126   const std::vector<size_t> &get_const_input_indexes() { return const_input_indexes_; }
id()127   uint64_t id() const { return id_; }
128 
129  protected:
130   std::unordered_map<std::string, ValuePtr> attrs_;
131   std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_;
132 
133  private:
134   std::string instance_name_;
135   bool is_base_;
136   bool has_signature_;
137   PrimType prim_type_;
138   bool record_evaluate_add_attr_;
139   bool is_const_prim_;
140   std::vector<size_t> const_input_indexes_;
141   uint64_t id_{0};
142 };
143 
144 inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
145   os << *p;
146   return os;
147 }
148 
149 struct MS_CORE_API PrimitiveEqual {
operatorPrimitiveEqual150   bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
151     MS_EXCEPTION_IF_NULL(t1);
152     MS_EXCEPTION_IF_NULL(t2);
153     return t1 == t2 || t1->name() == t2->name();
154   }
155 };
156 
157 struct MS_CORE_API PrimitiveHasher {
operatorPrimitiveHasher158   std::size_t operator()(PrimitivePtr const &prim) const {
159     MS_EXCEPTION_IF_NULL(prim);
160     return prim->Hash();
161   }
162 };
163 
164 struct MS_CORE_API PrimitiveTotalEqual {
operatorPrimitiveTotalEqual165   bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
166     MS_EXCEPTION_IF_NULL(t1);
167     MS_EXCEPTION_IF_NULL(t2);
168     return *t1 == *t2;
169   }
170 };
171 }  // namespace mindspore
172 #endif  // MINDSPORE_CORE_IR_PRIMITIVE_H_
173