1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.testutils
18
19 import android.net.ConnectivityManager.NetworkCallback
20 import android.net.LinkProperties
21 import android.net.Network
22 import android.net.NetworkCapabilities
23 import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
24 import android.util.Log
25 import com.android.net.module.util.ArrayTrackRecord
26 import com.android.testutils.RecorderCallback.CallbackEntry.Available
27 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
28 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatusInt
29 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
30 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
31 import com.android.testutils.RecorderCallback.CallbackEntry.Losing
32 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
33 import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
34 import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
35 import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
36 import kotlin.reflect.KClass
37 import kotlin.test.assertEquals
38 import kotlin.test.assertNotNull
39 import kotlin.test.fail
40
41 object NULL_NETWORK : Network(-1)
42 object ANY_NETWORK : Network(-2)
anyNetworknull43 fun anyNetwork() = ANY_NETWORK
44
45 open class RecorderCallback private constructor(
46 private val backingRecord: ArrayTrackRecord<CallbackEntry>
47 ) : NetworkCallback() {
48 public constructor() : this(ArrayTrackRecord())
49 protected constructor(src: RecorderCallback?) : this(src?.backingRecord ?: ArrayTrackRecord())
50
51 private val TAG = this::class.simpleName
52
53 sealed class CallbackEntry {
54 // To get equals(), hashcode(), componentN() etc for free, the child classes of
55 // this class are data classes. But while data classes can inherit from other classes,
56 // they may only have visible members in the constructors, so they couldn't declare
57 // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
58 // subclasses to implement a `network' property, which can be done in a data class
59 // constructor by specifying override.
60 abstract val network: Network
61
62 data class Available(override val network: Network) : CallbackEntry()
63 data class CapabilitiesChanged(
64 override val network: Network,
65 val caps: NetworkCapabilities
66 ) : CallbackEntry()
67 data class LinkPropertiesChanged(
68 override val network: Network,
69 val lp: LinkProperties
70 ) : CallbackEntry()
71 data class Suspended(override val network: Network) : CallbackEntry()
72 data class Resumed(override val network: Network) : CallbackEntry()
73 data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
74 data class Lost(override val network: Network) : CallbackEntry()
75 data class Unavailable private constructor(
76 override val network: Network
77 ) : CallbackEntry() {
78 constructor() : this(NULL_NETWORK)
79 }
80 data class BlockedStatus(
81 override val network: Network,
82 val blocked: Boolean
83 ) : CallbackEntry()
84 data class BlockedStatusInt(
85 override val network: Network,
86 val reason: Int
87 ) : CallbackEntry()
88 // Convenience constants for expecting a type
89 companion object {
90 @JvmField
91 val AVAILABLE = Available::class
92 @JvmField
93 val NETWORK_CAPS_UPDATED = CapabilitiesChanged::class
94 @JvmField
95 val LINK_PROPERTIES_CHANGED = LinkPropertiesChanged::class
96 @JvmField
97 val SUSPENDED = Suspended::class
98 @JvmField
99 val RESUMED = Resumed::class
100 @JvmField
101 val LOSING = Losing::class
102 @JvmField
103 val LOST = Lost::class
104 @JvmField
105 val UNAVAILABLE = Unavailable::class
106 @JvmField
107 val BLOCKED_STATUS = BlockedStatus::class
108 @JvmField
109 val BLOCKED_STATUS_INT = BlockedStatusInt::class
110 }
111 }
112
113 val history = backingRecord.newReadHead()
114 val mark get() = history.mark
115
116 override fun onAvailable(network: Network) {
117 Log.d(TAG, "onAvailable $network")
118 history.add(Available(network))
119 }
120
121 // PreCheck is not used in the tests today. For backward compatibility with existing tests that
122 // expect the callbacks not to record this, do not listen to PreCheck here.
123
124 override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
125 Log.d(TAG, "onCapabilitiesChanged $network $caps")
126 history.add(CapabilitiesChanged(network, caps))
127 }
128
129 override fun onLinkPropertiesChanged(network: Network, lp: LinkProperties) {
130 Log.d(TAG, "onLinkPropertiesChanged $network $lp")
131 history.add(LinkPropertiesChanged(network, lp))
132 }
133
134 override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
135 Log.d(TAG, "onBlockedStatusChanged $network $blocked")
136 history.add(BlockedStatus(network, blocked))
137 }
138
139 // Cannot do:
140 // fun onBlockedStatusChanged(network: Network, blocked: Int) {
141 // because on S, that needs to be "override fun", and on R, that cannot be "override fun".
142 override fun onNetworkSuspended(network: Network) {
143 Log.d(TAG, "onNetworkSuspended $network $network")
144 history.add(Suspended(network))
145 }
146
147 override fun onNetworkResumed(network: Network) {
148 Log.d(TAG, "$network onNetworkResumed $network")
149 history.add(Resumed(network))
150 }
151
152 override fun onLosing(network: Network, maxMsToLive: Int) {
153 Log.d(TAG, "onLosing $network $maxMsToLive")
154 history.add(Losing(network, maxMsToLive))
155 }
156
157 override fun onLost(network: Network) {
158 Log.d(TAG, "onLost $network")
159 history.add(Lost(network))
160 }
161
162 override fun onUnavailable() {
163 Log.d(TAG, "onUnavailable")
164 history.add(Unavailable())
165 }
166 }
167
168 private const val DEFAULT_TIMEOUT = 30_000L // ms
169 private const val DEFAULT_NO_CALLBACK_TIMEOUT = 200L // ms
<lambda>null170 private val NOOP = Runnable {}
171
172 /**
173 * See comments on the public constructor below for a description of the arguments.
174 */
175 open class TestableNetworkCallback private constructor(
176 src: TestableNetworkCallback?,
177 val defaultTimeoutMs: Long = DEFAULT_TIMEOUT,
178 val defaultNoCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT,
179 val waiterFunc: Runnable = NOOP // "() -> Unit" would forbid calling with a void func from Java
180 ) : RecorderCallback(src) {
181 /**
182 * Construct a testable network callback.
183 * @param timeoutMs the default timeout for expecting a callback. Default 30 seconds. This
184 * should be long in most cases, because the success case doesn't incur
185 * the wait.
186 * @param noCallbackTimeoutMs the timeout for expecting that no callback is received. Default
187 * 200ms. Because the success case does incur the timeout, this
188 * should be short in most cases, but not so short as to frequently
189 * time out before an incorrect callback is received.
190 * @param waiterFunc a function to use before asserting no callback. For some specific tests,
191 * it is useful to run test-specific code before asserting no callback to
192 * increase the likelihood that a spurious callback is correctly detected.
193 * As an example, a unit test using mock loopers may want to use this to
194 * make sure the loopers are drained before asserting no callback, since
195 * one of them may cause a callback to be called. @see ConnectivityServiceTest
196 * for such an example.
197 */
198 @JvmOverloads
199 constructor(
200 timeoutMs: Long = DEFAULT_TIMEOUT,
201 noCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT,
202 waiterFunc: Runnable = NOOP
203 ) : this(null, timeoutMs, noCallbackTimeoutMs, waiterFunc)
204
createLinkedCopynull205 fun createLinkedCopy() = TestableNetworkCallback(
206 this, defaultTimeoutMs, defaultNoCallbackTimeoutMs, waiterFunc)
207
208 // The last available network, or null if any network was lost since the last call to
209 // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
210 val lastAvailableNetwork: Network?
211 get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
212 is Available -> it.network
213 else -> null
214 }
215
216 /**
217 * Get the next callback or null if timeout.
218 *
219 * With no argument, this method waits out the default timeout. To wait forever, pass
220 * Long.MAX_VALUE.
221 */
222 @JvmOverloads
<lambda>null223 fun poll(timeoutMs: Long = defaultTimeoutMs, predicate: (CallbackEntry) -> Boolean = { true }) =
224 history.poll(timeoutMs, predicate)
225
226 /*****
227 * expect family of methods.
228 * These methods fetch the next callback and assert it matches the conditions : type,
229 * passed predicate. If no callback is received within the timeout, these methods fail.
230 */
231 @JvmOverloads
expectnull232 fun <T : CallbackEntry> expect(
233 type: KClass<T>,
234 network: Network = ANY_NETWORK,
235 timeoutMs: Long = defaultTimeoutMs,
236 errorMsg: String? = null,
237 test: (T) -> Boolean = { true }
<lambda>null238 ) = expect<CallbackEntry>(network, timeoutMs, errorMsg) {
239 test(it as? T ?: fail("Expected callback ${type.simpleName}, got $it"))
240 } as T
241
242 @JvmOverloads
expectnull243 fun <T : CallbackEntry> expect(
244 type: KClass<T>,
245 network: HasNetwork,
246 timeoutMs: Long = defaultTimeoutMs,
247 errorMsg: String? = null,
248 test: (T) -> Boolean = { true }
249 ) = expect(type, network.network, timeoutMs, errorMsg, test)
250
251 // Java needs an explicit overload to let it omit arguments in the middle, so define these
252 // here. Note that @JvmOverloads give us the versions without the last arguments too, so
253 // there is no need to explicitly define versions without the test predicate.
254 // Without |network|
255 @JvmOverloads
expectnull256 fun <T : CallbackEntry> expect(
257 type: KClass<T>,
258 timeoutMs: Long,
259 errorMsg: String?,
260 test: (T) -> Boolean = { true }
261 ) = expect(type, ANY_NETWORK, timeoutMs, errorMsg, test)
262
263 // Without |timeout|, in Network and HasNetwork versions
264 @JvmOverloads
expectnull265 fun <T : CallbackEntry> expect(
266 type: KClass<T>,
267 network: Network,
268 errorMsg: String?,
269 test: (T) -> Boolean = { true }
270 ) = expect(type, network, defaultTimeoutMs, errorMsg, test)
271
272 @JvmOverloads
expectnull273 fun <T : CallbackEntry> expect(
274 type: KClass<T>,
275 network: HasNetwork,
276 errorMsg: String?,
277 test: (T) -> Boolean = { true }
278 ) = expect(type, network.network, defaultTimeoutMs, errorMsg, test)
279
280 // Without |errorMsg|, in Network and HasNetwork versions
281 @JvmOverloads
expectnull282 fun <T : CallbackEntry> expect(
283 type: KClass<T>,
284 network: Network,
285 timeoutMs: Long,
286 test: (T) -> Boolean
287 ) = expect(type, network, timeoutMs, null, test)
288
289 @JvmOverloads
290 fun <T : CallbackEntry> expect(
291 type: KClass<T>,
292 network: HasNetwork,
293 timeoutMs: Long,
294 test: (T) -> Boolean
295 ) = expect(type, network.network, timeoutMs, null, test)
296
297 // Without |network| or |timeout|
298 @JvmOverloads
299 fun <T : CallbackEntry> expect(
300 type: KClass<T>,
301 errorMsg: String?,
302 test: (T) -> Boolean = { true }
303 ) = expect(type, ANY_NETWORK, defaultTimeoutMs, errorMsg, test)
304
305 // Without |network| or |errorMsg|
306 @JvmOverloads
expectnull307 fun <T : CallbackEntry> expect(
308 type: KClass<T>,
309 timeoutMs: Long,
310 test: (T) -> Boolean = { true }
311 ) = expect(type, ANY_NETWORK, timeoutMs, null, test)
312
313 // Without |timeout| or |errorMsg|, in Network and HasNetwork versions
314 @JvmOverloads
expectnull315 fun <T : CallbackEntry> expect(
316 type: KClass<T>,
317 network: Network,
318 test: (T) -> Boolean
319 ) = expect(type, network, defaultTimeoutMs, null, test)
320
321 @JvmOverloads
322 fun <T : CallbackEntry> expect(
323 type: KClass<T>,
324 network: HasNetwork,
325 test: (T) -> Boolean
326 ) = expect(type, network.network, defaultTimeoutMs, null, test)
327
328 // Without |network| or |timeout| or |errorMsg|
329 @JvmOverloads
330 fun <T : CallbackEntry> expect(
331 type: KClass<T>,
332 test: (T) -> Boolean
333 ) = expect(type, ANY_NETWORK, defaultTimeoutMs, null, test)
334
335 // Kotlin reified versions. Don't call methods above, or the predicate would need to be noinline
336 inline fun <reified T : CallbackEntry> expect(
337 network: Network = ANY_NETWORK,
338 timeoutMs: Long = defaultTimeoutMs,
339 errorMsg: String? = null,
340 test: (T) -> Boolean = { true }
341 ) = (poll(timeoutMs) ?: fail("Did not receive ${T::class.simpleName} after ${timeoutMs}ms"))
<lambda>null342 .also {
343 if (it !is T) fail("Expected callback ${T::class.simpleName}, got $it")
344 if (ANY_NETWORK !== network && it.network != network) {
345 fail("Expected network $network for callback : $it")
346 }
347 if (!test(it)) {
348 fail("${errorMsg ?: "Callback doesn't match predicate"} : $it")
349 }
350 } as T
351
352 inline fun <reified T : CallbackEntry> expect(
353 network: HasNetwork,
354 timeoutMs: Long = defaultTimeoutMs,
355 errorMsg: String? = null,
356 test: (T) -> Boolean = { true }
357 ) = expect(network.network, timeoutMs, errorMsg, test)
358
359 /*****
360 * assertNoCallback family of methods.
361 * These methods make sure that no callback that matches the predicate was received.
362 * If no predicate is given, they make sure that no callback at all was received.
363 * These methods run the waiter func given in the constructor if any.
364 */
365 @JvmOverloads
366 fun assertNoCallback(
367 timeoutMs: Long = defaultNoCallbackTimeoutMs,
368 valid: (CallbackEntry) -> Boolean = { true }
369 ) {
370 waiterFunc.run()
371 history.poll(timeoutMs) { valid(it) }?.let { fail("Expected no callback but got $it") }
372 }
373
374 fun assertNoCallback(valid: (CallbackEntry) -> Boolean) =
375 assertNoCallback(defaultNoCallbackTimeoutMs, valid)
376
377 /*****
378 * eventuallyExpect family of methods.
379 * These methods make sure a callback that matches the type/predicate is received eventually.
380 * Any callback of the wrong type, or doesn't match the optional predicate, is ignored.
381 * They fail if no callback matching the predicate is received within the timeout.
382 */
383 inline fun <reified T : CallbackEntry> eventuallyExpect(
384 timeoutMs: Long = defaultTimeoutMs,
385 from: Int = mark,
386 crossinline predicate: (T) -> Boolean = { true }
387 ): T = history.poll(timeoutMs, from) { it is T && predicate(it) }.also {
388 assertNotNull(it, "Callback ${T::class} not received within ${timeoutMs}ms")
389 } as T
390
391 @JvmOverloads
392 fun <T : CallbackEntry> eventuallyExpect(
393 type: KClass<T>,
394 timeoutMs: Long = defaultTimeoutMs,
395 predicate: (cb: T) -> Boolean = { true }
396 ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it as T) }.also {
397 assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
398 } as T
399
400 fun <T : CallbackEntry> eventuallyExpect(
401 type: KClass<T>,
402 timeoutMs: Long = defaultTimeoutMs,
403 from: Int = mark,
404 predicate: (cb: T) -> Boolean = { true }
405 ) = history.poll(timeoutMs, from) { type.java.isInstance(it) && predicate(it as T) }.also {
406 assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
407 } as T
408
409 // Expects onAvailable and the callbacks that follow it. These are:
410 // - onSuspended, iff the network was suspended when the callbacks fire.
411 // - onCapabilitiesChanged.
412 // - onLinkPropertiesChanged.
413 // - onBlockedStatusChanged.
414 //
415 // @param network the network to expect the callbacks on.
416 // @param suspended whether to expect a SUSPENDED callback.
417 // @param validated the expected value of the VALIDATED capability in the
418 // onCapabilitiesChanged callback.
419 // @param tmt how long to wait for the callbacks.
420 @JvmOverloads
421 fun expectAvailableCallbacks(
422 net: Network,
423 suspended: Boolean = false,
424 validated: Boolean? = true,
425 blocked: Boolean = false,
426 tmt: Long = defaultTimeoutMs
427 ) {
428 expectAvailableCallbacksCommon(net, suspended, validated, tmt)
429 expect<BlockedStatus>(net, tmt) { it.blocked == blocked }
430 }
431
432 fun expectAvailableCallbacks(
433 net: Network,
434 suspended: Boolean,
435 validated: Boolean,
436 blockedReason: Int,
437 tmt: Long
438 ) {
439 expectAvailableCallbacksCommon(net, suspended, validated, tmt)
440 expect<BlockedStatusInt>(net) { it.reason == blockedReason }
441 }
442
443 private fun expectAvailableCallbacksCommon(
444 net: Network,
445 suspended: Boolean,
446 validated: Boolean?,
447 tmt: Long
448 ) {
449 expect<Available>(net, tmt)
450 if (suspended) {
451 expect<Suspended>(net, tmt)
452 }
453 expect<CapabilitiesChanged>(net, tmt) {
454 validated == null || validated == it.caps.hasCapability(NET_CAPABILITY_VALIDATED)
455 }
456 expect<LinkPropertiesChanged>(net, tmt)
457 }
458
459 // Backward compatibility for existing Java code. Use named arguments instead and remove all
460 // these when there is no user left.
461 fun expectAvailableAndSuspendedCallbacks(
462 net: Network,
463 validated: Boolean,
464 tmt: Long = defaultTimeoutMs
465 ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
466
467 // Expects the available callbacks (where the onCapabilitiesChanged must contain the
468 // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
469 // one we just sent.
470 // TODO: this is likely a bug. Fix it and remove this method.
471 fun expectAvailableDoubleValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
472 val mark = history.mark
473 expectAvailableCallbacks(net, tmt = tmt)
474 val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
475 assertEquals(firstCaps, expect<CapabilitiesChanged>(net, tmt))
476 }
477
478 // Expects the available callbacks where the onCapabilitiesChanged must not have validated,
479 // then expects another onCapabilitiesChanged that has the validated bit set. This is used
480 // when a network connects and satisfies a callback, and then immediately validates.
481 fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
482 expectAvailableCallbacks(net, validated = false, tmt = tmt)
483 expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
484 }
485
486 fun expectAvailableThenValidatedCallbacks(
487 net: Network,
488 blockedReason: Int,
489 tmt: Long = defaultTimeoutMs
490 ) {
491 expectAvailableCallbacks(net, validated = false, suspended = false,
492 blockedReason = blockedReason, tmt = tmt)
493 expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
494 }
495
496 // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
497 // calls with networkAgent can be routed through here without moving MockNetworkAgent.
498 // TODO: clean this up, remove this method.
499 interface HasNetwork {
500 val network: Network
501 }
502
503 fun expectAvailableCallbacks(
504 n: HasNetwork,
505 suspended: Boolean,
506 validated: Boolean,
507 blocked: Boolean,
508 timeoutMs: Long
509 ) = expectAvailableCallbacks(n.network, suspended, validated, blocked, timeoutMs)
510
511 fun expectAvailableAndSuspendedCallbacks(n: HasNetwork, expectValidated: Boolean) {
512 expectAvailableAndSuspendedCallbacks(n.network, expectValidated)
513 }
514
515 fun expectAvailableCallbacksValidated(n: HasNetwork) {
516 expectAvailableCallbacks(n.network)
517 }
518
519 fun expectAvailableCallbacksValidatedAndBlocked(n: HasNetwork) {
520 expectAvailableCallbacks(n.network, blocked = true)
521 }
522
523 fun expectAvailableCallbacksUnvalidated(n: HasNetwork) {
524 expectAvailableCallbacks(n.network, validated = false)
525 }
526
527 fun expectAvailableCallbacksUnvalidatedAndBlocked(n: HasNetwork) {
528 expectAvailableCallbacks(n.network, validated = false, blocked = true)
529 }
530
531 fun expectAvailableDoubleValidatedCallbacks(n: HasNetwork) {
532 expectAvailableDoubleValidatedCallbacks(n.network, defaultTimeoutMs)
533 }
534
535 fun expectAvailableThenValidatedCallbacks(n: HasNetwork) {
536 expectAvailableThenValidatedCallbacks(n.network, defaultTimeoutMs)
537 }
538
539 @JvmOverloads
540 fun expectCaps(
541 n: HasNetwork,
542 tmt: Long = defaultTimeoutMs,
543 valid: (NetworkCapabilities) -> Boolean = { true }
544 ) = expect<CapabilitiesChanged>(n.network, tmt) { valid(it.caps) }.caps
545
546 @JvmOverloads
547 fun expectCaps(
548 n: Network,
549 tmt: Long = defaultTimeoutMs,
550 valid: (NetworkCapabilities) -> Boolean
551 ) = expect<CapabilitiesChanged>(n, tmt) { valid(it.caps) }.caps
552
553 fun expectCaps(
554 n: HasNetwork,
555 valid: (NetworkCapabilities) -> Boolean
556 ) = expect<CapabilitiesChanged>(n.network) { valid(it.caps) }.caps
557
558 fun expectCaps(
559 tmt: Long,
560 valid: (NetworkCapabilities) -> Boolean
561 ) = expect<CapabilitiesChanged>(ANY_NETWORK, tmt) { valid(it.caps) }.caps
562 }
563