• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/dtype.h"
18 #include <string>
19 #include <cstdlib>
20 #include <algorithm>
21 #include "utils/log_adapter.h"
22 #include "abstract/abstract_value.h"
23 
24 namespace mindspore {
DeepCopy() const25 TypePtr TypeAnything::DeepCopy() const { return kAnyType; }
26 
operator ()(TypePtr const & type) const27 std::size_t TypeHasher::operator()(TypePtr const &type) const {
28   MS_EXCEPTION_IF_NULL(type);
29   std::size_t hash = std::hash<size_t>()(type->type_id());
30   return hash;
31 }
32 
operator ()(const TypePtrList & type_list) const33 std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
34   std::size_t hash_sum = 0;
35   for (auto &type : type_list) {
36     auto type_id = static_cast<std::size_t>(type->type_id());
37     hash_sum = hash_combine(hash_sum, type_id);
38   }
39   return hash_sum;
40 }
41 
operator ()(TypePtr const & t1,TypePtr const & t2) const42 bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
43   MS_EXCEPTION_IF_NULL(t1);
44   MS_EXCEPTION_IF_NULL(t2);
45   return t1->type_id() == t2->type_id();
46 }
47 
operator ()(TypePtrList const & lhs,TypePtrList const & rhs) const48 bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
49   if (lhs.size() != rhs.size()) {
50     return false;
51   }
52   std::size_t size = lhs.size();
53   for (std::size_t i = 0; i < size; ++i) {
54     MS_EXCEPTION_IF_NULL(lhs[i]);
55     MS_EXCEPTION_IF_NULL(rhs[i]);
56     if (*lhs[i] != *rhs[i]) {
57       return false;
58     }
59   }
60   return true;
61 }
62 
TypeIdToType(TypeId id)63 TypePtr TypeIdToType(TypeId id) {
64   static std::unordered_map<TypeId, TypePtr> type_id_to_type = {
65     {kNumberTypeFloat16, kFloat16},     {kNumberTypeFloat, kFloat32},         {kNumberTypeFloat32, kFloat32},
66     {kNumberTypeFloat64, kFloat64},     {kNumberTypeComplex64, kComplex64},   {kNumberTypeInt8, kInt8},
67     {kNumberTypeInt16, kInt16},         {kNumberTypeInt32, kInt32},           {kNumberTypeInt, kInt32},
68     {kNumberTypeInt64, kInt64},         {kNumberTypeUInt8, kUInt8},           {kNumberTypeUInt16, kUInt16},
69     {kNumberTypeUInt32, kUInt32},       {kNumberTypeUInt64, kUInt64},         {kNumberTypeBool, kBool},
70     {kNumberTypeComplex64, kComplex64}, {kNumberTypeComplex128, kComplex128}, {kMetaTypeExternal, kTypeExternal},
71     {kMetaTypeAnything, kAnyType},      {kMetaTypeNone, kTypeNone},           {kMetaTypeNull, kTypeNull},
72     {kMetaTypeEllipsis, kTypeEllipsis}, {kObjectTypeEnvType, kTypeEnv},       {kObjectTypeRefKey, kRefKeyType},
73     {kObjectTypeRef, kRefType},         {kMetaTypeTypeType, kTypeType},       {kObjectTypeString, kString},
74     {kObjectTypeList, kList},           {kObjectTypeTuple, kTuple},           {kObjectTypeDictionary, kDict},
75     {kObjectTypeSlice, kSlice},         {kObjectTypeKeyword, kKeyword},       {kObjectTypeTensorType, kTensorType},
76     {kObjectTypeUMonad, kUMonadType},   {kObjectTypeIOMonad, kIOMonadType},   {kTypeUnknown, kTypeNone},
77     {kMetaTypeProblem, kTypeNone}};
78   const auto &it = type_id_to_type.find(id);
79   if (it == type_id_to_type.end()) {
80     MS_LOG(EXCEPTION) << "Not support the type: " << id;
81   }
82   return it->second;
83 }
84 
85 namespace {
86 template <typename T>
StringToNumberType(const std::string & type_name,const std::string & num_type_name)87 TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
88   TypePtr type = nullptr;
89   if (type_name == num_type_name) {
90     type = std::make_shared<T>();
91   } else {
92     if (num_type_name.size() >= type_name.size()) {
93       MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name
94                         << ")";
95     }
96     auto bits = std::stoi(type_name.substr(num_type_name.size()));
97     type = std::make_shared<T>(bits);
98   }
99   return type;
100 }
101 
StringToVectorOfType(const std::string & type_names)102 std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
103   std::vector<TypePtr> types;
104   if (type_names.length() == 0) {
105     return types;
106   }
107   std::string::size_type start = 0;
108   std::string::size_type end = type_names.find_first_of(',');
109   while (end != std::string::npos) {
110     types.push_back(StringToType(type_names.substr(start, end)));
111     // Skip ',' to find the next element.
112     start = end + 1;
113     end = type_names.find_first_of(',', start);
114   }
115   if (start >= type_names.size()) {
116     MS_LOG(EXCEPTION) << "Type name is empty string.";
117   }
118   types.push_back(StringToType(type_names.substr(start)));
119   return types;
120 }
121 
TensorStrToType(const std::string & type_name)122 TypePtr TensorStrToType(const std::string &type_name) {
123   TypePtr type = nullptr;
124   if (type_name == "Tensor") {
125     type = std::make_shared<TensorType>();
126   } else {
127     auto start = type_name.find_first_of('[') + 1;
128     auto end = type_name.find_last_of(']');
129     if (start >= type_name.size()) {
130       return nullptr;
131     }
132     auto element_str = type_name.substr(start, end - start);
133     auto element_type = StringToType(element_str);
134     if (element_type == nullptr) {
135       return nullptr;
136     }
137     type = std::make_shared<TensorType>(element_type);
138   }
139   return type;
140 }
141 
RowTensorStrToType(const std::string & type_name)142 TypePtr RowTensorStrToType(const std::string &type_name) {
143   if (type_name == "RowTensor") {
144     return std::make_shared<RowTensorType>();
145   }
146   auto start = type_name.find_first_of('[') + 1;
147   auto end = type_name.find_last_of(']');
148   if (start >= type_name.size()) {
149     return nullptr;
150   }
151   auto element_str = type_name.substr(start, end - start);
152   auto element_type = StringToType(element_str);
153   if (element_type == nullptr) {
154     return nullptr;
155   }
156   return std::make_shared<RowTensorType>(element_type);
157 }
158 
SparseTensorStrToType(const std::string & type_name)159 TypePtr SparseTensorStrToType(const std::string &type_name) {
160   if (type_name == "SparseTensor") {
161     return std::make_shared<SparseTensorType>();
162   }
163   auto start = type_name.find_first_of('[') + 1;
164   auto end = type_name.find_last_of(']');
165   if (start >= type_name.size()) {
166     return nullptr;
167   }
168   auto element_str = type_name.substr(start, end - start);
169   auto element_type = StringToType(element_str);
170   if (element_type == nullptr) {
171     return nullptr;
172   }
173   return std::make_shared<SparseTensorType>(element_type);
174 }
175 
UndeterminedStrToType(const std::string & type_name)176 TypePtr UndeterminedStrToType(const std::string &type_name) {
177   if (type_name == "Undetermined") {
178     return std::make_shared<UndeterminedType>();
179   }
180   auto start = type_name.find_first_of('[') + 1;
181   auto end = type_name.find_last_of(']');
182   if (start >= type_name.size()) {
183     return nullptr;
184   }
185   auto element_str = type_name.substr(start, end - start);
186   auto element_type = StringToType(element_str);
187   if (element_type == nullptr) {
188     return nullptr;
189   }
190   return std::make_shared<UndeterminedType>(element_type);
191 }
192 
ListStrToType(const std::string & type_name)193 TypePtr ListStrToType(const std::string &type_name) {
194   TypePtr type = nullptr;
195   if (type_name == "List") {
196     type = std::make_shared<List>();
197   } else {
198     auto start = type_name.find_first_of('[') + 1;
199     auto end = type_name.find_last_of(']');
200     if (start >= type_name.size()) {
201       return nullptr;
202     }
203     std::string element_strs = type_name.substr(start, end - start);
204     std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
205     bool wrong = std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
206     if (wrong) {
207       return nullptr;
208     }
209     type = std::make_shared<List>(element_types);
210   }
211 
212   return type;
213 }
214 
TupleStrToType(const std::string & type_name)215 TypePtr TupleStrToType(const std::string &type_name) {
216   TypePtr type = nullptr;
217   if (type_name == "Tuple") {
218     type = std::make_shared<Tuple>();
219   } else {
220     size_t start = type_name.find_first_of('[') + 1;
221     size_t end = type_name.find_last_of(']');
222     if (start >= type_name.size()) {
223       return nullptr;
224     }
225     std::string element_strs = type_name.substr(start, end - start);
226     std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
227     bool wrong = std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
228     if (wrong) {
229       return nullptr;
230     }
231     type = std::make_shared<Tuple>(element_types);
232   }
233   return type;
234 }
235 
FunctionStrToType(const std::string & type_name)236 TypePtr FunctionStrToType(const std::string &type_name) {
237   TypePtr type = nullptr;
238 
239   if (type_name == "Function") {
240     type = std::make_shared<Function>();
241   } else {
242     // format: [(para1, para2, para3, ...) retval]
243     size_t start = type_name.find_first_of('[') + 1;
244     size_t end = type_name.find_last_of(']');
245     if (start >= type_name.size()) {
246       return nullptr;
247     }
248     std::string str_all = type_name.substr(start, end - start);
249     size_t start_a = str_all.find_first_of('(') + 1;
250     size_t end_a = str_all.find_last_of(')');
251     if (start_a >= str_all.size()) {
252       return nullptr;
253     }
254     std::string str_args = str_all.substr(start_a, end_a - start_a);
255     // bypass " " between ")" and retval
256     start = end_a + 2;
257     if (start >= str_all.size()) {
258       return nullptr;
259     }
260     std::string str_retval = str_all.substr(start);
261     std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
262     TypePtr retval = StringToType(str_retval);
263     bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
264     if (retval == nullptr || wrong) {
265       return nullptr;
266     }
267     type = std::make_shared<Function>(args_type, retval);
268   }
269   return type;
270 }
271 }  // namespace
272 
GetTypeByFullString(const std::string & type_name)273 TypePtr GetTypeByFullString(const std::string &type_name) {
274   static std::map<std::string, TypePtr> type_map = {{"None", std::make_shared<TypeNone>()},
275                                                     {"Ellipsis", std::make_shared<TypeEllipsis>()},
276                                                     {"TypeType", std::make_shared<TypeType>()},
277                                                     {"SymbolicKeyType", std::make_shared<SymbolicKeyType>()},
278                                                     {"RefKeyType", std::make_shared<RefKeyType>()},
279                                                     {"EnvType", std::make_shared<EnvType>()},
280                                                     {"Number", std::make_shared<Number>()},
281                                                     {"Bool", std::make_shared<Bool>()},
282                                                     {"Slice", std::make_shared<Slice>()},
283                                                     {"Dictionary", std::make_shared<Dictionary>()},
284                                                     {"String", std::make_shared<String>()},
285                                                     {"Problem", std::make_shared<Problem>()},
286                                                     {"mstype", std::make_shared<TypeType>()},
287                                                     {"UMonad", kUMonadType},
288                                                     {"IOMonad", kIOMonadType}};
289 
290   auto iter = type_map.find(type_name);
291   return iter == type_map.end() ? nullptr : iter->second;
292 }
293 
GetTypeByStringStarts(const std::string & type_name)294 TypePtr GetTypeByStringStarts(const std::string &type_name) {
295   struct name_cmp {
296     bool operator()(const std::string &l, const std::string &r) const {
297       auto cmp_len = std::min(l.length(), r.length());
298       return r.compare(0, cmp_len, l, 0, cmp_len) < 0;
299     }
300   };
301   static std::map<std::string, std::function<TypePtr(const std::string &type_name)>, name_cmp> type_map = {
302     {"Int", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Int>(type_name, "Int"); }},
303     {"UInt", [](const std::string &type_name) -> TypePtr { return StringToNumberType<UInt>(type_name, "UInt"); }},
304     {"Float", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Float>(type_name, "Float"); }},
305     {"Tensor", [](const std::string &type_name) -> TypePtr { return TensorStrToType(type_name); }},
306     {"Undetermined", [](const std::string &type_name) -> TypePtr { return UndeterminedStrToType(type_name); }},
307     {"RowTensor", [](const std::string &type_name) -> TypePtr { return RowTensorStrToType(type_name); }},
308     {"SparseTensor", [](const std::string &type_name) -> TypePtr { return SparseTensorStrToType(type_name); }},
309     {"List", [](const std::string &type_name) -> TypePtr { return ListStrToType(type_name); }},
310     {"Tuple", [](const std::string &type_name) -> TypePtr { return TupleStrToType(type_name); }},
311     {"Function", [](const std::string &type_name) -> TypePtr { return FunctionStrToType(type_name); }}};
312   auto iter = type_map.find(type_name);
313   return iter == type_map.end() ? nullptr : iter->second(type_name);
314 }
315 
StringToType(const std::string & type_name)316 TypePtr StringToType(const std::string &type_name) {
317   auto type = GetTypeByFullString(type_name);
318   if (type == nullptr) {
319     type = GetTypeByStringStarts(type_name);
320   }
321   if (type == nullptr) {
322     // - unsupported to convert
323     // Class
324     // SymbolicType
325     // JTagged
326     // Anything
327     // External
328     MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!";
329   }
330   return type;
331 }
332 
IsIdentidityOrSubclass(TypePtr const & x,TypePtr const & base_type)333 bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
334   if (x == nullptr || base_type == nullptr) {
335     MS_LOG(ERROR) << "Type is nullptr.";
336     return false;
337   }
338   auto type_id = base_type->type_id();
339   if (type_id == kTypeUnknown || x->type_id() == kTypeUnknown) {
340     return false;
341   } else if (!(base_type->IsGeneric())) {
342     return *(base_type) == *(x);
343   } else if (type_id == x->type_id() || type_id == x->generic_type_id() || type_id == x->object_type() ||
344              type_id == x->meta_type()) {
345     return true;
346   } else {
347     return false;
348   }
349 }
350 
IsSubType(TypePtr const & t1,TypePtr const & t2)351 bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
352   MS_EXCEPTION_IF_NULL(t1);
353   if (t1->type_id() == kTypeUnknown) {
354     return false;
355   } else if (t2 != nullptr) {
356     return IsIdentidityOrSubclass(t1, t2);
357   } else {
358     return true;
359   }
360 }
361 }  // namespace mindspore
362