1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ 18 #define MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ 19 20 #include <map> 21 #include <memory> 22 #include <sstream> 23 #include <string> 24 #include "utils/hash_map.h" 25 #include "base/base.h" 26 #include "ir/named.h" 27 #include "ir/dtype/type.h" 28 29 namespace mindspore { 30 /// \brief Number defines an Object class whose type is number. 31 class MS_CORE_API Number : public Object { 32 public: 33 /// \brief Default constructor for Number. Number()34 Number() : Object(kObjectTypeNumber), number_type_(kObjectTypeNumber), nbits_(0) {} 35 36 /// \brief Constructor for Number. 37 /// 38 /// \param[in] number_type Define the number type of Number object. 39 /// \param[in] nbits Define the bit length of Number object. 40 /// \param[in] is_generic Define whether it is generic for Number object. 41 Number(const TypeId number_type, const int nbits, bool is_generic = true) Object(kObjectTypeNumber,is_generic)42 : Object(kObjectTypeNumber, is_generic), number_type_(number_type), nbits_(nbits) {} 43 44 /// \brief Destructor of Number. 45 ~Number() override = default; MS_DECLARE_PARENT(Number,Object)46 MS_DECLARE_PARENT(Number, Object) 47 48 /// \brief Get the bit length of Number object. 49 /// 50 /// \return bit length of Number object. 51 int nbits() const { return nbits_; } 52 number_type()53 TypeId number_type() const override { return number_type_; } type_id()54 TypeId type_id() const override { return number_type_; } generic_type_id()55 TypeId generic_type_id() const override { return kObjectTypeNumber; } 56 bool operator==(const Type &other) const override; 57 std::size_t hash() const override; DeepCopy()58 TypePtr DeepCopy() const override { return std::make_shared<Number>(); } ToString()59 std::string ToString() const override { return "Number"; } ToReprString()60 std::string ToReprString() const override { return "number"; } DumpText()61 std::string DumpText() const override { return "Number"; } 62 63 /// \brief Get type name for Number object. 64 /// 65 /// \param type_name Define the type name. 66 /// \return The full type name of the Number object. GetTypeName(const std::string & type_name)67 std::string GetTypeName(const std::string &type_name) const { 68 std::ostringstream oss; 69 oss << type_name; 70 if (nbits() != 0) { 71 oss << nbits(); 72 } 73 return oss.str(); 74 } 75 76 private: 77 const TypeId number_type_; 78 const int nbits_; 79 }; 80 81 using NumberPtr = std::shared_ptr<Number>; 82 83 // Bool 84 /// \brief Bool defines a Number class whose type is boolean. 85 class MS_CORE_API Bool : public Number { 86 public: 87 /// \brief Default constructor for Bool. Bool()88 Bool() : Number(kNumberTypeBool, 8) {} 89 90 /// \brief Destructor of Bool. 91 ~Bool() override = default; MS_DECLARE_PARENT(Bool,Number)92 MS_DECLARE_PARENT(Bool, Number) 93 94 TypeId generic_type_id() const override { return kNumberTypeBool; } DeepCopy()95 TypePtr DeepCopy() const override { return std::make_shared<Bool>(); } ToString()96 std::string ToString() const override { return "Bool"; } ToReprString()97 std::string ToReprString() const override { return "bool_"; } DumpText()98 std::string DumpText() const override { return "Bool"; } 99 }; 100 101 // Int 102 /// \brief Int defines a Number class whose type is int. 103 class MS_CORE_API Int : public Number { 104 public: 105 /// \brief Default constructor for Int. Int()106 Int() : Number(kNumberTypeInt, 0) {} 107 108 /// \brief Constructor for Int. 109 /// 110 /// \param nbits Define the bit length of Int object. 111 explicit Int(const int nbits); 112 113 /// \brief Destructor of Int. 114 ~Int() override = default; MS_DECLARE_PARENT(Int,Number)115 MS_DECLARE_PARENT(Int, Number) 116 117 TypeId generic_type_id() const override { return kNumberTypeInt; } DeepCopy()118 TypePtr DeepCopy() const override { 119 if (nbits() == 0) { 120 return std::make_shared<Int>(); 121 } 122 return std::make_shared<Int>(nbits()); 123 } 124 ToString()125 std::string ToString() const override { return GetTypeName("Int"); } ToReprString()126 std::string ToReprString() const override { return nbits() == 0 ? "int_" : GetTypeName("int"); } DumpText()127 std::string DumpText() const override { 128 return nbits() == 0 ? std::string("Int") : std::string("I") + std::to_string(nbits()); 129 } 130 }; 131 132 // UInt 133 /// \brief UInt defines a Number class whose type is uint. 134 class MS_CORE_API UInt : public Number { 135 public: 136 /// \brief Default constructor for UInt. UInt()137 UInt() : Number(kNumberTypeUInt, 0) {} 138 139 /// \brief Constructor for UInt. 140 /// 141 /// \param nbits Define the bit length of UInt object. 142 explicit UInt(const int nbits); 143 generic_type_id()144 TypeId generic_type_id() const override { return kNumberTypeUInt; } 145 146 /// \brief Destructor of UInt. ~UInt()147 ~UInt() override {} MS_DECLARE_PARENT(UInt,Number)148 MS_DECLARE_PARENT(UInt, Number) 149 150 TypePtr DeepCopy() const override { 151 if (nbits() == 0) { 152 return std::make_shared<UInt>(); 153 } 154 return std::make_shared<UInt>(nbits()); 155 } 156 ToString()157 std::string ToString() const override { return GetTypeName("UInt"); } ToReprString()158 std::string ToReprString() const override { return GetTypeName("uint"); } DumpText()159 std::string DumpText() const override { 160 return nbits() == 0 ? std::string("UInt") : std::string("U") + std::to_string(nbits()); 161 } 162 }; 163 164 // Float 165 /// \brief Float defines a Number class whose type is float. 166 class MS_CORE_API Float : public Number { 167 public: 168 /// \brief Default constructor for Float. Float()169 Float() : Number(kNumberTypeFloat, 0) {} 170 171 /// \brief Constructor for Float. 172 /// 173 /// \param nbits Define the bit length of Float object. 174 explicit Float(const int nbits); 175 176 /// \brief Destructor of Float. ~Float()177 ~Float() override {} MS_DECLARE_PARENT(Float,Number)178 MS_DECLARE_PARENT(Float, Number) 179 180 TypeId generic_type_id() const override { return kNumberTypeFloat; } DeepCopy()181 TypePtr DeepCopy() const override { 182 if (nbits() == 0) { 183 return std::make_shared<Float>(); 184 } 185 return std::make_shared<Float>(nbits()); 186 } 187 ToString()188 std::string ToString() const override { return GetTypeName("Float"); } ToReprString()189 std::string ToReprString() const override { return nbits() == 0 ? "float_" : GetTypeName("float"); } DumpText()190 std::string DumpText() const override { 191 return nbits() == 0 ? std::string("Float") : std::string("F") + std::to_string(nbits()); 192 } 193 }; 194 195 // BFloat 196 /// \brief BFloat defines a Number class whose type is brain float. 197 class MS_CORE_API BFloat : public Number { 198 public: 199 /// \brief Default constructor for BFloat. BFloat()200 BFloat() : Number(kNumberTypeBFloat16, 0) {} 201 202 /// \brief Constructor for BFloat. 203 /// 204 /// \param nbits Define the bit length of BFloat object. 205 explicit BFloat(const int nbits); 206 207 /// \brief Destructor of BFloat. ~BFloat()208 ~BFloat() override {} MS_DECLARE_PARENT(BFloat,Number)209 MS_DECLARE_PARENT(BFloat, Number) 210 211 TypeId generic_type_id() const override { return kNumberTypeBFloat16; } DeepCopy()212 TypePtr DeepCopy() const override { 213 if (nbits() == 0) { 214 return std::make_shared<BFloat>(); 215 } 216 return std::make_shared<BFloat>(nbits()); 217 } 218 ToString()219 std::string ToString() const override { return GetTypeName("BFloat"); } ToReprString()220 std::string ToReprString() const override { return nbits() == 0 ? "bfloat" : GetTypeName("bfloat"); } DumpText()221 std::string DumpText() const override { 222 return nbits() == 0 ? std::string("BFloat") : std::string("BF") + std::to_string(nbits()); 223 } 224 }; 225 226 // Complex 227 /// \brief Complex defines a Number class whose type is complex. 228 class MS_CORE_API Complex : public Number { 229 public: 230 /// \brief Default constructor for Complex. Complex()231 Complex() : Number(kNumberTypeComplex, 0) {} 232 233 /// \brief Constructor for Complex. 234 /// 235 /// \param nbits Define the bit length of Complex object. 236 explicit Complex(const int nbits); 237 238 /// \brief Destructor of Complex. ~Complex()239 ~Complex() override {} MS_DECLARE_PARENT(Complex,Number)240 MS_DECLARE_PARENT(Complex, Number) 241 242 TypeId generic_type_id() const override { return kNumberTypeComplex; } DeepCopy()243 TypePtr DeepCopy() const override { 244 if (nbits() == 0) { 245 return std::make_shared<Complex>(); 246 } 247 return std::make_shared<Complex>(nbits()); 248 } 249 ToString()250 std::string ToString() const override { return GetTypeName("Complex"); } ToReprString()251 std::string ToReprString() const override { return GetTypeName("complex"); } DumpText()252 std::string DumpText() const override { return std::string("Complex") + std::to_string(nbits()); } 253 }; 254 255 GVAR_DEF(TypePtr, kBool, std::make_shared<Bool>()); 256 GVAR_DEF(TypePtr, kInt4, std::make_shared<Int>(static_cast<int>(BitsNum::eBits4))); 257 GVAR_DEF(TypePtr, kInt8, std::make_shared<Int>(static_cast<int>(BitsNum::eBits8))); 258 GVAR_DEF(TypePtr, kInt16, std::make_shared<Int>(static_cast<int>(BitsNum::eBits16))); 259 GVAR_DEF(TypePtr, kInt32, std::make_shared<Int>(static_cast<int>(BitsNum::eBits32))); 260 GVAR_DEF(TypePtr, kInt64, std::make_shared<Int>(static_cast<int>(BitsNum::eBits64))); 261 GVAR_DEF(TypePtr, kUInt8, std::make_shared<UInt>(static_cast<int>(BitsNum::eBits8))); 262 GVAR_DEF(TypePtr, kUInt16, std::make_shared<UInt>(static_cast<int>(BitsNum::eBits16))); 263 GVAR_DEF(TypePtr, kUInt32, std::make_shared<UInt>(static_cast<int>(BitsNum::eBits32))); 264 GVAR_DEF(TypePtr, kUInt64, std::make_shared<UInt>(static_cast<int>(BitsNum::eBits64))); 265 GVAR_DEF(TypePtr, kFloat16, std::make_shared<Float>(static_cast<int>(BitsNum::eBits16))); 266 GVAR_DEF(TypePtr, kFloat32, std::make_shared<Float>(static_cast<int>(BitsNum::eBits32))); 267 GVAR_DEF(TypePtr, kFloat64, std::make_shared<Float>(static_cast<int>(BitsNum::eBits64))); 268 GVAR_DEF(TypePtr, kBFloat16, std::make_shared<BFloat>(static_cast<int>(BitsNum::eBits16))); 269 GVAR_DEF(TypePtr, kInt, std::make_shared<Int>()); 270 GVAR_DEF(TypePtr, kUInt, std::make_shared<UInt>()); 271 GVAR_DEF(TypePtr, kFloat, std::make_shared<Float>()); 272 GVAR_DEF(TypePtr, kBFloat, std::make_shared<BFloat>()); 273 GVAR_DEF(TypePtr, kNumber, std::make_shared<Number>()); 274 GVAR_DEF(TypePtr, kComplex64, std::make_shared<Complex>(static_cast<int>(BitsNum::eBits64))); 275 GVAR_DEF(TypePtr, kComplex128, std::make_shared<Complex>(static_cast<int>(BitsNum::eBits128))); 276 } // namespace mindspore 277 278 #endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ 279