1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_PRIMITIVE_H_ 18 #define MINDSPORE_CORE_IR_PRIMITIVE_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include <tuple> 24 #include <utility> 25 #include <shared_mutex> 26 #include <initializer_list> 27 28 #include "utils/hash_map.h" 29 #include "ir/signature.h" 30 #include "ir/dtype/type.h" 31 #include "abstract/abstract_value.h" 32 #include "base/base_ref.h" 33 34 namespace mindspore { 35 // Supported meta type 36 enum PrimType { 37 kPrimTypeUnknown = 0, 38 kPrimTypeBegin = kPrimTypeUnknown, 39 kPrimTypeBuiltIn, // Built-in primitive operator 40 kPrimTypePyInfer, // Primitive operator with python infer function 41 kPrimTypeUserCustom, // Primitive operator defined by custom 42 kPrimTypePyCheck // Primitive operator with input args checking method 43 }; 44 45 class MS_CORE_API PrimitiveReadLock { 46 public: PrimitiveReadLock(std::shared_ptr<std::shared_mutex> shared_mutex)47 explicit PrimitiveReadLock(std::shared_ptr<std::shared_mutex> shared_mutex) : shared_mutex_(std::move(shared_mutex)) { 48 if (shared_mutex_ != nullptr) { 49 shared_mutex_->lock_shared(); 50 } 51 } ~PrimitiveReadLock()52 ~PrimitiveReadLock() { 53 if (shared_mutex_ != nullptr) { 54 // cppcheck-suppress unreadVariable 55 shared_mutex_->unlock_shared(); 56 } 57 } 58 59 private: 60 std::shared_ptr<std::shared_mutex> shared_mutex_; 61 }; 62 63 class MS_CORE_API PrimitiveWriteLock { 64 public: PrimitiveWriteLock(std::shared_ptr<std::shared_mutex> shared_mutex)65 explicit PrimitiveWriteLock(std::shared_ptr<std::shared_mutex> shared_mutex) 66 : shared_mutex_(std::move(shared_mutex)) { 67 if (shared_mutex_ != nullptr) { 68 shared_mutex_->lock(); 69 } 70 } ~PrimitiveWriteLock()71 ~PrimitiveWriteLock() { 72 if (shared_mutex_ != nullptr) { 73 shared_mutex_->unlock(); 74 } 75 } 76 77 private: 78 std::shared_ptr<std::shared_mutex> shared_mutex_; 79 }; 80 81 /// \brief Primitive defines a operator primitive of MindSpore. 82 class MS_CORE_API Primitive : public Named { 83 public: 84 /// \brief The constructor of Primitive. 85 /// 86 /// \param[in] name The name of primitive. 87 /// \param[in] is_base True means the basic Primitive without BProp function inside. 88 /// \param[in] prim_type The type of primitive. 89 explicit Primitive(const std::string &name, bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn, 90 bool inplace_prim = false); 91 Primitive(const std::string &name, const mindspore::HashMap<std::string, ValuePtr> &attrs, bool inplace_prim = false); 92 /// \brief The constructor for Primitive, create a primitive for another primitive. 93 /// 94 /// \param[in] prim The input primitive. 95 Primitive(const Primitive &prim); 96 /// \brief The copy assignment operator for Primitive. 97 /// 98 /// \param[in] other An existing Primitive object. 99 /// \return A Primitive object set with the same members as other. 100 Primitive &operator=(const Primitive &other); 101 MS_DECLARE_PARENT(Primitive, Named); 102 abstract::AbstractBasePtr ToAbstract() override; 103 std::string ToString() const override; 104 /// \brief Ready to recording the attribute if the attribute needs to be added when deducing shape and type. 105 /// This attributes has been recorded needs to add in infer cache. BeginRecordAddAttr()106 void BeginRecordAddAttr() { 107 evaluate_added_attrs_.clear(); 108 record_evaluate_add_attr_ = true; 109 } 110 /// \brief End recording attribute. EndRecordAddAttr()111 void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } 112 /// \brief Add attribute to primitive attribute map and record the new attribute to evaluate_added_attrs_, 113 /// if record_evaluate_add_attr_ is true. 114 /// 115 /// \param[in] name The name of attribute. 116 /// \param[in] attr The value of attribute. 117 /// \return The primitive to which attribute has been added. AddAttr(const std::string & name,const ValuePtr & attr)118 Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { 119 // cppcheck-suppress unreadVariable 120 PrimitiveWriteLock write_lock(shared_mutex_); 121 attrs_[name] = attr; 122 if (record_evaluate_add_attr_) { 123 evaluate_added_attrs_[name] = attr; 124 } 125 return *this; 126 } 127 /// \brief Delete the attribute. 128 /// 129 /// \param[in] name The name of attribute to be delete. 130 /// \return The primitive to which attribute has been added. DelAttr(const std::string & name)131 Primitive &DelAttr(const std::string &name) { 132 // cppcheck-suppress unreadVariable 133 PrimitiveWriteLock write_lock(shared_mutex_); 134 (void)attrs_.erase(name); 135 return *this; 136 } 137 /// \brief Use add attribute by using a map,all elements of the map will be added in the primitive's attribute map. 138 /// 139 /// \param[in] attrs The attribute map needs to be added in the primitive attribute. 140 /// \return The primitive to which attribute has been added. SetAttrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)141 Primitive &SetAttrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) { 142 PrimitiveWriteLock write_lock(shared_mutex_); 143 for (auto &attr : attrs) { 144 attrs_[attr.first] = attr.second; 145 } 146 return *this; 147 } 148 /// \brief Use add attribute by using initializer_list, all elements of the vector will be added in the primitive's 149 /// attribute map. 150 /// 151 /// \param[in] attrs The attribute vector needs to be added in the primitive attribute. 152 /// \return The primitive to which attribute has been added. SetAttrs(const std::initializer_list<std::pair<std::string,ValuePtr>> & attrs)153 Primitive &SetAttrs(const std::initializer_list<std::pair<std::string, ValuePtr>> &attrs) { 154 PrimitiveWriteLock write_lock(shared_mutex_); 155 for (auto &attr : attrs) { 156 attrs_[attr.first] = attr.second; 157 } 158 return *this; 159 } 160 /// \brief Use add attribute by using a vector, all elements of the vector will be added in the primitive's attribute 161 /// map. 162 /// 163 /// \param[in] attrs The attribute vector needs to be added in the primitive attribute. 164 /// \return The primitive to which attribute has been added. SetAttrs(const std::vector<std::pair<std::string,ValuePtr>> & attrs)165 Primitive &SetAttrs(const std::vector<std::pair<std::string, ValuePtr>> &attrs) { 166 PrimitiveWriteLock write_lock(shared_mutex_); 167 for (auto &attr : attrs) { 168 attrs_[attr.first] = attr.second; 169 } 170 return *this; 171 } 172 /// \brief Set attribute to the primitive attribute map. set_attr(const std::string & attrName,const ValuePtr & attr)173 void set_attr(const std::string &attrName, const ValuePtr &attr) { 174 // cppcheck-suppress unreadVariable 175 PrimitiveWriteLock write_lock(shared_mutex_); 176 attrs_[attrName] = attr; 177 } 178 /// \brief Erase attribute to the primitive attribute map. EraseAttr(const std::string & attrName)179 void EraseAttr(const std::string &attrName) { 180 // cppcheck-suppress unreadVariable 181 PrimitiveWriteLock write_lock(shared_mutex_); 182 (void)attrs_.erase(attrName); 183 } 184 /// \brief Run Primitive's compute function if the compute function has been implemented. 185 /// 186 /// \param[in] args The arguments of primitive need to compute. 187 /// \return The primitive's calculation result. RunComputeFunction(const VectorRef & args)188 virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; } 189 /// \brief Get Primitive's attribute. 190 /// 191 /// \param[in] attrName Primitive attribute name. 192 /// \return The value of attribute in primitive attribute map, if the map is not GetAttr(const std::string & attrName)193 ValuePtr GetAttr(const std::string &attrName) const { 194 PrimitiveReadLock read_lock(shared_mutex_); 195 auto iter = attrs_.find(attrName); 196 return iter == attrs_.cend() ? nullptr : iter->second; 197 } 198 /// \brief Get Primitive's all attributes. 199 /// 200 /// \return The Primitive's all attribute. attrs()201 const mindspore::HashMap<std::string, ValuePtr> &attrs() const { return attrs_; } 202 /// \brief Get the attributes added in MindSpore renormalize stage. 203 /// 204 /// \return Attributes which have been added in MindSpore renormalize stage. evaluate_added_attrs()205 const mindspore::HashMap<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; } 206 /// \brief Use add attribute using a map,all elements of the map will be added in the primitive's attribute map. 207 /// 208 /// \param[in] attrs The attribute map needs to be added in the primitive attribute. set_evaluate_added_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)209 void set_evaluate_added_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) { 210 // cppcheck-suppress unreadVariable 211 PrimitiveWriteLock write_lock(shared_mutex_); 212 for (auto &attr : attrs) { 213 (void)attrs_.insert_or_assign(attr.first, attr.second); 214 } 215 evaluate_added_attrs_ = attrs; 216 } 217 /// \brief Check if Primitive has any attribute. 218 /// for example Primitives like scalar_add, return, etc, don't have any attribute. 219 /// 220 /// \return Return ture, If Primitive has attributes, else return false. HasAttr()221 bool HasAttr() const { return !attrs_.empty(); } 222 /// \brief Check If Primitive has an attribute named attrName. 223 /// 224 /// \param[in] attrName The name of attribute. 225 /// \return Return true if Primitive has an attribute named attrName,else return false. HasAttr(const std::string & attrName)226 bool HasAttr(const std::string &attrName) const { 227 auto iter = attrs_.find(attrName); 228 return !(iter == attrs_.cend()); 229 } 230 /// \brief Set the name of primitive. 231 /// 232 /// \param t The primitive type that needs to be set. set_prim_type(const PrimType t)233 void set_prim_type(const PrimType t) { prim_type_ = t; } 234 /// \brief Clone a Primitive. 235 /// 236 /// \return A Primitive which cloned by current primitive. Clone()237 virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); } 238 /// \brief Set primitive instance_name. 239 /// 240 /// \param[in] s The primitive instance name to be set. set_instance_name(const std::string & s)241 void set_instance_name(const std::string &s) { instance_name_ = s; } 242 /// \brief Check whether the primitive type if has the Python infer function, 243 /// 244 /// \return Return true if Primitive's type is kPrimTypePyInfer or kPrimTypeUserCustom, else return false. HasPyEvaluator()245 bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInfer || prim_type_ == kPrimTypeUserCustom; } 246 /// \brief Check whether the primitive type if has the python infer function, 247 /// 248 /// \return Return true if Primitive's type is kPrimTypeUserCustom, else return false. IsCustomPrim()249 bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } 250 /// \brief Get Primitive type. 251 /// 252 /// \return The type of Primitive. prim_type()253 PrimType prim_type() const { return prim_type_; } 254 /// \brief Get primitive instance name. 255 /// 256 /// \return The instance name of primitive. instance_name()257 std::string instance_name() const { return instance_name_; } 258 /// \brief Get primitive attribute debug string. 259 /// If the attribute name of primitive is a,the value is b 260 /// The return value of GetAttrsText function is [a=b]. 261 /// 262 /// \return Get attribute debug string of primitive. 263 std::string GetAttrsText() const; 264 bool operator==(const Value &other) const override; 265 /// \brief To compare whether two Primitive objects are equal. 266 /// 267 /// \param[in] other The other Primitive be compared with. 268 /// \return return true if the name and attributes of primitives are the same,otherwise return false. 269 bool operator==(const Primitive &other) const; 270 /// \brief Destructor of Primitive. 271 ~Primitive() override = default; 272 /// \brief The flag to be set in primitive. 273 /// 274 /// \param[in] has_signature Set the flag whether there is a signature for the primitive. set_has_signature(bool has_signature)275 void set_has_signature(bool has_signature) { has_signature_ = has_signature; } 276 /// \brief Check whether the primitive has signature. 277 /// 278 /// \return Return true if primitive has signature flag , else return false. has_signature()279 bool has_signature() const { return has_signature_; } 280 /// \brief Set signatures of primitive. 281 /// 282 /// \param[in] signatures Set signatures of primitive. 283 void set_signatures(const std::vector<Signature> &signatures); 284 /// \brief Get signatures of primitive. 285 /// 286 /// \return Return signatures of primitive. signatures()287 const std::vector<Signature> &signatures() const { return signatures_; } 288 /// \brief Check whether the primitive is a basic primitive. 289 /// 290 /// \return Return true if the primitive is basic, else return false. is_base()291 bool is_base() const { return is_base_; } 292 /// \brief Set primitive const flag. 293 /// If the const_prim_ of primitive is true means the primitive will be eliminated in constant folding. 294 /// 295 /// \param is_const_prim The flag of primitive to be set. set_const_prim(bool is_const_prim)296 void set_const_prim(bool is_const_prim) { const_prim_ = is_const_prim; } 297 /// \brief Check whether the primitive is const primitive. 298 /// 299 /// \return Return true if primitive is a const primitive, else return false. const_prim()300 bool const_prim() const { return const_prim_; } 301 /// \brief Set const input index for primitive. 302 /// 303 /// \param const_input_indexes The const input index of the primitive to be set. set_const_input_indexes(const std::vector<size_t> & const_input_indexes)304 void set_const_input_indexes(const std::vector<size_t> &const_input_indexes) { 305 const_input_indexes_ = const_input_indexes; 306 } 307 /// \brief Get const input index of the primitive. 308 /// 309 /// \return Const input indexes of the primitive. get_const_input_indexes()310 const std::vector<size_t> &get_const_input_indexes() const { return const_input_indexes_; } 311 /// \brief Get Primitive's id. 312 /// 313 /// \return primitive's Id. id()314 uint64_t id() const { return id_; } 315 316 /// \brief Check whether the primitive is inplace primitive. 317 /// 318 /// \return Return true if primitive is a inplace primitive, else return false. inplace_prim()319 bool inplace_prim() const { return inplace_prim_; } 320 /// \brief Set primitive inplace flag. 321 /// 322 /// \param inplace_prim The flag of primitive to be set. set_inplace_prim(bool inplace_prim)323 void set_inplace_prim(bool inplace_prim) { inplace_prim_ = inplace_prim; } 324 325 /// \brief Enable primitive read/write lock. EnableSharedMutex()326 void EnableSharedMutex() { 327 if (shared_mutex_ == nullptr) { 328 shared_mutex_ = std::make_shared<std::shared_mutex>(); 329 } 330 } 331 332 /// \brief Get primitive shared_mutex. 333 /// 334 /// \return Return shared_mutex of the primitive. shared_mutex()335 const std::shared_ptr<std::shared_mutex> &shared_mutex() const { return shared_mutex_; } 336 IsPythonPrim()337 virtual bool IsPythonPrim() { return false; } 338 339 protected: 340 mindspore::HashMap<std::string, ValuePtr> attrs_; 341 mindspore::HashMap<std::string, ValuePtr> evaluate_added_attrs_; 342 343 private: 344 std::string instance_name_; 345 PrimType prim_type_; 346 bool is_base_; 347 bool has_signature_; 348 std::vector<Signature> signatures_; 349 bool record_evaluate_add_attr_; 350 bool const_prim_; 351 bool inplace_prim_; 352 std::vector<size_t> const_input_indexes_; 353 uint64_t id_{0}; 354 std::shared_ptr<std::shared_mutex> shared_mutex_{nullptr}; 355 }; 356 357 inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { 358 os << *p; 359 return os; 360 } 361 362 /// \brief Equal operator for Primitive. 363 struct MS_CORE_API PrimitiveEqual { 364 /// \brief Implementation of Equal operation. 365 /// 366 /// \param t1 The left Primitive to compare. 367 /// \param t2 The right Primitive to compare. 368 /// \return The comparison result,Return true if the name and address of t1 and t2 are the same ,else return false. operatorPrimitiveEqual369 bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { 370 MS_EXCEPTION_IF_NULL(t1); 371 MS_EXCEPTION_IF_NULL(t2); 372 return t1 == t2 || t1->name() == t2->name(); 373 } 374 }; 375 376 /// \brief Implementation of hash operation. 377 struct MS_CORE_API PrimitiveHasher { 378 /// \brief Implementation of hash operation. 379 /// 380 /// \param name The PrimitiveHasher to be hashed. 381 /// \return The hash result. operatorPrimitiveHasher382 std::size_t operator()(PrimitivePtr const &prim) const { 383 MS_EXCEPTION_IF_NULL(prim); 384 return prim->Hash(); 385 } 386 }; 387 388 /// \brief Equal operator for Primitive. 389 struct MS_CORE_API PrimitiveTotalEqual { 390 /// \brief Implementation of Equal operation. 391 /// 392 /// \param t1 The left Primitive to compare. 393 /// \param t2 The right Primitive to compare. 394 /// \return The comparison result,Return true if t1 and t2 are the same,else return false. operatorPrimitiveTotalEqual395 bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { 396 MS_EXCEPTION_IF_NULL(t1); 397 MS_EXCEPTION_IF_NULL(t2); 398 return *t1 == *t2; 399 } 400 }; 401 } // namespace mindspore 402 #endif // MINDSPORE_CORE_IR_PRIMITIVE_H_ 403