• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.net.cts
17 
18 import android.Manifest.permission.MANAGE_TEST_NETWORKS
19 import android.net.ConnectivityManager
20 import android.net.ConnectivityManager.NetworkCallback
21 import android.net.LinkProperties
22 import android.net.Network
23 import android.net.NetworkAgentConfig
24 import android.net.NetworkCapabilities
25 import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED
26 import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
27 import android.net.NetworkCapabilities.TRANSPORT_TEST
28 import android.net.NetworkRequest
29 import android.net.TestNetworkInterface
30 import android.net.TestNetworkManager
31 import android.net.TestNetworkSpecifier
32 import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStarted
33 import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.DiscoveryStopped
34 import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceFound
35 import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.ServiceLost
36 import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StartDiscoveryFailed
37 import android.net.cts.NsdManagerTest.NsdDiscoveryRecord.DiscoveryEvent.StopDiscoveryFailed
38 import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.RegistrationFailed
39 import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceRegistered
40 import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceUnregistered
41 import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.UnregistrationFailed
42 import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolveFailed
43 import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ServiceResolved
44 import android.net.nsd.NsdManager
45 import android.net.nsd.NsdManager.DiscoveryListener
46 import android.net.nsd.NsdManager.RegistrationListener
47 import android.net.nsd.NsdManager.ResolveListener
48 import android.net.nsd.NsdServiceInfo
49 import android.os.Handler
50 import android.os.HandlerThread
51 import android.os.Process.myTid
52 import android.platform.test.annotations.AppModeFull
53 import android.util.Log
54 import androidx.test.platform.app.InstrumentationRegistry
55 import androidx.test.runner.AndroidJUnit4
56 import com.android.net.module.util.ArrayTrackRecord
57 import com.android.net.module.util.TrackRecord
58 import com.android.networkstack.apishim.NsdShimImpl
59 import com.android.testutils.TestableNetworkAgent
60 import com.android.testutils.TestableNetworkCallback
61 import com.android.testutils.runAsShell
62 import com.android.testutils.tryTest
63 import org.junit.After
64 import org.junit.Assert.assertArrayEquals
65 import org.junit.Assert.assertTrue
66 import org.junit.Assume.assumeTrue
67 import org.junit.Before
68 import org.junit.Test
69 import org.junit.runner.RunWith
70 import java.net.ServerSocket
71 import java.nio.charset.StandardCharsets
72 import java.util.Random
73 import java.util.concurrent.Executor
74 import kotlin.test.assertEquals
75 import kotlin.test.assertFailsWith
76 import kotlin.test.assertNotNull
77 import kotlin.test.assertNull
78 import kotlin.test.assertTrue
79 import kotlin.test.fail
80 
81 private const val TAG = "NsdManagerTest"
82 private const val SERVICE_TYPE = "_nmt._tcp"
83 private const val TIMEOUT_MS = 2000L
84 private const val NO_CALLBACK_TIMEOUT_MS = 200L
85 private const val DBG = false
86 
87 private val nsdShim = NsdShimImpl.newInstance()
88 
89 @AppModeFull(reason = "Socket cannot bind in instant app mode")
90 @RunWith(AndroidJUnit4::class)
91 class NsdManagerTest {
92     private val context by lazy { InstrumentationRegistry.getInstrumentation().context }
93     private val nsdManager by lazy { context.getSystemService(NsdManager::class.java) }
94 
95     private val cm by lazy { context.getSystemService(ConnectivityManager::class.java) }
96     private val serviceName = "NsdTest%09d".format(Random().nextInt(1_000_000_000))
97     private val handlerThread = HandlerThread(NsdManagerTest::class.java.simpleName)
98 
99     private lateinit var testNetwork1: TestTapNetwork
100     private lateinit var testNetwork2: TestTapNetwork
101 
102     private class TestTapNetwork(
103         val iface: TestNetworkInterface,
104         val requestCb: NetworkCallback,
105         val agent: TestableNetworkAgent,
106         val network: Network
107     ) {
108         fun close(cm: ConnectivityManager) {
109             cm.unregisterNetworkCallback(requestCb)
110             agent.unregister()
111             iface.fileDescriptor.close()
112         }
113     }
114 
115     private interface NsdEvent
116     private open class NsdRecord<T : NsdEvent> private constructor(
117         private val history: ArrayTrackRecord<T>,
118         private val expectedThreadId: Int? = null
119     ) : TrackRecord<T> by history {
120         constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
121 
122         val nextEvents = history.newReadHead()
123 
124         override fun add(e: T): Boolean {
125             if (expectedThreadId != null) {
126                 assertEquals(expectedThreadId, myTid(), "Callback is running on the wrong thread")
127             }
128             return history.add(e)
129         }
130 
131         inline fun <reified V : NsdEvent> expectCallbackEventually(
132             crossinline predicate: (V) -> Boolean = { true }
133         ): V = nextEvents.poll(TIMEOUT_MS) { e -> e is V && predicate(e) } as V?
134                 ?: fail("Callback for ${V::class.java.simpleName} not seen after $TIMEOUT_MS ms")
135 
136         inline fun <reified V : NsdEvent> expectCallback(): V {
137             val nextEvent = nextEvents.poll(TIMEOUT_MS)
138             assertNotNull(nextEvent, "No callback received after $TIMEOUT_MS ms")
139             assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
140                     nextEvent.javaClass.simpleName)
141             return nextEvent
142         }
143 
144         inline fun assertNoCallback(timeoutMs: Long = NO_CALLBACK_TIMEOUT_MS) {
145             val cb = nextEvents.poll(timeoutMs)
146             assertNull(cb, "Expected no callback but got $cb")
147         }
148     }
149 
150     private class NsdRegistrationRecord(expectedThreadId: Int? = null) : RegistrationListener,
151             NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
152         sealed class RegistrationEvent : NsdEvent {
153             abstract val serviceInfo: NsdServiceInfo
154 
155             data class RegistrationFailed(
156                 override val serviceInfo: NsdServiceInfo,
157                 val errorCode: Int
158             ) : RegistrationEvent()
159 
160             data class UnregistrationFailed(
161                 override val serviceInfo: NsdServiceInfo,
162                 val errorCode: Int
163             ) : RegistrationEvent()
164 
165             data class ServiceRegistered(override val serviceInfo: NsdServiceInfo)
166                 : RegistrationEvent()
167             data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo)
168                 : RegistrationEvent()
169         }
170 
171         override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
172             add(RegistrationFailed(si, err))
173         }
174 
175         override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
176             add(UnregistrationFailed(si, err))
177         }
178 
179         override fun onServiceRegistered(si: NsdServiceInfo) {
180             add(ServiceRegistered(si))
181         }
182 
183         override fun onServiceUnregistered(si: NsdServiceInfo) {
184             add(ServiceUnregistered(si))
185         }
186     }
187 
188     private class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
189             DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
190         sealed class DiscoveryEvent : NsdEvent {
191             data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int)
192                 : DiscoveryEvent()
193 
194             data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int)
195                 : DiscoveryEvent()
196 
197             data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
198             data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
199             data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
200             data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
201         }
202 
203         override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
204             add(StartDiscoveryFailed(serviceType, err))
205         }
206 
207         override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
208             add(StopDiscoveryFailed(serviceType, err))
209         }
210 
211         override fun onDiscoveryStarted(serviceType: String) {
212             add(DiscoveryStarted(serviceType))
213         }
214 
215         override fun onDiscoveryStopped(serviceType: String) {
216             add(DiscoveryStopped(serviceType))
217         }
218 
219         override fun onServiceFound(si: NsdServiceInfo) {
220             add(ServiceFound(si))
221         }
222 
223         override fun onServiceLost(si: NsdServiceInfo) {
224             add(ServiceLost(si))
225         }
226 
227         fun waitForServiceDiscovered(
228             serviceName: String,
229             expectedNetwork: Network? = null
230         ): NsdServiceInfo {
231             return expectCallbackEventually<ServiceFound> {
232                 it.serviceInfo.serviceName == serviceName &&
233                         (expectedNetwork == null ||
234                                 expectedNetwork == nsdShim.getNetwork(it.serviceInfo))
235             }.serviceInfo
236         }
237     }
238 
239     private class NsdResolveRecord : ResolveListener,
240             NsdRecord<NsdResolveRecord.ResolveEvent>() {
241         sealed class ResolveEvent : NsdEvent {
242             data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int)
243                 : ResolveEvent()
244 
245             data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
246         }
247 
248         override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
249             add(ResolveFailed(si, err))
250         }
251 
252         override fun onServiceResolved(si: NsdServiceInfo) {
253             add(ServiceResolved(si))
254         }
255     }
256 
257     @Before
258     fun setUp() {
259         handlerThread.start()
260 
261         if (TestUtils.shouldTestTApis()) {
262             runAsShell(MANAGE_TEST_NETWORKS) {
263                 testNetwork1 = createTestNetwork()
264                 testNetwork2 = createTestNetwork()
265             }
266         }
267     }
268 
269     private fun createTestNetwork(): TestTapNetwork {
270         val tnm = context.getSystemService(TestNetworkManager::class.java)
271         val iface = tnm.createTapInterface()
272         val cb = TestableNetworkCallback()
273         val testNetworkSpecifier = TestNetworkSpecifier(iface.interfaceName)
274         cm.requestNetwork(NetworkRequest.Builder()
275                 .removeCapability(NET_CAPABILITY_TRUSTED)
276                 .addTransportType(TRANSPORT_TEST)
277                 .setNetworkSpecifier(testNetworkSpecifier)
278                 .build(), cb)
279         val agent = registerTestNetworkAgent(iface.interfaceName)
280         val network = agent.network ?: fail("Registered agent should have a network")
281         // The network has no INTERNET capability, so will be marked validated immediately
282         cb.expectAvailableThenValidatedCallbacks(network)
283         return TestTapNetwork(iface, cb, agent, network)
284     }
285 
286     private fun registerTestNetworkAgent(ifaceName: String): TestableNetworkAgent {
287         val agent = TestableNetworkAgent(context, handlerThread.looper,
288                 NetworkCapabilities().apply {
289                     removeCapability(NET_CAPABILITY_TRUSTED)
290                     addTransportType(TRANSPORT_TEST)
291                     setNetworkSpecifier(TestNetworkSpecifier(ifaceName))
292                 },
293                 LinkProperties().apply {
294                     interfaceName = ifaceName
295                 },
296                 NetworkAgentConfig.Builder().build())
297         agent.register()
298         agent.markConnected()
299         return agent
300     }
301 
302     @After
303     fun tearDown() {
304         if (TestUtils.shouldTestTApis()) {
305             runAsShell(MANAGE_TEST_NETWORKS) {
306                 testNetwork1.close(cm)
307                 testNetwork2.close(cm)
308             }
309         }
310         handlerThread.quitSafely()
311     }
312 
313     @Test
314     fun testNsdManager() {
315         val si = NsdServiceInfo()
316         si.serviceType = SERVICE_TYPE
317         si.serviceName = serviceName
318         // Test binary data with various bytes
319         val testByteArray = byteArrayOf(-128, 127, 2, 1, 0, 1, 2)
320         // Test string data with 256 characters (25 blocks of 10 characters + 6)
321         val string256 = "1_________2_________3_________4_________5_________6_________" +
322                 "7_________8_________9_________10________11________12________13________" +
323                 "14________15________16________17________18________19________20________" +
324                 "21________22________23________24________25________123456"
325 
326         // Illegal attributes
327         listOf(
328                 Triple(null, null, "null key"),
329                 Triple("", null, "empty key"),
330                 Triple(string256, null, "key with 256 characters"),
331                 Triple("key", string256.substring(3),
332                         "key+value combination with more than 255 characters"),
333                 Triple("key", string256.substring(4), "key+value combination with 255 characters"),
334                 Triple("\u0019", null, "key with invalid character"),
335                 Triple("=", null, "key with invalid character"),
336                 Triple("\u007f", null, "key with invalid character")
337         ).forEach {
338             assertFailsWith<IllegalArgumentException>(
339                     "Setting invalid ${it.third} unexpectedly succeeded") {
340                 si.setAttribute(it.first, it.second)
341             }
342         }
343 
344         // Allowed attributes
345         si.setAttribute("booleanAttr", null as String?)
346         si.setAttribute("keyValueAttr", "value")
347         si.setAttribute("keyEqualsAttr", "=")
348         si.setAttribute(" whiteSpaceKeyValueAttr ", " value ")
349         si.setAttribute("binaryDataAttr", testByteArray)
350         si.setAttribute("nullBinaryDataAttr", null as ByteArray?)
351         si.setAttribute("emptyBinaryDataAttr", byteArrayOf())
352         si.setAttribute("longkey", string256.substring(9))
353         val socket = ServerSocket(0)
354         val localPort = socket.localPort
355         si.port = localPort
356         if (DBG) Log.d(TAG, "Port = $localPort")
357 
358         val registrationRecord = NsdRegistrationRecord()
359         // Test registering without an Executor
360         nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, registrationRecord)
361         val registeredInfo = registrationRecord.expectCallback<ServiceRegistered>().serviceInfo
362 
363         val discoveryRecord = NsdDiscoveryRecord()
364         // Test discovering without an Executor
365         nsdManager.discoverServices(SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD, discoveryRecord)
366 
367         // Expect discovery started
368         discoveryRecord.expectCallback<DiscoveryStarted>()
369 
370         // Expect a service record to be discovered
371         val foundInfo = discoveryRecord.waitForServiceDiscovered(registeredInfo.serviceName)
372 
373         // Test resolving without an Executor
374         val resolveRecord = NsdResolveRecord()
375         nsdManager.resolveService(foundInfo, resolveRecord)
376         val resolvedService = resolveRecord.expectCallback<ServiceResolved>().serviceInfo
377 
378         // Check Txt attributes
379         assertEquals(8, resolvedService.attributes.size)
380         assertTrue(resolvedService.attributes.containsKey("booleanAttr"))
381         assertNull(resolvedService.attributes["booleanAttr"])
382         assertEquals("value", resolvedService.attributes["keyValueAttr"].utf8ToString())
383         assertEquals("=", resolvedService.attributes["keyEqualsAttr"].utf8ToString())
384         assertEquals(" value ",
385                 resolvedService.attributes[" whiteSpaceKeyValueAttr "].utf8ToString())
386         assertEquals(string256.substring(9), resolvedService.attributes["longkey"].utf8ToString())
387         assertArrayEquals(testByteArray, resolvedService.attributes["binaryDataAttr"])
388         assertTrue(resolvedService.attributes.containsKey("nullBinaryDataAttr"))
389         assertNull(resolvedService.attributes["nullBinaryDataAttr"])
390         assertTrue(resolvedService.attributes.containsKey("emptyBinaryDataAttr"))
391         assertNull(resolvedService.attributes["emptyBinaryDataAttr"])
392         assertEquals(localPort, resolvedService.port)
393 
394         // Unregister the service
395         nsdManager.unregisterService(registrationRecord)
396         registrationRecord.expectCallback<ServiceUnregistered>()
397 
398         // Expect a callback for service lost
399         discoveryRecord.expectCallbackEventually<ServiceLost> {
400             it.serviceInfo.serviceName == serviceName
401         }
402 
403         // Register service again to see if NsdManager can discover it
404         val si2 = NsdServiceInfo()
405         si2.serviceType = SERVICE_TYPE
406         si2.serviceName = serviceName
407         si2.port = localPort
408         val registrationRecord2 = NsdRegistrationRecord()
409         nsdManager.registerService(si2, NsdManager.PROTOCOL_DNS_SD, registrationRecord2)
410         val registeredInfo2 = registrationRecord2.expectCallback<ServiceRegistered>().serviceInfo
411 
412         // Expect a service record to be discovered (and filter the ones
413         // that are unrelated to this test)
414         val foundInfo2 = discoveryRecord.waitForServiceDiscovered(registeredInfo2.serviceName)
415 
416         // Resolve the service
417         val resolveRecord2 = NsdResolveRecord()
418         nsdManager.resolveService(foundInfo2, resolveRecord2)
419         val resolvedService2 = resolveRecord2.expectCallback<ServiceResolved>().serviceInfo
420 
421         // Check that the resolved service doesn't have any TXT records
422         assertEquals(0, resolvedService2.attributes.size)
423 
424         nsdManager.stopServiceDiscovery(discoveryRecord)
425 
426         discoveryRecord.expectCallbackEventually<DiscoveryStopped>()
427 
428         nsdManager.unregisterService(registrationRecord2)
429         registrationRecord2.expectCallback<ServiceUnregistered>()
430     }
431 
432     @Test
433     fun testNsdManager_DiscoverOnNetwork() {
434         // This test requires shims supporting T+ APIs (discovering on specific network)
435         assumeTrue(TestUtils.shouldTestTApis())
436 
437         val si = NsdServiceInfo()
438         si.serviceType = SERVICE_TYPE
439         si.serviceName = this.serviceName
440         si.port = 12345 // Test won't try to connect so port does not matter
441 
442         val registrationRecord = NsdRegistrationRecord()
443         val registeredInfo = registerService(registrationRecord, si)
444 
445         tryTest {
446             val discoveryRecord = NsdDiscoveryRecord()
447             nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
448                     testNetwork1.network, Executor { it.run() }, discoveryRecord)
449 
450             val foundInfo = discoveryRecord.waitForServiceDiscovered(
451                     serviceName, testNetwork1.network)
452             assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo))
453 
454             // Rewind to ensure the service is not found on the other interface
455             discoveryRecord.nextEvents.rewind(0)
456             assertNull(discoveryRecord.nextEvents.poll(timeoutMs = 100L) {
457                 it is ServiceFound &&
458                         it.serviceInfo.serviceName == registeredInfo.serviceName &&
459                         nsdShim.getNetwork(it.serviceInfo) != testNetwork1.network
460             }, "The service should not be found on this network")
461         } cleanup {
462             nsdManager.unregisterService(registrationRecord)
463         }
464     }
465 
466     @Test
467     fun testNsdManager_DiscoverWithNetworkRequest() {
468         // This test requires shims supporting T+ APIs (discovering on network request)
469         assumeTrue(TestUtils.shouldTestTApis())
470 
471         val si = NsdServiceInfo()
472         si.serviceType = SERVICE_TYPE
473         si.serviceName = this.serviceName
474         si.port = 12345 // Test won't try to connect so port does not matter
475 
476         val handler = Handler(handlerThread.looper)
477         val executor = Executor { handler.post(it) }
478 
479         val registrationRecord = NsdRegistrationRecord(expectedThreadId = handlerThread.threadId)
480         val registeredInfo1 = registerService(registrationRecord, si, executor)
481         val discoveryRecord = NsdDiscoveryRecord(expectedThreadId = handlerThread.threadId)
482 
483         tryTest {
484             val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName)
485             nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
486                     NetworkRequest.Builder()
487                             .removeCapability(NET_CAPABILITY_TRUSTED)
488                             .addTransportType(TRANSPORT_TEST)
489                             .setNetworkSpecifier(specifier)
490                             .build(),
491                     executor, discoveryRecord)
492 
493             val discoveryStarted = discoveryRecord.expectCallback<DiscoveryStarted>()
494             assertEquals(SERVICE_TYPE, discoveryStarted.serviceType)
495 
496             val serviceDiscovered = discoveryRecord.expectCallback<ServiceFound>()
497             assertEquals(registeredInfo1.serviceName, serviceDiscovered.serviceInfo.serviceName)
498             assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered.serviceInfo))
499 
500             // Unregister, then register the service back: it should be lost and found again
501             nsdManager.unregisterService(registrationRecord)
502             val serviceLost1 = discoveryRecord.expectCallback<ServiceLost>()
503             assertEquals(registeredInfo1.serviceName, serviceLost1.serviceInfo.serviceName)
504             assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceLost1.serviceInfo))
505 
506             registrationRecord.expectCallback<ServiceUnregistered>()
507             val registeredInfo2 = registerService(registrationRecord, si, executor)
508             val serviceDiscovered2 = discoveryRecord.expectCallback<ServiceFound>()
509             assertEquals(registeredInfo2.serviceName, serviceDiscovered2.serviceInfo.serviceName)
510             assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceDiscovered2.serviceInfo))
511 
512             // Teardown, then bring back up a network on the test interface: the service should
513             // go away, then come back
514             testNetwork1.agent.unregister()
515             val serviceLost = discoveryRecord.expectCallback<ServiceLost>()
516             assertEquals(registeredInfo2.serviceName, serviceLost.serviceInfo.serviceName)
517             assertEquals(testNetwork1.network, nsdShim.getNetwork(serviceLost.serviceInfo))
518 
519             val newAgent = runAsShell(MANAGE_TEST_NETWORKS) {
520                 registerTestNetworkAgent(testNetwork1.iface.interfaceName)
521             }
522             val newNetwork = newAgent.network ?: fail("Registered agent should have a network")
523             val serviceDiscovered3 = discoveryRecord.expectCallback<ServiceFound>()
524             assertEquals(registeredInfo2.serviceName, serviceDiscovered3.serviceInfo.serviceName)
525             assertEquals(newNetwork, nsdShim.getNetwork(serviceDiscovered3.serviceInfo))
526         } cleanupStep {
527             nsdManager.stopServiceDiscovery(discoveryRecord)
528             discoveryRecord.expectCallback<DiscoveryStopped>()
529         } cleanup {
530             nsdManager.unregisterService(registrationRecord)
531         }
532     }
533 
534     @Test
535     fun testNsdManager_DiscoverWithNetworkRequest_NoMatchingNetwork() {
536         // This test requires shims supporting T+ APIs (discovering on network request)
537         assumeTrue(TestUtils.shouldTestTApis())
538 
539         val si = NsdServiceInfo()
540         si.serviceType = SERVICE_TYPE
541         si.serviceName = this.serviceName
542         si.port = 12345 // Test won't try to connect so port does not matter
543 
544         val handler = Handler(handlerThread.looper)
545         val executor = Executor { handler.post(it) }
546 
547         val discoveryRecord = NsdDiscoveryRecord(expectedThreadId = handlerThread.threadId)
548         val specifier = TestNetworkSpecifier(testNetwork1.iface.interfaceName)
549 
550         tryTest {
551             nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
552                     NetworkRequest.Builder()
553                             .removeCapability(NET_CAPABILITY_TRUSTED)
554                             .addTransportType(TRANSPORT_TEST)
555                             // Specified network does not have this capability
556                             .addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
557                             .setNetworkSpecifier(specifier)
558                             .build(),
559                     executor, discoveryRecord)
560             discoveryRecord.expectCallback<DiscoveryStarted>()
561         } cleanup {
562             nsdManager.stopServiceDiscovery(discoveryRecord)
563             discoveryRecord.expectCallback<DiscoveryStopped>()
564         }
565     }
566 
567     @Test
568     fun testNsdManager_ResolveOnNetwork() {
569         // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
570         assumeTrue(TestUtils.shouldTestTApis())
571 
572         val si = NsdServiceInfo()
573         si.serviceType = SERVICE_TYPE
574         si.serviceName = this.serviceName
575         si.port = 12345 // Test won't try to connect so port does not matter
576 
577         val registrationRecord = NsdRegistrationRecord()
578         val registeredInfo = registerService(registrationRecord, si)
579         tryTest {
580             val resolveRecord = NsdResolveRecord()
581 
582             val discoveryRecord = NsdDiscoveryRecord()
583             nsdManager.discoverServices(SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD, discoveryRecord)
584 
585             val foundInfo1 = discoveryRecord.waitForServiceDiscovered(
586                     serviceName, testNetwork1.network)
587             assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo1))
588             // Rewind as the service could be found on each interface in any order
589             discoveryRecord.nextEvents.rewind(0)
590             val foundInfo2 = discoveryRecord.waitForServiceDiscovered(
591                     serviceName, testNetwork2.network)
592             assertEquals(testNetwork2.network, nsdShim.getNetwork(foundInfo2))
593 
594             nsdShim.resolveService(nsdManager, foundInfo1, Executor { it.run() }, resolveRecord)
595             val cb = resolveRecord.expectCallback<ServiceResolved>()
596             cb.serviceInfo.let {
597                 // Resolved service type has leading dot
598                 assertEquals(".$SERVICE_TYPE", it.serviceType)
599                 assertEquals(registeredInfo.serviceName, it.serviceName)
600                 assertEquals(si.port, it.port)
601                 assertEquals(testNetwork1.network, nsdShim.getNetwork(it))
602             }
603             // TODO: check that MDNS packets are sent only on testNetwork1.
604         } cleanupStep {
605             nsdManager.unregisterService(registrationRecord)
606         } cleanup {
607             registrationRecord.expectCallback<ServiceUnregistered>()
608         }
609     }
610 
611     @Test
612     fun testNsdManager_RegisterOnNetwork() {
613         // This test requires shims supporting T+ APIs (NsdServiceInfo.network)
614         assumeTrue(TestUtils.shouldTestTApis())
615 
616         val si = NsdServiceInfo()
617         si.serviceType = SERVICE_TYPE
618         si.serviceName = this.serviceName
619         si.network = testNetwork1.network
620         si.port = 12345 // Test won't try to connect so port does not matter
621 
622         // Register service on testNetwork1
623         val registrationRecord = NsdRegistrationRecord()
624         registerService(registrationRecord, si)
625         val discoveryRecord = NsdDiscoveryRecord()
626         val discoveryRecord2 = NsdDiscoveryRecord()
627         val discoveryRecord3 = NsdDiscoveryRecord()
628 
629         tryTest {
630             // Discover service on testNetwork1.
631             nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
632                 testNetwork1.network, Executor { it.run() }, discoveryRecord)
633             // Expect that service is found on testNetwork1
634             val foundInfo = discoveryRecord.waitForServiceDiscovered(
635                 serviceName, testNetwork1.network)
636             assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo))
637 
638             // Discover service on testNetwork2.
639             nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
640                 testNetwork2.network, Executor { it.run() }, discoveryRecord2)
641             // Expect that discovery is started then no other callbacks.
642             discoveryRecord2.expectCallback<DiscoveryStarted>()
643             discoveryRecord2.assertNoCallback()
644 
645             // Discover service on all networks (not specify any network).
646             nsdShim.discoverServices(nsdManager, SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD,
647                 null as Network? /* network */, Executor { it.run() }, discoveryRecord3)
648             // Expect that service is found on testNetwork1
649             val foundInfo3 = discoveryRecord3.waitForServiceDiscovered(
650                     serviceName, testNetwork1.network)
651             assertEquals(testNetwork1.network, nsdShim.getNetwork(foundInfo3))
652         } cleanupStep {
653             nsdManager.stopServiceDiscovery(discoveryRecord2)
654             discoveryRecord2.expectCallback<DiscoveryStopped>()
655         } cleanup {
656             nsdManager.unregisterService(registrationRecord)
657         }
658     }
659 
660     @Test
661     fun testNsdManager_RegisterServiceNameWithNonStandardCharacters() {
662         val serviceNames = "^Nsd.Test|Non-#AsCiI\\Characters&\\ufffe テスト 測試"
663         val si = NsdServiceInfo().apply {
664             serviceType = SERVICE_TYPE
665             serviceName = serviceNames
666             port = 12345 // Test won't try to connect so port does not matter
667         }
668 
669         // Register the service name which contains non-standard characters.
670         val registrationRecord = NsdRegistrationRecord()
671         nsdManager.registerService(si, NsdManager.PROTOCOL_DNS_SD, registrationRecord)
672         registrationRecord.expectCallback<ServiceRegistered>()
673 
674         tryTest {
675             // Discover that service name.
676             val discoveryRecord = NsdDiscoveryRecord()
677             nsdManager.discoverServices(
678                 SERVICE_TYPE, NsdManager.PROTOCOL_DNS_SD, discoveryRecord
679             )
680             val foundInfo = discoveryRecord.waitForServiceDiscovered(serviceNames)
681 
682             // Expect that resolving the service name works properly even service name contains
683             // non-standard characters.
684             val resolveRecord = NsdResolveRecord()
685             nsdManager.resolveService(foundInfo, resolveRecord)
686             val resolvedCb = resolveRecord.expectCallback<ServiceResolved>()
687             assertEquals(foundInfo.serviceName, resolvedCb.serviceInfo.serviceName)
688         } cleanupStep {
689             nsdManager.unregisterService(registrationRecord)
690         } cleanup {
691             registrationRecord.expectCallback<ServiceUnregistered>()
692         }
693     }
694 
695     /**
696      * Register a service and return its registration record.
697      */
698     private fun registerService(
699         record: NsdRegistrationRecord,
700         si: NsdServiceInfo,
701         executor: Executor = Executor { it.run() }
702     ): NsdServiceInfo {
703         nsdShim.registerService(nsdManager, si, NsdManager.PROTOCOL_DNS_SD, executor, record)
704         // We may not always get the name that we tried to register;
705         // This events tells us the name that was registered.
706         val cb = record.expectCallback<ServiceRegistered>()
707         return cb.serviceInfo
708     }
709 
710     private fun resolveService(discoveredInfo: NsdServiceInfo): NsdServiceInfo {
711         val record = NsdResolveRecord()
712         nsdShim.resolveService(nsdManager, discoveredInfo, Executor { it.run() }, record)
713         val resolvedCb = record.expectCallback<ServiceResolved>()
714         assertEquals(discoveredInfo.serviceName, resolvedCb.serviceInfo.serviceName)
715 
716         return resolvedCb.serviceInfo
717     }
718 }
719 
utf8ToStringnull720 private fun ByteArray?.utf8ToString(): String {
721     if (this == null) return ""
722     return String(this, StandardCharsets.UTF_8)
723 }
724