• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_SCALAR_H_
18 #define MINDSPORE_CORE_IR_SCALAR_H_
19 
20 #include <type_traits>
21 #include <algorithm>
22 #include <cmath>
23 #include <vector>
24 #include <string>
25 #include <memory>
26 #include <sstream>
27 #include <utility>
28 #include <cfloat>
29 
30 #include "base/base.h"
31 #include "ir/dtype.h"
32 #include "ir/dtype/number.h"
33 #include "utils/hashing.h"
34 
35 using std::fabs;
36 
37 namespace mindspore {
38 class MS_CORE_API Scalar : public Value {
39  public:
40   Scalar() = default;
Scalar(const TypePtr t)41   explicit Scalar(const TypePtr t) : Value(t) {}
42   ~Scalar() override = default;
43   MS_DECLARE_PARENT(Scalar, Value)
44   virtual bool IsZero() = 0;
45   virtual bool IsOne() = 0;
46   abstract::AbstractBasePtr ToAbstract() override;
47 
48  protected:
49   std::size_t hash_ = 0;
50 };
51 using ScalarPtr = std::shared_ptr<Scalar>;
52 
53 class MS_CORE_API BoolImm : public Scalar {
54  public:
BoolImm(bool b)55   explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); }
56   ~BoolImm() override = default;
MS_DECLARE_PARENT(BoolImm,Scalar)57   MS_DECLARE_PARENT(BoolImm, Scalar)
58   std::size_t hash() const override { return hash_; }
value()59   bool value() const { return v_; }
IsZero()60   bool IsZero() override { return v_ == false; }
IsOne()61   bool IsOne() override { return v_ == true; }
62   bool operator==(const Value &other) const override;
63   bool operator==(const BoolImm &other) const;
ToString()64   std::string ToString() const override {
65     if (v_) {
66       return "true";
67     } else {
68       return "false";
69     }
70   }
71 
DumpText()72   std::string DumpText() const override {
73     std::ostringstream oss;
74     oss << "Bool(" << v_ << ")";
75     return oss.str();
76   }
77 
78  private:
79   bool v_;
80 };
81 using BoolImmPtr = std::shared_ptr<BoolImm>;
IMM_TRAITS(BoolImmPtr,bool)82 IMM_TRAITS(BoolImmPtr, bool)
83 
84 class MS_CORE_API IntergerImm : public Scalar {
85  public:
86   IntergerImm() = default;
87   explicit IntergerImm(const TypePtr &t) : Scalar(t) {}
88   ~IntergerImm() override = default;
89   MS_DECLARE_PARENT(IntergerImm, Scalar)
90 };
91 
92 class MS_CORE_API Int8Imm : public IntergerImm {
93  public:
Int8Imm()94   Int8Imm() : IntergerImm(kInt8), v_(0) {}
Int8Imm(int8_t v)95   explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
96   ~Int8Imm() override = default;
MS_DECLARE_PARENT(Int8Imm,IntergerImm)97   MS_DECLARE_PARENT(Int8Imm, IntergerImm)
98   std::size_t hash() const override { return hash_; }
IsZero()99   bool IsZero() override { return v_ == 0; }
IsOne()100   bool IsOne() override { return v_ == 1; }
value()101   int8_t value() const { return v_; }
102   bool operator==(const Value &other) const override;
103   bool operator==(const Int8Imm &other) const;
ToString()104   std::string ToString() const override { return std::to_string(v_); }
105 
DumpText()106   std::string DumpText() const override {
107     std::ostringstream oss;
108     oss << "I8(" << int(v_) << ")";
109     return oss.str();
110   }
111 
112  private:
113   int8_t v_;
114 };
115 using Int8ImmPtr = std::shared_ptr<Int8Imm>;
IMM_TRAITS(Int8ImmPtr,int8_t)116 IMM_TRAITS(Int8ImmPtr, int8_t)
117 
118 class MS_CORE_API Int16Imm : public IntergerImm {
119  public:
120   Int16Imm() : IntergerImm(kInt16), v_(0) {}
121   explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
122   ~Int16Imm() override = default;
123   MS_DECLARE_PARENT(Int16Imm, IntergerImm)
124   std::size_t hash() const override { return hash_; }
125   bool IsZero() override { return v_ == 0; }
126   bool IsOne() override { return v_ == 1; }
127   int16_t value() const { return v_; }
128   bool operator==(const Value &other) const override;
129   bool operator==(const Int16Imm &other) const;
130   std::string ToString() const override { return std::to_string(v_); }
131 
132   std::string DumpText() const override {
133     std::ostringstream oss;
134     oss << "I16(" << int(v_) << ")";
135     return oss.str();
136   }
137 
138  private:
139   int16_t v_;
140 };
141 using Int16ImmPtr = std::shared_ptr<Int16Imm>;
IMM_TRAITS(Int16ImmPtr,int16_t)142 IMM_TRAITS(Int16ImmPtr, int16_t)
143 
144 class MS_CORE_API Int32Imm : public IntergerImm {
145  public:
146   Int32Imm() : IntergerImm(kInt32), v_(0) {}
147   explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
148   ~Int32Imm() override = default;
149   MS_DECLARE_PARENT(Int32Imm, IntergerImm)
150   std::size_t hash() const override { return hash_; }
151   bool IsZero() override { return v_ == 0; }
152   bool IsOne() override { return v_ == 1; }
153   int32_t value() const { return v_; }
154   bool operator==(const Value &other) const override;
155   bool operator==(const Int32Imm &other) const;
156   std::string ToString() const override { return std::to_string(v_); }
157 
158   std::string DumpText() const override {
159     std::ostringstream oss;
160     oss << "I32(" << int(v_) << ")";
161     return oss.str();
162   }
163 
164  private:
165   int32_t v_;
166 };
167 using Int32ImmPtr = std::shared_ptr<Int32Imm>;
IMM_TRAITS(Int32ImmPtr,int32_t)168 IMM_TRAITS(Int32ImmPtr, int32_t)
169 
170 class MS_CORE_API Int64Imm : public IntergerImm {
171  public:
172   Int64Imm() : IntergerImm(kInt64), v_(0) {}
173   explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); }
174   ~Int64Imm() override = default;
175   MS_DECLARE_PARENT(Int64Imm, IntergerImm)
176   std::size_t hash() const override { return hash_; }
177   bool IsZero() override { return v_ == 0; }
178   bool IsOne() override { return v_ == 1; }
179   int64_t value() const { return v_; }
180   bool operator==(const Value &other) const override;
181   bool operator==(const Int64Imm &other) const;
182   std::string ToString() const override { return std::to_string(v_); }
183 
184   std::string DumpText() const override {
185     std::ostringstream oss;
186     oss << "I64(" << v_ << ")";
187     return oss.str();
188   }
189 
190  private:
191   int64_t v_;
192 };
193 using Int64ImmPtr = std::shared_ptr<Int64Imm>;
IMM_TRAITS(Int64ImmPtr,int64_t)194 IMM_TRAITS(Int64ImmPtr, int64_t)
195 
196 class MS_CORE_API UInt8Imm : public IntergerImm {
197  public:
198   UInt8Imm() : IntergerImm(kUInt8), v_(0) {}
199   explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) {
200     hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
201   }
202   ~UInt8Imm() override = default;
203   MS_DECLARE_PARENT(UInt8Imm, IntergerImm)
204   std::size_t hash() const override { return hash_; }
205   bool IsZero() override { return v_ == 0; }
206   bool IsOne() override { return v_ == 1; }
207   uint8_t value() const { return v_; }
208   bool operator==(const Value &other) const override;
209   bool operator==(const UInt8Imm &other) const;
210   std::string ToString() const override { return std::to_string(v_); }
211 
212   std::string DumpText() const override {
213     std::ostringstream oss;
214     oss << "U8(" << unsigned(v_) << ")";
215     return oss.str();
216   }
217 
218  private:
219   uint8_t v_;
220 };
221 using UInt8ImmPtr = std::shared_ptr<UInt8Imm>;
222 IMM_TRAITS(UInt8ImmPtr, uint8_t);
223 
224 class MS_CORE_API UInt16Imm : public IntergerImm {
225  public:
UInt16Imm()226   UInt16Imm() : IntergerImm(kUInt16), v_(0) {}
UInt16Imm(uint16_t v)227   explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) {
228     hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
229   }
230   ~UInt16Imm() override = default;
MS_DECLARE_PARENT(UInt16Imm,IntergerImm)231   MS_DECLARE_PARENT(UInt16Imm, IntergerImm)
232   std::size_t hash() const override { return hash_; }
IsZero()233   bool IsZero() override { return v_ == 0; }
IsOne()234   bool IsOne() override { return v_ == 1; }
value()235   uint16_t value() const { return v_; }
236   bool operator==(const Value &other) const override;
237   bool operator==(const UInt16Imm &other) const;
ToString()238   std::string ToString() const override { return std::to_string(v_); }
239 
DumpText()240   std::string DumpText() const override {
241     std::ostringstream oss;
242     oss << "U16(" << unsigned(v_) << ")";
243     return oss.str();
244   }
245 
246  private:
247   uint16_t v_;
248 };
249 using UInt16ImmPtr = std::shared_ptr<UInt16Imm>;
250 IMM_TRAITS(UInt16ImmPtr, uint16_t);
251 
252 class MS_CORE_API UInt32Imm : public IntergerImm {
253  public:
UInt32Imm()254   UInt32Imm() : IntergerImm(kUInt32), v_(0) {}
UInt32Imm(uint32_t v)255   explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) {
256     hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
257   }
258   ~UInt32Imm() override = default;
MS_DECLARE_PARENT(UInt32Imm,IntergerImm)259   MS_DECLARE_PARENT(UInt32Imm, IntergerImm)
260   std::size_t hash() const override { return hash_; }
IsZero()261   bool IsZero() override { return v_ == 0; }
IsOne()262   bool IsOne() override { return v_ == 1; }
value()263   uint32_t value() const { return v_; }
264   bool operator==(const Value &other) const override;
265   bool operator==(const UInt32Imm &other) const;
ToString()266   std::string ToString() const override { return std::to_string(v_); }
267 
DumpText()268   std::string DumpText() const override {
269     std::ostringstream oss;
270     oss << "U32(" << unsigned(v_) << ")";
271     return oss.str();
272   }
273 
274  private:
275   uint32_t v_;
276 };
277 using UInt32ImmPtr = std::shared_ptr<UInt32Imm>;
278 IMM_TRAITS(UInt32ImmPtr, uint32_t);
279 
280 class MS_CORE_API UInt64Imm : public IntergerImm {
281  public:
UInt64Imm()282   UInt64Imm() : IntergerImm(kUInt64), v_(0) {}
UInt64Imm(uint64_t v)283   explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) {
284     hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)});
285   }
286   ~UInt64Imm() override = default;
MS_DECLARE_PARENT(UInt64Imm,IntergerImm)287   MS_DECLARE_PARENT(UInt64Imm, IntergerImm)
288   std::size_t hash() const override { return hash_; }
IsZero()289   bool IsZero() override { return v_ == 0; }
IsOne()290   bool IsOne() override { return v_ == 1; }
value()291   uint64_t value() const { return v_; }
292   bool operator==(const Value &other) const override;
293   bool operator==(const UInt64Imm &other) const;
ToString()294   std::string ToString() const override { return std::to_string(v_); }
295 
DumpText()296   std::string DumpText() const override {
297     std::ostringstream oss;
298     oss << "U64(" << v_ << ")";
299     return oss.str();
300   }
301 
302  private:
303   uint64_t v_;
304 };
305 using UInt64ImmPtr = std::shared_ptr<UInt64Imm>;
306 IMM_TRAITS(UInt64ImmPtr, uint64_t);
307 
308 class MS_CORE_API FloatImm : public Scalar {
309  public:
310   FloatImm() = default;
FloatImm(const TypePtr & t)311   explicit FloatImm(const TypePtr &t) : Scalar(t) {}
312   ~FloatImm() override = default;
313   MS_DECLARE_PARENT(FloatImm, Scalar)
314 };
315 using FloatImmPtr = std::shared_ptr<FloatImm>;
316 
317 class MS_CORE_API FP32Imm : public FloatImm {
318  public:
FP32Imm()319   FP32Imm() : FloatImm(kFloat32), v_(0.0) {}
FP32Imm(float v)320   explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); }
321   ~FP32Imm() override = default;
MS_DECLARE_PARENT(FP32Imm,FloatImm)322   MS_DECLARE_PARENT(FP32Imm, FloatImm)
323   std::size_t hash() const override { return hash_; }
IsZero()324   bool IsZero() override { return fabs(v_) <= FLT_EPSILON; }
IsOne()325   bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; }
value()326   float value() const { return v_; }
327   bool operator==(const Value &other) const override;
328   bool operator==(const FP32Imm &other) const;
ToString()329   std::string ToString() const override { return std::to_string(v_); }
330 
DumpText()331   std::string DumpText() const override {
332     std::ostringstream oss;
333     oss << "F32(" << v_ << ")";
334     return oss.str();
335   }
336 
337  private:
338   float v_;
339 };
340 using FP32ImmPtr = std::shared_ptr<FP32Imm>;
IMM_TRAITS(FP32ImmPtr,float)341 IMM_TRAITS(FP32ImmPtr, float)
342 
343 class MS_CORE_API FP64Imm : public FloatImm {
344  public:
345   FP64Imm() : FloatImm(kFloat64), v_(0.0) {}
346   explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); }
347   ~FP64Imm() override = default;
348   MS_DECLARE_PARENT(FP64Imm, FloatImm)
349   std::size_t hash() const override { return hash_; }
350   bool IsZero() override { return fabs(v_) <= DBL_EPSILON; }
351   bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; }
352   double value() const { return v_; }
353   bool operator==(const Value &other) const override;
354   bool operator==(const FP64Imm &other) const;
355   std::string ToString() const override { return std::to_string(v_); }
356 
357   std::string DumpText() const override {
358     std::ostringstream oss;
359     oss << "F64(" << v_ << ")";
360     return oss.str();
361   }
362 
363  private:
364   double v_;
365 };
366 using FP64ImmPtr = std::shared_ptr<FP64Imm>;
367 IMM_TRAITS(FP64ImmPtr, double)
368 
369 }  // namespace mindspore
370 
371 #endif  // MINDSPORE_CORE_IR_SCALAR_H_
372