• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package platform.test.screenshot.matchers
18 
19 import android.graphics.Color
20 import android.graphics.Rect
21 import androidx.annotation.FloatRange
22 import kotlin.collections.List
23 import kotlin.math.pow
24 import platform.test.screenshot.proto.ScreenshotResultProto
25 
26 /**
27  * Image comparison using Structural Similarity Index, developed by Wang, Bovik, Sheikh, and
28  * Simoncelli. Details can be read in their paper:
29  * https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf
30  */
31 class MSSIMMatcher(
32     @FloatRange(from = 0.0, to = 1.0) private val threshold: Double = 0.98
33 ) : BitmapMatcher() {
34 
35     companion object {
36         // These values were taken from the publication
37         private const val CONSTANT_L = 254.0
38         private const val CONSTANT_K1 = 0.00001
39         private const val CONSTANT_K2 = 0.00003
40         private val CONSTANT_C1 = (CONSTANT_L * CONSTANT_K1).pow(2.0)
41         private val CONSTANT_C2 = (CONSTANT_L * CONSTANT_K2).pow(2.0)
42         private const val WINDOW_SIZE = 10
43     }
44 
compareBitmapsnull45     override fun compareBitmaps(
46         expected: IntArray,
47         given: IntArray,
48         width: Int,
49         height: Int,
50         regions: List<Rect>
51     ): MatchResult {
52         val filter = getFilter(width, height, regions)
53         val calSSIMResult = calculateSSIM(expected, given, width, height, filter)
54 
55         val stats = ScreenshotResultProto.DiffResult.ComparisonStatistics
56             .newBuilder()
57             .setNumberPixelsCompared(calSSIMResult.numPixelsCompared)
58             .setNumberPixelsSimilar(calSSIMResult.numPixelsSimilar)
59             .setNumberPixelsIgnored(calSSIMResult.numPixelsIgnored)
60             .setNumberPixelsDifferent(
61                 calSSIMResult.numPixelsCompared - calSSIMResult.numPixelsSimilar
62             )
63             .build()
64 
65         if (calSSIMResult.numPixelsSimilar
66             >= threshold * calSSIMResult.numPixelsCompared.toDouble()
67         ) {
68             return MatchResult(
69                 matches = true,
70                 diff = null,
71                 comparisonStatistics = stats
72             )
73         }
74 
75         // Create diff
76         val result = PixelPerfectMatcher()
77             .compareBitmaps(expected, given, width, height, regions)
78         return MatchResult(
79             matches = false,
80             diff = result.diff,
81             comparisonStatistics = stats
82         )
83     }
84 
calculateSSIMnull85     internal fun calculateSSIM(
86         ideal: IntArray,
87         given: IntArray,
88         width: Int,
89         height: Int,
90         filter: IntArray
91     ): SSIMResult {
92         return calculateSSIM(ideal, given, 0, width, width, height, filter)
93     }
94 
calculateSSIMnull95     private fun calculateSSIM(
96         ideal: IntArray,
97         given: IntArray,
98         offset: Int,
99         stride: Int,
100         width: Int,
101         height: Int,
102         filter: IntArray
103     ): SSIMResult {
104         var SSIMTotal = 0.0
105         var totalNumPixelsCompared = 0.0
106         var currentWindowY = 0
107         var ignored = 0
108 
109         while (currentWindowY < height) {
110             val windowHeight = computeWindowSize(currentWindowY, height)
111             var currentWindowX = 0
112             while (currentWindowX < width) {
113                 val windowWidth = computeWindowSize(currentWindowX, width)
114                 val start: Int =
115                     indexFromXAndY(currentWindowX, currentWindowY, stride, offset)
116                 if (shouldIgnoreWindow(ideal, start, stride, windowWidth, windowHeight, filter) &&
117                     shouldIgnoreWindow(given, start, stride, windowWidth, windowHeight, filter)
118                 ) {
119                     currentWindowX += WINDOW_SIZE
120                     ignored += windowWidth * windowHeight
121                     continue
122                 }
123                 val means = getMeans(ideal, given, filter, start, stride, windowWidth, windowHeight)
124                 val meanX = means[0]
125                 val meanY = means[1]
126                 val variances = getVariances(
127                     ideal, given, filter, meanX, meanY, start, stride, windowWidth, windowHeight
128                 )
129                 val varX = variances[0]
130                 val varY = variances[1]
131                 val stdBoth = variances[2]
132                 val SSIM = SSIM(meanX, meanY, varX, varY, stdBoth)
133                 val numPixelsCompared = numPixelsToCompareInWindow(
134                     start, stride, windowWidth, windowHeight, filter
135                 )
136                 SSIMTotal += SSIM * numPixelsCompared
137                 totalNumPixelsCompared += numPixelsCompared.toDouble()
138                 currentWindowX += WINDOW_SIZE
139             }
140             currentWindowY += WINDOW_SIZE
141         }
142 
143         val averageSSIM = SSIMTotal / totalNumPixelsCompared
144         return SSIMResult(
145             SSIM = averageSSIM,
146             numPixelsSimilar = (averageSSIM * totalNumPixelsCompared + 0.5).toInt(),
147             numPixelsIgnored = ignored,
148             numPixelsCompared = (totalNumPixelsCompared + 0.5).toInt()
149         )
150     }
151 
152     /**
153      * Compute the size of the window. The window defaults to WINDOW_SIZE, but
154      * must be contained within dimension.
155      */
computeWindowSizenull156     private fun computeWindowSize(coordinateStart: Int, dimension: Int): Int {
157         return if (coordinateStart + WINDOW_SIZE <= dimension) {
158             WINDOW_SIZE
159         } else {
160             dimension - coordinateStart
161         }
162     }
163 
164     /**
165      * Checks whether a pixel should be ignored. A pixel should be ignored if the corresponding
166      * filter entry is zero.
167      */
shouldIgnorePixelnull168     private fun shouldIgnorePixel(
169         x: Int,
170         y: Int,
171         start: Int,
172         stride: Int,
173         filter: IntArray
174     ): Boolean {
175         return filter[indexFromXAndY(x, y, stride, start)] == 0
176     }
177 
178     /**
179      * Checks whether a whole window should be ignored. A window should be ignored if all pixels
180      * are either white or should be ignored.
181      */
shouldIgnoreWindownull182     private fun shouldIgnoreWindow(
183         colors: IntArray,
184         start: Int,
185         stride: Int,
186         windowWidth: Int,
187         windowHeight: Int,
188         filter: IntArray
189     ): Boolean {
190         for (y in 0 until windowHeight) {
191             for (x in 0 until windowWidth) {
192                 if (shouldIgnorePixel(x, y, start, stride, filter)) {
193                     continue
194                 }
195                 if (colors[indexFromXAndY(x, y, stride, start)] != Color.WHITE) {
196                     return false
197                 }
198             }
199         }
200         return true
201     }
202 
numPixelsToCompareInWindownull203     private fun numPixelsToCompareInWindow(
204         start: Int,
205         stride: Int,
206         windowWidth: Int,
207         windowHeight: Int,
208         filter: IntArray
209     ): Int {
210         var numPixelsToCompare = 0
211         for (y in 0 until windowHeight) {
212             for (x in 0 until windowWidth) {
213                 if (!shouldIgnorePixel(x, y, start, stride, filter)) {
214                     numPixelsToCompare++
215                 }
216             }
217         }
218         return numPixelsToCompare
219     }
220 
221     /**
222      * This calculates the position in an array that would represent a bitmap given the parameters.
223      */
indexFromXAndYnull224     private fun indexFromXAndY(x: Int, y: Int, stride: Int, offset: Int): Int {
225         return x + y * stride + offset
226     }
227 
SSIMnull228     private fun SSIM(muX: Double, muY: Double, sigX: Double, sigY: Double, sigXY: Double): Double {
229         var SSIM = (2 * muX * muY + CONSTANT_C1) * (2 * sigXY + CONSTANT_C2)
230         val denom = ((muX * muX + muY * muY + CONSTANT_C1) * (sigX + sigY + CONSTANT_C2))
231         SSIM /= denom
232         return SSIM
233     }
234 
235     /**
236      * This method will find the mean of a window in both sets of pixels. The return is an array
237      * where the first double is the mean of the first set and the second double is the mean of the
238      * second set.
239      */
getMeansnull240     private fun getMeans(
241         pixels0: IntArray,
242         pixels1: IntArray,
243         filter: IntArray,
244         start: Int,
245         stride: Int,
246         windowWidth: Int,
247         windowHeight: Int
248     ): DoubleArray {
249         var avg0 = 0.0
250         var avg1 = 0.0
251         var numPixelsCounted = 0.0
252         for (y in 0 until windowHeight) {
253             for (x in 0 until windowWidth) {
254                 if (shouldIgnorePixel(x, y, start, stride, filter)) {
255                     continue
256                 }
257                 val index: Int = indexFromXAndY(x, y, stride, start)
258                 avg0 += getIntensity(pixels0[index])
259                 avg1 += getIntensity(pixels1[index])
260                 numPixelsCounted += 1.0
261             }
262         }
263         avg0 /= numPixelsCounted
264         avg1 /= numPixelsCounted
265         return doubleArrayOf(avg0, avg1)
266     }
267 
268     /**
269      * Finds the variance of the two sets of pixels, as well as the covariance of the windows. The
270      * return value is an array of doubles, the first is the variance of the first set of pixels,
271      * the second is the variance of the second set of pixels, and the third is the covariance.
272      */
getVariancesnull273     private fun getVariances(
274         pixels0: IntArray,
275         pixels1: IntArray,
276         filter: IntArray,
277         mean0: Double,
278         mean1: Double,
279         start: Int,
280         stride: Int,
281         windowWidth: Int,
282         windowHeight: Int
283     ): DoubleArray {
284         var var0 = 0.0
285         var var1 = 0.0
286         var varBoth = 0.0
287         var numPixelsCounted = 0
288         for (y in 0 until windowHeight) {
289             for (x in 0 until windowWidth) {
290                 if (shouldIgnorePixel(x, y, start, stride, filter)) {
291                     continue
292                 }
293                 val index: Int = indexFromXAndY(x, y, stride, start)
294                 val v0 = getIntensity(pixels0[index]) - mean0
295                 val v1 = getIntensity(pixels1[index]) - mean1
296                 var0 += v0 * v0
297                 var1 += v1 * v1
298                 varBoth += v0 * v1
299                 numPixelsCounted += 1
300             }
301         }
302         if (numPixelsCounted <= 1) {
303             var0 = 0.0
304             var1 = 0.0
305             varBoth = 0.0
306         } else {
307             var0 /= (numPixelsCounted - 1).toDouble()
308             var1 /= (numPixelsCounted - 1).toDouble()
309             varBoth /= (numPixelsCounted - 1).toDouble()
310         }
311         return doubleArrayOf(var0, var1, varBoth)
312     }
313 
314     /**
315      * Gets the intensity of a given pixel in RGB using luminosity formula
316      *
317      * l = 0.21R' + 0.72G' + 0.07B'
318      *
319      * The prime symbols dictate a gamma correction of 1.
320      */
getIntensitynull321     private fun getIntensity(pixel: Int): Double {
322         val gamma = 1.0
323         var l = 0.0
324         l += 0.21f * (Color.red(pixel) / 255f.toDouble()).pow(gamma)
325         l += 0.72f * (Color.green(pixel) / 255f.toDouble()).pow(gamma)
326         l += 0.07f * (Color.blue(pixel) / 255f.toDouble()).pow(gamma)
327         return l
328     }
329 }
330 
331 /**
332  * Result of the calculation of SSIM.
333  *
334  * @param numPixelsSimilar The number of similar pixels.
335  * @param numPixelsIgnored The number of ignored pixels.
336  * @param numPixelsCompared The number of compared pixels.
337  */
338 class SSIMResult(
339     val SSIM: Double,
340     val numPixelsSimilar: Int,
341     val numPixelsIgnored: Int,
342     val numPixelsCompared: Int
343 )
344