• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdbool.h>
18 #include <stdint.h>
19 #include <stdio.h>
20 #include <stdarg.h>
21 
22 #include "next/apf_defs.h"
23 #include "next/apf.h"
24 #include "disassembler.h"
25 
26 // If "c" is of a signed type, generate a compile warning that gets promoted to an error.
27 // This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding
28 // superfluous ">= 0" with unsigned expressions generates compile warnings.
29 #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c))
30 
31 char prefix_buf[16];
32 char print_buf[8196];
33 char* buf_ptr;
34 int buf_remain;
35 bool v6_mode = false;
36 
37 __attribute__ ((format (printf, 1, 2) ))
bprintf(const char * format,...)38 static void bprintf(const char* format, ...) {
39     va_list args;
40     va_start(args, format);
41     int ret = vsnprintf(buf_ptr, buf_remain, format, args);
42     va_end(args);
43     if (ret < 0) return;
44     if (ret >= buf_remain) ret = buf_remain;
45     buf_ptr += ret;
46     buf_remain -= ret;
47 }
48 
print_opcode(const char * opcode)49 static void print_opcode(const char* opcode) {
50     bprintf("%-12s", opcode);
51 }
52 
53 // Mapping from opcode number to opcode name.
54 static const char* opcode_names [] = {
55     [PASSDROP_OPCODE] = NULL,
56     [LDB_OPCODE] = "ldb",
57     [LDH_OPCODE] = "ldh",
58     [LDW_OPCODE] = "ldw",
59     [LDBX_OPCODE] = "ldbx",
60     [LDHX_OPCODE] = "ldhx",
61     [LDWX_OPCODE] = "ldwx",
62     [ADD_OPCODE] = "add",
63     [MUL_OPCODE] = "mul",
64     [DIV_OPCODE] = "div",
65     [AND_OPCODE] = "and",
66     [OR_OPCODE] = "or",
67     [SH_OPCODE] = "sh",
68     [LI_OPCODE] = "li",
69     [JMP_OPCODE] = "jmp",
70     [JEQ_OPCODE] = "jeq",
71     [JNE_OPCODE] = "jne",
72     [JGT_OPCODE] = "jgt",
73     [JLT_OPCODE] = "jlt",
74     [JSET_OPCODE] = "jset",
75     [JBSMATCH_OPCODE] = NULL,
76     [EXT_OPCODE] = NULL,
77     [LDDW_OPCODE] = "lddw",
78     [STDW_OPCODE] = "stdw",
79     [WRITE_OPCODE] = "write",
80     [PKTDATACOPY_OPCODE] = NULL,
81     [JNSET_OPCODE] = "jnset",
82     [JBSPTRMATCH_OPCODE] = NULL,
83     [ALLOC_XMIT_OPCODE] = NULL,
84 };
85 
print_jump_target(uint32_t target,uint32_t program_len)86 static void print_jump_target(uint32_t target, uint32_t program_len) {
87     if (target == program_len) {
88         bprintf("PASS");
89     } else if (target == program_len + 1) {
90         bprintf("DROP");
91     } else if (target > program_len + 1) {
92         uint32_t ofs = target - program_len;
93         uint32_t imm = ofs >> 1;
94         bprintf((ofs & 1) ? "cnt_and_drop" : "cnt_and_pass");
95         bprintf("[cnt=%d]", imm);
96     } else {
97         bprintf("%u", target);
98     }
99 }
100 
print_qtype(int qtype)101 static void print_qtype(int qtype) {
102     switch(qtype) {
103         case 1:
104             bprintf("A, ");
105             break;
106         case 28:
107             bprintf("AAAA, ");
108             break;
109         case 12:
110             bprintf("PTR, ");
111             break;
112         case 33:
113             bprintf("SRV, ");
114             break;
115         case 16:
116             bprintf("TXT, ");
117             break;
118         default:
119             bprintf("%d, ", qtype);
120     }
121 }
122 
apf_disassemble(const uint8_t * program,uint32_t program_len,uint32_t * const ptr2pc,bool is_v6)123 disas_ret apf_disassemble(const uint8_t* program, uint32_t program_len, uint32_t* const ptr2pc, bool is_v6) {
124     buf_ptr = print_buf;
125     buf_remain = sizeof(print_buf);
126     if (*ptr2pc > program_len + 1) {
127         snprintf(prefix_buf, sizeof(prefix_buf), "(%4u) ", 0);
128         bprintf("pc is overflow: pc %d, program_len: %d", *ptr2pc, program_len);
129         disas_ret ret = {
130             .prefix = prefix_buf,
131             .content = print_buf
132         };
133         return ret;
134     }
135     uint32_t prev_pc = *ptr2pc;
136 
137     bprintf("%4u: ", *ptr2pc);
138 
139     if (*ptr2pc == program_len) {
140         snprintf(prefix_buf, sizeof(prefix_buf), "(%4u) ", 0);
141         bprintf("PASS");
142         ++(*ptr2pc);
143         disas_ret ret = {
144             .prefix = prefix_buf,
145             .content = print_buf
146         };
147         return ret;
148     }
149 
150     if (*ptr2pc == program_len + 1) {
151         snprintf(prefix_buf, sizeof(prefix_buf), "(%4u) ", 0);
152         bprintf("DROP");
153         ++(*ptr2pc);
154         disas_ret ret = {
155             .prefix = prefix_buf,
156             .content = print_buf
157         };
158         return ret;
159     }
160 
161     const uint8_t bytecode = program[(*ptr2pc)++];
162     const uint32_t opcode = EXTRACT_OPCODE(bytecode);
163 
164 #define PRINT_OPCODE() print_opcode(opcode_names[opcode])
165 #define DECODE_IMM(length)  ({                                        \
166     uint32_t value = 0;                                               \
167     for (uint32_t i = 0; i < (length) && *ptr2pc < program_len; i++)  \
168         value = (value << 8) | program[(*ptr2pc)++];                  \
169     value;})
170 
171     const uint32_t reg_num = EXTRACT_REGISTER(bytecode);
172     // All instructions have immediate fields, so load them now.
173     const uint32_t len_field = EXTRACT_IMM_LENGTH(bytecode);
174     uint32_t imm = 0;
175     int32_t signed_imm = 0;
176     if (len_field != 0) {
177         const uint32_t imm_len = 1 << (len_field - 1);
178         imm = DECODE_IMM(imm_len);
179         // Sign extend imm into signed_imm.
180         signed_imm = imm << ((4 - imm_len) * 8);
181         signed_imm >>= (4 - imm_len) * 8;
182     }
183     switch (opcode) {
184         case PASSDROP_OPCODE:
185             if (reg_num == 0) {
186                 print_opcode("pass");
187             } else {
188                 print_opcode("drop");
189             }
190             if (imm > 0) {
191                 bprintf("counter=%d", imm);
192             }
193             break;
194         case LDB_OPCODE:
195         case LDH_OPCODE:
196         case LDW_OPCODE:
197             PRINT_OPCODE();
198             bprintf("r%d, [%u]", reg_num, imm);
199             break;
200         case LDBX_OPCODE:
201         case LDHX_OPCODE:
202         case LDWX_OPCODE:
203             PRINT_OPCODE();
204             if (imm) {
205                 bprintf("r%d, [r1+%u]", reg_num, imm);
206             } else {
207                 bprintf("r%d, [r1]", reg_num);
208             }
209             break;
210         case JMP_OPCODE:
211             if (reg_num == 0) {
212                 PRINT_OPCODE();
213                 print_jump_target(*ptr2pc + imm, program_len);
214             } else {
215                 v6_mode = true;
216                 print_opcode("data");
217                 bprintf("%d, ", imm);
218                 uint32_t len = imm;
219                 while (len--) bprintf("%02x", program[(*ptr2pc)++]);
220             }
221             break;
222         case JEQ_OPCODE:
223         case JNE_OPCODE:
224         case JGT_OPCODE:
225         case JLT_OPCODE:
226         case JSET_OPCODE:
227         case JNSET_OPCODE: {
228             PRINT_OPCODE();
229             bprintf("r0, ");
230             // Load second immediate field.
231             if (reg_num == 1) {
232                 bprintf("r1, ");
233             } else if (len_field == 0) {
234                 bprintf("0, ");
235             } else {
236                 uint32_t cmp_imm = DECODE_IMM(1 << (len_field - 1));
237                 bprintf("0x%x, ", cmp_imm);
238             }
239             print_jump_target(*ptr2pc + imm, program_len);
240             break;
241         }
242         case JBSMATCH_OPCODE: {
243             if (reg_num == 0) {
244                 print_opcode("jbsne");
245             } else {
246                 print_opcode("jbseq");
247             }
248             bprintf("r0, ");
249             const uint32_t cmp_imm = DECODE_IMM(1 << (len_field - 1));
250             const uint32_t cnt = (cmp_imm >> 11) + 1; // 1+, up to 32 fits in u16
251             const uint32_t len = cmp_imm & 2047; // 0..2047
252             bprintf("(%u), ", len);
253             print_jump_target(*ptr2pc + imm + cnt * len, program_len);
254             bprintf(", ");
255             if (cnt > 1) {
256                 bprintf("{ ");
257             }
258             for (uint32_t i = 0; i < cnt; ++i) {
259                 for (uint32_t j = 0; j < len; ++j) {
260                     uint8_t byte = program[(*ptr2pc)++];
261                     bprintf("%02x", byte);
262                 }
263                 if (i != cnt - 1) {
264                     bprintf(", ");
265                 }
266             }
267             if (cnt > 1) {
268                 bprintf(" }[%d]", cnt);
269             }
270             break;
271         }
272         case SH_OPCODE:
273             PRINT_OPCODE();
274             if (reg_num) {
275                 bprintf("r0, r1");
276             } else {
277                 bprintf("r0, %d", signed_imm);
278             }
279             break;
280         case ADD_OPCODE:
281         case AND_OPCODE: {
282             PRINT_OPCODE();
283             if (is_v6) {
284                 bprintf("r%d, ", reg_num);
285                 if (!imm) {
286                     bprintf("r%d", 1 - reg_num);
287                 } else if (opcode == AND_OPCODE) {
288                     bprintf("0x%x", signed_imm);
289                 } else {
290                     bprintf("%d", signed_imm);
291                 }
292             } else {
293                 if (reg_num) {
294                     bprintf("r0, r1");
295                 } else if (opcode == AND_OPCODE) {
296                     bprintf("r0, 0x%x", imm);
297                 } else {
298                     bprintf("r0, %u", imm);
299                 }
300             }
301             break;
302         }
303         case MUL_OPCODE:
304         case DIV_OPCODE:
305         case OR_OPCODE:
306             PRINT_OPCODE();
307             if (reg_num) {
308                 bprintf("r0, r1");
309             } else if (!imm && opcode == DIV_OPCODE) {
310                 bprintf("pass (div 0)");
311             } else if (opcode == OR_OPCODE) {
312                 bprintf("r0, 0x%x", imm);
313             } else {
314                 bprintf("r0, %u", imm);
315             }
316             break;
317         case LI_OPCODE:
318             PRINT_OPCODE();
319             bprintf("r%d, %d", reg_num, signed_imm);
320             break;
321         case EXT_OPCODE:
322             if (
323 // If LDM_EXT_OPCODE is 0 and imm is compared with it, a compiler error will result,
324 // instead just enforce that imm is unsigned (so it's always greater or equal to 0).
325 #if LDM_EXT_OPCODE == 0
326                 ENFORCE_UNSIGNED(imm) &&
327 #else
328                 imm >= LDM_EXT_OPCODE &&
329 #endif
330                 imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) {
331                 print_opcode("ldm");
332                 bprintf("r%d, m[%u]", reg_num, imm - LDM_EXT_OPCODE);
333             } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) {
334                 print_opcode("stm");
335                 bprintf("r%d, m[%u]", reg_num, imm - STM_EXT_OPCODE);
336             } else switch (imm) {
337                 case NOT_EXT_OPCODE:
338                     print_opcode("not");
339                     bprintf("r%d", reg_num);
340                     break;
341                 case NEG_EXT_OPCODE:
342                     print_opcode("neg");
343                     bprintf("r%d", reg_num);
344                     break;
345                 case SWAP_EXT_OPCODE:
346                     print_opcode("swap");
347                     break;
348                 case MOV_EXT_OPCODE:
349                     print_opcode("mov");
350                     bprintf("r%d, r%d", reg_num, reg_num ^ 1);
351                     break;
352                 case ALLOCATE_EXT_OPCODE:
353                     print_opcode("allocate");
354                     if (reg_num == 0) {
355                         bprintf("r%d", reg_num);
356                     } else {
357                         uint32_t alloc_len = DECODE_IMM(2);
358                         bprintf("%d", alloc_len);
359                     }
360                     break;
361                 case TRANSMIT_EXT_OPCODE:
362                     print_opcode(reg_num ? "transmitudp" : "transmit");
363                     u8 ip_ofs = DECODE_IMM(1);
364                     u8 csum_ofs = DECODE_IMM(1);
365                     if (csum_ofs < 255) {
366                         u8 csum_start = DECODE_IMM(1);
367                         u16 partial_csum = DECODE_IMM(2);
368                         bprintf("ip_ofs=%d, csum_ofs=%d, csum_start=%d, partial_csum=0x%04x",
369                                 ip_ofs, csum_ofs, csum_start, partial_csum);
370                     } else {
371                         bprintf("ip_ofs=%d", ip_ofs);
372                     }
373                     break;
374                 case EWRITE1_EXT_OPCODE: print_opcode("ewrite1"); bprintf("r%d", reg_num); break;
375                 case EWRITE2_EXT_OPCODE: print_opcode("ewrite2"); bprintf("r%d", reg_num); break;
376                 case EWRITE4_EXT_OPCODE: print_opcode("ewrite4"); bprintf("r%d", reg_num); break;
377                 case EPKTDATACOPYIMM_EXT_OPCODE:
378                 case EPKTDATACOPYR1_EXT_OPCODE: {
379                     if (reg_num == 0) {
380                         print_opcode("epktcopy");
381                     } else {
382                         print_opcode("edatacopy");
383                     }
384                     if (imm == EPKTDATACOPYIMM_EXT_OPCODE) {
385                         uint32_t len = DECODE_IMM(1);
386                         if (!len) len = 256 + DECODE_IMM(1);
387                         bprintf("src=r0, len=%d", len);
388                     } else {
389                         bprintf("src=r0, len=r1");
390                     }
391 
392                     break;
393                 }
394                 case JDNSAMATCH_EXT_OPCODE:
395                 case JDNSQMATCH_EXT_OPCODE:
396                 case JDNSQMATCH1_EXT_OPCODE:
397                 case JDNSQMATCH2_EXT_OPCODE:
398                 case JDNSAMATCHSAFE_EXT_OPCODE:
399                 case JDNSQMATCHSAFE_EXT_OPCODE:
400                 case JDNSQMATCHSAFE1_EXT_OPCODE:
401                 case JDNSQMATCHSAFE2_EXT_OPCODE: {
402                     uint32_t offs = DECODE_IMM(1 << (len_field - 1));
403                     int qtype1 = -1;
404                     int qtype2 = -1;
405                     switch (imm) {
406                         case JDNSQMATCH_EXT_OPCODE:
407                             print_opcode(reg_num ? "jdnsqeq" : "jdnsqne");
408                             qtype1 = DECODE_IMM(1);
409                             break;
410                         case JDNSQMATCHSAFE_EXT_OPCODE:
411                             print_opcode(reg_num ? "jdnsqeqsafe" : "jdnsqnesafe");
412                             qtype1 = DECODE_IMM(1);
413                             break;
414                         case JDNSAMATCH_EXT_OPCODE:
415                             print_opcode(reg_num ? "jdnsaeq" : "jdnsane"); break;
416                         case JDNSAMATCHSAFE_EXT_OPCODE:
417                             print_opcode(reg_num ? "jdnsaeqsafe" : "jdnsanesafe"); break;
418                         case JDNSQMATCH2_EXT_OPCODE:
419                             qtype1 = DECODE_IMM(1);
420                             qtype2 = DECODE_IMM(1);
421                             print_opcode(reg_num ? "jdnsqeq2" : "jdnsqne2");
422                             break;
423                         case JDNSQMATCHSAFE2_EXT_OPCODE:
424                             qtype1 = DECODE_IMM(1);
425                             qtype2 = DECODE_IMM(1);
426                             print_opcode(reg_num ? "jdnsqeqsafe2" : "jdnsqnesafe2");
427                             break;
428                         case JDNSQMATCH1_EXT_OPCODE:
429                             qtype1 = DECODE_IMM(2);
430                             print_opcode(reg_num ? "jdnsqeq1" : "jdnsqne1");
431                             break;
432                         case JDNSQMATCHSAFE1_EXT_OPCODE:
433                             qtype1 = DECODE_IMM(2);
434                             print_opcode(reg_num ? "jdnsqeqsafe1" : "jdnsqnesafe1");
435                             break;
436                         default:
437                             bprintf("unknown_ext %u", imm); break;
438                     }
439                     bprintf("r0, ");
440                     uint32_t end = *ptr2pc;
441                     while (end + 1 < program_len && !(program[end] == 0 && program[end + 1] == 0)) {
442                         end++;
443                     }
444                     end += 2;
445                     print_jump_target(end + offs, program_len);
446                     bprintf(", ");
447                     if (imm == JDNSQMATCH_EXT_OPCODE || imm == JDNSQMATCHSAFE_EXT_OPCODE ||
448                         imm == JDNSQMATCH1_EXT_OPCODE || imm == JDNSQMATCHSAFE1_EXT_OPCODE) {
449                         print_qtype(qtype1);
450                     } else if (imm == JDNSQMATCH2_EXT_OPCODE || imm == JDNSQMATCHSAFE2_EXT_OPCODE) {
451                         print_qtype(qtype1);
452                         print_qtype(qtype2);
453                     }
454                     while (*ptr2pc < end) {
455                         uint8_t byte = program[(*ptr2pc)++];
456                         // value == 0xff is a wildcard that consumes the whole label.
457                         // values < 0x40 could be lengths, but - and 0..9 are in practice usually
458                         // too long to be lengths so print them as characters. All other chars < 0x40
459                         // are not valid in dns character.
460                         if (byte == 0xff) {
461                             bprintf("(*)");
462                         } else if (byte == '-' || (byte >= '0' && byte <= '9') || byte >= 0x40) {
463                             bprintf("%c", byte);
464                         } else {
465                             bprintf("(%d)", byte);
466                         }
467                     }
468                     break;
469                 }
470                 case JONEOF_EXT_OPCODE: {
471                     const uint32_t imm_len = 1 << (len_field - 1);
472                     uint32_t jump_offs = DECODE_IMM(imm_len);
473                     uint8_t imm3 = DECODE_IMM(1);
474                     bool jmp = imm3 & 1;
475                     uint8_t len = ((imm3 >> 1) & 3) + 1;
476                     uint8_t cnt = (imm3 >> 3) + 2;
477                     if (jmp) {
478                         print_opcode("jnoneof");
479                     } else {
480                         print_opcode("joneof");
481                     }
482                     bprintf("r%d, ", reg_num);
483                     print_jump_target(*ptr2pc + jump_offs + cnt * len, program_len);
484                     bprintf(", { ");
485                     while (cnt--) {
486                         uint32_t v = DECODE_IMM(len);
487                         if (cnt) {
488                             bprintf("%d, ", v);
489                         } else {
490                             bprintf("%d ", v);
491                         }
492                     }
493                     bprintf("}");
494                     break;
495                 }
496                 case EXCEPTIONBUFFER_EXT_OPCODE: {
497                     uint32_t buf_size = DECODE_IMM(2);
498                     print_opcode("debugbuf");
499                     bprintf("size=%d", buf_size);
500                     break;
501                 }
502                 default:
503                     bprintf("unknown_ext %u", imm);
504                     break;
505             }
506             break;
507         case LDDW_OPCODE:
508         case STDW_OPCODE:
509             PRINT_OPCODE();
510             if (v6_mode) {
511                 if (opcode == LDDW_OPCODE) {
512                     bprintf("r%u, counter=%d", reg_num, imm);
513                 } else {
514                     bprintf("counter=%d, r%u", imm, reg_num);
515                 }
516             } else {
517                 if (signed_imm > 0) {
518                     bprintf("r%u, [r%u+%d]", reg_num, reg_num ^ 1, signed_imm);
519                 } else if (signed_imm < 0) {
520                     bprintf("r%u, [r%u-%d]", reg_num, reg_num ^ 1, -signed_imm);
521                 } else {
522                     bprintf("r%u, [r%u]", reg_num, reg_num ^ 1);
523                 }
524             }
525             break;
526         case WRITE_OPCODE: {
527             PRINT_OPCODE();
528             uint32_t write_len = 1 << (len_field - 1);
529             if (write_len > 0) {
530                 bprintf("0x");
531             }
532             for (uint32_t i = 0; i < write_len; ++i) {
533                 uint8_t byte =
534                     (uint8_t) ((imm >> (write_len - 1 - i) * 8) & 0xff);
535                 bprintf("%02x", byte);
536 
537             }
538             break;
539         }
540         case PKTDATACOPY_OPCODE: {
541             uint32_t src_offs = imm;
542             uint32_t copy_len = DECODE_IMM(1);
543             if (!copy_len) copy_len = 256 + DECODE_IMM(1);
544             if (reg_num == 0) {
545                 print_opcode("pktcopy");
546                 bprintf("src=%d, len=%d", src_offs, copy_len);
547             } else {
548                 print_opcode("datacopy");
549                 bprintf("src=%d, (%d)", src_offs, copy_len);
550                 for (uint32_t i = 0; i < copy_len; ++i) {
551                     uint8_t byte = program[src_offs + i];
552                     bprintf("%02x", byte);
553                 }
554             }
555             break;
556         }
557         // JNSET_OPCODE handled up above
558         case JBSPTRMATCH_OPCODE: {
559             print_opcode(reg_num ? "jbsptreq" : "jbsptrne");
560             bprintf("pktofs=%d, ", DECODE_IMM(1));
561             const uint8_t cmp_imm = DECODE_IMM(1);
562             const uint8_t cnt = (cmp_imm >> 4) + 1; // 1..16
563             const uint8_t len = (cmp_imm & 15) + 1; // 1..16
564             bprintf("(%u), ", len);
565             print_jump_target(*ptr2pc + imm + cnt, program_len);
566             bprintf(", ");
567             if (cnt > 1) bprintf("{ ");
568             for (int i = 0; i < cnt; ++i) {
569                 uint8_t ofs = program[(*ptr2pc)++];
570                 bprintf("@%d[", ofs * 2);
571                 for (int j = 0; j < len; ++j) bprintf("%02x", program[3 + 2 * ofs + j]);
572                 bprintf("]");
573                 if (i != cnt - 1) bprintf(", ");
574             }
575             if (cnt > 1) bprintf(" }[%d]", cnt);
576             break;
577         }
578         case ALLOC_XMIT_OPCODE:
579             if (reg_num) {
580                 print_opcode("allocate");
581                 bprintf("(%d)", 266 + 8 * imm);
582             } else {
583                 if (len_field) {
584                     static const char * const protocol[4] = { "udp", "tcp", "icmp", "alert/icmp" };
585                     print_opcode(imm & 3 ? "transmit" : "transmitudp");
586                     bprintf("offload=%s/%s, partial_csum=0x%x", imm & 4 ? "ipv6" : "ipv4",
587                             protocol[imm & 3], imm >> 3);
588                 } else {
589                     print_opcode("transmit");
590                 }
591             }
592             break;
593         // Unknown opcode
594         default:
595             bprintf("unknown %u", opcode);
596             break;
597     }
598     snprintf(prefix_buf, sizeof(prefix_buf), "(%4u) ", (*ptr2pc - prev_pc));
599     disas_ret ret = {
600         .prefix = prefix_buf,
601         .content = print_buf
602     };
603     return ret;
604 }
605