• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_
18 
19 #include <cstdint>
20 #include <ostream>
21 #include <sstream>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #ifndef ENABLE_ANDROID
27 #include <opencv2/core/mat.hpp>
28 #endif
29 
30 #ifdef ENABLE_PYTHON
31 #include "pybind11/pybind11.h"
32 namespace py = pybind11;
33 #endif
34 
35 #include "minddata/dataset/include/dataset/constants.h"
36 #include "minddata/dataset/util/status.h"
37 #include "minddata/dataset/core/global_context.h"
38 #include "minddata/dataset/util/allocator.h"
39 
40 namespace mindspore {
41 namespace dataset {
42 // Class that represents a shape of a Tensor. A shape can be:
43 // -# Known shape (mKnown = true)
44 //        -# Scalar --> empty vector        --> <>
45 //        -# n-Dim  --> not empty vector    --> <d1, d2, d2, d3, ...> where di is >= 0\n
46 //           Example: <1,2>,  <1>,   <1,13,10,11,1>
47 // -# Unknown shape (mKnown = false)
48 //        -# Rank is unknown            --> empty vector     --> <>
49 //        -# one or more dim is unknown --> not empty vector --> <d1, d2, d2, d3, ...> where di is unknown\n
50 //           Example: <3,?> (the 1st dim is unknown)\n
51 //              <2,?,?,?> (all dims but the 0th dim are unknown)
52 
53 /// \brief  TensorShape supports any dim > 0 and < 2^31-1
54 class DATASET_API TensorShape {
55  public:
56   static constexpr dsize_t kDimUnknown = -1;  // constant for an unknown dimension
57 
58   // Force the compiler to not create a no-arg constructor
59   TensorShape() = delete;
60 
61   /// \brief Create a Shape from an initialization list (e.g., TensorShape s = {2,2}).
62   ///     If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown
63   /// \param[in] list Length list of each axis.
64   TensorShape(const std::initializer_list<dsize_t> &list);
65 
66   /// \brief Create a Shape from a vector (e.g., TensorShape s = std::vector<dsize_t>({2,2}) ).
67   ///     If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown
68   /// \param[in] list
69   explicit TensorShape(const std::vector<dsize_t> &list);
70 
71   /// \brief Copy constructor.
72   /// \param[in] shape TensorShape to copy from.
73   TensorShape(const TensorShape &shape);
74 
75   /// \brief Move constructor.
76   /// \param[in] shape TensorShape to copy from.
77   TensorShape(TensorShape &&shape) noexcept;
78 
79   /// \brief Copy assignment.
80   /// \param[in] shape TensorShape to move from.
81   TensorShape &operator=(const TensorShape &shape);
82 
83   /// \brief Move assignment.
84   /// \param[in] shape TensorShape to move from.
85   TensorShape &operator=(TensorShape &&shape) noexcept;
86 
87 #ifdef ENABLE_PYTHON
88   /// \brief Construct a TensorShape via a python list.
89   /// \param[in] l A py::list of the shape.
90   explicit TensorShape(py::list l);
91 #endif
92 
93   ~TensorShape() = default;
94 
95   /// \brief Create a scalar Shape (i.e., empty shape with mKnown = true)
96   /// \return TensorShape
CreateScalar()97   static TensorShape CreateScalar() {
98     static std::vector<dsize_t> empty_shape{};
99     return TensorShape(empty_shape);
100   }
101 
102   /// \brief Create a shape with an unknown rank.
103   /// \return TensorShape
104   static TensorShape CreateUnknownRankShape();
105 
106   /// \brief Create a shape with a known rank .
107   /// \return TensorShape
108   static TensorShape CreateUnknownShapeWithRank(dsize_t rank);
109 
110   /// \brief Insert a new dim into a copy of the current shape.
111   /// \param[in] dim to be added
112   /// \param[in] axis the index where dim should be added
113   /// \return New modified shape
114   TensorShape InsertDim(dsize_t axis, dsize_t dim) const;
115 
116   /// \brief Insert new dim at index 0. For example,  <2,4> --> PrependDim(4) --> <4,2,4>
117   /// \param[in] dim
118   /// \return
119   TensorShape PrependDim(dsize_t dim) const;
120 
121   /// \brief Insert a new dim at the end of the shape. For example,  <2,4> --> AppendDim(4) --> <2,4,4>
122   /// \param[in] dim
123   /// \return
124   TensorShape AppendDim(dsize_t dim) const;
125 
126 #ifndef ENABLE_ANDROID
127   /// \brief Create a shape based on OpenCV shape and type
128   /// \param[in] cv_size
129   /// \param[in] type int that represent the type in OpenCV, example CV_8U, CV_64S
130   TensorShape(cv::MatSize cv_size, uint32_t type);
131 #endif
132 
Size()133   dsize_t Size() const { return raw_shape_.size(); }
134 
Rank()135   dsize_t Rank() const { return raw_shape_.size(); }
136 
known()137   bool known() const { return known_; }
138 
empty()139   bool empty() const { return raw_shape_.empty(); }
140 
141   dsize_t NumOfElements() const;
142 
143   bool operator==(const TensorShape &rhs) const { return known_ == rhs.known_ && raw_shape_ == rhs.raw_shape_; }
144 
145   bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); }
146 
147   dsize_t operator[](const dsize_t index) const {
148     if (index < 0) {
149       return raw_shape_[raw_shape_.size() + index];
150     }
151     return raw_shape_[index];
152   }
153 
154   /// \brief Return the Shape as a vector
155   /// \return
156   std::vector<dsize_t> AsVector() const;
157 
158   /// \brief Returns the class info as a string
159   /// \return
ToString()160   std::string ToString() const {
161     std::stringstream ss;
162     ss << *this;
163     return ss.str();
164   }
165 
166   /// \brief Actual print function used by operator<<
167   /// \param out output string stream
168   void Print(std::ostream &out) const;
169 
170   /// \brief << Stream output operator overload
171   ///     This allows you to print the info using stream operators
172   /// \param[in] out - reference to the output stream being overloaded
173   /// \param[in] rO - reference to the TensorShape to display
174   /// \return - the output stream must be returned
175   friend std::ostream &operator<<(std::ostream &out, const TensorShape &so) {
176     so.Print(out);
177     return out;
178   }
179 
180 #ifdef ENABLE_PYTHON
181   py::list AsPyList();
182 #endif
183 
184   /// \brief Checks if the given index is a valid index for this tensor.
185   ///     For example: Tensor<3,4> Index<1,1> is valid. But Index<4,1> or <1> are not.
186   /// \param[in] index
187   /// \return bool
188   bool IsValidIndex(const std::vector<dsize_t> &index) const;
189 
190   TensorShape Squeeze() const;
191 
192   std::vector<dsize_t> Strides() const;
193 
194   /// \brief Returns the location of the item assuming row major memory layout.
195   /// \param[in] index
196   /// \param[out] flat_index
197   /// \return
198   Status ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const;
199 
200  private:
201   // Vector to keep the dims of the shape.
202   std::vector<dsize_t> raw_shape_;
203   // Vector to keep the strides of the shape. The size is rank+1
204   std::vector<dsize_t> strides_;
205   // True if known and valid shape, false otherwise
206   bool known_;
207 
208   /// \brief Internal utility function to iterate over a list,
209   ///     check if the dim is valid and then insert it into the shape.
210   /// \param[in] list Iterable list
211   /// \return true if the shape is valid and no overflow would be generated when counting the number of elements.
212   ///     False otherwise.
213   template <typename T>
214   void AddListToShape(const T &list);
215 };
216 }  // namespace dataset
217 }  // namespace mindspore
218 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_TENSOR_SHAPE_H_
219