• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_SYMBOLIC_SHAPE_SYMBOL_H_
17 #define MINDSPORE_CORE_SYMBOLIC_SHAPE_SYMBOL_H_
18 #include <memory>
19 #include <vector>
20 #include <algorithm>
21 #include <ostream>
22 #include <string>
23 #include <utility>
24 #include "base/base.h"
25 #include "ir/value.h"
26 
27 #ifndef MS_UNLIKELY
28 #ifdef _MSC_VER
29 #define MS_UNLIKELY(x) (x)
30 #else
31 #define MS_UNLIKELY(x) __builtin_expect(!!(x), 0)
32 #endif
33 #endif
34 
35 #ifndef MS_LIKELY
36 #ifdef _MSC_VER
37 #define MS_LIKELY(x) (x)
38 #else
39 #define MS_LIKELY(x) __builtin_expect(!!(x), 1)
40 #endif
41 #endif
42 
43 namespace mindspore {
44 namespace symshape {
45 class Symbol;
46 using SymbolPtr = std::shared_ptr<Symbol>;
47 using SymbolPtrList = std::vector<SymbolPtr>;
48 
49 class IntSymbol;
50 using IntSymbolPtr = std::shared_ptr<IntSymbol>;
51 
52 class Operation;
53 using OpPtr = std::shared_ptr<Operation>;
54 using OpPtrList = std::vector<OpPtr>;
55 using OpWeakPtr = std::weak_ptr<Operation>;
56 
57 /// \brief The base class of symbol objects in symbolic shape.
58 ///
59 /// The symbol can represent a shape, items of shape, or values for inferring shape, etc.
60 ///
61 /// NOTE: The 'cast' and 'isa' function of Base is hid, 'cast_ptr' can be used to convert the original symbol.
62 /// Use 'as' and 'is' to cast and check the type of symbol, to make the `DynamicSymbol` transparent in most situation.
63 class MS_CORE_API Symbol : public Base {
64  public:
65   /// \brief Constructor of Symbol
66   ///
67   /// \param[in] op The operation that built this symbol (if exists)
operation_(op)68   explicit Symbol(const OpPtr &op = nullptr) : operation_(op) {}
69   ~Symbol() override = default;
MS_DECLARE_PARENT(Symbol,Base)70   MS_DECLARE_PARENT(Symbol, Base)
71 
72   /// \brief Update the symbol data in runtime. Only variable symbol can be updated.
73   inline void Update(const SymbolPtr &s) {
74     if (MS_LIKELY(s != nullptr && s.get() != this)) {
75       UpdateImpl(s);
76     }
77   }
78 
79   /// \brief Whether the symbol has data.
80   ///
81   /// Variable symbol has no data in compiling, it has data after updating in runtime.
82   /// Constant symbol always has data.
HasData()83   virtual bool HasData() const { return true; }
84 
85   /// @brief Whether the symbol can be updated in runtime, only variable symbol can be updated.
CanUpdate()86   virtual bool CanUpdate() const { return true; }
87 
88   /// \brief Whether two symbols are equal in mathematic.
89   virtual bool operator==(const Symbol &s) const { return this == &s; }
90 
91   /// \brief Whether two symbols are equal in mathematic.
EqualsTo(const SymbolPtr & other)92   inline bool EqualsTo(const SymbolPtr &other) const { return (other != nullptr) && ((*this) == (*other)); }
93 
94   /// \brief Get the raw data of symbol.
ToRawString()95   virtual std::string ToRawString() const { return ToString(); }
96 
97   /// \brief Convert the symbol to a ValuePtr
ToValue()98   virtual ValuePtr ToValue() const { return kValueAny; }
ToValueOf(const TypePtr &)99   virtual ValuePtr ToValueOf(const TypePtr &) const { return ToValue(); }
100 
101   /// \brief Get the operation that built this symbol.
operation()102   inline OpPtr operation() const { return operation_.lock(); }
103 
104   /// \brief Judge whether this object is an instance of a given class which is derived from Symbol.
105   template <typename T>
is()106   inline bool is() const {
107     auto *s = const_cast<Symbol *>(this)->real_symbol();
108     return s != nullptr && s->isa<T>();
109   }
110 
111   /// \brief Cast to a raw pointer of the given class, if the object type doesn't match, an exception will be thrown.
112   template <typename T>
as()113   inline T *as() {
114     auto ret = as_noexcept<T>();
115     if (MS_UNLIKELY(ret == nullptr)) {
116       MS_LOG(INTERNAL_EXCEPTION) << "Failed to cast the symbol " << ToString() << " to " << typeid(T).name();
117     }
118     return ret;
119   }
120 
121   /// \brief Cast to a raw pointer of the given class, if the object type doesn't match, an exception will be thrown.
122   template <typename T>
as()123   inline const T *as() const {
124     auto ret = as_noexcept<T>();
125     if (MS_UNLIKELY(ret == nullptr)) {
126       MS_LOG(INTERNAL_EXCEPTION) << "Failed to cast the symbol " << ToString() << " to " << typeid(T).name();
127     }
128     return ret;
129   }
130 
131   /// \brief Cast to a shared_ptr of the given class, if the object type doesn't match, an exception will be thrown.
132   template <typename T>
as_sptr()133   inline std::shared_ptr<T> as_sptr() {
134     auto ret = as_sptr_noexcept<T>();
135     if (MS_UNLIKELY(ret == nullptr)) {
136       MS_LOG(INTERNAL_EXCEPTION) << "Failed to cast the symbol " << ToString() << " to " << typeid(T).name();
137     }
138     return ret;
139   }
140 
141   /// \brief Cast to a raw pointer of the given class, if the object type doesn't match, a nullptr will be returned.
142   template <typename T>
as_noexcept()143   inline T *as_noexcept() {
144     auto s = real_symbol();
145     return MS_UNLIKELY(s == nullptr) ? nullptr : s->cast_ptr<T>();
146   }
147 
148   /// \brief Cast to a raw pointer of the given class, if the object type doesn't match, a nullptr will be returned.
149   template <typename T>
as_noexcept()150   inline const T *as_noexcept() const {
151     auto *s = const_cast<Symbol *>(this)->real_symbol();
152     return MS_UNLIKELY(s == nullptr) ? nullptr : s->cast_ptr<T>();
153   }
154 
155   /// \brief Cast to a shared_ptr of the given class, if the object type doesn't match, a nullptr will be returned.
156   template <typename T>
as_sptr_noexcept()157   inline std::shared_ptr<T> as_sptr_noexcept() {
158     auto s = real_symbol();
159     return MS_UNLIKELY(s == nullptr) ? nullptr : s->cast<std::shared_ptr<T>>();
160   }
161 
162  protected:
163   using Base::cast;
164   using Base::isa;
UpdateImpl(const SymbolPtr & s)165   virtual void UpdateImpl(const SymbolPtr &s) {
166     MS_EXCEPTION(NotImplementedError) << "The 'Update' of " << type_name() << " is not implemented.";
167   }
real_symbol()168   virtual Symbol *real_symbol() { return this; }
sid()169   inline std::string sid() const { return "s" + std::to_string(id()); }
170   OpWeakPtr operation_;
171 
172  private:
173   size_t id() const;
174   mutable size_t id_{0};
175 };
176 
177 /// \brief DynamicSymbol represents the symbol type is dynamic, such as "symbol of scalar or list".
178 class MS_CORE_API DynamicSymbol : public Symbol {
179  public:
180   using Symbol::Symbol;
181   ~DynamicSymbol() override = default;
MS_DECLARE_PARENT(DynamicSymbol,Symbol)182   MS_DECLARE_PARENT(DynamicSymbol, Symbol)
183   inline static std::shared_ptr<DynamicSymbol> Make(const OpPtr &op = nullptr) {
184     return std::make_shared<DynamicSymbol>(op);
185   }
186   bool operator==(const Symbol &s) const override { return (this == &s) || ((symbol_ != nullptr) && (*symbol_ == s)); }
HasData()187   bool HasData() const override { return symbol_ != nullptr; }
ToString()188   std::string ToString() const override { return symbol_ == nullptr ? "DYN-" + sid() : symbol_->ToString(); }
ToRawString()189   std::string ToRawString() const override { return symbol_ == nullptr ? sid() : symbol_->ToRawString(); }
ToValue()190   ValuePtr ToValue() const override { return symbol_ == nullptr ? Symbol::ToValue() : symbol_->ToValue(); }
ToValueOf(const TypePtr & type)191   ValuePtr ToValueOf(const TypePtr &type) const override {
192     return symbol_ == nullptr ? Symbol::ToValue() : symbol_->ToValueOf(type);
193   }
symbol()194   const SymbolPtr &symbol() const { return symbol_; }
195 
196  protected:
197   void UpdateImpl(const SymbolPtr &s) override;
real_symbol()198   Symbol *real_symbol() override { return symbol_.get(); }
199   SymbolPtr symbol_{nullptr};
200 };
201 using DynamicSymbolPtr = std::shared_ptr<DynamicSymbol>;
202 
203 /// \brief The base class of scalar objects.
204 class MS_CORE_API ScalarSymbol : public Symbol {
205  public:
ScalarSymbol(bool is_const,bool has_data,const OpPtr & op)206   ScalarSymbol(bool is_const, bool has_data, const OpPtr &op) : Symbol(op), is_const_(is_const), has_data_(has_data) {}
207   ~ScalarSymbol() override = default;
MS_DECLARE_PARENT(ScalarSymbol,Symbol)208   MS_DECLARE_PARENT(ScalarSymbol, Symbol)
209   bool HasData() const override { return has_data_; }
CanUpdate()210   bool CanUpdate() const override { return !is_const_; }
211   bool operator==(const Symbol &s) const override;
ToString()212   std::string ToString() const override { return ToRawString(); }
is_const()213   bool is_const() const { return is_const_; }
214 
215  protected:
216   void UpdateImpl(const SymbolPtr &s) override;
217   /// \brief set value, called by `UpdateImpl`
SetValueByScalar(const Symbol * s)218   virtual void SetValueByScalar(const Symbol *s) {
219     MS_EXCEPTION(NotImplementedError) << "The 'SetValueByScalar' of " << type_name() << " is not implemented.";
220   }
221   /// \brief check value equal, called by `operator==`
CheckEqualValue(const Symbol * s)222   virtual bool CheckEqualValue(const Symbol *s) const {
223     MS_EXCEPTION(NotImplementedError) << "The 'CheckEqualValue' of " << type_name() << " is not implemented.";
224   }
225 
226   bool is_const_;
227   bool has_data_;
228 };
229 using ScalarSymbolPtr = std::shared_ptr<ScalarSymbol>;
230 
231 class MS_CORE_API BoolSymbol final : public ScalarSymbol {
232  public:
233   using elem_type = bool;
234   using ScalarSymbol::ScalarSymbol;
235   ~BoolSymbol() override = default;
MS_DECLARE_PARENT(BoolSymbol,ScalarSymbol)236   MS_DECLARE_PARENT(BoolSymbol, ScalarSymbol)
237   static inline std::shared_ptr<BoolSymbol> Make(bool val, const OpPtr &op = nullptr) {
238     auto s = std::make_shared<BoolSymbol>(true, true, op);
239     s->value_ = val;
240     return s;
241   }
242   static inline std::shared_ptr<BoolSymbol> Make(const OpPtr &op = nullptr) {
243     return std::make_shared<BoolSymbol>(false, false, op);
244   }
SetValue(bool v)245   inline void SetValue(bool v) {
246     MS_EXCEPTION_IF_CHECK_FAIL(!is_const_, ToString() + " is const symbol and cannot be updated.");
247     has_data_ = true;
248     value_ = v;
249   }
value()250   inline bool value() const {
251     MS_EXCEPTION_IF_CHECK_FAIL(has_data_, ToString() + "has no value.");
252     return value_;
253   }
254   std::string ToRawString() const override;
255   ValuePtr ToValue() const override;
256 
257  protected:
SetValueByScalar(const Symbol * s)258   void SetValueByScalar(const Symbol *s) override { value_ = static_cast<const BoolSymbol *>(s)->value_; }
CheckEqualValue(const Symbol * s)259   bool CheckEqualValue(const Symbol *s) const override { return value_ == static_cast<const BoolSymbol *>(s)->value_; }
260 
261   bool value_{false};
262 };
263 using BoolSymbolPtr = std::shared_ptr<BoolSymbol>;
264 
265 class MS_CORE_API FloatSymbol final : public ScalarSymbol {
266  public:
267   using elem_type = double;
268   using ScalarSymbol::ScalarSymbol;
269   ~FloatSymbol() override = default;
MS_DECLARE_PARENT(FloatSymbol,ScalarSymbol)270   MS_DECLARE_PARENT(FloatSymbol, ScalarSymbol)
271   static inline std::shared_ptr<FloatSymbol> Make(elem_type val, const OpPtr &op = nullptr) {
272     auto s = std::make_shared<FloatSymbol>(true, true, op);
273     s->value_ = val;
274     return s;
275   }
276   static inline std::shared_ptr<FloatSymbol> Make(const OpPtr &op = nullptr) {
277     return std::make_shared<FloatSymbol>(false, false, op);
278   }
SetValue(elem_type v)279   inline void SetValue(elem_type v) {
280     MS_EXCEPTION_IF_CHECK_FAIL(!is_const_, ToString() + " is const symbol and cannot be updated.");
281     has_data_ = true;
282     value_ = v;
283   }
value()284   inline elem_type value() const {
285     MS_EXCEPTION_IF_CHECK_FAIL(has_data_, ToString() + "has no value.");
286     return value_;
287   }
288   std::string ToRawString() const override;
289   ValuePtr ToValue() const override;
290   ValuePtr ToValueOf(const TypePtr &type) const override;
291 
292  protected:
SetValueByScalar(const Symbol * s)293   void SetValueByScalar(const Symbol *s) override { value_ = static_cast<const FloatSymbol *>(s)->value_; }
CheckEqualValue(const Symbol * s)294   bool CheckEqualValue(const Symbol *s) const override { return value_ == static_cast<const FloatSymbol *>(s)->value_; }
295 
296   elem_type value_{0};
297 };
298 using FloatSymbolPtr = std::shared_ptr<FloatSymbol>;
299 
300 class MS_CORE_API StrSymbol final : public ScalarSymbol {
301  public:
302   using ScalarSymbol::ScalarSymbol;
303   ~StrSymbol() override = default;
MS_DECLARE_PARENT(StrSymbol,ScalarSymbol)304   MS_DECLARE_PARENT(StrSymbol, ScalarSymbol)
305   static inline std::shared_ptr<StrSymbol> Make(const std::string &val, const OpPtr &op = nullptr) {
306     auto s = std::make_shared<StrSymbol>(true, true, op);
307     s->value_ = val;
308     return s;
309   }
310   static inline std::shared_ptr<StrSymbol> Make(const OpPtr &op = nullptr) {
311     return std::make_shared<StrSymbol>(false, false, op);
312   }
SetValue(const std::string & v)313   inline void SetValue(const std::string &v) {
314     MS_EXCEPTION_IF_CHECK_FAIL(!is_const_, ToString() + " is const symbol and cannot be updated.");
315     has_data_ = true;
316     value_ = v;
317   }
value()318   inline const std::string &value() const {
319     MS_EXCEPTION_IF_CHECK_FAIL(has_data_, ToString() + "has no value.");
320     return value_;
321   }
322   std::string ToRawString() const override;
323   ValuePtr ToValue() const override;
324 
325  protected:
SetValueByScalar(const Symbol * s)326   void SetValueByScalar(const Symbol *s) override { value_ = static_cast<const StrSymbol *>(s)->value_; }
CheckEqualValue(const Symbol * s)327   bool CheckEqualValue(const Symbol *s) const override { return value_ == static_cast<const StrSymbol *>(s)->value_; }
328 
329   std::string value_;
330 };
331 using StrSymbolPtr = std::shared_ptr<StrSymbol>;
332 
333 class MS_CORE_API ListSymbol final : public Symbol {
334  public:
335   using SPtr = std::shared_ptr<ListSymbol>;
ListSymbol(const SymbolPtrList & slist,const OpPtr & op)336   ListSymbol(const SymbolPtrList &slist, const OpPtr &op) : Symbol(op), symbols_(slist) {}
ListSymbol(SymbolPtrList && slist,const OpPtr & op)337   ListSymbol(SymbolPtrList &&slist, const OpPtr &op) : Symbol(op), symbols_(slist) {}
ListSymbol(const std::initializer_list<SymbolPtr> & slist,const OpPtr & op)338   ListSymbol(const std::initializer_list<SymbolPtr> &slist, const OpPtr &op) : Symbol(op), symbols_(slist) {}
ListSymbol(const OpPtr & op)339   explicit ListSymbol(const OpPtr &op) : Symbol(op), is_dyn_len_(true), has_data_(false) {}
340   ~ListSymbol() override = default;
MS_DECLARE_PARENT(ListSymbol,Symbol)341   MS_DECLARE_PARENT(ListSymbol, Symbol)
342 
343   static inline SPtr Make(const SymbolPtrList &slist, const OpPtr &op = nullptr) {
344     return std::make_shared<ListSymbol>(slist, op);
345   }
346   static inline SPtr Make(SymbolPtrList &&slist, const OpPtr &op = nullptr) {
347     return std::make_shared<ListSymbol>(slist, op);
348   }
349   static inline SPtr Make(const std::initializer_list<SymbolPtr> &slist, const OpPtr &op = nullptr) {
350     return std::make_shared<ListSymbol>(slist, op);
351   }
352   static inline SPtr Make(const OpPtr &op = nullptr) { return std::make_shared<ListSymbol>(op); }
353 
354   bool operator==(const Symbol &s) const override;
355   std::string ToString() const override;
356   std::string ToRawString() const override;
357   ValuePtr ToValue() const override;
358   ValuePtr ToValueOf(const TypePtr &type) const override;
359 
HasData()360   bool HasData() const override { return has_data_; }
AllHaveData()361   bool AllHaveData() const {
362     return has_data_ && std::all_of(symbols_.cbegin(), symbols_.cend(), [](const SymbolPtr &s) {
363              return s->is<ListSymbol>() ? s->as_noexcept<ListSymbol>()->AllHaveData() : s->HasData();
364            });
365   }
CanUpdate()366   bool CanUpdate() const override {
367     return is_dyn_len_ || std::any_of(symbols_.cbegin(), symbols_.cend(), [](auto &s) { return s->CanUpdate(); });
368   }
369   void UpdateList(const SymbolPtrList &slist);
UpdateList(SymbolPtrList && slist)370   inline void UpdateList(SymbolPtrList &&slist) {
371     if (is_dyn_len_) {
372       has_data_ = true;
373       symbols_ = slist;
374     } else {
375       UpdateList(static_cast<const SymbolPtrList &>(slist));
376     }
377   }
378   const SymbolPtr &item(size_t i) const;
379   template <typename T>
item_as(size_t i)380   const T *item_as(size_t i) const {
381     auto ret = item(i)->as_noexcept<T>();
382     if (MS_UNLIKELY(ret == nullptr)) {
383       MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for item " << i << " of " << ToString();
384     }
385     return ret;
386   }
387   template <typename T>
item_as_sptr(size_t i)388   std::shared_ptr<T> item_as_sptr(size_t i) const {
389     auto ret = item(i)->as_sptr_noexcept<T>();
390     if (MS_UNLIKELY(ret == nullptr)) {
391       MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for item " << i << " of " << ToString();
392     }
393     return ret;
394   }
symbols()395   const SymbolPtrList &symbols() const { return symbols_; }
size()396   size_t size() const { return symbols_.size(); }
is_dyn_len()397   bool is_dyn_len() const { return is_dyn_len_; }
398 
399  protected:
400   void UpdateImpl(const SymbolPtr &s) override;
401   SymbolPtrList symbols_;
402   bool is_dyn_len_{false};
403   bool has_data_{true};
404 };
405 using ListSymbolPtr = std::shared_ptr<ListSymbol>;
406 }  // namespace symshape
407 
408 using symshape::BoolSymbol;
409 using symshape::BoolSymbolPtr;
410 using symshape::DynamicSymbol;
411 using symshape::DynamicSymbolPtr;
412 using symshape::FloatSymbol;
413 using symshape::FloatSymbolPtr;
414 using symshape::IntSymbol;
415 using symshape::IntSymbolPtr;
416 using symshape::ListSymbol;
417 using symshape::ListSymbolPtr;
418 using symshape::ScalarSymbol;
419 using symshape::ScalarSymbolPtr;
420 using symshape::Symbol;
421 using symshape::SymbolPtr;
422 using symshape::SymbolPtrList;
423 }  // namespace mindspore
424 #endif  // MINDSPORE_CORE_SYMBOLIC_SHAPE_SYMBOL_H_
425