1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_IR_SCALAR_H_
18 #define MINDSPORE_CORE_IR_SCALAR_H_
19
20 #include <type_traits>
21 #include <algorithm>
22 #include <cmath>
23 #include <vector>
24 #include <string>
25 #include <memory>
26 #include <sstream>
27 #include <utility>
28 #include <cfloat>
29
30 #include "base/base.h"
31 #include "ir/dtype.h"
32 #include "ir/dtype/number.h"
33 #include "utils/hashing.h"
34 #include "base/bfloat16.h"
35
36 using std::fabs;
37
38 namespace mindspore {
39 template <typename T>
scalar_float_to_string(T v)40 inline std::string scalar_float_to_string(T v) {
41 std::stringstream ss;
42 ss << v;
43 return ss.str();
44 }
45
46 template <typename T>
scalar_to_string(T v)47 inline std::string scalar_to_string(T v) {
48 return std::to_string(v);
49 }
50
51 /// \brief Scalar defines interface for scalar data.
52 class MS_CORE_API Scalar : public Value {
53 public:
54 /// \brief The default constructor for Scalar.
55 Scalar() = default;
56 /// \brief The constructor for Scalar.
57 ///
58 /// \param[in] t The type of scalar.
Scalar(const TypePtr t)59 explicit Scalar(const TypePtr t) : Value(t) {}
60 /// \brief The destructor of Scalar.
61 ~Scalar() override = default;
62 MS_DECLARE_PARENT(Scalar, Value)
63 /// \brief Check whether the value of scalar is zero.
64 ///
65 /// \return Return true if the value of scalar is zero ,else return false.
66 virtual bool IsZero() = 0;
67 /// \brief Check whether the value of scalar is zero.
68 ///
69 /// \return Return true if the value of scalar is zero ,else return false.
70 virtual bool IsOne() = 0;
71 abstract::AbstractBasePtr ToAbstract() override;
72
73 protected:
74 std::size_t hash_ = 0;
75 };
76 using ScalarPtr = std::shared_ptr<Scalar>;
77
78 /// \brief BoolImm defines interface for bool data.
79 class MS_CORE_API BoolImm final : public Scalar {
80 public:
81 /// \brief The constructor of BoolImm.
82 ///
83 /// \param[in] b The value of bool data.
BoolImm(bool b)84 explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); }
85 /// \brief The destructor of BoolImm.
86 ~BoolImm() override = default;
MS_DECLARE_PARENT(BoolImm,Scalar)87 MS_DECLARE_PARENT(BoolImm, Scalar)
88 std::size_t hash() const override { return hash_; }
89 /// \brief Get the value of BoolImm.
90 ///
91 /// \return Return the value of BoolImm.
value()92 bool value() const { return v_; }
IsZero()93 bool IsZero() override { return v_ == false; }
IsOne()94 bool IsOne() override { return v_ == true; }
95 bool operator==(const Value &other) const override;
96 /// \brief Compare two BoolImm objects is equal.
97 ///
98 /// \param[in] other The other BoolImm to be compared with.
99 /// \return Return true if other's value and the value of current object are the same,else return false.
100 bool operator==(const BoolImm &other) const;
ToString()101 std::string ToString() const override {
102 if (v_) {
103 return "true";
104 } else {
105 return "false";
106 }
107 }
108
DumpText()109 std::string DumpText() const override {
110 std::ostringstream oss;
111 oss << "Bool(" << v_ << ")";
112 return oss.str();
113 }
114
115 private:
116 bool v_;
117 };
118 using BoolImmPtr = std::shared_ptr<BoolImm>;
IMM_TRAITS(BoolImmPtr,bool)119 IMM_TRAITS(BoolImmPtr, bool)
120
121 /// \brief IntegerImm defines interface for integer data.
122 class MS_CORE_API IntegerImm : public Scalar {
123 public:
124 /// \brief The default constructor for IntegerImm.
125 IntegerImm() = default;
126 /// \brief The constructor for IntegerImm.
127 ///
128 /// \param[in] t The type of IntegerImm.
129 explicit IntegerImm(const TypePtr &t) : Scalar(t) {}
130 /// \brief The destructor of Scalar.
131 ~IntegerImm() override = default;
132 MS_DECLARE_PARENT(IntegerImm, Scalar)
133 };
134
135 /// \brief Int8Imm defines interface for int8 data.
136 class MS_CORE_API Int8Imm final : public IntegerImm {
137 public:
138 /// \brief The default constructor for Int8Imm.
Int8Imm()139 Int8Imm() : IntegerImm(kInt8), v_(0) {}
140 /// \brief The constructor for Int8Imm.
141 ///
142 /// \param[in] v The value of Int8Imm.
Int8Imm(int8_t v)143 explicit Int8Imm(int8_t v) : IntegerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
144 /// \brief The destructor of Int8Imm.
145 ~Int8Imm() override = default;
MS_DECLARE_PARENT(Int8Imm,IntegerImm)146 MS_DECLARE_PARENT(Int8Imm, IntegerImm)
147 std::size_t hash() const override { return hash_; }
IsZero()148 bool IsZero() override { return v_ == 0; }
IsOne()149 bool IsOne() override { return v_ == 1; }
150 /// \brief Get the value of Int8Imm.
151 ///
152 /// \return Return the value of Int8Imm.
value()153 int8_t value() const { return v_; }
154 bool operator==(const Value &other) const override;
155 /// \brief Compare two Int8Imm objects is equal.
156 ///
157 /// \param[in] other The other Int8Imm to be compared with.
158 /// \return Return true if other's value and the value of current object are the same,else return false.
159 bool operator==(const Int8Imm &other) const;
ToString()160 std::string ToString() const override { return scalar_to_string(v_); }
161
DumpText()162 std::string DumpText() const override {
163 std::ostringstream oss;
164 oss << "I8(" << int(v_) << ")";
165 return oss.str();
166 }
167
168 private:
169 int8_t v_;
170 };
171 using Int8ImmPtr = std::shared_ptr<Int8Imm>;
IMM_TRAITS(Int8ImmPtr,int8_t)172 IMM_TRAITS(Int8ImmPtr, int8_t)
173 /// \brief Int16Imm defines interface for int16 data.
174 class MS_CORE_API Int16Imm final : public IntegerImm {
175 public:
176 /// \brief The default constructor for Int16Imm.
177 Int16Imm() : IntegerImm(kInt16), v_(0) {}
178 /// \brief The constructor for Int16Imm.
179 ///
180 /// \param[in] v The value of Int16Imm.
181 explicit Int16Imm(int16_t v) : IntegerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
182 /// \brief The destructor of Int16Imm.
183 ~Int16Imm() override = default;
184 MS_DECLARE_PARENT(Int16Imm, IntegerImm)
185 std::size_t hash() const override { return hash_; }
186 bool IsZero() override { return v_ == 0; }
187 bool IsOne() override { return v_ == 1; }
188 /// \brief Get the value of Int16Imm.
189 ///
190 /// \return Return the value of Int16Imm.
191 int16_t value() const { return v_; }
192 bool operator==(const Value &other) const override;
193 /// \brief Compare two Int16Imm objects is equal.
194 ///
195 /// \param[in] other The other Int16Imm to be compared with.
196 /// \return Return true if other's value and the value of current object are the same,else return false.
197 bool operator==(const Int16Imm &other) const;
198 std::string ToString() const override { return scalar_to_string(v_); }
199
200 std::string DumpText() const override {
201 std::ostringstream oss;
202 oss << "I16(" << int(v_) << ")";
203 return oss.str();
204 }
205
206 private:
207 int16_t v_;
208 };
209 using Int16ImmPtr = std::shared_ptr<Int16Imm>;
IMM_TRAITS(Int16ImmPtr,int16_t)210 IMM_TRAITS(Int16ImmPtr, int16_t)
211
212 /// \brief Int32Imm defines interface for int32 data.
213 class MS_CORE_API Int32Imm final : public IntegerImm {
214 public:
215 /// \brief The default constructor for Int32Imm.
216 Int32Imm() : IntegerImm(kInt32), v_(0) {}
217 /// \brief The constructor for Int32Imm.
218 ///
219 /// \param[in] v The value of Int32Imm.
220 explicit Int32Imm(int v) : IntegerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
221 /// \brief The destructor of Int32Imm.
222 ~Int32Imm() override = default;
223 MS_DECLARE_PARENT(Int32Imm, IntegerImm)
224 std::size_t hash() const override { return hash_; }
225 bool IsZero() override { return v_ == 0; }
226 bool IsOne() override { return v_ == 1; }
227 /// \brief Get the value of Int32Imm.
228 ///
229 /// \return Return the value of Int32Imm.
230 int32_t value() const { return v_; }
231 bool operator==(const Value &other) const override;
232 /// \brief Compare two Int32Imm objects is equal.
233 ///
234 /// \param[in] other The other Int32Imm to be compared with.
235 /// \return Return true if other's value and the value of current object are the same,else return false.
236 bool operator==(const Int32Imm &other) const;
237 std::string ToString() const override { return scalar_to_string(v_); }
238
239 std::string DumpText() const override {
240 std::ostringstream oss;
241 oss << "I32(" << int(v_) << ")";
242 return oss.str();
243 }
244
245 private:
246 int32_t v_;
247 };
248 using Int32ImmPtr = std::shared_ptr<Int32Imm>;
IMM_TRAITS(Int32ImmPtr,int32_t)249 IMM_TRAITS(Int32ImmPtr, int32_t)
250
251 /// \brief Int64Imm defines interface for int64 data.
252 class MS_CORE_API Int64Imm final : public IntegerImm {
253 public:
254 /// \brief The default constructor for Int64Imm.
255 Int64Imm() : IntegerImm(kInt64), v_(0) {}
256 /// \brief The constructor for Int64Imm.
257 ///
258 /// \param[in] v The value of Int64Imm.
259 explicit Int64Imm(int64_t v) : IntegerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); }
260 /// \brief The destructor of Int64Imm.
261 ~Int64Imm() override = default;
262 MS_DECLARE_PARENT(Int64Imm, IntegerImm)
263 std::size_t hash() const override { return hash_; }
264 bool IsZero() override { return v_ == 0; }
265 bool IsOne() override { return v_ == 1; }
266 /// \brief Get the value of Int64Imm.
267 ///
268 /// \return Return the value of Int64Imm.
269 int64_t value() const { return v_; }
270 bool operator==(const Value &other) const override;
271 /// \brief Compare two Int64Imm objects is equal.
272 ///
273 /// \param[in] other The other Int64Imm to be compared with.
274 /// \return Return true if other's value and the value of current object are the same,else return false.
275 bool operator==(const Int64Imm &other) const;
276 std::string ToString() const override { return scalar_to_string(v_); }
277
278 std::string DumpText() const override {
279 std::ostringstream oss;
280 oss << "I64(" << v_ << ")";
281 return oss.str();
282 }
283
284 private:
285 int64_t v_;
286 };
287 using Int64ImmPtr = std::shared_ptr<Int64Imm>;
IMM_TRAITS(Int64ImmPtr,int64_t)288 IMM_TRAITS(Int64ImmPtr, int64_t)
289 /// \brief UInt8Imm defines interface for uint8 data.
290 class MS_CORE_API UInt8Imm final : public IntegerImm {
291 public:
292 /// \brief The default constructor for UInt8Imm.
293 UInt8Imm() : IntegerImm(kUInt8), v_(0) {}
294 /// \brief The constructor for UInt8Imm.
295 ///
296 /// \param[in] v The value of UInt8Imm.
297 explicit UInt8Imm(uint8_t v) : IntegerImm(kUInt8), v_(v) {
298 hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
299 }
300 /// \brief The destructor of UInt8Imm.
301 ~UInt8Imm() override = default;
302 MS_DECLARE_PARENT(UInt8Imm, IntegerImm)
303 std::size_t hash() const override { return hash_; }
304 bool IsZero() override { return v_ == 0; }
305 bool IsOne() override { return v_ == 1; }
306 /// \brief Get the value of UInt8Imm.
307 ///
308 /// \return Return the value of UInt8Imm.
309 uint8_t value() const { return v_; }
310 bool operator==(const Value &other) const override;
311 /// \brief Compare two UInt8Imm objects is equal.
312 ///
313 /// \param[in] other The other UInt8Imm to be compared with.
314 /// \return Return true if other's value and the value of current object are the same,else return false.
315 bool operator==(const UInt8Imm &other) const;
316 std::string ToString() const override { return scalar_to_string(v_); }
317
318 std::string DumpText() const override {
319 std::ostringstream oss;
320 oss << "U8(" << unsigned(v_) << ")";
321 return oss.str();
322 }
323
324 private:
325 uint8_t v_;
326 };
327 using UInt8ImmPtr = std::shared_ptr<UInt8Imm>;
328 IMM_TRAITS(UInt8ImmPtr, uint8_t);
329
330 /// \brief UInt16Imm defines interface for uint16 data.
331 class MS_CORE_API UInt16Imm final : public IntegerImm {
332 public:
333 /// \brief The default constructor for UInt16Imm.
UInt16Imm()334 UInt16Imm() : IntegerImm(kUInt16), v_(0) {}
335 /// \brief The constructor for UInt16Imm.
336 ///
337 /// \param[in] v The value of UInt16Imm.
UInt16Imm(uint16_t v)338 explicit UInt16Imm(uint16_t v) : IntegerImm(kUInt16), v_(v) {
339 hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
340 }
341 /// \brief The destructor of UInt16Imm.
342 ~UInt16Imm() override = default;
MS_DECLARE_PARENT(UInt16Imm,IntegerImm)343 MS_DECLARE_PARENT(UInt16Imm, IntegerImm)
344 std::size_t hash() const override { return hash_; }
IsZero()345 bool IsZero() override { return v_ == 0; }
IsOne()346 bool IsOne() override { return v_ == 1; }
347 /// \brief Get the value of UInt16Imm.
348 ///
349 /// \return Return the value of UInt16Imm.
value()350 uint16_t value() const { return v_; }
351 bool operator==(const Value &other) const override;
352 /// \brief Compare two UInt16Imm objects is equal.
353 ///
354 /// \param[in] other The other UInt16Imm to be compared with.
355 /// \return Return true if other's value and the value of current object are the same,else return false.
356 bool operator==(const UInt16Imm &other) const;
ToString()357 std::string ToString() const override { return scalar_to_string(v_); }
358
DumpText()359 std::string DumpText() const override {
360 std::ostringstream oss;
361 oss << "U16(" << unsigned(v_) << ")";
362 return oss.str();
363 }
364
365 private:
366 uint16_t v_;
367 };
368 using UInt16ImmPtr = std::shared_ptr<UInt16Imm>;
369 IMM_TRAITS(UInt16ImmPtr, uint16_t);
370
371 /// \brief UInt32Imm defines interface for uint32 data.
372 class MS_CORE_API UInt32Imm final : public IntegerImm {
373 public:
374 /// \brief The default constructor for UInt32Imm.
UInt32Imm()375 UInt32Imm() : IntegerImm(kUInt32), v_(0) {}
376 /// \brief The constructor for UInt32Imm.
377 ///
378 /// \param[in] v The value of UInt32Imm.
UInt32Imm(uint32_t v)379 explicit UInt32Imm(uint32_t v) : IntegerImm(kUInt32), v_(v) {
380 hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
381 }
382 /// \brief The destructor of UInt32Imm.
383 ~UInt32Imm() override = default;
MS_DECLARE_PARENT(UInt32Imm,IntegerImm)384 MS_DECLARE_PARENT(UInt32Imm, IntegerImm)
385 std::size_t hash() const override { return hash_; }
IsZero()386 bool IsZero() override { return v_ == 0; }
IsOne()387 bool IsOne() override { return v_ == 1; }
388 /// \brief Get the value of UInt32Imm.
389 ///
390 /// \return Return the value of UInt32Imm.
value()391 uint32_t value() const { return v_; }
392 bool operator==(const Value &other) const override;
393 /// \brief Compare two UInt32Imm objects is equal.
394 ///
395 /// \param[in] other The other UInt32Imm to be compared with.
396 /// \return Return true if other's value and the value of current object are the same,else return false.
397 bool operator==(const UInt32Imm &other) const;
ToString()398 std::string ToString() const override { return scalar_to_string(v_); }
399
DumpText()400 std::string DumpText() const override {
401 std::ostringstream oss;
402 oss << "U32(" << unsigned(v_) << ")";
403 return oss.str();
404 }
405
406 private:
407 uint32_t v_;
408 };
409 using UInt32ImmPtr = std::shared_ptr<UInt32Imm>;
410 IMM_TRAITS(UInt32ImmPtr, uint32_t);
411 /// \brief UInt64Imm defines interface for uint64 data.
412 class MS_CORE_API UInt64Imm final : public IntegerImm {
413 public:
414 /// \brief The default constructor for UInt64Imm.
UInt64Imm()415 UInt64Imm() : IntegerImm(kUInt64), v_(0) {}
416 /// \brief The constructor for UInt64Imm.
417 ///
418 /// \param[in] v The value of UInt64Imm.
UInt64Imm(uint64_t v)419 explicit UInt64Imm(uint64_t v) : IntegerImm(kUInt64), v_(v) {
420 hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)});
421 }
422 /// \brief The destructor of UInt64Imm.
423 ~UInt64Imm() override = default;
MS_DECLARE_PARENT(UInt64Imm,IntegerImm)424 MS_DECLARE_PARENT(UInt64Imm, IntegerImm)
425 std::size_t hash() const override { return hash_; }
IsZero()426 bool IsZero() override { return v_ == 0; }
IsOne()427 bool IsOne() override { return v_ == 1; }
428 /// \brief Get the value of UInt64Imm.
429 ///
430 /// \return Return the value of UInt64Imm.
value()431 uint64_t value() const { return v_; }
432 bool operator==(const Value &other) const override;
433 /// \brief Compare two UInt64Imm objects is equal.
434 ///
435 /// \param[in] other The other UInt64Imm to be compared with.
436 /// \return Return true if other's value and the value of current object are the same,else return false.
437 bool operator==(const UInt64Imm &other) const;
ToString()438 std::string ToString() const override { return scalar_to_string(v_); }
439
DumpText()440 std::string DumpText() const override {
441 std::ostringstream oss;
442 oss << "U64(" << v_ << ")";
443 return oss.str();
444 }
445
446 private:
447 uint64_t v_;
448 };
449 using UInt64ImmPtr = std::shared_ptr<UInt64Imm>;
450 IMM_TRAITS(UInt64ImmPtr, uint64_t);
451
452 #if defined(__APPLE__)
453 using SizetImmPtr = std::shared_ptr<UInt64Imm>;
454 IMM_TRAITS(SizetImmPtr, size_t);
455 #endif
456
457 /// \brief FloatImm defines interface for float data.
458 class MS_CORE_API FloatImm : public Scalar {
459 public:
460 /// \brief The default constructor for FloatImm.
461 FloatImm() = default;
462 /// \brief The constructor for FloatImm.
463 ///
464 /// \param[in] t The value of FloatImm.
FloatImm(const TypePtr & t)465 explicit FloatImm(const TypePtr &t) : Scalar(t) {}
466 /// \brief The destructor of FloatImm.
467 ~FloatImm() override = default;
468 MS_DECLARE_PARENT(FloatImm, Scalar)
469 };
470 using FloatImmPtr = std::shared_ptr<FloatImm>;
471
472 /// \brief FP32Imm defines interface for float32 data.
473 class MS_CORE_API FP32Imm final : public FloatImm {
474 public:
475 /// \brief The default constructor for FP32Imm.
FP32Imm()476 FP32Imm() : FloatImm(kFloat32), v_(0.0) {}
477 /// \brief The constructor for FP32Imm.
478 ///
479 /// \param[in] v The value of FP32Imm.
FP32Imm(float v)480 explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); }
481 /// \brief The destructor of FP32Imm.
482 ~FP32Imm() override = default;
MS_DECLARE_PARENT(FP32Imm,FloatImm)483 MS_DECLARE_PARENT(FP32Imm, FloatImm)
484 std::size_t hash() const override { return hash_; }
IsZero()485 bool IsZero() override { return fabs(v_) <= FLT_EPSILON; }
IsOne()486 bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; }
487 /// \brief Get the value of FP32Imm.
488 ///
489 /// \return Return the value of FP32Imm.
value()490 float value() const { return v_; }
491 /// \brief Get the double type value of FP32Imm.
492 ///
493 /// \return Return the double type value of FP32Imm.
prim_value()494 double prim_value() const { return prim_v_; }
495 /// \brief Set the double type value of FP32Imm.
496 ///
497 /// \param[prim_v] double type value for FP32IMM.
set_prim_value(double prim_v)498 void set_prim_value(double prim_v) { prim_v_ = prim_v; }
499 bool operator==(const Value &other) const override;
500 /// \brief Compare two FP32Imm objects is equal.
501 ///
502 /// \param[in] other The other FP32Imm to be compared with.
503 /// \return Return true if other's value and the value of current object are the same,else return false.
504 bool operator==(const FP32Imm &other) const;
505
ToString()506 std::string ToString() const override { return scalar_float_to_string(v_); }
507
DumpText()508 std::string DumpText() const override {
509 std::ostringstream oss;
510 oss << "F32(" << v_ << ")";
511 return oss.str();
512 }
513
514 private:
515 float v_;
516 double prim_v_;
517 };
518 using FP32ImmPtr = std::shared_ptr<FP32Imm>;
IMM_TRAITS(FP32ImmPtr,float)519 IMM_TRAITS(FP32ImmPtr, float)
520
521 /// \brief FP64Imm defines interface for float64 data.
522 class MS_CORE_API FP64Imm final : public FloatImm {
523 public:
524 /// \brief The default constructor for FP64Imm.
525 FP64Imm() : FloatImm(kFloat64), v_(0.0) {}
526 /// \brief The constructor for FP64Imm.
527 ///
528 /// \param[in] v The value of FP64Imm.
529 explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); }
530 /// \brief The destructor of FP64Imm.
531 ~FP64Imm() override = default;
532 MS_DECLARE_PARENT(FP64Imm, FloatImm)
533 std::size_t hash() const override { return hash_; }
534 bool IsZero() override { return fabs(v_) <= DBL_EPSILON; }
535 bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; }
536 /// \brief Get the value of FP64Imm.
537 ///
538 /// \return Return the value of FP64Imm.
539 double value() const { return v_; }
540 bool operator==(const Value &other) const override;
541 /// \brief Compare two FP64Imm objects is equal.
542 ///
543 /// \param[in] other The other FP64Imm to be compared with.
544 /// \return Return true if other's value and the value of current object are the same,else return false.
545 bool operator==(const FP64Imm &other) const;
546 std::string ToString() const override { return scalar_float_to_string(v_); }
547
548 std::string DumpText() const override {
549 std::ostringstream oss;
550 oss << "F64(" << v_ << ")";
551 return oss.str();
552 }
553
554 private:
555 double v_;
556 };
557 using FP64ImmPtr = std::shared_ptr<FP64Imm>;
IMM_TRAITS(FP64ImmPtr,double)558 IMM_TRAITS(FP64ImmPtr, double)
559
560 /// \brief BF16Imm defines interface for bfloat16 data.
561 class MS_CORE_API BF16Imm final : public FloatImm {
562 public:
563 /// \brief The default constructor for BF16Imm.
564 BF16Imm() : FloatImm(kBFloat16), v_(0.0) {}
565 /// \brief The constructor for BF16Imm.
566 ///
567 /// \param[in] v The value of BF16Imm.
568 explicit BF16Imm(bfloat16 v) : FloatImm(kBFloat16), v_(v) {
569 hash_ = hash_combine({tid(), std::hash<bfloat16>{}(v_)});
570 }
571 /// \brief The destructor of BF16Imm.
572 ~BF16Imm() override = default;
573 MS_DECLARE_PARENT(BF16Imm, FloatImm)
574 std::size_t hash() const override { return hash_; }
575 bool IsZero() override { return v_ == BFloat16(0.0); }
576 bool IsOne() override { return v_ == BFloat16(1.0); }
577 /// \brief Get the value of BF16Imm.
578 ///
579 /// \return Return the value of BF16Imm.
580 bfloat16 value() const { return v_; }
581 bool operator==(const Value &other) const override;
582 /// \brief Compare two BF16Imm objects is equal.
583 ///
584 /// \param[in] other The other BF16Imm to be compared with.
585 /// \return Return true if other's value and the value of current object are the same,else return false.
586 bool operator==(const BF16Imm &other) const;
587 std::string ToString() const override { return scalar_float_to_string(v_); }
588
589 std::string DumpText() const override {
590 std::ostringstream oss;
591 oss << "BF16(" << v_ << ")";
592 return oss.str();
593 }
594
595 private:
596 bfloat16 v_;
597 };
598 using BF16ImmPtr = std::shared_ptr<BF16Imm>;
599 IMM_TRAITS(BF16ImmPtr, bfloat16)
600 } // namespace mindspore
601
602 #endif // MINDSPORE_CORE_IR_SCALAR_H_
603