• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/core/data_type.h"
17 #ifdef ENABLE_PYTHON
18 #include "minddata/dataset/core/pybind_support.h"
19 #endif
20 
21 #ifndef ENABLE_ANDROID
22 #include "utils/log_adapter.h"
23 #else
24 #include "mindspore/lite/src/common/log_adapter.h"
25 #endif
26 
27 namespace mindspore {
28 namespace dataset {
29 
SizeInBytes() const30 uint8_t DataType::SizeInBytes() const {
31   if (type_ < DataType::NUM_OF_TYPES)
32     return kTypeInfo[type_].sizeInBytes_;
33   else
34     return 0;
35 }
36 
37 #ifdef ENABLE_PYTHON
AsNumpyType() const38 py::dtype DataType::AsNumpyType() const {
39   if (type_ < DataType::NUM_OF_TYPES)
40     return py::dtype(kTypeInfo[type_].pybindType_);
41   else
42     return py::dtype("unknown");
43 }
44 #endif
45 
46 #ifndef ENABLE_ANDROID
AsCVType() const47 uint8_t DataType::AsCVType() const {
48   uint8_t res = kCVInvalidType;
49   if (type_ < DataType::NUM_OF_TYPES) {
50     res = kTypeInfo[type_].cvType_;
51   }
52 
53   if (res == kCVInvalidType) {
54     std::string type_name = "unknown";
55     if (type_ < DataType::NUM_OF_TYPES) {
56       type_name = std::string(kTypeInfo[type_].name_);
57     }
58     std::string err_msg = "Cannot convert [" + type_name + "] to OpenCV type.";
59     err_msg += " Currently unsupported data type: [uint32, int64, uint64, string]";
60     MS_LOG(ERROR) << err_msg;
61   }
62 
63   return res;
64 }
65 
FromCVType(int cv_type)66 DataType DataType::FromCVType(int cv_type) {
67   auto depth = static_cast<uchar>(cv_type) & static_cast<uchar>(CV_MAT_DEPTH_MASK);
68   switch (depth) {
69     case CV_8S:
70       return DataType(DataType::DE_INT8);
71     case CV_8U:
72       return DataType(DataType::DE_UINT8);
73     case CV_16S:
74       return DataType(DataType::DE_INT16);
75     case CV_16U:
76       return DataType(DataType::DE_UINT16);
77     case CV_32S:
78       return DataType(DataType::DE_INT32);
79     case CV_16F:
80       return DataType(DataType::DE_FLOAT16);
81     case CV_32F:
82       return DataType(DataType::DE_FLOAT32);
83     case CV_64F:
84       return DataType(DataType::DE_FLOAT64);
85     default:
86       MS_LOG(ERROR) << "Cannot convert from OpenCV type, unknown CV type. Unknown data type is returned!";
87       return DataType(DataType::DE_UNKNOWN);
88   }
89 }
90 #endif
91 
DataType(const std::string & type_str)92 DataType::DataType(const std::string &type_str) {
93   if (type_str == "bool")
94     type_ = DE_BOOL;
95   else if (type_str == "int8")
96     type_ = DE_INT8;
97   else if (type_str == "uint8")
98     type_ = DE_UINT8;
99   else if (type_str == "int16")
100     type_ = DE_INT16;
101   else if (type_str == "uint16")
102     type_ = DE_UINT16;
103   else if (type_str == "int32")
104     type_ = DE_INT32;
105   else if (type_str == "uint32")
106     type_ = DE_UINT32;
107   else if (type_str == "int64")
108     type_ = DE_INT64;
109   else if (type_str == "uint64")
110     type_ = DE_UINT64;
111   else if (type_str == "float16")
112     type_ = DE_FLOAT16;
113   else if (type_str == "float32")
114     type_ = DE_FLOAT32;
115   else if (type_str == "float64")
116     type_ = DE_FLOAT64;
117   else if (type_str == "string")
118     type_ = DE_STRING;
119   else
120     type_ = DE_UNKNOWN;
121 }
122 
ToString() const123 std::string DataType::ToString() const {
124   if (type_ < DataType::NUM_OF_TYPES)
125     return kTypeInfo[type_].name_;
126   else
127     return "unknown";
128 }
129 
130 #ifdef ENABLE_PYTHON
FromNpArray(const py::array & arr)131 DataType DataType::FromNpArray(const py::array &arr) {
132   if (py::isinstance<py::array_t<bool>>(arr)) {
133     return DataType(DataType::DE_BOOL);
134   } else if (py::isinstance<py::array_t<std::int8_t>>(arr)) {
135     return DataType(DataType::DE_INT8);
136   } else if (py::isinstance<py::array_t<std::uint8_t>>(arr)) {
137     return DataType(DataType::DE_UINT8);
138   } else if (py::isinstance<py::array_t<std::int16_t>>(arr)) {
139     return DataType(DataType::DE_INT16);
140   } else if (py::isinstance<py::array_t<std::uint16_t>>(arr)) {
141     return DataType(DataType::DE_UINT16);
142   } else if (py::isinstance<py::array_t<std::int32_t>>(arr)) {
143     return DataType(DataType::DE_INT32);
144   } else if (py::isinstance<py::array_t<std::uint32_t>>(arr)) {
145     return DataType(DataType::DE_UINT32);
146   } else if (py::isinstance<py::array_t<std::int64_t>>(arr)) {
147     return DataType(DataType::DE_INT64);
148   } else if (py::isinstance<py::array_t<std::uint64_t>>(arr)) {
149     return DataType(DataType::DE_UINT64);
150   } else if (py::isinstance<py::array_t<float16>>(arr)) {
151     return DataType(DataType::DE_FLOAT16);
152   } else if (py::isinstance<py::array_t<std::float_t>>(arr)) {
153     return DataType(DataType::DE_FLOAT32);
154   } else if (py::isinstance<py::array_t<std::double_t>>(arr)) {
155     return DataType(DataType::DE_FLOAT64);
156   } else if (arr.dtype().kind() == 'S' || arr.dtype().kind() == 'U') {
157     return DataType(DataType::DE_STRING);
158   } else {
159     MS_LOG(ERROR) << "Cannot convert from numpy type. Unknown data type is returned!";
160     return DataType(DataType::DE_UNKNOWN);
161   }
162 }
163 
GetPybindFormat() const164 std::string DataType::GetPybindFormat() const {
165   std::string res;
166   if (type_ < DataType::NUM_OF_TYPES) {
167     res = kTypeInfo[type_].pybindFormatDescriptor_;
168   }
169 
170   if (res.empty()) {
171     MS_LOG(ERROR) << "Cannot convert from data type to pybind format descriptor!";
172   }
173   return res;
174 }
175 #endif
176 
177 }  // namespace dataset
178 }  // namespace mindspore
179