• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.apf;
18 
19 import static android.net.apf.BaseApfGenerator.Rbit.Rbit0;
20 import static android.net.apf.BaseApfGenerator.Rbit.Rbit1;
21 import static android.net.apf.BaseApfGenerator.Register.R0;
22 
23 import android.annotation.NonNull;
24 import android.util.SparseArray;
25 
26 import com.android.net.module.util.HexDump;
27 
28 import java.util.ArrayList;
29 import java.util.Arrays;
30 import java.util.List;
31 import java.util.Objects;
32 
33 /**
34  * The base class for APF assembler/generator.
35  *
36  * @hide
37  */
38 public abstract class BaseApfGenerator {
39 
BaseApfGenerator(int version, int ramSize, int clampSize, boolean disableCounterRangeCheck)40     public BaseApfGenerator(int version, int ramSize, int clampSize,
41             boolean disableCounterRangeCheck) {
42         mVersion = version;
43         mRamSize = ramSize;
44         mClampSize = clampSize;
45         mDisableCounterRangeCheck = disableCounterRangeCheck;
46     }
47 
48     /**
49      * This exception is thrown when an attempt is made to generate an illegal instruction.
50      */
51     public static class IllegalInstructionException extends Exception {
IllegalInstructionException(String msg)52         IllegalInstructionException(String msg) {
53             super(msg);
54         }
55     }
56     enum Opcodes {
57         LABEL(-1),
58         // Unconditionally pass (if R=0) or drop (if R=1) packet.
59         // An optional unsigned immediate value can be provided to encode the counter number.
60         // If the value is non-zero, the instruction increments the counter.
61         // The counter is located (-4 * counter number) bytes from the end of the data region.
62         // It is a U32 native-endian value and is always incremented by 1.
63         // This is more or less equivalent to: lddw R0, -N4; add R0,1; stdw R0, -N4; {pass,drop}
64         // e.g. "pass", "pass 1", "drop", "drop 1"
65         PASSDROP(0),
66         LDB(1),    // Load 1 byte from immediate offset, e.g. "ldb R0, [5]"
67         LDH(2),    // Load 2 bytes from immediate offset, e.g. "ldh R0, [5]"
68         LDW(3),    // Load 4 bytes from immediate offset, e.g. "ldw R0, [5]"
69         LDBX(4),   // Load 1 byte from immediate offset plus register, e.g. "ldbx R0, [5]R0"
70         LDHX(5),   // Load 2 byte from immediate offset plus register, e.g. "ldhx R0, [5]R0"
71         LDWX(6),   // Load 4 byte from immediate offset plus register, e.g. "ldwx R0, [5]R0"
72         ADD(7),    // Add, e.g. "add R0,5"
73         MUL(8),    // Multiply, e.g. "mul R0,5"
74         DIV(9),    // Divide, e.g. "div R0,5"
75         AND(10),   // And, e.g. "and R0,5"
76         OR(11),    // Or, e.g. "or R0,5"
77         SH(12),    // Left shift, e.g. "sh R0, 5" or "sh R0, -5" (shifts right)
78         LI(13),    // Load immediate, e.g. "li R0,5" (immediate encoded as signed value)
79         // Jump, e.g. "jmp label"
80         // In APFv6, we use JMP(R=1) to encode the DATA instruction. DATA is executed as a jump.
81         // It tells how many bytes of the program regions are used to store the data and followed
82         // by the actual data bytes.
83         // "e.g. data 5, abcde"
84         JMP(14),
85         JEQ(15),   // Compare equal and branch, e.g. "jeq R0,5,label"
86         JNE(16),   // Compare not equal and branch, e.g. "jne R0,5,label"
87         JGT(17),   // Compare greater than and branch, e.g. "jgt R0,5,label"
88         JLT(18),   // Compare less than and branch, e.g. "jlt R0,5,label"
89         JSET(19),  // Compare any bits set and branch, e.g. "jset R0,5,label"
90         // Compare not equal byte sequence, e.g. "jnebs R0,5,label,0x1122334455"
91         // NOTE: Only APFv6+ implements R=1 'jbseq' version and multi match
92         // imm1 is jmp target, imm2 is (cnt - 1) * 2048 + compare_len,
93         // which is followed by cnt * compare_len bytes to compare against.
94         // Warning: do not specify the same byte sequence multiple times.
95         JBSMATCH(20),
96         EXT(21),   // Followed by immediate indicating ExtendedOpcodes.
97         LDDW(22),  // Load 4 bytes from data memory address (register + immediate): "lddw R0, [5]R1"
98         STDW(23),  // Store 4 bytes to data memory address (register + immediate): "stdw R0, [5]R1"
99         // Write 1, 2 or 4 bytes immediate to the output buffer and auto-increment the pointer to
100         // write. e.g. "write 5"
101         WRITE(24),
102         // Copy bytes from input packet/APF program/data region to output buffer and
103         // auto-increment the output buffer pointer.
104         // Register bit is used to specify the source of data copy.
105         // R=0 means copy from packet.
106         // R=1 means copy from APF program/data region.
107         // The copy length is stored in (u8)imm2.
108         // e.g. "pktcopy 5, 5" "datacopy 5, 5"
109         PKTDATACOPY(25),
110         // JSET with reverse condition (jump if no bits set)
111         JNSET(26),
112         // APFv6.1: Compare byte sequence [R=0 not] equal, e.g. "jbsptrne 22,16,label,<dataptr>"
113         // imm1 is jmp target
114         // imm2(u8) is offset [0..255] into packet
115         // imm3(u8) is (count - 1) * 16 + (compare_len - 1), thus both count & compare_len are in
116         // [1..16] which is followed by compare_len u8 'even offset' ptrs into max 526 byte data
117         // section to compare against - ie. they are multipied by 2 and have 3 added to them
118         // (to skip over 'datajmp u16')
119         // Warning: do not specify the same byte sequence multiple times.
120         JBSPTRMATCH(27),
121         // APFv6.1: Bytecode optimized allocate | transmit instruction.
122         // R=1 -> allocate(266 + imm * 8)
123         // R=0 -> transmit
124         //   immlen=0 -> no checksum offload (transmit ip_ofs=255)
125         //   immlen>0 -> with checksum offload (transmit(udp) ip_ofs=14 ...)
126         //     imm & 7 | type of offload      | ip_ofs | udp | csum_start  | csum_ofs      | partial_csum |
127         //         0   | ip4/udp              |   14   |  X  | 14+20-8 =26 | 14+20   +6=40 |   imm >> 3   |
128         //         1   | ip4/tcp              |   14   |     | 14+20-8 =26 | 14+20  +10=44 |     --"--    |
129         //         2   | ip4/icmp             |   14   |     | 14+20   =34 | 14+20   +2=36 |     --"--    |
130         //         3   | ip4/routeralert/icmp |   14   |     | 14+20+4 =38 | 14+20+4 +2=40 |     --"--    |
131         //         4   | ip6/udp              |   14   |  X  | 14+40-32=22 | 14+40   +6=60 |     --"--    |
132         //         5   | ip6/tcp              |   14   |     | 14+40-32=22 | 14+40  +10=64 |     --"--    |
133         //         6   | ip6/icmp             |   14   |     | 14+40-32=22 | 14+40   +2=56 |     --"--    |
134         //         7   | ip6/routeralert/icmp |   14   |     | 14+40-32=22 | 14+40+8 +2=64 |     --"--    |
135         ALLOC_XMIT(28);
136 
137         final int value;
138 
Opcodes(int value)139         Opcodes(int value) {
140             this.value = value;
141         }
142     }
143     // Extended opcodes. Primary opcode is Opcodes.EXT. ExtendedOpcodes are encoded in the immediate
144     // field.
145     enum ExtendedOpcodes {
146         LDM(0),   // Load from memory, e.g. "ldm R0,5"
147         STM(16),  // Store to memory, e.g. "stm R0,5"
148         NOT(32),  // Not, e.g. "not R0"
149         NEG(33),  // Negate, e.g. "neg R0"
150         SWAP(34), // Swap, e.g. "swap R0,R1"
151         MOVE(35),  // Move, e.g. "move R0,R1"
152         // Allocate writable output buffer.
153         // R=0, use register R0 to store the length. R=1, encode the length in the u16 int imm2.
154         // "e.g. allocate R0"
155         // "e.g. allocate 123"
156         ALLOCATE(36),
157         // Transmit and deallocate the buffer (transmission can be delayed until the program
158         // terminates).  Length of buffer is the output buffer pointer (0 means discard).
159         // R=1 iff udp style L4 checksum
160         // u8 imm2 - ip header offset from start of buffer (255 for non-ip packets)
161         // u8 imm3 - offset from start of buffer to store L4 checksum (255 for no L4 checksum)
162         // u8 imm4 - offset from start of buffer to begin L4 checksum calc (present iff imm3 != 255)
163         // u16 imm5 - partial checksum value to include in L4 checksum (present iff imm3 != 255)
164         // "e.g. transmit"
165         TRANSMIT(37),
166         // Write 1, 2 or 4 byte value from register to the output buffer and auto-increment the
167         // output buffer pointer.
168         // e.g. "ewrite1 r0"
169         EWRITE1(38),
170         EWRITE2(39),
171         EWRITE4(40),
172         // Copy bytes from input packet/APF program/data region to output buffer and
173         // auto-increment the output buffer pointer.
174         // Register bit is used to specify the source of data copy.
175         // R=0 means copy from packet.
176         // R=1 means copy from APF program/data region.
177         // The source offset is stored in R0, copy length is stored in u8 imm2 or R1.
178         // e.g. "epktcopy r0, 16", "edatacopy r0, 16", "epktcopy r0, r1", "edatacopy r0, r1"
179         EPKTDATACOPYIMM(41),
180         EPKTDATACOPYR1(42),
181         // Jumps if the UDP payload content (starting at R0) does [not] match one
182         // of the specified QNAMEs in question records, applying case insensitivity.
183         // SAFE version PASSES corrupt packets, while the other one DROPS.
184         // R=0/1 meaning 'does not match'/'matches'
185         // R0: Offset to UDP payload content
186         // imm1: Extended opcode
187         // imm2: Jump label offset
188         // imm3(u8): Question type (PTR/SRV/TXT/A/AAAA)
189         // imm4(bytes): null terminated list of null terminated LV-encoded QNAMEs
190         // e.g.: "jdnsqeq R0,label,0xc,\002aa\005local\0\0",
191         //       "jdnsqne R0,label,0xc,\002aa\005local\0\0"
192         JDNSQMATCH(43),
193         JDNSQMATCHSAFE(45),
194         // Jumps if the UDP payload content (starting at R0) does [not] match one
195         // of the specified NAMEs in answers/authority/additional records, applying
196         // case insensitivity.
197         // SAFE version PASSES corrupt packets, while the other one DROPS.
198         // R=0/1 meaning 'does not match'/'matches'
199         // R0: Offset to UDP payload content
200         // imm1: Extended opcode
201         // imm2: Jump label offset
202         // imm3(bytes): null terminated list of null terminated LV-encoded NAMEs
203         // e.g.: "jdnsaeq R0,label,0xc,\002aa\005local\0\0",
204         //       "jdnsane R0,label,0xc,\002aa\005local\0\0"
205 
206         JDNSAMATCH(44),
207         JDNSAMATCHSAFE(46),
208         // Jump if register is [not] one of the list of values
209         // R bit - specifies the register (R0/R1) to test
210         // imm1: Extended opcode
211         // imm2: Jump label offset
212         // imm3(u8): top 5 bits - number of following u8/be16/be32 values - 1
213         //        middle 2 bits - 1..4 length of immediates - 1
214         //        bottom 1 bit  - =0 jmp if in set, =1 if not in set
215         // imm4(imm3 * 1/2/3/4 bytes): the *UNIQUE* values to compare against
216         JONEOF(47),
217         // Specify length of exception buffer, which is populated on abnormal program termination.
218         // imm1: Extended opcode
219         // imm2(u16): Length of exception buffer (located *immediately* after the program itself)
220         EXCEPTIONBUFFER(48),
221         // Jumps if the UDP payload content (starting at R0) does [not] match one
222         // of the specified QNAMEs in question records, applying case insensitivity.
223         // The qtypes in the input packet can match either of the two supplied qtypes.
224         // SAFE version PASSES corrupt packets, while the other one DROPS.
225         // R=0/1 meaning 'does not match'/'matches'
226         // R0: Offset to UDP payload content
227         // imm1: Extended opcode
228         // imm2: Jump label offset
229         // imm3(u8): Question type1 (PTR/SRV/TXT/A/AAAA)
230         // imm4(u8): Question type2 (PTR/SRV/TXT/A/AAAA)
231         // imm5(bytes): null terminated list of null terminated LV-encoded QNAMEs
232         // e.g.: "jdnsqeq2 R0,label,A,AAAA,\002aa\005local\0\0",
233         //       "jdnsqne2 R0,label,A,AAAA,\002aa\005local\0\0"
234         JDNSQMATCH2(51),
235         JDNSQMATCHSAFE2(53);
236 
237         final int value;
238 
ExtendedOpcodes(int value)239         ExtendedOpcodes(int value) {
240             this.value = value;
241         }
242     }
243     public enum Register {
244         R0,
245         R1;
246 
other()247         Register other() {
248             return (this == R0) ? R1 : R0;
249         }
250     }
251 
252     public enum Rbit {
253         Rbit0(0),
254         Rbit1(1);
255 
256         final int value;
257 
Rbit(int value)258         Rbit(int value) {
259             this.value = value;
260         }
261     }
262 
263     private enum IntImmediateType {
264         INDETERMINATE_SIZE_SIGNED,
265         INDETERMINATE_SIZE_UNSIGNED,
266         SIGNED_8,
267         UNSIGNED_8,
268         SIGNED_BE16,
269         UNSIGNED_BE16,
270         SIGNED_BE32,
271         UNSIGNED_BE32;
272     }
273 
274     private static class IntImmediate {
275         public final IntImmediateType mImmediateType;
276         public final int mValue;
277 
IntImmediate(int value, IntImmediateType type)278         IntImmediate(int value, IntImmediateType type) {
279             mImmediateType = type;
280             mValue = value;
281         }
282 
calculateIndeterminateSize()283         private int calculateIndeterminateSize() {
284             switch (mImmediateType) {
285                 case INDETERMINATE_SIZE_SIGNED:
286                     return calculateImmSize(mValue, true /* signed */);
287                 case INDETERMINATE_SIZE_UNSIGNED:
288                     return calculateImmSize(mValue, false /* signed */);
289                 default:
290                     // For IMM with determinate size, return 0 to allow Math.max() calculation in
291                     // caller function.
292                     return 0;
293             }
294         }
295 
getEncodingSize(int immFieldSize)296         private int getEncodingSize(int immFieldSize) {
297             switch (mImmediateType) {
298                 case SIGNED_8:
299                 case UNSIGNED_8:
300                     return 1;
301                 case SIGNED_BE16:
302                 case UNSIGNED_BE16:
303                     return 2;
304                 case SIGNED_BE32:
305                 case UNSIGNED_BE32:
306                     return 4;
307                 case INDETERMINATE_SIZE_SIGNED:
308                 case INDETERMINATE_SIZE_UNSIGNED: {
309                     int minSizeRequired = calculateIndeterminateSize();
310                     if (minSizeRequired > immFieldSize) {
311                         throw new IllegalStateException(
312                                 String.format("immFieldSize: %d is too small to encode value %d",
313                                         immFieldSize, mValue));
314                     }
315                     return immFieldSize;
316                 }
317             }
318             throw new IllegalStateException("UnhandledInvalid IntImmediateType: " + mImmediateType);
319         }
320 
writeValue(byte[] bytecode, Integer writingOffset, int immFieldSize)321         private int writeValue(byte[] bytecode, Integer writingOffset, int immFieldSize) {
322             return Instruction.writeValue(mValue, bytecode, writingOffset,
323                     getEncodingSize(immFieldSize));
324         }
325 
newSigned(int imm)326         public static IntImmediate newSigned(int imm) {
327             return new IntImmediate(imm, IntImmediateType.INDETERMINATE_SIZE_SIGNED);
328         }
329 
newUnsigned(long imm)330         public static IntImmediate newUnsigned(long imm) {
331             // upperBound is 2^32 - 1
332             checkRange("Unsigned IMM", imm, 0 /* lowerBound */,
333                     4294967295L /* upperBound */);
334             return new IntImmediate((int) imm, IntImmediateType.INDETERMINATE_SIZE_UNSIGNED);
335         }
336 
newTwosComplementUnsigned(long imm)337         public static IntImmediate newTwosComplementUnsigned(long imm) {
338             checkRange("Unsigned TwosComplement IMM", imm, Integer.MIN_VALUE,
339                     4294967295L /* upperBound */);
340             return new IntImmediate((int) imm, IntImmediateType.INDETERMINATE_SIZE_UNSIGNED);
341         }
342 
newTwosComplementSigned(long imm)343         public static IntImmediate newTwosComplementSigned(long imm) {
344             checkRange("Signed TwosComplement IMM", imm, Integer.MIN_VALUE,
345                     4294967295L /* upperBound */);
346             return new IntImmediate((int) imm, IntImmediateType.INDETERMINATE_SIZE_SIGNED);
347         }
348 
newS8(byte imm)349         public static IntImmediate newS8(byte imm) {
350             checkRange("S8 IMM", imm, Byte.MIN_VALUE, Byte.MAX_VALUE);
351             return new IntImmediate(imm, IntImmediateType.SIGNED_8);
352         }
353 
newU8(int imm)354         public static IntImmediate newU8(int imm) {
355             checkRange("U8 IMM", imm, 0, 255);
356             return new IntImmediate(imm, IntImmediateType.UNSIGNED_8);
357         }
358 
newS16(short imm)359         public static IntImmediate newS16(short imm) {
360             return new IntImmediate(imm, IntImmediateType.SIGNED_BE16);
361         }
362 
newU16(int imm)363         public static IntImmediate newU16(int imm) {
364             checkRange("U16 IMM", imm, 0, 65535);
365             return new IntImmediate(imm, IntImmediateType.UNSIGNED_BE16);
366         }
367 
newS32(int imm)368         public static IntImmediate newS32(int imm) {
369             return new IntImmediate(imm, IntImmediateType.SIGNED_BE32);
370         }
371 
newU32(long imm)372         public static IntImmediate newU32(long imm) {
373             // upperBound is 2^32 - 1
374             checkRange("U32 IMM", imm, 0 /* lowerBound */,
375                     4294967295L /* upperBound */);
376             return new IntImmediate((int) imm, IntImmediateType.UNSIGNED_BE32);
377         }
378 
379         @Override
toString()380         public String toString() {
381             return "IntImmediate{" + "mImmediateType=" + mImmediateType + ", mValue=" + mValue
382                     + '}';
383         }
384     }
385 
386     class Instruction {
387         public final Opcodes mOpcode;
388         private final Rbit mRbit;
389         public final List<IntImmediate> mIntImms = new ArrayList<>();
390         // When mOpcode is a jump:
391         private int mTargetLabelSize;
392         private int mImmSizeOverride = -1;
393         // mTargetLabel == -1 indicates it is uninitialized. mTargetLabel < -1 indicates a label
394         // within the program used for offset calculation. mTargetLabel >= 0 indicates a pass/drop
395         // label, its offset is mTargetLabel + program size.
396         private short mTargetLabel = -1;
397         public byte[] mBytesImm;
398         // Offset in bytes from the beginning of this program.
399         // Set by {@link BaseApfGenerator#generate}.
400         int offset;
401 
Instruction(Opcodes opcode, Rbit rbit)402         Instruction(Opcodes opcode, Rbit rbit) {
403             mOpcode = opcode;
404             mRbit = rbit;
405         }
406 
Instruction(Opcodes opcode, Register register)407         Instruction(Opcodes opcode, Register register) {
408             this(opcode, register == R0 ? Rbit0 : Rbit1);
409         }
410 
Instruction(ExtendedOpcodes extendedOpcodes, Rbit rbit)411         Instruction(ExtendedOpcodes extendedOpcodes, Rbit rbit) {
412             this(Opcodes.EXT, rbit);
413             addUnsigned(extendedOpcodes.value);
414         }
415 
Instruction(ExtendedOpcodes extendedOpcodes, Register register)416         Instruction(ExtendedOpcodes extendedOpcodes, Register register) {
417             this(Opcodes.EXT, register);
418             addUnsigned(extendedOpcodes.value);
419         }
420 
Instruction(ExtendedOpcodes extendedOpcodes, int slot, Register register)421         Instruction(ExtendedOpcodes extendedOpcodes, int slot, Register register)
422                 throws IllegalInstructionException {
423             this(Opcodes.EXT, register);
424             if (slot < 0 || slot >= MEMORY_SLOTS) {
425                 throw new IllegalInstructionException("illegal memory slot number: " + slot);
426             }
427             addUnsigned(extendedOpcodes.value + slot);
428         }
429 
Instruction(Opcodes opcode)430         Instruction(Opcodes opcode) {
431             this(opcode, R0);
432         }
433 
Instruction(ExtendedOpcodes extendedOpcodes)434         Instruction(ExtendedOpcodes extendedOpcodes) {
435             this(extendedOpcodes, R0);
436         }
437 
addSigned(int imm)438         Instruction addSigned(int imm) {
439             mIntImms.add(IntImmediate.newSigned(imm));
440             return this;
441         }
442 
addUnsigned(long imm)443         Instruction addUnsigned(long imm) {
444             mIntImms.add(IntImmediate.newUnsigned(imm));
445             return this;
446         }
447 
448         // in practice, 'int' always enough for packet offset
addPacketOffset(int imm)449         Instruction addPacketOffset(int imm) {
450             return addUnsigned(imm);
451         }
452 
453         // in practice, 'int' always enough for data offset
addDataOffset(int imm)454         Instruction addDataOffset(int imm) {
455             return addUnsigned(imm);
456         }
457 
addTwosCompSigned(long imm)458         Instruction addTwosCompSigned(long imm) {
459             mIntImms.add(IntImmediate.newTwosComplementSigned(imm));
460             return this;
461         }
462 
addTwosCompUnsigned(long imm)463         Instruction addTwosCompUnsigned(long imm) {
464             mIntImms.add(IntImmediate.newTwosComplementUnsigned(imm));
465             return this;
466         }
467 
addS8(byte imm)468         Instruction addS8(byte imm) {
469             mIntImms.add(IntImmediate.newS8(imm));
470             return this;
471         }
472 
addU8(int imm)473         Instruction addU8(int imm) {
474             mIntImms.add(IntImmediate.newU8(imm));
475             return this;
476         }
477 
addS16(short imm)478         Instruction addS16(short imm) {
479             mIntImms.add(IntImmediate.newS16(imm));
480             return this;
481         }
482 
addU16(int imm)483         Instruction addU16(int imm) {
484             mIntImms.add(IntImmediate.newU16(imm));
485             return this;
486         }
487 
addS32(int imm)488         Instruction addS32(int imm) {
489             mIntImms.add(IntImmediate.newS32(imm));
490             return this;
491         }
492 
addU32(long imm)493         Instruction addU32(long imm) {
494             mIntImms.add(IntImmediate.newU32(imm));
495             return this;
496         }
497 
setLabel(short label)498         Instruction setLabel(short label) throws IllegalInstructionException {
499             if (mLabels.get(label) != null) {
500                 throw new IllegalInstructionException("duplicate label " + label);
501             }
502             if (mOpcode != Opcodes.LABEL) {
503                 throw new IllegalStateException("adding label to non-label instruction");
504             }
505             mLabels.put(label, this);
506             return this;
507         }
508 
setTargetLabel(short label)509         Instruction setTargetLabel(short label) {
510             mTargetLabel = label;
511             mTargetLabelSize = 4; // May shrink later on in generate().
512             return this;
513         }
514 
overrideImmSize(int size)515         Instruction overrideImmSize(int size) {
516             mImmSizeOverride = size;
517             return this;
518         }
519 
setBytesImm(byte[] bytes)520         Instruction setBytesImm(byte[] bytes) {
521             mBytesImm = bytes;
522             return this;
523         }
524 
findMatchInDataBytes(@onNull byte[] content, int fromIndex, int toIndex)525         int findMatchInDataBytes(@NonNull byte[] content, int fromIndex, int toIndex)
526                 throws IllegalInstructionException {
527             if (fromIndex >= toIndex || fromIndex < 0 || toIndex > content.length) {
528                 throw new IllegalArgumentException(
529                         String.format("fromIndex: %d, toIndex: %d, content length: %d", fromIndex,
530                                 toIndex, content.length));
531             }
532             if (mOpcode != Opcodes.JMP || mBytesImm == null) {
533                 throw new IllegalInstructionException(String.format(
534                         "this method is only valid for jump data instruction, mOpcode "
535                                 + ":%s, mBytesImm: %s", Opcodes.JMP,
536                         mBytesImm == null ? "(empty)" : HexDump.toHexString(mBytesImm)));
537             }
538             if (mImmSizeOverride != 2) {
539                 throw new IllegalInstructionException(
540                         "mImmSizeOverride must be 2, mImmSizeOverride: " + mImmSizeOverride);
541             }
542             final int subArrayLength = toIndex - fromIndex;
543             for (int i = 0; i < mBytesImm.length - subArrayLength + 1; i++) {
544                 boolean found = true;
545                 for (int j = 0; j < subArrayLength; j++) {
546                     if (mBytesImm[i + j] != content[fromIndex + j]) {
547                         found = false;
548                         break;
549                     }
550                 }
551                 if (found) {
552                     return i;
553                 }
554             }
555             return -1;
556         }
557 
concat(byte[] prefix, byte[] suffix, int suffixFrom, int suffixTo)558         private static byte[] concat(byte[] prefix, byte[] suffix, int suffixFrom, int suffixTo) {
559             final byte[] newArray = new byte[prefix.length + suffixTo - suffixFrom];
560             System.arraycopy(prefix, 0, newArray, 0, prefix.length);
561             System.arraycopy(suffix, suffixFrom, newArray, prefix.length, suffixTo - suffixFrom);
562             return newArray;
563         }
564 
565         /**
566          * Manages and updates the data region.
567          * <p>
568          * Searches for the specified subarray within the existing data region. If the subarray
569          * is not found, it is appended to the data region. The subarray is defined as the
570          * portion of the {@code content} starting at {@code fromIndex} (inclusive)
571          * and ending at {@code toIndex} (exclusive).
572          * <p>
573          * @return The starting position of the subarray within the data region.
574          */
maybeUpdateBytesImm(byte[] content, int fromIndex, int toIndex)575         int maybeUpdateBytesImm(byte[] content, int fromIndex, int toIndex)
576                 throws IllegalInstructionException {
577             int offsetInDataBytes = findMatchInDataBytes(content, fromIndex, toIndex);
578             if (offsetInDataBytes == -1) {
579                 offsetInDataBytes = mBytesImm.length;
580                 mBytesImm = concat(mBytesImm, content, fromIndex, toIndex);
581                 // Update the length immediate (first imm) value. Due to mValue within
582                 // IntImmediate being final, we must remove and re-add the value to apply changes.
583                 mIntImms.remove(0);
584                 addDataOffset(mBytesImm.length);
585             }
586             // Note that the data instruction encoding consumes 1 byte and the data length
587             // encoding consumes 2 bytes.
588             return 1 + mImmSizeOverride + offsetInDataBytes;
589         }
590 
591         /**
592          * Updates exception buffer size.
593          * @param bufSize the new exception buffer size
594          */
updateExceptionBufferSize(int bufSize)595         void updateExceptionBufferSize(int bufSize) throws IllegalInstructionException {
596             if (mOpcode != Opcodes.EXT || mIntImms.get(0).mValue
597                     != ExtendedOpcodes.EXCEPTIONBUFFER.value) {
598                 throw new IllegalInstructionException(
599                         "updateExceptionBuffer() is only valid for EXCEPTIONBUFFER opcode");
600             }
601             // Update the buffer size immediate (second imm) value. Due to mValue within
602             // IntImmediate being final, we must remove and re-add the value to apply changes.
603             mIntImms.remove(1);
604             addU16(bufSize);
605         }
606 
607         /**
608          * @return size of instruction in bytes.
609          */
size()610         int size() {
611             if (mOpcode == Opcodes.LABEL) {
612                 return 0;
613             }
614             int size = 1;
615             int indeterminateSize = calculateRequiredIndeterminateSize();
616             for (IntImmediate imm : mIntImms) {
617                 size += imm.getEncodingSize(indeterminateSize);
618             }
619             if (mTargetLabel != -1) {
620                 size += indeterminateSize;
621             }
622             if (mBytesImm != null) {
623                 size += mBytesImm.length;
624             }
625             return size;
626         }
627 
628         /**
629          * Resize immediate value field so that it's only as big as required to
630          * contain the offset of the jump destination.
631          * @return {@code true} if shrunk.
632          */
shrink()633         boolean shrink() throws IllegalInstructionException {
634             if (mTargetLabel == -1) {
635                 return false;
636             }
637             int oldTargetLabelSize = mTargetLabelSize;
638             mTargetLabelSize = calculateImmSize(calculateTargetLabelOffset(), false);
639             if (mTargetLabelSize > oldTargetLabelSize) {
640                 throw new IllegalStateException("instruction grew");
641             }
642             return mTargetLabelSize < oldTargetLabelSize;
643         }
644 
645         /**
646          * Assemble value for instruction size field.
647          */
generateImmSizeField()648         private int generateImmSizeField() {
649             int immSize = calculateRequiredIndeterminateSize();
650             // Encode size field to fit in 2 bits: 0->0, 1->1, 2->2, 3->4.
651             return immSize == 4 ? 3 : immSize;
652         }
653 
654         /**
655          * Assemble first byte of generated instruction.
656          */
generateInstructionByte()657         private byte generateInstructionByte() {
658             int sizeField = generateImmSizeField();
659             return (byte) ((mOpcode.value << 3) | (sizeField << 1) | (byte) mRbit.value);
660         }
661 
662         /**
663          * Write {@code value} at offset {@code writingOffset} into {@code bytecode}.
664          * {@code immSize} bytes are written. {@code value} is truncated to
665          * {@code immSize} bytes. {@code value} is treated simply as a
666          * 32-bit value, so unsigned values should be zero extended and the truncation
667          * should simply throw away their zero-ed upper bits, and signed values should
668          * be sign extended and the truncation should simply throw away their signed
669          * upper bits.
670          */
writeValue(int value, byte[] bytecode, int writingOffset, int immSize)671         private static int writeValue(int value, byte[] bytecode, int writingOffset, int immSize) {
672             for (int i = immSize - 1; i >= 0; i--) {
673                 bytecode[writingOffset++] = (byte) ((value >> (i * 8)) & 255);
674             }
675             return writingOffset;
676         }
677 
678         /**
679          * Generate bytecode for this instruction at offset {@link Instruction#offset}.
680          */
generate(byte[] bytecode)681         void generate(byte[] bytecode) throws IllegalInstructionException {
682             if (mOpcode == Opcodes.LABEL) {
683                 return;
684             }
685             int writingOffset = offset;
686             bytecode[writingOffset++] = generateInstructionByte();
687             int indeterminateSize = calculateRequiredIndeterminateSize();
688             int startOffset = 0;
689             if (mOpcode == Opcodes.EXT) {
690                 // For extend opcode, always write the actual opcode first.
691                 writingOffset = mIntImms.get(startOffset++).writeValue(bytecode, writingOffset,
692                         indeterminateSize);
693             }
694             if (mTargetLabel != -1) {
695                 writingOffset = writeValue(calculateTargetLabelOffset(), bytecode, writingOffset,
696                         indeterminateSize);
697             }
698             for (int i = startOffset; i < mIntImms.size(); ++i) {
699                 writingOffset = mIntImms.get(i).writeValue(bytecode, writingOffset,
700                         indeterminateSize);
701             }
702             if (mBytesImm != null) {
703                 System.arraycopy(mBytesImm, 0, bytecode, writingOffset, mBytesImm.length);
704                 writingOffset += mBytesImm.length;
705             }
706             if ((writingOffset - offset) != size()) {
707                 throw new IllegalStateException("wrote " + (writingOffset - offset)
708                         + " but should have written " + size());
709             }
710         }
711 
712         /**
713          * Calculates the maximum indeterminate size of all IMMs in this instruction.
714          * <p>
715          * This method finds the largest size needed to encode any indeterminate-sized IMMs in
716          * the instruction. This size will be stored in the immLen field.
717          */
calculateRequiredIndeterminateSize()718         private int calculateRequiredIndeterminateSize() {
719             int maxSize = mTargetLabelSize;
720             for (IntImmediate imm : mIntImms) {
721                 maxSize = Math.max(maxSize, imm.calculateIndeterminateSize());
722             }
723             if (mImmSizeOverride != -1 && maxSize > mImmSizeOverride) {
724                 throw new IllegalStateException(String.format(
725                         "maxSize: %d should not be greater than mImmSizeOverride: %d", maxSize,
726                         mImmSizeOverride));
727             }
728             // If we already know the size the length field, just use it
729             switch (mImmSizeOverride) {
730                 case -1:
731                     return maxSize;
732                 case 1:
733                 case 2:
734                 case 4:
735                     return mImmSizeOverride;
736                 default:
737                     throw new IllegalStateException(
738                             "mImmSizeOverride has invalid value: " + mImmSizeOverride);
739             }
740         }
741 
calculateTargetLabelOffset()742         private int calculateTargetLabelOffset() throws IllegalInstructionException {
743             int targetOffset;
744             if (mTargetLabel >= 0) {
745                 targetOffset = mTotalSize + mTargetLabel;
746             } else {
747                 final Instruction targetLabelInstruction = mLabels.get(mTargetLabel);
748                 if (targetLabelInstruction == null) {
749                     throw new IllegalInstructionException("label not found: " + mTargetLabel);
750                 }
751                 targetOffset = targetLabelInstruction.offset;
752             }
753             // Calculate distance from end of this instruction to targetOffset.
754             return targetOffset - (offset + size());
755         }
756     }
757 
758     /**
759      * Updates instruction offset fields using latest instruction sizes.
760      * @return current program length in bytes.
761      */
updateInstructionOffsets()762     private int updateInstructionOffsets() {
763         int offset = 0;
764         for (Instruction instruction : mInstructions) {
765             instruction.offset = offset;
766             offset += instruction.size();
767         }
768         return offset;
769     }
770 
771     /**
772      * Calculate the size of the imm.
773      */
calculateImmSize(int imm, boolean signed)774     static int calculateImmSize(int imm, boolean signed) {
775         if (imm == 0) {
776             return 0;
777         }
778         if (signed && (imm >= -128 && imm <= 127) || !signed && (imm >= 0 && imm <= 255)) {
779             return 1;
780         }
781         if (signed && (imm >= -32768 && imm <= 32767) || !signed && (imm >= 0 && imm <= 65535)) {
782             return 2;
783         }
784         return 4;
785     }
786 
checkRange(@onNull String variableName, long value, long lowerBound, long upperBound)787     static void checkRange(@NonNull String variableName, long value, long lowerBound,
788                            long upperBound) {
789         if (value >= lowerBound && value <= upperBound) {
790             return;
791         }
792         throw new IllegalArgumentException(
793                 String.format("%s: %d, must be in range [%d, %d]", variableName, value, lowerBound,
794                         upperBound));
795     }
796 
checkPassCounterRange(ApfCounterTracker.Counter cnt)797     void checkPassCounterRange(ApfCounterTracker.Counter cnt) {
798         if (mDisableCounterRangeCheck) return;
799         cnt.getJumpPassLabel();
800     }
801 
checkDropCounterRange(ApfCounterTracker.Counter cnt)802     void checkDropCounterRange(ApfCounterTracker.Counter cnt) {
803         if (mDisableCounterRangeCheck) return;
804         cnt.getJumpDropLabel();
805     }
806 
807     /**
808      * Returns an overestimate of the size of the generated program. {@link #generate} may return
809      * a program that is smaller.
810      */
programLengthOverEstimate()811     public int programLengthOverEstimate() {
812         return updateInstructionOffsets();
813     }
814 
815     /**
816      * Updates the exception buffer size.
817      */
updateExceptionBufferSize(int programSize)818     abstract void updateExceptionBufferSize(int programSize) throws IllegalInstructionException;
819 
820     private int mTotalSize;
821 
822     /**
823      * Generate the bytecode for the APF program.
824      * @return the bytecode.
825      * @throws IllegalStateException if a label is referenced but not defined.
826      */
generate()827     public byte[] generate() throws IllegalInstructionException {
828         // Enforce that we can only generate once because we cannot unshrink instructions and
829         // PASS/DROP labels may move further away requiring unshrinking if we add further
830         // instructions.
831         if (mGenerated) {
832             throw new IllegalStateException("Can only generate() once!");
833         }
834         mGenerated = true;
835         boolean shrunk;
836         // Shrink the immediate value fields of instructions.
837         // As we shrink the instructions some branch offset
838         // fields may shrink also, thereby shrinking the
839         // instructions further. Loop until we've reached the
840         // minimum size. Rarely will this loop more than a few times.
841         // Limit iterations to avoid O(n^2) behavior.
842         int iterations_remaining = 10;
843         do {
844             mTotalSize = updateInstructionOffsets();
845             // Limit run-time in aberant circumstances.
846             if (iterations_remaining-- == 0) break;
847             // Attempt to shrink instructions.
848             shrunk = false;
849             for (Instruction instruction : mInstructions) {
850                 if (instruction.shrink()) {
851                     shrunk = true;
852                 }
853             }
854         } while (shrunk);
855         // Generate bytecode for instructions.
856         byte[] bytecode = new byte[mTotalSize];
857         updateExceptionBufferSize(mTotalSize);
858         for (Instruction instruction : mInstructions) {
859             instruction.generate(bytecode);
860         }
861         return bytecode;
862     }
863 
validateBytes(byte[] bytes)864     void validateBytes(byte[] bytes) {
865         Objects.requireNonNull(bytes);
866         if (bytes.length > 2047) {
867             throw new IllegalArgumentException(
868                     "bytes array size must be in less than 2048, current size: " + bytes.length);
869         }
870     }
871 
validateDeduplicateBytesList(List<byte[]> bytesList)872     List<byte[]> validateDeduplicateBytesList(List<byte[]> bytesList) {
873         if (bytesList == null || bytesList.size() == 0) {
874             throw new IllegalArgumentException(
875                     "bytesList size must > 0, current size: "
876                             + (bytesList == null ? "null" : bytesList.size()));
877         }
878         for (byte[] bytes : bytesList) {
879             validateBytes(bytes);
880         }
881         final int elementSize = bytesList.get(0).length;
882         if (elementSize > 2097151) { // 2 ^ 21 - 1
883             throw new IllegalArgumentException("too many elements");
884         }
885         List<byte[]> deduplicatedList = new ArrayList<>();
886         deduplicatedList.add(bytesList.get(0));
887         for (int i = 1; i < bytesList.size(); ++i) {
888             if (elementSize != bytesList.get(i).length) {
889                 throw new IllegalArgumentException("byte arrays in the set have different size");
890             }
891             int j = 0;
892             for (; j < deduplicatedList.size(); ++j) {
893                 if (Arrays.equals(bytesList.get(i), deduplicatedList.get(j))) {
894                     break;
895                 }
896             }
897             if (j == deduplicatedList.size()) {
898                 deduplicatedList.add(bytesList.get(i));
899             }
900         }
901         return deduplicatedList;
902     }
903 
requireApfVersion(int minimumVersion)904     void requireApfVersion(int minimumVersion) throws IllegalInstructionException {
905         if (mVersion < minimumVersion) {
906             throw new IllegalInstructionException("Requires APF >= " + minimumVersion);
907         }
908     }
909 
910     private short mLabelCount = 0;
911 
912     /**
913      * Return a unique label string.
914      */
getUniqueLabel()915     public short getUniqueLabel() {
916         final short nextLabel = (short) -(2 + mLabelCount++);
917         if (nextLabel == Short.MIN_VALUE) {
918             throw new IllegalStateException("Running out of unique labels");
919         }
920         return nextLabel;
921     }
922 
923     /**
924      * Jump to this label to terminate the program and indicate the packet
925      * should be dropped.
926      */
927     public static final short DROP_LABEL = 1;
928 
929     /**
930      * Jump to this label to terminate the program and indicate the packet
931      * should be passed to the AP.
932      */
933     public static final short PASS_LABEL = 0;
934 
935     /**
936      * Number of memory slots available for access via APF stores to memory and loads from memory.
937      * The memory slots are numbered 0 to {@code MEMORY_SLOTS} - 1. This must be kept in sync with
938      * the APF interpreter.
939      */
940     public static final int MEMORY_SLOTS = 16;
941 
942     public enum MemorySlot {
943         /**
944          * These slots start with value 0 and are unused.
945          */
946         SLOT_0(0),
947         SLOT_1(1),
948         SLOT_2(2),
949         SLOT_3(3),
950         SLOT_4(4),
951         SLOT_5(5),
952         SLOT_6(6),
953         SLOT_7(7),
954 
955         /**
956          * First memory slot containing prefilled (ie. non-zero) values.
957          * Can be used in range comparisons to determine if memory slot index
958          * is within prefilled slots.
959          */
960         FIRST_PREFILLED(8),
961 
962         /**
963          * Slot #8 is used for the APFv6+ version.
964          */
965         APF_VERSION(8),
966 
967         /**
968          * Slot #9 is used for the filter age in 16384ths of a second (APFv6+).
969          */
970         FILTER_AGE_16384THS(9),
971 
972         /**
973          * Slot #10 starts at zero, implicitly used as tx buffer output pointer.
974          */
975         TX_BUFFER_OUTPUT_POINTER(10),
976 
977         /**
978          * Slot #11 is used for the program byte code size (APFv2+).
979          */
980         PROGRAM_SIZE(11),
981 
982         /**
983          * Slot #12 is used for the total RAM length.
984          */
985         RAM_LEN(12),
986 
987         /**
988          * Slot #13 is the IPv4 header length (in bytes).
989          */
990         IPV4_HEADER_SIZE(13),
991 
992         /**
993          * Slot #14 is the size of the packet being filtered in bytes.
994          */
995         PACKET_SIZE(14),
996 
997         /**
998          * Slot #15 is the age of the filter (time since filter was installed
999          * till now) in seconds.
1000          */
1001         FILTER_AGE_SECONDS(15);
1002 
1003         public final int value;
1004 
MemorySlot(int value)1005         MemorySlot(int value) {
1006             this.value = value;
1007         }
1008     }
1009 
1010     // This version number syncs up with APF_VERSION in hardware/google/apf/apf_interpreter.h
1011     public static final int APF_VERSION_2 = 2;
1012     public static final int APF_VERSION_3 = 3;
1013     public static final int APF_VERSION_4 = 4;
1014     public static final int APF_VERSION_6 = 6000;
1015     // TODO: update the version code once we finalized APFv6.1.
1016     public static final int APF_VERSION_61 = 20250228;
1017 
1018 
1019     final ArrayList<Instruction> mInstructions = new ArrayList<Instruction>();
1020     private final SparseArray<Instruction> mLabels = new SparseArray<>();
1021     public final int mVersion;
1022     public final int mRamSize;
1023     public final int mClampSize;
1024     public boolean mGenerated;
1025     private final boolean mDisableCounterRangeCheck;
1026 }
1027