• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_
18 
19 #if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON)
20 #include <opencv2/core/hal/interface.h>
21 #endif
22 
23 #include <string>
24 #include <utility>
25 
26 #ifdef ENABLE_MINDDATA_PYTHON
27 #include "pybind11/numpy.h"
28 #include "pybind11/pybind11.h"
29 #include "minddata/dataset/core/pybind_support.h"
30 namespace py = pybind11;
31 #else
32 #include "base/bfloat16.h"
33 #include "base/float16.h"
34 #endif
35 #include "minddata/dataset/include/dataset/constants.h"
36 
37 namespace mindspore {
38 namespace dataset {
39 // Class that represents basic data types in DataEngine.
40 class DataType {
41  public:
42   enum Type : uint8_t {
43     DE_UNKNOWN = 0,
44     DE_BOOL,
45     DE_INT8,
46     DE_UINT8,
47     DE_INT16,
48     DE_UINT16,
49     DE_INT32,
50     DE_UINT32,
51     DE_INT64,
52     DE_UINT64,
53     DE_FLOAT16,
54     DE_FLOAT32,
55     DE_FLOAT64,
56     DE_STRING,
57     DE_BYTES,
58     DE_PYTHON,
59     NUM_OF_TYPES
60   };
61 
62   struct TypeInfo {
63     const char *name_;                          // name to be represent the type while printing
64     const uint8_t sizeInBytes_;                 // number of bytes needed for this type
65     const char *pybindType_;                    //  Python matching type, used in get_output_types
66     const std::string pybindFormatDescriptor_;  // pybind format used for numpy types
67     const uint8_t cvType_;                      // OpenCv matching type
68   };
69 
70 #ifdef ENABLE_MINDDATA_PYTHON
71   static inline const TypeInfo kTypeInfo[] = {
72     // name, sizeInBytes, pybindType, pybindFormatDescriptor, openCV
73     {"unknown", 0, "object", "", kCVInvalidType},                                        // DE_UNKNOWN
74     {"bool", 1, "bool", py::format_descriptor<bool>::format(), CV_8U},                   // DE_BOOL
75     {"int8", 1, "int8", py::format_descriptor<int8_t>::format(), CV_8S},                 // DE_INT8
76     {"uint8", 1, "uint8", py::format_descriptor<uint8_t>::format(), CV_8U},              // DE_UINT8
77     {"int16", 2, "int16", py::format_descriptor<int16_t>::format(), CV_16S},             // DE_INT16
78     {"uint16", 2, "uint16", py::format_descriptor<uint16_t>::format(), CV_16U},          // DE_UINT16
79     {"int32", 4, "int32", py::format_descriptor<int32_t>::format(), CV_32S},             // DE_INT32
80     {"uint32", 4, "uint32", py::format_descriptor<uint32_t>::format(), kCVInvalidType},  // DE_UINT32
81     {"int64", 8, "int64", py::format_descriptor<int64_t>::format(), kCVInvalidType},     // DE_INT64
82     {"uint64", 8, "uint64", py::format_descriptor<uint64_t>::format(), kCVInvalidType},  // DE_UINT64
83     {"float16", 2, "float16", "e", CV_16F},                                              // DE_FLOAT16
84     {"float32", 4, "float32", py::format_descriptor<float>::format(), CV_32F},           // DE_FLOAT32
85     {"float64", 8, "double", py::format_descriptor<double>::format(), CV_64F},           // DE_FLOAT64
86     {"string", 0, "str", "U", kCVInvalidType},                                           // DE_STRING
87     {"bytes", 0, "bytes", "S", CV_8U},                                                   // DE_BYTES
88     {"python", 0, "object", "O", kCVInvalidType}                                         // DE_PYTHON
89   };
90 #else
91 #if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON)
92   static inline const TypeInfo kTypeInfo[] = {
93     // name, sizeInBytes, pybindTypem formatDescriptor, openCV
94     {"unknown", 0, "object", "", kCVInvalidType},  // DE_UNKNOWN
95     {"bool", 1, "bool", "", CV_8U},                // DE_BOOL
96     {"int8", 1, "int8", "", CV_8S},                // DE_INT8
97     {"uint8", 1, "uint8", "", CV_8U},              // DE_UINT8
98     {"int16", 2, "int16", "", CV_16S},             // DE_INT16
99     {"uint16", 2, "uint16", "", CV_16U},           // DE_UINT16
100     {"int32", 4, "int32", "", CV_32S},             // DE_INT32
101     {"uint32", 4, "uint32", "", kCVInvalidType},   // DE_UINT32
102     {"int64", 8, "int64", "", kCVInvalidType},     // DE_INT64
103     {"uint64", 8, "uint64", "", kCVInvalidType},   // DE_UINT64
104     {"float16", 2, "float16", "", CV_16F},         // DE_FLOAT16
105     {"float32", 4, "float32", "", CV_32F},         // DE_FLOAT32
106     {"float64", 8, "double", "", CV_64F},          // DE_FLOAT64
107     {"string", 0, "str", "", kCVInvalidType},      // DE_STRING
108     {"bytes", 0, "bytes", "", CV_8U},              // DE_BYTES
109     {"python", 0, "object", "O", kCVInvalidType}   // DE_PYTHON
110   };
111 #else
112   // android and no python
113   static inline const TypeInfo kTypeInfo[] = {
114     // name, sizeInBytes, formatDescriptor
115     {"unknown", 0, "object", "", kCVInvalidType},  // DE_UNKNOWN
116     {"bool", 1, "bool", ""},                       // DE_BOOL
117     {"int8", 1, "int8", ""},                       // DE_INT8
118     {"uint8", 1, "uint8", ""},                     // DE_UINT8
119     {"int16", 2, "int16", ""},                     // DE_INT16
120     {"uint16", 2, "uint16", ""},                   // DE_UINT16
121     {"int32", 4, "int32", ""},                     // DE_INT32
122     {"uint32", 4, "uint32", "", kCVInvalidType},   // DE_UINT32
123     {"int64", 8, "int64", "", kCVInvalidType},     // DE_INT64
124     {"uint64", 8, "uint64", "", kCVInvalidType},   // DE_UINT64
125     {"float16", 2, "float16", ""},                 // DE_FLOAT16
126     {"float32", 4, "float32", ""},                 // DE_FLOAT32
127     {"float64", 8, "double", ""},                  // DE_FLOAT64
128     {"string", 0, "str", "", kCVInvalidType},      // DE_STRING
129     {"bytes", 0, "bytes", ""},                     // DE_BYTES
130     {"python", 0, "object", "O", kCVInvalidType}   // DE_PYTHON
131   };
132 #endif
133 #endif
134   // No arg constructor to create an unknown shape
DataType()135   DataType() : type_(DE_UNKNOWN) {}
136 
137   // Create a type from a given string
138   /// \param type_str
139   explicit DataType(const std::string &type_str);
140 
141   // Default destructor
142   ~DataType() = default;
143 
144   // Create a type from a given enum
145   /// \param type
DataType(const Type & type)146   constexpr explicit DataType(const Type &type) : type_(std::move(type)) {}
147 
148   constexpr bool operator==(const DataType a) const { return type_ == a.type_; }
149 
150   constexpr bool operator==(const Type a) const { return type_ == a; }
151 
152   constexpr bool operator!=(const DataType a) const { return type_ != a.type_; }
153 
154   constexpr bool operator!=(const Type a) const { return type_ != a; }
155 
156   // Disable this usage `if(d)` where d is of type DataType
157   /// \return return nothing since we deiable this function.
158   operator bool() = delete;
159 
160   // To be used in Switch/case
161   /// \return data type internal.
Type()162   operator Type() const { return type_; }
163 
164   // The number of bytes needed to store one value of this type
165   /// \return the number of bytes of the type.
166   uint8_t SizeInBytes() const;
167 
168 #if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON)
169   // Convert from DataType to OpenCV type
170   /// \return
171   uint8_t AsCVType() const;
172 
173   // Convert from OpenCV type to DataType
174   /// \param cv_type
175   /// \return
176   static DataType FromCVType(int cv_type);
177 #endif
178 
179   // Returns a string representation of the type
180   /// \return
181   std::string ToString() const;
182 
183   // returns true if the template type is the same as the Tensor type_
184   /// \tparam T
185   /// \return true or false
186   template <typename T>
IsCompatible()187   bool IsCompatible() const {
188     return type_ == FromCType<T>();
189   }
190 
191   // returns true if the template type is the same as the Tensor type_
192   /// \tparam T
193   /// \return true or false
194   template <typename T>
195   bool IsLooselyCompatible() const;
196 
197   // << Stream output operator overload
198   /// \notes This allows you to print the info using stream operators
199   /// \param out - reference to the output stream being overloaded
200   /// \param rO - reference to the DataType to display
201   /// \return - the output stream must be returned
202   friend std::ostream &operator<<(std::ostream &out, const DataType &so) {
203     out << so.ToString();
204     return out;
205   }
206 
207   template <typename T>
208   static DataType FromCType();
209 
210 #ifdef ENABLE_MINDDATA_PYTHON
211   // Convert from DataType to Pybind type
212   /// \return
213   py::dtype AsNumpyType() const;
214 
215   // Convert from NP type to DataType
216   /// \param type
217   /// \return
218   static DataType FromNpType(const py::dtype &type);
219 
220   // Convert from NP array to DataType
221   /// \param py array
222   /// \return
223   static DataType FromNpArray(const py::array &arr);
224 #endif
225 
226   // Get the buffer string format of the current type. Used in pybind buffer protocol.
227   /// \return
228   std::string GetPybindFormat() const;
229 
IsSignedInt()230   bool IsSignedInt() const {
231     return type_ == DataType::DE_INT8 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT32 ||
232            type_ == DataType::DE_INT64;
233   }
234 
IsUnsignedInt()235   bool IsUnsignedInt() const {
236     return type_ == DataType::DE_UINT8 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT32 ||
237            type_ == DataType::DE_UINT64;
238   }
239 
IsInt()240   bool IsInt() const { return IsSignedInt() || IsUnsignedInt(); }
241 
IsFloat()242   bool IsFloat() const {
243     return type_ == DataType::DE_FLOAT16 || type_ == DataType::DE_FLOAT32 || type_ == DataType::DE_FLOAT64;
244   }
245 
IsBool()246   bool IsBool() const { return type_ == DataType::DE_BOOL; }
247 
IsNumeric()248   bool IsNumeric() const { return IsInt() || IsFloat() || IsBool(); }
249 
IsString()250   bool IsString() const { return type_ == DataType::DE_STRING || type_ == DataType::DE_BYTES; }
251 
IsPython()252   bool IsPython() const { return type_ == DataType::DE_PYTHON; }
253 
value()254   Type value() const { return type_; }
255 
256  private:
257   Type type_;
258 };
259 
260 template <>
261 inline DataType DataType::FromCType<bool>() {
262   return DataType(DataType::DE_BOOL);
263 }
264 
265 template <>
266 inline DataType DataType::FromCType<double>() {
267   return DataType(DataType::DE_FLOAT64);
268 }
269 
270 template <>
271 inline DataType DataType::FromCType<float>() {
272   return DataType(DataType::DE_FLOAT32);
273 }
274 
275 template <>
276 inline DataType DataType::FromCType<float16>() {
277   return DataType(DataType::DE_FLOAT16);
278 }
279 
280 template <>
281 inline DataType DataType::FromCType<int64_t>() {
282   return DataType(DataType::DE_INT64);
283 }
284 
285 template <>
286 inline DataType DataType::FromCType<uint64_t>() {
287   return DataType(DataType::DE_UINT64);
288 }
289 
290 template <>
291 inline DataType DataType::FromCType<int32_t>() {
292   return DataType(DataType::DE_INT32);
293 }
294 
295 template <>
296 inline DataType DataType::FromCType<uint32_t>() {
297   return DataType(DataType::DE_UINT32);
298 }
299 
300 template <>
301 inline DataType DataType::FromCType<int16_t>() {
302   return DataType(DataType::DE_INT16);
303 }
304 
305 template <>
306 inline DataType DataType::FromCType<uint16_t>() {
307   return DataType(DataType::DE_UINT16);
308 }
309 
310 template <>
311 inline DataType DataType::FromCType<int8_t>() {
312   return DataType(DataType::DE_INT8);
313 }
314 
315 template <>
316 inline DataType DataType::FromCType<uint8_t>() {
317   return DataType(DataType::DE_UINT8);
318 }
319 
320 template <>
321 inline DataType DataType::FromCType<std::string_view>() {
322   return DataType(DataType::DE_STRING);
323 }
324 
325 template <>
326 inline DataType DataType::FromCType<std::string>() {
327   return DataType(DataType::DE_STRING);
328 }
329 
330 template <>
331 inline bool DataType::IsLooselyCompatible<bool>() const {
332   return type_ == DataType::DE_BOOL;
333 }
334 
335 template <>
336 inline bool DataType::IsLooselyCompatible<double>() const {
337   return type_ == DataType::DE_FLOAT64 || type_ == DataType::DE_FLOAT32;
338 }
339 
340 template <>
341 inline bool DataType::IsLooselyCompatible<float>() const {
342   return type_ == DataType::DE_FLOAT32;
343 }
344 
345 template <>
346 inline bool DataType::IsLooselyCompatible<float16>() const {
347   return type_ == DataType::DE_FLOAT16;
348 }
349 
350 template <>
351 inline bool DataType::IsLooselyCompatible<int64_t>() const {
352   return type_ == DataType::DE_INT64 || type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 ||
353          type_ == DataType::DE_INT8;
354 }
355 
356 template <>
357 inline bool DataType::IsLooselyCompatible<uint64_t>() const {
358   return type_ == DataType::DE_UINT64 || type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 ||
359          type_ == DataType::DE_UINT8;
360 }
361 
362 template <>
363 inline bool DataType::IsLooselyCompatible<int32_t>() const {
364   return type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8;
365 }
366 
367 template <>
368 inline bool DataType::IsLooselyCompatible<uint32_t>() const {
369   return type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8;
370 }
371 
372 template <>
373 inline bool DataType::IsLooselyCompatible<int16_t>() const {
374   return type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8;
375 }
376 
377 template <>
378 inline bool DataType::IsLooselyCompatible<uint16_t>() const {
379   return type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8;
380 }
381 
382 template <>
383 inline bool DataType::IsLooselyCompatible<int8_t>() const {
384   return type_ == DataType::DE_INT8;
385 }
386 
387 template <>
388 inline bool DataType::IsLooselyCompatible<uint8_t>() const {
389   return type_ == DataType::DE_UINT8;
390 }
391 
392 template <>
393 inline bool DataType::IsLooselyCompatible<std::string>() const {
394   return type_ == DataType::DE_STRING || type_ == DataType::DE_BYTES;
395 }
396 }  // namespace dataset
397 }  // namespace mindspore
398 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_
399