• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.net.apf
17 
18 import android.net.apf.ApfCounterTracker.Counter
19 import android.net.apf.ApfCounterTracker.Counter.APF_PROGRAM_ID
20 import android.net.apf.ApfCounterTracker.Counter.APF_VERSION
21 import android.net.apf.ApfCounterTracker.Counter.TOTAL_PACKETS
22 import android.net.apf.BaseApfGenerator.APF_VERSION_6
23 import com.android.net.module.util.HexDump
24 import kotlin.test.assertEquals
25 import org.mockito.ArgumentCaptor
26 import org.mockito.ArgumentMatchers.any
27 import org.mockito.Mockito.clearInvocations
28 import org.mockito.Mockito.timeout
29 import org.mockito.Mockito.verify
30 
31 class ApfTestHelpers(apfInterpreterVersion: Int){
32     private val apfJniUtils = ApfJniUtils(apfInterpreterVersion)
33     companion object {
34         const val TIMEOUT_MS: Long = 1000
35         const val PASS: Int = 1
36         const val DROP: Int = 0
37 
38         // Interpreter will just accept packets without link layer headers, so pad fake packet to at
39         // least the minimum packet size.
40         const val MIN_PKT_SIZE: Int = 15
labelnull41         private fun label(code: Int): String {
42             return when (code) {
43                 PASS -> "PASS"
44                 DROP -> "DROP"
45                 else -> "UNKNOWN"
46             }
47         }
48 
assertReturnCodesEqualnull49         private fun assertReturnCodesEqual(msg: String, expected: Int, got: Int) {
50             assertEquals(label(expected), label(got), msg)
51         }
52 
assertReturnCodesEqualnull53         private fun assertReturnCodesEqual(expected: Int, got: Int) {
54             assertEquals(label(expected), label(got))
55         }
56 
57         /**
58          * Checks the generated APF program equals to the expected value.
59          */
60         @Throws(AssertionError::class)
61         @JvmStatic
assertProgramEqualsnull62         fun assertProgramEquals(expected: ByteArray, program: ByteArray?) {
63             // assertArrayEquals() would only print one byte, making debugging difficult.
64             if (!expected.contentEquals(program)) {
65                 throw AssertionError(
66                     "\nexpected: " + HexDump.toHexString(expected) +
67                     "\nactual:   " + HexDump.toHexString(program)
68                 )
69             }
70         }
71 
decodeCountersIntoMapnull72         fun decodeCountersIntoMap(counterBytes: ByteArray): Map<Counter, Long> {
73             val counters = Counter::class.java.enumConstants
74             val ret = HashMap<Counter, Long>()
75             val skippedCounters = setOf(APF_PROGRAM_ID, APF_VERSION)
76             // starting from index 2 to skip the endianness mark
77             if (counters != null) {
78                 for (c in listOf(*counters).subList(2, counters.size)) {
79                     if (c in skippedCounters) continue
80                     val value = ApfCounterTracker.getCounterValue(counterBytes, c)
81                     if (value != 0L) {
82                         ret[c] = value
83                     }
84                 }
85             }
86             return ret
87         }
88     }
89 
assertVerdictnull90     private fun assertVerdict(
91         apfVersion: Int,
92         expected: Int,
93         program: ByteArray,
94         packet: ByteArray,
95         filterAge: Int
96     ) {
97         val msg = """Unexpected APF verdict. To debug:
98                 apf_run
99                     --program ${HexDump.toHexString(program)}
100                     --packet ${HexDump.toHexString(packet)}
101                     --age $filterAge
102                     ${if (apfVersion > 4) " --v6" else ""}
103                     --trace " + " | less\n
104             """
105         assertReturnCodesEqual(
106             msg,
107             expected,
108             apfJniUtils.apfSimulate(apfVersion, program, packet, null, filterAge)
109         )
110     }
111 
112     @Throws(BaseApfGenerator.IllegalInstructionException::class)
assertVerdictnull113     private fun assertVerdict(
114         apfVersion: Int,
115         expected: Int,
116         gen: ApfV4Generator,
117         packet: ByteArray,
118         filterAge: Int
119     ) {
120         assertVerdict(apfVersion, expected, gen.generate(), packet, null, filterAge)
121     }
122 
assertVerdictnull123     private fun assertVerdict(
124         apfVersion: Int,
125         expected: Int,
126         program: ByteArray,
127         packet: ByteArray,
128         data: ByteArray?,
129         filterAge: Int
130     ) {
131         val msg = "Unexpected APF verdict. To debug: \n" + """
132                 apf_run
133                     --program ${HexDump.toHexString(program)}
134                     --packet ${HexDump.toHexString(packet)}
135                     ${if (data != null) "--data ${HexDump.toHexString(data)}" else ""}
136                     --age $filterAge
137                     ${if (apfVersion > 4) "--v6" else ""}
138                     --trace | less
139             """.replace("\n", " ").replace("\\s+".toRegex(), " ") + "\n"
140         assertReturnCodesEqual(
141             msg,
142             expected,
143             apfJniUtils.apfSimulate(apfVersion, program, packet, data, filterAge)
144         )
145     }
146 
147     /**
148      * Runs the APF program with customized data region and checks the return code.
149      */
assertVerdictnull150     fun assertVerdict(
151         apfVersion: Int,
152         expected: Int,
153         program: ByteArray,
154         packet: ByteArray,
155         data: ByteArray?
156     ) {
157         assertVerdict(apfVersion, expected, program, packet, data, filterAge = 0)
158     }
159 
160     /**
161      * Runs the APF program and checks the return code is equals to expected value. If not, the
162      * customized message is printed.
163      */
assertVerdictnull164     fun assertVerdict(
165         apfVersion: Int,
166         msg: String,
167         expected: Int,
168         program: ByteArray?,
169         packet: ByteArray?,
170         filterAge: Int
171     ) {
172         assertReturnCodesEqual(
173             msg,
174             expected,
175             apfJniUtils.apfSimulate(apfVersion, program, packet, null, filterAge)
176         )
177     }
178 
179     /**
180      * Runs the APF program and checks the return code is equals to expected value.
181      */
assertVerdictnull182     fun assertVerdict(apfVersion: Int, expected: Int, program: ByteArray, packet: ByteArray) {
183         assertVerdict(apfVersion, expected, program, packet, 0)
184     }
185 
186     /**
187      * Runs the APF program and checks the return code is PASS.
188      */
assertPassnull189     fun assertPass(apfVersion: Int, program: ByteArray, packet: ByteArray, filterAge: Int) {
190         assertVerdict(apfVersion, PASS, program, packet, filterAge)
191     }
192 
193     /**
194      * Runs the APF program and checks the return code is PASS.
195      */
assertPassnull196     fun assertPass(apfVersion: Int, program: ByteArray, packet: ByteArray) {
197         assertVerdict(apfVersion, PASS, program, packet)
198     }
199 
200     /**
201      * Runs the APF program and checks the return code is DROP.
202      */
assertDropnull203     fun assertDrop(apfVersion: Int, program: ByteArray, packet: ByteArray, filterAge: Int) {
204         assertVerdict(apfVersion, DROP, program, packet, filterAge)
205     }
206 
207     /**
208      * Runs the APF program and checks the return code is DROP.
209      */
assertDropnull210     fun assertDrop(apfVersion: Int, program: ByteArray, packet: ByteArray) {
211         assertVerdict(apfVersion, DROP, program, packet)
212     }
213 
214     /**
215      * Runs the APF program and checks the return code is PASS.
216      */
217     @Throws(BaseApfGenerator.IllegalInstructionException::class)
assertPassnull218     fun assertPass(apfVersion: Int, gen: ApfV4Generator, packet: ByteArray, filterAge: Int) {
219         assertVerdict(apfVersion, PASS, gen, packet, filterAge)
220     }
221 
222     /**
223      * Runs the APF program and checks the return code is DROP.
224      */
225     @Throws(BaseApfGenerator.IllegalInstructionException::class)
assertDropnull226     fun assertDrop(apfVersion: Int, gen: ApfV4Generator, packet: ByteArray, filterAge: Int) {
227         assertVerdict(apfVersion, DROP, gen, packet, filterAge)
228     }
229 
230     /**
231      * Runs the APF program and checks the return code is PASS.
232      */
233     @Throws(BaseApfGenerator.IllegalInstructionException::class)
assertPassnull234     fun assertPass(apfVersion: Int, gen: ApfV4Generator) {
235         assertVerdict(apfVersion, PASS, gen, ByteArray(MIN_PKT_SIZE), 0)
236     }
237 
238     /**
239      * Runs the APF program and checks the return code is DROP.
240      */
241     @Throws(BaseApfGenerator.IllegalInstructionException::class)
assertDropnull242     fun assertDrop(apfVersion: Int, gen: ApfV4Generator) {
243         assertVerdict(apfVersion, DROP, gen, ByteArray(MIN_PKT_SIZE), 0)
244     }
245 
246     /**
247      * Runs the APF program and checks the return code and data regions
248      * equals to expected value.
249      */
250     @Throws(BaseApfGenerator.IllegalInstructionException::class, Exception::class)
assertDataMemoryContentsnull251     fun assertDataMemoryContents(
252         apfVersion: Int,
253         expected: Int,
254         program: ByteArray?,
255         packet: ByteArray?,
256         data: ByteArray,
257         expectedData: ByteArray,
258         ignoreInterpreterVersion: Boolean
259     ) {
260         assertReturnCodesEqual(
261             expected,
262             apfJniUtils.apfSimulate(apfVersion, program, packet, data, 0)
263         )
264 
265         if (ignoreInterpreterVersion) {
266             val apfVersionIdx = (Counter.totalSize() +
267                     APF_VERSION.offset())
268             val apfProgramIdIdx = (Counter.totalSize() +
269                     APF_PROGRAM_ID.offset())
270             for (i in 0..3) {
271                 data[apfVersionIdx + i] = 0
272                 data[apfProgramIdIdx + i] = 0
273             }
274         }
275         // assertArrayEquals() would only print one byte, making debugging difficult.
276         if (!expectedData.contentEquals(data)) {
277             throw Exception(
278                 ("\nprogram:     " + HexDump.toHexString(program) +
279                         "\ndata memory: " + HexDump.toHexString(data) +
280                         "\nexpected:    " + HexDump.toHexString(expectedData))
281             )
282         }
283     }
284 
verifyProgramRunnull285     fun verifyProgramRun(
286         version: Int,
287         program: ByteArray,
288         pkt: ByteArray,
289         targetCnt: Counter,
290         cntMap: MutableMap<Counter, Long> = mutableMapOf(),
291         dataRegion: ByteArray = ByteArray(Counter.totalSize()) { 0 },
292         incTotal: Boolean = true,
293         result: Int = if (targetCnt.name.startsWith("PASSED")) PASS else DROP
294     ) {
295         assertVerdict(version, result, program, pkt, dataRegion)
296         cntMap[targetCnt] = cntMap.getOrDefault(targetCnt, 0) + 1
297         if (incTotal) {
298             cntMap[TOTAL_PACKETS] = cntMap.getOrDefault(TOTAL_PACKETS, 0) + 1
299         }
300         val errMsg = "Counter is not increased properly. To debug: \n" +
301                 " apf_run --program ${HexDump.toHexString(program)} " +
302                 "--packet ${HexDump.toHexString(pkt)} " +
303                 "--data ${HexDump.toHexString(dataRegion)} --age 0 " +
304                 "${if (version == APF_VERSION_6) "--v6" else "" } --trace  | less \n"
305         assertEquals(cntMap, decodeCountersIntoMap(dataRegion), errMsg)
306     }
307 
consumeInstalledProgramnull308     fun consumeInstalledProgram(
309         apfController: ApfFilter.IApfController,
310         installCnt: Int
311     ): ByteArray {
312         val programCaptor = ArgumentCaptor.forClass(
313             ByteArray::class.java
314         )
315 
316         verify(apfController, timeout(TIMEOUT_MS).times(installCnt)).installPacketFilter(
317             programCaptor.capture(),
318             any()
319         )
320 
321         clearInvocations<Any>(apfController)
322         return programCaptor.value
323     }
324 
consumeTransmittedPacketsnull325     fun consumeTransmittedPackets(
326         expectCnt: Int
327     ): List<ByteArray> {
328         val transmittedPackets = apfJniUtils.getAllTransmittedPackets()
329         assertEquals(expectCnt, transmittedPackets.size)
330         resetTransmittedPacketMemory()
331         return transmittedPackets
332     }
333 
resetTransmittedPacketMemorynull334     fun resetTransmittedPacketMemory() {
335         apfJniUtils.resetTransmittedPacketMemory()
336     }
337 
disassembleApfnull338     fun disassembleApf(program: ByteArray): Array<String> {
339         return apfJniUtils.disassembleApf(program)
340     }
341 
getAllTransmittedPacketsnull342     fun getAllTransmittedPackets(): List<ByteArray> {
343         return apfJniUtils.allTransmittedPackets
344     }
345 
compareBpfApfnull346     fun compareBpfApf(
347         apfVersion: Int,
348         filter: String,
349         pcapFilename: String,
350         apfProgram: ByteArray
351     ): Boolean {
352         return apfJniUtils.compareBpfApf(apfVersion, filter, pcapFilename, apfProgram)
353     }
354 
compileToBpfnull355     fun compileToBpf(filter: String): String {
356         return apfJniUtils.compileToBpf(filter)
357     }
358 
dropsAllPacketsnull359     fun dropsAllPackets(
360         apfVersion: Int,
361         program: ByteArray,
362         data: ByteArray,
363         pcapFilename: String
364     ): Boolean {
365         return apfJniUtils.dropsAllPackets(apfVersion, program, data, pcapFilename)
366     }
367 }
368