1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package android.net.apf 17 18 import android.net.apf.ApfCounterTracker.Counter 19 import android.net.apf.ApfCounterTracker.Counter.APF_PROGRAM_ID 20 import android.net.apf.ApfCounterTracker.Counter.APF_VERSION 21 import android.net.apf.ApfCounterTracker.Counter.TOTAL_PACKETS 22 import android.net.apf.BaseApfGenerator.APF_VERSION_6 23 import com.android.net.module.util.HexDump 24 import kotlin.test.assertEquals 25 import org.mockito.ArgumentCaptor 26 import org.mockito.ArgumentMatchers.any 27 import org.mockito.Mockito.clearInvocations 28 import org.mockito.Mockito.timeout 29 import org.mockito.Mockito.verify 30 31 class ApfTestHelpers(apfInterpreterVersion: Int){ 32 private val apfJniUtils = ApfJniUtils(apfInterpreterVersion) 33 companion object { 34 const val TIMEOUT_MS: Long = 1000 35 const val PASS: Int = 1 36 const val DROP: Int = 0 37 38 // Interpreter will just accept packets without link layer headers, so pad fake packet to at 39 // least the minimum packet size. 40 const val MIN_PKT_SIZE: Int = 15 labelnull41 private fun label(code: Int): String { 42 return when (code) { 43 PASS -> "PASS" 44 DROP -> "DROP" 45 else -> "UNKNOWN" 46 } 47 } 48 assertReturnCodesEqualnull49 private fun assertReturnCodesEqual(msg: String, expected: Int, got: Int) { 50 assertEquals(label(expected), label(got), msg) 51 } 52 assertReturnCodesEqualnull53 private fun assertReturnCodesEqual(expected: Int, got: Int) { 54 assertEquals(label(expected), label(got)) 55 } 56 57 /** 58 * Checks the generated APF program equals to the expected value. 59 */ 60 @Throws(AssertionError::class) 61 @JvmStatic assertProgramEqualsnull62 fun assertProgramEquals(expected: ByteArray, program: ByteArray?) { 63 // assertArrayEquals() would only print one byte, making debugging difficult. 64 if (!expected.contentEquals(program)) { 65 throw AssertionError( 66 "\nexpected: " + HexDump.toHexString(expected) + 67 "\nactual: " + HexDump.toHexString(program) 68 ) 69 } 70 } 71 decodeCountersIntoMapnull72 fun decodeCountersIntoMap(counterBytes: ByteArray): Map<Counter, Long> { 73 val counters = Counter::class.java.enumConstants 74 val ret = HashMap<Counter, Long>() 75 val skippedCounters = setOf(APF_PROGRAM_ID, APF_VERSION) 76 // starting from index 2 to skip the endianness mark 77 if (counters != null) { 78 for (c in listOf(*counters).subList(2, counters.size)) { 79 if (c in skippedCounters) continue 80 val value = ApfCounterTracker.getCounterValue(counterBytes, c) 81 if (value != 0L) { 82 ret[c] = value 83 } 84 } 85 } 86 return ret 87 } 88 } 89 assertVerdictnull90 private fun assertVerdict( 91 apfVersion: Int, 92 expected: Int, 93 program: ByteArray, 94 packet: ByteArray, 95 filterAge: Int 96 ) { 97 val msg = """Unexpected APF verdict. To debug: 98 apf_run 99 --program ${HexDump.toHexString(program)} 100 --packet ${HexDump.toHexString(packet)} 101 --age $filterAge 102 ${if (apfVersion > 4) " --v6" else ""} 103 --trace " + " | less\n 104 """ 105 assertReturnCodesEqual( 106 msg, 107 expected, 108 apfJniUtils.apfSimulate(apfVersion, program, packet, null, filterAge) 109 ) 110 } 111 112 @Throws(BaseApfGenerator.IllegalInstructionException::class) assertVerdictnull113 private fun assertVerdict( 114 apfVersion: Int, 115 expected: Int, 116 gen: ApfV4Generator, 117 packet: ByteArray, 118 filterAge: Int 119 ) { 120 assertVerdict(apfVersion, expected, gen.generate(), packet, null, filterAge) 121 } 122 assertVerdictnull123 private fun assertVerdict( 124 apfVersion: Int, 125 expected: Int, 126 program: ByteArray, 127 packet: ByteArray, 128 data: ByteArray?, 129 filterAge: Int 130 ) { 131 val msg = "Unexpected APF verdict. To debug: \n" + """ 132 apf_run 133 --program ${HexDump.toHexString(program)} 134 --packet ${HexDump.toHexString(packet)} 135 ${if (data != null) "--data ${HexDump.toHexString(data)}" else ""} 136 --age $filterAge 137 ${if (apfVersion > 4) "--v6" else ""} 138 --trace | less 139 """.replace("\n", " ").replace("\\s+".toRegex(), " ") + "\n" 140 assertReturnCodesEqual( 141 msg, 142 expected, 143 apfJniUtils.apfSimulate(apfVersion, program, packet, data, filterAge) 144 ) 145 } 146 147 /** 148 * Runs the APF program with customized data region and checks the return code. 149 */ assertVerdictnull150 fun assertVerdict( 151 apfVersion: Int, 152 expected: Int, 153 program: ByteArray, 154 packet: ByteArray, 155 data: ByteArray? 156 ) { 157 assertVerdict(apfVersion, expected, program, packet, data, filterAge = 0) 158 } 159 160 /** 161 * Runs the APF program and checks the return code is equals to expected value. If not, the 162 * customized message is printed. 163 */ assertVerdictnull164 fun assertVerdict( 165 apfVersion: Int, 166 msg: String, 167 expected: Int, 168 program: ByteArray?, 169 packet: ByteArray?, 170 filterAge: Int 171 ) { 172 assertReturnCodesEqual( 173 msg, 174 expected, 175 apfJniUtils.apfSimulate(apfVersion, program, packet, null, filterAge) 176 ) 177 } 178 179 /** 180 * Runs the APF program and checks the return code is equals to expected value. 181 */ assertVerdictnull182 fun assertVerdict(apfVersion: Int, expected: Int, program: ByteArray, packet: ByteArray) { 183 assertVerdict(apfVersion, expected, program, packet, 0) 184 } 185 186 /** 187 * Runs the APF program and checks the return code is PASS. 188 */ assertPassnull189 fun assertPass(apfVersion: Int, program: ByteArray, packet: ByteArray, filterAge: Int) { 190 assertVerdict(apfVersion, PASS, program, packet, filterAge) 191 } 192 193 /** 194 * Runs the APF program and checks the return code is PASS. 195 */ assertPassnull196 fun assertPass(apfVersion: Int, program: ByteArray, packet: ByteArray) { 197 assertVerdict(apfVersion, PASS, program, packet) 198 } 199 200 /** 201 * Runs the APF program and checks the return code is DROP. 202 */ assertDropnull203 fun assertDrop(apfVersion: Int, program: ByteArray, packet: ByteArray, filterAge: Int) { 204 assertVerdict(apfVersion, DROP, program, packet, filterAge) 205 } 206 207 /** 208 * Runs the APF program and checks the return code is DROP. 209 */ assertDropnull210 fun assertDrop(apfVersion: Int, program: ByteArray, packet: ByteArray) { 211 assertVerdict(apfVersion, DROP, program, packet) 212 } 213 214 /** 215 * Runs the APF program and checks the return code is PASS. 216 */ 217 @Throws(BaseApfGenerator.IllegalInstructionException::class) assertPassnull218 fun assertPass(apfVersion: Int, gen: ApfV4Generator, packet: ByteArray, filterAge: Int) { 219 assertVerdict(apfVersion, PASS, gen, packet, filterAge) 220 } 221 222 /** 223 * Runs the APF program and checks the return code is DROP. 224 */ 225 @Throws(BaseApfGenerator.IllegalInstructionException::class) assertDropnull226 fun assertDrop(apfVersion: Int, gen: ApfV4Generator, packet: ByteArray, filterAge: Int) { 227 assertVerdict(apfVersion, DROP, gen, packet, filterAge) 228 } 229 230 /** 231 * Runs the APF program and checks the return code is PASS. 232 */ 233 @Throws(BaseApfGenerator.IllegalInstructionException::class) assertPassnull234 fun assertPass(apfVersion: Int, gen: ApfV4Generator) { 235 assertVerdict(apfVersion, PASS, gen, ByteArray(MIN_PKT_SIZE), 0) 236 } 237 238 /** 239 * Runs the APF program and checks the return code is DROP. 240 */ 241 @Throws(BaseApfGenerator.IllegalInstructionException::class) assertDropnull242 fun assertDrop(apfVersion: Int, gen: ApfV4Generator) { 243 assertVerdict(apfVersion, DROP, gen, ByteArray(MIN_PKT_SIZE), 0) 244 } 245 246 /** 247 * Runs the APF program and checks the return code and data regions 248 * equals to expected value. 249 */ 250 @Throws(BaseApfGenerator.IllegalInstructionException::class, Exception::class) assertDataMemoryContentsnull251 fun assertDataMemoryContents( 252 apfVersion: Int, 253 expected: Int, 254 program: ByteArray?, 255 packet: ByteArray?, 256 data: ByteArray, 257 expectedData: ByteArray, 258 ignoreInterpreterVersion: Boolean 259 ) { 260 assertReturnCodesEqual( 261 expected, 262 apfJniUtils.apfSimulate(apfVersion, program, packet, data, 0) 263 ) 264 265 if (ignoreInterpreterVersion) { 266 val apfVersionIdx = (Counter.totalSize() + 267 APF_VERSION.offset()) 268 val apfProgramIdIdx = (Counter.totalSize() + 269 APF_PROGRAM_ID.offset()) 270 for (i in 0..3) { 271 data[apfVersionIdx + i] = 0 272 data[apfProgramIdIdx + i] = 0 273 } 274 } 275 // assertArrayEquals() would only print one byte, making debugging difficult. 276 if (!expectedData.contentEquals(data)) { 277 throw Exception( 278 ("\nprogram: " + HexDump.toHexString(program) + 279 "\ndata memory: " + HexDump.toHexString(data) + 280 "\nexpected: " + HexDump.toHexString(expectedData)) 281 ) 282 } 283 } 284 verifyProgramRunnull285 fun verifyProgramRun( 286 version: Int, 287 program: ByteArray, 288 pkt: ByteArray, 289 targetCnt: Counter, 290 cntMap: MutableMap<Counter, Long> = mutableMapOf(), 291 dataRegion: ByteArray = ByteArray(Counter.totalSize()) { 0 }, 292 incTotal: Boolean = true, 293 result: Int = if (targetCnt.name.startsWith("PASSED")) PASS else DROP 294 ) { 295 assertVerdict(version, result, program, pkt, dataRegion) 296 cntMap[targetCnt] = cntMap.getOrDefault(targetCnt, 0) + 1 297 if (incTotal) { 298 cntMap[TOTAL_PACKETS] = cntMap.getOrDefault(TOTAL_PACKETS, 0) + 1 299 } 300 val errMsg = "Counter is not increased properly. To debug: \n" + 301 " apf_run --program ${HexDump.toHexString(program)} " + 302 "--packet ${HexDump.toHexString(pkt)} " + 303 "--data ${HexDump.toHexString(dataRegion)} --age 0 " + 304 "${if (version == APF_VERSION_6) "--v6" else "" } --trace | less \n" 305 assertEquals(cntMap, decodeCountersIntoMap(dataRegion), errMsg) 306 } 307 consumeInstalledProgramnull308 fun consumeInstalledProgram( 309 apfController: ApfFilter.IApfController, 310 installCnt: Int 311 ): ByteArray { 312 val programCaptor = ArgumentCaptor.forClass( 313 ByteArray::class.java 314 ) 315 316 verify(apfController, timeout(TIMEOUT_MS).times(installCnt)).installPacketFilter( 317 programCaptor.capture(), 318 any() 319 ) 320 321 clearInvocations<Any>(apfController) 322 return programCaptor.value 323 } 324 consumeTransmittedPacketsnull325 fun consumeTransmittedPackets( 326 expectCnt: Int 327 ): List<ByteArray> { 328 val transmittedPackets = apfJniUtils.getAllTransmittedPackets() 329 assertEquals(expectCnt, transmittedPackets.size) 330 resetTransmittedPacketMemory() 331 return transmittedPackets 332 } 333 resetTransmittedPacketMemorynull334 fun resetTransmittedPacketMemory() { 335 apfJniUtils.resetTransmittedPacketMemory() 336 } 337 disassembleApfnull338 fun disassembleApf(program: ByteArray): Array<String> { 339 return apfJniUtils.disassembleApf(program) 340 } 341 getAllTransmittedPacketsnull342 fun getAllTransmittedPackets(): List<ByteArray> { 343 return apfJniUtils.allTransmittedPackets 344 } 345 compareBpfApfnull346 fun compareBpfApf( 347 apfVersion: Int, 348 filter: String, 349 pcapFilename: String, 350 apfProgram: ByteArray 351 ): Boolean { 352 return apfJniUtils.compareBpfApf(apfVersion, filter, pcapFilename, apfProgram) 353 } 354 compileToBpfnull355 fun compileToBpf(filter: String): String { 356 return apfJniUtils.compileToBpf(filter) 357 } 358 dropsAllPacketsnull359 fun dropsAllPackets( 360 apfVersion: Int, 361 program: ByteArray, 362 data: ByteArray, 363 pcapFilename: String 364 ): Boolean { 365 return apfJniUtils.dropsAllPackets(apfVersion, program, data, pcapFilename) 366 } 367 } 368