• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.net.cts
17 
18 import android.net.DnsResolver
19 import android.net.Network
20 import android.net.nsd.NsdManager
21 import android.net.nsd.NsdServiceInfo
22 import android.os.Process
23 import com.android.net.module.util.ArrayTrackRecord
24 import com.android.net.module.util.DnsPacket
25 import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
26 import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
27 import com.android.net.module.util.NetworkStackConstants.IPV6_DST_ADDR_OFFSET
28 import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
29 import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
30 import com.android.net.module.util.TrackRecord
31 import com.android.testutils.IPv6UdpFilter
32 import com.android.testutils.TapPacketReader
33 import java.net.Inet6Address
34 import java.net.InetAddress
35 import kotlin.test.assertEquals
36 import kotlin.test.assertNotNull
37 import kotlin.test.assertNull
38 import kotlin.test.assertTrue
39 import kotlin.test.fail
40 
41 private const val MDNS_REGISTRATION_TIMEOUT_MS = 10_000L
42 private const val MDNS_PORT = 5353.toShort()
43 const val MDNS_CALLBACK_TIMEOUT = 2000L
44 const val MDNS_NO_CALLBACK_TIMEOUT_MS = 200L
45 
46 interface NsdEvent
47 open class NsdRecord<T : NsdEvent> private constructor(
48     private val history: ArrayTrackRecord<T>,
49     private val expectedThreadId: Int? = null
50 ) : TrackRecord<T> by history {
51     constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
52 
53     val nextEvents = history.newReadHead()
54 
55     override fun add(e: T): Boolean {
56         if (expectedThreadId != null) {
57             assertEquals(
58                 expectedThreadId, Process.myTid(),
59                 "Callback is running on the wrong thread"
60             )
61         }
62         return history.add(e)
63     }
64 
65     inline fun <reified V : NsdEvent> expectCallbackEventually(
66         timeoutMs: Long = MDNS_CALLBACK_TIMEOUT,
67         crossinline predicate: (V) -> Boolean = { true }
68     ): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V?
69         ?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms")
70 
71     inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = MDNS_CALLBACK_TIMEOUT): V {
72         val nextEvent = nextEvents.poll(timeoutMs)
73         assertNotNull(
74             nextEvent, "No callback received after $timeoutMs ms, expected " +
75                     "${V::class.java.simpleName}"
76         )
77         assertTrue(
78             nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
79                     nextEvent.javaClass.simpleName
80         )
81         return nextEvent
82     }
83 
84     inline fun assertNoCallback(timeoutMs: Long = MDNS_NO_CALLBACK_TIMEOUT_MS) {
85         val cb = nextEvents.poll(timeoutMs)
86         assertNull(cb, "Expected no callback but got $cb")
87     }
88 }
89 
90 class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
91     NsdManager.DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
92     sealed class DiscoveryEvent : NsdEvent {
93         data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) :
94             DiscoveryEvent()
95 
96         data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) :
97             DiscoveryEvent()
98 
99         data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
100         data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
101         data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
102         data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
103     }
104 
onStartDiscoveryFailednull105     override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
106         add(DiscoveryEvent.StartDiscoveryFailed(serviceType, err))
107     }
108 
onStopDiscoveryFailednull109     override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
110         add(DiscoveryEvent.StopDiscoveryFailed(serviceType, err))
111     }
112 
onDiscoveryStartednull113     override fun onDiscoveryStarted(serviceType: String) {
114         add(DiscoveryEvent.DiscoveryStarted(serviceType))
115     }
116 
onDiscoveryStoppednull117     override fun onDiscoveryStopped(serviceType: String) {
118         add(DiscoveryEvent.DiscoveryStopped(serviceType))
119     }
120 
onServiceFoundnull121     override fun onServiceFound(si: NsdServiceInfo) {
122         add(DiscoveryEvent.ServiceFound(si))
123     }
124 
onServiceLostnull125     override fun onServiceLost(si: NsdServiceInfo) {
126         add(DiscoveryEvent.ServiceLost(si))
127     }
128 
waitForServiceDiscoverednull129     fun waitForServiceDiscovered(
130         serviceName: String,
131         serviceType: String,
132         expectedNetwork: Network? = null
133     ): NsdServiceInfo {
134         val serviceFound = expectCallbackEventually<DiscoveryEvent.ServiceFound> {
135             it.serviceInfo.serviceName == serviceName &&
136                     (expectedNetwork == null ||
137                             expectedNetwork == it.serviceInfo.network)
138         }.serviceInfo
139         // Discovered service types have a dot at the end
140         assertEquals("$serviceType.", serviceFound.serviceType)
141         return serviceFound
142     }
143 }
144 
145 class NsdRegistrationRecord(expectedThreadId: Int? = null) : NsdManager.RegistrationListener,
146     NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
147     sealed class RegistrationEvent : NsdEvent {
148         abstract val serviceInfo: NsdServiceInfo
149 
150         data class RegistrationFailed(
151             override val serviceInfo: NsdServiceInfo,
152             val errorCode: Int
153         ) : RegistrationEvent()
154 
155         data class UnregistrationFailed(
156             override val serviceInfo: NsdServiceInfo,
157             val errorCode: Int
158         ) : RegistrationEvent()
159 
160         data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) :
161             RegistrationEvent()
162 
163         data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) :
164             RegistrationEvent()
165     }
166 
onRegistrationFailednull167     override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
168         add(RegistrationEvent.RegistrationFailed(si, err))
169     }
170 
onUnregistrationFailednull171     override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
172         add(RegistrationEvent.UnregistrationFailed(si, err))
173     }
174 
onServiceRegisterednull175     override fun onServiceRegistered(si: NsdServiceInfo) {
176         add(RegistrationEvent.ServiceRegistered(si))
177     }
178 
onServiceUnregisterednull179     override fun onServiceUnregistered(si: NsdServiceInfo) {
180         add(RegistrationEvent.ServiceUnregistered(si))
181     }
182 }
183 
184 class NsdResolveRecord : NsdManager.ResolveListener,
185     NsdRecord<NsdResolveRecord.ResolveEvent>() {
186     sealed class ResolveEvent : NsdEvent {
187         data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
188             ResolveEvent()
189 
190         data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
191         data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
192         data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
193             ResolveEvent()
194     }
195 
onResolveFailednull196     override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
197         add(ResolveEvent.ResolveFailed(si, err))
198     }
199 
onServiceResolvednull200     override fun onServiceResolved(si: NsdServiceInfo) {
201         add(ResolveEvent.ServiceResolved(si))
202     }
203 
onResolutionStoppednull204     override fun onResolutionStopped(si: NsdServiceInfo) {
205         add(ResolveEvent.ResolutionStopped(si))
206     }
207 
onStopResolutionFailednull208     override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) {
209         super.onStopResolutionFailed(si, err)
210         add(ResolveEvent.StopResolutionFailed(si, err))
211     }
212 }
213 
214 class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
215     NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
216     sealed class ServiceInfoCallbackEvent : NsdEvent {
217         data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
218         data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent()
219         object ServiceUpdatedLost : ServiceInfoCallbackEvent()
220         object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent()
221     }
222 
onServiceInfoCallbackRegistrationFailednull223     override fun onServiceInfoCallbackRegistrationFailed(err: Int) {
224         add(ServiceInfoCallbackEvent.RegisterCallbackFailed(err))
225     }
226 
onServiceUpdatednull227     override fun onServiceUpdated(si: NsdServiceInfo) {
228         add(ServiceInfoCallbackEvent.ServiceUpdated(si))
229     }
230 
onServiceLostnull231     override fun onServiceLost() {
232         add(ServiceInfoCallbackEvent.ServiceUpdatedLost)
233     }
234 
onServiceInfoCallbackUnregisterednull235     override fun onServiceInfoCallbackUnregistered() {
236         add(ServiceInfoCallbackEvent.UnregisterCallbackSucceeded)
237     }
238 }
239 
getMdnsPayloadnull240 private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange(
241     ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size)
242 
243 private fun getDstAddr(packet: ByteArray): Inet6Address {
244     val v6AddrPos = ETHER_HEADER_LEN + IPV6_DST_ADDR_OFFSET
245     return Inet6Address.getByAddress(packet.copyOfRange(v6AddrPos, v6AddrPos + IPV6_ADDR_LEN))
246             as Inet6Address
247 }
248 
pollForMdnsPacketnull249 fun TapPacketReader.pollForMdnsPacket(
250     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
251     predicate: (TestDnsPacket) -> Boolean
252 ): TestDnsPacket? {
253     val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
254         val dst = getDstAddr(it)
255         val mdnsPayload = getMdnsPayload(it)
256         try {
257             predicate(TestDnsPacket(mdnsPayload, dst))
258         } catch (e: DnsPacket.ParseException) {
259             false
260         }
261     }
262     return poll(timeoutMs, mdnsProbeFilter)?.let {
263         TestDnsPacket(getMdnsPayload(it), getDstAddr(it))
264     }
265 }
266 
pollForProbenull267 fun TapPacketReader.pollForProbe(
268     serviceName: String,
269     serviceType: String,
270     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
271 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
272     it.isProbeFor("$serviceName.$serviceType.local")
273 }
274 
pollForAdvertisementnull275 fun TapPacketReader.pollForAdvertisement(
276     serviceName: String,
277     serviceType: String,
278     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
279 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
280     it.isReplyFor("$serviceName.$serviceType.local")
281 }
282 
TapPacketReadernull283 fun TapPacketReader.pollForQuery(
284     recordName: String,
285     vararg requiredTypes: Int,
286     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
287 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) }
288 
TapPacketReadernull289 fun TapPacketReader.pollForReply(
290     recordName: String,
291     type: Int,
292     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
293 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isReplyFor(recordName, type) }
294 
TapPacketReadernull295 fun TapPacketReader.pollForReply(
296     serviceName: String,
297     serviceType: String,
298     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
299 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
300     it.isReplyFor("$serviceName.$serviceType.local")
301 }
302 
303 class TestDnsPacket(data: ByteArray, val dstAddr: InetAddress) : DnsPacket(data) {
304     val header: DnsHeader
305         get() = mHeader
306     val records: Array<List<DnsRecord>>
307         get() = mRecords
<lambda>null308     fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
309         it.dName == name && it.nsType == DnsResolver.TYPE_ANY
310     }
311 
isReplyFornull312     fun isReplyFor(name: String, type: Int = DnsResolver.TYPE_SRV): Boolean =
313         mRecords[ANSECTION].any {
314             it.dName == name && it.nsType == type
315         }
316 
isQueryFornull317     fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type ->
318         mRecords[QDSECTION].any {
319             it.dName == name && it.nsType == type
320         }
321     }
322 }
323