• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils
18 
19 import android.net.ConnectivityManager.NetworkCallback
20 import android.net.LinkProperties
21 import android.net.Network
22 import android.net.NetworkCapabilities
23 import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
24 import android.util.Log
25 import com.android.net.module.util.ArrayTrackRecord
26 import com.android.testutils.RecorderCallback.CallbackEntry.Available
27 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
28 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatusInt
29 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
30 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
31 import com.android.testutils.RecorderCallback.CallbackEntry.Losing
32 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
33 import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
34 import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
35 import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
36 import kotlin.reflect.KClass
37 import kotlin.test.assertEquals
38 import kotlin.test.assertNotNull
39 import kotlin.test.fail
40 
41 object NULL_NETWORK : Network(-1)
42 object ANY_NETWORK : Network(-2)
anyNetworknull43 fun anyNetwork() = ANY_NETWORK
44 
45 open class RecorderCallback private constructor(
46     private val backingRecord: ArrayTrackRecord<CallbackEntry>
47 ) : NetworkCallback() {
48     public constructor() : this(ArrayTrackRecord())
49     protected constructor(src: RecorderCallback?) : this(src?.backingRecord ?: ArrayTrackRecord())
50 
51     private val TAG = this::class.simpleName
52 
53     sealed class CallbackEntry {
54         // To get equals(), hashcode(), componentN() etc for free, the child classes of
55         // this class are data classes. But while data classes can inherit from other classes,
56         // they may only have visible members in the constructors, so they couldn't declare
57         // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
58         // subclasses to implement a `network' property, which can be done in a data class
59         // constructor by specifying override.
60         abstract val network: Network
61 
62         data class Available(override val network: Network) : CallbackEntry()
63         data class CapabilitiesChanged(
64             override val network: Network,
65             val caps: NetworkCapabilities
66         ) : CallbackEntry()
67         data class LinkPropertiesChanged(
68             override val network: Network,
69             val lp: LinkProperties
70         ) : CallbackEntry()
71         data class Suspended(override val network: Network) : CallbackEntry()
72         data class Resumed(override val network: Network) : CallbackEntry()
73         data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
74         data class Lost(override val network: Network) : CallbackEntry()
75         data class Unavailable private constructor(
76             override val network: Network
77         ) : CallbackEntry() {
78             constructor() : this(NULL_NETWORK)
79         }
80         data class BlockedStatus(
81             override val network: Network,
82             val blocked: Boolean
83         ) : CallbackEntry()
84         data class BlockedStatusInt(
85             override val network: Network,
86             val reason: Int
87         ) : CallbackEntry()
88         // Convenience constants for expecting a type
89         companion object {
90             @JvmField
91             val AVAILABLE = Available::class
92             @JvmField
93             val NETWORK_CAPS_UPDATED = CapabilitiesChanged::class
94             @JvmField
95             val LINK_PROPERTIES_CHANGED = LinkPropertiesChanged::class
96             @JvmField
97             val SUSPENDED = Suspended::class
98             @JvmField
99             val RESUMED = Resumed::class
100             @JvmField
101             val LOSING = Losing::class
102             @JvmField
103             val LOST = Lost::class
104             @JvmField
105             val UNAVAILABLE = Unavailable::class
106             @JvmField
107             val BLOCKED_STATUS = BlockedStatus::class
108             @JvmField
109             val BLOCKED_STATUS_INT = BlockedStatusInt::class
110         }
111     }
112 
113     val history = backingRecord.newReadHead()
114     val mark get() = history.mark
115 
116     override fun onAvailable(network: Network) {
117         Log.d(TAG, "onAvailable $network")
118         history.add(Available(network))
119     }
120 
121     // PreCheck is not used in the tests today. For backward compatibility with existing tests that
122     // expect the callbacks not to record this, do not listen to PreCheck here.
123 
124     override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
125         Log.d(TAG, "onCapabilitiesChanged $network $caps")
126         history.add(CapabilitiesChanged(network, caps))
127     }
128 
129     override fun onLinkPropertiesChanged(network: Network, lp: LinkProperties) {
130         Log.d(TAG, "onLinkPropertiesChanged $network $lp")
131         history.add(LinkPropertiesChanged(network, lp))
132     }
133 
134     override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
135         Log.d(TAG, "onBlockedStatusChanged $network $blocked")
136         history.add(BlockedStatus(network, blocked))
137     }
138 
139     // Cannot do:
140     // fun onBlockedStatusChanged(network: Network, blocked: Int) {
141     // because on S, that needs to be "override fun", and on R, that cannot be "override fun".
142     override fun onNetworkSuspended(network: Network) {
143         Log.d(TAG, "onNetworkSuspended $network $network")
144         history.add(Suspended(network))
145     }
146 
147     override fun onNetworkResumed(network: Network) {
148         Log.d(TAG, "$network onNetworkResumed $network")
149         history.add(Resumed(network))
150     }
151 
152     override fun onLosing(network: Network, maxMsToLive: Int) {
153         Log.d(TAG, "onLosing $network $maxMsToLive")
154         history.add(Losing(network, maxMsToLive))
155     }
156 
157     override fun onLost(network: Network) {
158         Log.d(TAG, "onLost $network")
159         history.add(Lost(network))
160     }
161 
162     override fun onUnavailable() {
163         Log.d(TAG, "onUnavailable")
164         history.add(Unavailable())
165     }
166 }
167 
168 private const val DEFAULT_TIMEOUT = 30_000L // ms
169 private const val DEFAULT_NO_CALLBACK_TIMEOUT = 200L // ms
<lambda>null170 private val NOOP = Runnable {}
171 
172 /**
173  * See comments on the public constructor below for a description of the arguments.
174  */
175 open class TestableNetworkCallback private constructor(
176     src: TestableNetworkCallback?,
177     val defaultTimeoutMs: Long = DEFAULT_TIMEOUT,
178     val defaultNoCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT,
179     val waiterFunc: Runnable = NOOP // "() -> Unit" would forbid calling with a void func from Java
180 ) : RecorderCallback(src) {
181     /**
182      * Construct a testable network callback.
183      * @param timeoutMs the default timeout for expecting a callback. Default 30 seconds. This
184      *                  should be long in most cases, because the success case doesn't incur
185      *                  the wait.
186      * @param noCallbackTimeoutMs the timeout for expecting that no callback is received. Default
187      *                            200ms. Because the success case does incur the timeout, this
188      *                            should be short in most cases, but not so short as to frequently
189      *                            time out before an incorrect callback is received.
190      * @param waiterFunc a function to use before asserting no callback. For some specific tests,
191      *                   it is useful to run test-specific code before asserting no callback to
192      *                   increase the likelihood that a spurious callback is correctly detected.
193      *                   As an example, a unit test using mock loopers may want to use this to
194      *                   make sure the loopers are drained before asserting no callback, since
195      *                   one of them may cause a callback to be called. @see ConnectivityServiceTest
196      *                   for such an example.
197      */
198     @JvmOverloads
199     constructor(
200         timeoutMs: Long = DEFAULT_TIMEOUT,
201         noCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT,
202         waiterFunc: Runnable = NOOP
203     ) : this(null, timeoutMs, noCallbackTimeoutMs, waiterFunc)
204 
createLinkedCopynull205     fun createLinkedCopy() = TestableNetworkCallback(
206             this, defaultTimeoutMs, defaultNoCallbackTimeoutMs, waiterFunc)
207 
208     // The last available network, or null if any network was lost since the last call to
209     // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
210     val lastAvailableNetwork: Network?
211         get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
212             is Available -> it.network
213             else -> null
214         }
215 
216     /**
217      * Get the next callback or null if timeout.
218      *
219      * With no argument, this method waits out the default timeout. To wait forever, pass
220      * Long.MAX_VALUE.
221      */
222     @JvmOverloads
<lambda>null223     fun poll(timeoutMs: Long = defaultTimeoutMs, predicate: (CallbackEntry) -> Boolean = { true }) =
224             history.poll(timeoutMs, predicate)
225 
226     /*****
227      * expect family of methods.
228      * These methods fetch the next callback and assert it matches the conditions : type,
229      * passed predicate. If no callback is received within the timeout, these methods fail.
230      */
231     @JvmOverloads
expectnull232     fun <T : CallbackEntry> expect(
233         type: KClass<T>,
234         network: Network = ANY_NETWORK,
235         timeoutMs: Long = defaultTimeoutMs,
236         errorMsg: String? = null,
237         test: (T) -> Boolean = { true }
<lambda>null238     ) = expect<CallbackEntry>(network, timeoutMs, errorMsg) {
239         test(it as? T ?: fail("Expected callback ${type.simpleName}, got $it"))
240     } as T
241 
242     @JvmOverloads
expectnull243     fun <T : CallbackEntry> expect(
244         type: KClass<T>,
245         network: HasNetwork,
246         timeoutMs: Long = defaultTimeoutMs,
247         errorMsg: String? = null,
248         test: (T) -> Boolean = { true }
249     ) = expect(type, network.network, timeoutMs, errorMsg, test)
250 
251     // Java needs an explicit overload to let it omit arguments in the middle, so define these
252     // here. Note that @JvmOverloads give us the versions without the last arguments too, so
253     // there is no need to explicitly define versions without the test predicate.
254     // Without |network|
255     @JvmOverloads
expectnull256     fun <T : CallbackEntry> expect(
257         type: KClass<T>,
258         timeoutMs: Long,
259         errorMsg: String?,
260         test: (T) -> Boolean = { true }
261     ) = expect(type, ANY_NETWORK, timeoutMs, errorMsg, test)
262 
263     // Without |timeout|, in Network and HasNetwork versions
264     @JvmOverloads
expectnull265     fun <T : CallbackEntry> expect(
266         type: KClass<T>,
267         network: Network,
268         errorMsg: String?,
269         test: (T) -> Boolean = { true }
270     ) = expect(type, network, defaultTimeoutMs, errorMsg, test)
271 
272     @JvmOverloads
expectnull273     fun <T : CallbackEntry> expect(
274         type: KClass<T>,
275         network: HasNetwork,
276         errorMsg: String?,
277         test: (T) -> Boolean = { true }
278     ) = expect(type, network.network, defaultTimeoutMs, errorMsg, test)
279 
280     // Without |errorMsg|, in Network and HasNetwork versions
281     @JvmOverloads
expectnull282     fun <T : CallbackEntry> expect(
283         type: KClass<T>,
284         network: Network,
285         timeoutMs: Long,
286         test: (T) -> Boolean
287     ) = expect(type, network, timeoutMs, null, test)
288 
289     @JvmOverloads
290     fun <T : CallbackEntry> expect(
291         type: KClass<T>,
292         network: HasNetwork,
293         timeoutMs: Long,
294         test: (T) -> Boolean
295     ) = expect(type, network.network, timeoutMs, null, test)
296 
297     // Without |network| or |timeout|
298     @JvmOverloads
299     fun <T : CallbackEntry> expect(
300         type: KClass<T>,
301         errorMsg: String?,
302         test: (T) -> Boolean = { true }
303     ) = expect(type, ANY_NETWORK, defaultTimeoutMs, errorMsg, test)
304 
305     // Without |network| or |errorMsg|
306     @JvmOverloads
expectnull307     fun <T : CallbackEntry> expect(
308         type: KClass<T>,
309         timeoutMs: Long,
310         test: (T) -> Boolean = { true }
311     ) = expect(type, ANY_NETWORK, timeoutMs, null, test)
312 
313     // Without |timeout| or |errorMsg|, in Network and HasNetwork versions
314     @JvmOverloads
expectnull315     fun <T : CallbackEntry> expect(
316         type: KClass<T>,
317         network: Network,
318         test: (T) -> Boolean
319     ) = expect(type, network, defaultTimeoutMs, null, test)
320 
321     @JvmOverloads
322     fun <T : CallbackEntry> expect(
323         type: KClass<T>,
324         network: HasNetwork,
325         test: (T) -> Boolean
326     ) = expect(type, network.network, defaultTimeoutMs, null, test)
327 
328     // Without |network| or |timeout| or |errorMsg|
329     @JvmOverloads
330     fun <T : CallbackEntry> expect(
331         type: KClass<T>,
332         test: (T) -> Boolean
333     ) = expect(type, ANY_NETWORK, defaultTimeoutMs, null, test)
334 
335     // Kotlin reified versions. Don't call methods above, or the predicate would need to be noinline
336     inline fun <reified T : CallbackEntry> expect(
337         network: Network = ANY_NETWORK,
338         timeoutMs: Long = defaultTimeoutMs,
339         errorMsg: String? = null,
340         test: (T) -> Boolean = { true }
341     ) = (poll(timeoutMs) ?: fail("Did not receive ${T::class.simpleName} after ${timeoutMs}ms"))
<lambda>null342             .also {
343                 if (it !is T) fail("Expected callback ${T::class.simpleName}, got $it")
344                 if (ANY_NETWORK !== network && it.network != network) {
345                     fail("Expected network $network for callback : $it")
346                 }
347                 if (!test(it)) {
348                     fail("${errorMsg ?: "Callback doesn't match predicate"} : $it")
349                 }
350             } as T
351 
352     inline fun <reified T : CallbackEntry> expect(
353         network: HasNetwork,
354         timeoutMs: Long = defaultTimeoutMs,
355         errorMsg: String? = null,
356         test: (T) -> Boolean = { true }
357     ) = expect(network.network, timeoutMs, errorMsg, test)
358 
359     /*****
360      * assertNoCallback family of methods.
361      * These methods make sure that no callback that matches the predicate was received.
362      * If no predicate is given, they make sure that no callback at all was received.
363      * These methods run the waiter func given in the constructor if any.
364      */
365     @JvmOverloads
366     fun assertNoCallback(
367         timeoutMs: Long = defaultNoCallbackTimeoutMs,
368         valid: (CallbackEntry) -> Boolean = { true }
369     ) {
370         waiterFunc.run()
371         history.poll(timeoutMs) { valid(it) }?.let { fail("Expected no callback but got $it") }
372     }
373 
374     fun assertNoCallback(valid: (CallbackEntry) -> Boolean) =
375             assertNoCallback(defaultNoCallbackTimeoutMs, valid)
376 
377     /*****
378      * eventuallyExpect family of methods.
379      * These methods make sure a callback that matches the type/predicate is received eventually.
380      * Any callback of the wrong type, or doesn't match the optional predicate, is ignored.
381      * They fail if no callback matching the predicate is received within the timeout.
382      */
383     inline fun <reified T : CallbackEntry> eventuallyExpect(
384         timeoutMs: Long = defaultTimeoutMs,
385         from: Int = mark,
386         crossinline predicate: (T) -> Boolean = { true }
387     ): T = history.poll(timeoutMs, from) { it is T && predicate(it) }.also {
388         assertNotNull(it, "Callback ${T::class} not received within ${timeoutMs}ms")
389     } as T
390 
391     @JvmOverloads
392     fun <T : CallbackEntry> eventuallyExpect(
393         type: KClass<T>,
394         timeoutMs: Long = defaultTimeoutMs,
395         predicate: (cb: T) -> Boolean = { true }
396     ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it as T) }.also {
397         assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
398     } as T
399 
400     fun <T : CallbackEntry> eventuallyExpect(
401         type: KClass<T>,
402         timeoutMs: Long = defaultTimeoutMs,
403         from: Int = mark,
404         predicate: (cb: T) -> Boolean = { true }
405     ) = history.poll(timeoutMs, from) { type.java.isInstance(it) && predicate(it as T) }.also {
406         assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
407     } as T
408 
409     // Expects onAvailable and the callbacks that follow it. These are:
410     // - onSuspended, iff the network was suspended when the callbacks fire.
411     // - onCapabilitiesChanged.
412     // - onLinkPropertiesChanged.
413     // - onBlockedStatusChanged.
414     //
415     // @param network the network to expect the callbacks on.
416     // @param suspended whether to expect a SUSPENDED callback.
417     // @param validated the expected value of the VALIDATED capability in the
418     //        onCapabilitiesChanged callback.
419     // @param tmt how long to wait for the callbacks.
420     @JvmOverloads
421     fun expectAvailableCallbacks(
422         net: Network,
423         suspended: Boolean = false,
424         validated: Boolean? = true,
425         blocked: Boolean = false,
426         tmt: Long = defaultTimeoutMs
427     ) {
428         expectAvailableCallbacksCommon(net, suspended, validated, tmt)
429         expect<BlockedStatus>(net, tmt) { it.blocked == blocked }
430     }
431 
432     fun expectAvailableCallbacks(
433         net: Network,
434         suspended: Boolean,
435         validated: Boolean,
436         blockedReason: Int,
437         tmt: Long
438     ) {
439         expectAvailableCallbacksCommon(net, suspended, validated, tmt)
440         expect<BlockedStatusInt>(net) { it.reason == blockedReason }
441     }
442 
443     private fun expectAvailableCallbacksCommon(
444         net: Network,
445         suspended: Boolean,
446         validated: Boolean?,
447         tmt: Long
448     ) {
449         expect<Available>(net, tmt)
450         if (suspended) {
451             expect<Suspended>(net, tmt)
452         }
453         expect<CapabilitiesChanged>(net, tmt) {
454             validated == null || validated == it.caps.hasCapability(NET_CAPABILITY_VALIDATED)
455         }
456         expect<LinkPropertiesChanged>(net, tmt)
457     }
458 
459     // Backward compatibility for existing Java code. Use named arguments instead and remove all
460     // these when there is no user left.
461     fun expectAvailableAndSuspendedCallbacks(
462         net: Network,
463         validated: Boolean,
464         tmt: Long = defaultTimeoutMs
465     ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
466 
467     // Expects the available callbacks (where the onCapabilitiesChanged must contain the
468     // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
469     // one we just sent.
470     // TODO: this is likely a bug. Fix it and remove this method.
471     fun expectAvailableDoubleValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
472         val mark = history.mark
473         expectAvailableCallbacks(net, tmt = tmt)
474         val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
475         assertEquals(firstCaps, expect<CapabilitiesChanged>(net, tmt))
476     }
477 
478     // Expects the available callbacks where the onCapabilitiesChanged must not have validated,
479     // then expects another onCapabilitiesChanged that has the validated bit set. This is used
480     // when a network connects and satisfies a callback, and then immediately validates.
481     fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
482         expectAvailableCallbacks(net, validated = false, tmt = tmt)
483         expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
484     }
485 
486     fun expectAvailableThenValidatedCallbacks(
487         net: Network,
488         blockedReason: Int,
489         tmt: Long = defaultTimeoutMs
490     ) {
491         expectAvailableCallbacks(net, validated = false, suspended = false,
492                 blockedReason = blockedReason, tmt = tmt)
493         expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
494     }
495 
496     // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
497     // calls with networkAgent can be routed through here without moving MockNetworkAgent.
498     // TODO: clean this up, remove this method.
499     interface HasNetwork {
500         val network: Network
501     }
502 
503     fun expectAvailableCallbacks(
504         n: HasNetwork,
505         suspended: Boolean,
506         validated: Boolean,
507         blocked: Boolean,
508         timeoutMs: Long
509     ) = expectAvailableCallbacks(n.network, suspended, validated, blocked, timeoutMs)
510 
511     fun expectAvailableAndSuspendedCallbacks(n: HasNetwork, expectValidated: Boolean) {
512         expectAvailableAndSuspendedCallbacks(n.network, expectValidated)
513     }
514 
515     fun expectAvailableCallbacksValidated(n: HasNetwork) {
516         expectAvailableCallbacks(n.network)
517     }
518 
519     fun expectAvailableCallbacksValidatedAndBlocked(n: HasNetwork) {
520         expectAvailableCallbacks(n.network, blocked = true)
521     }
522 
523     fun expectAvailableCallbacksUnvalidated(n: HasNetwork) {
524         expectAvailableCallbacks(n.network, validated = false)
525     }
526 
527     fun expectAvailableCallbacksUnvalidatedAndBlocked(n: HasNetwork) {
528         expectAvailableCallbacks(n.network, validated = false, blocked = true)
529     }
530 
531     fun expectAvailableDoubleValidatedCallbacks(n: HasNetwork) {
532         expectAvailableDoubleValidatedCallbacks(n.network, defaultTimeoutMs)
533     }
534 
535     fun expectAvailableThenValidatedCallbacks(n: HasNetwork) {
536         expectAvailableThenValidatedCallbacks(n.network, defaultTimeoutMs)
537     }
538 
539     @JvmOverloads
540     fun expectCaps(
541         n: HasNetwork,
542         tmt: Long = defaultTimeoutMs,
543         valid: (NetworkCapabilities) -> Boolean = { true }
544     ) = expect<CapabilitiesChanged>(n.network, tmt) { valid(it.caps) }.caps
545 
546     @JvmOverloads
547     fun expectCaps(
548         n: Network,
549         tmt: Long = defaultTimeoutMs,
550         valid: (NetworkCapabilities) -> Boolean
551     ) = expect<CapabilitiesChanged>(n, tmt) { valid(it.caps) }.caps
552 
553     fun expectCaps(
554         n: HasNetwork,
555         valid: (NetworkCapabilities) -> Boolean
556     ) = expect<CapabilitiesChanged>(n.network) { valid(it.caps) }.caps
557 
558     fun expectCaps(
559         tmt: Long,
560         valid: (NetworkCapabilities) -> Boolean
561     ) = expect<CapabilitiesChanged>(ANY_NETWORK, tmt) { valid(it.caps) }.caps
562 }
563