1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
17 #define MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
18
19 #include "mir_type.h"
20
21 namespace maple {
22
23 /// @brief this class provides different operations on signed and unsigned integers with arbitrary bit-width
24 class IntVal {
25 public:
26 /// @brief create zero value with zero bit-width
IntVal()27 IntVal() : value(0), width(0), sign(false) {}
28
IntVal(uint64 val,uint8 bitWidth,bool isSigned)29 IntVal(uint64 val, uint8 bitWidth, bool isSigned) : value(val), width(bitWidth), sign(isSigned)
30 {
31 DEBUG_ASSERT(width <= valBitSize && width != 0, "bit-width is too wide");
32 TruncInPlace();
33 }
34
IntVal(uint64 val,PrimType type)35 IntVal(uint64 val, PrimType type) : IntVal(val, GetPrimTypeActualBitSize(type), IsSignedInteger(type))
36 {
37 DEBUG_ASSERT(IsPrimitiveInteger(type), "Type must be integral");
38 }
39
IntVal(const IntVal & val)40 IntVal(const IntVal &val) : IntVal(val.value, val.width, val.sign) {}
41
IntVal(const IntVal & val,PrimType type)42 IntVal(const IntVal &val, PrimType type) : IntVal(val.value, type) {}
43
IntVal(const IntVal & val,uint8 bitWidth,bool isSigned)44 IntVal(const IntVal &val, uint8 bitWidth, bool isSigned) : IntVal(val.value, bitWidth, isSigned) {}
45
IntVal(const IntVal & val,bool isSigned)46 IntVal(const IntVal &val, bool isSigned) : IntVal(val.value, val.width, isSigned) {}
47
48 IntVal &operator=(const IntVal &other)
49 {
50 if (width == 0) {
51 // Allow 'this' to be assigned with new bit-width and sign iff
52 // its original bit-width is zero (i.e. the value was created by the default ctor)
53 Assign(other);
54 } else {
55 // Otherwise, assign only new value, but sign and width must be the same
56 DEBUG_ASSERT(width == other.width && sign == other.sign, "different bit-width or sign");
57 value = other.value;
58 }
59
60 return *this;
61 }
62
63 IntVal &operator=(uint64 other)
64 {
65 DEBUG_ASSERT(width != 0, "can't be assigned to value with unknown size");
66 value = other;
67 TruncInPlace();
68
69 return *this;
70 }
71
Assign(const IntVal & other)72 void Assign(const IntVal &other)
73 {
74 value = other.value;
75 width = other.width;
76 sign = other.sign;
77 }
78
79 /// @return bit-width of the value
GetBitWidth()80 uint8 GetBitWidth() const
81 {
82 return width;
83 }
84
85 /// @return true if the value is signed
IsSigned()86 bool IsSigned() const
87 {
88 return sign;
89 }
90
91 /// @return sign or zero extended value depending on its signedness
92 int64 GetExtValue(uint8 size = 0) const
93 {
94 return sign ? GetSXTValue(size) : GetZXTValue(size);
95 }
96
97 /// @return zero extended value
98 uint64 GetZXTValue(uint8 size = 0) const
99 {
100 // if size == 0, just return the value itself because it's already truncated for an appropriate width
101 return size ? (value << (valBitSize - size)) >> (valBitSize - size) : value;
102 }
103
104 /// @return sign extended value
105 int64 GetSXTValue(uint8 size = 0) const
106 {
107 uint8 bitWidth = size ? size : width;
108 return static_cast<int64>(value << (valBitSize - bitWidth)) >> (valBitSize - bitWidth);
109 }
110
111 /// @return true if the (most significant bit) MSB is set
GetSignBit()112 bool GetSignBit() const
113 {
114 return GetBit(width - 1);
115 }
116
117 /// @return true if all bits are 1
AreAllBitsOne()118 bool AreAllBitsOne() const
119 {
120 return value == (allOnes >> (valBitSize - width));
121 }
122
123 /// @return true if the value is maximum considering its signedness
IsMaxValue()124 bool IsMaxValue() const
125 {
126 return sign ? value == ((uint64(1) << (width - 1)) - 1) : AreAllBitsOne();
127 }
128
129 /// @return true if the value is minimum considering its signedness
IsMinValue()130 bool IsMinValue() const
131 {
132 return sign ? value == (uint64(1) << (width - 1)) : value == 0;
133 }
134
135 // Comparison operators that manipulate on values with the same sign and bit-width
136 bool operator==(const IntVal &rhs) const
137 {
138 DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
139 return value == rhs.value;
140 }
141
142 bool operator!=(const IntVal &rhs) const
143 {
144 return !(*this == rhs);
145 }
146
147 bool operator<(const IntVal &rhs) const
148 {
149 DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
150 return sign ? GetSXTValue() < rhs.GetSXTValue() : value < rhs.value;
151 }
152
153 bool operator>(const IntVal &rhs) const
154 {
155 DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
156 return sign ? GetSXTValue() > rhs.GetSXTValue() : value > rhs.value;
157 }
158
159 bool operator<=(const IntVal &rhs) const
160 {
161 return !(*this > rhs);
162 }
163
164 bool operator>=(const IntVal &rhs) const
165 {
166 return !(*this < rhs);
167 }
168
169 // Arithmetic and bitwise operators that manipulate on values with the same sign and bit-width
170 IntVal operator+(const IntVal &val) const
171 {
172 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
173 return IntVal(value + val.value, width, sign);
174 }
175
176 IntVal operator-(const IntVal &val) const
177 {
178 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
179 return IntVal(value - val.value, width, sign);
180 }
181
182 IntVal operator*(const IntVal &val) const
183 {
184 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
185 return IntVal(value * val.value, width, sign);
186 }
187
188 IntVal operator/(const IntVal &divisor) const
189 {
190 DEBUG_ASSERT(width == divisor.width && sign == divisor.sign, "bit-width and sign must be the same");
191 DEBUG_ASSERT(divisor.value != 0, "division by zero");
192 DEBUG_ASSERT(!sign || (!IsMinValue() || !divisor.AreAllBitsOne()), "minValue / -1 leads to overflow");
193
194 bool isNeg = sign && GetSignBit();
195 bool isDivisorNeg = divisor.sign && divisor.GetSignBit();
196
197 uint64 dividendVal = isNeg ? (-*this).value : value;
198 uint64 divisorVal = isDivisorNeg ? (-divisor).value : divisor.value;
199
200 return isNeg != isDivisorNeg ? -IntVal(dividendVal / divisorVal, width, sign)
201 : IntVal(dividendVal / divisorVal, width, sign);
202 }
203
204 IntVal operator%(const IntVal &divisor) const
205 {
206 DEBUG_ASSERT(width == divisor.width && sign == divisor.sign, "bit-width and sign must be the same");
207 DEBUG_ASSERT(divisor.value != 0, "division by zero");
208 DEBUG_ASSERT(!sign || (!IsMinValue() || !divisor.AreAllBitsOne()), "minValue % -1 leads to overflow");
209
210 bool isNeg = sign && GetSignBit();
211 bool isDivisorNeg = divisor.sign && divisor.GetSignBit();
212
213 uint64 dividendVal = isNeg ? (-*this).value : value;
214 uint64 divisorVal = isDivisorNeg ? (-divisor).value : divisor.value;
215
216 return isNeg ? -IntVal(dividendVal % divisorVal, width, sign) : IntVal(dividendVal % divisorVal, width, sign);
217 }
218
219 IntVal operator&(const IntVal &val) const
220 {
221 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
222 return IntVal(value & val.value, width, sign);
223 }
224
225 IntVal operator|(const IntVal &val) const
226 {
227 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
228 return IntVal(value | val.value, width, sign);
229 }
230
231 IntVal operator^(const IntVal &val) const
232 {
233 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
234 return IntVal(value ^ val.value, width, sign);
235 }
236
237 /// @brief shift-left operator
238 IntVal operator<<(uint64 bits) const
239 {
240 DEBUG_ASSERT(bits <= width, "invalid shift value");
241 return IntVal(value << bits, width, sign);
242 }
243
244 IntVal operator<<(const IntVal &bits) const
245 {
246 return *this << bits.value;
247 }
248
249 /// @brief shift-right operator
250 /// @note if value is signed this operator works as arithmetic shift-right operator,
251 /// otherwise it's logical shift-right operator
252 IntVal operator>>(uint64 bits) const
253 {
254 DEBUG_ASSERT(bits <= width, "invalid shift value");
255 return IntVal(sign ? GetSXTValue() >> bits : value >> bits, width, sign);
256 }
257
258 IntVal operator>>(const IntVal &bits) const
259 {
260 return *this >> bits.value;
261 }
262
263 // Comparison operators that compare values obtained from primitive type.
264 /// @note Note that these functions work as follows:
265 /// 1) sign or zero extend both values (*this and/or rhs) to the new bit-width
266 /// obtained from pType depending on their original signedness;
267 /// or truncate the values if their original bit-width is greater than new one
268 /// 2) then perform the operation itself on new given values that have the same bit-width and sign
269 /// @warning it's better to avoid using these function in favor of operator==, operator< etc
Equal(const IntVal & rhs,PrimType pType)270 bool Equal(const IntVal &rhs, PrimType pType) const
271 {
272 return TruncOrExtend(pType) == rhs.TruncOrExtend(pType);
273 }
274
Less(const IntVal & rhs,PrimType pType)275 bool Less(const IntVal &rhs, PrimType pType) const
276 {
277 return TruncOrExtend(pType) < rhs.TruncOrExtend(pType);
278 }
279
Greater(const IntVal & rhs,PrimType pType)280 bool Greater(const IntVal &rhs, PrimType pType) const
281 {
282 return TruncOrExtend(pType) > rhs.TruncOrExtend(pType);
283 }
284
285 // Arithmetic and bitwise operators that allow creating a new value
286 // with the bit-width and sign obtained from primitive type
287 /// @note Note that these functions work as follows:
288 /// 1) sign or zero extend both values (*this and rhs) to the new bit-width
289 /// obtained from pType depending on their original signedness;
290 /// or truncate the values if their original bit-width is greater than new one
291 /// 2) then perform the operation itself on new given values that have the same bit-width and sign
292 /// @warning it's better to avoid using these function in favor of operator+, operator- etc
Add(const IntVal & val,PrimType pType)293 IntVal Add(const IntVal &val, PrimType pType) const
294 {
295 return TruncOrExtend(pType) + val.TruncOrExtend(pType);
296 }
297
Sub(const IntVal & val,PrimType pType)298 IntVal Sub(const IntVal &val, PrimType pType) const
299 {
300 return TruncOrExtend(pType) - val.TruncOrExtend(pType);
301 }
302
Mul(const IntVal & val,PrimType pType)303 IntVal Mul(const IntVal &val, PrimType pType) const
304 {
305 return TruncOrExtend(pType) * val.TruncOrExtend(pType);
306 }
307
Div(const IntVal & divisor,PrimType pType)308 IntVal Div(const IntVal &divisor, PrimType pType) const
309 {
310 return TruncOrExtend(pType) / divisor.TruncOrExtend(pType);
311 }
312
313 // sign division in terms of new bitWidth
SDiv(const IntVal & divisor,uint8 bitWidth)314 IntVal SDiv(const IntVal &divisor, uint8 bitWidth) const
315 {
316 return TruncOrExtend(bitWidth, true) / divisor.TruncOrExtend(bitWidth, true);
317 }
318
319 // unsigned division in terms of new bitWidth
UDiv(const IntVal & divisor,uint8 bitWidth)320 IntVal UDiv(const IntVal &divisor, uint8 bitWidth) const
321 {
322 return TruncOrExtend(bitWidth, false) / divisor.TruncOrExtend(bitWidth, false);
323 }
324
325 // unsigned division in terms of new bitWidth
Rem(const IntVal & divisor,PrimType pType)326 IntVal Rem(const IntVal &divisor, PrimType pType) const
327 {
328 return TruncOrExtend(pType) % divisor.TruncOrExtend(pType);
329 }
330
331 // signed modulo in terms of new bitWidth
SRem(const IntVal & divisor,uint8 bitWidth)332 IntVal SRem(const IntVal &divisor, uint8 bitWidth) const
333 {
334 return TruncOrExtend(bitWidth, true) % divisor.TruncOrExtend(bitWidth, true);
335 }
336
337 // unsigned modulo in terms of new bitWidth
URem(const IntVal & divisor,uint8 bitWidth)338 IntVal URem(const IntVal &divisor, uint8 bitWidth) const
339 {
340 return TruncOrExtend(bitWidth, false) % divisor.TruncOrExtend(bitWidth, false);
341 }
342
And(const IntVal & val,PrimType pType)343 IntVal And(const IntVal &val, PrimType pType) const
344 {
345 return TruncOrExtend(pType) & val.TruncOrExtend(pType);
346 }
347
Or(const IntVal & val,PrimType pType)348 IntVal Or(const IntVal &val, PrimType pType) const
349 {
350 return TruncOrExtend(pType) | val.TruncOrExtend(pType);
351 }
352
Xor(const IntVal & val,PrimType pType)353 IntVal Xor(const IntVal &val, PrimType pType) const
354 {
355 return TruncOrExtend(pType) ^ val.TruncOrExtend(pType);
356 }
357
358 // left-shift operators
Shl(const IntVal & shift,PrimType pType)359 IntVal Shl(const IntVal &shift, PrimType pType) const
360 {
361 return Shl(shift.value, pType);
362 }
363
Shl(uint64 shift,PrimType pType)364 IntVal Shl(uint64 shift, PrimType pType) const
365 {
366 return TruncOrExtend(pType) << shift;
367 }
368
369 // logical right-shift operators (MSB is zero extended)
LShr(const IntVal & shift,PrimType pType)370 IntVal LShr(const IntVal &shift, PrimType pType) const
371 {
372 return LShr(shift.value, pType);
373 }
374
LShr(uint64 shift,PrimType pType)375 IntVal LShr(uint64 shift, PrimType pType) const
376 {
377 IntVal ret = TruncOrExtend(pType);
378
379 DEBUG_ASSERT(shift <= ret.width, "invalid shift value");
380 ret.value >>= shift;
381
382 return ret;
383 }
384
LShr(uint64 shift)385 IntVal LShr(uint64 shift) const
386 {
387 DEBUG_ASSERT(shift <= width, "invalid shift value");
388 return IntVal(value >> shift, width, sign);
389 }
390
391 // arithmetic right-shift operators (MSB is sign extended)
AShr(const IntVal & shift,PrimType pType)392 IntVal AShr(const IntVal &shift, PrimType pType) const
393 {
394 return AShr(shift.value, pType);
395 }
396
AShr(uint64 shift,PrimType pType)397 IntVal AShr(uint64 shift, PrimType pType) const
398 {
399 IntVal ret = TruncOrExtend(pType);
400
401 DEBUG_ASSERT(shift <= ret.width, "invalid shift value");
402 ret.value = ret.GetSXTValue() >> shift;
403 ret.TruncInPlace();
404
405 return ret;
406 }
407
408 /// @brief invert all bits of value
409 IntVal operator~() const
410 {
411 return IntVal(~value, width, sign);
412 }
413
414 /// @return negated value
415 IntVal operator-() const
416 {
417 return IntVal(~value + 1, width, sign);
418 }
419
420 /// @brief truncate value to the given bit-width in-place.
421 /// @note if bitWidth is not passed (or zero), the original
422 /// bit-width is preserved and the value is truncated to the original bit-width
423 void TruncInPlace(uint8 bitWidth = 0)
424 {
425 DEBUG_ASSERT(valBitSize >= bitWidth, "invalid bit-width for truncate");
426
427 value &= allOnes >> (valBitSize - (bitWidth ? bitWidth : width));
428
429 if (bitWidth) {
430 width = bitWidth;
431 }
432 }
433
434 /// @return truncated value to the given bit-width
435 /// @note returned value will have bit-width and sign obtained from newType
Trunc(PrimType newType)436 IntVal Trunc(PrimType newType) const
437 {
438 return Trunc(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
439 }
440
441 /// @return sign or zero extended value depending on its signedness
442 /// @note returned value will have bit-width and sign obtained from newType
Extend(PrimType newType)443 IntVal Extend(PrimType newType) const
444 {
445 return Extend(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
446 }
447
448 /// @return sign/zero extended value or truncated value depending on bit-width
449 /// @note returned value will have bit-width and sign obtained from newType
TruncOrExtend(PrimType newType)450 IntVal TruncOrExtend(PrimType newType) const
451 {
452 return TruncOrExtend(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
453 }
454
TruncOrExtend(uint8 newWidth,bool isSigned)455 IntVal TruncOrExtend(uint8 newWidth, bool isSigned) const
456 {
457 return newWidth <= width ? Trunc(newWidth, isSigned) : Extend(newWidth, isSigned);
458 }
459
460 private:
GetBit(uint8 bit)461 bool GetBit(uint8 bit) const
462 {
463 DEBUG_ASSERT(bit < width, "Required bit is out of value range");
464 return (value & (uint64(1) << bit)) != 0;
465 }
466
Trunc(uint8 newWidth,bool isSigned)467 IntVal Trunc(uint8 newWidth, bool isSigned) const
468 {
469 return {value, newWidth, isSigned};
470 }
471
Extend(uint8 newWidth,bool isSigned)472 IntVal Extend(uint8 newWidth, bool isSigned) const
473 {
474 DEBUG_ASSERT(newWidth > width, "invalid size for extension");
475 return IntVal(GetExtValue(), newWidth, isSigned);
476 }
477
478 static constexpr uint8 valBitSize = sizeof(uint64) * CHAR_BIT;
479 static constexpr uint64 allOnes = uint64(~0);
480
481 uint64 value;
482 uint8 width;
483 bool sign;
484 };
485
486 // Additional comparison operators
487 inline bool operator==(const IntVal &v1, int64 v2)
488 {
489 return v1.GetExtValue() == v2;
490 }
491
492 inline bool operator==(int64 v1, const IntVal &v2)
493 {
494 return v2 == v1;
495 }
496
497 inline bool operator!=(const IntVal &v1, int64 v2)
498 {
499 return !(v1 == v2);
500 }
501
502 inline bool operator!=(int64 v1, const IntVal &v2)
503 {
504 return !(v2 == v1);
505 }
506
507 /// @return the smaller of two values
508 /// @note bit-width and sign must be the same for both parameters
Min(const IntVal & a,const IntVal & b)509 inline IntVal Min(const IntVal &a, const IntVal &b)
510 {
511 return a < b ? a : b;
512 }
513
514 /// @return the smaller of two values in terms of newType
515 /// @note returned value will have bit-width and sign obtained from newType
Min(const IntVal & a,const IntVal & b,PrimType newType)516 inline IntVal Min(const IntVal &a, const IntVal &b, PrimType newType)
517 {
518 return a.Less(b, newType) ? IntVal(a, newType) : IntVal(b, newType);
519 }
520
521 /// @return the larger of two values
522 /// @note bit-width and sign must be the same for both parameters
Max(const IntVal & a,const IntVal & b)523 inline IntVal Max(const IntVal &a, const IntVal &b)
524 {
525 return Min(a, b) == a ? b : a;
526 }
527
528 /// @return the larger of two values in terms of newType
529 /// @note returned value will have bit-width and sign obtained from newType
Max(const IntVal & a,const IntVal & b,PrimType newType)530 inline IntVal Max(const IntVal &a, const IntVal &b, PrimType newType)
531 {
532 return Min(a, b, newType) == a ? b : a;
533 }
534
535 /// @brief dump IntVal object to the output stream
536 std::ostream &operator<<(std::ostream &os, const IntVal &value);
537
538 // Arithmetic operators that manipulate on scalar (uint64) value and IntVal object
539 // in terms of sign and bit-width of IntVal object
540 inline IntVal operator+(const IntVal &v1, uint64 v2)
541 {
542 return v1 + IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
543 }
544
545 inline IntVal operator+(uint64 v1, const IntVal &v2)
546 {
547 return v2 + v1;
548 }
549
550 inline IntVal operator-(const IntVal &v1, uint64 v2)
551 {
552 return v1 - IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
553 }
554
555 inline IntVal operator-(uint64 v1, const IntVal &v2)
556 {
557 return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) - v2;
558 }
559
560 inline IntVal operator*(const IntVal &v1, uint64 v2)
561 {
562 return v1 * IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
563 }
564
565 inline IntVal operator*(uint64 v1, const IntVal &v2)
566 {
567 return v2 * v1;
568 }
569
570 inline IntVal operator/(const IntVal &v1, uint64 v2)
571 {
572 return v1 / IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
573 }
574
575 inline IntVal operator/(uint64 v1, const IntVal &v2)
576 {
577 return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) / v2;
578 }
579
580 inline IntVal operator%(const IntVal &v1, uint64 v2)
581 {
582 return v1 % IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
583 }
584
585 inline IntVal operator%(uint64 v1, const IntVal &v2)
586 {
587 return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) % v2;
588 }
589
590 inline IntVal operator&(const IntVal &v1, uint64 v2)
591 {
592 return v1 & IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
593 }
594
595 inline IntVal operator&(uint64 v1, const IntVal &v2)
596 {
597 return v2 & v1;
598 }
599
600 } // namespace maple
601
602 #endif // MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
603