• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
17 #define MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
18 
19 #include "mir_type.h"
20 
21 namespace maple {
22 
23 /// @brief this class provides different operations on signed and unsigned integers with arbitrary bit-width
24 class IntVal {
25 public:
26     /// @brief create zero value with zero bit-width
IntVal()27     IntVal() : value(0), width(0), sign(false) {}
28 
IntVal(uint64 val,uint8 bitWidth,bool isSigned)29     IntVal(uint64 val, uint8 bitWidth, bool isSigned) : value(val), width(bitWidth), sign(isSigned)
30     {
31         DEBUG_ASSERT(width <= valBitSize && width != 0, "bit-width is too wide");
32         TruncInPlace();
33     }
34 
IntVal(uint64 val,PrimType type)35     IntVal(uint64 val, PrimType type) : IntVal(val, GetPrimTypeActualBitSize(type), IsSignedInteger(type))
36     {
37         DEBUG_ASSERT(IsPrimitiveInteger(type), "Type must be integral");
38     }
39 
IntVal(const IntVal & val)40     IntVal(const IntVal &val) : IntVal(val.value, val.width, val.sign) {}
41 
IntVal(const IntVal & val,PrimType type)42     IntVal(const IntVal &val, PrimType type) : IntVal(val.value, type) {}
43 
IntVal(const IntVal & val,uint8 bitWidth,bool isSigned)44     IntVal(const IntVal &val, uint8 bitWidth, bool isSigned) : IntVal(val.value, bitWidth, isSigned) {}
45 
IntVal(const IntVal & val,bool isSigned)46     IntVal(const IntVal &val, bool isSigned) : IntVal(val.value, val.width, isSigned) {}
47 
48     IntVal &operator=(const IntVal &other)
49     {
50         if (width == 0) {
51             // Allow 'this' to be assigned with new bit-width and sign iff
52             // its original bit-width is zero (i.e. the value was created by the default ctor)
53             Assign(other);
54         } else {
55             // Otherwise, assign only new value, but sign and width must be the same
56             DEBUG_ASSERT(width == other.width && sign == other.sign, "different bit-width or sign");
57             value = other.value;
58         }
59 
60         return *this;
61     }
62 
63     IntVal &operator=(uint64 other)
64     {
65         DEBUG_ASSERT(width != 0, "can't be assigned to value with unknown size");
66         value = other;
67         TruncInPlace();
68 
69         return *this;
70     }
71 
Assign(const IntVal & other)72     void Assign(const IntVal &other)
73     {
74         value = other.value;
75         width = other.width;
76         sign = other.sign;
77     }
78 
79     /// @return bit-width of the value
GetBitWidth()80     uint8 GetBitWidth() const
81     {
82         return width;
83     }
84 
85     /// @return true if the value is signed
IsSigned()86     bool IsSigned() const
87     {
88         return sign;
89     }
90 
91     /// @return sign or zero extended value depending on its signedness
92     int64 GetExtValue(uint8 size = 0) const
93     {
94         return sign ? GetSXTValue(size) : GetZXTValue(size);
95     }
96 
97     /// @return zero extended value
98     uint64 GetZXTValue(uint8 size = 0) const
99     {
100         // if size == 0, just return the value itself because it's already truncated for an appropriate width
101         return size ? (value << (valBitSize - size)) >> (valBitSize - size) : value;
102     }
103 
104     /// @return sign extended value
105     int64 GetSXTValue(uint8 size = 0) const
106     {
107         uint8 bitWidth = size ? size : width;
108         return static_cast<int64>(value << (valBitSize - bitWidth)) >> (valBitSize - bitWidth);
109     }
110 
111     /// @return true if the (most significant bit) MSB is set
GetSignBit()112     bool GetSignBit() const
113     {
114         return GetBit(width - 1);
115     }
116 
117     /// @return true if all bits are 1
AreAllBitsOne()118     bool AreAllBitsOne() const
119     {
120         return value == (allOnes >> (valBitSize - width));
121     }
122 
123     /// @return true if the value is maximum considering its signedness
IsMaxValue()124     bool IsMaxValue() const
125     {
126         return sign ? value == ((uint64(1) << (width - 1)) - 1) : AreAllBitsOne();
127     }
128 
129     /// @return true if the value is minimum considering its signedness
IsMinValue()130     bool IsMinValue() const
131     {
132         return sign ? value == (uint64(1) << (width - 1)) : value == 0;
133     }
134 
135     // Comparison operators that manipulate on values with the same sign and bit-width
136     bool operator==(const IntVal &rhs) const
137     {
138         DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
139         return value == rhs.value;
140     }
141 
142     bool operator!=(const IntVal &rhs) const
143     {
144         return !(*this == rhs);
145     }
146 
147     bool operator<(const IntVal &rhs) const
148     {
149         DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
150         return sign ? GetSXTValue() < rhs.GetSXTValue() : value < rhs.value;
151     }
152 
153     bool operator>(const IntVal &rhs) const
154     {
155         DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
156         return sign ? GetSXTValue() > rhs.GetSXTValue() : value > rhs.value;
157     }
158 
159     bool operator<=(const IntVal &rhs) const
160     {
161         return !(*this > rhs);
162     }
163 
164     bool operator>=(const IntVal &rhs) const
165     {
166         return !(*this < rhs);
167     }
168 
169     // Arithmetic and bitwise operators that manipulate on values with the same sign and bit-width
170     IntVal operator+(const IntVal &val) const
171     {
172         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
173         return IntVal(value + val.value, width, sign);
174     }
175 
176     IntVal operator-(const IntVal &val) const
177     {
178         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
179         return IntVal(value - val.value, width, sign);
180     }
181 
182     IntVal operator*(const IntVal &val) const
183     {
184         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
185         return IntVal(value * val.value, width, sign);
186     }
187 
188     IntVal operator/(const IntVal &divisor) const
189     {
190         DEBUG_ASSERT(width == divisor.width && sign == divisor.sign, "bit-width and sign must be the same");
191         DEBUG_ASSERT(divisor.value != 0, "division by zero");
192         DEBUG_ASSERT(!sign || (!IsMinValue() || !divisor.AreAllBitsOne()), "minValue / -1 leads to overflow");
193 
194         bool isNeg = sign && GetSignBit();
195         bool isDivisorNeg = divisor.sign && divisor.GetSignBit();
196 
197         uint64 dividendVal = isNeg ? (-*this).value : value;
198         uint64 divisorVal = isDivisorNeg ? (-divisor).value : divisor.value;
199 
200         return isNeg != isDivisorNeg ? -IntVal(dividendVal / divisorVal, width, sign)
201                                      : IntVal(dividendVal / divisorVal, width, sign);
202     }
203 
204     IntVal operator%(const IntVal &divisor) const
205     {
206         DEBUG_ASSERT(width == divisor.width && sign == divisor.sign, "bit-width and sign must be the same");
207         DEBUG_ASSERT(divisor.value != 0, "division by zero");
208         DEBUG_ASSERT(!sign || (!IsMinValue() || !divisor.AreAllBitsOne()), "minValue % -1 leads to overflow");
209 
210         bool isNeg = sign && GetSignBit();
211         bool isDivisorNeg = divisor.sign && divisor.GetSignBit();
212 
213         uint64 dividendVal = isNeg ? (-*this).value : value;
214         uint64 divisorVal = isDivisorNeg ? (-divisor).value : divisor.value;
215 
216         return isNeg ? -IntVal(dividendVal % divisorVal, width, sign) : IntVal(dividendVal % divisorVal, width, sign);
217     }
218 
219     IntVal operator&(const IntVal &val) const
220     {
221         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
222         return IntVal(value & val.value, width, sign);
223     }
224 
225     IntVal operator|(const IntVal &val) const
226     {
227         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
228         return IntVal(value | val.value, width, sign);
229     }
230 
231     IntVal operator^(const IntVal &val) const
232     {
233         DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
234         return IntVal(value ^ val.value, width, sign);
235     }
236 
237     /// @brief shift-left operator
238     IntVal operator<<(uint64 bits) const
239     {
240         DEBUG_ASSERT(bits <= width, "invalid shift value");
241         return IntVal(value << bits, width, sign);
242     }
243 
244     IntVal operator<<(const IntVal &bits) const
245     {
246         return *this << bits.value;
247     }
248 
249     /// @brief shift-right operator
250     /// @note if value is signed this operator works as arithmetic shift-right operator,
251     ///       otherwise it's logical shift-right operator
252     IntVal operator>>(uint64 bits) const
253     {
254         DEBUG_ASSERT(bits <= width, "invalid shift value");
255         return IntVal(sign ? GetSXTValue() >> bits : value >> bits, width, sign);
256     }
257 
258     IntVal operator>>(const IntVal &bits) const
259     {
260         return *this >> bits.value;
261     }
262 
263     // Comparison operators that compare values obtained from primitive type.
264     /// @note Note that these functions work as follows:
265     ///          1) sign or zero extend both values (*this and/or rhs) to the new bit-width
266     ///             obtained from pType depending on their original signedness;
267     ///             or truncate the values if their original bit-width is greater than new one
268     ///          2) then perform the operation itself on new given values that have the same bit-width and sign
269     /// @warning it's better to avoid using these function in favor of operator==, operator< etc
Equal(const IntVal & rhs,PrimType pType)270     bool Equal(const IntVal &rhs, PrimType pType) const
271     {
272         return TruncOrExtend(pType) == rhs.TruncOrExtend(pType);
273     }
274 
Less(const IntVal & rhs,PrimType pType)275     bool Less(const IntVal &rhs, PrimType pType) const
276     {
277         return TruncOrExtend(pType) < rhs.TruncOrExtend(pType);
278     }
279 
Greater(const IntVal & rhs,PrimType pType)280     bool Greater(const IntVal &rhs, PrimType pType) const
281     {
282         return TruncOrExtend(pType) > rhs.TruncOrExtend(pType);
283     }
284 
285     // Arithmetic and bitwise operators that allow creating a new value
286     // with the bit-width and sign obtained from primitive type
287     /// @note Note that these functions work as follows:
288     ///          1) sign or zero extend both values (*this and rhs) to the new bit-width
289     ///             obtained from pType depending on their original signedness;
290     ///             or truncate the values if their original bit-width is greater than new one
291     ///          2) then perform the operation itself on new given values that have the same bit-width and sign
292     /// @warning it's better to avoid using these function in favor of operator+, operator- etc
Add(const IntVal & val,PrimType pType)293     IntVal Add(const IntVal &val, PrimType pType) const
294     {
295         return TruncOrExtend(pType) + val.TruncOrExtend(pType);
296     }
297 
Sub(const IntVal & val,PrimType pType)298     IntVal Sub(const IntVal &val, PrimType pType) const
299     {
300         return TruncOrExtend(pType) - val.TruncOrExtend(pType);
301     }
302 
Mul(const IntVal & val,PrimType pType)303     IntVal Mul(const IntVal &val, PrimType pType) const
304     {
305         return TruncOrExtend(pType) * val.TruncOrExtend(pType);
306     }
307 
Div(const IntVal & divisor,PrimType pType)308     IntVal Div(const IntVal &divisor, PrimType pType) const
309     {
310         return TruncOrExtend(pType) / divisor.TruncOrExtend(pType);
311     }
312 
313     // sign division in terms of new bitWidth
SDiv(const IntVal & divisor,uint8 bitWidth)314     IntVal SDiv(const IntVal &divisor, uint8 bitWidth) const
315     {
316         return TruncOrExtend(bitWidth, true) / divisor.TruncOrExtend(bitWidth, true);
317     }
318 
319     // unsigned division in terms of new bitWidth
UDiv(const IntVal & divisor,uint8 bitWidth)320     IntVal UDiv(const IntVal &divisor, uint8 bitWidth) const
321     {
322         return TruncOrExtend(bitWidth, false) / divisor.TruncOrExtend(bitWidth, false);
323     }
324 
325     // unsigned division in terms of new bitWidth
Rem(const IntVal & divisor,PrimType pType)326     IntVal Rem(const IntVal &divisor, PrimType pType) const
327     {
328         return TruncOrExtend(pType) % divisor.TruncOrExtend(pType);
329     }
330 
331     // signed modulo in terms of new bitWidth
SRem(const IntVal & divisor,uint8 bitWidth)332     IntVal SRem(const IntVal &divisor, uint8 bitWidth) const
333     {
334         return TruncOrExtend(bitWidth, true) % divisor.TruncOrExtend(bitWidth, true);
335     }
336 
337     // unsigned modulo in terms of new bitWidth
URem(const IntVal & divisor,uint8 bitWidth)338     IntVal URem(const IntVal &divisor, uint8 bitWidth) const
339     {
340         return TruncOrExtend(bitWidth, false) % divisor.TruncOrExtend(bitWidth, false);
341     }
342 
And(const IntVal & val,PrimType pType)343     IntVal And(const IntVal &val, PrimType pType) const
344     {
345         return TruncOrExtend(pType) & val.TruncOrExtend(pType);
346     }
347 
Or(const IntVal & val,PrimType pType)348     IntVal Or(const IntVal &val, PrimType pType) const
349     {
350         return TruncOrExtend(pType) | val.TruncOrExtend(pType);
351     }
352 
Xor(const IntVal & val,PrimType pType)353     IntVal Xor(const IntVal &val, PrimType pType) const
354     {
355         return TruncOrExtend(pType) ^ val.TruncOrExtend(pType);
356     }
357 
358     // left-shift operators
Shl(const IntVal & shift,PrimType pType)359     IntVal Shl(const IntVal &shift, PrimType pType) const
360     {
361         return Shl(shift.value, pType);
362     }
363 
Shl(uint64 shift,PrimType pType)364     IntVal Shl(uint64 shift, PrimType pType) const
365     {
366         return TruncOrExtend(pType) << shift;
367     }
368 
369     // logical right-shift operators (MSB is zero extended)
LShr(const IntVal & shift,PrimType pType)370     IntVal LShr(const IntVal &shift, PrimType pType) const
371     {
372         return LShr(shift.value, pType);
373     }
374 
LShr(uint64 shift,PrimType pType)375     IntVal LShr(uint64 shift, PrimType pType) const
376     {
377         IntVal ret = TruncOrExtend(pType);
378 
379         DEBUG_ASSERT(shift <= ret.width, "invalid shift value");
380         ret.value >>= shift;
381 
382         return ret;
383     }
384 
LShr(uint64 shift)385     IntVal LShr(uint64 shift) const
386     {
387         DEBUG_ASSERT(shift <= width, "invalid shift value");
388         return IntVal(value >> shift, width, sign);
389     }
390 
391     // arithmetic right-shift operators (MSB is sign extended)
AShr(const IntVal & shift,PrimType pType)392     IntVal AShr(const IntVal &shift, PrimType pType) const
393     {
394         return AShr(shift.value, pType);
395     }
396 
AShr(uint64 shift,PrimType pType)397     IntVal AShr(uint64 shift, PrimType pType) const
398     {
399         IntVal ret = TruncOrExtend(pType);
400 
401         DEBUG_ASSERT(shift <= ret.width, "invalid shift value");
402         ret.value = ret.GetSXTValue() >> shift;
403         ret.TruncInPlace();
404 
405         return ret;
406     }
407 
408     /// @brief invert all bits of value
409     IntVal operator~() const
410     {
411         return IntVal(~value, width, sign);
412     }
413 
414     /// @return negated value
415     IntVal operator-() const
416     {
417         return IntVal(~value + 1, width, sign);
418     }
419 
420     /// @brief truncate value to the given bit-width in-place.
421     /// @note if bitWidth is not passed (or zero), the original
422     ///       bit-width is preserved and the value is truncated to the original bit-width
423     void TruncInPlace(uint8 bitWidth = 0)
424     {
425         DEBUG_ASSERT(valBitSize >= bitWidth, "invalid bit-width for truncate");
426 
427         value &= allOnes >> (valBitSize - (bitWidth ? bitWidth : width));
428 
429         if (bitWidth) {
430             width = bitWidth;
431         }
432     }
433 
434     /// @return truncated value to the given bit-width
435     /// @note returned value will have bit-width and sign obtained from newType
Trunc(PrimType newType)436     IntVal Trunc(PrimType newType) const
437     {
438         return Trunc(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
439     }
440 
441     /// @return sign or zero extended value depending on its signedness
442     /// @note returned value will have bit-width and sign obtained from newType
Extend(PrimType newType)443     IntVal Extend(PrimType newType) const
444     {
445         return Extend(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
446     }
447 
448     /// @return sign/zero extended value or truncated value depending on bit-width
449     /// @note returned value will have bit-width and sign obtained from newType
TruncOrExtend(PrimType newType)450     IntVal TruncOrExtend(PrimType newType) const
451     {
452         return TruncOrExtend(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
453     }
454 
TruncOrExtend(uint8 newWidth,bool isSigned)455     IntVal TruncOrExtend(uint8 newWidth, bool isSigned) const
456     {
457         return newWidth <= width ? Trunc(newWidth, isSigned) : Extend(newWidth, isSigned);
458     }
459 
460 private:
GetBit(uint8 bit)461     bool GetBit(uint8 bit) const
462     {
463         DEBUG_ASSERT(bit < width, "Required bit is out of value range");
464         return (value & (uint64(1) << bit)) != 0;
465     }
466 
Trunc(uint8 newWidth,bool isSigned)467     IntVal Trunc(uint8 newWidth, bool isSigned) const
468     {
469         return {value, newWidth, isSigned};
470     }
471 
Extend(uint8 newWidth,bool isSigned)472     IntVal Extend(uint8 newWidth, bool isSigned) const
473     {
474         DEBUG_ASSERT(newWidth > width, "invalid size for extension");
475         return IntVal(GetExtValue(), newWidth, isSigned);
476     }
477 
478     static constexpr uint8 valBitSize = sizeof(uint64) * CHAR_BIT;
479     static constexpr uint64 allOnes = uint64(~0);
480 
481     uint64 value;
482     uint8 width;
483     bool sign;
484 };
485 
486 // Additional comparison operators
487 inline bool operator==(const IntVal &v1, int64 v2)
488 {
489     return v1.GetExtValue() == v2;
490 }
491 
492 inline bool operator==(int64 v1, const IntVal &v2)
493 {
494     return v2 == v1;
495 }
496 
497 inline bool operator!=(const IntVal &v1, int64 v2)
498 {
499     return !(v1 == v2);
500 }
501 
502 inline bool operator!=(int64 v1, const IntVal &v2)
503 {
504     return !(v2 == v1);
505 }
506 
507 /// @return the smaller of two values
508 /// @note bit-width and sign must be the same for both parameters
Min(const IntVal & a,const IntVal & b)509 inline IntVal Min(const IntVal &a, const IntVal &b)
510 {
511     return a < b ? a : b;
512 }
513 
514 /// @return the smaller of two values in terms of newType
515 /// @note returned value will have bit-width and sign obtained from newType
Min(const IntVal & a,const IntVal & b,PrimType newType)516 inline IntVal Min(const IntVal &a, const IntVal &b, PrimType newType)
517 {
518     return a.Less(b, newType) ? IntVal(a, newType) : IntVal(b, newType);
519 }
520 
521 /// @return the larger of two values
522 /// @note bit-width and sign must be the same for both parameters
Max(const IntVal & a,const IntVal & b)523 inline IntVal Max(const IntVal &a, const IntVal &b)
524 {
525     return Min(a, b) == a ? b : a;
526 }
527 
528 /// @return the larger of two values in terms of newType
529 /// @note returned value will have bit-width and sign obtained from newType
Max(const IntVal & a,const IntVal & b,PrimType newType)530 inline IntVal Max(const IntVal &a, const IntVal &b, PrimType newType)
531 {
532     return Min(a, b, newType) == a ? b : a;
533 }
534 
535 /// @brief dump IntVal object to the output stream
536 std::ostream &operator<<(std::ostream &os, const IntVal &value);
537 
538 // Arithmetic operators that manipulate on scalar (uint64) value and IntVal object
539 // in terms of sign and bit-width of IntVal object
540 inline IntVal operator+(const IntVal &v1, uint64 v2)
541 {
542     return v1 + IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
543 }
544 
545 inline IntVal operator+(uint64 v1, const IntVal &v2)
546 {
547     return v2 + v1;
548 }
549 
550 inline IntVal operator-(const IntVal &v1, uint64 v2)
551 {
552     return v1 - IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
553 }
554 
555 inline IntVal operator-(uint64 v1, const IntVal &v2)
556 {
557     return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) - v2;
558 }
559 
560 inline IntVal operator*(const IntVal &v1, uint64 v2)
561 {
562     return v1 * IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
563 }
564 
565 inline IntVal operator*(uint64 v1, const IntVal &v2)
566 {
567     return v2 * v1;
568 }
569 
570 inline IntVal operator/(const IntVal &v1, uint64 v2)
571 {
572     return v1 / IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
573 }
574 
575 inline IntVal operator/(uint64 v1, const IntVal &v2)
576 {
577     return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) / v2;
578 }
579 
580 inline IntVal operator%(const IntVal &v1, uint64 v2)
581 {
582     return v1 % IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
583 }
584 
585 inline IntVal operator%(uint64 v1, const IntVal &v2)
586 {
587     return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) % v2;
588 }
589 
590 inline IntVal operator&(const IntVal &v1, uint64 v2)
591 {
592     return v1 & IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
593 }
594 
595 inline IntVal operator&(uint64 v1, const IntVal &v2)
596 {
597     return v2 & v1;
598 }
599 
600 }  // namespace maple
601 
602 #endif  // MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
603