1 /*
<lambda>null2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 package android.net.cts
17
18 import android.net.DnsResolver
19 import android.net.Network
20 import android.net.nsd.NsdManager
21 import android.net.nsd.NsdServiceInfo
22 import android.os.Process
23 import com.android.net.module.util.ArrayTrackRecord
24 import com.android.net.module.util.DnsPacket
25 import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
26 import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
27 import com.android.net.module.util.NetworkStackConstants.IPV6_DST_ADDR_OFFSET
28 import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
29 import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
30 import com.android.net.module.util.TrackRecord
31 import com.android.testutils.IPv6UdpFilter
32 import com.android.testutils.TapPacketReader
33 import java.net.Inet6Address
34 import java.net.InetAddress
35 import kotlin.test.assertEquals
36 import kotlin.test.assertNotNull
37 import kotlin.test.assertNull
38 import kotlin.test.assertTrue
39 import kotlin.test.fail
40
41 private const val MDNS_REGISTRATION_TIMEOUT_MS = 10_000L
42 private const val MDNS_PORT = 5353.toShort()
43 const val MDNS_CALLBACK_TIMEOUT = 2000L
44 const val MDNS_NO_CALLBACK_TIMEOUT_MS = 200L
45
46 interface NsdEvent
47 open class NsdRecord<T : NsdEvent> private constructor(
48 private val history: ArrayTrackRecord<T>,
49 private val expectedThreadId: Int? = null
50 ) : TrackRecord<T> by history {
51 constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
52
53 val nextEvents = history.newReadHead()
54
55 override fun add(e: T): Boolean {
56 if (expectedThreadId != null) {
57 assertEquals(
58 expectedThreadId, Process.myTid(),
59 "Callback is running on the wrong thread"
60 )
61 }
62 return history.add(e)
63 }
64
65 inline fun <reified V : NsdEvent> expectCallbackEventually(
66 timeoutMs: Long = MDNS_CALLBACK_TIMEOUT,
67 crossinline predicate: (V) -> Boolean = { true }
68 ): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V?
69 ?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms")
70
71 inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = MDNS_CALLBACK_TIMEOUT): V {
72 val nextEvent = nextEvents.poll(timeoutMs)
73 assertNotNull(
74 nextEvent, "No callback received after $timeoutMs ms, expected " +
75 "${V::class.java.simpleName}"
76 )
77 assertTrue(
78 nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
79 nextEvent.javaClass.simpleName
80 )
81 return nextEvent
82 }
83
84 inline fun assertNoCallback(timeoutMs: Long = MDNS_NO_CALLBACK_TIMEOUT_MS) {
85 val cb = nextEvents.poll(timeoutMs)
86 assertNull(cb, "Expected no callback but got $cb")
87 }
88 }
89
90 class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
91 NsdManager.DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
92 sealed class DiscoveryEvent : NsdEvent {
93 data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) :
94 DiscoveryEvent()
95
96 data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) :
97 DiscoveryEvent()
98
99 data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
100 data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
101 data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
102 data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
103 }
104
onStartDiscoveryFailednull105 override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
106 add(DiscoveryEvent.StartDiscoveryFailed(serviceType, err))
107 }
108
onStopDiscoveryFailednull109 override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
110 add(DiscoveryEvent.StopDiscoveryFailed(serviceType, err))
111 }
112
onDiscoveryStartednull113 override fun onDiscoveryStarted(serviceType: String) {
114 add(DiscoveryEvent.DiscoveryStarted(serviceType))
115 }
116
onDiscoveryStoppednull117 override fun onDiscoveryStopped(serviceType: String) {
118 add(DiscoveryEvent.DiscoveryStopped(serviceType))
119 }
120
onServiceFoundnull121 override fun onServiceFound(si: NsdServiceInfo) {
122 add(DiscoveryEvent.ServiceFound(si))
123 }
124
onServiceLostnull125 override fun onServiceLost(si: NsdServiceInfo) {
126 add(DiscoveryEvent.ServiceLost(si))
127 }
128
waitForServiceDiscoverednull129 fun waitForServiceDiscovered(
130 serviceName: String,
131 serviceType: String,
132 expectedNetwork: Network? = null
133 ): NsdServiceInfo {
134 val serviceFound = expectCallbackEventually<DiscoveryEvent.ServiceFound> {
135 it.serviceInfo.serviceName == serviceName &&
136 (expectedNetwork == null ||
137 expectedNetwork == it.serviceInfo.network)
138 }.serviceInfo
139 // Discovered service types have a dot at the end
140 assertEquals("$serviceType.", serviceFound.serviceType)
141 return serviceFound
142 }
143 }
144
145 class NsdRegistrationRecord(expectedThreadId: Int? = null) : NsdManager.RegistrationListener,
146 NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
147 sealed class RegistrationEvent : NsdEvent {
148 abstract val serviceInfo: NsdServiceInfo
149
150 data class RegistrationFailed(
151 override val serviceInfo: NsdServiceInfo,
152 val errorCode: Int
153 ) : RegistrationEvent()
154
155 data class UnregistrationFailed(
156 override val serviceInfo: NsdServiceInfo,
157 val errorCode: Int
158 ) : RegistrationEvent()
159
160 data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) :
161 RegistrationEvent()
162
163 data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) :
164 RegistrationEvent()
165 }
166
onRegistrationFailednull167 override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
168 add(RegistrationEvent.RegistrationFailed(si, err))
169 }
170
onUnregistrationFailednull171 override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
172 add(RegistrationEvent.UnregistrationFailed(si, err))
173 }
174
onServiceRegisterednull175 override fun onServiceRegistered(si: NsdServiceInfo) {
176 add(RegistrationEvent.ServiceRegistered(si))
177 }
178
onServiceUnregisterednull179 override fun onServiceUnregistered(si: NsdServiceInfo) {
180 add(RegistrationEvent.ServiceUnregistered(si))
181 }
182 }
183
184 class NsdResolveRecord : NsdManager.ResolveListener,
185 NsdRecord<NsdResolveRecord.ResolveEvent>() {
186 sealed class ResolveEvent : NsdEvent {
187 data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
188 ResolveEvent()
189
190 data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
191 data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
192 data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
193 ResolveEvent()
194 }
195
onResolveFailednull196 override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
197 add(ResolveEvent.ResolveFailed(si, err))
198 }
199
onServiceResolvednull200 override fun onServiceResolved(si: NsdServiceInfo) {
201 add(ResolveEvent.ServiceResolved(si))
202 }
203
onResolutionStoppednull204 override fun onResolutionStopped(si: NsdServiceInfo) {
205 add(ResolveEvent.ResolutionStopped(si))
206 }
207
onStopResolutionFailednull208 override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) {
209 super.onStopResolutionFailed(si, err)
210 add(ResolveEvent.StopResolutionFailed(si, err))
211 }
212 }
213
214 class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
215 NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
216 sealed class ServiceInfoCallbackEvent : NsdEvent {
217 data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
218 data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent()
219 object ServiceUpdatedLost : ServiceInfoCallbackEvent()
220 object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent()
221 }
222
onServiceInfoCallbackRegistrationFailednull223 override fun onServiceInfoCallbackRegistrationFailed(err: Int) {
224 add(ServiceInfoCallbackEvent.RegisterCallbackFailed(err))
225 }
226
onServiceUpdatednull227 override fun onServiceUpdated(si: NsdServiceInfo) {
228 add(ServiceInfoCallbackEvent.ServiceUpdated(si))
229 }
230
onServiceLostnull231 override fun onServiceLost() {
232 add(ServiceInfoCallbackEvent.ServiceUpdatedLost)
233 }
234
onServiceInfoCallbackUnregisterednull235 override fun onServiceInfoCallbackUnregistered() {
236 add(ServiceInfoCallbackEvent.UnregisterCallbackSucceeded)
237 }
238 }
239
getMdnsPayloadnull240 private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange(
241 ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size)
242
243 private fun getDstAddr(packet: ByteArray): Inet6Address {
244 val v6AddrPos = ETHER_HEADER_LEN + IPV6_DST_ADDR_OFFSET
245 return Inet6Address.getByAddress(packet.copyOfRange(v6AddrPos, v6AddrPos + IPV6_ADDR_LEN))
246 as Inet6Address
247 }
248
pollForMdnsPacketnull249 fun TapPacketReader.pollForMdnsPacket(
250 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
251 predicate: (TestDnsPacket) -> Boolean
252 ): TestDnsPacket? {
253 val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
254 val dst = getDstAddr(it)
255 val mdnsPayload = getMdnsPayload(it)
256 try {
257 predicate(TestDnsPacket(mdnsPayload, dst))
258 } catch (e: DnsPacket.ParseException) {
259 false
260 }
261 }
262 return poll(timeoutMs, mdnsProbeFilter)?.let {
263 TestDnsPacket(getMdnsPayload(it), getDstAddr(it))
264 }
265 }
266
pollForProbenull267 fun TapPacketReader.pollForProbe(
268 serviceName: String,
269 serviceType: String,
270 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
271 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
272 it.isProbeFor("$serviceName.$serviceType.local")
273 }
274
pollForAdvertisementnull275 fun TapPacketReader.pollForAdvertisement(
276 serviceName: String,
277 serviceType: String,
278 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
279 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
280 it.isReplyFor("$serviceName.$serviceType.local")
281 }
282
TapPacketReadernull283 fun TapPacketReader.pollForQuery(
284 recordName: String,
285 vararg requiredTypes: Int,
286 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
287 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) }
288
TapPacketReadernull289 fun TapPacketReader.pollForReply(
290 recordName: String,
291 type: Int,
292 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
293 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isReplyFor(recordName, type) }
294
TapPacketReadernull295 fun TapPacketReader.pollForReply(
296 serviceName: String,
297 serviceType: String,
298 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
299 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
300 it.isReplyFor("$serviceName.$serviceType.local")
301 }
302
303 class TestDnsPacket(data: ByteArray, val dstAddr: InetAddress) : DnsPacket(data) {
304 val header: DnsHeader
305 get() = mHeader
306 val records: Array<List<DnsRecord>>
307 get() = mRecords
<lambda>null308 fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
309 it.dName == name && it.nsType == DnsResolver.TYPE_ANY
310 }
311
isReplyFornull312 fun isReplyFor(name: String, type: Int = DnsResolver.TYPE_SRV): Boolean =
313 mRecords[ANSECTION].any {
314 it.dName == name && it.nsType == type
315 }
316
isQueryFornull317 fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type ->
318 mRecords[QDSECTION].any {
319 it.dName == name && it.nsType == type
320 }
321 }
322 }
323