• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net
18 
19 import android.app.Instrumentation
20 import android.content.Context
21 import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED
22 import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
23 import android.net.NetworkCapabilities.TRANSPORT_TEST
24 import android.net.NetworkProviderTest.TestNetworkCallback.CallbackEntry.OnUnavailable
25 import android.net.NetworkProviderTest.TestNetworkProvider.CallbackEntry.OnNetworkRequestWithdrawn
26 import android.net.NetworkProviderTest.TestNetworkProvider.CallbackEntry.OnNetworkRequested
27 import android.os.Build
28 import android.os.Handler
29 import android.os.HandlerThread
30 import android.os.Looper
31 import android.util.Log
32 import androidx.test.InstrumentationRegistry
33 import com.android.modules.utils.build.SdkLevel.isAtLeastS
34 import com.android.net.module.util.ArrayTrackRecord
35 import com.android.testutils.CompatUtil
36 import com.android.testutils.ConnectivityModuleTest
37 import com.android.testutils.DevSdkIgnoreRule
38 import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
39 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
40 import com.android.testutils.DevSdkIgnoreRunner
41 import com.android.testutils.TestableNetworkOfferCallback
42 import org.junit.After
43 import org.junit.Before
44 import org.junit.Rule
45 import org.junit.Test
46 import org.junit.runner.RunWith
47 import org.mockito.Mockito.doReturn
48 import org.mockito.Mockito.mock
49 import org.mockito.Mockito.verifyNoMoreInteractions
50 import java.util.UUID
51 import java.util.concurrent.Executor
52 import java.util.concurrent.RejectedExecutionException
53 import kotlin.test.assertEquals
54 import kotlin.test.assertNotEquals
55 import kotlin.test.fail
56 
57 private const val DEFAULT_TIMEOUT_MS = 5000L
58 private const val DEFAULT_NO_CALLBACK_TIMEOUT_MS = 200L
59 private val instrumentation: Instrumentation
60     get() = InstrumentationRegistry.getInstrumentation()
61 private val context: Context get() = InstrumentationRegistry.getContext()
62 private val PROVIDER_NAME = "NetworkProviderTest"
63 
64 @RunWith(DevSdkIgnoreRunner::class)
65 @IgnoreUpTo(Build.VERSION_CODES.Q)
66 @ConnectivityModuleTest
67 class NetworkProviderTest {
68     @Rule @JvmField
69     val mIgnoreRule = DevSdkIgnoreRule()
70     private val mCm = context.getSystemService(ConnectivityManager::class.java)
71     private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread")
72 
73     @Before
74     fun setUp() {
75         instrumentation.getUiAutomation().adoptShellPermissionIdentity()
76         mHandlerThread.start()
77     }
78 
79     @After
80     fun tearDown() {
81         mHandlerThread.quitSafely()
82         mHandlerThread.join()
83         instrumentation.getUiAutomation().dropShellPermissionIdentity()
84     }
85 
86     private class TestNetworkProvider(context: Context, looper: Looper) :
87             NetworkProvider(context, looper, PROVIDER_NAME) {
88         private val TAG = this::class.simpleName
89         private val seenEvents = ArrayTrackRecord<CallbackEntry>().newReadHead()
90 
91         sealed class CallbackEntry {
92             data class OnNetworkRequested(
93                 val request: NetworkRequest,
94                 val score: Int,
95                 val id: Int
96             ) : CallbackEntry()
97             data class OnNetworkRequestWithdrawn(val request: NetworkRequest) : CallbackEntry()
98         }
99 
100         override fun onNetworkRequested(request: NetworkRequest, score: Int, id: Int) {
101             Log.d(TAG, "onNetworkRequested $request, $score, $id")
102             seenEvents.add(OnNetworkRequested(request, score, id))
103         }
104 
105         override fun onNetworkRequestWithdrawn(request: NetworkRequest) {
106             Log.d(TAG, "onNetworkRequestWithdrawn $request")
107             seenEvents.add(OnNetworkRequestWithdrawn(request))
108         }
109 
110         inline fun <reified T : CallbackEntry> eventuallyExpectCallbackThat(
111             crossinline predicate: (T) -> Boolean
112         ) = seenEvents.poll(DEFAULT_TIMEOUT_MS) { it is T && predicate(it) }
113                 ?: fail("Did not receive callback after ${DEFAULT_TIMEOUT_MS}ms")
114 
115         fun assertNoCallback() {
116             val cb = seenEvents.poll(DEFAULT_NO_CALLBACK_TIMEOUT_MS)
117             if (null != cb) fail("Expected no callback but got $cb")
118         }
119     }
120 
121     private fun createNetworkProvider(ctx: Context = context): TestNetworkProvider {
122         return TestNetworkProvider(ctx, mHandlerThread.looper)
123     }
124 
125     private fun createAndRegisterNetworkProvider(ctx: Context = context) =
126         createNetworkProvider(ctx).also {
127             assertEquals(it.getProviderId(), NetworkProvider.ID_NONE)
128             mCm.registerNetworkProvider(it)
129             assertNotEquals(it.getProviderId(), NetworkProvider.ID_NONE)
130         }
131 
132     // In S+ framework, do not run this test, since the provider will no longer receive
133     // onNetworkRequested for every request. Instead, provider needs to
134     // call {@code registerNetworkOffer} with the description of networks they
135     // might have ability to setup, and expects {@link NetworkOfferCallback#onNetworkNeeded}.
136     @IgnoreAfter(Build.VERSION_CODES.R)
137     @Test
138     fun testOnNetworkRequested() {
139         val provider = createAndRegisterNetworkProvider()
140 
141         val specifier = CompatUtil.makeTestNetworkSpecifier(
142                 UUID.randomUUID().toString())
143         // Test network is not allowed to be trusted.
144         val nr: NetworkRequest = NetworkRequest.Builder()
145                 .addTransportType(TRANSPORT_TEST)
146                 .removeCapability(NET_CAPABILITY_TRUSTED)
147                 .setNetworkSpecifier(specifier)
148                 .build()
149         val cb = ConnectivityManager.NetworkCallback()
150         mCm.requestNetwork(nr, cb)
151         provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
152             callback.request.getNetworkSpecifier() == specifier &&
153             callback.request.hasTransport(TRANSPORT_TEST)
154         }
155 
156         val initialScore = 40
157         val updatedScore = 60
158         val nc = NetworkCapabilities().apply {
159                 addTransportType(NetworkCapabilities.TRANSPORT_TEST)
160                 removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
161                 removeCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
162                 addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
163                 addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
164                 addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
165                 setNetworkSpecifier(specifier)
166         }
167         val lp = LinkProperties()
168         val config = NetworkAgentConfig.Builder().build()
169         val agent = object : NetworkAgent(context, mHandlerThread.looper, "TestAgent", nc, lp,
170                 initialScore, config, provider) {}
171         agent.register()
172         agent.markConnected()
173 
174         provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
175             callback.request.getNetworkSpecifier() == specifier &&
176             callback.score == initialScore &&
177             callback.id == agent.providerId
178         }
179 
180         agent.sendNetworkScore(updatedScore)
181         provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
182             callback.request.getNetworkSpecifier() == specifier &&
183             callback.score == updatedScore &&
184             callback.id == agent.providerId
185         }
186 
187         mCm.unregisterNetworkCallback(cb)
188         provider.eventuallyExpectCallbackThat<OnNetworkRequestWithdrawn>() { callback ->
189             callback.request.getNetworkSpecifier() == specifier &&
190             callback.request.hasTransport(TRANSPORT_TEST)
191         }
192         mCm.unregisterNetworkProvider(provider)
193         // Provider id should be ID_NONE after unregister network provider
194         assertEquals(provider.getProviderId(), NetworkProvider.ID_NONE)
195         // unregisterNetworkProvider should not crash even if it's called on an
196         // already unregistered provider.
197         mCm.unregisterNetworkProvider(provider)
198     }
199 
200     // Mainline module can't use internal HandlerExecutor, so add an identical executor here.
201     // TODO: Refactor with the one in MultiNetworkPolicyTracker.
202     private class HandlerExecutor(private val handler: Handler) : Executor {
203         public override fun execute(command: Runnable) {
204             if (!handler.post(command)) {
205                 throw RejectedExecutionException(handler.toString() + " is shutting down")
206             }
207         }
208     }
209 
210     @IgnoreUpTo(Build.VERSION_CODES.R)
211     @Test
212     fun testRegisterNetworkOffer() {
213         val provider = createAndRegisterNetworkProvider()
214         val provider2 = createAndRegisterNetworkProvider()
215 
216         // Prepare the materials which will be used to create different offers.
217         val specifier1 = CompatUtil.makeTestNetworkSpecifier("TEST-SPECIFIER-1")
218         val specifier2 = CompatUtil.makeTestNetworkSpecifier("TEST-SPECIFIER-2")
219         val scoreWeaker = NetworkScore.Builder().build()
220         val scoreStronger = NetworkScore.Builder().setTransportPrimary(true).build()
221         val ncFilter1 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
222                 .setNetworkSpecifier(specifier1).build()
223         val ncFilter2 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
224                 .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
225                 .setNetworkSpecifier(specifier1).build()
226         val ncFilter3 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
227                 .setNetworkSpecifier(specifier2).build()
228         val ncFilter4 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
229                 .setNetworkSpecifier(specifier2).build()
230 
231         // Make 4 offers, where 1 doesn't have NOT_VCN, 2 has NOT_VCN, 3 is similar to 1 but with
232         // different specifier, and 4 is also similar to 1 but with different provider.
233         val offerCallback1 = TestableNetworkOfferCallback(
234                 DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
235         val offerCallback2 = TestableNetworkOfferCallback(
236                 DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
237         val offerCallback3 = TestableNetworkOfferCallback(
238                 DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
239         val offerCallback4 = TestableNetworkOfferCallback(
240                 DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
241         provider.registerNetworkOffer(scoreWeaker, ncFilter1,
242                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback1)
243         provider.registerNetworkOffer(scoreStronger, ncFilter2,
244                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback2)
245         provider.registerNetworkOffer(scoreWeaker, ncFilter3,
246                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback3)
247         provider2.registerNetworkOffer(scoreWeaker, ncFilter4,
248                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback4)
249         // Unlike Android R, Android S+ provider will only receive interested requests via offer
250         // callback. Verify that the callbacks do not see any existing request such as default
251         // requests.
252         offerCallback1.assertNoCallback()
253         offerCallback2.assertNoCallback()
254         offerCallback3.assertNoCallback()
255         offerCallback4.assertNoCallback()
256 
257         // File a request with specifier but without NOT_VCN, verify network is needed for callback
258         // with the same specifier.
259         val nrNoNotVcn: NetworkRequest = NetworkRequest.Builder()
260                 .addTransportType(TRANSPORT_TEST)
261                 // Test network is not allowed to be trusted.
262                 .removeCapability(NET_CAPABILITY_TRUSTED)
263                 .setNetworkSpecifier(specifier1)
264                 .build()
265         val cb1 = ConnectivityManager.NetworkCallback()
266         mCm.requestNetwork(nrNoNotVcn, cb1)
267         offerCallback1.expectOnNetworkNeeded(ncFilter1)
268         offerCallback2.expectOnNetworkNeeded(ncFilter2)
269         offerCallback3.assertNoCallback()
270         offerCallback4.assertNoCallback()
271 
272         mCm.unregisterNetworkCallback(cb1)
273         offerCallback1.expectOnNetworkUnneeded(ncFilter1)
274         offerCallback2.expectOnNetworkUnneeded(ncFilter2)
275         offerCallback3.assertNoCallback()
276         offerCallback4.assertNoCallback()
277 
278         // File a request without specifier but with NOT_VCN, verify network is needed for offer
279         // with NOT_VCN.
280         val nrNotVcn: NetworkRequest = NetworkRequest.Builder()
281                 .addTransportType(TRANSPORT_TEST)
282                 .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
283                 // Test network is not allowed to be trusted.
284                 .removeCapability(NET_CAPABILITY_TRUSTED)
285                 .build()
286         val cb2 = ConnectivityManager.NetworkCallback()
287         mCm.requestNetwork(nrNotVcn, cb2)
288         offerCallback1.assertNoCallback()
289         offerCallback2.expectOnNetworkNeeded(ncFilter2)
290         offerCallback3.assertNoCallback()
291         offerCallback4.assertNoCallback()
292 
293         // Upgrade offer 3 & 4 to satisfy previous request and then verify they are also needed.
294         ncFilter3.addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
295         provider.registerNetworkOffer(scoreWeaker, ncFilter3,
296                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback3)
297         ncFilter4.addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
298         provider2.registerNetworkOffer(scoreWeaker, ncFilter4,
299                 HandlerExecutor(mHandlerThread.threadHandler), offerCallback4)
300         offerCallback1.assertNoCallback()
301         offerCallback2.assertNoCallback()
302         offerCallback3.expectOnNetworkNeeded(ncFilter3)
303         offerCallback4.expectOnNetworkNeeded(ncFilter4)
304 
305         // Connect an agent to fulfill the request, verify offer 4 is not needed since it is not
306         // from currently serving provider nor can beat the current satisfier.
307         val nc = NetworkCapabilities().apply {
308             addTransportType(NetworkCapabilities.TRANSPORT_TEST)
309             removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
310             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED)
311             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
312             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
313             addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
314             setNetworkSpecifier(specifier1)
315         }
316         val config = NetworkAgentConfig.Builder().build()
317         val agent = object : NetworkAgent(context, mHandlerThread.looper, "TestAgent", nc,
318                 LinkProperties(), scoreWeaker, config, provider) {}
319         agent.register()
320         agent.markConnected()
321         offerCallback1.assertNoCallback()  // Still unneeded.
322         offerCallback2.assertNoCallback()  // Still needed.
323         offerCallback3.assertNoCallback()  // Still needed.
324         offerCallback4.expectOnNetworkUnneeded(ncFilter4)
325 
326         // Upgrade the agent, verify no change since the framework will treat the offer as needed
327         // if a request is currently satisfied by the network provided by the same provider.
328         // TODO: Consider offers with weaker score are unneeded.
329         agent.sendNetworkScore(scoreStronger)
330         offerCallback1.assertNoCallback()  // Still unneeded.
331         offerCallback2.assertNoCallback()  // Still needed.
332         offerCallback3.assertNoCallback()  // Still needed.
333         offerCallback4.assertNoCallback()  // Still unneeded.
334 
335         // Verify that offer callbacks cannot receive any event if offer is unregistered.
336         provider2.unregisterNetworkOffer(offerCallback4)
337         agent.unregister()
338         offerCallback1.assertNoCallback()  // Still unneeded.
339         offerCallback2.assertNoCallback()  // Still needed.
340         offerCallback3.assertNoCallback()  // Still needed.
341         // Since the agent is unregistered, and the offer has chance to satisfy the request,
342         // this callback should receive needed if it is not unregistered.
343         offerCallback4.assertNoCallback()
344 
345         // Verify that offer callbacks cannot receive any event if provider is unregistered.
346         mCm.unregisterNetworkProvider(provider)
347         mCm.unregisterNetworkCallback(cb2)
348         offerCallback1.assertNoCallback()  // No callback since it is still unneeded.
349         offerCallback2.assertNoCallback()  // Should be unneeded if not unregistered.
350         offerCallback3.assertNoCallback()  // Should be unneeded if not unregistered.
351         offerCallback4.assertNoCallback()  // Already unregistered.
352 
353         // Clean up and Verify providers did not receive any callback during the entire test.
354         mCm.unregisterNetworkProvider(provider2)
355         provider.assertNoCallback()
356         provider2.assertNoCallback()
357     }
358 
359     private class TestNetworkCallback : ConnectivityManager.NetworkCallback() {
360         private val seenEvents = ArrayTrackRecord<CallbackEntry>().newReadHead()
361         sealed class CallbackEntry {
362             object OnUnavailable : CallbackEntry()
363         }
364 
365         override fun onUnavailable() {
366             seenEvents.add(OnUnavailable)
367         }
368 
369         inline fun <reified T : CallbackEntry> expectCallback(
370             crossinline predicate: (T) -> Boolean
371         ) = seenEvents.poll(DEFAULT_TIMEOUT_MS) { it is T && predicate(it) }
372     }
373 
374     @Test
375     fun testDeclareNetworkRequestUnfulfillable() {
376         val mockContext = mock(Context::class.java)
377         doReturn(mCm).`when`(mockContext).getSystemService(Context.CONNECTIVITY_SERVICE)
378         val provider = createNetworkProvider(mockContext)
379         // ConnectivityManager not required at creation time after R
380         if (isAtLeastS()) {
381             verifyNoMoreInteractions(mockContext)
382         }
383 
384         mCm.registerNetworkProvider(provider)
385 
386         val specifier = CompatUtil.makeTestNetworkSpecifier(
387                 UUID.randomUUID().toString())
388         val nr: NetworkRequest = NetworkRequest.Builder()
389                 .addTransportType(TRANSPORT_TEST)
390                 .setNetworkSpecifier(specifier)
391                 .build()
392 
393         val cb = TestNetworkCallback()
394         mCm.requestNetwork(nr, cb)
395         provider.declareNetworkRequestUnfulfillable(nr)
396         cb.expectCallback<OnUnavailable>() { nr.getNetworkSpecifier() == specifier }
397         mCm.unregisterNetworkProvider(provider)
398     }
399 }
400