• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_SCALAR_H_
18 #define MINDSPORE_CORE_IR_SCALAR_H_
19 
20 #include <type_traits>
21 #include <algorithm>
22 #include <cmath>
23 #include <vector>
24 #include <string>
25 #include <memory>
26 #include <sstream>
27 #include <utility>
28 #include <cfloat>
29 
30 #include "base/base.h"
31 #include "ir/dtype.h"
32 #include "ir/dtype/number.h"
33 #include "utils/hashing.h"
34 #include "base/bfloat16.h"
35 
36 using std::fabs;
37 
38 namespace mindspore {
39 template <typename T>
scalar_float_to_string(T v)40 inline std::string scalar_float_to_string(T v) {
41   std::stringstream ss;
42   ss << v;
43   return ss.str();
44 }
45 
46 template <typename T>
scalar_to_string(T v)47 inline std::string scalar_to_string(T v) {
48   return std::to_string(v);
49 }
50 
51 /// \brief Scalar defines interface for scalar data.
52 class MS_CORE_API Scalar : public Value {
53  public:
54   /// \brief The default constructor for Scalar.
55   Scalar() = default;
56   /// \brief The constructor for Scalar.
57   ///
58   /// \param[in] t The type of scalar.
Scalar(const TypePtr t)59   explicit Scalar(const TypePtr t) : Value(t) {}
60   /// \brief The destructor of Scalar.
61   ~Scalar() override = default;
62   MS_DECLARE_PARENT(Scalar, Value)
63   /// \brief Check whether the value of scalar is zero.
64   ///
65   /// \return Return true if the value of scalar is zero ,else return false.
66   virtual bool IsZero() = 0;
67   /// \brief Check whether the value of scalar is zero.
68   ///
69   /// \return Return true if the value of scalar is zero ,else return false.
70   virtual bool IsOne() = 0;
71   abstract::AbstractBasePtr ToAbstract() override;
72 
73  protected:
74   std::size_t hash_ = 0;
75 };
76 using ScalarPtr = std::shared_ptr<Scalar>;
77 
78 /// \brief BoolImm defines interface for bool data.
79 class MS_CORE_API BoolImm final : public Scalar {
80  public:
81   /// \brief The constructor of BoolImm.
82   ///
83   /// \param[in] b The value of bool data.
BoolImm(bool b)84   explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); }
85   /// \brief The destructor of BoolImm.
86   ~BoolImm() override = default;
MS_DECLARE_PARENT(BoolImm,Scalar)87   MS_DECLARE_PARENT(BoolImm, Scalar)
88   std::size_t hash() const override { return hash_; }
89   /// \brief Get the value of BoolImm.
90   ///
91   /// \return Return the value of BoolImm.
value()92   bool value() const { return v_; }
IsZero()93   bool IsZero() override { return v_ == false; }
IsOne()94   bool IsOne() override { return v_ == true; }
95   bool operator==(const Value &other) const override;
96   /// \brief Compare two BoolImm objects is equal.
97   ///
98   /// \param[in] other The other BoolImm to be compared with.
99   /// \return Return true if other's value and the value of current object are the same,else return false.
100   bool operator==(const BoolImm &other) const;
ToString()101   std::string ToString() const override {
102     if (v_) {
103       return "true";
104     } else {
105       return "false";
106     }
107   }
108 
DumpText()109   std::string DumpText() const override {
110     std::ostringstream oss;
111     oss << "Bool(" << v_ << ")";
112     return oss.str();
113   }
114 
115  private:
116   bool v_;
117 };
118 using BoolImmPtr = std::shared_ptr<BoolImm>;
IMM_TRAITS(BoolImmPtr,bool)119 IMM_TRAITS(BoolImmPtr, bool)
120 
121 /// \brief IntegerImm defines interface for integer data.
122 class MS_CORE_API IntegerImm : public Scalar {
123  public:
124   /// \brief The default constructor for IntegerImm.
125   IntegerImm() = default;
126   /// \brief The constructor for IntegerImm.
127   ///
128   /// \param[in] t The type of IntegerImm.
129   explicit IntegerImm(const TypePtr &t) : Scalar(t) {}
130   /// \brief The destructor of Scalar.
131   ~IntegerImm() override = default;
132   MS_DECLARE_PARENT(IntegerImm, Scalar)
133 };
134 
135 /// \brief Int8Imm defines interface for int8 data.
136 class MS_CORE_API Int8Imm final : public IntegerImm {
137  public:
138   /// \brief The default constructor for Int8Imm.
Int8Imm()139   Int8Imm() : IntegerImm(kInt8), v_(0) {}
140   /// \brief The constructor for Int8Imm.
141   ///
142   /// \param[in] v The value of Int8Imm.
Int8Imm(int8_t v)143   explicit Int8Imm(int8_t v) : IntegerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
144   /// \brief The destructor of Int8Imm.
145   ~Int8Imm() override = default;
MS_DECLARE_PARENT(Int8Imm,IntegerImm)146   MS_DECLARE_PARENT(Int8Imm, IntegerImm)
147   std::size_t hash() const override { return hash_; }
IsZero()148   bool IsZero() override { return v_ == 0; }
IsOne()149   bool IsOne() override { return v_ == 1; }
150   /// \brief Get the value of Int8Imm.
151   ///
152   /// \return Return the value of Int8Imm.
value()153   int8_t value() const { return v_; }
154   bool operator==(const Value &other) const override;
155   /// \brief Compare two Int8Imm objects is equal.
156   ///
157   /// \param[in] other The other Int8Imm to be compared with.
158   /// \return Return true if other's value and the value of current object are the same,else return false.
159   bool operator==(const Int8Imm &other) const;
ToString()160   std::string ToString() const override { return scalar_to_string(v_); }
161 
DumpText()162   std::string DumpText() const override {
163     std::ostringstream oss;
164     oss << "I8(" << int(v_) << ")";
165     return oss.str();
166   }
167 
168  private:
169   int8_t v_;
170 };
171 using Int8ImmPtr = std::shared_ptr<Int8Imm>;
IMM_TRAITS(Int8ImmPtr,int8_t)172 IMM_TRAITS(Int8ImmPtr, int8_t)
173 /// \brief Int16Imm defines interface for int16 data.
174 class MS_CORE_API Int16Imm final : public IntegerImm {
175  public:
176   /// \brief The default constructor for Int16Imm.
177   Int16Imm() : IntegerImm(kInt16), v_(0) {}
178   /// \brief The constructor for Int16Imm.
179   ///
180   /// \param[in] v The value of Int16Imm.
181   explicit Int16Imm(int16_t v) : IntegerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
182   /// \brief The destructor of Int16Imm.
183   ~Int16Imm() override = default;
184   MS_DECLARE_PARENT(Int16Imm, IntegerImm)
185   std::size_t hash() const override { return hash_; }
186   bool IsZero() override { return v_ == 0; }
187   bool IsOne() override { return v_ == 1; }
188   /// \brief Get the value of Int16Imm.
189   ///
190   /// \return Return the value of Int16Imm.
191   int16_t value() const { return v_; }
192   bool operator==(const Value &other) const override;
193   /// \brief Compare two Int16Imm objects is equal.
194   ///
195   /// \param[in] other The other Int16Imm to be compared with.
196   /// \return Return true if other's value and the value of current object are the same,else return false.
197   bool operator==(const Int16Imm &other) const;
198   std::string ToString() const override { return scalar_to_string(v_); }
199 
200   std::string DumpText() const override {
201     std::ostringstream oss;
202     oss << "I16(" << int(v_) << ")";
203     return oss.str();
204   }
205 
206  private:
207   int16_t v_;
208 };
209 using Int16ImmPtr = std::shared_ptr<Int16Imm>;
IMM_TRAITS(Int16ImmPtr,int16_t)210 IMM_TRAITS(Int16ImmPtr, int16_t)
211 
212 /// \brief Int32Imm defines interface for int32 data.
213 class MS_CORE_API Int32Imm final : public IntegerImm {
214  public:
215   /// \brief The default constructor for Int32Imm.
216   Int32Imm() : IntegerImm(kInt32), v_(0) {}
217   /// \brief The constructor for Int32Imm.
218   ///
219   /// \param[in] v The value of Int32Imm.
220   explicit Int32Imm(int v) : IntegerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
221   /// \brief The destructor of Int32Imm.
222   ~Int32Imm() override = default;
223   MS_DECLARE_PARENT(Int32Imm, IntegerImm)
224   std::size_t hash() const override { return hash_; }
225   bool IsZero() override { return v_ == 0; }
226   bool IsOne() override { return v_ == 1; }
227   /// \brief Get the value of Int32Imm.
228   ///
229   /// \return Return the value of Int32Imm.
230   int32_t value() const { return v_; }
231   bool operator==(const Value &other) const override;
232   /// \brief Compare two Int32Imm objects is equal.
233   ///
234   /// \param[in] other The other Int32Imm to be compared with.
235   /// \return Return true if other's value and the value of current object are the same,else return false.
236   bool operator==(const Int32Imm &other) const;
237   std::string ToString() const override { return scalar_to_string(v_); }
238 
239   std::string DumpText() const override {
240     std::ostringstream oss;
241     oss << "I32(" << int(v_) << ")";
242     return oss.str();
243   }
244 
245  private:
246   int32_t v_;
247 };
248 using Int32ImmPtr = std::shared_ptr<Int32Imm>;
IMM_TRAITS(Int32ImmPtr,int32_t)249 IMM_TRAITS(Int32ImmPtr, int32_t)
250 
251 /// \brief Int64Imm defines interface for int64 data.
252 class MS_CORE_API Int64Imm final : public IntegerImm {
253  public:
254   /// \brief The default constructor for Int64Imm.
255   Int64Imm() : IntegerImm(kInt64), v_(0) {}
256   /// \brief The constructor for Int64Imm.
257   ///
258   /// \param[in] v The value of Int64Imm.
259   explicit Int64Imm(int64_t v) : IntegerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); }
260   /// \brief The destructor of Int64Imm.
261   ~Int64Imm() override = default;
262   MS_DECLARE_PARENT(Int64Imm, IntegerImm)
263   std::size_t hash() const override { return hash_; }
264   bool IsZero() override { return v_ == 0; }
265   bool IsOne() override { return v_ == 1; }
266   /// \brief Get the value of Int64Imm.
267   ///
268   /// \return Return the value of Int64Imm.
269   int64_t value() const { return v_; }
270   bool operator==(const Value &other) const override;
271   /// \brief Compare two Int64Imm objects is equal.
272   ///
273   /// \param[in] other The other Int64Imm to be compared with.
274   /// \return Return true if other's value and the value of current object are the same,else return false.
275   bool operator==(const Int64Imm &other) const;
276   std::string ToString() const override { return scalar_to_string(v_); }
277 
278   std::string DumpText() const override {
279     std::ostringstream oss;
280     oss << "I64(" << v_ << ")";
281     return oss.str();
282   }
283 
284  private:
285   int64_t v_;
286 };
287 using Int64ImmPtr = std::shared_ptr<Int64Imm>;
IMM_TRAITS(Int64ImmPtr,int64_t)288 IMM_TRAITS(Int64ImmPtr, int64_t)
289 /// \brief UInt8Imm defines interface for uint8 data.
290 class MS_CORE_API UInt8Imm final : public IntegerImm {
291  public:
292   /// \brief The default constructor for UInt8Imm.
293   UInt8Imm() : IntegerImm(kUInt8), v_(0) {}
294   /// \brief The constructor for UInt8Imm.
295   ///
296   /// \param[in] v The value of UInt8Imm.
297   explicit UInt8Imm(uint8_t v) : IntegerImm(kUInt8), v_(v) {
298     hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
299   }
300   /// \brief The destructor of UInt8Imm.
301   ~UInt8Imm() override = default;
302   MS_DECLARE_PARENT(UInt8Imm, IntegerImm)
303   std::size_t hash() const override { return hash_; }
304   bool IsZero() override { return v_ == 0; }
305   bool IsOne() override { return v_ == 1; }
306   /// \brief Get the value of UInt8Imm.
307   ///
308   /// \return Return the value of UInt8Imm.
309   uint8_t value() const { return v_; }
310   bool operator==(const Value &other) const override;
311   /// \brief Compare two UInt8Imm objects is equal.
312   ///
313   /// \param[in] other The other UInt8Imm to be compared with.
314   /// \return Return true if other's value and the value of current object are the same,else return false.
315   bool operator==(const UInt8Imm &other) const;
316   std::string ToString() const override { return scalar_to_string(v_); }
317 
318   std::string DumpText() const override {
319     std::ostringstream oss;
320     oss << "U8(" << unsigned(v_) << ")";
321     return oss.str();
322   }
323 
324  private:
325   uint8_t v_;
326 };
327 using UInt8ImmPtr = std::shared_ptr<UInt8Imm>;
328 IMM_TRAITS(UInt8ImmPtr, uint8_t);
329 
330 /// \brief UInt16Imm defines interface for uint16 data.
331 class MS_CORE_API UInt16Imm final : public IntegerImm {
332  public:
333   /// \brief The default constructor for UInt16Imm.
UInt16Imm()334   UInt16Imm() : IntegerImm(kUInt16), v_(0) {}
335   /// \brief The constructor for UInt16Imm.
336   ///
337   /// \param[in] v The value of UInt16Imm.
UInt16Imm(uint16_t v)338   explicit UInt16Imm(uint16_t v) : IntegerImm(kUInt16), v_(v) {
339     hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
340   }
341   /// \brief The destructor of UInt16Imm.
342   ~UInt16Imm() override = default;
MS_DECLARE_PARENT(UInt16Imm,IntegerImm)343   MS_DECLARE_PARENT(UInt16Imm, IntegerImm)
344   std::size_t hash() const override { return hash_; }
IsZero()345   bool IsZero() override { return v_ == 0; }
IsOne()346   bool IsOne() override { return v_ == 1; }
347   /// \brief Get the value of UInt16Imm.
348   ///
349   /// \return Return the value of UInt16Imm.
value()350   uint16_t value() const { return v_; }
351   bool operator==(const Value &other) const override;
352   /// \brief Compare two UInt16Imm objects is equal.
353   ///
354   /// \param[in] other The other UInt16Imm to be compared with.
355   /// \return Return true if other's value and the value of current object are the same,else return false.
356   bool operator==(const UInt16Imm &other) const;
ToString()357   std::string ToString() const override { return scalar_to_string(v_); }
358 
DumpText()359   std::string DumpText() const override {
360     std::ostringstream oss;
361     oss << "U16(" << unsigned(v_) << ")";
362     return oss.str();
363   }
364 
365  private:
366   uint16_t v_;
367 };
368 using UInt16ImmPtr = std::shared_ptr<UInt16Imm>;
369 IMM_TRAITS(UInt16ImmPtr, uint16_t);
370 
371 /// \brief UInt32Imm defines interface for uint32 data.
372 class MS_CORE_API UInt32Imm final : public IntegerImm {
373  public:
374   /// \brief The default constructor for UInt32Imm.
UInt32Imm()375   UInt32Imm() : IntegerImm(kUInt32), v_(0) {}
376   /// \brief The constructor for UInt32Imm.
377   ///
378   /// \param[in] v The value of UInt32Imm.
UInt32Imm(uint32_t v)379   explicit UInt32Imm(uint32_t v) : IntegerImm(kUInt32), v_(v) {
380     hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
381   }
382   /// \brief The destructor of UInt32Imm.
383   ~UInt32Imm() override = default;
MS_DECLARE_PARENT(UInt32Imm,IntegerImm)384   MS_DECLARE_PARENT(UInt32Imm, IntegerImm)
385   std::size_t hash() const override { return hash_; }
IsZero()386   bool IsZero() override { return v_ == 0; }
IsOne()387   bool IsOne() override { return v_ == 1; }
388   /// \brief Get the value of UInt32Imm.
389   ///
390   /// \return Return the value of UInt32Imm.
value()391   uint32_t value() const { return v_; }
392   bool operator==(const Value &other) const override;
393   /// \brief Compare two UInt32Imm objects is equal.
394   ///
395   /// \param[in] other The other UInt32Imm to be compared with.
396   /// \return Return true if other's value and the value of current object are the same,else return false.
397   bool operator==(const UInt32Imm &other) const;
ToString()398   std::string ToString() const override { return scalar_to_string(v_); }
399 
DumpText()400   std::string DumpText() const override {
401     std::ostringstream oss;
402     oss << "U32(" << unsigned(v_) << ")";
403     return oss.str();
404   }
405 
406  private:
407   uint32_t v_;
408 };
409 using UInt32ImmPtr = std::shared_ptr<UInt32Imm>;
410 IMM_TRAITS(UInt32ImmPtr, uint32_t);
411 /// \brief UInt64Imm defines interface for uint64 data.
412 class MS_CORE_API UInt64Imm final : public IntegerImm {
413  public:
414   /// \brief The default constructor for UInt64Imm.
UInt64Imm()415   UInt64Imm() : IntegerImm(kUInt64), v_(0) {}
416   /// \brief The constructor for UInt64Imm.
417   ///
418   /// \param[in] v The value of UInt64Imm.
UInt64Imm(uint64_t v)419   explicit UInt64Imm(uint64_t v) : IntegerImm(kUInt64), v_(v) {
420     hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)});
421   }
422   /// \brief The destructor of UInt64Imm.
423   ~UInt64Imm() override = default;
MS_DECLARE_PARENT(UInt64Imm,IntegerImm)424   MS_DECLARE_PARENT(UInt64Imm, IntegerImm)
425   std::size_t hash() const override { return hash_; }
IsZero()426   bool IsZero() override { return v_ == 0; }
IsOne()427   bool IsOne() override { return v_ == 1; }
428   /// \brief Get the value of UInt64Imm.
429   ///
430   /// \return Return the value of UInt64Imm.
value()431   uint64_t value() const { return v_; }
432   bool operator==(const Value &other) const override;
433   /// \brief Compare two UInt64Imm objects is equal.
434   ///
435   /// \param[in] other The other UInt64Imm to be compared with.
436   /// \return Return true if other's value and the value of current object are the same,else return false.
437   bool operator==(const UInt64Imm &other) const;
ToString()438   std::string ToString() const override { return scalar_to_string(v_); }
439 
DumpText()440   std::string DumpText() const override {
441     std::ostringstream oss;
442     oss << "U64(" << v_ << ")";
443     return oss.str();
444   }
445 
446  private:
447   uint64_t v_;
448 };
449 using UInt64ImmPtr = std::shared_ptr<UInt64Imm>;
450 IMM_TRAITS(UInt64ImmPtr, uint64_t);
451 
452 #if defined(__APPLE__)
453 using SizetImmPtr = std::shared_ptr<UInt64Imm>;
454 IMM_TRAITS(SizetImmPtr, size_t);
455 #endif
456 
457 /// \brief FloatImm defines interface for float data.
458 class MS_CORE_API FloatImm : public Scalar {
459  public:
460   /// \brief The default constructor for FloatImm.
461   FloatImm() = default;
462   /// \brief The constructor for FloatImm.
463   ///
464   /// \param[in] t The value of FloatImm.
FloatImm(const TypePtr & t)465   explicit FloatImm(const TypePtr &t) : Scalar(t) {}
466   /// \brief The destructor of FloatImm.
467   ~FloatImm() override = default;
468   MS_DECLARE_PARENT(FloatImm, Scalar)
469 };
470 using FloatImmPtr = std::shared_ptr<FloatImm>;
471 
472 /// \brief FP32Imm defines interface for float32 data.
473 class MS_CORE_API FP32Imm final : public FloatImm {
474  public:
475   /// \brief The default constructor for FP32Imm.
FP32Imm()476   FP32Imm() : FloatImm(kFloat32), v_(0.0) {}
477   /// \brief The constructor for FP32Imm.
478   ///
479   /// \param[in] v The value of FP32Imm.
FP32Imm(float v)480   explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); }
481   /// \brief The destructor of FP32Imm.
482   ~FP32Imm() override = default;
MS_DECLARE_PARENT(FP32Imm,FloatImm)483   MS_DECLARE_PARENT(FP32Imm, FloatImm)
484   std::size_t hash() const override { return hash_; }
IsZero()485   bool IsZero() override { return fabs(v_) <= FLT_EPSILON; }
IsOne()486   bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; }
487   /// \brief Get the value of FP32Imm.
488   ///
489   /// \return Return the value of FP32Imm.
value()490   float value() const { return v_; }
491   /// \brief Get the double type value of FP32Imm.
492   ///
493   /// \return Return the double type value of FP32Imm.
prim_value()494   double prim_value() const { return prim_v_; }
495   /// \brief Set the double type value of FP32Imm.
496   ///
497   /// \param[prim_v] double type value for FP32IMM.
set_prim_value(double prim_v)498   void set_prim_value(double prim_v) { prim_v_ = prim_v; }
499   bool operator==(const Value &other) const override;
500   /// \brief Compare two FP32Imm objects is equal.
501   ///
502   /// \param[in] other The other FP32Imm to be compared with.
503   /// \return Return true if other's value and the value of current object are the same,else return false.
504   bool operator==(const FP32Imm &other) const;
505 
ToString()506   std::string ToString() const override { return scalar_float_to_string(v_); }
507 
DumpText()508   std::string DumpText() const override {
509     std::ostringstream oss;
510     oss << "F32(" << v_ << ")";
511     return oss.str();
512   }
513 
514  private:
515   float v_;
516   double prim_v_;
517 };
518 using FP32ImmPtr = std::shared_ptr<FP32Imm>;
IMM_TRAITS(FP32ImmPtr,float)519 IMM_TRAITS(FP32ImmPtr, float)
520 
521 /// \brief FP64Imm defines interface for float64 data.
522 class MS_CORE_API FP64Imm final : public FloatImm {
523  public:
524   /// \brief The default constructor for FP64Imm.
525   FP64Imm() : FloatImm(kFloat64), v_(0.0) {}
526   /// \brief The constructor for FP64Imm.
527   ///
528   /// \param[in] v The value of FP64Imm.
529   explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); }
530   /// \brief The destructor of FP64Imm.
531   ~FP64Imm() override = default;
532   MS_DECLARE_PARENT(FP64Imm, FloatImm)
533   std::size_t hash() const override { return hash_; }
534   bool IsZero() override { return fabs(v_) <= DBL_EPSILON; }
535   bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; }
536   /// \brief Get the value of FP64Imm.
537   ///
538   /// \return Return the value of FP64Imm.
539   double value() const { return v_; }
540   bool operator==(const Value &other) const override;
541   /// \brief Compare two FP64Imm objects is equal.
542   ///
543   /// \param[in] other The other FP64Imm to be compared with.
544   /// \return Return true if other's value and the value of current object are the same,else return false.
545   bool operator==(const FP64Imm &other) const;
546   std::string ToString() const override { return scalar_float_to_string(v_); }
547 
548   std::string DumpText() const override {
549     std::ostringstream oss;
550     oss << "F64(" << v_ << ")";
551     return oss.str();
552   }
553 
554  private:
555   double v_;
556 };
557 using FP64ImmPtr = std::shared_ptr<FP64Imm>;
IMM_TRAITS(FP64ImmPtr,double)558 IMM_TRAITS(FP64ImmPtr, double)
559 
560 /// \brief BF16Imm defines interface for bfloat16 data.
561 class MS_CORE_API BF16Imm final : public FloatImm {
562  public:
563   /// \brief The default constructor for BF16Imm.
564   BF16Imm() : FloatImm(kBFloat16), v_(0.0) {}
565   /// \brief The constructor for BF16Imm.
566   ///
567   /// \param[in] v The value of BF16Imm.
568   explicit BF16Imm(bfloat16 v) : FloatImm(kBFloat16), v_(v) {
569     hash_ = hash_combine({tid(), std::hash<bfloat16>{}(v_)});
570   }
571   /// \brief The destructor of BF16Imm.
572   ~BF16Imm() override = default;
573   MS_DECLARE_PARENT(BF16Imm, FloatImm)
574   std::size_t hash() const override { return hash_; }
575   bool IsZero() override { return v_ == BFloat16(0.0); }
576   bool IsOne() override { return v_ == BFloat16(1.0); }
577   /// \brief Get the value of BF16Imm.
578   ///
579   /// \return Return the value of BF16Imm.
580   bfloat16 value() const { return v_; }
581   bool operator==(const Value &other) const override;
582   /// \brief Compare two BF16Imm objects is equal.
583   ///
584   /// \param[in] other The other BF16Imm to be compared with.
585   /// \return Return true if other's value and the value of current object are the same,else return false.
586   bool operator==(const BF16Imm &other) const;
587   std::string ToString() const override { return scalar_float_to_string(v_); }
588 
589   std::string DumpText() const override {
590     std::ostringstream oss;
591     oss << "BF16(" << v_ << ")";
592     return oss.str();
593   }
594 
595  private:
596   bfloat16 v_;
597 };
598 using BF16ImmPtr = std::shared_ptr<BF16Imm>;
599 IMM_TRAITS(BF16ImmPtr, bfloat16)
600 }  // namespace mindspore
601 
602 #endif  // MINDSPORE_CORE_IR_SCALAR_H_
603