• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef BERBERIS_DECODER_RISCV64_DECODER_H_
18 #define BERBERIS_DECODER_RISCV64_DECODER_H_
19 
20 #include <climits>
21 #include <cstdint>
22 #include <cstdlib>
23 #include <cstring>
24 #include <type_traits>
25 
26 #include "berberis/base/bit_util.h"
27 
28 namespace berberis {
29 
30 // Decode() method takes a sequence of bytes and decodes it into the instruction opcode and fields.
31 // The InsnConsumer's method corresponding to the decoded opcode is called with the decoded fields
32 // as an argument. Returned is the instruction size.
33 template <class InsnConsumer>
34 class Decoder {
35  public:
Decoder(InsnConsumer * insn_consumer)36   explicit Decoder(InsnConsumer* insn_consumer) : insn_consumer_(insn_consumer) {}
37 
38   // https://eel.is/c++draft/enum#dcl.enum-8
39   // For an enumeration whose underlying type is fixed, the values of the enumeration are the values
40   // of the underlying type. Otherwise, the values of the enumeration are the values representable
41   // by a hypothetical integer type with minimal width M such that all enumerators can be
42   // represented. The width of the smallest bit-field large enough to hold all the values of the
43   // enumeration type is M. It is possible to define an enumeration that has values not defined by
44   // any of its enumerators. If the enumerator-list is empty, the values of the enumeration are as
45   // if the enumeration had a single enumerator with value 0.
46 
47   // To ensure that we wouldn't trigger UB by accident each opcode includes kMaxXXX value (kOpOcode,
48   // kSystemOpcode and so on) which have all possible bit values set.
49   enum class BaseOpcode {
50     kLoad = 0b00'000,
51     kLoadFp = 0b00'001,
52     kCustom0 = 0b00'010,
53     kMiscMem = 0b00'011,
54     kOpImm = 0b00'100,
55     kAuipc = 0b00'101,
56     kOpImm32 = 0b00'110,
57     // Reserved 0b00'111,
58     kStore = 0b01'000,
59     kStoreFp = 0b01'001,
60     kCustom1 = 0b01'010,
61     kAmo = 0b01'011,
62     kOp = 0b01'100,
63     kLui = 0b01'101,
64     kOp32 = 0b01'110,
65     // Reserved 0b01'111,
66     kMAdd = 0b10'000,
67     kMSub = 0b10'001,
68     kNmSub = 0b10'010,
69     kNmAdd = 0b10'011,
70     kOpFp = 0b10'100,
71     // Reserved 0b10'101,
72     kCustom2 = 0b10'110,
73     // Reserved 0b10'111,
74     kBranch = 0b11'000,
75     kJalr = 0b11'001,
76     // Reserved 0b11'010,
77     kJal = 0b11'011,
78     kSystem = 0b11'100,
79     // Reserved 0b11'101,
80     kCustom3 = 0b11'110,
81     // Reserved 0b11'111,
82     kMaxBaseOpcode = 0b11'111,
83   };
84 
85   enum class CompressedOpcode {
86     kAddi4spn = 0b00'000,
87     kFld = 0b001'00,
88     kLw = 0b010'00,
89     kLd = 0b011'00,
90     // Reserved 0b00'100
91     kFsd = 0b101'00,
92     kSw = 0b110'00,
93     kSd = 0b111'00,
94     kAddi = 0b000'01,
95     kAddiw = 0b001'01,
96     kLi = 0b010'01,
97     kLui_Addi16sp = 0b011'01,
98     kMisc_Alu = 0b100'01,
99     kJ = 0b101'01,
100     kBeqz = 0b110'01,
101     kBnez = 0b111'01,
102     kSlli = 0b000'10,
103     kFldsp = 0b001'10,
104     kLwsp = 0b010'10,
105     kDsp = 0b011'10,
106     kJr_Jalr_Mv_Add = 0b100'10,
107     kFdsp = 0b101'10,
108     kSwsp = 0b110'10,
109     kSdsp = 0b111'10,
110     // instruction with 0bxxx'11 opcodes are not compressed instruction and can not be in this
111     // table.
112     kMaxCompressedOpcode = 0b111'11,
113   };
114 
115   enum class CsrOpcode {
116     kCsrrw = 0b01,
117     kCsrrs = 0b10,
118     kCsrrc = 0b11,
119     kMaxCsrOpcode = 0b11,
120   };
121 
122   enum class CsrImmOpcode {
123     kCsrrwi = 0b01,
124     kCsrrsi = 0b10,
125     kCsrrci = 0b11,
126     kMaxCsrOpcode = 0b11,
127   };
128 
129   enum class FenceOpcode {
130     kFence = 0b0000,
131     kFenceTso = 0b1000,
132     kFenceMaxOpcode = 0b1111,
133   };
134 
135   enum class OpOpcode {
136     kAdd = 0b0000'000'000,
137     kSub = 0b0100'000'000,
138     kSll = 0b0000'000'001,
139     kSlt = 0b0000'000'010,
140     kSltu = 0b0000'000'011,
141     kXor = 0b0000'000'100,
142     kSrl = 0b0000'000'101,
143     kSra = 0b0100'000'101,
144     kOr = 0b0000'000'110,
145     kAnd = 0b0000'000'111,
146     kMul = 0b0000'001'000,
147     kMulh = 0b0000'001'001,
148     kMulhsu = 0b0000'001'010,
149     kMulhu = 0b0000'001'011,
150     kDiv = 0b0000'001'100,
151     kDivu = 0b0000'001'101,
152     kRem = 0b0000'001'110,
153     kRemu = 0b0000'001'111,
154     kMaxOpOpcode = 0b1111'111'111,
155   };
156 
157   enum class Op32Opcode {
158     kAddw = 0b0000'000'000,
159     kSubw = 0b0100'000'000,
160     kSllw = 0b0000'000'001,
161     kSrlw = 0b0000'000'101,
162     kSraw = 0b0100'000'101,
163     kMulw = 0b0000'001'000,
164     kDivw = 0b0000'001'100,
165     kDivuw = 0b0000'001'101,
166     kRemw = 0b0000'001'110,
167     kRemuw = 0b0000'001'111,
168     kMaxOp32Opcode = 0b1111'111'111,
169   };
170 
171   enum class AmoOpcode {
172     kLrW = 0b00010'010,
173     kScW = 0b00011'010,
174     kAmoswapW = 0b00001'010,
175     kAmoaddW = 0b00000'010,
176     kAmoxorW = 0b00100'010,
177     kAmoandW = 0b01100'010,
178     kAmoorW = 0b01000'010,
179     kAmominW = 0b10000'010,
180     kAmomaxW = 0b10100'010,
181     kAmominuW = 0b11000'010,
182     kAmomaxuW = 0b11100'010,
183     kLrD = 0b00010'011,
184     kScD = 0b00011'011,
185     kAmoswapD = 0b00001'011,
186     kAmoaddD = 0b00000'011,
187     kAmoxorD = 0b00100'011,
188     kAmoandD = 0b01100'011,
189     kAmoorD = 0b01000'011,
190     kAmominD = 0b10000'011,
191     kAmomaxD = 0b10100'011,
192     kAmominuD = 0b11000'011,
193     kAmomaxuD = 0b11100'011,
194     kMaxAmoOpcode = 0b11111'111,
195   };
196 
197   enum class OpFpOpcode {
198     // Bit #2 = 1 means rm is an opcode extension.
199     // Bit #3 = 1 means rs2 is an opcode extension
200     // Bits #4, #1, and #0 - actual opcode.
201     kFAdd = 0b0'0'0'00,
202     kFSub = 0b0'0'0'01,
203     kFMul = 0b0'0'0'10,
204     kFDiv = 0b0'0'0'11,
205     kMaxOpFpOpcode = 0b1'1'1'11,
206   };
207 
208   enum class LoadOpcode {
209     kLb = 0b000,
210     kLh = 0b001,
211     kLw = 0b010,
212     kLd = 0b011,
213     kLbu = 0b100,
214     kLhu = 0b101,
215     kLwu = 0b110,
216     kMaxLoadOpcode = 0b1111,
217   };
218 
219   enum class LoadFpOpcode {
220     kFlw = 0b010,
221     kFld = 0b011,
222     kLoadFpMaxOpcode = 0b111,
223   };
224 
225   enum class OpImmOpcode {
226     kAddi = 0b000,
227     kSlti = 0b010,
228     kSltiu = 0b011,
229     kXori = 0b100,
230     kOri = 0b110,
231     kAndi = 0b111,
232     kMaxOpImmOpcode = 0b111,
233   };
234 
235   enum class OpImm32Opcode {
236     kAddiw = 0b000,
237     kMaxOpImm32Opcode = 0b111,
238   };
239 
240   enum class ShiftImmOpcode {
241     kSlli = 0b000000'001,
242     kSrli = 0b000000'101,
243     kSrai = 0b010000'101,
244     kMaxShiftImmOpcode = 0b11111'111,
245   };
246 
247   enum class ShiftImm32Opcode {
248     kSlliw = 0b0000000'001,
249     kSrliw = 0b0000000'101,
250     kSraiw = 0b0100000'101,
251     kMaxShiftImm32Opcode = 0b111111'111,
252   };
253 
254   enum class StoreOpcode {
255     kSb = 0b000,
256     kSh = 0b001,
257     kSw = 0b010,
258     kSd = 0b011,
259     kMaxStoreOpcode = 0b111,
260   };
261 
262   enum class StoreFpOpcode {
263     kFsw = 0b010,
264     kFsd = 0b011,
265     kMaxStoreFpOpcode = 0b111,
266   };
267 
268   enum class SystemOpcode {
269     kEcall = 0b000000000000'00000'000'00000,
270     kEbreak = 0b000000000001'00000'000'00000,
271     kMaxSystemOpcode = 0b111111111111'11111'111'11111,
272   };
273 
274   enum class BranchOpcode {
275     kBeq = 0b000,
276     kBne = 0b001,
277     kBlt = 0b100,
278     kBge = 0b101,
279     kBltu = 0b110,
280     kBgeu = 0b111,
281     kMaxBranchOpcode = 0b111,
282   };
283 
284   enum class CsrRegister {
285     kFFlags = 0b00'00'0000'0001,
286     kFrm = 0b00'00'0000'0010,
287     kFCsr = 0b00'00'0000'0011,
288     kMaxCsrRegister = 0b11'11'1111'1111,
289   };
290 
291   enum class FloatSize {
292     kFloat = 0b00,
293     kDouble = 0b01,
294     kHalf = 0b10,
295     kQuad = 0b11,
296     kMaxFloatSize = 0b11,
297   };
298 
299   struct AmoArgs {
300     AmoOpcode opcode;
301     uint8_t dst;
302     uint8_t src1;
303     uint8_t src2;
304     bool rl : 1;
305     bool aq : 1;
306   };
307 
308   struct CsrArgs {
309     CsrOpcode opcode;
310     uint8_t dst;
311     uint8_t src;
312     CsrRegister csr;
313   };
314 
315   struct CsrImmArgs {
316     CsrImmOpcode opcode;
317     uint8_t dst;
318     uint8_t imm;
319     CsrRegister csr;
320   };
321 
322   struct FenceArgs {
323     FenceOpcode opcode;
324     uint8_t dst;
325     uint8_t src;
326     bool sw : 1;
327     bool sr : 1;
328     bool so : 1;
329     bool si : 1;
330     bool pw : 1;
331     bool pr : 1;
332     bool po : 1;
333     bool pi : 1;
334   };
335 
336   struct FenceIArgs {
337     uint8_t dst;
338     uint8_t src;
339     int16_t imm;
340   };
341 
342   template <typename OpcodeType>
343   struct OpArgsTemplate {
344     OpcodeType opcode;
345     uint8_t dst;
346     uint8_t src1;
347     uint8_t src2;
348   };
349 
350   using OpArgs = OpArgsTemplate<OpOpcode>;
351   using Op32Args = OpArgsTemplate<Op32Opcode>;
352 
353   template <typename OpcodeType>
354   struct LoadArgsTemplate {
355     OpcodeType opcode;
356     uint8_t dst;
357     uint8_t src;
358     int16_t offset;
359   };
360 
361   using LoadArgs = LoadArgsTemplate<LoadOpcode>;
362   using LoadFpArgs = LoadArgsTemplate<LoadFpOpcode>;
363 
364   template <typename OpcodeType>
365   struct OpImmArgsTemplate {
366     OpcodeType opcode;
367     uint8_t dst;
368     uint8_t src;
369     int16_t imm;
370   };
371 
372   using OpImmArgs = OpImmArgsTemplate<OpImmOpcode>;
373   using OpImm32Args = OpImmArgsTemplate<OpImm32Opcode>;
374 
375   struct SystemArgs {
376     SystemOpcode opcode;
377   };
378 
379   template <typename OpcodeType>
380   struct ShiftImmArgsTemplate {
381     OpcodeType opcode;
382     uint8_t dst;
383     uint8_t src;
384     uint8_t imm;
385   };
386 
387   using ShiftImmArgs = ShiftImmArgsTemplate<ShiftImmOpcode>;
388   using ShiftImm32Args = ShiftImmArgsTemplate<ShiftImm32Opcode>;
389 
390   template <typename OpcodeType>
391   struct StoreArgsTemplate {
392     OpcodeType opcode;
393     uint8_t src;
394     int16_t offset;
395     uint8_t data;
396   };
397 
398   using StoreArgs = StoreArgsTemplate<StoreOpcode>;
399   using StoreFpArgs = StoreArgsTemplate<StoreFpOpcode>;
400 
401   struct OpFpArgs {
402     OpFpOpcode opcode;
403     FloatSize float_size;
404     uint8_t dst;
405     uint8_t src1;
406     uint8_t src2;
407     uint8_t rm;
408   };
409 
410   struct BranchArgs {
411     BranchOpcode opcode;
412     uint8_t src1;
413     uint8_t src2;
414     int16_t offset;
415   };
416 
417   struct UpperImmArgs {
418     uint8_t dst;
419     int32_t imm;
420   };
421 
422   struct JumpAndLinkArgs {
423     uint8_t dst;
424     int32_t offset;
425     uint8_t insn_len;
426   };
427 
428   struct JumpAndLinkRegisterArgs {
429     uint8_t dst;
430     uint8_t base;
431     int16_t offset;
432     uint8_t insn_len;
433   };
434 
Decode(const uint16_t * code)435   uint8_t Decode(const uint16_t* code) {
436     constexpr uint16_t kInsnLenMask = uint16_t{0b11};
437     if ((*code & kInsnLenMask) != kInsnLenMask) {
438       code_ = *code;
439       return DecodeCompressedInstruction();
440     }
441     // Warning: do not cast and dereference the pointer
442     // since the address may not be 4-bytes aligned.
443     memcpy(&code_, code, sizeof(code_));
444     return DecodeBaseInstruction();
445   }
446 
DecodeCompressedInstruction()447   uint8_t DecodeCompressedInstruction() {
448     CompressedOpcode opcode_bits{(GetBits<uint8_t, 13, 3>() << 2) | GetBits<uint8_t, 0, 2>()};
449 
450     switch (opcode_bits) {
451       case CompressedOpcode::kJ:
452         DecodeCJ();
453         break;
454       case CompressedOpcode::kAddi4spn:
455         DecodeCAddi4spn();
456         break;
457       case CompressedOpcode::kAddi:
458         DecodeCAddi();
459         break;
460       case CompressedOpcode::kFld:
461         DecodeCompressedLoadStore<LoadFpOpcode::kFld>();
462         break;
463       case CompressedOpcode::kLw:
464         DecodeCompressedLoadStore<LoadOpcode::kLw>();
465         break;
466       case CompressedOpcode::kLd:
467         DecodeCompressedLoadStore<LoadOpcode::kLd>();
468         break;
469       case CompressedOpcode::kFsd:
470         DecodeCompressedLoadStore<StoreFpOpcode::kFsd>();
471         break;
472       case CompressedOpcode::kSd:
473         DecodeCompressedLoadStore<StoreOpcode::kSd>();
474         break;
475       default:
476         insn_consumer_->Unimplemented();
477     }
478     return 2;
479   }
480 
481   template <auto opcode>
DecodeCompressedLoadStore()482   void DecodeCompressedLoadStore() {
483     uint8_t low_imm = GetBits<uint8_t, 5, 2>();
484     uint8_t high_imm = GetBits<uint8_t, 10, 3>();
485     uint8_t imm;
486     if constexpr ((uint8_t(opcode) & 1) == 0) {
487       constexpr uint8_t kLwLow[4] = {0x0, 0x40, 0x04, 0x44};
488       imm = (kLwLow[low_imm] | high_imm << 3);
489     } else {
490       imm = (low_imm << 6 | high_imm << 3);
491     }
492     uint8_t rd = GetBits<uint8_t, 2, 3>();
493     uint8_t rs = GetBits<uint8_t, 7, 3>();
494     if constexpr (std::is_same_v<decltype(opcode), StoreOpcode> ||
495                   std::is_same_v<decltype(opcode), StoreFpOpcode>) {
496       const StoreArgsTemplate<decltype(opcode)> args = {
497           .opcode = opcode,
498           .src = uint8_t(8 + rs),
499           .offset = imm,
500           .data = uint8_t(8 + rd),
501       };
502       insn_consumer_->Store(args);
503     } else {
504       const LoadArgsTemplate<decltype(opcode)> args = {
505           .opcode = opcode,
506           .dst = uint8_t(8 + rd),
507           .src = uint8_t(8 + rs),
508           .offset = imm,
509       };
510       insn_consumer_->Load(args);
511     }
512   }
513 
DecodeCAddi()514   void DecodeCAddi() {
515     uint8_t low_imm = GetBits<uint8_t, 2, 5>();
516     uint8_t high_imm = GetBits<uint8_t, 12, 1>();
517     int8_t imm = SignExtend<6>(high_imm << 5 | low_imm);
518     uint8_t r = GetBits<uint8_t, 7, 5>();
519     if (r == 0 || imm == 0) {
520       insn_consumer_->Nop();
521     }
522     const OpImmArgs args = {
523         .opcode = OpImmOpcode::kAddi,
524         .dst = r,
525         .src = r,
526         .imm = imm,
527     };
528     insn_consumer_->OpImm(args);
529   }
530 
DecodeCJ()531   void DecodeCJ() {
532     constexpr uint16_t kJHigh[32] = {
533         0x0,    0x400,  0x100,  0x500,  0x200,  0x600,  0x300,  0x700,  0x10,   0x410,  0x110,
534         0x510,  0x210,  0x610,  0x310,  0x710,  0xf800, 0xfc00, 0xf900, 0xfd00, 0xfa00, 0xfe00,
535         0xfb00, 0xff00, 0xf810, 0xfc10, 0xf910, 0xfd10, 0xfa10, 0xfe10, 0xfb10, 0xff10,
536     };
537     constexpr uint8_t kJLow[64] = {
538         0x0,  0x20, 0x2,  0x22, 0x4,  0x24, 0x6,  0x26, 0x8,  0x28, 0xa,  0x2a, 0xc,
539         0x2c, 0xe,  0x2e, 0x80, 0xa0, 0x82, 0xa2, 0x84, 0xa4, 0x86, 0xa6, 0x88, 0xa8,
540         0x8a, 0xaa, 0x8c, 0xac, 0x8e, 0xae, 0x40, 0x60, 0x42, 0x62, 0x44, 0x64, 0x46,
541         0x66, 0x48, 0x68, 0x4a, 0x6a, 0x4c, 0x6c, 0x4e, 0x6e, 0xc0, 0xe0, 0xc2, 0xe2,
542         0xc4, 0xe4, 0xc6, 0xe6, 0xc8, 0xe8, 0xca, 0xea, 0xcc, 0xec, 0xce, 0xee,
543     };
544     const JumpAndLinkArgs args = {
545         .dst = 0,
546         .offset =
547             bit_cast<int16_t>(kJHigh[GetBits<uint16_t, 8, 5>()]) | kJLow[GetBits<uint16_t, 2, 6>()],
548         .insn_len = 2,
549     };
550     insn_consumer_->JumpAndLink(args);
551   }
552 
DecodeCAddi4spn()553   void DecodeCAddi4spn() {
554     constexpr uint8_t kAddi4spnHigh[16] = {
555         0x0, 0x40, 0x80, 0xc0, 0x4, 0x44, 0x84, 0xc4, 0x8, 0x48, 0x88, 0xc8, 0xc, 0x4c, 0x8c, 0xcc};
556     constexpr uint8_t kAddi4spnLow[16] = {
557         0x0, 0x2, 0x1, 0x3, 0x10, 0x12, 0x11, 0x13, 0x20, 0x22, 0x21, 0x23, 0x30, 0x32, 0x31, 0x33};
558     int16_t imm = (kAddi4spnHigh[GetBits<uint8_t, 9, 4>()] | kAddi4spnLow[GetBits<uint8_t, 5, 4>()])
559                   << 2;
560     // If immediate is zero then this instruction is treated as unimplemented.
561     // This includes RISC-V dedicated 16bit “unimplemented instruction” 0x0000.
562     if (imm == 0) {
563       return Undefined();
564     }
565     const OpImmArgs args = {
566         .opcode = OpImmOpcode::kAddi,
567         .dst = uint8_t(8 + GetBits<uint8_t, 2, 3>()),
568         .src = 2,
569         .imm = imm,
570     };
571     insn_consumer_->OpImm(args);
572   }
573 
DecodeBaseInstruction()574   uint8_t DecodeBaseInstruction() {
575     BaseOpcode opcode_bits{GetBits<uint8_t, 2, 5>()};
576 
577     switch (opcode_bits) {
578       case BaseOpcode::kMiscMem:
579         DecodeMiscMem();
580         break;
581       case BaseOpcode::kOp:
582         DecodeOp<OpOpcode>();
583         break;
584       case BaseOpcode::kOp32:
585         DecodeOp<Op32Opcode>();
586         break;
587       case BaseOpcode::kAmo:
588         DecodeAmo();
589         break;
590       case BaseOpcode::kLoad:
591         DecodeLoad<LoadOpcode>();
592         break;
593       case BaseOpcode::kLoadFp:
594         DecodeLoad<LoadFpOpcode>();
595         break;
596       case BaseOpcode::kOpImm:
597         DecodeOp<OpImmOpcode, ShiftImmOpcode, 6>();
598         break;
599       case BaseOpcode::kOpImm32:
600         DecodeOp<OpImm32Opcode, ShiftImm32Opcode, 5>();
601         break;
602       case BaseOpcode::kOpFp:
603         DecodeOpFp();
604         break;
605       case BaseOpcode::kStore:
606         DecodeStore<StoreOpcode>();
607         break;
608       case BaseOpcode::kStoreFp:
609         DecodeStore<StoreFpOpcode>();
610         break;
611       case BaseOpcode::kBranch:
612         DecodeBranch();
613         break;
614       case BaseOpcode::kJal:
615         DecodeJumpAndLink();
616         break;
617       case BaseOpcode::kJalr:
618         DecodeJumpAndLinkRegister();
619         break;
620       case BaseOpcode::kSystem:
621         DecodeSystem();
622         break;
623       case BaseOpcode::kLui:
624         DecodeLui();
625         break;
626       case BaseOpcode::kAuipc:
627         DecodeAuipc();
628         break;
629       default:
630         insn_consumer_->Unimplemented();
631     }
632     return 4;
633   }
634 
635  private:
636   template <typename ResultType, uint32_t start, uint32_t size>
GetBits()637   ResultType GetBits() {
638     static_assert(std::is_unsigned_v<ResultType>, "Only unsigned types are supported");
639     static_assert(sizeof(ResultType) * CHAR_BIT >= size, "Too small ResultType for size");
640     static_assert((start + size) <= 32 && size > 0, "Invalid start or size value");
641     uint32_t shifted_val = code_ << (32 - start - size);
642     return static_cast<ResultType>(shifted_val >> (32 - size));
643   }
644 
645   // Signextend bits from size to the corresponding signed type of sizeof(Type) size.
646   // If the result of this function is assigned to a wider signed type it'll automatically
647   // sign-extend.
648   template <unsigned size, typename Type>
SignExtend(const Type val)649   static auto SignExtend(const Type val) {
650     static_assert(std::is_integral_v<Type>, "Only integral types are supported");
651     static_assert(size > 0 && size < (sizeof(Type) * CHAR_BIT), "Invalid size value");
652     typedef std::make_signed_t<Type> SignedType;
653     struct {
654       SignedType val : size;
655     } holder = {.val = static_cast<SignedType>(val)};
656     // Compiler takes care of sign-extension of the field with the specified bit-length.
657     return static_cast<SignedType>(holder.val);
658   }
659 
Undefined()660   void Undefined() {
661     // TODO(b/265372622): Handle undefined differently from unimplemented.
662     insn_consumer_->Unimplemented();
663   }
664 
DecodeMiscMem()665   void DecodeMiscMem() {
666     uint8_t low_opcode = GetBits<uint8_t, 12, 3>();
667     switch (low_opcode) {
668       case 0b000: {
669         uint8_t high_opcode = GetBits<uint8_t, 28, 4>();
670         FenceOpcode opcode = FenceOpcode{high_opcode};
671         const FenceArgs args = {
672             .opcode = opcode,
673             .dst = GetBits<uint8_t, 7, 5>(),
674             .src = GetBits<uint8_t, 15, 5>(),
675             .sw = bool(GetBits<uint8_t, 20, 1>()),
676             .sr = bool(GetBits<uint8_t, 21, 1>()),
677             .so = bool(GetBits<uint8_t, 22, 1>()),
678             .si = bool(GetBits<uint8_t, 23, 1>()),
679             .pw = bool(GetBits<uint8_t, 24, 1>()),
680             .pr = bool(GetBits<uint8_t, 25, 1>()),
681             .pi = bool(GetBits<uint8_t, 26, 1>()),
682             .po = bool(GetBits<uint8_t, 27, 1>()),
683         };
684         insn_consumer_->Fence(args);
685         break;
686       }
687       case 0b001: {
688         uint16_t imm = GetBits<uint16_t, 20, 12>();
689         const FenceIArgs args = {
690             .dst = GetBits<uint8_t, 7, 5>(),
691             .src = GetBits<uint8_t, 15, 5>(),
692             .imm = SignExtend<12>(imm),
693         };
694         insn_consumer_->FenceI(args);
695         break;
696       }
697       default:
698         return Undefined();
699     }
700   }
701 
702   template <typename OpcodeType>
DecodeOp()703   void DecodeOp() {
704     uint16_t low_opcode = GetBits<uint16_t, 12, 3>();
705     uint16_t high_opcode = GetBits<uint16_t, 25, 7>();
706     OpcodeType opcode{int16_t(low_opcode | (high_opcode << 3))};
707     const OpArgsTemplate<OpcodeType> args = {
708         .opcode = opcode,
709         .dst = GetBits<uint8_t, 7, 5>(),
710         .src1 = GetBits<uint8_t, 15, 5>(),
711         .src2 = GetBits<uint8_t, 20, 5>(),
712     };
713     insn_consumer_->Op(args);
714   }
715 
DecodeAmo()716   void DecodeAmo() {
717     uint16_t low_opcode = GetBits<uint16_t, 12, 3>();
718     uint16_t high_opcode = GetBits<uint16_t, 27, 5>();
719     // lr instruction must have rs2 == 0
720     if (high_opcode == 0b00010 && GetBits<uint8_t, 20, 5>() != 0) {
721       return Undefined();
722     }
723     AmoOpcode opcode = AmoOpcode{low_opcode | (high_opcode << 3)};
724     const AmoArgs args = {
725         .opcode = opcode,
726         .dst = GetBits<uint8_t, 7, 5>(),
727         .src1 = GetBits<uint8_t, 15, 5>(),
728         .src2 = GetBits<uint8_t, 20, 5>(),
729         .rl = bool(GetBits<uint8_t, 25, 1>()),
730         .aq = bool(GetBits<uint8_t, 26, 1>()),
731     };
732     insn_consumer_->Amo(args);
733   }
734 
DecodeLui()735   void DecodeLui() {
736     int32_t imm = GetBits<uint32_t, 12, 20>();
737     const UpperImmArgs args = {
738         .dst = GetBits<uint8_t, 7, 5>(),
739         .imm = imm << 12,
740     };
741     insn_consumer_->Lui(args);
742   }
743 
DecodeAuipc()744   void DecodeAuipc() {
745     int32_t imm = GetBits<uint32_t, 12, 20>();
746     const UpperImmArgs args = {
747         .dst = GetBits<uint8_t, 7, 5>(),
748         .imm = imm << 12,
749     };
750     insn_consumer_->Auipc(args);
751   }
752 
753   template <typename OpcodeType>
DecodeLoad()754   void DecodeLoad() {
755     OpcodeType opcode{GetBits<uint8_t, 12, 3>()};
756     const LoadArgsTemplate<OpcodeType> args = {
757         .opcode = opcode,
758         .dst = GetBits<uint8_t, 7, 5>(),
759         .src = GetBits<uint8_t, 15, 5>(),
760         .offset = SignExtend<12>(GetBits<uint16_t, 20, 12>()),
761     };
762     insn_consumer_->Load(args);
763   }
764 
765   template <typename OpcodeType>
DecodeStore()766   void DecodeStore() {
767     OpcodeType opcode{GetBits<uint8_t, 12, 3>()};
768 
769     uint16_t low_imm = GetBits<uint16_t, 7, 5>();
770     uint16_t high_imm = GetBits<uint16_t, 25, 7>();
771 
772     const StoreArgsTemplate<OpcodeType> args = {
773         .opcode = opcode,
774         .src = GetBits<uint8_t, 15, 5>(),
775         .offset = SignExtend<12>(int16_t(low_imm | (high_imm << 5))),
776         .data = GetBits<uint8_t, 20, 5>(),
777     };
778     insn_consumer_->Store(args);
779   }
780 
781   template <typename OpOpcodeType, typename ShiftOcodeType, uint32_t kShiftFieldSize>
DecodeOp()782   void DecodeOp() {
783     uint8_t low_opcode = GetBits<uint8_t, 12, 3>();
784     if (low_opcode != 0b001 && low_opcode != 0b101) {
785       OpOpcodeType opcode{low_opcode};
786 
787       uint16_t imm = GetBits<uint16_t, 20, 12>();
788 
789       const OpImmArgsTemplate<OpOpcodeType> args = {
790           .opcode = opcode,
791           .dst = GetBits<uint8_t, 7, 5>(),
792           .src = GetBits<uint8_t, 15, 5>(),
793           .imm = SignExtend<12>(imm),
794       };
795       insn_consumer_->OpImm(args);
796     } else {
797       uint16_t high_opcode = GetBits<uint16_t, 20 + kShiftFieldSize, 12 - kShiftFieldSize>();
798       ShiftOcodeType opcode{int16_t(low_opcode | (high_opcode << 3))};
799 
800       const ShiftImmArgsTemplate<ShiftOcodeType> args = {
801           .opcode = opcode,
802           .dst = GetBits<uint8_t, 7, 5>(),
803           .src = GetBits<uint8_t, 15, 5>(),
804           .imm = GetBits<uint8_t, 20, kShiftFieldSize>(),
805       };
806       insn_consumer_->OpImm(args);
807     }
808   }
809 
DecodeBranch()810   void DecodeBranch() {
811     BranchOpcode opcode{GetBits<uint8_t, 12, 3>()};
812 
813     // Decode the offset.
814     auto low_imm = GetBits<uint16_t, 8, 4>();
815     auto mid_imm = GetBits<uint16_t, 25, 6>();
816     auto bit11_imm = GetBits<uint16_t, 7, 1>();
817     auto bit12_imm = GetBits<uint16_t, 31, 1>();
818     auto offset =
819         static_cast<int16_t>(low_imm | (mid_imm << 4) | (bit11_imm << 10) | (bit12_imm << 11));
820 
821     const BranchArgs args = {
822         .opcode = opcode,
823         .src1 = GetBits<uint8_t, 15, 5>(),
824         .src2 = GetBits<uint8_t, 20, 5>(),
825         // The offset is encoded as 2-byte units, we need to multiply by 2.
826         .offset = SignExtend<13>(int16_t(offset * 2)),
827     };
828     insn_consumer_->Branch(args);
829   }
830 
DecodeJumpAndLink()831   void DecodeJumpAndLink() {
832     // Decode the offset.
833     auto low_imm = GetBits<uint32_t, 21, 10>();
834     auto mid_imm = GetBits<uint32_t, 12, 8>();
835     auto bit11_imm = GetBits<uint32_t, 20, 1>();
836     auto bit20_imm = GetBits<uint32_t, 31, 1>();
837     auto offset =
838         static_cast<int32_t>(low_imm | (bit11_imm << 10) | (mid_imm << 11) | (bit20_imm << 19));
839 
840     const JumpAndLinkArgs args = {
841         .dst = GetBits<uint8_t, 7, 5>(),
842         // The offset is encoded as 2-byte units, we need to multiply by 2.
843         .offset = SignExtend<21>(offset * 2),
844         .insn_len = 4,
845     };
846     insn_consumer_->JumpAndLink(args);
847   }
848 
DecodeOpFp()849   void DecodeOpFp() {
850     uint8_t float_size = GetBits<uint8_t, 25, 2>();
851     uint8_t opcode_bits = GetBits<uint8_t, 27, 5>();
852     const OpFpArgs args = {
853         .opcode = OpFpOpcode(opcode_bits),
854         .float_size = FloatSize(float_size),
855         .dst = GetBits<uint8_t, 7, 5>(),
856         .src1 = GetBits<uint8_t, 15, 5>(),
857         .src2 = GetBits<uint8_t, 20, 5>(),
858         .rm = GetBits<uint8_t, 12, 3>(),
859     };
860     insn_consumer_->OpFp(args);
861   }
862 
DecodeSystem()863   void DecodeSystem() {
864     uint8_t low_opcode = GetBits<uint8_t, 12, 2>();
865     if (low_opcode == 0b00) {
866       int32_t opcode = GetBits<uint32_t, 7, 25>();
867       const SystemArgs args = {
868           .opcode = SystemOpcode(opcode),
869       };
870       return insn_consumer_->System(args);
871     }
872     if (GetBits<uint8_t, 14, 1>()) {
873       CsrImmOpcode opcode = CsrImmOpcode(low_opcode);
874       const CsrImmArgs args = {
875           .opcode = opcode,
876           .dst = GetBits<uint8_t, 7, 5>(),
877           .imm = GetBits<uint8_t, 15, 5>(),
878           .csr = CsrRegister(GetBits<uint16_t, 20, 12>()),
879       };
880       return insn_consumer_->Csr(args);
881     }
882     CsrOpcode opcode = CsrOpcode(low_opcode);
883     const CsrArgs args = {
884         .opcode = opcode,
885         .dst = GetBits<uint8_t, 7, 5>(),
886         .src = GetBits<uint8_t, 15, 5>(),
887         .csr = CsrRegister(GetBits<uint16_t, 20, 12>()),
888     };
889     return insn_consumer_->Csr(args);
890   }
891 
DecodeJumpAndLinkRegister()892   void DecodeJumpAndLinkRegister() {
893     if (GetBits<uint8_t, 12, 3>() != 0b000) {
894       Undefined();
895       return;
896     }
897     // Decode sign-extend offset.
898     int16_t offset = GetBits<uint16_t, 20, 12>();
899     offset = static_cast<int16_t>(offset << 4) >> 4;
900 
901     const JumpAndLinkRegisterArgs args = {
902         .dst = GetBits<uint8_t, 7, 5>(),
903         .base = GetBits<uint8_t, 15, 5>(),
904         .offset = offset,
905         .insn_len = 4,
906     };
907     insn_consumer_->JumpAndLinkRegister(args);
908   }
909 
910   InsnConsumer* insn_consumer_;
911   uint32_t code_;
912 };
913 
914 }  // namespace berberis
915 
916 #endif  // BERBERIS_DECODER_RISCV64_DECODER_H_
917