• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_PRIMITIVE_H_
18 #define MINDSPORE_CORE_IR_PRIMITIVE_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25 #include <shared_mutex>
26 #include <initializer_list>
27 
28 #include "utils/hash_map.h"
29 #include "ir/signature.h"
30 #include "ir/dtype/type.h"
31 #include "abstract/abstract_value.h"
32 #include "base/base_ref.h"
33 
34 namespace mindspore {
35 // Supported meta type
36 enum PrimType {
37   kPrimTypeUnknown = 0,
38   kPrimTypeBegin = kPrimTypeUnknown,
39   kPrimTypeBuiltIn,     // Built-in primitive operator
40   kPrimTypePyInfer,     // Primitive operator with python infer function
41   kPrimTypeUserCustom,  // Primitive operator defined by custom
42   kPrimTypePyCheck      // Primitive operator with input args checking method
43 };
44 
45 class MS_CORE_API PrimitiveReadLock {
46  public:
PrimitiveReadLock(std::shared_ptr<std::shared_mutex> shared_mutex)47   explicit PrimitiveReadLock(std::shared_ptr<std::shared_mutex> shared_mutex) : shared_mutex_(std::move(shared_mutex)) {
48     if (shared_mutex_ != nullptr) {
49       shared_mutex_->lock_shared();
50     }
51   }
~PrimitiveReadLock()52   ~PrimitiveReadLock() {
53     if (shared_mutex_ != nullptr) {
54       // cppcheck-suppress unreadVariable
55       shared_mutex_->unlock_shared();
56     }
57   }
58 
59  private:
60   std::shared_ptr<std::shared_mutex> shared_mutex_;
61 };
62 
63 class MS_CORE_API PrimitiveWriteLock {
64  public:
PrimitiveWriteLock(std::shared_ptr<std::shared_mutex> shared_mutex)65   explicit PrimitiveWriteLock(std::shared_ptr<std::shared_mutex> shared_mutex)
66       : shared_mutex_(std::move(shared_mutex)) {
67     if (shared_mutex_ != nullptr) {
68       shared_mutex_->lock();
69     }
70   }
~PrimitiveWriteLock()71   ~PrimitiveWriteLock() {
72     if (shared_mutex_ != nullptr) {
73       shared_mutex_->unlock();
74     }
75   }
76 
77  private:
78   std::shared_ptr<std::shared_mutex> shared_mutex_;
79 };
80 
81 /// \brief Primitive defines a operator primitive of MindSpore.
82 class MS_CORE_API Primitive : public Named {
83  public:
84   /// \brief The constructor of Primitive.
85   ///
86   /// \param[in] name The name of primitive.
87   /// \param[in] is_base True means the basic Primitive without BProp function inside.
88   /// \param[in] prim_type The type of primitive.
89   explicit Primitive(const std::string &name, bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn,
90                      bool inplace_prim = false);
91   Primitive(const std::string &name, const mindspore::HashMap<std::string, ValuePtr> &attrs, bool inplace_prim = false);
92   /// \brief The constructor for Primitive, create a primitive for another primitive.
93   ///
94   /// \param[in] prim The input primitive.
95   Primitive(const Primitive &prim);
96   /// \brief The copy assignment operator for Primitive.
97   ///
98   /// \param[in] other An existing Primitive object.
99   /// \return A Primitive object set with the same members as other.
100   Primitive &operator=(const Primitive &other);
101   MS_DECLARE_PARENT(Primitive, Named);
102   abstract::AbstractBasePtr ToAbstract() override;
103   std::string ToString() const override;
104   /// \brief Ready to recording the attribute if the attribute needs to be added when deducing shape and type.
105   /// This attributes has been recorded needs to add in infer cache.
BeginRecordAddAttr()106   void BeginRecordAddAttr() {
107     evaluate_added_attrs_.clear();
108     record_evaluate_add_attr_ = true;
109   }
110   /// \brief End recording attribute.
EndRecordAddAttr()111   void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
112   /// \brief Add attribute to primitive attribute map and record the new attribute to evaluate_added_attrs_,
113   /// if record_evaluate_add_attr_ is true.
114   ///
115   /// \param[in] name The name of attribute.
116   /// \param[in] attr The value of attribute.
117   /// \return The primitive to which attribute has been added.
AddAttr(const std::string & name,const ValuePtr & attr)118   Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
119     // cppcheck-suppress unreadVariable
120     PrimitiveWriteLock write_lock(shared_mutex_);
121     attrs_[name] = attr;
122     if (record_evaluate_add_attr_) {
123       evaluate_added_attrs_[name] = attr;
124     }
125     return *this;
126   }
127   /// \brief Delete the attribute.
128   ///
129   /// \param[in] name The name of attribute to be delete.
130   /// \return The primitive to which attribute has been added.
DelAttr(const std::string & name)131   Primitive &DelAttr(const std::string &name) {
132     // cppcheck-suppress unreadVariable
133     PrimitiveWriteLock write_lock(shared_mutex_);
134     (void)attrs_.erase(name);
135     return *this;
136   }
137   /// \brief Use add attribute by using a map,all elements of the map will be added in the primitive's attribute map.
138   ///
139   /// \param[in] attrs The attribute map needs to be added in the primitive attribute.
140   /// \return The primitive to which attribute has been added.
SetAttrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)141   Primitive &SetAttrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
142     PrimitiveWriteLock write_lock(shared_mutex_);
143     for (auto &attr : attrs) {
144       attrs_[attr.first] = attr.second;
145     }
146     return *this;
147   }
148   /// \brief Use add attribute by using initializer_list, all elements of the vector will be added in the primitive's
149   /// attribute map.
150   ///
151   /// \param[in] attrs The attribute vector needs to be added in the primitive attribute.
152   /// \return The primitive to which attribute has been added.
SetAttrs(const std::initializer_list<std::pair<std::string,ValuePtr>> & attrs)153   Primitive &SetAttrs(const std::initializer_list<std::pair<std::string, ValuePtr>> &attrs) {
154     PrimitiveWriteLock write_lock(shared_mutex_);
155     for (auto &attr : attrs) {
156       attrs_[attr.first] = attr.second;
157     }
158     return *this;
159   }
160   /// \brief Use add attribute by using a vector, all elements of the vector will be added in the primitive's attribute
161   /// map.
162   ///
163   /// \param[in] attrs The attribute vector needs to be added in the primitive attribute.
164   /// \return The primitive to which attribute has been added.
SetAttrs(const std::vector<std::pair<std::string,ValuePtr>> & attrs)165   Primitive &SetAttrs(const std::vector<std::pair<std::string, ValuePtr>> &attrs) {
166     PrimitiveWriteLock write_lock(shared_mutex_);
167     for (auto &attr : attrs) {
168       attrs_[attr.first] = attr.second;
169     }
170     return *this;
171   }
172   /// \brief Set attribute to the primitive attribute map.
set_attr(const std::string & attrName,const ValuePtr & attr)173   void set_attr(const std::string &attrName, const ValuePtr &attr) {
174     // cppcheck-suppress unreadVariable
175     PrimitiveWriteLock write_lock(shared_mutex_);
176     attrs_[attrName] = attr;
177   }
178   /// \brief Erase attribute to the primitive attribute map.
EraseAttr(const std::string & attrName)179   void EraseAttr(const std::string &attrName) {
180     // cppcheck-suppress unreadVariable
181     PrimitiveWriteLock write_lock(shared_mutex_);
182     (void)attrs_.erase(attrName);
183   }
184   /// \brief Run Primitive's compute function if the compute function has been implemented.
185   ///
186   /// \param[in] args The arguments of primitive need to compute.
187   /// \return The primitive's calculation result.
RunComputeFunction(const VectorRef & args)188   virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; }
189   /// \brief Get Primitive's attribute.
190   ///
191   /// \param[in] attrName Primitive attribute name.
192   /// \return The value of attribute in primitive attribute map, if the map is not
GetAttr(const std::string & attrName)193   ValuePtr GetAttr(const std::string &attrName) const {
194     PrimitiveReadLock read_lock(shared_mutex_);
195     auto iter = attrs_.find(attrName);
196     return iter == attrs_.cend() ? nullptr : iter->second;
197   }
198   /// \brief Get Primitive's all attributes.
199   ///
200   /// \return The Primitive's all attribute.
attrs()201   const mindspore::HashMap<std::string, ValuePtr> &attrs() const { return attrs_; }
202   /// \brief Get the attributes added in MindSpore renormalize stage.
203   ///
204   /// \return Attributes which have been added in MindSpore renormalize stage.
evaluate_added_attrs()205   const mindspore::HashMap<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; }
206   /// \brief Use add attribute using a map,all elements of the map will be added in the primitive's attribute map.
207   ///
208   /// \param[in] attrs The attribute map needs to be added in the primitive attribute.
set_evaluate_added_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)209   void set_evaluate_added_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
210     // cppcheck-suppress unreadVariable
211     PrimitiveWriteLock write_lock(shared_mutex_);
212     for (auto &attr : attrs) {
213       (void)attrs_.insert_or_assign(attr.first, attr.second);
214     }
215     evaluate_added_attrs_ = attrs;
216   }
217   /// \brief Check if Primitive has any attribute.
218   /// for example Primitives like scalar_add, return, etc, don't have any attribute.
219   ///
220   /// \return Return ture, If Primitive has attributes, else return false.
HasAttr()221   bool HasAttr() const { return !attrs_.empty(); }
222   /// \brief Check If Primitive has an attribute named attrName.
223   ///
224   /// \param[in] attrName The name of attribute.
225   /// \return Return true if Primitive has an attribute named attrName,else return false.
HasAttr(const std::string & attrName)226   bool HasAttr(const std::string &attrName) const {
227     auto iter = attrs_.find(attrName);
228     return !(iter == attrs_.cend());
229   }
230   /// \brief Set the name of primitive.
231   ///
232   /// \param t The primitive type that needs to be set.
set_prim_type(const PrimType t)233   void set_prim_type(const PrimType t) { prim_type_ = t; }
234   /// \brief Clone a Primitive.
235   ///
236   /// \return A Primitive which cloned by current primitive.
Clone()237   virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
238   /// \brief Set primitive instance_name.
239   ///
240   /// \param[in] s The primitive instance name to be set.
set_instance_name(const std::string & s)241   void set_instance_name(const std::string &s) { instance_name_ = s; }
242   /// \brief Check whether the primitive type if has the Python infer function,
243   ///
244   /// \return Return true if Primitive's type is kPrimTypePyInfer or kPrimTypeUserCustom, else return false.
HasPyEvaluator()245   bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; }
246   /// \brief Check whether the primitive type if has the python infer function,
247   ///
248   /// \return Return true if Primitive's type is kPrimTypeUserCustom, else return false.
IsCustomPrim()249   bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; }
250   /// \brief Get Primitive type.
251   ///
252   /// \return The type of Primitive.
prim_type()253   PrimType prim_type() const { return prim_type_; }
254   /// \brief Get primitive instance name.
255   ///
256   /// \return The instance name of primitive.
instance_name()257   std::string instance_name() const { return instance_name_; }
258   /// \brief Get primitive attribute debug string.
259   /// If the attribute name of primitive is a,the value is b
260   /// The return value of GetAttrsText function is [a=b].
261   ///
262   /// \return Get attribute debug string of primitive.
263   std::string GetAttrsText() const;
264   bool operator==(const Value &other) const override;
265   /// \brief To compare whether two Primitive objects are equal.
266   ///
267   /// \param[in] other The other Primitive be compared with.
268   /// \return return true if the name and attributes of primitives are the same,otherwise return false.
269   bool operator==(const Primitive &other) const;
270   /// \brief Destructor of Primitive.
271   ~Primitive() override = default;
272   /// \brief The flag to be set in primitive.
273   ///
274   /// \param[in] has_signature Set the flag whether there is a signature for the primitive.
set_has_signature(bool has_signature)275   void set_has_signature(bool has_signature) { has_signature_ = has_signature; }
276   /// \brief Check whether the primitive has signature.
277   ///
278   /// \return Return true if primitive has signature flag , else return false.
has_signature()279   bool has_signature() const { return has_signature_; }
280   /// \brief Set signatures of primitive.
281   ///
282   /// \param[in] signatures Set signatures of primitive.
283   void set_signatures(const std::vector<Signature> &signatures);
284   /// \brief Get signatures of primitive.
285   ///
286   /// \return Return signatures of primitive.
signatures()287   const std::vector<Signature> &signatures() const { return signatures_; }
288   /// \brief Check whether the primitive is a basic primitive.
289   ///
290   /// \return Return true if the primitive is basic, else return false.
is_base()291   bool is_base() const { return is_base_; }
292   /// \brief Set primitive const flag.
293   /// If the const_prim_ of primitive is true means the primitive will be eliminated in constant folding.
294   ///
295   /// \param is_const_prim The flag of primitive to be set.
set_const_prim(bool is_const_prim)296   void set_const_prim(bool is_const_prim) { const_prim_ = is_const_prim; }
297   /// \brief Check whether the primitive is const primitive.
298   ///
299   /// \return Return true if primitive is a const primitive, else return false.
const_prim()300   bool const_prim() const { return const_prim_; }
301   /// \brief Set const input index for primitive.
302   ///
303   /// \param const_input_indexes The const input index of the primitive to be set.
set_const_input_indexes(const std::vector<size_t> & const_input_indexes)304   void set_const_input_indexes(const std::vector<size_t> &const_input_indexes) {
305     const_input_indexes_ = const_input_indexes;
306   }
307   /// \brief Get const input index of the primitive.
308   ///
309   /// \return  Const input indexes of the primitive.
get_const_input_indexes()310   const std::vector<size_t> &get_const_input_indexes() const { return const_input_indexes_; }
311   /// \brief Get Primitive's id.
312   ///
313   /// \return primitive's Id.
id()314   uint64_t id() const { return id_; }
315 
316   /// \brief Check whether the primitive is inplace primitive.
317   ///
318   /// \return Return true if primitive is a inplace primitive, else return false.
inplace_prim()319   bool inplace_prim() const { return inplace_prim_; }
320   /// \brief Set primitive inplace flag.
321   ///
322   /// \param inplace_prim The flag of primitive to be set.
set_inplace_prim(bool inplace_prim)323   void set_inplace_prim(bool inplace_prim) { inplace_prim_ = inplace_prim; }
324 
325   /// \brief Enable primitive read/write lock.
EnableSharedMutex()326   void EnableSharedMutex() {
327     if (shared_mutex_ == nullptr) {
328       shared_mutex_ = std::make_shared<std::shared_mutex>();
329     }
330   }
331 
332   /// \brief Get primitive shared_mutex.
333   ///
334   /// \return Return shared_mutex of the primitive.
shared_mutex()335   const std::shared_ptr<std::shared_mutex> &shared_mutex() const { return shared_mutex_; }
336 
IsPythonPrim()337   virtual bool IsPythonPrim() { return false; }
338 
339  protected:
340   mindspore::HashMap<std::string, ValuePtr> attrs_;
341   mindspore::HashMap<std::string, ValuePtr> evaluate_added_attrs_;
342 
343  private:
344   std::string instance_name_;
345   PrimType prim_type_;
346   bool is_base_;
347   bool has_signature_;
348   std::vector<Signature> signatures_;
349   bool record_evaluate_add_attr_;
350   bool const_prim_;
351   bool inplace_prim_;
352   std::vector<size_t> const_input_indexes_;
353   uint64_t id_{0};
354   std::shared_ptr<std::shared_mutex> shared_mutex_{nullptr};
355 };
356 
357 inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
358   os << *p;
359   return os;
360 }
361 
362 /// \brief Equal operator for Primitive.
363 struct MS_CORE_API PrimitiveEqual {
364   /// \brief Implementation of Equal operation.
365   ///
366   /// \param t1 The left Primitive to compare.
367   /// \param t2 The right Primitive to compare.
368   /// \return The comparison result,Return true if the name and address of t1 and t2 are the same ,else return false.
operatorPrimitiveEqual369   bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
370     MS_EXCEPTION_IF_NULL(t1);
371     MS_EXCEPTION_IF_NULL(t2);
372     return t1 == t2 || t1->name() == t2->name();
373   }
374 };
375 
376 /// \brief Implementation of hash operation.
377 struct MS_CORE_API PrimitiveHasher {
378   /// \brief Implementation of hash operation.
379   ///
380   /// \param name The PrimitiveHasher to be hashed.
381   /// \return The hash result.
operatorPrimitiveHasher382   std::size_t operator()(PrimitivePtr const &prim) const {
383     MS_EXCEPTION_IF_NULL(prim);
384     return prim->Hash();
385   }
386 };
387 
388 /// \brief Equal operator for Primitive.
389 struct MS_CORE_API PrimitiveTotalEqual {
390   /// \brief Implementation of Equal operation.
391   ///
392   /// \param t1 The left Primitive to compare.
393   /// \param t2 The right Primitive to compare.
394   /// \return The comparison result,Return true if t1 and t2 are the same,else return false.
operatorPrimitiveTotalEqual395   bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
396     MS_EXCEPTION_IF_NULL(t1);
397     MS_EXCEPTION_IF_NULL(t2);
398     return *t1 == *t2;
399   }
400 };
401 }  // namespace mindspore
402 #endif  // MINDSPORE_CORE_IR_PRIMITIVE_H_
403