• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_
18 
19 #include <cstdint>
20 #include <ostream>
21 #include <sstream>
22 #include <string>
23 #include <vector>
24 
25 #ifndef ENABLE_ANDROID
26 #include <opencv2/core/mat.hpp>
27 #endif
28 
29 #ifdef ENABLE_PYTHON
30 #include "pybind11/pybind11.h"
31 namespace py = pybind11;
32 #endif
33 
34 #include "minddata/dataset/include/dataset/constants.h"
35 #include "minddata/dataset/util/status.h"
36 #include "minddata/dataset/core/global_context.h"
37 #include "minddata/dataset/util/allocator.h"
38 
39 namespace mindspore {
40 namespace dataset {
41 // Class that represents a shape of a Tensor. A shape can be:
42 // -# Known shape (mKnown = true)
43 //        -# Scalar --> empty vector        --> <>
44 //        -# n-Dim  --> not empty vector    --> <d1, d2, d2, d3, ...> where di is >= 0\n
45 //           Example: <1,2>,  <1>,   <1,13,10,11,1>
46 // -# Unknown shape (mKnown = false)
47 //        -# Rank is unknown            --> empty vector     --> <>
48 //        -# one or more dim is unknown --> not empty vector --> <d1, d2, d2, d3, ...> where di is unknown\n
49 //           Example: <3,?> (the 1st dim is unknown)\n
50 //              <2,?,?,?> (all dims but the 0th dim are unknown)
51 
52 /// \brief  TensorShape supports any dim > 0 and < 2^31-1
53 class TensorShape {
54  public:
55   static constexpr dsize_t kDimUnknown = -1;  // constant for an unknown dimension
56 
57   // Force the compiler to not create a no-arg constructor
58   TensorShape() = delete;
59 
60   /// \brief Create a Shape from an initialization list (e.g., TensorShape s = {2,2}).
61   ///     If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown
62   /// \param[in] list
63   explicit TensorShape(const std::initializer_list<dsize_t> &list);
64 
65   /// \brief Create a Shape from a vector (e.g., TensorShape s = std::vector<dsize_t>({2,2}) ).
66   ///     If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown
67   /// \param[in] list
68   explicit TensorShape(const std::vector<dsize_t> &list);
69 
70   /// \brief Copy constructor
71   /// \param[in] shape
72   TensorShape(const TensorShape &shape);
73 
74 #ifdef ENABLE_PYTHON
75   /// \brief construct a TensorShape via a python list
76   /// \param[in] py::list l - a list object from python
77   explicit TensorShape(py::list l);
78 #endif
79 
80   ~TensorShape() = default;
81 
82   /// \brief Create a scalar Shape (i.e., empty shape with mKnown = true)
83   /// \return TensorShape
CreateScalar()84   static TensorShape CreateScalar() { return TensorShape({}); }
85 
86   /// \brief Create a shape with an unknown rank.
87   /// \return TensorShape
88   static TensorShape CreateUnknownRankShape();
89 
90   /// \brief Create a shape with a known rank .
91   /// \return TensorShape
92   static TensorShape CreateUnknownShapeWithRank(dsize_t rank);
93 
94   /// \brief Insert a new dim into a copy of the current shape.
95   /// \param[in] dim to be added
96   /// \param[in] axis the index where dim should be added
97   /// \return New modified shape
98   TensorShape InsertDim(dsize_t axis, dsize_t dim) const;
99 
100   /// \brief Insert new dim at index 0. For example,  <2,4> --> PrependDim(4) --> <4,2,4>
101   /// \param[in] dim
102   /// \return
103   TensorShape PrependDim(dsize_t dim) const;
104 
105   /// \brief Insert a new dim at the end of the shape. For example,  <2,4> --> AppendDim(4) --> <2,4,4>
106   /// \param[in] dim
107   /// \return
108   TensorShape AppendDim(dsize_t dim) const;
109 
110 #ifndef ENABLE_ANDROID
111   /// \brief Create a shape based on OpenCV shape and type
112   /// \param[in] cv_size
113   /// \param[in] type int that represent the type in OpenCV, example CV_8U, CV_64S
114   TensorShape(cv::MatSize cv_size, uint32_t type);
115 #endif
116 
Size()117   dsize_t Size() const { return raw_shape_.size(); }
118 
Rank()119   dsize_t Rank() const { return raw_shape_.size(); }
120 
known()121   bool known() const { return known_; }
122 
empty()123   bool empty() const { return raw_shape_.empty(); }
124 
125   dsize_t NumOfElements() const;
126 
127   bool operator==(const TensorShape &rhs) const { return known_ == rhs.known_ && raw_shape_ == rhs.raw_shape_; }
128 
129   bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); }
130 
131   dsize_t operator[](const dsize_t index) const {
132     if (index < 0) return raw_shape_[raw_shape_.size() + index];
133     return raw_shape_[index];
134   }
135 
136   /// \brief Return the Shape as a vector
137   /// \return
138   std::vector<dsize_t> AsVector() const;
139 
140   /// \brief Returns the class info as a string
141   /// \return
ToString()142   std::string ToString() const {
143     std::stringstream ss;
144     ss << *this;
145     return ss.str();
146   }
147 
148   /// \brief Actual print function used by operator<<
149   /// \param out output string stream
150   void Print(std::ostream &out) const;
151 
152   /// \brief << Stream output operator overload
153   ///     This allows you to print the info using stream operators
154   /// \param[in] out - reference to the output stream being overloaded
155   /// \param[in] rO - reference to the TensorShape to display
156   /// \return - the output stream must be returned
157   friend std::ostream &operator<<(std::ostream &out, const TensorShape &so) {
158     so.Print(out);
159     return out;
160   }
161 
162 #ifdef ENABLE_PYTHON
163   py::list AsPyList();
164 #endif
165 
166   /// \brief Checks if the given index is a valid index for this tensor.
167   ///     For example: Tensor<3,4> Index<1,1> is valid. But Index<4,1> or <1> are not.
168   /// \param[in] index
169   /// \return bool
170   bool IsValidIndex(const std::vector<dsize_t> &index) const;
171 
172   TensorShape Squeeze() const;
173 
174   std::vector<dsize_t> Strides() const;
175 
176   /// \brief Returns the location of the item assuming row major memory layout.
177   /// \param[in] index
178   /// \param[out] flat_index
179   /// \return
180   Status ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const;
181 
182  private:
183   // True if known and valid shape, false otherwise
184   bool known_;
185   // Vector to keep the dims of the shape.
186   std::vector<dsize_t, IntAlloc> raw_shape_;
187   // Vector to keep the strides of the shape. The size is rank+1
188   std::vector<dsize_t, IntAlloc> strides_;
189 
190   /// \brief Internal utility function to iterate over a list,
191   ///     check if the dim is valid and then insert it into the shape.
192   /// \param[in] list Iterable list
193   /// \return true if the shape is valid and no overflow would be generated when counting the number of elements.
194   ///     False otherwise.
195   template <typename T>
196   void AddListToShape(const T &list);
197 };
198 }  // namespace dataset
199 }  // namespace mindspore
200 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_
201