• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_
18 
19 #ifndef ENABLE_ANDROID
20 #include <opencv2/core/hal/interface.h>
21 #endif
22 
23 #include <string>
24 #ifdef ENABLE_PYTHON
25 #include "pybind11/numpy.h"
26 #include "pybind11/pybind11.h"
27 #include "minddata/dataset/core/pybind_support.h"
28 namespace py = pybind11;
29 #else
30 #include "base/float16.h"
31 #endif
32 #include "minddata/dataset/include/dataset/constants.h"
33 namespace mindspore {
34 namespace dataset {
35 
36 // Class that represents basic data types in DataEngine.
37 class DataType {
38  public:
39   enum Type : uint8_t {
40     DE_UNKNOWN = 0,
41     DE_BOOL,
42     DE_INT8,
43     DE_UINT8,
44     DE_INT16,
45     DE_UINT16,
46     DE_INT32,
47     DE_UINT32,
48     DE_INT64,
49     DE_UINT64,
50     DE_FLOAT16,
51     DE_FLOAT32,
52     DE_FLOAT64,
53     DE_STRING,
54     NUM_OF_TYPES
55   };
56 
57   struct TypeInfo {
58     const char *name_;                          // name to be represent the type while printing
59     const uint8_t sizeInBytes_;                 // number of bytes needed for this type
60     const char *pybindType_;                    //  Python matching type, used in get_output_types
61     const std::string pybindFormatDescriptor_;  // pybind format used for numpy types
62     const uint8_t cvType_;                      // OpenCv matching type
63   };
64 
65 #ifdef ENABLE_PYTHON
66   static inline const TypeInfo kTypeInfo[] = {
67     // name, sizeInBytes, pybindTypem formatDescriptor, openCV
68     {"unknown", 0, "object", "", kCVInvalidType},                                        // DE_UNKNOWN
69     {"bool", 1, "bool", py::format_descriptor<bool>::format(), CV_8U},                   // DE_BOOL
70     {"int8", 1, "int8", py::format_descriptor<int8_t>::format(), CV_8S},                 // DE_INT8
71     {"uint8", 1, "uint8", py::format_descriptor<uint8_t>::format(), CV_8U},              // DE_UINT8
72     {"int16", 2, "int16", py::format_descriptor<int16_t>::format(), CV_16S},             // DE_INT16
73     {"uint16", 2, "uint16", py::format_descriptor<uint16_t>::format(), CV_16U},          // DE_UINT16
74     {"int32", 4, "int32", py::format_descriptor<int32_t>::format(), CV_32S},             // DE_INT32
75     {"uint32", 4, "uint32", py::format_descriptor<uint32_t>::format(), kCVInvalidType},  // DE_UINT32
76     {"int64", 8, "int64", py::format_descriptor<int64_t>::format(), kCVInvalidType},     // DE_INT64
77     {"uint64", 8, "uint64", py::format_descriptor<uint64_t>::format(), kCVInvalidType},  // DE_UINT64
78     {"float16", 2, "float16", "e", CV_16F},                                              // DE_FLOAT16
79     {"float32", 4, "float32", py::format_descriptor<float>::format(), CV_32F},           // DE_FLOAT32
80     {"float64", 8, "double", py::format_descriptor<double>::format(), CV_64F},           // DE_FLOAT64
81     {"string", 0, "bytes", "S", kCVInvalidType}                                          // DE_STRING
82   };
83 #else
84 #ifndef ENABLE_ANDROID
85   static inline const TypeInfo kTypeInfo[] = {
86     // name, sizeInBytes, pybindTypem formatDescriptor, openCV
87     {"unknown", 0, "object", "", kCVInvalidType},  // DE_UNKNOWN
88     {"bool", 1, "bool", "", CV_8U},                // DE_BOOL
89     {"int8", 1, "int8", "", CV_8S},                // DE_INT8
90     {"uint8", 1, "uint8", "", CV_8U},              // DE_UINT8
91     {"int16", 2, "int16", "", CV_16S},             // DE_INT16
92     {"uint16", 2, "uint16", "", CV_16U},           // DE_UINT16
93     {"int32", 4, "int32", "", CV_32S},             // DE_INT32
94     {"uint32", 4, "uint32", "", kCVInvalidType},   // DE_UINT32
95     {"int64", 8, "int64", "", kCVInvalidType},     // DE_INT64
96     {"uint64", 8, "uint64", "", kCVInvalidType},   // DE_UINT64
97     {"float16", 2, "float16", "", CV_16F},         // DE_FLOAT16
98     {"float32", 4, "float32", "", CV_32F},         // DE_FLOAT32
99     {"float64", 8, "double", "", CV_64F},          // DE_FLOAT64
100     {"string", 0, "bytes", "", kCVInvalidType}     // DE_STRING
101   };
102 #else
103   // android and no python
104   static inline const TypeInfo kTypeInfo[] = {
105     // name, sizeInBytes, formatDescriptor
106     {"unknown", 0, "object", "", kCVInvalidType},  // DE_UNKNOWN
107     {"bool", 1, "bool", ""},                       // DE_BOOL
108     {"int8", 1, "int8", ""},                       // DE_INT8
109     {"uint8", 1, "uint8", ""},                     // DE_UINT8
110     {"int16", 2, "int16", ""},                     // DE_INT16
111     {"uint16", 2, "uint16", ""},                   // DE_UINT16
112     {"int32", 4, "int32", ""},                     // DE_INT32
113     {"uint32", 4, "uint32", "", kCVInvalidType},   // DE_UINT32
114     {"int64", 8, "int64", "", kCVInvalidType},     // DE_INT64
115     {"uint64", 8, "uint64", "", kCVInvalidType},   // DE_UINT64
116     {"float16", 2, "float16", ""},                 // DE_FLOAT16
117     {"float32", 4, "float32", ""},                 // DE_FLOAT32
118     {"float64", 8, "double", ""},                  // DE_FLOAT64
119     {"string", 0, "bytes", "", kCVInvalidType}     // DE_STRING
120   };
121 #endif
122 #endif
123   // No arg constructor to create an unknown shape
DataType()124   DataType() : type_(DE_UNKNOWN) {}
125 
126   // Create a type from a given string
127   /// \param type_str
128   explicit DataType(const std::string &type_str);
129 
130   // Default destructor
131   ~DataType() = default;
132 
133   // Create a type from a given enum
134   /// \param d
DataType(Type d)135   constexpr explicit DataType(Type d) : type_(d) {}
136 
137   constexpr bool operator==(const DataType a) const { return type_ == a.type_; }
138 
139   constexpr bool operator==(const Type a) const { return type_ == a; }
140 
141   constexpr bool operator!=(const DataType a) const { return type_ != a.type_; }
142 
143   constexpr bool operator!=(const Type a) const { return type_ != a; }
144 
145   // Disable this usage `if(d)` where d is of type DataType
146   /// \return return nothing since we deiable this function.
147   operator bool() = delete;
148 
149   // To be used in Switch/case
150   /// \return data type internal.
Type()151   operator Type() const { return type_; }
152 
153   // The number of bytes needed to store one value of this type
154   /// \return the number of bytes of the type.
155   uint8_t SizeInBytes() const;
156 
157 #ifndef ENABLE_ANDROID
158   // Convert from DataType to OpenCV type
159   /// \return
160   uint8_t AsCVType() const;
161 
162   // Convert from OpenCV type to DataType
163   /// \param cv_type
164   /// \return
165   static DataType FromCVType(int cv_type);
166 #endif
167 
168   // Returns a string representation of the type
169   /// \return
170   std::string ToString() const;
171 
172   // returns true if the template type is the same as the Tensor type_
173   /// \tparam T
174   /// \return true or false
175   template <typename T>
IsCompatible()176   bool IsCompatible() const {
177     return type_ == FromCType<T>();
178   }
179 
180   // returns true if the template type is the same as the Tensor type_
181   /// \tparam T
182   /// \return true or false
183   template <typename T>
184   bool IsLooselyCompatible() const;
185 
186   // << Stream output operator overload
187   /// \notes This allows you to print the info using stream operators
188   /// \param out - reference to the output stream being overloaded
189   /// \param rO - reference to the DataType to display
190   /// \return - the output stream must be returned
191   friend std::ostream &operator<<(std::ostream &out, const DataType &so) {
192     out << so.ToString();
193     return out;
194   }
195 
196   template <typename T>
197   static DataType FromCType();
198 
199 #ifdef ENABLE_PYTHON
200   // Convert from DataType to Pybind type
201   /// \return
202   py::dtype AsNumpyType() const;
203 
204   // Convert from NP type to DataType
205   /// \param type
206   /// \return
207   static DataType FromNpType(const py::dtype &type);
208 
209   // Convert from NP array to DataType
210   /// \param py array
211   /// \return
212   static DataType FromNpArray(const py::array &arr);
213 #endif
214 
215   // Get the buffer string format of the current type. Used in pybind buffer protocol.
216   /// \return
217   std::string GetPybindFormat() const;
218 
IsSignedInt()219   bool IsSignedInt() const {
220     return type_ == DataType::DE_INT8 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT32 ||
221            type_ == DataType::DE_INT64;
222   }
223 
IsUnsignedInt()224   bool IsUnsignedInt() const {
225     return type_ == DataType::DE_UINT8 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT32 ||
226            type_ == DataType::DE_UINT64;
227   }
228 
IsInt()229   bool IsInt() const { return IsSignedInt() || IsUnsignedInt(); }
230 
IsFloat()231   bool IsFloat() const {
232     return type_ == DataType::DE_FLOAT16 || type_ == DataType::DE_FLOAT32 || type_ == DataType::DE_FLOAT64;
233   }
234 
IsBool()235   bool IsBool() const { return type_ == DataType::DE_BOOL; }
236 
IsNumeric()237   bool IsNumeric() const { return type_ != DataType::DE_STRING; }
238 
value()239   Type value() const { return type_; }
240 
241  private:
242   Type type_;
243 };
244 
245 template <>
246 inline DataType DataType::FromCType<bool>() {
247   return DataType(DataType::DE_BOOL);
248 }
249 
250 template <>
251 inline DataType DataType::FromCType<double>() {
252   return DataType(DataType::DE_FLOAT64);
253 }
254 
255 template <>
256 inline DataType DataType::FromCType<float>() {
257   return DataType(DataType::DE_FLOAT32);
258 }
259 
260 template <>
261 inline DataType DataType::FromCType<float16>() {
262   return DataType(DataType::DE_FLOAT16);
263 }
264 
265 template <>
266 inline DataType DataType::FromCType<int64_t>() {
267   return DataType(DataType::DE_INT64);
268 }
269 
270 template <>
271 inline DataType DataType::FromCType<uint64_t>() {
272   return DataType(DataType::DE_UINT64);
273 }
274 
275 template <>
276 inline DataType DataType::FromCType<int32_t>() {
277   return DataType(DataType::DE_INT32);
278 }
279 
280 template <>
281 inline DataType DataType::FromCType<uint32_t>() {
282   return DataType(DataType::DE_UINT32);
283 }
284 
285 template <>
286 inline DataType DataType::FromCType<int16_t>() {
287   return DataType(DataType::DE_INT16);
288 }
289 
290 template <>
291 inline DataType DataType::FromCType<uint16_t>() {
292   return DataType(DataType::DE_UINT16);
293 }
294 
295 template <>
296 inline DataType DataType::FromCType<int8_t>() {
297   return DataType(DataType::DE_INT8);
298 }
299 
300 template <>
301 inline DataType DataType::FromCType<uint8_t>() {
302   return DataType(DataType::DE_UINT8);
303 }
304 
305 template <>
306 inline DataType DataType::FromCType<std::string_view>() {
307   return DataType(DataType::DE_STRING);
308 }
309 
310 template <>
311 inline DataType DataType::FromCType<std::string>() {
312   return DataType(DataType::DE_STRING);
313 }
314 
315 template <>
316 inline bool DataType::IsLooselyCompatible<bool>() const {
317   return type_ == DataType::DE_BOOL;
318 }
319 
320 template <>
321 inline bool DataType::IsLooselyCompatible<double>() const {
322   return type_ == DataType::DE_FLOAT64 || type_ == DataType::DE_FLOAT32;
323 }
324 
325 template <>
326 inline bool DataType::IsLooselyCompatible<float>() const {
327   return type_ == DataType::DE_FLOAT32;
328 }
329 
330 template <>
331 inline bool DataType::IsLooselyCompatible<float16>() const {
332   return type_ == DataType::DE_FLOAT16;
333 }
334 
335 template <>
336 inline bool DataType::IsLooselyCompatible<int64_t>() const {
337   return type_ == DataType::DE_INT64 || type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 ||
338          type_ == DataType::DE_INT8;
339 }
340 
341 template <>
342 inline bool DataType::IsLooselyCompatible<uint64_t>() const {
343   return type_ == DataType::DE_UINT64 || type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 ||
344          type_ == DataType::DE_UINT8;
345 }
346 
347 template <>
348 inline bool DataType::IsLooselyCompatible<int32_t>() const {
349   return type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8;
350 }
351 
352 template <>
353 inline bool DataType::IsLooselyCompatible<uint32_t>() const {
354   return type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8;
355 }
356 
357 template <>
358 inline bool DataType::IsLooselyCompatible<int16_t>() const {
359   return type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8;
360 }
361 
362 template <>
363 inline bool DataType::IsLooselyCompatible<uint16_t>() const {
364   return type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8;
365 }
366 
367 template <>
368 inline bool DataType::IsLooselyCompatible<int8_t>() const {
369   return type_ == DataType::DE_INT8;
370 }
371 
372 template <>
373 inline bool DataType::IsLooselyCompatible<uint8_t>() const {
374   return type_ == DataType::DE_UINT8;
375 }
376 }  // namespace dataset
377 }  // namespace mindspore
378 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_
379