1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/dtype/type.h"
20
21 #include <algorithm>
22 #include <cstdlib>
23 #include <climits>
24
25 #include "ir/dtype/number.h"
26 #include "utils/log_adapter.h"
27 #include "utils/convert_utils_base.h"
28
29 namespace mindspore {
30 static mindspore::HashMap<TypeId, std::string> g_type_2_lable{{kTypeUnknown, "Unknown"},
31 {kMetaTypeType, "Type"},
32 {kMetaTypeAny, "Any"},
33 {kMetaTypeObject, "Object"},
34 {kMetaTypeTypeType, "TypeType"},
35 {kMetaTypeProblem, "Problem"},
36 {kMetaTypeExternal, "External"},
37 {kMetaTypeNone, "None"},
38 {kMetaTypeNull, "Null"},
39 {kMetaTypeEllipsis, "Ellipsis"},
40 {kObjectTypeNumber, "Number"},
41 {kObjectTypeString, "String"},
42 {kObjectTypeList, "List"},
43 {kObjectTypeTuple, "Tuple"},
44 {kObjectTypeSlice, "Slice"},
45 {kObjectTypeKeyword, "Keyword"},
46 {kObjectTypeTensorType, "Tensor"},
47 {kObjectTypeMapTensorType, "MapTensor"},
48 {kObjectTypeRowTensorType, "RowTensor"},
49 {kObjectTypeCOOTensorType, "COOTensor"},
50 {kObjectTypeCSRTensorType, "CSRTensor"},
51 {kObjectTypeUndeterminedType, "Undetermined"},
52 {kObjectTypeClass, "Class"},
53 {kObjectTypeDictionary, "Dictionary"},
54 {kObjectTypeFunction, "Function"},
55 {kObjectTypeJTagged, "JTagged"},
56 {kObjectTypeSymbolicKeyType, "SymbolicKey"},
57 {kObjectTypeEnvType, "EnvType"},
58 {kObjectTypeRefKey, "RefKey"},
59 {kObjectTypeRef, "Ref"},
60 {kNumberTypeBool, "Bool"},
61 {kNumberTypeInt, "Int"},
62 {kNumberTypeInt4, "QInt4x2"},
63 {kNumberTypeInt8, "Int8"},
64 {kNumberTypeInt16, "Int16"},
65 {kNumberTypeInt32, "Int32"},
66 {kNumberTypeInt64, "Int64"},
67 {kNumberTypeUInt, "UInt"},
68 {kNumberTypeUInt8, "UInt8"},
69 {kNumberTypeUInt16, "UInt16"},
70 {kNumberTypeUInt32, "UInt32"},
71 {kNumberTypeUInt64, "UInt64"},
72 {kNumberTypeFloat, "Float"},
73 {kNumberTypeFloat16, "Float16"},
74 {kNumberTypeFloat32, "Float32"},
75 {kNumberTypeFloat64, "Float64"},
76 {kNumberTypeBFloat16, "BFloat16"},
77 {kNumberTypeComplex, "Complex"},
78 {kNumberTypeComplex64, "Complex64"},
79 {kNumberTypeComplex128, "Complex128"},
80 {kNumberTypeGLUInt, "GLUInt"},
81 {kObjectTypeMonad, "Monad"},
82 {kObjectTypeUMonad, "UMonad"},
83 {kObjectTypeIOMonad, "IOMonad"}};
84
type_priority_map()85 const mindspore::HashMap<TypeId, int> &type_priority_map() {
86 static const mindspore::HashMap<TypeId, int> type_priority_map = {
87 {kNumberTypeBool, 0}, {kNumberTypeUInt8, 1}, {kNumberTypeInt8, 2}, {kNumberTypeInt16, 3},
88 {kNumberTypeInt32, 4}, {kNumberTypeInt64, 5}, {kNumberTypeFloat16, 6}, {kNumberTypeFloat32, 7},
89 {kNumberTypeFloat64, 8}, {kNumberTypeBFloat16, 9}};
90 return type_priority_map;
91 }
92
type_name_map()93 const mindspore::HashMap<TypeId, std::string> &type_name_map() {
94 static const mindspore::HashMap<TypeId, std::string> type_name_map = {
95 {kNumberTypeBool, "bool_"}, {kNumberTypeInt8, "int8"}, {kNumberTypeUInt8, "uint8"},
96 {kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"},
97 {kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"},
98 {kNumberTypeBFloat16, "bfloat16"}, {kNumberTypeInt4, "int4"}};
99 return type_name_map;
100 }
101
IntBitsToTypeId(const int nbits)102 TypeId IntBitsToTypeId(const int nbits) {
103 switch (nbits) {
104 case static_cast<int>(BitsNum::eBits4):
105 return kNumberTypeInt4;
106 case static_cast<int>(BitsNum::eBits8):
107 return kNumberTypeInt8;
108 case static_cast<int>(BitsNum::eBits16):
109 return kNumberTypeInt16;
110 case static_cast<int>(BitsNum::eBits32):
111 return kNumberTypeInt32;
112 case static_cast<int>(BitsNum::eBits64):
113 return kNumberTypeInt64;
114 default:
115 MS_LOG(EXCEPTION) << "For Int type only support number of 8bits, 16bits, 32bits and 64bits, but got " << nbits
116 << "bits";
117 }
118 }
119
UIntBitsToTypeId(const int nbits)120 TypeId UIntBitsToTypeId(const int nbits) {
121 switch (nbits) {
122 case static_cast<int>(BitsNum::eBits8):
123 return kNumberTypeUInt8;
124 case static_cast<int>(BitsNum::eBits16):
125 return kNumberTypeUInt16;
126 case static_cast<int>(BitsNum::eBits32):
127 return kNumberTypeUInt32;
128 case static_cast<int>(BitsNum::eBits64):
129 return kNumberTypeUInt64;
130 default:
131 MS_LOG(EXCEPTION) << "For UInt type only support number of 8bits, 16bits, 32bits and 64bits, but got " << nbits
132 << "bits";
133 }
134 }
135
FloatBitsToTypeId(const int nbits)136 TypeId FloatBitsToTypeId(const int nbits) {
137 switch (nbits) {
138 case static_cast<int>(BitsNum::eBits16):
139 return kNumberTypeFloat16;
140 case static_cast<int>(BitsNum::eBits32):
141 return kNumberTypeFloat32;
142 case static_cast<int>(BitsNum::eBits64):
143 return kNumberTypeFloat64;
144 default:
145 MS_LOG(EXCEPTION) << "For Float type only support number of 16bits, 32bits and 64bits, but got " << nbits
146 << "bits";
147 }
148 }
149
BFloatBitsToTypeId(const int nbits)150 TypeId BFloatBitsToTypeId(const int nbits) {
151 switch (nbits) {
152 case static_cast<int>(BitsNum::eBits16):
153 return kNumberTypeBFloat16;
154 default:
155 MS_LOG(EXCEPTION) << "For BFloat type only support number of 16bits, but got " << nbits << "bits";
156 }
157 }
158
ComplexBitsToTypeId(const int nbits)159 TypeId ComplexBitsToTypeId(const int nbits) {
160 switch (nbits) {
161 case static_cast<int>(BitsNum::eBits64):
162 return kNumberTypeComplex64;
163 case static_cast<int>(BitsNum::eBits128):
164 return kNumberTypeComplex128;
165 default:
166 MS_LOG(EXCEPTION) << "For Complex type only support number of 64bits and 128bits, but got " << nbits << "bits";
167 }
168 }
169
TypeIdLabel(const TypeId & v)170 const std::string &TypeIdLabel(const TypeId &v) {
171 static const std::string unknown("[Unknown Type Id]");
172 auto iter = g_type_2_lable.find(v);
173 if (iter != g_type_2_lable.end()) {
174 return iter->second;
175 } else {
176 return unknown;
177 }
178 }
179
NormalizeTypeId(const TypeId type_id)180 TypeId NormalizeTypeId(const TypeId type_id) {
181 if ((type_id == kNumberTypeInt) || (type_id == kNumberTypeInt8) || (type_id == kNumberTypeInt16) ||
182 (type_id == kNumberTypeInt32) || (type_id == kNumberTypeInt64)) {
183 return kNumberTypeInt;
184 } else if ((type_id == kNumberTypeFloat) || (type_id == kNumberTypeFloat16) || (type_id == kNumberTypeFloat32) ||
185 (type_id == kNumberTypeFloat64)) {
186 return kNumberTypeFloat;
187 } else if (type_id == kNumberTypeBFloat16) {
188 return kNumberTypeBFloat16;
189 } else if ((type_id == kNumberTypeUInt) || (type_id == kNumberTypeUInt8) || (type_id == kNumberTypeUInt16) ||
190 (type_id == kNumberTypeUInt32) || (type_id == kNumberTypeUInt64)) {
191 return kNumberTypeUInt;
192 } else if ((type_id == kNumberTypeComplex) || (type_id == kNumberTypeComplex64) ||
193 (type_id == kNumberTypeComplex128)) {
194 return kNumberTypeComplex;
195 } else {
196 return type_id;
197 }
198 }
199
IsSameObjectType(const Type & lhs,const Type & rhs)200 bool IsSameObjectType(const Type &lhs, const Type &rhs) {
201 if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
202 return false;
203 }
204 return lhs.object_type() == rhs.object_type();
205 }
206
GetTypeByte(const TypePtr & type_ptr)207 size_t GetTypeByte(const TypePtr &type_ptr) {
208 if (type_ptr && type_ptr->isa<Number>()) {
209 auto number = dyn_cast<Number>(type_ptr);
210 if (!number) {
211 MS_LOG(DEBUG) << "Invalid TypePtr got from ApplyKernel.";
212 return 0;
213 } else {
214 if (number->nbits() < CHAR_BIT) {
215 MS_LOG(DEBUG) << "Number of bit " << number->nbits() << " is less than CHAR_BIT " << CHAR_BIT << ", return 1.";
216 return 1;
217 }
218 return IntToSize(number->nbits() / CHAR_BIT);
219 }
220 } else {
221 MS_LOG(DEBUG) << "Invalid TypePtr got from ApplyKernel:" << (type_ptr == nullptr ? "null" : type_ptr->ToString());
222 return 0;
223 }
224 }
225
GetTypeId(const TypeId & type_id)226 int64_t GetTypeId(const TypeId &type_id) { return static_cast<int64_t>(type_id); }
227
operator ==(const Value & other) const228 bool Type::operator==(const Value &other) const {
229 if (!other.isa<Type>()) {
230 return false;
231 }
232 auto other_type = static_cast<const Type *>(&other);
233 return *this == *other_type;
234 }
235
operator <<(std::ostream & os,const Type & type)236 std::ostream &operator<<(std::ostream &os, const Type &type) {
237 os << type.ToString();
238 return os;
239 }
240
operator <<(std::ostream & os,const TypePtr type)241 std::ostream &operator<<(std::ostream &os, const TypePtr type) {
242 os << type->ToString();
243 return os;
244 }
245
equal(const TypePtr other) const246 bool Object::equal(const TypePtr other) const {
247 auto same_other = dyn_cast<Object>(other);
248 if (same_other != nullptr) {
249 return *this == *same_other;
250 }
251 return false;
252 }
253
operator <<(std::ostream & os,const Object & obj)254 std::ostream &operator<<(std::ostream &os, const Object &obj) {
255 os << obj.ToString();
256 return os;
257 }
258
operator <<(std::ostream & os,const std::shared_ptr<Object> obj)259 std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) {
260 os << obj->ToString();
261 return os;
262 }
263
operator <<(std::ostream & os,const TypePtrList & types)264 std::ostream &operator<<(std::ostream &os, const TypePtrList &types) {
265 os << "[";
266 for (size_t i = 0; i < types.size(); ++i) {
267 if (i > 0) {
268 os << ", ";
269 }
270 os << (types[i] == nullptr ? "nullptr" : types[i]->ToString());
271 }
272 os << "]";
273 return os;
274 }
275 } // namespace mindspore
276