• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 @file:JvmName("PacketReflectorUtil")
18 
19 package com.android.testutils
20 
21 import android.system.ErrnoException
22 import android.system.Os
23 import com.android.net.module.util.IpUtils
24 import com.android.testutils.PacketReflector.IPV4_HEADER_LENGTH
25 import com.android.testutils.PacketReflector.IPV6_HEADER_LENGTH
26 import java.io.FileDescriptor
27 import java.io.IOException
28 import java.net.InetAddress
29 import java.nio.ByteBuffer
30 
readPacketnull31 fun readPacket(fd: FileDescriptor, buf: ByteArray): Int {
32     return try {
33         Os.read(fd, buf, 0, buf.size)
34     } catch (e: ErrnoException) {
35         -1
36     } catch (e: IOException) {
37         -1
38     }
39 }
40 
getInetAddressAtnull41 fun getInetAddressAt(buf: ByteArray, pos: Int, len: Int): InetAddress =
42     InetAddress.getByAddress(buf.copyOfRange(pos, pos + len))
43 
44 /**
45  * Reads a 16-bit unsigned int at pos in big endian, with no alignment requirements.
46  */
47 fun getPortAt(buf: ByteArray, pos: Int): Int {
48     return (buf[pos].toInt() and 0xff shl 8) + (buf[pos + 1].toInt() and 0xff)
49 }
50 
setPortAtnull51 fun setPortAt(port: Int, buf: ByteArray, pos: Int) {
52     buf[pos] = (port ushr 8).toByte()
53     buf[pos + 1] = (port and 0xff).toByte()
54 }
55 
getAddressPositionAndLengthnull56 fun getAddressPositionAndLength(version: Int) = when (version) {
57     4 -> PacketReflector.IPV4_ADDR_OFFSET to PacketReflector.IPV4_ADDR_LENGTH
58     6 -> PacketReflector.IPV6_ADDR_OFFSET to PacketReflector.IPV6_ADDR_LENGTH
59     else -> throw IllegalArgumentException("Unknown IP version $version")
60 }
61 
62 private const val IPV4_CHKSUM_OFFSET = 10
63 private const val UDP_CHECKSUM_OFFSET = 6
64 private const val TCP_CHECKSUM_OFFSET = 16
65 
fixPacketChecksumnull66 fun fixPacketChecksum(buf: ByteArray, len: Int, version: Int, protocol: Byte) {
67     // Fill Ip checksum for IPv4. IPv6 header doesn't have a checksum field.
68     if (version == 4) {
69         val checksum = IpUtils.ipChecksum(ByteBuffer.wrap(buf), 0)
70         // Place checksum in Big-endian order.
71         buf[IPV4_CHKSUM_OFFSET] = (checksum.toInt() ushr 8).toByte()
72         buf[IPV4_CHKSUM_OFFSET + 1] = (checksum.toInt() and 0xff).toByte()
73     }
74 
75     // Fill transport layer checksum.
76     val transportOffset = if (version == 4) IPV4_HEADER_LENGTH else IPV6_HEADER_LENGTH
77     when (protocol) {
78         PacketReflector.IPPROTO_UDP -> {
79             val checksumPos = transportOffset + UDP_CHECKSUM_OFFSET
80             // Clear before calculate.
81             buf[checksumPos + 1] = 0x00
82             buf[checksumPos] = buf[checksumPos + 1]
83             val checksum = IpUtils.udpChecksum(
84                 ByteBuffer.wrap(buf), 0,
85                 transportOffset
86             )
87             buf[checksumPos] = (checksum.toInt() ushr 8).toByte()
88             buf[checksumPos + 1] = (checksum.toInt() and 0xff).toByte()
89         }
90         PacketReflector.IPPROTO_TCP -> {
91             val checksumPos = transportOffset + TCP_CHECKSUM_OFFSET
92             // Clear before calculate.
93             buf[checksumPos + 1] = 0x00
94             buf[checksumPos] = buf[checksumPos + 1]
95             val transportLen: Int = len - transportOffset
96             val checksum = IpUtils.tcpChecksum(
97                 ByteBuffer.wrap(buf), 0, transportOffset,
98                 transportLen
99             )
100             buf[checksumPos] = (checksum.toInt() ushr 8).toByte()
101             buf[checksumPos + 1] = (checksum.toInt() and 0xff).toByte()
102         }
103         // TODO: Support ICMP.
104         else -> throw IllegalArgumentException("Unsupported protocol: $protocol")
105     }
106 }
107