1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_IR_SCALAR_H_
18 #define MINDSPORE_CORE_IR_SCALAR_H_
19
20 #include <type_traits>
21 #include <algorithm>
22 #include <cmath>
23 #include <vector>
24 #include <string>
25 #include <memory>
26 #include <sstream>
27 #include <utility>
28 #include <cfloat>
29
30 #include "base/base.h"
31 #include "ir/dtype.h"
32 #include "ir/dtype/number.h"
33 #include "utils/hashing.h"
34
35 using std::fabs;
36
37 namespace mindspore {
38 class MS_CORE_API Scalar : public Value {
39 public:
40 Scalar() = default;
Scalar(const TypePtr t)41 explicit Scalar(const TypePtr t) : Value(t) {}
42 ~Scalar() override = default;
43 MS_DECLARE_PARENT(Scalar, Value)
44 virtual bool IsZero() = 0;
45 virtual bool IsOne() = 0;
46 abstract::AbstractBasePtr ToAbstract() override;
47
48 protected:
49 std::size_t hash_ = 0;
50 };
51 using ScalarPtr = std::shared_ptr<Scalar>;
52
53 class MS_CORE_API BoolImm : public Scalar {
54 public:
BoolImm(bool b)55 explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); }
56 ~BoolImm() override = default;
MS_DECLARE_PARENT(BoolImm,Scalar)57 MS_DECLARE_PARENT(BoolImm, Scalar)
58 std::size_t hash() const override { return hash_; }
value()59 bool value() const { return v_; }
IsZero()60 bool IsZero() override { return v_ == false; }
IsOne()61 bool IsOne() override { return v_ == true; }
62 bool operator==(const Value &other) const override;
63 bool operator==(const BoolImm &other) const;
ToString()64 std::string ToString() const override {
65 if (v_) {
66 return "true";
67 } else {
68 return "false";
69 }
70 }
71
DumpText()72 std::string DumpText() const override {
73 std::ostringstream oss;
74 oss << "Bool(" << v_ << ")";
75 return oss.str();
76 }
77
78 private:
79 bool v_;
80 };
81 using BoolImmPtr = std::shared_ptr<BoolImm>;
IMM_TRAITS(BoolImmPtr,bool)82 IMM_TRAITS(BoolImmPtr, bool)
83
84 class MS_CORE_API IntergerImm : public Scalar {
85 public:
86 IntergerImm() = default;
87 explicit IntergerImm(const TypePtr &t) : Scalar(t) {}
88 ~IntergerImm() override = default;
89 MS_DECLARE_PARENT(IntergerImm, Scalar)
90 };
91
92 class MS_CORE_API Int8Imm : public IntergerImm {
93 public:
Int8Imm()94 Int8Imm() : IntergerImm(kInt8), v_(0) {}
Int8Imm(int8_t v)95 explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
96 ~Int8Imm() override = default;
MS_DECLARE_PARENT(Int8Imm,IntergerImm)97 MS_DECLARE_PARENT(Int8Imm, IntergerImm)
98 std::size_t hash() const override { return hash_; }
IsZero()99 bool IsZero() override { return v_ == 0; }
IsOne()100 bool IsOne() override { return v_ == 1; }
value()101 int8_t value() const { return v_; }
102 bool operator==(const Value &other) const override;
103 bool operator==(const Int8Imm &other) const;
ToString()104 std::string ToString() const override { return std::to_string(v_); }
105
DumpText()106 std::string DumpText() const override {
107 std::ostringstream oss;
108 oss << "I8(" << int(v_) << ")";
109 return oss.str();
110 }
111
112 private:
113 int8_t v_;
114 };
115 using Int8ImmPtr = std::shared_ptr<Int8Imm>;
IMM_TRAITS(Int8ImmPtr,int8_t)116 IMM_TRAITS(Int8ImmPtr, int8_t)
117
118 class MS_CORE_API Int16Imm : public IntergerImm {
119 public:
120 Int16Imm() : IntergerImm(kInt16), v_(0) {}
121 explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
122 ~Int16Imm() override = default;
123 MS_DECLARE_PARENT(Int16Imm, IntergerImm)
124 std::size_t hash() const override { return hash_; }
125 bool IsZero() override { return v_ == 0; }
126 bool IsOne() override { return v_ == 1; }
127 int16_t value() const { return v_; }
128 bool operator==(const Value &other) const override;
129 bool operator==(const Int16Imm &other) const;
130 std::string ToString() const override { return std::to_string(v_); }
131
132 std::string DumpText() const override {
133 std::ostringstream oss;
134 oss << "I16(" << int(v_) << ")";
135 return oss.str();
136 }
137
138 private:
139 int16_t v_;
140 };
141 using Int16ImmPtr = std::shared_ptr<Int16Imm>;
IMM_TRAITS(Int16ImmPtr,int16_t)142 IMM_TRAITS(Int16ImmPtr, int16_t)
143
144 class MS_CORE_API Int32Imm : public IntergerImm {
145 public:
146 Int32Imm() : IntergerImm(kInt32), v_(0) {}
147 explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
148 ~Int32Imm() override = default;
149 MS_DECLARE_PARENT(Int32Imm, IntergerImm)
150 std::size_t hash() const override { return hash_; }
151 bool IsZero() override { return v_ == 0; }
152 bool IsOne() override { return v_ == 1; }
153 int32_t value() const { return v_; }
154 bool operator==(const Value &other) const override;
155 bool operator==(const Int32Imm &other) const;
156 std::string ToString() const override { return std::to_string(v_); }
157
158 std::string DumpText() const override {
159 std::ostringstream oss;
160 oss << "I32(" << int(v_) << ")";
161 return oss.str();
162 }
163
164 private:
165 int32_t v_;
166 };
167 using Int32ImmPtr = std::shared_ptr<Int32Imm>;
IMM_TRAITS(Int32ImmPtr,int32_t)168 IMM_TRAITS(Int32ImmPtr, int32_t)
169
170 class MS_CORE_API Int64Imm : public IntergerImm {
171 public:
172 Int64Imm() : IntergerImm(kInt64), v_(0) {}
173 explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); }
174 ~Int64Imm() override = default;
175 MS_DECLARE_PARENT(Int64Imm, IntergerImm)
176 std::size_t hash() const override { return hash_; }
177 bool IsZero() override { return v_ == 0; }
178 bool IsOne() override { return v_ == 1; }
179 int64_t value() const { return v_; }
180 bool operator==(const Value &other) const override;
181 bool operator==(const Int64Imm &other) const;
182 std::string ToString() const override { return std::to_string(v_); }
183
184 std::string DumpText() const override {
185 std::ostringstream oss;
186 oss << "I64(" << v_ << ")";
187 return oss.str();
188 }
189
190 private:
191 int64_t v_;
192 };
193 using Int64ImmPtr = std::shared_ptr<Int64Imm>;
IMM_TRAITS(Int64ImmPtr,int64_t)194 IMM_TRAITS(Int64ImmPtr, int64_t)
195
196 class MS_CORE_API UInt8Imm : public IntergerImm {
197 public:
198 UInt8Imm() : IntergerImm(kUInt8), v_(0) {}
199 explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) {
200 hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
201 }
202 ~UInt8Imm() override = default;
203 MS_DECLARE_PARENT(UInt8Imm, IntergerImm)
204 std::size_t hash() const override { return hash_; }
205 bool IsZero() override { return v_ == 0; }
206 bool IsOne() override { return v_ == 1; }
207 uint8_t value() const { return v_; }
208 bool operator==(const Value &other) const override;
209 bool operator==(const UInt8Imm &other) const;
210 std::string ToString() const override { return std::to_string(v_); }
211
212 std::string DumpText() const override {
213 std::ostringstream oss;
214 oss << "U8(" << unsigned(v_) << ")";
215 return oss.str();
216 }
217
218 private:
219 uint8_t v_;
220 };
221 using UInt8ImmPtr = std::shared_ptr<UInt8Imm>;
222 IMM_TRAITS(UInt8ImmPtr, uint8_t);
223
224 class MS_CORE_API UInt16Imm : public IntergerImm {
225 public:
UInt16Imm()226 UInt16Imm() : IntergerImm(kUInt16), v_(0) {}
UInt16Imm(uint16_t v)227 explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) {
228 hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
229 }
230 ~UInt16Imm() override = default;
MS_DECLARE_PARENT(UInt16Imm,IntergerImm)231 MS_DECLARE_PARENT(UInt16Imm, IntergerImm)
232 std::size_t hash() const override { return hash_; }
IsZero()233 bool IsZero() override { return v_ == 0; }
IsOne()234 bool IsOne() override { return v_ == 1; }
value()235 uint16_t value() const { return v_; }
236 bool operator==(const Value &other) const override;
237 bool operator==(const UInt16Imm &other) const;
ToString()238 std::string ToString() const override { return std::to_string(v_); }
239
DumpText()240 std::string DumpText() const override {
241 std::ostringstream oss;
242 oss << "U16(" << unsigned(v_) << ")";
243 return oss.str();
244 }
245
246 private:
247 uint16_t v_;
248 };
249 using UInt16ImmPtr = std::shared_ptr<UInt16Imm>;
250 IMM_TRAITS(UInt16ImmPtr, uint16_t);
251
252 class MS_CORE_API UInt32Imm : public IntergerImm {
253 public:
UInt32Imm()254 UInt32Imm() : IntergerImm(kUInt32), v_(0) {}
UInt32Imm(uint32_t v)255 explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) {
256 hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
257 }
258 ~UInt32Imm() override = default;
MS_DECLARE_PARENT(UInt32Imm,IntergerImm)259 MS_DECLARE_PARENT(UInt32Imm, IntergerImm)
260 std::size_t hash() const override { return hash_; }
IsZero()261 bool IsZero() override { return v_ == 0; }
IsOne()262 bool IsOne() override { return v_ == 1; }
value()263 uint32_t value() const { return v_; }
264 bool operator==(const Value &other) const override;
265 bool operator==(const UInt32Imm &other) const;
ToString()266 std::string ToString() const override { return std::to_string(v_); }
267
DumpText()268 std::string DumpText() const override {
269 std::ostringstream oss;
270 oss << "U32(" << unsigned(v_) << ")";
271 return oss.str();
272 }
273
274 private:
275 uint32_t v_;
276 };
277 using UInt32ImmPtr = std::shared_ptr<UInt32Imm>;
278 IMM_TRAITS(UInt32ImmPtr, uint32_t);
279
280 class MS_CORE_API UInt64Imm : public IntergerImm {
281 public:
UInt64Imm()282 UInt64Imm() : IntergerImm(kUInt64), v_(0) {}
UInt64Imm(uint64_t v)283 explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) {
284 hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)});
285 }
286 ~UInt64Imm() override = default;
MS_DECLARE_PARENT(UInt64Imm,IntergerImm)287 MS_DECLARE_PARENT(UInt64Imm, IntergerImm)
288 std::size_t hash() const override { return hash_; }
IsZero()289 bool IsZero() override { return v_ == 0; }
IsOne()290 bool IsOne() override { return v_ == 1; }
value()291 uint64_t value() const { return v_; }
292 bool operator==(const Value &other) const override;
293 bool operator==(const UInt64Imm &other) const;
ToString()294 std::string ToString() const override { return std::to_string(v_); }
295
DumpText()296 std::string DumpText() const override {
297 std::ostringstream oss;
298 oss << "U64(" << v_ << ")";
299 return oss.str();
300 }
301
302 private:
303 uint64_t v_;
304 };
305 using UInt64ImmPtr = std::shared_ptr<UInt64Imm>;
306 IMM_TRAITS(UInt64ImmPtr, uint64_t);
307
308 class MS_CORE_API FloatImm : public Scalar {
309 public:
310 FloatImm() = default;
FloatImm(const TypePtr & t)311 explicit FloatImm(const TypePtr &t) : Scalar(t) {}
312 ~FloatImm() override = default;
313 MS_DECLARE_PARENT(FloatImm, Scalar)
314 };
315 using FloatImmPtr = std::shared_ptr<FloatImm>;
316
317 class MS_CORE_API FP32Imm : public FloatImm {
318 public:
FP32Imm()319 FP32Imm() : FloatImm(kFloat32), v_(0.0) {}
FP32Imm(float v)320 explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); }
321 ~FP32Imm() override = default;
MS_DECLARE_PARENT(FP32Imm,FloatImm)322 MS_DECLARE_PARENT(FP32Imm, FloatImm)
323 std::size_t hash() const override { return hash_; }
IsZero()324 bool IsZero() override { return fabs(v_) <= FLT_EPSILON; }
IsOne()325 bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; }
value()326 float value() const { return v_; }
327 bool operator==(const Value &other) const override;
328 bool operator==(const FP32Imm &other) const;
ToString()329 std::string ToString() const override { return std::to_string(v_); }
330
DumpText()331 std::string DumpText() const override {
332 std::ostringstream oss;
333 oss << "F32(" << v_ << ")";
334 return oss.str();
335 }
336
337 private:
338 float v_;
339 };
340 using FP32ImmPtr = std::shared_ptr<FP32Imm>;
IMM_TRAITS(FP32ImmPtr,float)341 IMM_TRAITS(FP32ImmPtr, float)
342
343 class MS_CORE_API FP64Imm : public FloatImm {
344 public:
345 FP64Imm() : FloatImm(kFloat64), v_(0.0) {}
346 explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); }
347 ~FP64Imm() override = default;
348 MS_DECLARE_PARENT(FP64Imm, FloatImm)
349 std::size_t hash() const override { return hash_; }
350 bool IsZero() override { return fabs(v_) <= DBL_EPSILON; }
351 bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; }
352 double value() const { return v_; }
353 bool operator==(const Value &other) const override;
354 bool operator==(const FP64Imm &other) const;
355 std::string ToString() const override { return std::to_string(v_); }
356
357 std::string DumpText() const override {
358 std::ostringstream oss;
359 oss << "F64(" << v_ << ")";
360 return oss.str();
361 }
362
363 private:
364 double v_;
365 };
366 using FP64ImmPtr = std::shared_ptr<FP64Imm>;
367 IMM_TRAITS(FP64ImmPtr, double)
368
369 } // namespace mindspore
370
371 #endif // MINDSPORE_CORE_IR_SCALAR_H_
372