1 /*
<lambda>null2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.compose.ui.scrollcapture
18 
19 import android.graphics.Bitmap
20 import android.graphics.Bitmap.Config.ARGB_8888
21 import android.graphics.ColorSpace
22 import android.graphics.PixelFormat
23 import android.graphics.Point
24 import android.graphics.Rect
25 import android.hardware.HardwareBuffer.USAGE_GPU_COLOR_OUTPUT
26 import android.hardware.HardwareBuffer.USAGE_GPU_SAMPLED_IMAGE
27 import android.media.Image
28 import android.media.ImageReader
29 import android.os.CancellationSignal
30 import android.os.Handler
31 import android.os.Looper
32 import android.view.ScrollCaptureCallback
33 import android.view.ScrollCaptureSession
34 import android.view.ScrollCaptureTarget
35 import android.view.Surface
36 import android.view.View
37 import androidx.annotation.RequiresApi
38 import androidx.compose.runtime.Composable
39 import androidx.compose.ui.geometry.Offset
40 import androidx.compose.ui.internal.checkPreconditionNotNull
41 import androidx.compose.ui.internal.requirePrecondition
42 import androidx.compose.ui.platform.AndroidComposeView
43 import androidx.compose.ui.platform.AndroidUiDispatcher
44 import androidx.compose.ui.platform.LocalView
45 import androidx.compose.ui.test.junit4.ComposeContentTestRule
46 import java.util.concurrent.CountDownLatch
47 import kotlin.coroutines.resume
48 import kotlin.math.roundToInt
49 import kotlin.test.fail
50 import kotlinx.coroutines.CancellableContinuation
51 import kotlinx.coroutines.CoroutineScope
52 import kotlinx.coroutines.CoroutineStart
53 import kotlinx.coroutines.ExperimentalCoroutinesApi
54 import kotlinx.coroutines.channels.Channel
55 import kotlinx.coroutines.channels.ReceiveChannel
56 import kotlinx.coroutines.coroutineScope
57 import kotlinx.coroutines.currentCoroutineContext
58 import kotlinx.coroutines.ensureActive
59 import kotlinx.coroutines.launch
60 import kotlinx.coroutines.selects.onTimeout
61 import kotlinx.coroutines.selects.select
62 import kotlinx.coroutines.suspendCancellableCoroutine
63 import kotlinx.coroutines.withContext
64 
65 /**
66  * Helps tests pretend to be the Android platform performing scroll capture search and image
67  * capture. Tests must call [setContent] on this class instead of on [rule], and the entire test
68  * should be run in the coroutine started by the [runTest] method on this class.
69  */
70 @RequiresApi(31)
71 class ScrollCaptureTester(private val rule: ComposeContentTestRule) {
72 
73     interface CaptureSessionScope {
74         val windowHeight: Int
75 
76         suspend fun performCapture(): CaptureResult
77 
78         fun shiftWindowBy(offset: Int)
79     }
80 
81     class CaptureResult(val bitmap: Bitmap?, val capturedRect: Rect)
82 
83     private var view: View? = null
84 
85     fun setContent(content: @Composable () -> Unit) {
86         rule.setContent {
87             this.view = LocalView.current
88             content()
89         }
90     }
91 
92     /**
93      * Workaround for standard kotlin runTest because it deadlocks when a composition coroutine
94      * calls `delay()`.
95      */
96     fun runTest(timeoutMillis: Long = 5_000, block: suspend CoroutineScope.() -> Unit) {
97         val scope = CoroutineScope(AndroidUiDispatcher.Main)
98         val latch = CountDownLatch(1)
99         var result: Result<Unit>? = null
100         scope.launch {
101             result = runCatching { block() }
102             latch.countDown()
103         }
104         rule.waitUntil("Test coroutine completed", timeoutMillis) { result != null }
105         return result!!.getOrThrow()
106     }
107 
108     /**
109      * Calls [View.onScrollCaptureSearch] on the Compose host view, which searches the composition
110      * from [setContent] for scroll containers, and returns all the [ScrollCaptureTarget]s produced
111      * that would be given to the platform in production.
112      */
113     suspend fun findCaptureTargets(): List<ScrollCaptureTarget> {
114         rule.awaitIdle()
115         return withContext(AndroidUiDispatcher.Main) {
116             val view =
117                 checkNotNull(view as? AndroidComposeView) {
118                     "Must call setContent on ScrollCaptureTester before capturing."
119                 }
120             val localVisibleRect = Rect().also(view::getLocalVisibleRect)
121             val windowOffset = view.calculatePositionInWindow(Offset.Zero).roundToPoint()
122             val targets = mutableListOf<ScrollCaptureTarget>()
123             view.onScrollCaptureSearch(localVisibleRect, windowOffset, targets::add)
124             targets
125         }
126     }
127 
128     /**
129      * Runs a capture session. [block] should call methods on [CaptureSessionScope] to incrementally
130      * capture bitmaps of [target].
131      *
132      * @param captureWindowHeight The height of the capture window. Must not be greater than
133      *   viewport height.
134      */
135     suspend fun <T> capture(
136         target: ScrollCaptureTarget,
137         captureWindowHeight: Int,
138         block: suspend CaptureSessionScope.() -> T
139     ): T =
140         withContext(AndroidUiDispatcher.Main) {
141             val callback = target.callback
142             // Use the bounds returned from the callback, not the ones from the target, because
143             // that's
144             // what the system does.
145             val scrollBounds = callback.onScrollCaptureSearch()
146             val captureWidth = scrollBounds.width()
147             requirePrecondition(captureWindowHeight <= scrollBounds.height()) {
148                 "Expected windowSize ($captureWindowHeight) ≤ viewport height " +
149                     "(${scrollBounds.height()})"
150             }
151 
152             val result =
153                 withSurfaceBitmaps(captureWidth, captureWindowHeight) { surface, bitmapsFromSurface
154                     ->
155                     val session =
156                         ScrollCaptureSession(surface, scrollBounds, target.positionInWindow)
157                     callback.onScrollCaptureStart(session)
158 
159                     block(
160                         object : CaptureSessionScope {
161                             private var captureOffset = Point(0, 0)
162 
163                             override val windowHeight: Int
164                                 get() = captureWindowHeight
165 
166                             override fun shiftWindowBy(offset: Int) {
167                                 captureOffset = Point(0, captureOffset.y + offset)
168                             }
169 
170                             override suspend fun performCapture(): CaptureResult {
171                                 val requestedCaptureArea =
172                                     Rect(
173                                         captureOffset.x,
174                                         captureOffset.y,
175                                         captureOffset.x + captureWidth,
176                                         captureOffset.y + captureWindowHeight
177                                     )
178                                 val resultCaptureArea =
179                                     callback.onScrollCaptureImageRequest(
180                                         session,
181                                         requestedCaptureArea
182                                     )
183 
184                                 // Empty results shouldn't produce an image.
185                                 val bitmap =
186                                     if (!resultCaptureArea.isEmpty) {
187                                         bitmapsFromSurface.receiveWithTimeout(1_000) {
188                                             "No bitmap received after 1 second for capture area " +
189                                                 resultCaptureArea
190                                         }
191                                     } else {
192                                         null
193                                     }
194                                 return CaptureResult(
195                                     bitmap = bitmap,
196                                     capturedRect = resultCaptureArea
197                                 )
198                             }
199                         }
200                     )
201                 }
202             callback.onScrollCaptureEnd()
203             return@withContext result
204         }
205 
206     /**
207      * Creates a [Surface] passes it to [block] along with a channel that will receive all images
208      * written to the [Surface].
209      */
210     private suspend inline fun <T> withSurfaceBitmaps(
211         width: Int,
212         height: Int,
213         crossinline block: suspend (Surface, ReceiveChannel<Bitmap>) -> T
214     ): T = coroutineScope {
215         // ImageReader gives us the Surface that we'll provide to the session.
216         ImageReader.newInstance(
217                 width,
218                 height,
219                 PixelFormat.RGBA_8888,
220                 // Each image is read, processed, and closed before the next request to draw is
221                 // made,
222                 // so we don't need multiple images.
223                 /* maxImages= */ 1,
224                 USAGE_GPU_SAMPLED_IMAGE or USAGE_GPU_COLOR_OUTPUT
225             )
226             .use { imageReader ->
227                 val bitmapsChannel = Channel<Bitmap>(capacity = Channel.RENDEZVOUS)
228 
229                 // Must register the OnImageAvailableListener before any code in block runs to avoid
230                 // race conditions.
231                 val imageCollectorJob =
232                     launch(start = CoroutineStart.UNDISPATCHED) {
233                         imageReader.collectImages {
234                             val bitmap = it.toSoftwareBitmap()
235                             bitmapsChannel.send(bitmap)
236                         }
237                     }
238 
239                 try {
240                     block(imageReader.surface, bitmapsChannel)
241                 } finally {
242                     // ImageReader has no signal that it's finished, so in the happy path we have to
243                     // stop the collector job explicitly.
244                     imageCollectorJob.cancel()
245                     bitmapsChannel.close()
246                 }
247             }
248     }
249 
250     /**
251      * Reads all images from this [ImageReader] and passes them to [onImage]. The [Image] will
252      * automatically be closed when [onImage] returns.
253      *
254      * Propagates backpressure to the [ImageReader] – only one image will be acquired from the
255      * [ImageReader] at a time, and the next image won't be acquired until [onImage] returns.
256      */
257     private suspend inline fun ImageReader.collectImages(onImage: (Image) -> Unit): Nothing {
258         val imageAvailableChannel = Channel<Unit>(capacity = Channel.CONFLATED)
259         setOnImageAvailableListener(
260             { imageAvailableChannel.trySend(Unit) },
261             Handler(Looper.getMainLooper())
262         )
263         val context = currentCoroutineContext()
264 
265         try {
266             // Read all images until cancelled.
267             while (true) {
268                 context.ensureActive()
269                 // Fast path – if an image is immediately available, don't suspend.
270                 var image: Image? = acquireNextImage()
271                 // If no image was available, suspend until the callback fires.
272                 while (image == null) {
273                     imageAvailableChannel.receive()
274                     image = acquireNextImage()
275                 }
276                 image.use { onImage(image) }
277             }
278         } finally {
279             setOnImageAvailableListener(null, null)
280         }
281     }
282 
283     /**
284      * Helper function for converting an [Image] to a [Bitmap] by copying the hardware buffer into a
285      * software bitmap.
286      */
287     private fun Image.toSoftwareBitmap(): Bitmap {
288         val hardwareBuffer = checkPreconditionNotNull(hardwareBuffer) { "No hardware buffer" }
289         hardwareBuffer.use {
290             val hardwareBitmap =
291                 Bitmap.wrapHardwareBuffer(hardwareBuffer, ColorSpace.get(ColorSpace.Named.SRGB))
292                     ?: error("wrapHardwareBuffer returned null")
293             try {
294                 return hardwareBitmap.copy(ARGB_8888, false)
295             } finally {
296                 hardwareBitmap.recycle()
297             }
298         }
299     }
300 
301     @OptIn(ExperimentalCoroutinesApi::class)
302     private suspend inline fun <E> ReceiveChannel<E>.receiveWithTimeout(
303         timeoutMillis: Long,
304         crossinline timeoutMessage: () -> String
305     ): E = select {
306         onReceive { it }
307         onTimeout(timeoutMillis) { fail(timeoutMessage()) }
308     }
309 }
310 
311 /**
312  * Emulates (roughly) how the platform interacts with [ScrollCaptureCallback] to iteratively
313  * assemble a screenshot of the entire contents of the [target]. Unlike the platform, this method
314  * will not limit itself to a certain size, it always captures the entire scroll contents, so tests
315  * should make sure to use small enough scroll contents or the test might run out of memory.
316  *
317  * @param captureHeight The height of the capture window. Must not be greater than viewport height.
318  */
319 @RequiresApi(31)
captureBitmapsVerticallynull320 suspend fun ScrollCaptureTester.captureBitmapsVertically(
321     target: ScrollCaptureTarget,
322     captureHeight: Int
323 ): List<Bitmap> = capture(target, captureHeight) { buildList { captureAllFromTop(::add) } }
324 
325 @RequiresApi(31)
performCaptureDiscardingBitmapnull326 internal suspend fun ScrollCaptureTester.CaptureSessionScope.performCaptureDiscardingBitmap() =
327     performCapture().also { it.bitmap?.recycle() }.capturedRect
328 
329 @RequiresApi(31)
captureAllFromTopnull330 suspend fun ScrollCaptureTester.CaptureSessionScope.captureAllFromTop(
331     onBitmap: suspend (Bitmap) -> Unit
332 ) {
333     // Starting with the original viewport, scrolls all the way to the top, then all the way
334     // back down, capturing images on the way down until it hits the bottom.
335     var goingUp = true
336     while (true) {
337         val result = performCapture()
338         val bitmap = result.bitmap
339         if (bitmap != null) {
340             // Only collect the returned images on the way down.
341             if (!goingUp) {
342                 onBitmap(bitmap)
343             } else {
344                 bitmap.recycle()
345             }
346         }
347 
348         val consumed = result.capturedRect.height()
349         if (consumed < windowHeight) {
350             // We found the top or bottom.
351             if (goingUp) {
352                 // "Bounce" off the top: Change direction and start re-capturing down.
353                 goingUp = false
354                 // Move the window to the top of the content.
355                 shiftWindowBy(windowHeight - consumed)
356             } else {
357                 // If we hit the bottom then we're done.
358                 break
359             }
360         } else {
361             // We can keep going in the same direction, offset the capture window and
362             // loop.
363             if (goingUp) {
364                 shiftWindowBy(-windowHeight)
365             } else {
366                 shiftWindowBy(windowHeight)
367             }
368         }
369     }
370 }
371 
372 /**
373  * Helper for calling [ScrollCaptureCallback.onScrollCaptureSearch] from a suspend function. The
374  * [CancellationSignal] and continuation callback are generated from the coroutine.
375  */
376 @RequiresApi(31)
onScrollCaptureSearchnull377 suspend fun ScrollCaptureCallback.onScrollCaptureSearch(): Rect =
378     suspendCancellableCoroutine { continuation ->
379         onScrollCaptureSearch(continuation.createCancellationSignal()) { continuation.resume(it) }
380     }
381 
382 /**
383  * Helper for calling [ScrollCaptureCallback.onScrollCaptureStart] from a suspend function. The
384  * [CancellationSignal] and continuation callback are generated from the coroutine.
385  */
386 @RequiresApi(31)
onScrollCaptureStartnull387 suspend fun ScrollCaptureCallback.onScrollCaptureStart(session: ScrollCaptureSession) {
388     suspendCancellableCoroutine { continuation ->
389         onScrollCaptureStart(session, continuation.createCancellationSignal()) {
390             continuation.resume(Unit)
391         }
392     }
393 }
394 
395 /**
396  * Helper for calling [ScrollCaptureCallback.onScrollCaptureImageRequest] from a suspend function.
397  * The [CancellationSignal] and continuation callback are generated from the coroutine.
398  */
399 @RequiresApi(31)
onScrollCaptureImageRequestnull400 suspend fun ScrollCaptureCallback.onScrollCaptureImageRequest(
401     session: ScrollCaptureSession,
402     captureArea: Rect
403 ): Rect = suspendCancellableCoroutine { continuation ->
404     onScrollCaptureImageRequest(session, continuation.createCancellationSignal(), captureArea) {
405         continuation.resume(it)
406     }
407 }
408 
409 /**
410  * Helper for calling [ScrollCaptureCallback.onScrollCaptureEnd] from a suspend function. The
411  * [CancellationSignal] and continuation callback are generated from the coroutine.
412  */
413 @RequiresApi(31)
onScrollCaptureEndnull414 suspend fun ScrollCaptureCallback.onScrollCaptureEnd() {
415     suspendCancellableCoroutine { continuation -> onScrollCaptureEnd { continuation.resume(Unit) } }
416 }
417 
roundToPointnull418 fun Offset.roundToPoint(): Point = Point(x.roundToInt(), y.roundToInt())
419 
420 /**
421  * Creates a [CancellationSignal] and wires up cancellation bidirectionally to the coroutine's job:
422  * cancelling either one will automatically cancel the other.
423  */
424 private fun CancellableContinuation<*>.createCancellationSignal(): CancellationSignal {
425     val signal = CancellationSignal()
426     signal.setOnCancelListener(this::cancel)
427     invokeOnCancellation { signal.cancel() }
428     return signal
429 }
430