1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2022 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_IR_DTYPE_TYPE_H_ 20 #define MINDSPORE_CORE_IR_DTYPE_TYPE_H_ 21 22 #include <cstddef> 23 #include <iostream> 24 #include <initializer_list> 25 #include <unordered_map> 26 #include <memory> 27 #include <utility> 28 #include <sstream> 29 #include <string> 30 #include <vector> 31 #include <type_traits> 32 #include <algorithm> 33 #include "utils/hash_map.h" 34 #include "base/base.h" 35 #include "ir/named.h" 36 #include "ir/dtype/type_id.h" 37 #include "utils/ms_utils.h" 38 39 namespace mindspore { 40 41 TypeId IntBitsToTypeId(const int nbits); 42 TypeId UIntBitsToTypeId(const int nbits); 43 TypeId FloatBitsToTypeId(const int nbits); 44 TypeId BFloatBitsToTypeId(const int nbits); 45 TypeId ComplexBitsToTypeId(const int nbits); 46 47 /// \brief Get label of the input TypeId. 48 /// 49 /// \param[in] v Define the input TypeId. 50 /// \return The label of input TypeId. 51 MS_CORE_API const std::string &TypeIdLabel(const TypeId &v); 52 MS_CORE_API TypeId NormalizeTypeId(const TypeId type_id); 53 bool IsSameObjectType(const Type &lhs, const Type &rhs); 54 MS_CORE_API size_t GetTypeByte(const TypePtr &type_ptr); 55 MS_CORE_API int64_t GetTypeId(const TypeId &type_id); 56 57 enum class BitsNum : int { 58 eBits4 = 4, 59 eBits8 = 8, 60 eBits16 = 16, 61 eBits32 = 32, 62 eBits64 = 64, 63 eBits128 = 128, 64 }; 65 66 /// \brief Type defines an Value class for type. 67 class MS_CORE_API Type : public Value { 68 public: 69 /// \brief Default constructor for Type. Type()70 Type() : meta_type_(kMetaTypeType), is_generic_(true) {} 71 72 /// \brief Constructor for Type. 73 /// 74 /// \param[in] t Define TypeId for Type object. 75 /// \param[in] is_generic Define whether the Type object is generic. meta_type_(t)76 explicit Type(TypeId t, bool is_generic = true) : meta_type_(t), is_generic_(is_generic) {} 77 78 /// \brief Destructor of Type. 79 ~Type() override = default; 80 MS_DECLARE_PARENT(Type, Value) 81 82 bool operator==(const Value &other) const override; 83 84 /// \brief Show the meta type of the Type object. 85 /// 86 /// \return The meta type of the Type object. meta_type()87 TypeId meta_type() const { return meta_type_; } 88 89 /// \brief Show the type id of the Type object. 90 /// 91 /// \return The type id of the Type object. type_id()92 virtual TypeId type_id() const { return meta_type_; } 93 94 /// \brief Show the generic type id for the Number object. 95 /// 96 /// \return The generic type id. generic_type_id()97 virtual TypeId generic_type_id() const { return kMetaTypeType; } 98 99 /// \brief Check whether the input is not the current Type object. 100 /// 101 /// \param[in] other Define a Value object. 102 /// \return Check whether the current object and other object are different. 103 virtual bool operator!=(const Type &other) const { return !(*this == other); } 104 105 /// \brief Check whether the input is the current Type object. 106 /// 107 /// \param[in] other Define a Value object. 108 /// \return Check whether the current object and other object have the same type id. 109 virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); } 110 111 /// \brief Check whether the input is the current Type object. 112 /// 113 /// \param[in] other Define a TypePtr. 114 /// \return Check whether the current object and other object are the same. equal(const TypePtr other)115 virtual bool equal(const TypePtr other) const { return *this == *other; } 116 117 /// \brief Get the object type of the Type object. 118 /// 119 /// \return The object type of the Type object. object_type()120 virtual TypeId object_type() const { return kTypeUnknown; } 121 122 /// \brief Get the parent type of the Type object. 123 /// 124 /// \return The parent type of the Type object. parent_type()125 virtual TypeId parent_type() const { return kTypeUnknown; } 126 127 /// \brief Get the number type of the Type object. 128 /// 129 /// \return The number type of the Type object. number_type()130 virtual TypeId number_type() const { return kTypeUnknown; } 131 132 /// \brief Deep copy the Type object. 133 /// 134 /// \return The deep copy of the Type object. 135 virtual TypePtr DeepCopy() const = 0; 136 137 /// \brief Clone the Type object. 138 /// 139 /// \return The clone of the Type object. Clone()140 virtual TypePtr Clone() const { return DeepCopy(); } 141 hash()142 std::size_t hash() const override { return static_cast<size_t>(type_id()); } ToString()143 std::string ToString() const override { return TypeIdLabel(meta_type_); } 144 145 /// \brief Get Type object ToReprString description. 146 /// 147 /// \return The description of Type object. ToReprString()148 virtual std::string ToReprString() const { return ToString(); } 149 150 /// \brief Get Type object ToReprString description. 151 /// 152 /// \return The description of Type object. ReprString()153 std::string ReprString() const { return "mindspore." + ToReprString(); } dump()154 void dump() const override { std::cout << ToString() << std::endl; } 155 156 /// \brief Check whether the Type object is unknown. 157 /// 158 /// \return whether the Type object is unknown. IsUnknown()159 bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } 160 161 /// \brief Check whether the Type object is generic. 162 /// 163 /// \return whether the Type object is generic. IsGeneric()164 bool IsGeneric() const { return is_generic_; } 165 abstract::AbstractBasePtr ToAbstract() override; 166 167 /// \brief Get Type object ToString description. 168 /// 169 /// \param[in] os The ostream to receive the description 170 /// \param[in] type The Type object need to show the description 171 /// \return The ostream with Type object description 172 MS_CORE_API friend std::ostream &operator<<(std::ostream &os, const Type &type); 173 174 /// \brief Get Type object ToString description. 175 /// 176 /// \param[in] os The ostream to receive the description 177 /// \param[in] type The TypePtr need to show the description 178 /// \return The ostream with Type object description 179 MS_CORE_API friend std::ostream &operator<<(std::ostream &os, const TypePtr type); 180 181 private: 182 TypeId meta_type_; 183 bool is_generic_; 184 }; 185 186 using TypePtrList = std::vector<TypePtr>; 187 188 /// \brief Type defines an Type class for object. 189 class MS_CORE_API Object : public Type { 190 public: 191 /// \brief Default constructor for Object. Object()192 Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject), parent_type_(kMetaTypeObject) {} 193 194 /// \brief Constructor for Object. 195 /// 196 /// \param[in] object_type Define object type for Object object. 197 /// \param[in] is_generic Define whether the Object object is generic. 198 explicit Object(const TypeId object_type, bool is_generic = true) Type(kMetaTypeObject,is_generic)199 : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(kMetaTypeObject) {} 200 201 /// \brief Constructor for Object. 202 /// 203 /// \param[in] object_type Define object type for Object object. 204 /// \param[in] parent_type Define the parent type for Object object. 205 /// \param[in] is_generic Define whether the Object object is generic. 206 explicit Object(const TypeId object_type, const TypeId parent_type, bool is_generic = true) Type(kMetaTypeObject,is_generic)207 : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(parent_type) {} 208 209 /// \brief Destructor of Object. 210 ~Object() override = default; MS_DECLARE_PARENT(Object,Type)211 MS_DECLARE_PARENT(Object, Type) 212 213 TypeId object_type() const override { return object_type_; } parent_type()214 TypeId parent_type() const override { return parent_type_; } type_id()215 TypeId type_id() const override { return object_type_; } generic_type_id()216 TypeId generic_type_id() const override { return kMetaTypeObject; } 217 bool equal(const TypePtr other) const override; ToString()218 std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } 219 220 /// \brief Get Object object ToString description. 221 /// 222 /// \param[in] os The ostream to receive the description 223 /// \param[in] obj The Object object need to show the description 224 /// \return The ostream with Object object description 225 friend std::ostream &operator<<(std::ostream &os, const Object &obj); 226 227 /// \brief Get Object object ToString description. 228 /// 229 /// \param[in] os The ostream to receive the description 230 /// \param[in] obj The Object object need to show the description 231 /// \return The ostream with Object object description 232 friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj); 233 234 private: 235 const TypeId object_type_; 236 const TypeId parent_type_; 237 }; 238 239 /// \brief Gettype_name_map. 240 /// 241 /// \return type_name_map 242 MS_CORE_API const mindspore::HashMap<TypeId, std::string> &type_name_map(); 243 244 /// \brief type_priority_map. 245 /// 246 /// \return type_priority_map 247 MS_CORE_API const mindspore::HashMap<TypeId, int> &type_priority_map(); 248 249 /// \brief Get TypePtrList description. 250 /// 251 /// \param[in] os The ostream to receive the description 252 /// \param[in] types The TypePtrList need to show the description 253 /// \return The ostream with TypePtrList description 254 MS_CORE_API std::ostream &operator<<(std::ostream &os, const TypePtrList &types); 255 256 /// \brief TypeHashById provides a hash function by Type id. 257 struct MS_CORE_API TypeHashById { operatorTypeHashById258 std::size_t operator()(TypePtr const &type) const { 259 return type == nullptr ? 0 : static_cast<size_t>(type->type_id()); 260 } 261 }; 262 263 /// \brief TypeEqualById provides an equivalent function by Type id. 264 struct MS_CORE_API TypeEqualById { operatorTypeEqualById265 bool operator()(const TypePtr &t1, const TypePtr &t2) const { 266 return (t1 == t2) || (t1 != nullptr && t2 != nullptr && t1->type_id() == t2->type_id()); 267 } 268 }; 269 270 /// \brief TypeListHasher provides a hash function for the list of shared_ptr of Type. 271 struct MS_CORE_API TypeListHasher { operatorTypeListHasher272 std::size_t operator()(const TypePtrList &type_list) const { 273 // Hash for empty list is zero. 274 if (type_list.empty()) { 275 return 0; 276 } 277 // Hashing all elements is costly, we only calculate hash from 278 // the first element and last few elements base on some experiments. 279 // In some scenarios, this may lead high hash conflicts. Therefore, 280 // we should use this hash function in hash tables that can tolerate 281 // high hash conflicts, such as std::unordered_map. 282 constexpr size_t max_last_types = 4; 283 const size_t n_args = type_list.size(); 284 // Hash from list size and the first element. 285 const auto &first_type = type_list[0]; 286 std::size_t hash_sum = hash_combine(n_args, (first_type == nullptr ? 0 : first_type->hash())); 287 // Hash from last few elements. 288 const size_t start = ((n_args > max_last_types) ? (n_args - max_last_types) : 1); 289 for (size_t i = start; i < n_args; ++i) { 290 const auto &type = type_list[i]; 291 hash_sum = hash_combine(hash_sum, (type == nullptr ? 0 : type->hash())); 292 } 293 return hash_sum; 294 } 295 }; 296 297 /// \brief TypeListEqual provides an equivalent function for the list of shared_ptr of Type. 298 struct MS_CORE_API TypeListEqual { operatorTypeListEqual299 bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { 300 const auto size = lhs.size(); 301 if (size != rhs.size()) { 302 return false; 303 } 304 for (std::size_t i = 0; i < size; ++i) { 305 if (!common::IsEqual(lhs[i], rhs[i])) { 306 return false; 307 } 308 } 309 return true; 310 } 311 }; 312 313 // Hash map that using TypePtrList as the key. 314 template <typename T> 315 using TypeListMap = std::unordered_map<TypePtrList, T, TypeListHasher, TypeListEqual>; 316 } // namespace mindspore 317 318 #endif // MINDSPORE_CORE_IR_DTYPE_TYPE_H_ 319