• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_ABSTRACT_DSHAPE_H_
20 #define MINDSPORE_CORE_ABSTRACT_DSHAPE_H_
21 
22 #include <vector>
23 #include <string>
24 #include <sstream>
25 #include <typeindex>
26 #include <memory>
27 #include <utility>
28 #include <algorithm>
29 
30 #include "utils/hashing.h"
31 #include "utils/log_adapter.h"
32 #include "base/base.h"
33 #include "mindapi/base/shape_vector.h"
34 #include "mindspore/core/symbolic_shape/symbol.h"
35 
36 namespace mindspore {
37 namespace abstract {
38 class BaseShape;
39 using BaseShapePtr = std::shared_ptr<BaseShape>;
40 using BaseShapePtrList = std::vector<BaseShapePtr>;
41 
42 /// \brief BaseShape defines the basic virtual class of NoShape and Shape classes.
43 class MS_CORE_API BaseShape : public Base {
44  public:
45   /// \brief Constructor of BaseShape.
46   BaseShape() = default;
47 
48   /// \brief Destructor of BaseShape.
49   ~BaseShape() override = default;
50 
51   MS_DECLARE_PARENT(BaseShape, Base)
52 
53   /// \brief Check whether 2 objects are equal.
54   ///
55   /// \param[in] other Another object.
56   /// \return True if current object is equal to another, otherwise false.
57   virtual bool operator==(const BaseShape &other) const;
58 
59   /// \brief Check whether 2 objects are not equal.
60   ///
61   /// \param[in] other Another object.
62   /// \return True if current object is not equal to another, otherwise false.
63   bool operator!=(const BaseShape &other) const;
64 
65   /// \brief Calculate the hash value of BaseShape.
66   ///
67   /// \return The hash value of BaseShape.
hash()68   std::size_t hash() const override { return tid(); }
69 
70   /// \brief Whether the object's dimensions are dynamic.
71   ///
72   /// \return True if the object's dimensions are dynamic, otherwise false.
73   virtual bool IsDynamic() const = 0;
74 
75   /// \brief Whether the object's dimension is zero.
76   ///
77   /// \return True if the object's dimension is zero, otherwise false.
78   virtual bool IsDimZero() const = 0;
79 
80   /// \brief Whether the object's dimensions are unknown.
81   ///
82   /// \return True if the object's dimensions are unknown, otherwise false.
83   virtual bool IsDimUnknown() const = 0;
84 
85   /// \brief Clone a new object by this one.
86   ///
87   /// \return New cloned object.
88   virtual BaseShapePtr Clone() const = 0;
89 
90   /// \brief Broaden the shape.
Broaden()91   virtual void Broaden() {}
92 
93   /// \brief Get shape dimensions of BaseShape object.
94   ///
95   /// \return Shape dimensions.
GetShapeVector()96   virtual const ShapeVector &GetShapeVector() const {
97     MS_LOG(EXCEPTION) << "The method 'GetShapeVector()' doesn't implement";
98   }
99 
100   /// \brief Set shape dimensions of BaseShape object.
101   ///
102   /// \param[in] shape Dimensions of shape.
SetShapeVector(const ShapeVector & shape)103   virtual void SetShapeVector(const ShapeVector &shape) {
104     MS_LOG(EXCEPTION) << "The method 'SetShapeVector()' doesn't implement";
105   }
106 
107   /// \brief Build symbolic shape according to the digital shape.
108   /// Constant symbols are generated for static dims, and variable symbols are generated for dynamic dims.
109   ///
110   /// \return Symbolic Shape.
BuildSymbolicShape()111   virtual ListSymbolPtr BuildSymbolicShape() const {
112     MS_LOG(EXCEPTION) << "The method 'BuildSymbolicShape()' doesn't implement";
113   }
114 };
115 
116 /// \brief NoShape defines an invalid shape.
117 class MS_CORE_API NoShape final : public BaseShape {
118  public:
MS_DECLARE_PARENT(NoShape,BaseShape)119   MS_DECLARE_PARENT(NoShape, BaseShape)
120 
121   BaseShapePtr Clone() const override { return std::make_shared<NoShape>(); }
122 
123   /// \brief Get the description string about the NoShape object.
124   ///
125   /// \return The description string about the NoShape object.
ToString()126   std::string ToString() const override { return type_name(); }
127 
IsDynamic()128   bool IsDynamic() const override { return false; }
129 
IsDimZero()130   bool IsDimZero() const override { return true; };
131 
IsDimUnknown()132   bool IsDimUnknown() const override { return false; }
133 
BuildSymbolicShape()134   ListSymbolPtr BuildSymbolicShape() const override { return ListSymbol::Make({}); }
135 };
136 
137 GVAR_DEF(std::shared_ptr<NoShape>, kNoShape, std::make_shared<NoShape>());
138 
139 /// \brief TensorShape defines dimensions of tensor.
140 class MS_CORE_API TensorShape final : public BaseShape {
141  public:
142   static constexpr ShapeValueDType kShapeDimAny = -1;
143   static constexpr ShapeValueDType kShapeRankAny = -2;
144   static constexpr ShapeValueDType kShapeError = -3;
145   static constexpr size_t kDynamicRankLen = 1;
146 
147   /// \brief Constructor of TensorShape.
TensorShape()148   TensorShape() : shape_() {}
149 
150   /// \brief Constructor of TensorShape.
151   ///
152   /// \param[in] list Initial shape dimensions.
TensorShape(const std::initializer_list<ShapeValueDType> & list)153   TensorShape(const std::initializer_list<ShapeValueDType> &list) : shape_(list) {}
154 
155   /// \brief Constructor of TensorShape.
156   ///
157   /// \param[in] list Initial shape dimensions.
TensorShape(const ShapeVector & list)158   explicit TensorShape(const ShapeVector &list) : shape_(list) {}
159 
160   /// \brief Constructor of TensorShape with rvalue input.
161   ///
162   /// \param[in] list Initial shape dimensions.
TensorShape(ShapeVector && list)163   explicit TensorShape(ShapeVector &&list) : shape_(std::move(list)) {}
164 
165   /// \brief Constructor of Shape.
166   ///
167   /// \param[in] list Initial shape dimensions.
168   /// \param[in] max_shape Maximum shape dimensions of dynamic shape.
TensorShape(const ShapeVector & list,const ShapeVector & max_shape)169   TensorShape(const ShapeVector &list, const ShapeVector &max_shape) : shape_(list), max_shape_(max_shape) {}
170 
171   /// \brief Destructor of TensorShape.
172   ~TensorShape() override = default;
MS_DECLARE_PARENT(TensorShape,BaseShape)173   MS_DECLARE_PARENT(TensorShape, BaseShape)
174 
175   /// \brief Calculate the hash value for TensorShape.
176   ///
177   /// \return The hash value of TensorShape.
178   std::size_t hash() const override {
179     auto hash_code = static_cast<std::size_t>(tid());
180     for (auto dim : shape_) {
181       hash_code = hash_combine(hash_code, static_cast<size_t>(dim));
182     }
183     return hash_code;
184   }
185 
186   /// \brief Get the description string about the TensorShape object.
187   ///
188   /// \return The description string about the TensorShape object.
189   std::string ToString() const override;
190 
191   /// \brief Get the debug information about the TensorShape object.
192   ///
193   /// \return The debug information about the TensorShape object.
194   std::string DumpText() const override;
195 
196   bool operator==(const BaseShape &other) const override;
197 
Clone()198   BaseShapePtr Clone() const override { return std::make_shared<TensorShape>(shape_); }
199 
200   void Broaden() override;
201 
202   /// \brief Set shape dimensions of TensorShape object.
203   ///
204   /// \param[in] shape Dimensions of shape.
set_shape(const ShapeVector & shape)205   void set_shape(const ShapeVector &shape) { shape_ = shape; }
206 
207   /// \brief Get shape dimensions.
208   ///
209   /// \return TensorShape dimensions.
shape()210   const ShapeVector &shape() const { return shape_; }
211 
212   /// \brief Get maximum shape dimensions.
213   ///
214   /// \return Maximum shape dimensions.
max_shape()215   const ShapeVector &max_shape() const { return max_shape_; }
216 
217   /// \brief Get shape dimensions of a tensor shape.
218   ///
219   /// \return Shape dimensions.
GetShapeVector()220   const ShapeVector &GetShapeVector() const override { return shape_; }
221 
222   /// \brief Set shape dimensions of TensorShape object.
223   ///
224   /// \param[in] shape Dimensions of shape.
SetShapeVector(const ShapeVector & shape)225   void SetShapeVector(const ShapeVector &shape) override { shape_ = shape; }
226 
227   bool IsDynamic() const override;
228 
IsDimZero()229   bool IsDimZero() const override { return shape_.empty(); };
230 
IsDimUnknown()231   bool IsDimUnknown() const override {
232     return std::any_of(shape_.begin(), shape_.end(), [](ShapeValueDType s) { return s < -1; });
233   }
234 
235   ListSymbolPtr BuildSymbolicShape() const override;
236 
237  private:
238   ShapeVector shape_;      // use kShapeDimAny to implement the any shape in python
239   ShapeVector max_shape_;  // record maximum length for each dynamic dimension
240 };
241 using TensorShapePtr = std::shared_ptr<TensorShape>;
242 using TensorShapePtrList = std::vector<TensorShapePtr>;
243 // `Shape` is deprecated, which will be removed in the future, please use `TensorShape` instead.
244 using Shape = TensorShape;
245 using ShapePtr = TensorShapePtr;
246 using ShapePtrList = TensorShapePtrList;
247 
248 /// \brief DynamicSequenceShape defines shape of dynamic sequence.
249 class MS_CORE_API DynamicSequenceShape : public BaseShape {
250  public:
251   /// \brief Constructor of DynamicSequenceShape.
252   DynamicSequenceShape() = default;
253 
254   /// \brief Constructor of DynamicSequenceShape.
DynamicSequenceShape(const BaseShapePtr & element_shape)255   explicit DynamicSequenceShape(const BaseShapePtr &element_shape) : element_shape_(element_shape) {}
256 
257   /// \brief Destructor of DynamicSequenceShape.
258   ~DynamicSequenceShape() override = default;
259   MS_DECLARE_PARENT(DynamicSequenceShape, BaseShape);
260 
261   /// \brief Get the description string about the DynamicSequenceShape object.
262   ///
263   /// \return The description string about the DynamicSequenceShape object.
ToString()264   std::string ToString() const override { return type_name(); }
265 
266   /// \brief Check whether any element shape of DynamicSequenceShape is dynamic shape or dynamic rank.
267   ///
268   /// \return True if any element shape of DynamicSequenceShape is dynamic shape or dynamic rank, otherwise false.
269   bool IsDynamic() const override;
270 
271   /// \brief Check whether all elements shape of DynamicSequenceShape are empty shape.
272   ///
273   /// \return True if all elements shape of DynamicSequenceShape are empty shape.
274   bool IsDimZero() const override;
275 
276   /// \brief Check whether any element shape of DynamicSequenceShape is dynamic shape.
277   ///
278   /// \return True if any element shape of DynamicSequenceShape is dynamic shape.
279   bool IsDimUnknown() const override;
280 
BuildSymbolicShape()281   ListSymbolPtr BuildSymbolicShape() const override { return ListSymbol::Make(); }
282 
Clone()283   BaseShapePtr Clone() const override {
284     if (element_shape_ == nullptr) {
285       return std::make_shared<DynamicSequenceShape>(nullptr);
286     }
287     return std::make_shared<DynamicSequenceShape>(element_shape_->Clone());
288   }
289 
290   bool operator==(const BaseShape &other) const override;
291 
292   /// \brief Calculate the hash value for DynamicSequenceShape.
293   ///
294   /// \return The hash value of Shape.
295   std::size_t hash() const override;
296 
element_shape()297   BaseShapePtr element_shape() { return element_shape_; }
298 
299  private:
300   // element's shape
301   BaseShapePtr element_shape_{nullptr};
302 };
303 using DynamicSequenceShapePtr = std::shared_ptr<DynamicSequenceShape>;
304 GVAR_DEF(std::shared_ptr<DynamicSequenceShape>, kDynamicSequenceShape, std::make_shared<DynamicSequenceShape>());
305 
306 /// \brief SequequeShape defines base class of multiple-shape classes.
307 class MS_CORE_API SequenceShape : public BaseShape {
308  public:
309   /// \brief Constructor of SequenceShape.
SequenceShape()310   SequenceShape() : p_shapes_() {}
311 
312   /// \brief Constructor of SequenceShape.
313   ///
314   /// \param[in]  shapes All element-shapes.
SequenceShape(const BaseShapePtrList & shapes)315   explicit SequenceShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {}
316 
317   /// \brief Constructor of SequenceShape with rvalue inputs.
318   ///
319   /// \param[in] shapes All element-shapes.
SequenceShape(BaseShapePtrList && shapes)320   explicit SequenceShape(BaseShapePtrList &&shapes) : p_shapes_(std::move(shapes)) {}
321 
322   /// \brief Destructor of SequenceShape.
323   ~SequenceShape() override = default;
324   MS_DECLARE_PARENT(SequenceShape, BaseShape)
325 
326   /// \brief Get the description string about the SequenceShape object.
327   ///
328   /// \return The description string about the SequenceShape object.
329   std::string ToString() const override;
330 
331   /// \brief Clone all element-shapes.
332   ///
333   /// \return New cloned element-shapes.
334   BaseShapePtrList ElementsClone() const;
335 
336   /// \brief Check whether SequenceShape object is equal to a BaseShape object.
337   ///
338   /// \param[in] other Another SequenceShape object.
339   /// \return True if current SequenceShape object is equal to another BaseShape object, otherwise false.
340   template <typename T>
SequenceEqual(const BaseShape & other)341   bool SequenceEqual(const BaseShape &other) const {
342     if (tid() != other.tid()) {
343       return false;
344     }
345     auto &other_shapes = static_cast<const T &>(other).p_shapes_;
346     if (other_shapes.size() != p_shapes_.size()) {
347       return false;
348     }
349     for (uint64_t i = 0; i < p_shapes_.size(); ++i) {
350       MS_EXCEPTION_IF_NULL(p_shapes_[i]);
351       MS_EXCEPTION_IF_NULL(other_shapes[i]);
352       if (!(*p_shapes_[i] == *other_shapes[i])) {
353         return false;
354       }
355     }
356     return true;
357   }
358 
359   /// \brief Get all element-shapes.
360   ///
361   /// \return  All element-shapes.
shape()362   const BaseShapePtrList &shape() const { return p_shapes_; }
363 
364   /// \brief Get the number of element-shapes.
365   ///
366   /// \return The number of element-shapes.
size()367   size_t size() const { return p_shapes_.size(); }
368 
369   /// \brief Get the element-shape by index through operator '[]'.
370   ///
371   /// \param[in] dim The index of element shape.
372   /// \return The element shape got by index.
373   const BaseShapePtr &operator[](std::size_t dim) const { return p_shapes_[dim]; }
374 
375   /// \brief Check whether any element shape of DynamicSequenceShape is dynamic shape or dynamic rank.
376   ///
377   /// \return True if any element shape of DynamicSequenceShape is dynamic shape or dynamic rank, otherwise false.
IsDynamic()378   bool IsDynamic() const override {
379     return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); });
380   }
381 
382   /// \brief Check whether all elements shape of DynamicSequenceShape are empty shape.
383   ///
384   /// \return True if all elements shape of DynamicSequenceShape are empty shape.
IsDimZero()385   bool IsDimZero() const override {
386     return std::all_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimZero(); });
387   };
388 
389   /// \brief Check whether any element shape of DynamicSequenceShape is dynamic shape.
390   ///
391   /// \return True if any element shape of DynamicSequenceShape is dynamic shape.
IsDimUnknown()392   bool IsDimUnknown() const override {
393     return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimUnknown(); });
394   }
395 
396   ListSymbolPtr BuildSymbolicShape() const override;
397 
398  protected:
399   BaseShapePtrList p_shapes_;  // shape list of each elements
400 };
401 using SequenceShapePtr = std::shared_ptr<SequenceShape>;
402 
403 /// \brief TupleShape defines shape used by tuple with tensor inside.
404 class MS_CORE_API TupleShape final : public SequenceShape {
405  public:
406   /// \brief Constructor of TupleShape.
TupleShape()407   TupleShape() : SequenceShape() {}
408 
409   /// \brief Constructor of TupleShape.
410   ///
411   /// \param[in] shapes Element-shapes of TupleShape.
TupleShape(const BaseShapePtrList & shapes)412   explicit TupleShape(const BaseShapePtrList &shapes) : SequenceShape(shapes) {}
413 
414   /// \brief Constructor of TupleShape with rvalue input.
415   ///
416   /// \param[in] shapes Element-shapes of TupleShape.
TupleShape(BaseShapePtrList && shapes)417   explicit TupleShape(BaseShapePtrList &&shapes) : SequenceShape(std::move(shapes)) {}
418 
419   /// \brief Destructor of TupleShape.
420   ~TupleShape() override = default;
MS_DECLARE_PARENT(TupleShape,SequenceShape)421   MS_DECLARE_PARENT(TupleShape, SequenceShape)
422 
423   std::string ToString() const override { return type_name() + "(" + SequenceShape::ToString() + ")"; }
424 
Clone()425   BaseShapePtr Clone() const override { return std::make_shared<TupleShape>(ElementsClone()); }
426 
427   bool operator==(const BaseShape &other) const override { return SequenceEqual<TupleShape>(other); }
428 };
429 using TupleShapePtr = std::shared_ptr<TupleShape>;
430 
431 /// \brief ListShape defines shape used by list with tensor inside.
432 class MS_CORE_API ListShape final : public SequenceShape {
433  public:
434   /// \brief Constructor of ListShape.
ListShape()435   ListShape() : SequenceShape() {}
436   /// \brief Constructor of ListShape.
437   ///
438   /// \param[in] shapes Element-shapes of ListShape.
ListShape(const BaseShapePtrList & shapes)439   explicit ListShape(const BaseShapePtrList &shapes) : SequenceShape(shapes) {}
440 
441   /// \brief Constructor of ListShape with rvalue input.
442   ///
443   /// \param[in] shapes Element-shapes of ListShape.
ListShape(BaseShapePtrList && shapes)444   explicit ListShape(BaseShapePtrList &&shapes) : SequenceShape(std::move(shapes)) {}
445 
446   /// \brief Destructor of ListShape.
447   ~ListShape() override = default;
MS_DECLARE_PARENT(ListShape,SequenceShape)448   MS_DECLARE_PARENT(ListShape, SequenceShape)
449 
450   std::string ToString() const override { return type_name() + "[" + SequenceShape::ToString() + "]"; }
451 
Clone()452   BaseShapePtr Clone() const override { return std::make_shared<ListShape>(SequenceShape::ElementsClone()); }
453 
454   bool operator==(const BaseShape &other) const override { return SequenceEqual<ListShape>(other); }
455 };
456 using ListShapePtr = std::shared_ptr<ListShape>;
457 }  // namespace abstract
458 }  // namespace mindspore
459 
460 #endif  // MINDSPORE_CORE_ABSTRACT_DSHAPE_H_
461