1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2020 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/dtype/type.h"
20
21 #include <algorithm>
22 #include <cstdlib>
23 #include <string>
24 #include <climits>
25
26 #include "ir/dtype/number.h"
27 #include "utils/log_adapter.h"
28 #include "utils/convert_utils_base.h"
29
30 namespace mindspore {
31 #define MS_TYPE2LABLE(type_id) #type_id
32 static std::unordered_map<TypeId, std::string> g_type_2_lable{
33 {kTypeUnknown, MS_TYPE2LABLE(kTypeUnknown)},
34 {kMetaTypeType, MS_TYPE2LABLE(kMetaTypeType)},
35 {kMetaTypeAnything, MS_TYPE2LABLE(kMetaTypeAnything)},
36 {kMetaTypeObject, MS_TYPE2LABLE(kMetaTypeObject)},
37 {kMetaTypeTypeType, MS_TYPE2LABLE(kMetaTypeTypeType)},
38 {kMetaTypeProblem, MS_TYPE2LABLE(kMetaTypeProblem)},
39 {kMetaTypeExternal, MS_TYPE2LABLE(kMetaTypeExternal)},
40 {kMetaTypeNone, MS_TYPE2LABLE(kMetaTypeNone)},
41 {kMetaTypeNull, MS_TYPE2LABLE(kMetaTypeNull)},
42 {kMetaTypeEllipsis, MS_TYPE2LABLE(kMetaTypeEllipsis)},
43 {kMetaTypeEnd, MS_TYPE2LABLE(kMetaTypeEnd)},
44 {kObjectTypeNumber, MS_TYPE2LABLE(kObjectTypeNumber)},
45 {kObjectTypeString, MS_TYPE2LABLE(kObjectTypeString)},
46 {kObjectTypeList, MS_TYPE2LABLE(kObjectTypeList)},
47 {kObjectTypeTuple, MS_TYPE2LABLE(kObjectTypeTuple)},
48 {kObjectTypeSlice, MS_TYPE2LABLE(kObjectTypeSlice)},
49 {kObjectTypeKeyword, MS_TYPE2LABLE(kObjectTypeKeyword)},
50 {kObjectTypeTensorType, MS_TYPE2LABLE(kObjectTypeTensorType)},
51 {kObjectTypeRowTensorType, MS_TYPE2LABLE(kObjectTypeRowTensorType)},
52 {kObjectTypeSparseTensorType, MS_TYPE2LABLE(kObjectTypeSparseTensorType)},
53 {kObjectTypeUndeterminedType, MS_TYPE2LABLE(kObjectTypeUndeterminedType)},
54 {kObjectTypeClass, MS_TYPE2LABLE(kObjectTypeClass)},
55 {kObjectTypeDictionary, MS_TYPE2LABLE(kObjectTypeDictionary)},
56 {kObjectTypeFunction, MS_TYPE2LABLE(kObjectTypeFunction)},
57 {kObjectTypeJTagged, MS_TYPE2LABLE(kObjectTypeJTagged)},
58 {kObjectTypeSymbolicKeyType, MS_TYPE2LABLE(kObjectTypeSymbolicKeyType)},
59 {kObjectTypeEnvType, MS_TYPE2LABLE(kObjectTypeEnvType)},
60 {kObjectTypeRefKey, MS_TYPE2LABLE(kObjectTypeRefKey)},
61 {kObjectTypeRef, MS_TYPE2LABLE(kObjectTypeRef)},
62 {kObjectTypeEnd, MS_TYPE2LABLE(kObjectTypeEnd)},
63 {kNumberTypeBool, MS_TYPE2LABLE(kNumberTypeBool)},
64 {kNumberTypeInt, MS_TYPE2LABLE(kNumberTypeInt)},
65 {kNumberTypeInt8, MS_TYPE2LABLE(kNumberTypeInt8)},
66 {kNumberTypeInt16, MS_TYPE2LABLE(kNumberTypeInt16)},
67 {kNumberTypeInt32, MS_TYPE2LABLE(kNumberTypeInt32)},
68 {kNumberTypeInt64, MS_TYPE2LABLE(kNumberTypeInt64)},
69 {kNumberTypeUInt, MS_TYPE2LABLE(kNumberTypeUInt)},
70 {kNumberTypeUInt8, MS_TYPE2LABLE(kNumberTypeUInt8)},
71 {kNumberTypeUInt16, MS_TYPE2LABLE(kNumberTypeUInt16)},
72 {kNumberTypeUInt32, MS_TYPE2LABLE(kNumberTypeUInt32)},
73 {kNumberTypeUInt64, MS_TYPE2LABLE(kNumberTypeUInt64)},
74 {kNumberTypeFloat, MS_TYPE2LABLE(kNumberTypeFloat)},
75 {kNumberTypeFloat16, MS_TYPE2LABLE(kNumberTypeFloat16)},
76 {kNumberTypeFloat32, MS_TYPE2LABLE(kNumberTypeFloat32)},
77 {kNumberTypeFloat64, MS_TYPE2LABLE(kNumberTypeFloat64)},
78 {kNumberTypeComplex64, MS_TYPE2LABLE(kNumberTypeComplex64)},
79 {kNumberTypeEnd, MS_TYPE2LABLE(kNumberTypeEnd)},
80 {kObjectTypeMonad, MS_TYPE2LABLE(kObjectTypeMonad)},
81 {kObjectTypeUMonad, MS_TYPE2LABLE(kObjectTypeUMonad)},
82 {kObjectTypeIOMonad, MS_TYPE2LABLE(kObjectTypeIOMonad)},
83 {kMonadTypeEnd, MS_TYPE2LABLE(kMonadTypeEnd)}};
84
IntBitsToTypeId(const int nbits)85 TypeId IntBitsToTypeId(const int nbits) {
86 switch (nbits) {
87 case static_cast<int>(BitsNum::eBits8):
88 return kNumberTypeInt8;
89 case static_cast<int>(BitsNum::eBits16):
90 return kNumberTypeInt16;
91 case static_cast<int>(BitsNum::eBits32):
92 return kNumberTypeInt32;
93 case static_cast<int>(BitsNum::eBits64):
94 return kNumberTypeInt64;
95 default:
96 MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits;
97 }
98 }
99
UIntBitsToTypeId(const int nbits)100 TypeId UIntBitsToTypeId(const int nbits) {
101 switch (nbits) {
102 case static_cast<int>(BitsNum::eBits8):
103 return kNumberTypeUInt8;
104 case static_cast<int>(BitsNum::eBits16):
105 return kNumberTypeUInt16;
106 case static_cast<int>(BitsNum::eBits32):
107 return kNumberTypeUInt32;
108 case static_cast<int>(BitsNum::eBits64):
109 return kNumberTypeUInt64;
110 default:
111 MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits;
112 }
113 }
114
FloatBitsToTypeId(const int nbits)115 TypeId FloatBitsToTypeId(const int nbits) {
116 switch (nbits) {
117 case static_cast<int>(BitsNum::eBits16):
118 return kNumberTypeFloat16;
119 case static_cast<int>(BitsNum::eBits32):
120 return kNumberTypeFloat32;
121 case static_cast<int>(BitsNum::eBits64):
122 return kNumberTypeFloat64;
123 default:
124 MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits;
125 }
126 }
127
ComplexBitsToTypeId(const int nbits)128 TypeId ComplexBitsToTypeId(const int nbits) {
129 switch (nbits) {
130 case static_cast<int>(BitsNum::eBits64):
131 return kNumberTypeComplex64;
132 case static_cast<int>(BitsNum::eBits128):
133 return kNumberTypeComplex128;
134 default:
135 MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits;
136 }
137 }
138
TypeIdLabel(const TypeId & v)139 const std::string &TypeIdLabel(const TypeId &v) {
140 static const std::string unknown("[Unknown Type Id]");
141 auto iter = g_type_2_lable.find(v);
142 if (iter != g_type_2_lable.end()) {
143 return iter->second;
144 } else {
145 return unknown;
146 }
147 }
148
NormalizeTypeId(const TypeId type_id)149 TypeId NormalizeTypeId(const TypeId type_id) {
150 if ((type_id == kNumberTypeInt) || (type_id == kNumberTypeInt8) || (type_id == kNumberTypeInt16) ||
151 (type_id == kNumberTypeInt32) || (type_id == kNumberTypeInt64)) {
152 return kNumberTypeInt;
153 } else if ((type_id == kNumberTypeFloat) || (type_id == kNumberTypeFloat16) || (type_id == kNumberTypeFloat32) ||
154 (type_id == kNumberTypeFloat64)) {
155 return kNumberTypeFloat;
156 } else {
157 return type_id;
158 }
159 }
160
IsSameObjectType(const Type & lhs,const Type & rhs)161 bool IsSameObjectType(const Type &lhs, const Type &rhs) {
162 if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
163 return false;
164 }
165 return lhs.object_type() == rhs.object_type();
166 }
167
GetTypeByte(const TypePtr & type_ptr)168 size_t GetTypeByte(const TypePtr &type_ptr) {
169 if (type_ptr && type_ptr->isa<Number>()) {
170 auto number = dyn_cast<Number>(type_ptr);
171 if (!number) {
172 MS_LOG(DEBUG) << "Invalid TypePtr got from ApplyKernel.";
173 return 0;
174 } else {
175 return IntToSize(number->nbits() / CHAR_BIT);
176 }
177 } else {
178 MS_LOG(DEBUG) << "Invalid TypePtr got from ApplyKernel.";
179 return 0;
180 }
181 }
182
operator ==(const Value & other) const183 bool Type::operator==(const Value &other) const {
184 if (other.isa<Type>()) {
185 auto other_type = static_cast<const Type *>(&other);
186 return *this == *other_type;
187 } else {
188 return false;
189 }
190 }
191
operator <<(std::ostream & os,const Type & type)192 std::ostream &operator<<(std::ostream &os, const Type &type) {
193 os << type.ToString();
194 return os;
195 }
196
operator <<(std::ostream & os,const TypePtr type)197 std::ostream &operator<<(std::ostream &os, const TypePtr type) {
198 os << type->ToString();
199 return os;
200 }
201
equal(const TypePtr other) const202 bool Object::equal(const TypePtr other) const {
203 auto same_other = dyn_cast<Object>(other);
204 if (same_other != nullptr) {
205 return *this == *same_other;
206 }
207 return false;
208 }
209
operator <<(std::ostream & os,const Object & obj)210 std::ostream &operator<<(std::ostream &os, const Object &obj) {
211 os << obj.ToString();
212 return os;
213 }
214
operator <<(std::ostream & os,const std::shared_ptr<Object> obj)215 std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) {
216 os << obj->ToString();
217 return os;
218 }
219
operator <<(std::ostream & os,const TypePtrList & types)220 std::ostream &operator<<(std::ostream &os, const TypePtrList &types) {
221 os << "[";
222 for (size_t i = 0; i < types.size(); ++i) {
223 if (i > 0) {
224 os << ", ";
225 }
226 os << (types[i] == nullptr ? "nullptr" : types[i]->ToString());
227 }
228 os << "]";
229 return os;
230 }
231 } // namespace mindspore
232