• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
18 #define MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
19 
20 #include <map>
21 #include <memory>
22 #include <sstream>
23 #include <string>
24 #include "utils/hash_map.h"
25 #include "base/base.h"
26 #include "ir/named.h"
27 #include "ir/dtype/type.h"
28 
29 namespace mindspore {
30 /// \brief Number defines an Object class whose type is number.
31 class MS_CORE_API Number : public Object {
32  public:
33   /// \brief Default constructor for Number.
Number()34   Number() : Object(kObjectTypeNumber), number_type_(kObjectTypeNumber), nbits_(0) {}
35 
36   /// \brief Constructor for  Number.
37   ///
38   /// \param[in] number_type Define the number type of Number object.
39   /// \param[in] nbits Define the bit length of Number object.
40   /// \param[in] is_generic Define whether it is generic for Number object.
41   Number(const TypeId number_type, const int nbits, bool is_generic = true)
Object(kObjectTypeNumber,is_generic)42       : Object(kObjectTypeNumber, is_generic), number_type_(number_type), nbits_(nbits) {}
43 
44   /// \brief Destructor of Number.
45   ~Number() override = default;
MS_DECLARE_PARENT(Number,Object)46   MS_DECLARE_PARENT(Number, Object)
47 
48   /// \brief Get the bit length of Number object.
49   ///
50   /// \return bit length of Number object.
51   int nbits() const { return nbits_; }
52 
number_type()53   TypeId number_type() const override { return number_type_; }
type_id()54   TypeId type_id() const override { return number_type_; }
generic_type_id()55   TypeId generic_type_id() const override { return kObjectTypeNumber; }
56   bool operator==(const Type &other) const override;
57   std::size_t hash() const override;
DeepCopy()58   TypePtr DeepCopy() const override { return std::make_shared<Number>(); }
ToString()59   std::string ToString() const override { return "Number"; }
ToReprString()60   std::string ToReprString() const override { return "number"; }
DumpText()61   std::string DumpText() const override { return "Number"; }
62 
63   /// \brief Get type name for Number object.
64   ///
65   /// \param type_name Define the type name.
66   /// \return The full type name of the Number object.
GetTypeName(const std::string & type_name)67   std::string GetTypeName(const std::string &type_name) const {
68     std::ostringstream oss;
69     oss << type_name;
70     if (nbits() != 0) {
71       oss << nbits();
72     }
73     return oss.str();
74   }
75 
76  private:
77   const TypeId number_type_;
78   const int nbits_;
79 };
80 
81 using NumberPtr = std::shared_ptr<Number>;
82 
83 // Bool
84 /// \brief Bool defines a Number class whose type is boolean.
85 class MS_CORE_API Bool : public Number {
86  public:
87   /// \brief Default constructor for Bool.
Bool()88   Bool() : Number(kNumberTypeBool, 8) {}
89 
90   /// \brief Destructor of Bool.
91   ~Bool() override = default;
MS_DECLARE_PARENT(Bool,Number)92   MS_DECLARE_PARENT(Bool, Number)
93 
94   TypeId generic_type_id() const override { return kNumberTypeBool; }
DeepCopy()95   TypePtr DeepCopy() const override { return std::make_shared<Bool>(); }
ToString()96   std::string ToString() const override { return "Bool"; }
ToReprString()97   std::string ToReprString() const override { return "bool_"; }
DumpText()98   std::string DumpText() const override { return "Bool"; }
99 };
100 
101 // Int
102 /// \brief Int defines a Number class whose type is int.
103 class MS_CORE_API Int : public Number {
104  public:
105   /// \brief Default constructor for Int.
Int()106   Int() : Number(kNumberTypeInt, 0) {}
107 
108   /// \brief Constructor for Int.
109   ///
110   /// \param nbits Define the bit length of Int object.
111   explicit Int(const int nbits);
112 
113   /// \brief Destructor of Int.
114   ~Int() override = default;
MS_DECLARE_PARENT(Int,Number)115   MS_DECLARE_PARENT(Int, Number)
116 
117   TypeId generic_type_id() const override { return kNumberTypeInt; }
DeepCopy()118   TypePtr DeepCopy() const override {
119     if (nbits() == 0) {
120       return std::make_shared<Int>();
121     }
122     return std::make_shared<Int>(nbits());
123   }
124 
ToString()125   std::string ToString() const override { return GetTypeName("Int"); }
ToReprString()126   std::string ToReprString() const override { return nbits() == 0 ? "int_" : GetTypeName("int"); }
DumpText()127   std::string DumpText() const override {
128     return nbits() == 0 ? std::string("Int") : std::string("I") + std::to_string(nbits());
129   }
130 };
131 
132 // UInt
133 /// \brief UInt defines a Number class whose type is uint.
134 class MS_CORE_API UInt : public Number {
135  public:
136   /// \brief Default constructor for UInt.
UInt()137   UInt() : Number(kNumberTypeUInt, 0) {}
138 
139   /// \brief Constructor for UInt.
140   ///
141   /// \param nbits Define the bit length of UInt object.
142   explicit UInt(const int nbits);
143 
generic_type_id()144   TypeId generic_type_id() const override { return kNumberTypeUInt; }
145 
146   /// \brief Destructor of UInt.
~UInt()147   ~UInt() override {}
MS_DECLARE_PARENT(UInt,Number)148   MS_DECLARE_PARENT(UInt, Number)
149 
150   TypePtr DeepCopy() const override {
151     if (nbits() == 0) {
152       return std::make_shared<UInt>();
153     }
154     return std::make_shared<UInt>(nbits());
155   }
156 
ToString()157   std::string ToString() const override { return GetTypeName("UInt"); }
ToReprString()158   std::string ToReprString() const override { return GetTypeName("uint"); }
DumpText()159   std::string DumpText() const override {
160     return nbits() == 0 ? std::string("UInt") : std::string("U") + std::to_string(nbits());
161   }
162 };
163 
164 // Float
165 /// \brief Float defines a Number class whose type is float.
166 class MS_CORE_API Float : public Number {
167  public:
168   /// \brief Default constructor for Float.
Float()169   Float() : Number(kNumberTypeFloat, 0) {}
170 
171   /// \brief Constructor for Float.
172   ///
173   /// \param nbits Define the bit length of Float object.
174   explicit Float(const int nbits);
175 
176   /// \brief Destructor of Float.
~Float()177   ~Float() override {}
MS_DECLARE_PARENT(Float,Number)178   MS_DECLARE_PARENT(Float, Number)
179 
180   TypeId generic_type_id() const override { return kNumberTypeFloat; }
DeepCopy()181   TypePtr DeepCopy() const override {
182     if (nbits() == 0) {
183       return std::make_shared<Float>();
184     }
185     return std::make_shared<Float>(nbits());
186   }
187 
ToString()188   std::string ToString() const override { return GetTypeName("Float"); }
ToReprString()189   std::string ToReprString() const override { return nbits() == 0 ? "float_" : GetTypeName("float"); }
DumpText()190   std::string DumpText() const override {
191     return nbits() == 0 ? std::string("Float") : std::string("F") + std::to_string(nbits());
192   }
193 };
194 
195 // BFloat
196 /// \brief BFloat defines a Number class whose type is brain float.
197 class MS_CORE_API BFloat : public Number {
198  public:
199   /// \brief Default constructor for BFloat.
BFloat()200   BFloat() : Number(kNumberTypeBFloat16, 0) {}
201 
202   /// \brief Constructor for BFloat.
203   ///
204   /// \param nbits Define the bit length of BFloat object.
205   explicit BFloat(const int nbits);
206 
207   /// \brief Destructor of BFloat.
~BFloat()208   ~BFloat() override {}
MS_DECLARE_PARENT(BFloat,Number)209   MS_DECLARE_PARENT(BFloat, Number)
210 
211   TypeId generic_type_id() const override { return kNumberTypeBFloat16; }
DeepCopy()212   TypePtr DeepCopy() const override {
213     if (nbits() == 0) {
214       return std::make_shared<BFloat>();
215     }
216     return std::make_shared<BFloat>(nbits());
217   }
218 
ToString()219   std::string ToString() const override { return GetTypeName("BFloat"); }
ToReprString()220   std::string ToReprString() const override { return nbits() == 0 ? "bfloat" : GetTypeName("bfloat"); }
DumpText()221   std::string DumpText() const override {
222     return nbits() == 0 ? std::string("BFloat") : std::string("BF") + std::to_string(nbits());
223   }
224 };
225 
226 // Complex
227 /// \brief Complex defines a Number class whose type is complex.
228 class MS_CORE_API Complex : public Number {
229  public:
230   /// \brief Default constructor for Complex.
Complex()231   Complex() : Number(kNumberTypeComplex, 0) {}
232 
233   /// \brief Constructor for Complex.
234   ///
235   /// \param nbits Define the bit length of Complex object.
236   explicit Complex(const int nbits);
237 
238   /// \brief Destructor of Complex.
~Complex()239   ~Complex() override {}
MS_DECLARE_PARENT(Complex,Number)240   MS_DECLARE_PARENT(Complex, Number)
241 
242   TypeId generic_type_id() const override { return kNumberTypeComplex; }
DeepCopy()243   TypePtr DeepCopy() const override {
244     if (nbits() == 0) {
245       return std::make_shared<Complex>();
246     }
247     return std::make_shared<Complex>(nbits());
248   }
249 
ToString()250   std::string ToString() const override { return GetTypeName("Complex"); }
ToReprString()251   std::string ToReprString() const override { return GetTypeName("complex"); }
DumpText()252   std::string DumpText() const override { return std::string("Complex") + std::to_string(nbits()); }
253 };
254 
255 GVAR_DEF(TypePtr, kBool, std::make_shared<Bool>());
256 GVAR_DEF(TypePtr, kInt4, std::make_shared<Int>(static_cast<int>(BitsNum::eBits4)));
257 GVAR_DEF(TypePtr, kInt8, std::make_shared<Int>(static_cast<int>(BitsNum::eBits8)));
258 GVAR_DEF(TypePtr, kInt16, std::make_shared<Int>(static_cast<int>(BitsNum::eBits16)));
259 GVAR_DEF(TypePtr, kInt32, std::make_shared<Int>(static_cast<int>(BitsNum::eBits32)));
260 GVAR_DEF(TypePtr, kInt64, std::make_shared<Int>(static_cast<int>(BitsNum::eBits64)));
261 GVAR_DEF(TypePtr, kUInt8, std::make_shared<UInt>(static_cast<int>(BitsNum::eBits8)));
262 GVAR_DEF(TypePtr, kUInt16, std::make_shared<UInt>(static_cast<int>(BitsNum::eBits16)));
263 GVAR_DEF(TypePtr, kUInt32, std::make_shared<UInt>(static_cast<int>(BitsNum::eBits32)));
264 GVAR_DEF(TypePtr, kUInt64, std::make_shared<UInt>(static_cast<int>(BitsNum::eBits64)));
265 GVAR_DEF(TypePtr, kFloat16, std::make_shared<Float>(static_cast<int>(BitsNum::eBits16)));
266 GVAR_DEF(TypePtr, kFloat32, std::make_shared<Float>(static_cast<int>(BitsNum::eBits32)));
267 GVAR_DEF(TypePtr, kFloat64, std::make_shared<Float>(static_cast<int>(BitsNum::eBits64)));
268 GVAR_DEF(TypePtr, kBFloat16, std::make_shared<BFloat>(static_cast<int>(BitsNum::eBits16)));
269 GVAR_DEF(TypePtr, kInt, std::make_shared<Int>());
270 GVAR_DEF(TypePtr, kUInt, std::make_shared<UInt>());
271 GVAR_DEF(TypePtr, kFloat, std::make_shared<Float>());
272 GVAR_DEF(TypePtr, kBFloat, std::make_shared<BFloat>());
273 GVAR_DEF(TypePtr, kNumber, std::make_shared<Number>());
274 GVAR_DEF(TypePtr, kComplex64, std::make_shared<Complex>(static_cast<int>(BitsNum::eBits64)));
275 GVAR_DEF(TypePtr, kComplex128, std::make_shared<Complex>(static_cast<int>(BitsNum::eBits128)));
276 }  // namespace mindspore
277 
278 #endif  // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_
279