1 /** 2 * Copyright 2020-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_ 18 19 #include <cstdint> 20 #include <ostream> 21 #include <sstream> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #ifndef ENABLE_ANDROID 27 #include <opencv2/core/mat.hpp> 28 #endif 29 30 #ifdef ENABLE_PYTHON 31 #include "pybind11/pybind11.h" 32 namespace py = pybind11; 33 #endif 34 35 #include "minddata/dataset/include/dataset/constants.h" 36 #include "minddata/dataset/util/status.h" 37 #include "minddata/dataset/core/global_context.h" 38 #include "minddata/dataset/util/allocator.h" 39 40 namespace mindspore { 41 namespace dataset { 42 // Class that represents a shape of a Tensor. A shape can be: 43 // -# Known shape (mKnown = true) 44 // -# Scalar --> empty vector --> <> 45 // -# n-Dim --> not empty vector --> <d1, d2, d2, d3, ...> where di is >= 0\n 46 // Example: <1,2>, <1>, <1,13,10,11,1> 47 // -# Unknown shape (mKnown = false) 48 // -# Rank is unknown --> empty vector --> <> 49 // -# one or more dim is unknown --> not empty vector --> <d1, d2, d2, d3, ...> where di is unknown\n 50 // Example: <3,?> (the 1st dim is unknown)\n 51 // <2,?,?,?> (all dims but the 0th dim are unknown) 52 53 /// \brief TensorShape supports any dim > 0 and < 2^31-1 54 class DATASET_API TensorShape { 55 public: 56 static constexpr dsize_t kDimUnknown = -1; // constant for an unknown dimension 57 58 // Force the compiler to not create a no-arg constructor 59 TensorShape() = delete; 60 61 /// \brief Create a Shape from an initialization list (e.g., TensorShape s = {2,2}). 62 /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown 63 /// \param[in] list Length list of each axis. 64 TensorShape(const std::initializer_list<dsize_t> &list); 65 66 /// \brief Create a Shape from a vector (e.g., TensorShape s = std::vector<dsize_t>({2,2}) ). 67 /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown 68 /// \param[in] list 69 explicit TensorShape(const std::vector<dsize_t> &list); 70 71 /// \brief Copy constructor. 72 /// \param[in] shape TensorShape to copy from. 73 TensorShape(const TensorShape &shape); 74 75 /// \brief Move constructor. 76 /// \param[in] shape TensorShape to copy from. 77 TensorShape(TensorShape &&shape) noexcept; 78 79 /// \brief Copy assignment. 80 /// \param[in] shape TensorShape to move from. 81 TensorShape &operator=(const TensorShape &shape); 82 83 /// \brief Move assignment. 84 /// \param[in] shape TensorShape to move from. 85 TensorShape &operator=(TensorShape &&shape) noexcept; 86 87 #ifdef ENABLE_PYTHON 88 /// \brief Construct a TensorShape via a python list. 89 /// \param[in] l A py::list of the shape. 90 explicit TensorShape(py::list l); 91 #endif 92 93 ~TensorShape() = default; 94 95 /// \brief Create a scalar Shape (i.e., empty shape with mKnown = true) 96 /// \return TensorShape CreateScalar()97 static TensorShape CreateScalar() { 98 static std::vector<dsize_t> empty_shape{}; 99 return TensorShape(empty_shape); 100 } 101 102 /// \brief Create a shape with an unknown rank. 103 /// \return TensorShape 104 static TensorShape CreateUnknownRankShape(); 105 106 /// \brief Create a shape with a known rank . 107 /// \return TensorShape 108 static TensorShape CreateUnknownShapeWithRank(dsize_t rank); 109 110 /// \brief Insert a new dim into a copy of the current shape. 111 /// \param[in] dim to be added 112 /// \param[in] axis the index where dim should be added 113 /// \return New modified shape 114 TensorShape InsertDim(dsize_t axis, dsize_t dim) const; 115 116 /// \brief Insert new dim at index 0. For example, <2,4> --> PrependDim(4) --> <4,2,4> 117 /// \param[in] dim 118 /// \return 119 TensorShape PrependDim(dsize_t dim) const; 120 121 /// \brief Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4> 122 /// \param[in] dim 123 /// \return 124 TensorShape AppendDim(dsize_t dim) const; 125 126 #ifndef ENABLE_ANDROID 127 /// \brief Create a shape based on OpenCV shape and type 128 /// \param[in] cv_size 129 /// \param[in] type int that represent the type in OpenCV, example CV_8U, CV_64S 130 TensorShape(cv::MatSize cv_size, uint32_t type); 131 #endif 132 Size()133 dsize_t Size() const { return raw_shape_.size(); } 134 Rank()135 dsize_t Rank() const { return raw_shape_.size(); } 136 known()137 bool known() const { return known_; } 138 empty()139 bool empty() const { return raw_shape_.empty(); } 140 141 dsize_t NumOfElements() const; 142 143 bool operator==(const TensorShape &rhs) const { return known_ == rhs.known_ && raw_shape_ == rhs.raw_shape_; } 144 145 bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); } 146 147 dsize_t operator[](const dsize_t index) const { 148 if (index < 0) { 149 return raw_shape_[raw_shape_.size() + index]; 150 } 151 return raw_shape_[index]; 152 } 153 154 /// \brief Return the Shape as a vector 155 /// \return 156 std::vector<dsize_t> AsVector() const; 157 158 /// \brief Returns the class info as a string 159 /// \return ToString()160 std::string ToString() const { 161 std::stringstream ss; 162 ss << *this; 163 return ss.str(); 164 } 165 166 /// \brief Actual print function used by operator<< 167 /// \param out output string stream 168 void Print(std::ostream &out) const; 169 170 /// \brief << Stream output operator overload 171 /// This allows you to print the info using stream operators 172 /// \param[in] out - reference to the output stream being overloaded 173 /// \param[in] rO - reference to the TensorShape to display 174 /// \return - the output stream must be returned 175 friend std::ostream &operator<<(std::ostream &out, const TensorShape &so) { 176 so.Print(out); 177 return out; 178 } 179 180 #ifdef ENABLE_PYTHON 181 py::list AsPyList(); 182 #endif 183 184 /// \brief Checks if the given index is a valid index for this tensor. 185 /// For example: Tensor<3,4> Index<1,1> is valid. But Index<4,1> or <1> are not. 186 /// \param[in] index 187 /// \return bool 188 bool IsValidIndex(const std::vector<dsize_t> &index) const; 189 190 TensorShape Squeeze() const; 191 192 std::vector<dsize_t> Strides() const; 193 194 /// \brief Returns the location of the item assuming row major memory layout. 195 /// \param[in] index 196 /// \param[out] flat_index 197 /// \return 198 Status ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const; 199 200 private: 201 // Vector to keep the dims of the shape. 202 std::vector<dsize_t> raw_shape_; 203 // Vector to keep the strides of the shape. The size is rank+1 204 std::vector<dsize_t> strides_; 205 // True if known and valid shape, false otherwise 206 bool known_; 207 208 /// \brief Internal utility function to iterate over a list, 209 /// check if the dim is valid and then insert it into the shape. 210 /// \param[in] list Iterable list 211 /// \return true if the shape is valid and no overflow would be generated when counting the number of elements. 212 /// False otherwise. 213 template <typename T> 214 void AddListToShape(const T &list); 215 }; 216 } // namespace dataset 217 } // namespace mindspore 218 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_ 219