1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_SYMBOLIC_SHAPE_SYMBOL_H_ 17 #define MINDSPORE_CORE_SYMBOLIC_SHAPE_SYMBOL_H_ 18 #include <memory> 19 #include <vector> 20 #include <algorithm> 21 #include <ostream> 22 #include <string> 23 #include <utility> 24 #include "base/base.h" 25 #include "ir/value.h" 26 27 #ifndef MS_UNLIKELY 28 #ifdef _MSC_VER 29 #define MS_UNLIKELY(x) (x) 30 #else 31 #define MS_UNLIKELY(x) __builtin_expect(!!(x), 0) 32 #endif 33 #endif 34 35 #ifndef MS_LIKELY 36 #ifdef _MSC_VER 37 #define MS_LIKELY(x) (x) 38 #else 39 #define MS_LIKELY(x) __builtin_expect(!!(x), 1) 40 #endif 41 #endif 42 43 namespace mindspore { 44 namespace symshape { 45 class Symbol; 46 using SymbolPtr = std::shared_ptr<Symbol>; 47 using SymbolPtrList = std::vector<SymbolPtr>; 48 49 class IntSymbol; 50 using IntSymbolPtr = std::shared_ptr<IntSymbol>; 51 52 class Operation; 53 using OpPtr = std::shared_ptr<Operation>; 54 using OpPtrList = std::vector<OpPtr>; 55 using OpWeakPtr = std::weak_ptr<Operation>; 56 57 /// \brief The base class of symbol objects in symbolic shape. 58 /// 59 /// The symbol can represent a shape, items of shape, or values for inferring shape, etc. 60 /// 61 /// NOTE: The 'cast' and 'isa' function of Base is hid, 'cast_ptr' can be used to convert the original symbol. 62 /// Use 'as' and 'is' to cast and check the type of symbol, to make the `DynamicSymbol` transparent in most situation. 63 class MS_CORE_API Symbol : public Base { 64 public: 65 /// \brief Constructor of Symbol 66 /// 67 /// \param[in] op The operation that built this symbol (if exists) operation_(op)68 explicit Symbol(const OpPtr &op = nullptr) : operation_(op) {} 69 ~Symbol() override = default; MS_DECLARE_PARENT(Symbol,Base)70 MS_DECLARE_PARENT(Symbol, Base) 71 72 /// \brief Update the symbol data in runtime. Only variable symbol can be updated. 73 inline void Update(const SymbolPtr &s) { 74 if (MS_LIKELY(s != nullptr && s.get() != this)) { 75 UpdateImpl(s); 76 } 77 } 78 79 /// \brief Whether the symbol has data. 80 /// 81 /// Variable symbol has no data in compiling, it has data after updating in runtime. 82 /// Constant symbol always has data. HasData()83 virtual bool HasData() const { return true; } 84 85 /// @brief Whether the symbol can be updated in runtime, only variable symbol can be updated. CanUpdate()86 virtual bool CanUpdate() const { return true; } 87 88 /// \brief Whether two symbols are equal in mathematic. 89 virtual bool operator==(const Symbol &s) const { return this == &s; } 90 91 /// \brief Whether two symbols are equal in mathematic. EqualsTo(const SymbolPtr & other)92 inline bool EqualsTo(const SymbolPtr &other) const { return (other != nullptr) && ((*this) == (*other)); } 93 94 /// \brief Get the raw data of symbol. ToRawString()95 virtual std::string ToRawString() const { return ToString(); } 96 97 /// \brief Convert the symbol to a ValuePtr ToValue()98 virtual ValuePtr ToValue() const { return kValueAny; } ToValueOf(const TypePtr &)99 virtual ValuePtr ToValueOf(const TypePtr &) const { return ToValue(); } 100 101 /// \brief Get the operation that built this symbol. operation()102 inline OpPtr operation() const { return operation_.lock(); } 103 104 /// \brief Judge whether this object is an instance of a given class which is derived from Symbol. 105 template <typename T> is()106 inline bool is() const { 107 auto *s = const_cast<Symbol *>(this)->real_symbol(); 108 return s != nullptr && s->isa<T>(); 109 } 110 111 /// \brief Cast to a raw pointer of the given class, if the object type doesn't match, an exception will be thrown. 112 template <typename T> as()113 inline T *as() { 114 auto ret = as_noexcept<T>(); 115 if (MS_UNLIKELY(ret == nullptr)) { 116 MS_LOG(INTERNAL_EXCEPTION) << "Failed to cast the symbol " << ToString() << " to " << typeid(T).name(); 117 } 118 return ret; 119 } 120 121 /// \brief Cast to a raw pointer of the given class, if the object type doesn't match, an exception will be thrown. 122 template <typename T> as()123 inline const T *as() const { 124 auto ret = as_noexcept<T>(); 125 if (MS_UNLIKELY(ret == nullptr)) { 126 MS_LOG(INTERNAL_EXCEPTION) << "Failed to cast the symbol " << ToString() << " to " << typeid(T).name(); 127 } 128 return ret; 129 } 130 131 /// \brief Cast to a shared_ptr of the given class, if the object type doesn't match, an exception will be thrown. 132 template <typename T> as_sptr()133 inline std::shared_ptr<T> as_sptr() { 134 auto ret = as_sptr_noexcept<T>(); 135 if (MS_UNLIKELY(ret == nullptr)) { 136 MS_LOG(INTERNAL_EXCEPTION) << "Failed to cast the symbol " << ToString() << " to " << typeid(T).name(); 137 } 138 return ret; 139 } 140 141 /// \brief Cast to a raw pointer of the given class, if the object type doesn't match, a nullptr will be returned. 142 template <typename T> as_noexcept()143 inline T *as_noexcept() { 144 auto s = real_symbol(); 145 return MS_UNLIKELY(s == nullptr) ? nullptr : s->cast_ptr<T>(); 146 } 147 148 /// \brief Cast to a raw pointer of the given class, if the object type doesn't match, a nullptr will be returned. 149 template <typename T> as_noexcept()150 inline const T *as_noexcept() const { 151 auto *s = const_cast<Symbol *>(this)->real_symbol(); 152 return MS_UNLIKELY(s == nullptr) ? nullptr : s->cast_ptr<T>(); 153 } 154 155 /// \brief Cast to a shared_ptr of the given class, if the object type doesn't match, a nullptr will be returned. 156 template <typename T> as_sptr_noexcept()157 inline std::shared_ptr<T> as_sptr_noexcept() { 158 auto s = real_symbol(); 159 return MS_UNLIKELY(s == nullptr) ? nullptr : s->cast<std::shared_ptr<T>>(); 160 } 161 162 protected: 163 using Base::cast; 164 using Base::isa; UpdateImpl(const SymbolPtr & s)165 virtual void UpdateImpl(const SymbolPtr &s) { 166 MS_EXCEPTION(NotImplementedError) << "The 'Update' of " << type_name() << " is not implemented."; 167 } real_symbol()168 virtual Symbol *real_symbol() { return this; } sid()169 inline std::string sid() const { return "s" + std::to_string(id()); } 170 OpWeakPtr operation_; 171 172 private: 173 size_t id() const; 174 mutable size_t id_{0}; 175 }; 176 177 /// \brief DynamicSymbol represents the symbol type is dynamic, such as "symbol of scalar or list". 178 class MS_CORE_API DynamicSymbol : public Symbol { 179 public: 180 using Symbol::Symbol; 181 ~DynamicSymbol() override = default; MS_DECLARE_PARENT(DynamicSymbol,Symbol)182 MS_DECLARE_PARENT(DynamicSymbol, Symbol) 183 inline static std::shared_ptr<DynamicSymbol> Make(const OpPtr &op = nullptr) { 184 return std::make_shared<DynamicSymbol>(op); 185 } 186 bool operator==(const Symbol &s) const override { return (this == &s) || ((symbol_ != nullptr) && (*symbol_ == s)); } HasData()187 bool HasData() const override { return symbol_ != nullptr; } ToString()188 std::string ToString() const override { return symbol_ == nullptr ? "DYN-" + sid() : symbol_->ToString(); } ToRawString()189 std::string ToRawString() const override { return symbol_ == nullptr ? sid() : symbol_->ToRawString(); } ToValue()190 ValuePtr ToValue() const override { return symbol_ == nullptr ? Symbol::ToValue() : symbol_->ToValue(); } ToValueOf(const TypePtr & type)191 ValuePtr ToValueOf(const TypePtr &type) const override { 192 return symbol_ == nullptr ? Symbol::ToValue() : symbol_->ToValueOf(type); 193 } symbol()194 const SymbolPtr &symbol() const { return symbol_; } 195 196 protected: 197 void UpdateImpl(const SymbolPtr &s) override; real_symbol()198 Symbol *real_symbol() override { return symbol_.get(); } 199 SymbolPtr symbol_{nullptr}; 200 }; 201 using DynamicSymbolPtr = std::shared_ptr<DynamicSymbol>; 202 203 /// \brief The base class of scalar objects. 204 class MS_CORE_API ScalarSymbol : public Symbol { 205 public: ScalarSymbol(bool is_const,bool has_data,const OpPtr & op)206 ScalarSymbol(bool is_const, bool has_data, const OpPtr &op) : Symbol(op), is_const_(is_const), has_data_(has_data) {} 207 ~ScalarSymbol() override = default; MS_DECLARE_PARENT(ScalarSymbol,Symbol)208 MS_DECLARE_PARENT(ScalarSymbol, Symbol) 209 bool HasData() const override { return has_data_; } CanUpdate()210 bool CanUpdate() const override { return !is_const_; } 211 bool operator==(const Symbol &s) const override; ToString()212 std::string ToString() const override { return ToRawString(); } is_const()213 bool is_const() const { return is_const_; } 214 215 protected: 216 void UpdateImpl(const SymbolPtr &s) override; 217 /// \brief set value, called by `UpdateImpl` SetValueByScalar(const Symbol * s)218 virtual void SetValueByScalar(const Symbol *s) { 219 MS_EXCEPTION(NotImplementedError) << "The 'SetValueByScalar' of " << type_name() << " is not implemented."; 220 } 221 /// \brief check value equal, called by `operator==` CheckEqualValue(const Symbol * s)222 virtual bool CheckEqualValue(const Symbol *s) const { 223 MS_EXCEPTION(NotImplementedError) << "The 'CheckEqualValue' of " << type_name() << " is not implemented."; 224 } 225 226 bool is_const_; 227 bool has_data_; 228 }; 229 using ScalarSymbolPtr = std::shared_ptr<ScalarSymbol>; 230 231 class MS_CORE_API BoolSymbol final : public ScalarSymbol { 232 public: 233 using elem_type = bool; 234 using ScalarSymbol::ScalarSymbol; 235 ~BoolSymbol() override = default; MS_DECLARE_PARENT(BoolSymbol,ScalarSymbol)236 MS_DECLARE_PARENT(BoolSymbol, ScalarSymbol) 237 static inline std::shared_ptr<BoolSymbol> Make(bool val, const OpPtr &op = nullptr) { 238 auto s = std::make_shared<BoolSymbol>(true, true, op); 239 s->value_ = val; 240 return s; 241 } 242 static inline std::shared_ptr<BoolSymbol> Make(const OpPtr &op = nullptr) { 243 return std::make_shared<BoolSymbol>(false, false, op); 244 } SetValue(bool v)245 inline void SetValue(bool v) { 246 MS_EXCEPTION_IF_CHECK_FAIL(!is_const_, ToString() + " is const symbol and cannot be updated."); 247 has_data_ = true; 248 value_ = v; 249 } value()250 inline bool value() const { 251 MS_EXCEPTION_IF_CHECK_FAIL(has_data_, ToString() + "has no value."); 252 return value_; 253 } 254 std::string ToRawString() const override; 255 ValuePtr ToValue() const override; 256 257 protected: SetValueByScalar(const Symbol * s)258 void SetValueByScalar(const Symbol *s) override { value_ = static_cast<const BoolSymbol *>(s)->value_; } CheckEqualValue(const Symbol * s)259 bool CheckEqualValue(const Symbol *s) const override { return value_ == static_cast<const BoolSymbol *>(s)->value_; } 260 261 bool value_{false}; 262 }; 263 using BoolSymbolPtr = std::shared_ptr<BoolSymbol>; 264 265 class MS_CORE_API FloatSymbol final : public ScalarSymbol { 266 public: 267 using elem_type = double; 268 using ScalarSymbol::ScalarSymbol; 269 ~FloatSymbol() override = default; MS_DECLARE_PARENT(FloatSymbol,ScalarSymbol)270 MS_DECLARE_PARENT(FloatSymbol, ScalarSymbol) 271 static inline std::shared_ptr<FloatSymbol> Make(elem_type val, const OpPtr &op = nullptr) { 272 auto s = std::make_shared<FloatSymbol>(true, true, op); 273 s->value_ = val; 274 return s; 275 } 276 static inline std::shared_ptr<FloatSymbol> Make(const OpPtr &op = nullptr) { 277 return std::make_shared<FloatSymbol>(false, false, op); 278 } SetValue(elem_type v)279 inline void SetValue(elem_type v) { 280 MS_EXCEPTION_IF_CHECK_FAIL(!is_const_, ToString() + " is const symbol and cannot be updated."); 281 has_data_ = true; 282 value_ = v; 283 } value()284 inline elem_type value() const { 285 MS_EXCEPTION_IF_CHECK_FAIL(has_data_, ToString() + "has no value."); 286 return value_; 287 } 288 std::string ToRawString() const override; 289 ValuePtr ToValue() const override; 290 ValuePtr ToValueOf(const TypePtr &type) const override; 291 292 protected: SetValueByScalar(const Symbol * s)293 void SetValueByScalar(const Symbol *s) override { value_ = static_cast<const FloatSymbol *>(s)->value_; } CheckEqualValue(const Symbol * s)294 bool CheckEqualValue(const Symbol *s) const override { return value_ == static_cast<const FloatSymbol *>(s)->value_; } 295 296 elem_type value_{0}; 297 }; 298 using FloatSymbolPtr = std::shared_ptr<FloatSymbol>; 299 300 class MS_CORE_API StrSymbol final : public ScalarSymbol { 301 public: 302 using ScalarSymbol::ScalarSymbol; 303 ~StrSymbol() override = default; MS_DECLARE_PARENT(StrSymbol,ScalarSymbol)304 MS_DECLARE_PARENT(StrSymbol, ScalarSymbol) 305 static inline std::shared_ptr<StrSymbol> Make(const std::string &val, const OpPtr &op = nullptr) { 306 auto s = std::make_shared<StrSymbol>(true, true, op); 307 s->value_ = val; 308 return s; 309 } 310 static inline std::shared_ptr<StrSymbol> Make(const OpPtr &op = nullptr) { 311 return std::make_shared<StrSymbol>(false, false, op); 312 } SetValue(const std::string & v)313 inline void SetValue(const std::string &v) { 314 MS_EXCEPTION_IF_CHECK_FAIL(!is_const_, ToString() + " is const symbol and cannot be updated."); 315 has_data_ = true; 316 value_ = v; 317 } value()318 inline const std::string &value() const { 319 MS_EXCEPTION_IF_CHECK_FAIL(has_data_, ToString() + "has no value."); 320 return value_; 321 } 322 std::string ToRawString() const override; 323 ValuePtr ToValue() const override; 324 325 protected: SetValueByScalar(const Symbol * s)326 void SetValueByScalar(const Symbol *s) override { value_ = static_cast<const StrSymbol *>(s)->value_; } CheckEqualValue(const Symbol * s)327 bool CheckEqualValue(const Symbol *s) const override { return value_ == static_cast<const StrSymbol *>(s)->value_; } 328 329 std::string value_; 330 }; 331 using StrSymbolPtr = std::shared_ptr<StrSymbol>; 332 333 class MS_CORE_API ListSymbol final : public Symbol { 334 public: 335 using SPtr = std::shared_ptr<ListSymbol>; ListSymbol(const SymbolPtrList & slist,const OpPtr & op)336 ListSymbol(const SymbolPtrList &slist, const OpPtr &op) : Symbol(op), symbols_(slist) {} ListSymbol(SymbolPtrList && slist,const OpPtr & op)337 ListSymbol(SymbolPtrList &&slist, const OpPtr &op) : Symbol(op), symbols_(slist) {} ListSymbol(const std::initializer_list<SymbolPtr> & slist,const OpPtr & op)338 ListSymbol(const std::initializer_list<SymbolPtr> &slist, const OpPtr &op) : Symbol(op), symbols_(slist) {} ListSymbol(const OpPtr & op)339 explicit ListSymbol(const OpPtr &op) : Symbol(op), is_dyn_len_(true), has_data_(false) {} 340 ~ListSymbol() override = default; MS_DECLARE_PARENT(ListSymbol,Symbol)341 MS_DECLARE_PARENT(ListSymbol, Symbol) 342 343 static inline SPtr Make(const SymbolPtrList &slist, const OpPtr &op = nullptr) { 344 return std::make_shared<ListSymbol>(slist, op); 345 } 346 static inline SPtr Make(SymbolPtrList &&slist, const OpPtr &op = nullptr) { 347 return std::make_shared<ListSymbol>(slist, op); 348 } 349 static inline SPtr Make(const std::initializer_list<SymbolPtr> &slist, const OpPtr &op = nullptr) { 350 return std::make_shared<ListSymbol>(slist, op); 351 } 352 static inline SPtr Make(const OpPtr &op = nullptr) { return std::make_shared<ListSymbol>(op); } 353 354 bool operator==(const Symbol &s) const override; 355 std::string ToString() const override; 356 std::string ToRawString() const override; 357 ValuePtr ToValue() const override; 358 ValuePtr ToValueOf(const TypePtr &type) const override; 359 HasData()360 bool HasData() const override { return has_data_; } AllHaveData()361 bool AllHaveData() const { 362 return has_data_ && std::all_of(symbols_.cbegin(), symbols_.cend(), [](const SymbolPtr &s) { 363 return s->is<ListSymbol>() ? s->as_noexcept<ListSymbol>()->AllHaveData() : s->HasData(); 364 }); 365 } CanUpdate()366 bool CanUpdate() const override { 367 return is_dyn_len_ || std::any_of(symbols_.cbegin(), symbols_.cend(), [](auto &s) { return s->CanUpdate(); }); 368 } 369 void UpdateList(const SymbolPtrList &slist); UpdateList(SymbolPtrList && slist)370 inline void UpdateList(SymbolPtrList &&slist) { 371 if (is_dyn_len_) { 372 has_data_ = true; 373 symbols_ = slist; 374 } else { 375 UpdateList(static_cast<const SymbolPtrList &>(slist)); 376 } 377 } 378 const SymbolPtr &item(size_t i) const; 379 template <typename T> item_as(size_t i)380 const T *item_as(size_t i) const { 381 auto ret = item(i)->as_noexcept<T>(); 382 if (MS_UNLIKELY(ret == nullptr)) { 383 MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for item " << i << " of " << ToString(); 384 } 385 return ret; 386 } 387 template <typename T> item_as_sptr(size_t i)388 std::shared_ptr<T> item_as_sptr(size_t i) const { 389 auto ret = item(i)->as_sptr_noexcept<T>(); 390 if (MS_UNLIKELY(ret == nullptr)) { 391 MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for item " << i << " of " << ToString(); 392 } 393 return ret; 394 } symbols()395 const SymbolPtrList &symbols() const { return symbols_; } size()396 size_t size() const { return symbols_.size(); } is_dyn_len()397 bool is_dyn_len() const { return is_dyn_len_; } 398 399 protected: 400 void UpdateImpl(const SymbolPtr &s) override; 401 SymbolPtrList symbols_; 402 bool is_dyn_len_{false}; 403 bool has_data_{true}; 404 }; 405 using ListSymbolPtr = std::shared_ptr<ListSymbol>; 406 } // namespace symshape 407 408 using symshape::BoolSymbol; 409 using symshape::BoolSymbolPtr; 410 using symshape::DynamicSymbol; 411 using symshape::DynamicSymbolPtr; 412 using symshape::FloatSymbol; 413 using symshape::FloatSymbolPtr; 414 using symshape::IntSymbol; 415 using symshape::IntSymbolPtr; 416 using symshape::ListSymbol; 417 using symshape::ListSymbolPtr; 418 using symshape::ScalarSymbol; 419 using symshape::ScalarSymbolPtr; 420 using symshape::Symbol; 421 using symshape::SymbolPtr; 422 using symshape::SymbolPtrList; 423 } // namespace mindspore 424 #endif // MINDSPORE_CORE_SYMBOLIC_SHAPE_SYMBOL_H_ 425