1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
17 #define MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
18
19 #include "mir_type.h"
20
21 namespace maple {
22
23 /// @brief this class provides different operations on signed and unsigned integers with arbitrary bit-width
24 class IntVal {
25 public:
26 /// @brief create zero value with zero bit-width
IntVal()27 IntVal() : value(0), width(0), sign(false) {}
28
IntVal(uint64 val,uint8 bitWidth,bool isSigned)29 IntVal(uint64 val, uint8 bitWidth, bool isSigned) : value(val), width(bitWidth), sign(isSigned)
30 {
31 DEBUG_ASSERT(width <= valBitSize && width != 0, "bit-width is too wide");
32 TruncInPlace();
33 }
34
IntVal(uint64 val,PrimType type)35 IntVal(uint64 val, PrimType type) : IntVal(val, GetPrimTypeActualBitSize(type), IsSignedInteger(type))
36 {
37 DEBUG_ASSERT(IsPrimitiveInteger(type), "Type must be integral");
38 }
39
IntVal(const IntVal & val)40 IntVal(const IntVal &val) : IntVal(val.value, val.width, val.sign) {}
41
IntVal(const IntVal & val,PrimType type)42 IntVal(const IntVal &val, PrimType type) : IntVal(val.value, type) {}
43
IntVal(const IntVal & val,uint8 bitWidth,bool isSigned)44 IntVal(const IntVal &val, uint8 bitWidth, bool isSigned) : IntVal(val.value, bitWidth, isSigned) {}
45
IntVal(const IntVal & val,bool isSigned)46 IntVal(const IntVal &val, bool isSigned) : IntVal(val.value, val.width, isSigned) {}
47
48 IntVal &operator=(const IntVal &other)
49 {
50 if (width == 0) {
51 // Allow 'this' to be assigned with new bit-width and sign iff
52 // its original bit-width is zero (i.e. the value was created by the default ctor)
53 Assign(other);
54 } else {
55 // Otherwise, assign only new value, but sign and width must be the same
56 DEBUG_ASSERT(width == other.width && sign == other.sign, "different bit-width or sign");
57 value = other.value;
58 }
59
60 return *this;
61 }
62
63 IntVal &operator=(uint64 other)
64 {
65 DEBUG_ASSERT(width != 0, "can't be assigned to value with unknown size");
66 value = other;
67 TruncInPlace();
68
69 return *this;
70 }
71
Assign(const IntVal & other)72 void Assign(const IntVal &other)
73 {
74 value = other.value;
75 width = other.width;
76 sign = other.sign;
77 }
78
79 /// @return bit-width of the value
GetBitWidth()80 uint8 GetBitWidth() const
81 {
82 return width;
83 }
84
85 /// @return true if the value is signed
IsSigned()86 bool IsSigned() const
87 {
88 return sign;
89 }
90
91 /// @return sign or zero extended value depending on its signedness
92 int64 GetExtValue(uint8 size = 0) const
93 {
94 return sign ? GetSXTValue(size) : GetZXTValue(size);
95 }
96
97 /// @return zero extended value
98 uint64 GetZXTValue(uint8 size = 0) const
99 {
100 // if size == 0, just return the value itself because it's already truncated for an appropriate width
101 return size ? (value << (valBitSize - size)) >> (valBitSize - size) : value;
102 }
103
104 /// @return sign extended value
105 int64 GetSXTValue(uint8 size = 0) const
106 {
107 uint8 bitWidth = size ? size : width;
108 return static_cast<int64>(value << (valBitSize - bitWidth)) >> (valBitSize - bitWidth);
109 }
110
111 /// @return true if the (most significant bit) MSB is set
GetSignBit()112 bool GetSignBit() const
113 {
114 return GetBit(width - 1);
115 }
116
117 /// @return true if all bits are 1
AreAllBitsOne()118 bool AreAllBitsOne() const
119 {
120 return value == (allOnes >> (valBitSize - width));
121 }
122
123 /// @return true if the value is maximum considering its signedness
IsMaxValue()124 bool IsMaxValue() const
125 {
126 return sign ? value == ((uint64(1) << (width - 1)) - 1) : AreAllBitsOne();
127 }
128
129 /// @return true if the value is minimum considering its signedness
IsMinValue()130 bool IsMinValue() const
131 {
132 return sign ? value == (uint64(1) << (width - 1)) : value == 0;
133 }
134
135 // Comparison operators that manipulate on values with the same sign and bit-width
136 bool operator==(const IntVal &rhs) const
137 {
138 DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
139 return value == rhs.value;
140 }
141
142 bool operator!=(const IntVal &rhs) const
143 {
144 return !(*this == rhs);
145 }
146
147 bool operator<(const IntVal &rhs) const
148 {
149 DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
150 return sign ? GetSXTValue() < rhs.GetSXTValue() : value < rhs.value;
151 }
152
153 bool operator>(const IntVal &rhs) const
154 {
155 DEBUG_ASSERT(width == rhs.width && sign == rhs.sign, "bit-width and sign must be the same");
156 return sign ? GetSXTValue() > rhs.GetSXTValue() : value > rhs.value;
157 }
158
159 bool operator<=(const IntVal &rhs) const
160 {
161 return !(*this > rhs);
162 }
163
164 bool operator>=(const IntVal &rhs) const
165 {
166 return !(*this < rhs);
167 }
168
169 // Arithmetic and bitwise operators that manipulate on values with the same sign and bit-width
170 IntVal operator+(const IntVal &val) const
171 {
172 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
173 return IntVal(value + val.value, width, sign);
174 }
175
176 IntVal operator-(const IntVal &val) const
177 {
178 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
179 return IntVal(value - val.value, width, sign);
180 }
181
182 IntVal operator*(const IntVal &val) const
183 {
184 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
185 return IntVal(value * val.value, width, sign);
186 }
187
188 IntVal operator/(const IntVal &divisor) const;
189
190 IntVal operator%(const IntVal &divisor) const;
191
192 IntVal operator&(const IntVal &val) const
193 {
194 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
195 return IntVal(value & val.value, width, sign);
196 }
197
198 IntVal operator|(const IntVal &val) const
199 {
200 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
201 return IntVal(value | val.value, width, sign);
202 }
203
204 IntVal operator^(const IntVal &val) const
205 {
206 DEBUG_ASSERT(width == val.width && sign == val.sign, "bit-width and sign must be the same");
207 return IntVal(value ^ val.value, width, sign);
208 }
209
210 /// @brief shift-left operator
211 IntVal operator<<(uint64 bits) const
212 {
213 DEBUG_ASSERT(bits <= width, "invalid shift value");
214 return IntVal(value << bits, width, sign);
215 }
216
217 IntVal operator<<(const IntVal &bits) const
218 {
219 return *this << bits.value;
220 }
221
222 /// @brief shift-right operator
223 /// @note if value is signed this operator works as arithmetic shift-right operator,
224 /// otherwise it's logical shift-right operator
225 IntVal operator>>(uint64 bits) const
226 {
227 DEBUG_ASSERT(bits <= width, "invalid shift value");
228 return IntVal(sign ? GetSXTValue() >> bits : value >> bits, width, sign);
229 }
230
231 IntVal operator>>(const IntVal &bits) const
232 {
233 return *this >> bits.value;
234 }
235
236 // Comparison operators that compare values obtained from primitive type.
237 /// @note Note that these functions work as follows:
238 /// 1) sign or zero extend both values (*this and/or rhs) to the new bit-width
239 /// obtained from pType depending on their original signedness;
240 /// or truncate the values if their original bit-width is greater than new one
241 /// 2) then perform the operation itself on new given values that have the same bit-width and sign
242 /// @warning it's better to avoid using these function in favor of operator==, operator< etc
Equal(const IntVal & rhs,PrimType pType)243 bool Equal(const IntVal &rhs, PrimType pType) const
244 {
245 return TruncOrExtend(pType) == rhs.TruncOrExtend(pType);
246 }
247
Less(const IntVal & rhs,PrimType pType)248 bool Less(const IntVal &rhs, PrimType pType) const
249 {
250 return TruncOrExtend(pType) < rhs.TruncOrExtend(pType);
251 }
252
Greater(const IntVal & rhs,PrimType pType)253 bool Greater(const IntVal &rhs, PrimType pType) const
254 {
255 return TruncOrExtend(pType) > rhs.TruncOrExtend(pType);
256 }
257
258 // Arithmetic and bitwise operators that allow creating a new value
259 // with the bit-width and sign obtained from primitive type
260 /// @note Note that these functions work as follows:
261 /// 1) sign or zero extend both values (*this and rhs) to the new bit-width
262 /// obtained from pType depending on their original signedness;
263 /// or truncate the values if their original bit-width is greater than new one
264 /// 2) then perform the operation itself on new given values that have the same bit-width and sign
265 /// @warning it's better to avoid using these function in favor of operator+, operator- etc
Add(const IntVal & val,PrimType pType)266 IntVal Add(const IntVal &val, PrimType pType) const
267 {
268 return TruncOrExtend(pType) + val.TruncOrExtend(pType);
269 }
270
Sub(const IntVal & val,PrimType pType)271 IntVal Sub(const IntVal &val, PrimType pType) const
272 {
273 return TruncOrExtend(pType) - val.TruncOrExtend(pType);
274 }
275
Mul(const IntVal & val,PrimType pType)276 IntVal Mul(const IntVal &val, PrimType pType) const
277 {
278 return TruncOrExtend(pType) * val.TruncOrExtend(pType);
279 }
280
Div(const IntVal & divisor,PrimType pType)281 IntVal Div(const IntVal &divisor, PrimType pType) const
282 {
283 return TruncOrExtend(pType) / divisor.TruncOrExtend(pType);
284 }
285
286 // sign division in terms of new bitWidth
SDiv(const IntVal & divisor,uint8 bitWidth)287 IntVal SDiv(const IntVal &divisor, uint8 bitWidth) const
288 {
289 return TruncOrExtend(bitWidth, true) / divisor.TruncOrExtend(bitWidth, true);
290 }
291
292 // unsigned division in terms of new bitWidth
UDiv(const IntVal & divisor,uint8 bitWidth)293 IntVal UDiv(const IntVal &divisor, uint8 bitWidth) const
294 {
295 return TruncOrExtend(bitWidth, false) / divisor.TruncOrExtend(bitWidth, false);
296 }
297
298 // unsigned division in terms of new bitWidth
Rem(const IntVal & divisor,PrimType pType)299 IntVal Rem(const IntVal &divisor, PrimType pType) const
300 {
301 return TruncOrExtend(pType) % divisor.TruncOrExtend(pType);
302 }
303
304 // signed modulo in terms of new bitWidth
SRem(const IntVal & divisor,uint8 bitWidth)305 IntVal SRem(const IntVal &divisor, uint8 bitWidth) const
306 {
307 return TruncOrExtend(bitWidth, true) % divisor.TruncOrExtend(bitWidth, true);
308 }
309
310 // unsigned modulo in terms of new bitWidth
URem(const IntVal & divisor,uint8 bitWidth)311 IntVal URem(const IntVal &divisor, uint8 bitWidth) const
312 {
313 return TruncOrExtend(bitWidth, false) % divisor.TruncOrExtend(bitWidth, false);
314 }
315
And(const IntVal & val,PrimType pType)316 IntVal And(const IntVal &val, PrimType pType) const
317 {
318 return TruncOrExtend(pType) & val.TruncOrExtend(pType);
319 }
320
Or(const IntVal & val,PrimType pType)321 IntVal Or(const IntVal &val, PrimType pType) const
322 {
323 return TruncOrExtend(pType) | val.TruncOrExtend(pType);
324 }
325
Xor(const IntVal & val,PrimType pType)326 IntVal Xor(const IntVal &val, PrimType pType) const
327 {
328 return TruncOrExtend(pType) ^ val.TruncOrExtend(pType);
329 }
330
331 // left-shift operators
Shl(const IntVal & shift,PrimType pType)332 IntVal Shl(const IntVal &shift, PrimType pType) const
333 {
334 return Shl(shift.value, pType);
335 }
336
Shl(uint64 shift,PrimType pType)337 IntVal Shl(uint64 shift, PrimType pType) const
338 {
339 return TruncOrExtend(pType) << shift;
340 }
341
342 // logical right-shift operators (MSB is zero extended)
LShr(const IntVal & shift,PrimType pType)343 IntVal LShr(const IntVal &shift, PrimType pType) const
344 {
345 return LShr(shift.value, pType);
346 }
347
LShr(uint64 shift,PrimType pType)348 IntVal LShr(uint64 shift, PrimType pType) const
349 {
350 IntVal ret = TruncOrExtend(pType);
351
352 DEBUG_ASSERT(shift <= ret.width, "invalid shift value");
353 ret.value >>= shift;
354
355 return ret;
356 }
357
LShr(uint64 shift)358 IntVal LShr(uint64 shift) const
359 {
360 DEBUG_ASSERT(shift <= width, "invalid shift value");
361 return IntVal(value >> shift, width, sign);
362 }
363
364 // arithmetic right-shift operators (MSB is sign extended)
AShr(const IntVal & shift,PrimType pType)365 IntVal AShr(const IntVal &shift, PrimType pType) const
366 {
367 return AShr(shift.value, pType);
368 }
369
AShr(uint64 shift,PrimType pType)370 IntVal AShr(uint64 shift, PrimType pType) const
371 {
372 IntVal ret = TruncOrExtend(pType);
373
374 DEBUG_ASSERT(shift <= ret.width, "invalid shift value");
375 ret.value = ret.GetSXTValue() >> shift;
376 ret.TruncInPlace();
377
378 return ret;
379 }
380
381 /// @brief invert all bits of value
382 IntVal operator~() const
383 {
384 return IntVal(~value, width, sign);
385 }
386
387 /// @return negated value
388 IntVal operator-() const
389 {
390 return IntVal(~value + 1, width, sign);
391 }
392
393 /// @brief truncate value to the given bit-width in-place.
394 /// @note if bitWidth is not passed (or zero), the original
395 /// bit-width is preserved and the value is truncated to the original bit-width
396 void TruncInPlace(uint8 bitWidth = 0)
397 {
398 DEBUG_ASSERT(valBitSize >= bitWidth, "invalid bit-width for truncate");
399
400 value &= allOnes >> (valBitSize - (bitWidth ? bitWidth : width));
401
402 if (bitWidth) {
403 width = bitWidth;
404 }
405 }
406
407 /// @return truncated value to the given bit-width
408 /// @note returned value will have bit-width and sign obtained from newType
Trunc(PrimType newType)409 IntVal Trunc(PrimType newType) const
410 {
411 return Trunc(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
412 }
413
414 /// @return sign or zero extended value depending on its signedness
415 /// @note returned value will have bit-width and sign obtained from newType
Extend(PrimType newType)416 IntVal Extend(PrimType newType) const
417 {
418 return Extend(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
419 }
420
421 /// @return sign/zero extended value or truncated value depending on bit-width
422 /// @note returned value will have bit-width and sign obtained from newType
TruncOrExtend(PrimType newType)423 IntVal TruncOrExtend(PrimType newType) const
424 {
425 return TruncOrExtend(GetPrimTypeActualBitSize(newType), IsSignedInteger(newType));
426 }
427
TruncOrExtend(uint8 newWidth,bool isSigned)428 IntVal TruncOrExtend(uint8 newWidth, bool isSigned) const
429 {
430 return newWidth <= width ? Trunc(newWidth, isSigned) : Extend(newWidth, isSigned);
431 }
432
433 private:
GetBit(uint8 bit)434 bool GetBit(uint8 bit) const
435 {
436 DEBUG_ASSERT(bit < width, "Required bit is out of value range");
437 return (value & (uint64(1) << bit)) != 0;
438 }
439
Trunc(uint8 newWidth,bool isSigned)440 IntVal Trunc(uint8 newWidth, bool isSigned) const
441 {
442 return {value, newWidth, isSigned};
443 }
444
Extend(uint8 newWidth,bool isSigned)445 IntVal Extend(uint8 newWidth, bool isSigned) const
446 {
447 DEBUG_ASSERT(newWidth > width, "invalid size for extension");
448 return IntVal(GetExtValue(), newWidth, isSigned);
449 }
450
451 static constexpr uint8 valBitSize = sizeof(uint64) * CHAR_BIT;
452 static constexpr uint64 allOnes = uint64(~0);
453
454 uint64 value;
455 uint8 width;
456 bool sign;
457 };
458
459 // Additional comparison operators
460 inline bool operator==(const IntVal &v1, int64 v2)
461 {
462 return v1.GetExtValue() == v2;
463 }
464
465 inline bool operator==(int64 v1, const IntVal &v2)
466 {
467 return v2 == v1;
468 }
469
470 inline bool operator!=(const IntVal &v1, int64 v2)
471 {
472 return !(v1 == v2);
473 }
474
475 inline bool operator!=(int64 v1, const IntVal &v2)
476 {
477 return !(v2 == v1);
478 }
479
480 /// @return the smaller of two values
481 /// @note bit-width and sign must be the same for both parameters
Min(const IntVal & a,const IntVal & b)482 inline IntVal Min(const IntVal &a, const IntVal &b)
483 {
484 return a < b ? a : b;
485 }
486
487 /// @return the smaller of two values in terms of newType
488 /// @note returned value will have bit-width and sign obtained from newType
Min(const IntVal & a,const IntVal & b,PrimType newType)489 inline IntVal Min(const IntVal &a, const IntVal &b, PrimType newType)
490 {
491 return a.Less(b, newType) ? IntVal(a, newType) : IntVal(b, newType);
492 }
493
494 /// @return the larger of two values
495 /// @note bit-width and sign must be the same for both parameters
Max(const IntVal & a,const IntVal & b)496 inline IntVal Max(const IntVal &a, const IntVal &b)
497 {
498 return Min(a, b) == a ? b : a;
499 }
500
501 /// @return the larger of two values in terms of newType
502 /// @note returned value will have bit-width and sign obtained from newType
Max(const IntVal & a,const IntVal & b,PrimType newType)503 inline IntVal Max(const IntVal &a, const IntVal &b, PrimType newType)
504 {
505 return Min(a, b, newType) == a ? b : a;
506 }
507
508 /// @brief dump IntVal object to the output stream
509 std::ostream &operator<<(std::ostream &os, const IntVal &value);
510
511 // Arithmetic operators that manipulate on scalar (uint64) value and IntVal object
512 // in terms of sign and bit-width of IntVal object
513 inline IntVal operator+(const IntVal &v1, uint64 v2)
514 {
515 return v1 + IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
516 }
517
518 inline IntVal operator+(uint64 v1, const IntVal &v2)
519 {
520 return v2 + v1;
521 }
522
523 inline IntVal operator-(const IntVal &v1, uint64 v2)
524 {
525 return v1 - IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
526 }
527
528 inline IntVal operator-(uint64 v1, const IntVal &v2)
529 {
530 return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) - v2;
531 }
532
533 inline IntVal operator*(const IntVal &v1, uint64 v2)
534 {
535 return v1 * IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
536 }
537
538 inline IntVal operator*(uint64 v1, const IntVal &v2)
539 {
540 return v2 * v1;
541 }
542
543 inline IntVal operator/(const IntVal &v1, uint64 v2)
544 {
545 return v1 / IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
546 }
547
548 inline IntVal operator/(uint64 v1, const IntVal &v2)
549 {
550 return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) / v2;
551 }
552
553 inline IntVal operator%(const IntVal &v1, uint64 v2)
554 {
555 return v1 % IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
556 }
557
558 inline IntVal operator%(uint64 v1, const IntVal &v2)
559 {
560 return IntVal(v1, v2.GetBitWidth(), v2.IsSigned()) % v2;
561 }
562
563 inline IntVal operator&(const IntVal &v1, uint64 v2)
564 {
565 return v1 & IntVal(v2, v1.GetBitWidth(), v1.IsSigned());
566 }
567
568 inline IntVal operator&(uint64 v1, const IntVal &v2)
569 {
570 return v2 & v1;
571 }
572
573 } // namespace maple
574
575 #endif // MAPLE_UTIL_INCLUDE_MPL_INT_VAL_H
576