1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_ABSTRACT_DSHAPE_H_ 20 #define MINDSPORE_CORE_ABSTRACT_DSHAPE_H_ 21 22 #include <vector> 23 #include <string> 24 #include <sstream> 25 #include <unordered_map> 26 #include <typeindex> 27 #include <memory> 28 #include <algorithm> 29 30 #include "utils/log_adapter.h" 31 #include "base/base.h" 32 #include "utils/shape_utils.h" 33 34 namespace mindspore { 35 namespace abstract { 36 class BaseShape; 37 using BaseShapePtr = std::shared_ptr<BaseShape>; 38 using BaseShapePtrList = std::vector<BaseShapePtr>; 39 40 class MS_CORE_API BaseShape : public Base { 41 public: 42 BaseShape() = default; 43 ~BaseShape() override = default; 44 45 MS_DECLARE_PARENT(BaseShape, Base) 46 virtual bool operator==(const BaseShape &other) const; 47 bool operator!=(const BaseShape &other) const; hash()48 std::size_t hash() const override { return tid(); } 49 virtual bool IsDynamic() const = 0; 50 51 // return a deep copy 52 virtual BaseShapePtr Clone() const = 0; Broaden()53 virtual void Broaden() {} 54 }; 55 56 class MS_CORE_API NoShape : public BaseShape { 57 public: MS_DECLARE_PARENT(NoShape,BaseShape)58 MS_DECLARE_PARENT(NoShape, BaseShape) 59 BaseShapePtr Clone() const override { return std::make_shared<NoShape>(); } ToString()60 std::string ToString() const override { return type_name(); } IsDynamic()61 bool IsDynamic() const override { return false; } 62 }; 63 64 inline const std::shared_ptr<NoShape> kNoShape = std::make_shared<NoShape>(); 65 66 class MS_CORE_API Shape : public BaseShape { 67 public: 68 static const int64_t SHP_ANY = -1; Shape()69 Shape() : shape_() {} Shape(const std::initializer_list<int64_t> & list)70 Shape(const std::initializer_list<int64_t> &list) : shape_(list) {} Shape(const ShapeVector & list)71 explicit Shape(const ShapeVector &list) : shape_(list) {} Shape(const ShapeVector & list,const ShapeVector & min_shape,const ShapeVector & max_shape)72 Shape(const ShapeVector &list, const ShapeVector &min_shape, const ShapeVector &max_shape) 73 : shape_(list), min_shape_(min_shape), max_shape_(max_shape) {} 74 ~Shape() override = default; 75 MS_DECLARE_PARENT(Shape, BaseShape) 76 std::string ToString() const override; 77 std::string DumpText() const override; 78 bool operator==(const BaseShape &other) const override; Clone()79 BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_, min_shape_, max_shape_); } 80 void Broaden() override; set_shape(const ShapeVector & shape)81 void set_shape(const ShapeVector &shape) { shape_ = shape; } shape()82 const ShapeVector &shape() { return shape_; } min_shape()83 const ShapeVector &min_shape() { return min_shape_; } max_shape()84 const ShapeVector &max_shape() { return max_shape_; } IsDynamic()85 bool IsDynamic() const override { 86 return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; }); 87 } 88 89 private: 90 ShapeVector shape_; // use SHP_ANY to implement the any shape in python 91 ShapeVector min_shape_; // record minimum length for each dynamic dimension 92 ShapeVector max_shape_; // record maximum length for each dynamic dimension 93 }; 94 using ShapePtr = std::shared_ptr<Shape>; 95 using ShapePtrList = std::vector<ShapePtr>; 96 97 class MS_CORE_API SequeueShape : public BaseShape { 98 public: SequeueShape()99 SequeueShape() : p_shapes_() {} SequeueShape(const BaseShapePtrList & shapes)100 explicit SequeueShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {} 101 ~SequeueShape() override = default; 102 MS_DECLARE_PARENT(SequeueShape, BaseShape) 103 104 std::string ToString() const override; 105 BaseShapePtrList ElementsClone() const; 106 107 template <typename T> 108 bool SequeueEqual(const BaseShape &other) const; 109 shape()110 const BaseShapePtrList &shape() const { return p_shapes_; } size()111 size_t size() const { return p_shapes_.size(); } 112 const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; } IsDynamic()113 bool IsDynamic() const override { 114 return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); }); 115 } 116 117 protected: 118 BaseShapePtrList p_shapes_; // shape list of each elements 119 }; 120 using SequeueShapePtr = std::shared_ptr<SequeueShape>; 121 122 class MS_CORE_API TupleShape : public SequeueShape { 123 public: TupleShape()124 TupleShape() : SequeueShape() {} TupleShape(const BaseShapePtrList & shapes)125 explicit TupleShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} 126 ~TupleShape() override = default; MS_DECLARE_PARENT(TupleShape,SequeueShape)127 MS_DECLARE_PARENT(TupleShape, SequeueShape) 128 129 std::string ToString() const override { return type_name() + "(" + SequeueShape::ToString() + ")"; } 130 Clone()131 BaseShapePtr Clone() const override { return std::make_shared<TupleShape>(ElementsClone()); } 132 133 bool operator==(const BaseShape &other) const override { return SequeueEqual<TupleShape>(other); } 134 }; 135 using TupleShapePtr = std::shared_ptr<TupleShape>; 136 137 class MS_CORE_API ListShape : public SequeueShape { 138 public: ListShape()139 ListShape() : SequeueShape() {} ListShape(const BaseShapePtrList & shapes)140 explicit ListShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} 141 ~ListShape() override = default; MS_DECLARE_PARENT(ListShape,SequeueShape)142 MS_DECLARE_PARENT(ListShape, SequeueShape) 143 144 std::string ToString() const override { return type_name() + "[" + SequeueShape::ToString() + "]"; } 145 Clone()146 BaseShapePtr Clone() const override { return std::make_shared<ListShape>(SequeueShape::ElementsClone()); } 147 148 bool operator==(const BaseShape &other) const override { return SequeueEqual<ListShape>(other); } 149 }; 150 using ListShapePtr = std::shared_ptr<ListShape>; 151 } // namespace abstract 152 } // namespace mindspore 153 154 #endif // MINDSPORE_CORE_ABSTRACT_DSHAPE_H_ 155