1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.testutils
18
19 import android.net.ConnectivityManager.NetworkCallback
20 import android.net.LinkProperties
21 import android.net.Network
22 import android.net.NetworkCapabilities
23 import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
24 import android.util.Log
25 import com.android.net.module.util.ArrayTrackRecord
26 import com.android.testutils.RecorderCallback.CallbackEntry.Available
27 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
28 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatusInt
29 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
30 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
31 import com.android.testutils.RecorderCallback.CallbackEntry.Losing
32 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
33 import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
34 import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
35 import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
36 import kotlin.reflect.KClass
37 import kotlin.test.assertEquals
38 import kotlin.test.assertNotNull
39 import kotlin.test.assertTrue
40 import kotlin.test.fail
41
42 object NULL_NETWORK : Network(-1)
43 object ANY_NETWORK : Network(-2)
anyNetworknull44 fun anyNetwork() = ANY_NETWORK
45
46 private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
47
48 open class RecorderCallback private constructor(
49 private val backingRecord: ArrayTrackRecord<CallbackEntry>
50 ) : NetworkCallback() {
51 public constructor() : this(ArrayTrackRecord())
52 protected constructor(src: RecorderCallback?): this(src?.backingRecord ?: ArrayTrackRecord())
53
54 private val TAG = this::class.simpleName
55
56 sealed class CallbackEntry {
57 // To get equals(), hashcode(), componentN() etc for free, the child classes of
58 // this class are data classes. But while data classes can inherit from other classes,
59 // they may only have visible members in the constructors, so they couldn't declare
60 // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
61 // subclasses to implement a `network' property, which can be done in a data class
62 // constructor by specifying override.
63 abstract val network: Network
64
65 data class Available(override val network: Network) : CallbackEntry()
66 data class CapabilitiesChanged(
67 override val network: Network,
68 val caps: NetworkCapabilities
69 ) : CallbackEntry()
70 data class LinkPropertiesChanged(
71 override val network: Network,
72 val lp: LinkProperties
73 ) : CallbackEntry()
74 data class Suspended(override val network: Network) : CallbackEntry()
75 data class Resumed(override val network: Network) : CallbackEntry()
76 data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
77 data class Lost(override val network: Network) : CallbackEntry()
78 data class Unavailable private constructor(
79 override val network: Network
80 ) : CallbackEntry() {
81 constructor() : this(NULL_NETWORK)
82 }
83 data class BlockedStatus(
84 override val network: Network,
85 val blocked: Boolean
86 ) : CallbackEntry()
87 data class BlockedStatusInt(
88 override val network: Network,
89 val blocked: Int
90 ) : CallbackEntry()
91 // Convenience constants for expecting a type
92 companion object {
93 @JvmField
94 val AVAILABLE = Available::class
95 @JvmField
96 val NETWORK_CAPS_UPDATED = CapabilitiesChanged::class
97 @JvmField
98 val LINK_PROPERTIES_CHANGED = LinkPropertiesChanged::class
99 @JvmField
100 val SUSPENDED = Suspended::class
101 @JvmField
102 val RESUMED = Resumed::class
103 @JvmField
104 val LOSING = Losing::class
105 @JvmField
106 val LOST = Lost::class
107 @JvmField
108 val UNAVAILABLE = Unavailable::class
109 @JvmField
110 val BLOCKED_STATUS = BlockedStatus::class
111 @JvmField
112 val BLOCKED_STATUS_INT = BlockedStatusInt::class
113 }
114 }
115
116 val history = backingRecord.newReadHead()
117 val mark get() = history.mark
118
119 override fun onAvailable(network: Network) {
120 Log.d(TAG, "onAvailable $network")
121 history.add(Available(network))
122 }
123
124 // PreCheck is not used in the tests today. For backward compatibility with existing tests that
125 // expect the callbacks not to record this, do not listen to PreCheck here.
126
127 override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
128 Log.d(TAG, "onCapabilitiesChanged $network $caps")
129 history.add(CapabilitiesChanged(network, caps))
130 }
131
132 override fun onLinkPropertiesChanged(network: Network, lp: LinkProperties) {
133 Log.d(TAG, "onLinkPropertiesChanged $network $lp")
134 history.add(LinkPropertiesChanged(network, lp))
135 }
136
137 override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
138 Log.d(TAG, "onBlockedStatusChanged $network $blocked")
139 history.add(BlockedStatus(network, blocked))
140 }
141
142 // Cannot do:
143 // fun onBlockedStatusChanged(network: Network, blocked: Int) {
144 // because on S, that needs to be "override fun", and on R, that cannot be "override fun".
145 override fun onNetworkSuspended(network: Network) {
146 Log.d(TAG, "onNetworkSuspended $network $network")
147 history.add(Suspended(network))
148 }
149
150 override fun onNetworkResumed(network: Network) {
151 Log.d(TAG, "$network onNetworkResumed $network")
152 history.add(Resumed(network))
153 }
154
155 override fun onLosing(network: Network, maxMsToLive: Int) {
156 Log.d(TAG, "onLosing $network $maxMsToLive")
157 history.add(Losing(network, maxMsToLive))
158 }
159
160 override fun onLost(network: Network) {
161 Log.d(TAG, "onLost $network")
162 history.add(Lost(network))
163 }
164
165 override fun onUnavailable() {
166 Log.d(TAG, "onUnavailable")
167 history.add(Unavailable())
168 }
169 }
170
171 private const val DEFAULT_TIMEOUT = 200L // ms
172
173 open class TestableNetworkCallback private constructor(
174 src: TestableNetworkCallback?,
175 val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
176 ) : RecorderCallback(src) {
177 @JvmOverloads
178 constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
179
createLinkedCopynull180 fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
181
182 // The last available network, or null if any network was lost since the last call to
183 // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
184 val lastAvailableNetwork: Network?
185 get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
186 is Available -> it.network
187 else -> null
188 }
189
pollForNextCallbacknull190 fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
191 return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
192 }
193
194 // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
195 // TODO : remove the necessity to overload this, remove the open qualifier, and give a
196 // default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary.
assertNoCallbacknull197 open fun assertNoCallback() = assertNoCallback(defaultTimeoutMs)
198
199 fun assertNoCallback(timeoutMs: Long) {
200 val cb = history.poll(timeoutMs)
201 if (null != cb) fail("Expected no callback but got $cb")
202 }
203
assertNoCallbackThatnull204 fun assertNoCallbackThat(
205 timeoutMs: Long = defaultTimeoutMs,
206 valid: (CallbackEntry) -> Boolean
207 ) {
208 val cb = history.poll(timeoutMs) { valid(it) }.let {
209 if (null != it) fail("Expected no callback but got $it")
210 }
211 }
212
213 // Expects a callback of the specified type on the specified network within the timeout.
214 // If no callback arrives, or a different callback arrives, fail. Returns the callback.
expectCallbacknull215 inline fun <reified T : CallbackEntry> expectCallback(
216 network: Network = ANY_NETWORK,
217 timeoutMs: Long = defaultTimeoutMs
218 ): T = pollForNextCallback(timeoutMs).let {
219 if (it !is T || (ANY_NETWORK !== network && it.network != network)) {
220 fail("Unexpected callback : $it, expected ${T::class} with Network[$network]")
221 } else {
222 it
223 }
224 }
225
226 // Expects a callback of the specified type matching the predicate within the timeout.
227 // Any callback that doesn't match the predicate will be skipped. Fails only if
228 // no matching callback is received within the timeout.
eventuallyExpectnull229 inline fun <reified T : CallbackEntry> eventuallyExpect(
230 timeoutMs: Long = defaultTimeoutMs,
231 from: Int = mark,
232 crossinline predicate: (T) -> Boolean = { true }
<lambda>null233 ): T = eventuallyExpectOrNull(timeoutMs, from, predicate).also {
234 assertNotNull(it, "Callback ${T::class} not received within ${timeoutMs}ms")
235 } as T
236
eventuallyExpectnull237 fun <T : CallbackEntry> eventuallyExpect(
238 type: KClass<T>,
239 timeoutMs: Long = defaultTimeoutMs,
240 predicate: (T: CallbackEntry) -> Boolean = { true }
<lambda>null241 ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it) }.also {
242 assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
243 } as T
244
245 // TODO (b/157405399) straighten and unify the method names
eventuallyExpectOrNullnull246 inline fun <reified T : CallbackEntry> eventuallyExpectOrNull(
247 timeoutMs: Long = defaultTimeoutMs,
248 from: Int = mark,
249 crossinline predicate: (T) -> Boolean = { true }
<lambda>null250 ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
251
expectCallbackThatnull252 fun expectCallbackThat(
253 timeoutMs: Long = defaultTimeoutMs,
254 valid: (CallbackEntry) -> Boolean
255 ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
256
expectCapabilitiesThatnull257 fun expectCapabilitiesThat(
258 net: Network,
259 tmt: Long = defaultTimeoutMs,
260 valid: (NetworkCapabilities) -> Boolean
261 ): CapabilitiesChanged {
262 return expectCallback<CapabilitiesChanged>(net, tmt).also {
263 assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}")
264 }
265 }
266
expectLinkPropertiesThatnull267 fun expectLinkPropertiesThat(
268 net: Network,
269 tmt: Long = defaultTimeoutMs,
270 valid: (LinkProperties) -> Boolean
271 ): LinkPropertiesChanged {
272 return expectCallback<LinkPropertiesChanged>(net, tmt).also {
273 assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}")
274 }
275 }
276
277 // Expects onAvailable and the callbacks that follow it. These are:
278 // - onSuspended, iff the network was suspended when the callbacks fire.
279 // - onCapabilitiesChanged.
280 // - onLinkPropertiesChanged.
281 // - onBlockedStatusChanged.
282 //
283 // @param network the network to expect the callbacks on.
284 // @param suspended whether to expect a SUSPENDED callback.
285 // @param validated the expected value of the VALIDATED capability in the
286 // onCapabilitiesChanged callback.
287 // @param tmt how long to wait for the callbacks.
expectAvailableCallbacksnull288 fun expectAvailableCallbacks(
289 net: Network,
290 suspended: Boolean = false,
291 validated: Boolean? = true,
292 blocked: Boolean = false,
293 tmt: Long = defaultTimeoutMs
294 ) {
295 expectAvailableCallbacksCommon(net, suspended, validated, tmt)
296 expectBlockedStatusCallback(blocked, net, tmt)
297 }
298
expectAvailableCallbacksnull299 fun expectAvailableCallbacks(
300 net: Network,
301 suspended: Boolean,
302 validated: Boolean,
303 blockedStatus: Int,
304 tmt: Long
305 ) {
306 expectAvailableCallbacksCommon(net, suspended, validated, tmt)
307 expectBlockedStatusCallback(blockedStatus, net)
308 }
309
expectAvailableCallbacksCommonnull310 private fun expectAvailableCallbacksCommon(
311 net: Network,
312 suspended: Boolean,
313 validated: Boolean?,
314 tmt: Long
315 ) {
316 expectCallback<Available>(net, tmt)
317 if (suspended) {
318 expectCallback<Suspended>(net, tmt)
319 }
320 expectCapabilitiesThat(net, tmt) {
321 validated == null || validated == it.hasCapability(
322 NET_CAPABILITY_VALIDATED
323 )
324 }
325 expectCallback<LinkPropertiesChanged>(net, tmt)
326 }
327
328 // Backward compatibility for existing Java code. Use named arguments instead and remove all
329 // these when there is no user left.
expectAvailableAndSuspendedCallbacksnull330 fun expectAvailableAndSuspendedCallbacks(
331 net: Network,
332 validated: Boolean,
333 tmt: Long = defaultTimeoutMs
334 ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
335
336 fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) {
337 expectCallback<BlockedStatus>(net, tmt).also {
338 assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
339 }
340 }
341
expectBlockedStatusCallbacknull342 fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) {
343 expectCallback<BlockedStatusInt>(net, tmt).also {
344 assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
345 }
346 }
347
348 // Expects the available callbacks (where the onCapabilitiesChanged must contain the
349 // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
350 // one we just sent.
351 // TODO: this is likely a bug. Fix it and remove this method.
expectAvailableDoubleValidatedCallbacksnull352 fun expectAvailableDoubleValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
353 val mark = history.mark
354 expectAvailableCallbacks(net, tmt = tmt)
355 val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
356 assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt))
357 }
358
359 // Expects the available callbacks where the onCapabilitiesChanged must not have validated,
360 // then expects another onCapabilitiesChanged that has the validated bit set. This is used
361 // when a network connects and satisfies a callback, and then immediately validates.
expectAvailableThenValidatedCallbacksnull362 fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
363 expectAvailableCallbacks(net, validated = false, tmt = tmt)
364 expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
365 }
366
expectAvailableThenValidatedCallbacksnull367 fun expectAvailableThenValidatedCallbacks(
368 net: Network,
369 blockedStatus: Int,
370 tmt: Long = defaultTimeoutMs
371 ) {
372 expectAvailableCallbacks(net, validated = false, suspended = false,
373 blockedStatus = blockedStatus, tmt = tmt)
374 expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
375 }
376
377 // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
378 // calls with networkAgent can be routed through here without moving MockNetworkAgent.
379 // TODO: clean this up, remove this method.
380 interface HasNetwork {
381 val network: Network
382 }
383
384 @JvmOverloads
expectCallbacknull385 open fun <T : CallbackEntry> expectCallback(
386 type: KClass<T>,
387 n: Network?,
388 timeoutMs: Long = defaultTimeoutMs
389 ) = pollForNextCallback(timeoutMs).also {
390 val network = n ?: NULL_NETWORK
391 // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of
392 // this writing this would be the only use of this library in the tests.
393 assertTrue(type.java.isInstance(it) && (ANY_NETWORK === n || it.network == network),
394 "Unexpected callback : $it, expected ${type.java} with Network[$network]")
395 } as T
396
397 @JvmOverloads
expectCallbacknull398 open fun <T : CallbackEntry> expectCallback(
399 type: KClass<T>,
400 n: HasNetwork?,
401 timeoutMs: Long = defaultTimeoutMs
402 ) = expectCallback(type, n?.network, timeoutMs)
403
404 fun expectAvailableCallbacks(
405 n: HasNetwork,
406 suspended: Boolean,
407 validated: Boolean,
408 blocked: Boolean,
409 timeoutMs: Long
410 ) = expectAvailableCallbacks(n.network, suspended, validated, blocked, timeoutMs)
411
412 fun expectAvailableAndSuspendedCallbacks(n: HasNetwork, expectValidated: Boolean) {
413 expectAvailableAndSuspendedCallbacks(n.network, expectValidated)
414 }
415
expectAvailableCallbacksValidatednull416 fun expectAvailableCallbacksValidated(n: HasNetwork) {
417 expectAvailableCallbacks(n.network)
418 }
419
expectAvailableCallbacksValidatedAndBlockednull420 fun expectAvailableCallbacksValidatedAndBlocked(n: HasNetwork) {
421 expectAvailableCallbacks(n.network, blocked = true)
422 }
423
expectAvailableCallbacksUnvalidatednull424 fun expectAvailableCallbacksUnvalidated(n: HasNetwork) {
425 expectAvailableCallbacks(n.network, validated = false)
426 }
427
expectAvailableCallbacksUnvalidatedAndBlockednull428 fun expectAvailableCallbacksUnvalidatedAndBlocked(n: HasNetwork) {
429 expectAvailableCallbacks(n.network, validated = false, blocked = true)
430 }
431
expectAvailableDoubleValidatedCallbacksnull432 fun expectAvailableDoubleValidatedCallbacks(n: HasNetwork) {
433 expectAvailableDoubleValidatedCallbacks(n.network, defaultTimeoutMs)
434 }
435
expectAvailableThenValidatedCallbacksnull436 fun expectAvailableThenValidatedCallbacks(n: HasNetwork) {
437 expectAvailableThenValidatedCallbacks(n.network, defaultTimeoutMs)
438 }
439
440 @JvmOverloads
expectLinkPropertiesThatnull441 fun expectLinkPropertiesThat(
442 n: HasNetwork,
443 tmt: Long = defaultTimeoutMs,
444 valid: (LinkProperties) -> Boolean
445 ) = expectLinkPropertiesThat(n.network, tmt, valid)
446
447 @JvmOverloads
448 fun expectCapabilitiesThat(
449 n: HasNetwork,
450 tmt: Long = defaultTimeoutMs,
451 valid: (NetworkCapabilities) -> Boolean
452 ) = expectCapabilitiesThat(n.network, tmt, valid)
453
454 @JvmOverloads
455 fun expectCapabilitiesWith(
456 capability: Int,
457 n: HasNetwork,
458 timeoutMs: Long = defaultTimeoutMs
459 ): NetworkCapabilities {
460 return expectCapabilitiesThat(n.network, timeoutMs) { it.hasCapability(capability) }.caps
461 }
462
463 @JvmOverloads
expectCapabilitiesWithoutnull464 fun expectCapabilitiesWithout(
465 capability: Int,
466 n: HasNetwork,
467 timeoutMs: Long = defaultTimeoutMs
468 ): NetworkCapabilities {
469 return expectCapabilitiesThat(n.network, timeoutMs) { !it.hasCapability(capability) }.caps
470 }
471
expectBlockedStatusCallbacknull472 fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) {
473 expectBlockedStatusCallback(expectBlocked, n.network, defaultTimeoutMs)
474 }
475 }
476