• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
17 #define MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
18 
19 #include "mir_type.h"
20 
21 namespace maple {
22 
23 /// @brief this class provides different operations on signed and unsigned integers with arbitrary bit-width
24 class IntVal {
25 public:
26     /// @brief create zero value with zero bit-width
IntVal()27     IntVal() : value(0), width(0), sign(false) {}
28 
IntVal(uint64 val,uint8 bitWidth,bool isSigned)29     IntVal(uint64 val, uint8 bitWidth, bool isSigned) : value(val), width(bitWidth), sign(isSigned)
30     {
31         DEBUG_ASSERT(width <= valBitSize && width != 0, "bit-width is too wide");
32         TruncInPlace();
33     }
34 
IntVal(uint64 val,PrimType type)35     IntVal(uint64 val, PrimType type) : IntVal(val, GetPrimTypeActualBitSize(type), IsSignedInteger(type))
36     {
37         DEBUG_ASSERT(IsPrimitiveInteger(type), "Type must be integral");
38     }
39 
IntVal(const IntVal & val)40     IntVal(const IntVal &val) : IntVal(val.value, val.width, val.sign) {}
41 
IntVal(const IntVal & val,PrimType type)42     IntVal(const IntVal &val, PrimType type) : IntVal(val.value, type) {}
43 
IntVal(const IntVal & val,uint8 bitWidth,bool isSigned)44     IntVal(const IntVal &val, uint8 bitWidth, bool isSigned) : IntVal(val.value, bitWidth, isSigned) {}
45 
IntVal(const IntVal & val,bool isSigned)46     IntVal(const IntVal &val, bool isSigned) : IntVal(val.value, val.width, isSigned) {}
47 
48     IntVal &operator=(const IntVal &other)
49     {
50         if (width == 0) {
51             // Allow 'this' to be assigned with new bit-width and sign iff
52             // its original bit-width is zero (i.e. the value was created by the default ctor)
53             Assign(other);
54         } else {
55             // Otherwise, assign only new value, but sign and width must be the same
56             DEBUG_ASSERT(width == other.width && sign == other.sign, "different bit-width or sign");
57             value = other.value;
58         }
59 
60         return *this;
61     }
62 
63     IntVal &operator=(uint64 other)
64     {
65         DEBUG_ASSERT(width != 0, "can't be assigned to value with unknown size");
66         value = other;
67         TruncInPlace();
68 
69         return *this;
70     }
71 
Assign(const IntVal & other)72     void Assign(const IntVal &other)
73     {
74         value = other.value;
75         width = other.width;
76         sign = other.sign;
77     }
78 
79     /// @return bit-width of the value
GetBitWidth()80     uint8 GetBitWidth() const
81     {
82         return width;
83     }
84 
85     /// @return true if the value is signed
IsSigned()86     bool IsSigned() const
87     {
88         return sign;
89     }
90 
91     /// @return sign or zero extended value depending on its signedness
92     int64 GetExtValue(uint8 size = 0) const
93     {
94         return sign ? GetSXTValue(size) : GetZXTValue(size);
95     }
96 
97     /// @return zero extended value
98     uint64 GetZXTValue(uint8 size = 0) const
99     {
100         // if size == 0, just return the value itself because it's already truncated for an appropriate width
101         return size ? (value << (valBitSize - size)) >> (valBitSize - size) : value;
102     }
103 
104     /// @return sign extended value
105     int64 GetSXTValue(uint8 size = 0) const
106     {
107         uint8 bitWidth = size ? size : width;
108         return static_cast<int64>(value << (valBitSize - bitWidth)) >> (valBitSize - bitWidth);
109     }
110 
111     /// @return true if the (most significant bit) MSB is set
GetSignBit()112     bool GetSignBit() const
113     {
114         return GetBit(width - 1);
115     }
116 
117     /// @return true if all bits are 1
AreAllBitsOne()118     bool AreAllBitsOne() const
119     {
120         return value == (allOnes >> (valBitSize - width));
121     }
122 
123     /// @return true if the value is maximum considering its signedness
IsMaxValue()124     bool IsMaxValue() const
125     {
126         return sign ? value == ((uint64(1) << (width - 1)) - 1) : AreAllBitsOne();
127     }
128 
129     /// @return true if the value is minimum considering its signedness
IsMinValue()130     bool IsMinValue() const
131     {
132         return sign ? value == (uint64(1) << (width - 1)) : value == 0;
133     }
134 
135     // Comparison operators that manipulate on values with the same sign and bit-width
136     bool operator==(const IntVal &rhs) const
137     {
138         DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
139         return value == rhs.value;
140     }
141 
142     bool operator!=(const IntVal &rhs) const
143     {
144         return !(*this == rhs);
145     }
146 
147     bool operator<(const IntVal &rhs) const
148     {
149         DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
150         return sign ? GetSXTValue() < rhs.GetSXTValue() : value < rhs.value;
151     }
152 
153     bool operator>(const IntVal &rhs) const
154     {
155         DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
156         return sign ? GetSXTValue() > rhs.GetSXTValue() : value > rhs.value;
157     }
158 
159     bool operator<=(const IntVal &rhs) const
160     {
161         return !(*this > rhs);
162     }
163 
164     bool operator>=(const IntVal &rhs) const
165     {
166         return !(*this < rhs);
167     }
168 
169     // Arithmetic and bitwise operators that manipulate on values with the same sign and bit-width
170     IntVal operator+(const IntVal &val) const
171     {
172         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
173         return IntVal(value + val.value, width, sign);
174     }
175 
176     IntVal operator-(const IntVal &val) const
177     {
178         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
179         return IntVal(value - val.value, width, sign);
180     }
181 
182     IntVal operator*(const IntVal &val) const
183     {
184         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
185         return IntVal(value * val.value, width, sign);
186     }
187 
188     IntVal operator/(const IntVal &divisor) const;
189 
190     IntVal operator%(const IntVal &divisor) const;
191 
192     IntVal operator&(const IntVal &val) const
193     {
194         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
195         return IntVal(value & val.value, width, sign);
196     }
197 
198     IntVal operator|(const IntVal &val) const
199     {
200         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
201         return IntVal(value | val.value, width, sign);
202     }
203 
204     IntVal operator^(const IntVal &val) const
205     {
206         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
207         return IntVal(value ^ val.value, width, sign);
208     }
209 
210     /// @brief shift-left operator
211     IntVal operator<<(uint64 bits) const
212     {
213         DEBUG_ASSERT(bits <= width, "invalid shift value");
214         return IntVal(value << bits, width, sign);
215     }
216 
217     IntVal operator<<(const IntVal &bits) const
218     {
219         return *this << bits.value;
220     }
221 
222     /// @brief shift-right operator
223     /// @note if value is signed this operator works as arithmetic shift-right operator,
224     ///       otherwise it's logical shift-right operator
225     IntVal operator>>(uint64 bits) const
226     {
227         DEBUG_ASSERT(bits <= width, "invalid shift value");
228         return IntVal(sign ? GetSXTValue() >> bits : value >> bits, width, sign);
229     }
230 
231     IntVal operator>>(const IntVal &bits) const
232     {
233         return *this >> bits.value;
234     }
235 
236     // Comparison operators that compare values obtained from primitive type.
237     /// @note Note that these functions work as follows:
238     ///          1) sign or zero extend both values (*this and/or rhs) to the new bit-width
239     ///             obtained from pType depending on their original signedness;
240     ///             or truncate the values if their original bit-width is greater than new one
241     ///          2) then perform the operation itself on new given values that have the same bit-width and sign
242     /// @warning it's better to avoid using these function in favor of operator==, operator< etc
Equal(const IntVal & rhs,PrimType pType)243     bool Equal(const IntVal &rhs, PrimType pType) const
244     {
245         return TruncOrExtend(pType) == rhs.TruncOrExtend(pType);
246     }
247 
Less(const IntVal & rhs,PrimType pType)248     bool Less(const IntVal &rhs, PrimType pType) const
249     {
250         return TruncOrExtend(pType) < rhs.TruncOrExtend(pType);
251     }
252 
Greater(const IntVal & rhs,PrimType pType)253     bool Greater(const IntVal &rhs, PrimType pType) const
254     {
255         return TruncOrExtend(pType) > rhs.TruncOrExtend(pType);
256     }
257 
258     // Arithmetic and bitwise operators that allow creating a new value
259     // with the bit-width and sign obtained from primitive type
260     /// @note Note that these functions work as follows:
261     ///          1) sign or zero extend both values (*this and rhs) to the new bit-width
262     ///             obtained from pType depending on their original signedness;
263     ///             or truncate the values if their original bit-width is greater than new one
264     ///          2) then perform the operation itself on new given values that have the same bit-width and sign
265     /// @warning it's better to avoid using these function in favor of operator+, operator- etc
Add(const IntVal & val,PrimType pType)266     IntVal Add(const IntVal &val, PrimType pType) const
267     {
268         return TruncOrExtend(pType) + val.TruncOrExtend(pType);
269     }
270 
Sub(const IntVal & val,PrimType pType)271     IntVal Sub(const IntVal &val, PrimType pType) const
272     {
273         return TruncOrExtend(pType) - val.TruncOrExtend(pType);
274     }
275 
Mul(const IntVal & val,PrimType pType)276     IntVal Mul(const IntVal &val, PrimType pType) const
277     {
278         return TruncOrExtend(pType) * val.TruncOrExtend(pType);
279     }
280 
Div(const IntVal & divisor,PrimType pType)281     IntVal Div(const IntVal &divisor, PrimType pType) const
282     {
283         return TruncOrExtend(pType) / divisor.TruncOrExtend(pType);
284     }
285 
286     // sign division in terms of new bitWidth
SDiv(const IntVal & divisor,uint8 bitWidth)287     IntVal SDiv(const IntVal &divisor, uint8 bitWidth) const
288     {
289         return TruncOrExtend(bitWidth, true) / divisor.TruncOrExtend(bitWidth, true);
290     }
291 
292     // unsigned division in terms of new bitWidth
UDiv(const IntVal & divisor,uint8 bitWidth)293     IntVal UDiv(const IntVal &divisor, uint8 bitWidth) const
294     {
295         return TruncOrExtend(bitWidth, false) / divisor.TruncOrExtend(bitWidth, false);
296     }
297 
298     // unsigned division in terms of new bitWidth
Rem(const IntVal & divisor,PrimType pType)299     IntVal Rem(const IntVal &divisor, PrimType pType) const
300     {
301         return TruncOrExtend(pType) % divisor.TruncOrExtend(pType);
302     }
303 
304     // signed modulo in terms of new bitWidth
SRem(const IntVal & divisor,uint8 bitWidth)305     IntVal SRem(const IntVal &divisor, uint8 bitWidth) const
306     {
307         return TruncOrExtend(bitWidth, true) % divisor.TruncOrExtend(bitWidth, true);
308     }
309 
310     // unsigned modulo in terms of new bitWidth
URem(const IntVal & divisor,uint8 bitWidth)311     IntVal URem(const IntVal &divisor, uint8 bitWidth) const
312     {
313         return TruncOrExtend(bitWidth, false) % divisor.TruncOrExtend(bitWidth, false);
314     }
315 
And(const IntVal & val,PrimType pType)316     IntVal And(const IntVal &val, PrimType pType) const
317     {
318         return TruncOrExtend(pType) & val.TruncOrExtend(pType);
319     }
320 
Or(const IntVal & val,PrimType pType)321     IntVal Or(const IntVal &val, PrimType pType) const
322     {
323         return TruncOrExtend(pType) | val.TruncOrExtend(pType);
324     }
325 
Xor(const IntVal & val,PrimType pType)326     IntVal Xor(const IntVal &val, PrimType pType) const
327     {
328         return TruncOrExtend(pType) ^ val.TruncOrExtend(pType);
329     }
330 
331     // left-shift operators
Shl(const IntVal & shift,PrimType pType)332     IntVal Shl(const IntVal &shift, PrimType pType) const
333     {
334         return Shl(shift.value, pType);
335     }
336 
Shl(uint64 shift,PrimType pType)337     IntVal Shl(uint64 shift, PrimType pType) const
338     {
339         return TruncOrExtend(pType) << shift;
340     }
341 
342     // logical right-shift operators (MSB is zero extended)
LShr(const IntVal & shift,PrimType pType)343     IntVal LShr(const IntVal &shift, PrimType pType) const
344     {
345         return LShr(shift.value, pType);
346     }
347 
LShr(uint64 shift,PrimType pType)348     IntVal LShr(uint64 shift, PrimType pType) const
349     {
350         IntVal ret = TruncOrExtend(pType);
351 
352         DEBUG_ASSERT(shift <= ret.width, "invalid shift value");
353         ret.value >>= shift;
354 
355         return ret;
356     }
357 
LShr(uint64 shift)358     IntVal LShr(uint64 shift) const
359     {
360         DEBUG_ASSERT(shift <= width, "invalid shift value");
361         return IntVal(value >> shift, width, sign);
362     }
363 
364     // arithmetic right-shift operators (MSB is sign extended)
AShr(const IntVal & shift,PrimType pType)365     IntVal AShr(const IntVal &shift, PrimType pType) const
366     {
367         return AShr(shift.value, pType);
368     }
369 
AShr(uint64 shift,PrimType pType)370     IntVal AShr(uint64 shift, PrimType pType) const
371     {
372         IntVal ret = TruncOrExtend(pType);
373 
374         DEBUG_ASSERT(shift <= ret.width, "invalid shift value");
375         ret.value = ret.GetSXTValue() >> shift;
376         ret.TruncInPlace();
377 
378         return ret;
379     }
380 
381     /// @brief invert all bits of value
382     IntVal operator~() const
383     {
384         return IntVal(~value, width, sign);
385     }
386 
387     /// @return negated value
388     IntVal operator-() const
389     {
390         return IntVal(~value + 1, width, sign);
391     }
392 
393     /// @brief truncate value to the given bit-width in-place.
394     /// @note if bitWidth is not passed (or zero), the original
395     ///       bit-width is preserved and the value is truncated to the original bit-width
396     void TruncInPlace(uint8 bitWidth = 0)
397     {
398         DEBUG_ASSERT(valBitSize >= bitWidth, "invalid bit-width for truncate");
399 
400         value &= allOnes >> (valBitSize - (bitWidth ? bitWidth : width));
401 
402         if (bitWidth) {
403             width = bitWidth;
404         }
405     }
406 
407     /// @return truncated value to the given bit-width
408     /// @note returned value will have bit-width and sign obtained from newType
Trunc(PrimType newType)409     IntVal Trunc(PrimType newType) const
410     {
411         return Trunc(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
412     }
413 
414     /// @return sign or zero extended value depending on its signedness
415     /// @note returned value will have bit-width and sign obtained from newType
Extend(PrimType newType)416     IntVal Extend(PrimType newType) const
417     {
418         return Extend(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
419     }
420 
421     /// @return sign/zero extended value or truncated value depending on bit-width
422     /// @note returned value will have bit-width and sign obtained from newType
TruncOrExtend(PrimType newType)423     IntVal TruncOrExtend(PrimType newType) const
424     {
425         return TruncOrExtend(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
426     }
427 
TruncOrExtend(uint8 newWidth,bool isSigned)428     IntVal TruncOrExtend(uint8 newWidth, bool isSigned) const
429     {
430         return newWidth <= width ? Trunc(newWidth, isSigned) : Extend(newWidth, isSigned);
431     }
432 
433 private:
GetBit(uint8 bit)434     bool GetBit(uint8 bit) const
435     {
436         DEBUG_ASSERT(bit < width, "Required bit is out of value range");
437         return (value & (uint64(1) << bit)) != 0;
438     }
439 
Trunc(uint8 newWidth,bool isSigned)440     IntVal Trunc(uint8 newWidth, bool isSigned) const
441     {
442         return {value, newWidth, isSigned};
443     }
444 
Extend(uint8 newWidth,bool isSigned)445     IntVal Extend(uint8 newWidth, bool isSigned) const
446     {
447         DEBUG_ASSERT(newWidth > width, "invalid size for extension");
448         return IntVal(GetExtValue(), newWidth, isSigned);
449     }
450 
451     static constexpr uint8 valBitSize = sizeof(uint64) * CHAR_BIT;
452     static constexpr uint64 allOnes = uint64(~0);
453 
454     uint64 value;
455     uint8 width;
456     bool sign;
457 };
458 
459 // Additional comparison operators
460 inline bool operator==(const IntVal &v1, int64 v2)
461 {
462     return v1.GetExtValue() == v2;
463 }
464 
465 inline bool operator==(int64 v1, const IntVal &v2)
466 {
467     return v2 == v1;
468 }
469 
470 inline bool operator!=(const IntVal &v1, int64 v2)
471 {
472     return !(v1 == v2);
473 }
474 
475 inline bool operator!=(int64 v1, const IntVal &v2)
476 {
477     return !(v2 == v1);
478 }
479 
480 /// @return the smaller of two values
481 /// @note bit-width and sign must be the same for both parameters
Min(const IntVal & a,const IntVal & b)482 inline IntVal Min(const IntVal &a, const IntVal &b)
483 {
484     return a < b ? a : b;
485 }
486 
487 /// @return the smaller of two values in terms of newType
488 /// @note returned value will have bit-width and sign obtained from newType
Min(const IntVal & a,const IntVal & b,PrimType newType)489 inline IntVal Min(const IntVal &a, const IntVal &b, PrimType newType)
490 {
491     return a.Less(b, newType) ? IntVal(a, newType) : IntVal(b, newType);
492 }
493 
494 /// @return the larger of two values
495 /// @note bit-width and sign must be the same for both parameters
Max(const IntVal & a,const IntVal & b)496 inline IntVal Max(const IntVal &a, const IntVal &b)
497 {
498     return Min(a, b) == a ? b : a;
499 }
500 
501 /// @return the larger of two values in terms of newType
502 /// @note returned value will have bit-width and sign obtained from newType
Max(const IntVal & a,const IntVal & b,PrimType newType)503 inline IntVal Max(const IntVal &a, const IntVal &b, PrimType newType)
504 {
505     return Min(a, b, newType) == a ? b : a;
506 }
507 
508 /// @brief dump IntVal object to the output stream
509 std::ostream &operator<<(std::ostream &os, const IntVal &value);
510 
511 // Arithmetic operators that manipulate on scalar (uint64) value and IntVal object
512 // in terms of sign and bit-width of IntVal object
513 inline IntVal operator+(const IntVal &v1, uint64 v2)
514 {
515     return v1 + IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
516 }
517 
518 inline IntVal operator+(uint64 v1, const IntVal &v2)
519 {
520     return v2 + v1;
521 }
522 
523 inline IntVal operator-(const IntVal &v1, uint64 v2)
524 {
525     return v1 - IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
526 }
527 
528 inline IntVal operator-(uint64 v1, const IntVal &v2)
529 {
530     return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) - v2;
531 }
532 
533 inline IntVal operator*(const IntVal &v1, uint64 v2)
534 {
535     return v1 * IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
536 }
537 
538 inline IntVal operator*(uint64 v1, const IntVal &v2)
539 {
540     return v2 * v1;
541 }
542 
543 inline IntVal operator/(const IntVal &v1, uint64 v2)
544 {
545     return v1 / IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
546 }
547 
548 inline IntVal operator/(uint64 v1, const IntVal &v2)
549 {
550     return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) / v2;
551 }
552 
553 inline IntVal operator%(const IntVal &v1, uint64 v2)
554 {
555     return v1 % IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
556 }
557 
558 inline IntVal operator%(uint64 v1, const IntVal &v2)
559 {
560     return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) % v2;
561 }
562 
563 inline IntVal operator&(const IntVal &v1, uint64 v2)
564 {
565     return v1 & IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
566 }
567 
568 inline IntVal operator&(uint64 v1, const IntVal &v2)
569 {
570     return v2 & v1;
571 }
572 
573 }  // namespace maple
574 
575 #endif  // MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
576