1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_ 18 19 #include <cstdint> 20 #include <ostream> 21 #include <sstream> 22 #include <string> 23 #include <vector> 24 25 #ifndef ENABLE_ANDROID 26 #include <opencv2/core/mat.hpp> 27 #endif 28 29 #ifdef ENABLE_PYTHON 30 #include "pybind11/pybind11.h" 31 namespace py = pybind11; 32 #endif 33 34 #include "minddata/dataset/include/dataset/constants.h" 35 #include "minddata/dataset/util/status.h" 36 #include "minddata/dataset/core/global_context.h" 37 #include "minddata/dataset/util/allocator.h" 38 39 namespace mindspore { 40 namespace dataset { 41 // Class that represents a shape of a Tensor. A shape can be: 42 // -# Known shape (mKnown = true) 43 // -# Scalar --> empty vector --> <> 44 // -# n-Dim --> not empty vector --> <d1, d2, d2, d3, ...> where di is >= 0\n 45 // Example: <1,2>, <1>, <1,13,10,11,1> 46 // -# Unknown shape (mKnown = false) 47 // -# Rank is unknown --> empty vector --> <> 48 // -# one or more dim is unknown --> not empty vector --> <d1, d2, d2, d3, ...> where di is unknown\n 49 // Example: <3,?> (the 1st dim is unknown)\n 50 // <2,?,?,?> (all dims but the 0th dim are unknown) 51 52 /// \brief TensorShape supports any dim > 0 and < 2^31-1 53 class TensorShape { 54 public: 55 static constexpr dsize_t kDimUnknown = -1; // constant for an unknown dimension 56 57 // Force the compiler to not create a no-arg constructor 58 TensorShape() = delete; 59 60 /// \brief Create a Shape from an initialization list (e.g., TensorShape s = {2,2}). 61 /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown 62 /// \param[in] list 63 explicit TensorShape(const std::initializer_list<dsize_t> &list); 64 65 /// \brief Create a Shape from a vector (e.g., TensorShape s = std::vector<dsize_t>({2,2}) ). 66 /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown 67 /// \param[in] list 68 explicit TensorShape(const std::vector<dsize_t> &list); 69 70 /// \brief Copy constructor 71 /// \param[in] shape 72 TensorShape(const TensorShape &shape); 73 74 #ifdef ENABLE_PYTHON 75 /// \brief construct a TensorShape via a python list 76 /// \param[in] py::list l - a list object from python 77 explicit TensorShape(py::list l); 78 #endif 79 80 ~TensorShape() = default; 81 82 /// \brief Create a scalar Shape (i.e., empty shape with mKnown = true) 83 /// \return TensorShape CreateScalar()84 static TensorShape CreateScalar() { return TensorShape({}); } 85 86 /// \brief Create a shape with an unknown rank. 87 /// \return TensorShape 88 static TensorShape CreateUnknownRankShape(); 89 90 /// \brief Create a shape with a known rank . 91 /// \return TensorShape 92 static TensorShape CreateUnknownShapeWithRank(dsize_t rank); 93 94 /// \brief Insert a new dim into a copy of the current shape. 95 /// \param[in] dim to be added 96 /// \param[in] axis the index where dim should be added 97 /// \return New modified shape 98 TensorShape InsertDim(dsize_t axis, dsize_t dim) const; 99 100 /// \brief Insert new dim at index 0. For example, <2,4> --> PrependDim(4) --> <4,2,4> 101 /// \param[in] dim 102 /// \return 103 TensorShape PrependDim(dsize_t dim) const; 104 105 /// \brief Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4> 106 /// \param[in] dim 107 /// \return 108 TensorShape AppendDim(dsize_t dim) const; 109 110 #ifndef ENABLE_ANDROID 111 /// \brief Create a shape based on OpenCV shape and type 112 /// \param[in] cv_size 113 /// \param[in] type int that represent the type in OpenCV, example CV_8U, CV_64S 114 TensorShape(cv::MatSize cv_size, uint32_t type); 115 #endif 116 Size()117 dsize_t Size() const { return raw_shape_.size(); } 118 Rank()119 dsize_t Rank() const { return raw_shape_.size(); } 120 known()121 bool known() const { return known_; } 122 empty()123 bool empty() const { return raw_shape_.empty(); } 124 125 dsize_t NumOfElements() const; 126 127 bool operator==(const TensorShape &rhs) const { return known_ == rhs.known_ && raw_shape_ == rhs.raw_shape_; } 128 129 bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); } 130 131 dsize_t operator[](const dsize_t index) const { 132 if (index < 0) return raw_shape_[raw_shape_.size() + index]; 133 return raw_shape_[index]; 134 } 135 136 /// \brief Return the Shape as a vector 137 /// \return 138 std::vector<dsize_t> AsVector() const; 139 140 /// \brief Returns the class info as a string 141 /// \return ToString()142 std::string ToString() const { 143 std::stringstream ss; 144 ss << *this; 145 return ss.str(); 146 } 147 148 /// \brief Actual print function used by operator<< 149 /// \param out output string stream 150 void Print(std::ostream &out) const; 151 152 /// \brief << Stream output operator overload 153 /// This allows you to print the info using stream operators 154 /// \param[in] out - reference to the output stream being overloaded 155 /// \param[in] rO - reference to the TensorShape to display 156 /// \return - the output stream must be returned 157 friend std::ostream &operator<<(std::ostream &out, const TensorShape &so) { 158 so.Print(out); 159 return out; 160 } 161 162 #ifdef ENABLE_PYTHON 163 py::list AsPyList(); 164 #endif 165 166 /// \brief Checks if the given index is a valid index for this tensor. 167 /// For example: Tensor<3,4> Index<1,1> is valid. But Index<4,1> or <1> are not. 168 /// \param[in] index 169 /// \return bool 170 bool IsValidIndex(const std::vector<dsize_t> &index) const; 171 172 TensorShape Squeeze() const; 173 174 std::vector<dsize_t> Strides() const; 175 176 /// \brief Returns the location of the item assuming row major memory layout. 177 /// \param[in] index 178 /// \param[out] flat_index 179 /// \return 180 Status ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const; 181 182 private: 183 // True if known and valid shape, false otherwise 184 bool known_; 185 // Vector to keep the dims of the shape. 186 std::vector<dsize_t, IntAlloc> raw_shape_; 187 // Vector to keep the strides of the shape. The size is rank+1 188 std::vector<dsize_t, IntAlloc> strides_; 189 190 /// \brief Internal utility function to iterate over a list, 191 /// check if the dim is valid and then insert it into the shape. 192 /// \param[in] list Iterable list 193 /// \return true if the shape is valid and no overflow would be generated when counting the number of elements. 194 /// False otherwise. 195 template <typename T> 196 void AddListToShape(const T &list); 197 }; 198 } // namespace dataset 199 } // namespace mindspore 200 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_ 201