• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/dtype/type.h"
20 
21 #include <algorithm>
22 #include <cstdlib>
23 #include <climits>
24 
25 #include "ir/dtype/number.h"
26 #include "utils/log_adapter.h"
27 #include "utils/convert_utils_base.h"
28 
29 namespace mindspore {
30 static mindspore::HashMap<TypeId, std::string> g_type_2_lable{{kTypeUnknown, "Unknown"},
31                                                               {kMetaTypeType, "Type"},
32                                                               {kMetaTypeAny, "Any"},
33                                                               {kMetaTypeObject, "Object"},
34                                                               {kMetaTypeTypeType, "TypeType"},
35                                                               {kMetaTypeProblem, "Problem"},
36                                                               {kMetaTypeExternal, "External"},
37                                                               {kMetaTypeNone, "None"},
38                                                               {kMetaTypeNull, "Null"},
39                                                               {kMetaTypeEllipsis, "Ellipsis"},
40                                                               {kObjectTypeNumber, "Number"},
41                                                               {kObjectTypeString, "String"},
42                                                               {kObjectTypeList, "List"},
43                                                               {kObjectTypeTuple, "Tuple"},
44                                                               {kObjectTypeSlice, "Slice"},
45                                                               {kObjectTypeKeyword, "Keyword"},
46                                                               {kObjectTypeTensorType, "Tensor"},
47                                                               {kObjectTypeMapTensorType, "MapTensor"},
48                                                               {kObjectTypeRowTensorType, "RowTensor"},
49                                                               {kObjectTypeCOOTensorType, "COOTensor"},
50                                                               {kObjectTypeCSRTensorType, "CSRTensor"},
51                                                               {kObjectTypeUndeterminedType, "Undetermined"},
52                                                               {kObjectTypeClass, "Class"},
53                                                               {kObjectTypeDictionary, "Dictionary"},
54                                                               {kObjectTypeFunction, "Function"},
55                                                               {kObjectTypeJTagged, "JTagged"},
56                                                               {kObjectTypeSymbolicKeyType, "SymbolicKey"},
57                                                               {kObjectTypeEnvType, "EnvType"},
58                                                               {kObjectTypeRefKey, "RefKey"},
59                                                               {kObjectTypeRef, "Ref"},
60                                                               {kNumberTypeBool, "Bool"},
61                                                               {kNumberTypeInt, "Int"},
62                                                               {kNumberTypeInt4, "QInt4x2"},
63                                                               {kNumberTypeInt8, "Int8"},
64                                                               {kNumberTypeInt16, "Int16"},
65                                                               {kNumberTypeInt32, "Int32"},
66                                                               {kNumberTypeInt64, "Int64"},
67                                                               {kNumberTypeUInt, "UInt"},
68                                                               {kNumberTypeUInt8, "UInt8"},
69                                                               {kNumberTypeUInt16, "UInt16"},
70                                                               {kNumberTypeUInt32, "UInt32"},
71                                                               {kNumberTypeUInt64, "UInt64"},
72                                                               {kNumberTypeFloat, "Float"},
73                                                               {kNumberTypeFloat16, "Float16"},
74                                                               {kNumberTypeFloat32, "Float32"},
75                                                               {kNumberTypeFloat64, "Float64"},
76                                                               {kNumberTypeBFloat16, "BFloat16"},
77                                                               {kNumberTypeComplex, "Complex"},
78                                                               {kNumberTypeComplex64, "Complex64"},
79                                                               {kNumberTypeComplex128, "Complex128"},
80                                                               {kNumberTypeGLUInt, "GLUInt"},
81                                                               {kObjectTypeMonad, "Monad"},
82                                                               {kObjectTypeUMonad, "UMonad"},
83                                                               {kObjectTypeIOMonad, "IOMonad"}};
84 
type_priority_map()85 const mindspore::HashMap<TypeId, int> &type_priority_map() {
86   static const mindspore::HashMap<TypeId, int> type_priority_map = {
87     {kNumberTypeBool, 0},    {kNumberTypeUInt8, 1},   {kNumberTypeInt8, 2},    {kNumberTypeInt16, 3},
88     {kNumberTypeInt32, 4},   {kNumberTypeInt64, 5},   {kNumberTypeFloat16, 6}, {kNumberTypeFloat32, 7},
89     {kNumberTypeFloat64, 8}, {kNumberTypeBFloat16, 9}};
90   return type_priority_map;
91 }
92 
type_name_map()93 const mindspore::HashMap<TypeId, std::string> &type_name_map() {
94   static const mindspore::HashMap<TypeId, std::string> type_name_map = {
95     {kNumberTypeBool, "bool_"},        {kNumberTypeInt8, "int8"},       {kNumberTypeUInt8, "uint8"},
96     {kNumberTypeInt16, "int16"},       {kNumberTypeInt32, "int32"},     {kNumberTypeInt64, "int64"},
97     {kNumberTypeFloat16, "float16"},   {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"},
98     {kNumberTypeBFloat16, "bfloat16"}, {kNumberTypeInt4, "int4"}};
99   return type_name_map;
100 }
101 
IntBitsToTypeId(const int nbits)102 TypeId IntBitsToTypeId(const int nbits) {
103   switch (nbits) {
104     case static_cast<int>(BitsNum::eBits4):
105       return kNumberTypeInt4;
106     case static_cast<int>(BitsNum::eBits8):
107       return kNumberTypeInt8;
108     case static_cast<int>(BitsNum::eBits16):
109       return kNumberTypeInt16;
110     case static_cast<int>(BitsNum::eBits32):
111       return kNumberTypeInt32;
112     case static_cast<int>(BitsNum::eBits64):
113       return kNumberTypeInt64;
114     default:
115       MS_LOG(EXCEPTION) << "For Int type only support number of 8bits, 16bits, 32bits and 64bits, but got " << nbits
116                         << "bits";
117   }
118 }
119 
UIntBitsToTypeId(const int nbits)120 TypeId UIntBitsToTypeId(const int nbits) {
121   switch (nbits) {
122     case static_cast<int>(BitsNum::eBits8):
123       return kNumberTypeUInt8;
124     case static_cast<int>(BitsNum::eBits16):
125       return kNumberTypeUInt16;
126     case static_cast<int>(BitsNum::eBits32):
127       return kNumberTypeUInt32;
128     case static_cast<int>(BitsNum::eBits64):
129       return kNumberTypeUInt64;
130     default:
131       MS_LOG(EXCEPTION) << "For UInt type only support number of 8bits, 16bits, 32bits and 64bits, but got " << nbits
132                         << "bits";
133   }
134 }
135 
FloatBitsToTypeId(const int nbits)136 TypeId FloatBitsToTypeId(const int nbits) {
137   switch (nbits) {
138     case static_cast<int>(BitsNum::eBits16):
139       return kNumberTypeFloat16;
140     case static_cast<int>(BitsNum::eBits32):
141       return kNumberTypeFloat32;
142     case static_cast<int>(BitsNum::eBits64):
143       return kNumberTypeFloat64;
144     default:
145       MS_LOG(EXCEPTION) << "For Float type only support number of 16bits, 32bits and 64bits, but got " << nbits
146                         << "bits";
147   }
148 }
149 
BFloatBitsToTypeId(const int nbits)150 TypeId BFloatBitsToTypeId(const int nbits) {
151   switch (nbits) {
152     case static_cast<int>(BitsNum::eBits16):
153       return kNumberTypeBFloat16;
154     default:
155       MS_LOG(EXCEPTION) << "For BFloat type only support number of 16bits, but got " << nbits << "bits";
156   }
157 }
158 
ComplexBitsToTypeId(const int nbits)159 TypeId ComplexBitsToTypeId(const int nbits) {
160   switch (nbits) {
161     case static_cast<int>(BitsNum::eBits64):
162       return kNumberTypeComplex64;
163     case static_cast<int>(BitsNum::eBits128):
164       return kNumberTypeComplex128;
165     default:
166       MS_LOG(EXCEPTION) << "For Complex type only support number of 64bits and 128bits, but got " << nbits << "bits";
167   }
168 }
169 
TypeIdLabel(const TypeId & v)170 const std::string &TypeIdLabel(const TypeId &v) {
171   static const std::string unknown("[Unknown Type Id]");
172   auto iter = g_type_2_lable.find(v);
173   if (iter != g_type_2_lable.end()) {
174     return iter->second;
175   } else {
176     return unknown;
177   }
178 }
179 
NormalizeTypeId(const TypeId type_id)180 TypeId NormalizeTypeId(const TypeId type_id) {
181   if ((type_id == kNumberTypeInt) || (type_id == kNumberTypeInt8) || (type_id == kNumberTypeInt16) ||
182       (type_id == kNumberTypeInt32) || (type_id == kNumberTypeInt64)) {
183     return kNumberTypeInt;
184   } else if ((type_id == kNumberTypeFloat) || (type_id == kNumberTypeFloat16) || (type_id == kNumberTypeFloat32) ||
185              (type_id == kNumberTypeFloat64)) {
186     return kNumberTypeFloat;
187   } else if (type_id == kNumberTypeBFloat16) {
188     return kNumberTypeBFloat16;
189   } else if ((type_id == kNumberTypeUInt) || (type_id == kNumberTypeUInt8) || (type_id == kNumberTypeUInt16) ||
190              (type_id == kNumberTypeUInt32) || (type_id == kNumberTypeUInt64)) {
191     return kNumberTypeUInt;
192   } else if ((type_id == kNumberTypeComplex) || (type_id == kNumberTypeComplex64) ||
193              (type_id == kNumberTypeComplex128)) {
194     return kNumberTypeComplex;
195   } else {
196     return type_id;
197   }
198 }
199 
IsSameObjectType(const Type & lhs,const Type & rhs)200 bool IsSameObjectType(const Type &lhs, const Type &rhs) {
201   if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
202     return false;
203   }
204   return lhs.object_type() == rhs.object_type();
205 }
206 
GetTypeByte(const TypePtr & type_ptr)207 size_t GetTypeByte(const TypePtr &type_ptr) {
208   if (type_ptr && type_ptr->isa<Number>()) {
209     auto number = dyn_cast<Number>(type_ptr);
210     if (!number) {
211       MS_LOG(DEBUG) << "Invalid TypePtr got from ApplyKernel.";
212       return 0;
213     } else {
214       if (number->nbits() < CHAR_BIT) {
215         MS_LOG(DEBUG) << "Number of bit " << number->nbits() << " is less than CHAR_BIT " << CHAR_BIT << ", return 1.";
216         return 1;
217       }
218       return IntToSize(number->nbits() / CHAR_BIT);
219     }
220   } else {
221     MS_LOG(DEBUG) << "Invalid TypePtr got from ApplyKernel:" << (type_ptr == nullptr ? "null" : type_ptr->ToString());
222     return 0;
223   }
224 }
225 
GetTypeId(const TypeId & type_id)226 int64_t GetTypeId(const TypeId &type_id) { return static_cast<int64_t>(type_id); }
227 
operator ==(const Value & other) const228 bool Type::operator==(const Value &other) const {
229   if (!other.isa<Type>()) {
230     return false;
231   }
232   auto other_type = static_cast<const Type *>(&other);
233   return *this == *other_type;
234 }
235 
operator <<(std::ostream & os,const Type & type)236 std::ostream &operator<<(std::ostream &os, const Type &type) {
237   os << type.ToString();
238   return os;
239 }
240 
operator <<(std::ostream & os,const TypePtr type)241 std::ostream &operator<<(std::ostream &os, const TypePtr type) {
242   os << type->ToString();
243   return os;
244 }
245 
equal(const TypePtr other) const246 bool Object::equal(const TypePtr other) const {
247   auto same_other = dyn_cast<Object>(other);
248   if (same_other != nullptr) {
249     return *this == *same_other;
250   }
251   return false;
252 }
253 
operator <<(std::ostream & os,const Object & obj)254 std::ostream &operator<<(std::ostream &os, const Object &obj) {
255   os << obj.ToString();
256   return os;
257 }
258 
operator <<(std::ostream & os,const std::shared_ptr<Object> obj)259 std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) {
260   os << obj->ToString();
261   return os;
262 }
263 
operator <<(std::ostream & os,const TypePtrList & types)264 std::ostream &operator<<(std::ostream &os, const TypePtrList &types) {
265   os << "[";
266   for (size_t i = 0; i < types.size(); ++i) {
267     if (i > 0) {
268       os << ", ";
269     }
270     os << (types[i] == nullptr ? "nullptr" : types[i]->ToString());
271   }
272   os << "]";
273   return os;
274 }
275 }  // namespace mindspore
276