1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/dtype.h"
18 #include <string>
19 #include <cstdlib>
20 #include <algorithm>
21 #include "utils/log_adapter.h"
22 #include "abstract/abstract_value.h"
23
24 namespace mindspore {
DeepCopy() const25 TypePtr TypeAnything::DeepCopy() const { return kAnyType; }
26
operator ()(TypePtr const & type) const27 std::size_t TypeHasher::operator()(TypePtr const &type) const {
28 MS_EXCEPTION_IF_NULL(type);
29 std::size_t hash = std::hash<size_t>()(type->type_id());
30 return hash;
31 }
32
operator ()(const TypePtrList & type_list) const33 std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
34 std::size_t hash_sum = 0;
35 for (auto &type : type_list) {
36 auto type_id = static_cast<std::size_t>(type->type_id());
37 hash_sum = hash_combine(hash_sum, type_id);
38 }
39 return hash_sum;
40 }
41
operator ()(TypePtr const & t1,TypePtr const & t2) const42 bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
43 MS_EXCEPTION_IF_NULL(t1);
44 MS_EXCEPTION_IF_NULL(t2);
45 return t1->type_id() == t2->type_id();
46 }
47
operator ()(TypePtrList const & lhs,TypePtrList const & rhs) const48 bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
49 if (lhs.size() != rhs.size()) {
50 return false;
51 }
52 std::size_t size = lhs.size();
53 for (std::size_t i = 0; i < size; ++i) {
54 MS_EXCEPTION_IF_NULL(lhs[i]);
55 MS_EXCEPTION_IF_NULL(rhs[i]);
56 if (*lhs[i] != *rhs[i]) {
57 return false;
58 }
59 }
60 return true;
61 }
62
TypeIdToType(TypeId id)63 TypePtr TypeIdToType(TypeId id) {
64 static std::unordered_map<TypeId, TypePtr> type_id_to_type = {
65 {kNumberTypeFloat16, kFloat16}, {kNumberTypeFloat, kFloat32}, {kNumberTypeFloat32, kFloat32},
66 {kNumberTypeFloat64, kFloat64}, {kNumberTypeComplex64, kComplex64}, {kNumberTypeInt8, kInt8},
67 {kNumberTypeInt16, kInt16}, {kNumberTypeInt32, kInt32}, {kNumberTypeInt, kInt32},
68 {kNumberTypeInt64, kInt64}, {kNumberTypeUInt8, kUInt8}, {kNumberTypeUInt16, kUInt16},
69 {kNumberTypeUInt32, kUInt32}, {kNumberTypeUInt64, kUInt64}, {kNumberTypeBool, kBool},
70 {kNumberTypeComplex64, kComplex64}, {kNumberTypeComplex128, kComplex128}, {kMetaTypeExternal, kTypeExternal},
71 {kMetaTypeAnything, kAnyType}, {kMetaTypeNone, kTypeNone}, {kMetaTypeNull, kTypeNull},
72 {kMetaTypeEllipsis, kTypeEllipsis}, {kObjectTypeEnvType, kTypeEnv}, {kObjectTypeRefKey, kRefKeyType},
73 {kObjectTypeRef, kRefType}, {kMetaTypeTypeType, kTypeType}, {kObjectTypeString, kString},
74 {kObjectTypeList, kList}, {kObjectTypeTuple, kTuple}, {kObjectTypeDictionary, kDict},
75 {kObjectTypeSlice, kSlice}, {kObjectTypeKeyword, kKeyword}, {kObjectTypeTensorType, kTensorType},
76 {kObjectTypeUMonad, kUMonadType}, {kObjectTypeIOMonad, kIOMonadType}, {kTypeUnknown, kTypeNone},
77 {kMetaTypeProblem, kTypeNone}};
78 const auto &it = type_id_to_type.find(id);
79 if (it == type_id_to_type.end()) {
80 MS_LOG(EXCEPTION) << "Not support the type: " << id;
81 }
82 return it->second;
83 }
84
85 namespace {
86 template <typename T>
StringToNumberType(const std::string & type_name,const std::string & num_type_name)87 TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
88 TypePtr type = nullptr;
89 if (type_name == num_type_name) {
90 type = std::make_shared<T>();
91 } else {
92 if (num_type_name.size() >= type_name.size()) {
93 MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name
94 << ")";
95 }
96 auto bits = std::stoi(type_name.substr(num_type_name.size()));
97 type = std::make_shared<T>(bits);
98 }
99 return type;
100 }
101
StringToVectorOfType(const std::string & type_names)102 std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
103 std::vector<TypePtr> types;
104 if (type_names.length() == 0) {
105 return types;
106 }
107 std::string::size_type start = 0;
108 std::string::size_type end = type_names.find_first_of(',');
109 while (end != std::string::npos) {
110 types.push_back(StringToType(type_names.substr(start, end)));
111 // Skip ',' to find the next element.
112 start = end + 1;
113 end = type_names.find_first_of(',', start);
114 }
115 if (start >= type_names.size()) {
116 MS_LOG(EXCEPTION) << "Type name is empty string.";
117 }
118 types.push_back(StringToType(type_names.substr(start)));
119 return types;
120 }
121
TensorStrToType(const std::string & type_name)122 TypePtr TensorStrToType(const std::string &type_name) {
123 TypePtr type = nullptr;
124 if (type_name == "Tensor") {
125 type = std::make_shared<TensorType>();
126 } else {
127 auto start = type_name.find_first_of('[') + 1;
128 auto end = type_name.find_last_of(']');
129 if (start >= type_name.size()) {
130 return nullptr;
131 }
132 auto element_str = type_name.substr(start, end - start);
133 auto element_type = StringToType(element_str);
134 if (element_type == nullptr) {
135 return nullptr;
136 }
137 type = std::make_shared<TensorType>(element_type);
138 }
139 return type;
140 }
141
RowTensorStrToType(const std::string & type_name)142 TypePtr RowTensorStrToType(const std::string &type_name) {
143 if (type_name == "RowTensor") {
144 return std::make_shared<RowTensorType>();
145 }
146 auto start = type_name.find_first_of('[') + 1;
147 auto end = type_name.find_last_of(']');
148 if (start >= type_name.size()) {
149 return nullptr;
150 }
151 auto element_str = type_name.substr(start, end - start);
152 auto element_type = StringToType(element_str);
153 if (element_type == nullptr) {
154 return nullptr;
155 }
156 return std::make_shared<RowTensorType>(element_type);
157 }
158
SparseTensorStrToType(const std::string & type_name)159 TypePtr SparseTensorStrToType(const std::string &type_name) {
160 if (type_name == "SparseTensor") {
161 return std::make_shared<SparseTensorType>();
162 }
163 auto start = type_name.find_first_of('[') + 1;
164 auto end = type_name.find_last_of(']');
165 if (start >= type_name.size()) {
166 return nullptr;
167 }
168 auto element_str = type_name.substr(start, end - start);
169 auto element_type = StringToType(element_str);
170 if (element_type == nullptr) {
171 return nullptr;
172 }
173 return std::make_shared<SparseTensorType>(element_type);
174 }
175
UndeterminedStrToType(const std::string & type_name)176 TypePtr UndeterminedStrToType(const std::string &type_name) {
177 if (type_name == "Undetermined") {
178 return std::make_shared<UndeterminedType>();
179 }
180 auto start = type_name.find_first_of('[') + 1;
181 auto end = type_name.find_last_of(']');
182 if (start >= type_name.size()) {
183 return nullptr;
184 }
185 auto element_str = type_name.substr(start, end - start);
186 auto element_type = StringToType(element_str);
187 if (element_type == nullptr) {
188 return nullptr;
189 }
190 return std::make_shared<UndeterminedType>(element_type);
191 }
192
ListStrToType(const std::string & type_name)193 TypePtr ListStrToType(const std::string &type_name) {
194 TypePtr type = nullptr;
195 if (type_name == "List") {
196 type = std::make_shared<List>();
197 } else {
198 auto start = type_name.find_first_of('[') + 1;
199 auto end = type_name.find_last_of(']');
200 if (start >= type_name.size()) {
201 return nullptr;
202 }
203 std::string element_strs = type_name.substr(start, end - start);
204 std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
205 bool wrong = std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
206 if (wrong) {
207 return nullptr;
208 }
209 type = std::make_shared<List>(element_types);
210 }
211
212 return type;
213 }
214
TupleStrToType(const std::string & type_name)215 TypePtr TupleStrToType(const std::string &type_name) {
216 TypePtr type = nullptr;
217 if (type_name == "Tuple") {
218 type = std::make_shared<Tuple>();
219 } else {
220 size_t start = type_name.find_first_of('[') + 1;
221 size_t end = type_name.find_last_of(']');
222 if (start >= type_name.size()) {
223 return nullptr;
224 }
225 std::string element_strs = type_name.substr(start, end - start);
226 std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
227 bool wrong = std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
228 if (wrong) {
229 return nullptr;
230 }
231 type = std::make_shared<Tuple>(element_types);
232 }
233 return type;
234 }
235
FunctionStrToType(const std::string & type_name)236 TypePtr FunctionStrToType(const std::string &type_name) {
237 TypePtr type = nullptr;
238
239 if (type_name == "Function") {
240 type = std::make_shared<Function>();
241 } else {
242 // format: [(para1, para2, para3, ...) retval]
243 size_t start = type_name.find_first_of('[') + 1;
244 size_t end = type_name.find_last_of(']');
245 if (start >= type_name.size()) {
246 return nullptr;
247 }
248 std::string str_all = type_name.substr(start, end - start);
249 size_t start_a = str_all.find_first_of('(') + 1;
250 size_t end_a = str_all.find_last_of(')');
251 if (start_a >= str_all.size()) {
252 return nullptr;
253 }
254 std::string str_args = str_all.substr(start_a, end_a - start_a);
255 // bypass " " between ")" and retval
256 start = end_a + 2;
257 if (start >= str_all.size()) {
258 return nullptr;
259 }
260 std::string str_retval = str_all.substr(start);
261 std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
262 TypePtr retval = StringToType(str_retval);
263 bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
264 if (retval == nullptr || wrong) {
265 return nullptr;
266 }
267 type = std::make_shared<Function>(args_type, retval);
268 }
269 return type;
270 }
271 } // namespace
272
GetTypeByFullString(const std::string & type_name)273 TypePtr GetTypeByFullString(const std::string &type_name) {
274 static std::map<std::string, TypePtr> type_map = {{"None", std::make_shared<TypeNone>()},
275 {"Ellipsis", std::make_shared<TypeEllipsis>()},
276 {"TypeType", std::make_shared<TypeType>()},
277 {"SymbolicKeyType", std::make_shared<SymbolicKeyType>()},
278 {"RefKeyType", std::make_shared<RefKeyType>()},
279 {"EnvType", std::make_shared<EnvType>()},
280 {"Number", std::make_shared<Number>()},
281 {"Bool", std::make_shared<Bool>()},
282 {"Slice", std::make_shared<Slice>()},
283 {"Dictionary", std::make_shared<Dictionary>()},
284 {"String", std::make_shared<String>()},
285 {"Problem", std::make_shared<Problem>()},
286 {"mstype", std::make_shared<TypeType>()},
287 {"UMonad", kUMonadType},
288 {"IOMonad", kIOMonadType}};
289
290 auto iter = type_map.find(type_name);
291 return iter == type_map.end() ? nullptr : iter->second;
292 }
293
GetTypeByStringStarts(const std::string & type_name)294 TypePtr GetTypeByStringStarts(const std::string &type_name) {
295 struct name_cmp {
296 bool operator()(const std::string &l, const std::string &r) const {
297 auto cmp_len = std::min(l.length(), r.length());
298 return r.compare(0, cmp_len, l, 0, cmp_len) < 0;
299 }
300 };
301 static std::map<std::string, std::function<TypePtr(const std::string &type_name)>, name_cmp> type_map = {
302 {"Int", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Int>(type_name, "Int"); }},
303 {"UInt", [](const std::string &type_name) -> TypePtr { return StringToNumberType<UInt>(type_name, "UInt"); }},
304 {"Float", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Float>(type_name, "Float"); }},
305 {"Tensor", [](const std::string &type_name) -> TypePtr { return TensorStrToType(type_name); }},
306 {"Undetermined", [](const std::string &type_name) -> TypePtr { return UndeterminedStrToType(type_name); }},
307 {"RowTensor", [](const std::string &type_name) -> TypePtr { return RowTensorStrToType(type_name); }},
308 {"SparseTensor", [](const std::string &type_name) -> TypePtr { return SparseTensorStrToType(type_name); }},
309 {"List", [](const std::string &type_name) -> TypePtr { return ListStrToType(type_name); }},
310 {"Tuple", [](const std::string &type_name) -> TypePtr { return TupleStrToType(type_name); }},
311 {"Function", [](const std::string &type_name) -> TypePtr { return FunctionStrToType(type_name); }}};
312 auto iter = type_map.find(type_name);
313 return iter == type_map.end() ? nullptr : iter->second(type_name);
314 }
315
StringToType(const std::string & type_name)316 TypePtr StringToType(const std::string &type_name) {
317 auto type = GetTypeByFullString(type_name);
318 if (type == nullptr) {
319 type = GetTypeByStringStarts(type_name);
320 }
321 if (type == nullptr) {
322 // - unsupported to convert
323 // Class
324 // SymbolicType
325 // JTagged
326 // Anything
327 // External
328 MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!";
329 }
330 return type;
331 }
332
IsIdentidityOrSubclass(TypePtr const & x,TypePtr const & base_type)333 bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
334 if (x == nullptr || base_type == nullptr) {
335 MS_LOG(ERROR) << "Type is nullptr.";
336 return false;
337 }
338 auto type_id = base_type->type_id();
339 if (type_id == kTypeUnknown || x->type_id() == kTypeUnknown) {
340 return false;
341 } else if (!(base_type->IsGeneric())) {
342 return *(base_type) == *(x);
343 } else if (type_id == x->type_id() || type_id == x->generic_type_id() || type_id == x->object_type() ||
344 type_id == x->meta_type()) {
345 return true;
346 } else {
347 return false;
348 }
349 }
350
IsSubType(TypePtr const & t1,TypePtr const & t2)351 bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
352 MS_EXCEPTION_IF_NULL(t1);
353 if (t1->type_id() == kTypeUnknown) {
354 return false;
355 } else if (t2 != nullptr) {
356 return IsIdentidityOrSubclass(t1, t2);
357 } else {
358 return true;
359 }
360 }
361 } // namespace mindspore
362