1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_ABSTRACT_DSHAPE_H_ 20 #define MINDSPORE_CORE_ABSTRACT_DSHAPE_H_ 21 22 #include <vector> 23 #include <string> 24 #include <sstream> 25 #include <typeindex> 26 #include <memory> 27 #include <utility> 28 #include <algorithm> 29 30 #include "utils/hashing.h" 31 #include "utils/log_adapter.h" 32 #include "base/base.h" 33 #include "mindapi/base/shape_vector.h" 34 #include "mindspore/core/symbolic_shape/symbol.h" 35 36 namespace mindspore { 37 namespace abstract { 38 class BaseShape; 39 using BaseShapePtr = std::shared_ptr<BaseShape>; 40 using BaseShapePtrList = std::vector<BaseShapePtr>; 41 42 /// \brief BaseShape defines the basic virtual class of NoShape and Shape classes. 43 class MS_CORE_API BaseShape : public Base { 44 public: 45 /// \brief Constructor of BaseShape. 46 BaseShape() = default; 47 48 /// \brief Destructor of BaseShape. 49 ~BaseShape() override = default; 50 51 MS_DECLARE_PARENT(BaseShape, Base) 52 53 /// \brief Check whether 2 objects are equal. 54 /// 55 /// \param[in] other Another object. 56 /// \return True if current object is equal to another, otherwise false. 57 virtual bool operator==(const BaseShape &other) const; 58 59 /// \brief Check whether 2 objects are not equal. 60 /// 61 /// \param[in] other Another object. 62 /// \return True if current object is not equal to another, otherwise false. 63 bool operator!=(const BaseShape &other) const; 64 65 /// \brief Calculate the hash value of BaseShape. 66 /// 67 /// \return The hash value of BaseShape. hash()68 std::size_t hash() const override { return tid(); } 69 70 /// \brief Whether the object's dimensions are dynamic. 71 /// 72 /// \return True if the object's dimensions are dynamic, otherwise false. 73 virtual bool IsDynamic() const = 0; 74 75 /// \brief Whether the object's dimension is zero. 76 /// 77 /// \return True if the object's dimension is zero, otherwise false. 78 virtual bool IsDimZero() const = 0; 79 80 /// \brief Whether the object's dimensions are unknown. 81 /// 82 /// \return True if the object's dimensions are unknown, otherwise false. 83 virtual bool IsDimUnknown() const = 0; 84 85 /// \brief Clone a new object by this one. 86 /// 87 /// \return New cloned object. 88 virtual BaseShapePtr Clone() const = 0; 89 90 /// \brief Broaden the shape. Broaden()91 virtual void Broaden() {} 92 93 /// \brief Get shape dimensions of BaseShape object. 94 /// 95 /// \return Shape dimensions. GetShapeVector()96 virtual const ShapeVector &GetShapeVector() const { 97 MS_LOG(EXCEPTION) << "The method 'GetShapeVector()' doesn't implement"; 98 } 99 100 /// \brief Set shape dimensions of BaseShape object. 101 /// 102 /// \param[in] shape Dimensions of shape. SetShapeVector(const ShapeVector & shape)103 virtual void SetShapeVector(const ShapeVector &shape) { 104 MS_LOG(EXCEPTION) << "The method 'SetShapeVector()' doesn't implement"; 105 } 106 107 /// \brief Build symbolic shape according to the digital shape. 108 /// Constant symbols are generated for static dims, and variable symbols are generated for dynamic dims. 109 /// 110 /// \return Symbolic Shape. BuildSymbolicShape()111 virtual ListSymbolPtr BuildSymbolicShape() const { 112 MS_LOG(EXCEPTION) << "The method 'BuildSymbolicShape()' doesn't implement"; 113 } 114 }; 115 116 /// \brief NoShape defines an invalid shape. 117 class MS_CORE_API NoShape final : public BaseShape { 118 public: MS_DECLARE_PARENT(NoShape,BaseShape)119 MS_DECLARE_PARENT(NoShape, BaseShape) 120 121 BaseShapePtr Clone() const override { return std::make_shared<NoShape>(); } 122 123 /// \brief Get the description string about the NoShape object. 124 /// 125 /// \return The description string about the NoShape object. ToString()126 std::string ToString() const override { return type_name(); } 127 IsDynamic()128 bool IsDynamic() const override { return false; } 129 IsDimZero()130 bool IsDimZero() const override { return true; }; 131 IsDimUnknown()132 bool IsDimUnknown() const override { return false; } 133 BuildSymbolicShape()134 ListSymbolPtr BuildSymbolicShape() const override { return ListSymbol::Make({}); } 135 }; 136 137 GVAR_DEF(std::shared_ptr<NoShape>, kNoShape, std::make_shared<NoShape>()); 138 139 /// \brief TensorShape defines dimensions of tensor. 140 class MS_CORE_API TensorShape final : public BaseShape { 141 public: 142 static constexpr ShapeValueDType kShapeDimAny = -1; 143 static constexpr ShapeValueDType kShapeRankAny = -2; 144 static constexpr ShapeValueDType kShapeError = -3; 145 static constexpr size_t kDynamicRankLen = 1; 146 147 /// \brief Constructor of TensorShape. TensorShape()148 TensorShape() : shape_() {} 149 150 /// \brief Constructor of TensorShape. 151 /// 152 /// \param[in] list Initial shape dimensions. TensorShape(const std::initializer_list<ShapeValueDType> & list)153 TensorShape(const std::initializer_list<ShapeValueDType> &list) : shape_(list) {} 154 155 /// \brief Constructor of TensorShape. 156 /// 157 /// \param[in] list Initial shape dimensions. TensorShape(const ShapeVector & list)158 explicit TensorShape(const ShapeVector &list) : shape_(list) {} 159 160 /// \brief Constructor of TensorShape with rvalue input. 161 /// 162 /// \param[in] list Initial shape dimensions. TensorShape(ShapeVector && list)163 explicit TensorShape(ShapeVector &&list) : shape_(std::move(list)) {} 164 165 /// \brief Constructor of Shape. 166 /// 167 /// \param[in] list Initial shape dimensions. 168 /// \param[in] max_shape Maximum shape dimensions of dynamic shape. TensorShape(const ShapeVector & list,const ShapeVector & max_shape)169 TensorShape(const ShapeVector &list, const ShapeVector &max_shape) : shape_(list), max_shape_(max_shape) {} 170 171 /// \brief Destructor of TensorShape. 172 ~TensorShape() override = default; MS_DECLARE_PARENT(TensorShape,BaseShape)173 MS_DECLARE_PARENT(TensorShape, BaseShape) 174 175 /// \brief Calculate the hash value for TensorShape. 176 /// 177 /// \return The hash value of TensorShape. 178 std::size_t hash() const override { 179 auto hash_code = static_cast<std::size_t>(tid()); 180 for (auto dim : shape_) { 181 hash_code = hash_combine(hash_code, static_cast<size_t>(dim)); 182 } 183 return hash_code; 184 } 185 186 /// \brief Get the description string about the TensorShape object. 187 /// 188 /// \return The description string about the TensorShape object. 189 std::string ToString() const override; 190 191 /// \brief Get the debug information about the TensorShape object. 192 /// 193 /// \return The debug information about the TensorShape object. 194 std::string DumpText() const override; 195 196 bool operator==(const BaseShape &other) const override; 197 Clone()198 BaseShapePtr Clone() const override { return std::make_shared<TensorShape>(shape_); } 199 200 void Broaden() override; 201 202 /// \brief Set shape dimensions of TensorShape object. 203 /// 204 /// \param[in] shape Dimensions of shape. set_shape(const ShapeVector & shape)205 void set_shape(const ShapeVector &shape) { shape_ = shape; } 206 207 /// \brief Get shape dimensions. 208 /// 209 /// \return TensorShape dimensions. shape()210 const ShapeVector &shape() const { return shape_; } 211 212 /// \brief Get maximum shape dimensions. 213 /// 214 /// \return Maximum shape dimensions. max_shape()215 const ShapeVector &max_shape() const { return max_shape_; } 216 217 /// \brief Get shape dimensions of a tensor shape. 218 /// 219 /// \return Shape dimensions. GetShapeVector()220 const ShapeVector &GetShapeVector() const override { return shape_; } 221 222 /// \brief Set shape dimensions of TensorShape object. 223 /// 224 /// \param[in] shape Dimensions of shape. SetShapeVector(const ShapeVector & shape)225 void SetShapeVector(const ShapeVector &shape) override { shape_ = shape; } 226 227 bool IsDynamic() const override; 228 IsDimZero()229 bool IsDimZero() const override { return shape_.empty(); }; 230 IsDimUnknown()231 bool IsDimUnknown() const override { 232 return std::any_of(shape_.begin(), shape_.end(), [](ShapeValueDType s) { return s < -1; }); 233 } 234 235 ListSymbolPtr BuildSymbolicShape() const override; 236 237 private: 238 ShapeVector shape_; // use kShapeDimAny to implement the any shape in python 239 ShapeVector max_shape_; // record maximum length for each dynamic dimension 240 }; 241 using TensorShapePtr = std::shared_ptr<TensorShape>; 242 using TensorShapePtrList = std::vector<TensorShapePtr>; 243 // `Shape` is deprecated, which will be removed in the future, please use `TensorShape` instead. 244 using Shape = TensorShape; 245 using ShapePtr = TensorShapePtr; 246 using ShapePtrList = TensorShapePtrList; 247 248 /// \brief DynamicSequenceShape defines shape of dynamic sequence. 249 class MS_CORE_API DynamicSequenceShape : public BaseShape { 250 public: 251 /// \brief Constructor of DynamicSequenceShape. 252 DynamicSequenceShape() = default; 253 254 /// \brief Constructor of DynamicSequenceShape. DynamicSequenceShape(const BaseShapePtr & element_shape)255 explicit DynamicSequenceShape(const BaseShapePtr &element_shape) : element_shape_(element_shape) {} 256 257 /// \brief Destructor of DynamicSequenceShape. 258 ~DynamicSequenceShape() override = default; 259 MS_DECLARE_PARENT(DynamicSequenceShape, BaseShape); 260 261 /// \brief Get the description string about the DynamicSequenceShape object. 262 /// 263 /// \return The description string about the DynamicSequenceShape object. ToString()264 std::string ToString() const override { return type_name(); } 265 266 /// \brief Check whether any element shape of DynamicSequenceShape is dynamic shape or dynamic rank. 267 /// 268 /// \return True if any element shape of DynamicSequenceShape is dynamic shape or dynamic rank, otherwise false. 269 bool IsDynamic() const override; 270 271 /// \brief Check whether all elements shape of DynamicSequenceShape are empty shape. 272 /// 273 /// \return True if all elements shape of DynamicSequenceShape are empty shape. 274 bool IsDimZero() const override; 275 276 /// \brief Check whether any element shape of DynamicSequenceShape is dynamic shape. 277 /// 278 /// \return True if any element shape of DynamicSequenceShape is dynamic shape. 279 bool IsDimUnknown() const override; 280 BuildSymbolicShape()281 ListSymbolPtr BuildSymbolicShape() const override { return ListSymbol::Make(); } 282 Clone()283 BaseShapePtr Clone() const override { 284 if (element_shape_ == nullptr) { 285 return std::make_shared<DynamicSequenceShape>(nullptr); 286 } 287 return std::make_shared<DynamicSequenceShape>(element_shape_->Clone()); 288 } 289 290 bool operator==(const BaseShape &other) const override; 291 292 /// \brief Calculate the hash value for DynamicSequenceShape. 293 /// 294 /// \return The hash value of Shape. 295 std::size_t hash() const override; 296 element_shape()297 BaseShapePtr element_shape() { return element_shape_; } 298 299 private: 300 // element's shape 301 BaseShapePtr element_shape_{nullptr}; 302 }; 303 using DynamicSequenceShapePtr = std::shared_ptr<DynamicSequenceShape>; 304 GVAR_DEF(std::shared_ptr<DynamicSequenceShape>, kDynamicSequenceShape, std::make_shared<DynamicSequenceShape>()); 305 306 /// \brief SequequeShape defines base class of multiple-shape classes. 307 class MS_CORE_API SequenceShape : public BaseShape { 308 public: 309 /// \brief Constructor of SequenceShape. SequenceShape()310 SequenceShape() : p_shapes_() {} 311 312 /// \brief Constructor of SequenceShape. 313 /// 314 /// \param[in] shapes All element-shapes. SequenceShape(const BaseShapePtrList & shapes)315 explicit SequenceShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {} 316 317 /// \brief Constructor of SequenceShape with rvalue inputs. 318 /// 319 /// \param[in] shapes All element-shapes. SequenceShape(BaseShapePtrList && shapes)320 explicit SequenceShape(BaseShapePtrList &&shapes) : p_shapes_(std::move(shapes)) {} 321 322 /// \brief Destructor of SequenceShape. 323 ~SequenceShape() override = default; 324 MS_DECLARE_PARENT(SequenceShape, BaseShape) 325 326 /// \brief Get the description string about the SequenceShape object. 327 /// 328 /// \return The description string about the SequenceShape object. 329 std::string ToString() const override; 330 331 /// \brief Clone all element-shapes. 332 /// 333 /// \return New cloned element-shapes. 334 BaseShapePtrList ElementsClone() const; 335 336 /// \brief Check whether SequenceShape object is equal to a BaseShape object. 337 /// 338 /// \param[in] other Another SequenceShape object. 339 /// \return True if current SequenceShape object is equal to another BaseShape object, otherwise false. 340 template <typename T> SequenceEqual(const BaseShape & other)341 bool SequenceEqual(const BaseShape &other) const { 342 if (tid() != other.tid()) { 343 return false; 344 } 345 auto &other_shapes = static_cast<const T &>(other).p_shapes_; 346 if (other_shapes.size() != p_shapes_.size()) { 347 return false; 348 } 349 for (uint64_t i = 0; i < p_shapes_.size(); ++i) { 350 MS_EXCEPTION_IF_NULL(p_shapes_[i]); 351 MS_EXCEPTION_IF_NULL(other_shapes[i]); 352 if (!(*p_shapes_[i] == *other_shapes[i])) { 353 return false; 354 } 355 } 356 return true; 357 } 358 359 /// \brief Get all element-shapes. 360 /// 361 /// \return All element-shapes. shape()362 const BaseShapePtrList &shape() const { return p_shapes_; } 363 364 /// \brief Get the number of element-shapes. 365 /// 366 /// \return The number of element-shapes. size()367 size_t size() const { return p_shapes_.size(); } 368 369 /// \brief Get the element-shape by index through operator '[]'. 370 /// 371 /// \param[in] dim The index of element shape. 372 /// \return The element shape got by index. 373 const BaseShapePtr &operator[](std::size_t dim) const { return p_shapes_[dim]; } 374 375 /// \brief Check whether any element shape of DynamicSequenceShape is dynamic shape or dynamic rank. 376 /// 377 /// \return True if any element shape of DynamicSequenceShape is dynamic shape or dynamic rank, otherwise false. IsDynamic()378 bool IsDynamic() const override { 379 return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); }); 380 } 381 382 /// \brief Check whether all elements shape of DynamicSequenceShape are empty shape. 383 /// 384 /// \return True if all elements shape of DynamicSequenceShape are empty shape. IsDimZero()385 bool IsDimZero() const override { 386 return std::all_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimZero(); }); 387 }; 388 389 /// \brief Check whether any element shape of DynamicSequenceShape is dynamic shape. 390 /// 391 /// \return True if any element shape of DynamicSequenceShape is dynamic shape. IsDimUnknown()392 bool IsDimUnknown() const override { 393 return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDimUnknown(); }); 394 } 395 396 ListSymbolPtr BuildSymbolicShape() const override; 397 398 protected: 399 BaseShapePtrList p_shapes_; // shape list of each elements 400 }; 401 using SequenceShapePtr = std::shared_ptr<SequenceShape>; 402 403 /// \brief TupleShape defines shape used by tuple with tensor inside. 404 class MS_CORE_API TupleShape final : public SequenceShape { 405 public: 406 /// \brief Constructor of TupleShape. TupleShape()407 TupleShape() : SequenceShape() {} 408 409 /// \brief Constructor of TupleShape. 410 /// 411 /// \param[in] shapes Element-shapes of TupleShape. TupleShape(const BaseShapePtrList & shapes)412 explicit TupleShape(const BaseShapePtrList &shapes) : SequenceShape(shapes) {} 413 414 /// \brief Constructor of TupleShape with rvalue input. 415 /// 416 /// \param[in] shapes Element-shapes of TupleShape. TupleShape(BaseShapePtrList && shapes)417 explicit TupleShape(BaseShapePtrList &&shapes) : SequenceShape(std::move(shapes)) {} 418 419 /// \brief Destructor of TupleShape. 420 ~TupleShape() override = default; MS_DECLARE_PARENT(TupleShape,SequenceShape)421 MS_DECLARE_PARENT(TupleShape, SequenceShape) 422 423 std::string ToString() const override { return type_name() + "(" + SequenceShape::ToString() + ")"; } 424 Clone()425 BaseShapePtr Clone() const override { return std::make_shared<TupleShape>(ElementsClone()); } 426 427 bool operator==(const BaseShape &other) const override { return SequenceEqual<TupleShape>(other); } 428 }; 429 using TupleShapePtr = std::shared_ptr<TupleShape>; 430 431 /// \brief ListShape defines shape used by list with tensor inside. 432 class MS_CORE_API ListShape final : public SequenceShape { 433 public: 434 /// \brief Constructor of ListShape. ListShape()435 ListShape() : SequenceShape() {} 436 /// \brief Constructor of ListShape. 437 /// 438 /// \param[in] shapes Element-shapes of ListShape. ListShape(const BaseShapePtrList & shapes)439 explicit ListShape(const BaseShapePtrList &shapes) : SequenceShape(shapes) {} 440 441 /// \brief Constructor of ListShape with rvalue input. 442 /// 443 /// \param[in] shapes Element-shapes of ListShape. ListShape(BaseShapePtrList && shapes)444 explicit ListShape(BaseShapePtrList &&shapes) : SequenceShape(std::move(shapes)) {} 445 446 /// \brief Destructor of ListShape. 447 ~ListShape() override = default; MS_DECLARE_PARENT(ListShape,SequenceShape)448 MS_DECLARE_PARENT(ListShape, SequenceShape) 449 450 std::string ToString() const override { return type_name() + "[" + SequenceShape::ToString() + "]"; } 451 Clone()452 BaseShapePtr Clone() const override { return std::make_shared<ListShape>(SequenceShape::ElementsClone()); } 453 454 bool operator==(const BaseShape &other) const override { return SequenceEqual<ListShape>(other); } 455 }; 456 using ListShapePtr = std::shared_ptr<ListShape>; 457 } // namespace abstract 458 } // namespace mindspore 459 460 #endif // MINDSPORE_CORE_ABSTRACT_DSHAPE_H_ 461