• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_IR_DTYPE_TYPE_H_
20 #define MINDSPORE_CORE_IR_DTYPE_TYPE_H_
21 
22 #include <cstddef>
23 #include <iostream>
24 #include <initializer_list>
25 #include <unordered_map>
26 #include <memory>
27 #include <utility>
28 #include <sstream>
29 #include <string>
30 #include <vector>
31 #include <type_traits>
32 #include <algorithm>
33 #include "utils/hash_map.h"
34 #include "base/base.h"
35 #include "ir/named.h"
36 #include "ir/dtype/type_id.h"
37 #include "utils/ms_utils.h"
38 
39 namespace mindspore {
40 
41 TypeId IntBitsToTypeId(const int nbits);
42 TypeId UIntBitsToTypeId(const int nbits);
43 TypeId FloatBitsToTypeId(const int nbits);
44 TypeId BFloatBitsToTypeId(const int nbits);
45 TypeId ComplexBitsToTypeId(const int nbits);
46 
47 /// \brief Get label of the input TypeId.
48 ///
49 /// \param[in] v Define the input TypeId.
50 /// \return The label of input TypeId.
51 MS_CORE_API const std::string &TypeIdLabel(const TypeId &v);
52 MS_CORE_API TypeId NormalizeTypeId(const TypeId type_id);
53 bool IsSameObjectType(const Type &lhs, const Type &rhs);
54 MS_CORE_API size_t GetTypeByte(const TypePtr &type_ptr);
55 MS_CORE_API int64_t GetTypeId(const TypeId &type_id);
56 
57 enum class BitsNum : int {
58   eBits4 = 4,
59   eBits8 = 8,
60   eBits16 = 16,
61   eBits32 = 32,
62   eBits64 = 64,
63   eBits128 = 128,
64 };
65 
66 /// \brief Type defines an Value class for type.
67 class MS_CORE_API Type : public Value {
68  public:
69   /// \brief Default constructor for Type.
Type()70   Type() : meta_type_(kMetaTypeType), is_generic_(true) {}
71 
72   /// \brief Constructor for Type.
73   ///
74   /// \param[in] t Define TypeId for Type object.
75   /// \param[in] is_generic Define whether the Type object is generic.
meta_type_(t)76   explicit Type(TypeId t, bool is_generic = true) : meta_type_(t), is_generic_(is_generic) {}
77 
78   /// \brief Destructor of Type.
79   ~Type() override = default;
80   MS_DECLARE_PARENT(Type, Value)
81 
82   bool operator==(const Value &other) const override;
83 
84   /// \brief Show the meta type of the Type object.
85   ///
86   /// \return The meta type of the Type object.
meta_type()87   TypeId meta_type() const { return meta_type_; }
88 
89   /// \brief Show the type id of the Type object.
90   ///
91   /// \return The type id of the Type object.
type_id()92   virtual TypeId type_id() const { return meta_type_; }
93 
94   /// \brief Show the generic type id for the Number object.
95   ///
96   /// \return The generic type id.
generic_type_id()97   virtual TypeId generic_type_id() const { return kMetaTypeType; }
98 
99   /// \brief Check whether the input is not the current Type object.
100   ///
101   /// \param[in] other Define a Value object.
102   /// \return Check whether the current object and other object are different.
103   virtual bool operator!=(const Type &other) const { return !(*this == other); }
104 
105   /// \brief Check whether the input is the current Type object.
106   ///
107   /// \param[in] other Define a Value object.
108   /// \return Check whether the current object and other object have the same type id.
109   virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); }
110 
111   /// \brief Check whether the input is the current Type object.
112   ///
113   /// \param[in] other Define a TypePtr.
114   /// \return Check whether the current object and other object are the same.
equal(const TypePtr other)115   virtual bool equal(const TypePtr other) const { return *this == *other; }
116 
117   /// \brief Get the object type of the Type object.
118   ///
119   /// \return The object type of the Type object.
object_type()120   virtual TypeId object_type() const { return kTypeUnknown; }
121 
122   /// \brief Get the parent type of the Type object.
123   ///
124   /// \return The parent type of the Type object.
parent_type()125   virtual TypeId parent_type() const { return kTypeUnknown; }
126 
127   /// \brief Get the number type of the Type object.
128   ///
129   /// \return The number type of the Type object.
number_type()130   virtual TypeId number_type() const { return kTypeUnknown; }
131 
132   /// \brief Deep copy the Type object.
133   ///
134   /// \return The deep copy of the Type object.
135   virtual TypePtr DeepCopy() const = 0;
136 
137   /// \brief Clone the Type object.
138   ///
139   /// \return The clone of the Type object.
Clone()140   virtual TypePtr Clone() const { return DeepCopy(); }
141 
hash()142   std::size_t hash() const override { return static_cast<size_t>(type_id()); }
ToString()143   std::string ToString() const override { return TypeIdLabel(meta_type_); }
144 
145   /// \brief Get Type object ToReprString description.
146   ///
147   /// \return The description of Type object.
ToReprString()148   virtual std::string ToReprString() const { return ToString(); }
149 
150   /// \brief Get Type object ToReprString description.
151   ///
152   /// \return The description of Type object.
ReprString()153   std::string ReprString() const { return "mindspore." + ToReprString(); }
dump()154   void dump() const override { std::cout << ToString() << std::endl; }
155 
156   /// \brief Check whether the Type object is unknown.
157   ///
158   /// \return whether the Type object is unknown.
IsUnknown()159   bool IsUnknown() const { return (meta_type_ == kMetaTypeType); }
160 
161   /// \brief Check whether the Type object is generic.
162   ///
163   /// \return whether the Type object is generic.
IsGeneric()164   bool IsGeneric() const { return is_generic_; }
165   abstract::AbstractBasePtr ToAbstract() override;
166 
167   /// \brief Get Type object ToString description.
168   ///
169   /// \param[in] os The ostream to receive the description
170   /// \param[in] type The Type object need to show the description
171   /// \return The ostream with Type object description
172   MS_CORE_API friend std::ostream &operator<<(std::ostream &os, const Type &type);
173 
174   /// \brief Get Type object ToString description.
175   ///
176   /// \param[in] os The ostream to receive the description
177   /// \param[in] type The TypePtr need to show the description
178   /// \return The ostream with Type object description
179   MS_CORE_API friend std::ostream &operator<<(std::ostream &os, const TypePtr type);
180 
181  private:
182   TypeId meta_type_;
183   bool is_generic_;
184 };
185 
186 using TypePtrList = std::vector<TypePtr>;
187 
188 /// \brief Type defines an Type class for object.
189 class MS_CORE_API Object : public Type {
190  public:
191   /// \brief Default constructor for Object.
Object()192   Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject), parent_type_(kMetaTypeObject) {}
193 
194   /// \brief Constructor for Object.
195   ///
196   /// \param[in] object_type Define object type for Object object.
197   /// \param[in] is_generic Define whether the Object object is generic.
198   explicit Object(const TypeId object_type, bool is_generic = true)
Type(kMetaTypeObject,is_generic)199       : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(kMetaTypeObject) {}
200 
201   /// \brief Constructor for Object.
202   ///
203   /// \param[in] object_type Define object type for Object object.
204   /// \param[in] parent_type Define the parent type for Object object.
205   /// \param[in] is_generic Define whether the Object object is generic.
206   explicit Object(const TypeId object_type, const TypeId parent_type, bool is_generic = true)
Type(kMetaTypeObject,is_generic)207       : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(parent_type) {}
208 
209   /// \brief Destructor of Object.
210   ~Object() override = default;
MS_DECLARE_PARENT(Object,Type)211   MS_DECLARE_PARENT(Object, Type)
212 
213   TypeId object_type() const override { return object_type_; }
parent_type()214   TypeId parent_type() const override { return parent_type_; }
type_id()215   TypeId type_id() const override { return object_type_; }
generic_type_id()216   TypeId generic_type_id() const override { return kMetaTypeObject; }
217   bool equal(const TypePtr other) const override;
ToString()218   std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); }
219 
220   /// \brief Get Object object ToString description.
221   ///
222   /// \param[in] os The ostream to receive the description
223   /// \param[in] obj The Object object need to show the description
224   /// \return The ostream with Object object description
225   friend std::ostream &operator<<(std::ostream &os, const Object &obj);
226 
227   /// \brief Get Object object ToString description.
228   ///
229   /// \param[in] os The ostream to receive the description
230   /// \param[in] obj The Object object need to show the description
231   /// \return The ostream with Object object description
232   friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj);
233 
234  private:
235   const TypeId object_type_;
236   const TypeId parent_type_;
237 };
238 
239 /// \brief Gettype_name_map.
240 ///
241 /// \return type_name_map
242 MS_CORE_API const mindspore::HashMap<TypeId, std::string> &type_name_map();
243 
244 /// \brief type_priority_map.
245 ///
246 /// \return type_priority_map
247 MS_CORE_API const mindspore::HashMap<TypeId, int> &type_priority_map();
248 
249 /// \brief Get TypePtrList description.
250 ///
251 /// \param[in] os The ostream to receive the description
252 /// \param[in] types The TypePtrList need to show the description
253 /// \return The ostream with TypePtrList description
254 MS_CORE_API std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
255 
256 /// \brief TypeHashById provides a hash function by Type id.
257 struct MS_CORE_API TypeHashById {
operatorTypeHashById258   std::size_t operator()(TypePtr const &type) const {
259     return type == nullptr ? 0 : static_cast<size_t>(type->type_id());
260   }
261 };
262 
263 /// \brief TypeEqualById provides an equivalent function by Type id.
264 struct MS_CORE_API TypeEqualById {
operatorTypeEqualById265   bool operator()(const TypePtr &t1, const TypePtr &t2) const {
266     return (t1 == t2) || (t1 != nullptr && t2 != nullptr && t1->type_id() == t2->type_id());
267   }
268 };
269 
270 /// \brief TypeListHasher provides a hash function for the list of shared_ptr of Type.
271 struct MS_CORE_API TypeListHasher {
operatorTypeListHasher272   std::size_t operator()(const TypePtrList &type_list) const {
273     // Hash for empty list is zero.
274     if (type_list.empty()) {
275       return 0;
276     }
277     // Hashing all elements is costly, we only calculate hash from
278     // the first element and last few elements base on some experiments.
279     // In some scenarios, this may lead high hash conflicts. Therefore,
280     // we should use this hash function in hash tables that can tolerate
281     // high hash conflicts, such as std::unordered_map.
282     constexpr size_t max_last_types = 4;
283     const size_t n_args = type_list.size();
284     // Hash from list size and the first element.
285     const auto &first_type = type_list[0];
286     std::size_t hash_sum = hash_combine(n_args, (first_type == nullptr ? 0 : first_type->hash()));
287     // Hash from last few elements.
288     const size_t start = ((n_args > max_last_types) ? (n_args - max_last_types) : 1);
289     for (size_t i = start; i < n_args; ++i) {
290       const auto &type = type_list[i];
291       hash_sum = hash_combine(hash_sum, (type == nullptr ? 0 : type->hash()));
292     }
293     return hash_sum;
294   }
295 };
296 
297 /// \brief TypeListEqual provides an equivalent function for the list of shared_ptr of Type.
298 struct MS_CORE_API TypeListEqual {
operatorTypeListEqual299   bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
300     const auto size = lhs.size();
301     if (size != rhs.size()) {
302       return false;
303     }
304     for (std::size_t i = 0; i < size; ++i) {
305       if (!common::IsEqual(lhs[i], rhs[i])) {
306         return false;
307       }
308     }
309     return true;
310   }
311 };
312 
313 // Hash map that using TypePtrList as the key.
314 template <typename T>
315 using TypeListMap = std::unordered_map<TypePtrList, T, TypeListHasher, TypeListEqual>;
316 }  // namespace mindspore
317 
318 #endif  // MINDSPORE_CORE_IR_DTYPE_TYPE_H_
319