• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_ABSTRACT_DSHAPE_H_
20 #define MINDSPORE_CORE_ABSTRACT_DSHAPE_H_
21 
22 #include <vector>
23 #include <string>
24 #include <sstream>
25 #include <unordered_map>
26 #include <typeindex>
27 #include <memory>
28 #include <algorithm>
29 
30 #include "utils/log_adapter.h"
31 #include "base/base.h"
32 #include "utils/shape_utils.h"
33 
34 namespace mindspore {
35 namespace abstract {
36 class BaseShape;
37 using BaseShapePtr = std::shared_ptr<BaseShape>;
38 using BaseShapePtrList = std::vector<BaseShapePtr>;
39 
40 class MS_CORE_API BaseShape : public Base {
41  public:
42   BaseShape() = default;
43   ~BaseShape() override = default;
44 
45   MS_DECLARE_PARENT(BaseShape, Base)
46   virtual bool operator==(const BaseShape &other) const;
47   bool operator!=(const BaseShape &other) const;
hash()48   std::size_t hash() const override { return tid(); }
49   virtual bool IsDynamic() const = 0;
50 
51   // return a deep copy
52   virtual BaseShapePtr Clone() const = 0;
Broaden()53   virtual void Broaden() {}
54 };
55 
56 class MS_CORE_API NoShape : public BaseShape {
57  public:
MS_DECLARE_PARENT(NoShape,BaseShape)58   MS_DECLARE_PARENT(NoShape, BaseShape)
59   BaseShapePtr Clone() const override { return std::make_shared<NoShape>(); }
ToString()60   std::string ToString() const override { return type_name(); }
IsDynamic()61   bool IsDynamic() const override { return false; }
62 };
63 
64 inline const std::shared_ptr<NoShape> kNoShape = std::make_shared<NoShape>();
65 
66 class MS_CORE_API Shape : public BaseShape {
67  public:
68   static const int64_t SHP_ANY = -1;
Shape()69   Shape() : shape_() {}
Shape(const std::initializer_list<int64_t> & list)70   Shape(const std::initializer_list<int64_t> &list) : shape_(list) {}
Shape(const ShapeVector & list)71   explicit Shape(const ShapeVector &list) : shape_(list) {}
Shape(const ShapeVector & list,const ShapeVector & min_shape,const ShapeVector & max_shape)72   Shape(const ShapeVector &list, const ShapeVector &min_shape, const ShapeVector &max_shape)
73       : shape_(list), min_shape_(min_shape), max_shape_(max_shape) {}
74   ~Shape() override = default;
75   MS_DECLARE_PARENT(Shape, BaseShape)
76   std::string ToString() const override;
77   std::string DumpText() const override;
78   bool operator==(const BaseShape &other) const override;
Clone()79   BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_, min_shape_, max_shape_); }
80   void Broaden() override;
set_shape(const ShapeVector & shape)81   void set_shape(const ShapeVector &shape) { shape_ = shape; }
shape()82   const ShapeVector &shape() { return shape_; }
min_shape()83   const ShapeVector &min_shape() { return min_shape_; }
max_shape()84   const ShapeVector &max_shape() { return max_shape_; }
IsDynamic()85   bool IsDynamic() const override {
86     return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; });
87   }
88 
89  private:
90   ShapeVector shape_;      // use SHP_ANY to implement the any shape in python
91   ShapeVector min_shape_;  // record minimum length for each dynamic dimension
92   ShapeVector max_shape_;  // record maximum length for each dynamic dimension
93 };
94 using ShapePtr = std::shared_ptr<Shape>;
95 using ShapePtrList = std::vector<ShapePtr>;
96 
97 class MS_CORE_API SequeueShape : public BaseShape {
98  public:
SequeueShape()99   SequeueShape() : p_shapes_() {}
SequeueShape(const BaseShapePtrList & shapes)100   explicit SequeueShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {}
101   ~SequeueShape() override = default;
102   MS_DECLARE_PARENT(SequeueShape, BaseShape)
103 
104   std::string ToString() const override;
105   BaseShapePtrList ElementsClone() const;
106 
107   template <typename T>
108   bool SequeueEqual(const BaseShape &other) const;
109 
shape()110   const BaseShapePtrList &shape() const { return p_shapes_; }
size()111   size_t size() const { return p_shapes_.size(); }
112   const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; }
IsDynamic()113   bool IsDynamic() const override {
114     return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); });
115   }
116 
117  protected:
118   BaseShapePtrList p_shapes_;  // shape list of each elements
119 };
120 using SequeueShapePtr = std::shared_ptr<SequeueShape>;
121 
122 class MS_CORE_API TupleShape : public SequeueShape {
123  public:
TupleShape()124   TupleShape() : SequeueShape() {}
TupleShape(const BaseShapePtrList & shapes)125   explicit TupleShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {}
126   ~TupleShape() override = default;
MS_DECLARE_PARENT(TupleShape,SequeueShape)127   MS_DECLARE_PARENT(TupleShape, SequeueShape)
128 
129   std::string ToString() const override { return type_name() + "(" + SequeueShape::ToString() + ")"; }
130 
Clone()131   BaseShapePtr Clone() const override { return std::make_shared<TupleShape>(ElementsClone()); }
132 
133   bool operator==(const BaseShape &other) const override { return SequeueEqual<TupleShape>(other); }
134 };
135 using TupleShapePtr = std::shared_ptr<TupleShape>;
136 
137 class MS_CORE_API ListShape : public SequeueShape {
138  public:
ListShape()139   ListShape() : SequeueShape() {}
ListShape(const BaseShapePtrList & shapes)140   explicit ListShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {}
141   ~ListShape() override = default;
MS_DECLARE_PARENT(ListShape,SequeueShape)142   MS_DECLARE_PARENT(ListShape, SequeueShape)
143 
144   std::string ToString() const override { return type_name() + "[" + SequeueShape::ToString() + "]"; }
145 
Clone()146   BaseShapePtr Clone() const override { return std::make_shared<ListShape>(SequeueShape::ElementsClone()); }
147 
148   bool operator==(const BaseShape &other) const override { return SequeueEqual<ListShape>(other); }
149 };
150 using ListShapePtr = std::shared_ptr<ListShape>;
151 }  // namespace abstract
152 }  // namespace mindspore
153 
154 #endif  // MINDSPORE_CORE_ABSTRACT_DSHAPE_H_
155