• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/dtype/type.h"
20 
21 #include <algorithm>
22 #include <cstdlib>
23 #include <string>
24 #include <climits>
25 
26 #include "ir/dtype/number.h"
27 #include "utils/log_adapter.h"
28 #include "utils/convert_utils_base.h"
29 
30 namespace mindspore {
31 #define MS_TYPE2LABLE(type_id) #type_id
32 static std::unordered_map<TypeId, std::string> g_type_2_lable{
33   {kTypeUnknown, MS_TYPE2LABLE(kTypeUnknown)},
34   {kMetaTypeType, MS_TYPE2LABLE(kMetaTypeType)},
35   {kMetaTypeAnything, MS_TYPE2LABLE(kMetaTypeAnything)},
36   {kMetaTypeObject, MS_TYPE2LABLE(kMetaTypeObject)},
37   {kMetaTypeTypeType, MS_TYPE2LABLE(kMetaTypeTypeType)},
38   {kMetaTypeProblem, MS_TYPE2LABLE(kMetaTypeProblem)},
39   {kMetaTypeExternal, MS_TYPE2LABLE(kMetaTypeExternal)},
40   {kMetaTypeNone, MS_TYPE2LABLE(kMetaTypeNone)},
41   {kMetaTypeNull, MS_TYPE2LABLE(kMetaTypeNull)},
42   {kMetaTypeEllipsis, MS_TYPE2LABLE(kMetaTypeEllipsis)},
43   {kMetaTypeEnd, MS_TYPE2LABLE(kMetaTypeEnd)},
44   {kObjectTypeNumber, MS_TYPE2LABLE(kObjectTypeNumber)},
45   {kObjectTypeString, MS_TYPE2LABLE(kObjectTypeString)},
46   {kObjectTypeList, MS_TYPE2LABLE(kObjectTypeList)},
47   {kObjectTypeTuple, MS_TYPE2LABLE(kObjectTypeTuple)},
48   {kObjectTypeSlice, MS_TYPE2LABLE(kObjectTypeSlice)},
49   {kObjectTypeKeyword, MS_TYPE2LABLE(kObjectTypeKeyword)},
50   {kObjectTypeTensorType, MS_TYPE2LABLE(kObjectTypeTensorType)},
51   {kObjectTypeRowTensorType, MS_TYPE2LABLE(kObjectTypeRowTensorType)},
52   {kObjectTypeSparseTensorType, MS_TYPE2LABLE(kObjectTypeSparseTensorType)},
53   {kObjectTypeUndeterminedType, MS_TYPE2LABLE(kObjectTypeUndeterminedType)},
54   {kObjectTypeClass, MS_TYPE2LABLE(kObjectTypeClass)},
55   {kObjectTypeDictionary, MS_TYPE2LABLE(kObjectTypeDictionary)},
56   {kObjectTypeFunction, MS_TYPE2LABLE(kObjectTypeFunction)},
57   {kObjectTypeJTagged, MS_TYPE2LABLE(kObjectTypeJTagged)},
58   {kObjectTypeSymbolicKeyType, MS_TYPE2LABLE(kObjectTypeSymbolicKeyType)},
59   {kObjectTypeEnvType, MS_TYPE2LABLE(kObjectTypeEnvType)},
60   {kObjectTypeRefKey, MS_TYPE2LABLE(kObjectTypeRefKey)},
61   {kObjectTypeRef, MS_TYPE2LABLE(kObjectTypeRef)},
62   {kObjectTypeEnd, MS_TYPE2LABLE(kObjectTypeEnd)},
63   {kNumberTypeBool, MS_TYPE2LABLE(kNumberTypeBool)},
64   {kNumberTypeInt, MS_TYPE2LABLE(kNumberTypeInt)},
65   {kNumberTypeInt8, MS_TYPE2LABLE(kNumberTypeInt8)},
66   {kNumberTypeInt16, MS_TYPE2LABLE(kNumberTypeInt16)},
67   {kNumberTypeInt32, MS_TYPE2LABLE(kNumberTypeInt32)},
68   {kNumberTypeInt64, MS_TYPE2LABLE(kNumberTypeInt64)},
69   {kNumberTypeUInt, MS_TYPE2LABLE(kNumberTypeUInt)},
70   {kNumberTypeUInt8, MS_TYPE2LABLE(kNumberTypeUInt8)},
71   {kNumberTypeUInt16, MS_TYPE2LABLE(kNumberTypeUInt16)},
72   {kNumberTypeUInt32, MS_TYPE2LABLE(kNumberTypeUInt32)},
73   {kNumberTypeUInt64, MS_TYPE2LABLE(kNumberTypeUInt64)},
74   {kNumberTypeFloat, MS_TYPE2LABLE(kNumberTypeFloat)},
75   {kNumberTypeFloat16, MS_TYPE2LABLE(kNumberTypeFloat16)},
76   {kNumberTypeFloat32, MS_TYPE2LABLE(kNumberTypeFloat32)},
77   {kNumberTypeFloat64, MS_TYPE2LABLE(kNumberTypeFloat64)},
78   {kNumberTypeComplex64, MS_TYPE2LABLE(kNumberTypeComplex64)},
79   {kNumberTypeEnd, MS_TYPE2LABLE(kNumberTypeEnd)},
80   {kObjectTypeMonad, MS_TYPE2LABLE(kObjectTypeMonad)},
81   {kObjectTypeUMonad, MS_TYPE2LABLE(kObjectTypeUMonad)},
82   {kObjectTypeIOMonad, MS_TYPE2LABLE(kObjectTypeIOMonad)},
83   {kMonadTypeEnd, MS_TYPE2LABLE(kMonadTypeEnd)}};
84 
IntBitsToTypeId(const int nbits)85 TypeId IntBitsToTypeId(const int nbits) {
86   switch (nbits) {
87     case static_cast<int>(BitsNum::eBits8):
88       return kNumberTypeInt8;
89     case static_cast<int>(BitsNum::eBits16):
90       return kNumberTypeInt16;
91     case static_cast<int>(BitsNum::eBits32):
92       return kNumberTypeInt32;
93     case static_cast<int>(BitsNum::eBits64):
94       return kNumberTypeInt64;
95     default:
96       MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits;
97   }
98 }
99 
UIntBitsToTypeId(const int nbits)100 TypeId UIntBitsToTypeId(const int nbits) {
101   switch (nbits) {
102     case static_cast<int>(BitsNum::eBits8):
103       return kNumberTypeUInt8;
104     case static_cast<int>(BitsNum::eBits16):
105       return kNumberTypeUInt16;
106     case static_cast<int>(BitsNum::eBits32):
107       return kNumberTypeUInt32;
108     case static_cast<int>(BitsNum::eBits64):
109       return kNumberTypeUInt64;
110     default:
111       MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits;
112   }
113 }
114 
FloatBitsToTypeId(const int nbits)115 TypeId FloatBitsToTypeId(const int nbits) {
116   switch (nbits) {
117     case static_cast<int>(BitsNum::eBits16):
118       return kNumberTypeFloat16;
119     case static_cast<int>(BitsNum::eBits32):
120       return kNumberTypeFloat32;
121     case static_cast<int>(BitsNum::eBits64):
122       return kNumberTypeFloat64;
123     default:
124       MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits;
125   }
126 }
127 
ComplexBitsToTypeId(const int nbits)128 TypeId ComplexBitsToTypeId(const int nbits) {
129   switch (nbits) {
130     case static_cast<int>(BitsNum::eBits64):
131       return kNumberTypeComplex64;
132     case static_cast<int>(BitsNum::eBits128):
133       return kNumberTypeComplex128;
134     default:
135       MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits;
136   }
137 }
138 
TypeIdLabel(const TypeId & v)139 const std::string &TypeIdLabel(const TypeId &v) {
140   static const std::string unknown("[Unknown Type Id]");
141   auto iter = g_type_2_lable.find(v);
142   if (iter != g_type_2_lable.end()) {
143     return iter->second;
144   } else {
145     return unknown;
146   }
147 }
148 
NormalizeTypeId(const TypeId type_id)149 TypeId NormalizeTypeId(const TypeId type_id) {
150   if ((type_id == kNumberTypeInt) || (type_id == kNumberTypeInt8) || (type_id == kNumberTypeInt16) ||
151       (type_id == kNumberTypeInt32) || (type_id == kNumberTypeInt64)) {
152     return kNumberTypeInt;
153   } else if ((type_id == kNumberTypeFloat) || (type_id == kNumberTypeFloat16) || (type_id == kNumberTypeFloat32) ||
154              (type_id == kNumberTypeFloat64)) {
155     return kNumberTypeFloat;
156   } else {
157     return type_id;
158   }
159 }
160 
IsSameObjectType(const Type & lhs,const Type & rhs)161 bool IsSameObjectType(const Type &lhs, const Type &rhs) {
162   if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
163     return false;
164   }
165   return lhs.object_type() == rhs.object_type();
166 }
167 
GetTypeByte(const TypePtr & type_ptr)168 size_t GetTypeByte(const TypePtr &type_ptr) {
169   if (type_ptr && type_ptr->isa<Number>()) {
170     auto number = dyn_cast<Number>(type_ptr);
171     if (!number) {
172       MS_LOG(DEBUG) << "Invalid TypePtr got from ApplyKernel.";
173       return 0;
174     } else {
175       return IntToSize(number->nbits() / CHAR_BIT);
176     }
177   } else {
178     MS_LOG(DEBUG) << "Invalid TypePtr got from ApplyKernel.";
179     return 0;
180   }
181 }
182 
operator ==(const Value & other) const183 bool Type::operator==(const Value &other) const {
184   if (other.isa<Type>()) {
185     auto other_type = static_cast<const Type *>(&other);
186     return *this == *other_type;
187   } else {
188     return false;
189   }
190 }
191 
operator <<(std::ostream & os,const Type & type)192 std::ostream &operator<<(std::ostream &os, const Type &type) {
193   os << type.ToString();
194   return os;
195 }
196 
operator <<(std::ostream & os,const TypePtr type)197 std::ostream &operator<<(std::ostream &os, const TypePtr type) {
198   os << type->ToString();
199   return os;
200 }
201 
equal(const TypePtr other) const202 bool Object::equal(const TypePtr other) const {
203   auto same_other = dyn_cast<Object>(other);
204   if (same_other != nullptr) {
205     return *this == *same_other;
206   }
207   return false;
208 }
209 
operator <<(std::ostream & os,const Object & obj)210 std::ostream &operator<<(std::ostream &os, const Object &obj) {
211   os << obj.ToString();
212   return os;
213 }
214 
operator <<(std::ostream & os,const std::shared_ptr<Object> obj)215 std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) {
216   os << obj->ToString();
217   return os;
218 }
219 
operator <<(std::ostream & os,const TypePtrList & types)220 std::ostream &operator<<(std::ostream &os, const TypePtrList &types) {
221   os << "[";
222   for (size_t i = 0; i < types.size(); ++i) {
223     if (i > 0) {
224       os << ", ";
225     }
226     os << (types[i] == nullptr ? "nullptr" : types[i]->ToString());
227   }
228   os << "]";
229   return os;
230 }
231 }  // namespace mindspore
232