1 /*
2  * Copyright 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.test.screenshot.matchers
18 
19 import android.graphics.Color
20 import androidx.annotation.FloatRange
21 import kotlin.math.pow
22 
23 /**
24  * Image comparison using Structural Similarity Index, developed by Wang, Bovik, Sheikh, and
25  * Simoncelli. Details can be read in their paper:
26  * https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf
27  */
28 class MSSIMMatcher(@FloatRange(from = 0.0, to = 1.0) private val threshold: Double = 0.98) :
29     BitmapMatcher {
30 
31     companion object {
32         // These values were taken from the publication
33         private const val CONSTANT_L = 254.0
34         private const val CONSTANT_K1 = 0.00001
35         private const val CONSTANT_K2 = 0.00003
36         private val CONSTANT_C1 = (CONSTANT_L * CONSTANT_K1).pow(2.0)
37         private val CONSTANT_C2 = (CONSTANT_L * CONSTANT_K2).pow(2.0)
38         private const val WINDOW_SIZE = 10
39     }
40 
compareBitmapsnull41     override fun compareBitmaps(
42         expected: IntArray,
43         given: IntArray,
44         width: Int,
45         height: Int
46     ): MatchResult {
47         val SSIMTotal = calculateSSIM(expected, given, width, height)
48 
49         val stats =
50             "[MSSIM] Required SSIM: $threshold, Actual " + "SSIM: " + "%.3f".format(SSIMTotal)
51 
52         if (SSIMTotal >= threshold) {
53             return MatchResult(matches = true, diff = null, comparisonStatistics = stats)
54         }
55 
56         // Create diff
57         val result = PixelPerfectMatcher().compareBitmaps(expected, given, width, height)
58         return MatchResult(matches = false, diff = result.diff, comparisonStatistics = stats)
59     }
60 
calculateSSIMnull61     internal fun calculateSSIM(ideal: IntArray, given: IntArray, width: Int, height: Int): Double {
62         return calculateSSIM(ideal, given, 0, width, width, height)
63     }
64 
calculateSSIMnull65     private fun calculateSSIM(
66         ideal: IntArray,
67         given: IntArray,
68         offset: Int,
69         stride: Int,
70         width: Int,
71         height: Int
72     ): Double {
73         var SSIMTotal = 0.0
74         var windows = 0
75         var currentWindowY = 0
76         while (currentWindowY < height) {
77             val windowHeight = computeWindowSize(currentWindowY, height)
78             var currentWindowX = 0
79             while (currentWindowX < width) {
80                 val windowWidth = computeWindowSize(currentWindowX, width)
81                 val start: Int = indexFromXAndY(currentWindowX, currentWindowY, stride, offset)
82                 if (
83                     isWindowWhite(ideal, start, stride, windowWidth, windowHeight) &&
84                         isWindowWhite(given, start, stride, windowWidth, windowHeight)
85                 ) {
86                     currentWindowX += WINDOW_SIZE
87                     continue
88                 }
89                 windows++
90                 val means = getMeans(ideal, given, start, stride, windowWidth, windowHeight)
91                 val meanX = means[0]
92                 val meanY = means[1]
93                 val variances =
94                     getVariances(
95                         ideal,
96                         given,
97                         meanX,
98                         meanY,
99                         start,
100                         stride,
101                         windowWidth,
102                         windowHeight
103                     )
104                 val varX = variances[0]
105                 val varY = variances[1]
106                 val stdBoth = variances[2]
107                 val SSIM = SSIM(meanX, meanY, varX, varY, stdBoth)
108                 SSIMTotal += SSIM
109                 currentWindowX += WINDOW_SIZE
110             }
111             currentWindowY += WINDOW_SIZE
112         }
113         if (windows == 0) {
114             return 1.0
115         }
116         return SSIMTotal / windows.toDouble()
117     }
118 
119     /**
120      * Compute the size of the window. The window defaults to WINDOW_SIZE, but must be contained
121      * within dimension.
122      */
computeWindowSizenull123     private fun computeWindowSize(coordinateStart: Int, dimension: Int): Int {
124         return if (coordinateStart + WINDOW_SIZE <= dimension) {
125             WINDOW_SIZE
126         } else {
127             dimension - coordinateStart
128         }
129     }
130 
isWindowWhitenull131     private fun isWindowWhite(
132         colors: IntArray,
133         start: Int,
134         stride: Int,
135         windowWidth: Int,
136         windowHeight: Int
137     ): Boolean {
138         for (y in 0 until windowHeight) {
139             for (x in 0 until windowWidth) {
140                 if (colors[indexFromXAndY(x, y, stride, start)] != Color.WHITE) {
141                     return false
142                 }
143             }
144         }
145         return true
146     }
147 
148     /**
149      * This calculates the position in an array that would represent a bitmap given the parameters.
150      */
indexFromXAndYnull151     private fun indexFromXAndY(x: Int, y: Int, stride: Int, offset: Int): Int {
152         return x + y * stride + offset
153     }
154 
SSIMnull155     private fun SSIM(muX: Double, muY: Double, sigX: Double, sigY: Double, sigXY: Double): Double {
156         var SSIM = (2 * muX * muY + CONSTANT_C1) * (2 * sigXY + CONSTANT_C2)
157         val denom = ((muX * muX + muY * muY + CONSTANT_C1) * (sigX + sigY + CONSTANT_C2))
158         SSIM /= denom
159         return SSIM
160     }
161 
162     /**
163      * This method will find the mean of a window in both sets of pixels. The return is an array
164      * where the first double is the mean of the first set and the second double is the mean of the
165      * second set.
166      */
getMeansnull167     private fun getMeans(
168         pixels0: IntArray,
169         pixels1: IntArray,
170         start: Int,
171         stride: Int,
172         windowWidth: Int,
173         windowHeight: Int
174     ): DoubleArray {
175         var avg0 = 0.0
176         var avg1 = 0.0
177         for (y in 0 until windowHeight) {
178             for (x in 0 until windowWidth) {
179                 val index: Int = indexFromXAndY(x, y, stride, start)
180                 avg0 += getIntensity(pixels0[index])
181                 avg1 += getIntensity(pixels1[index])
182             }
183         }
184         avg0 /= windowWidth * windowHeight.toDouble()
185         avg1 /= windowWidth * windowHeight.toDouble()
186         return doubleArrayOf(avg0, avg1)
187     }
188 
189     /**
190      * Finds the variance of the two sets of pixels, as well as the covariance of the windows. The
191      * return value is an array of doubles, the first is the variance of the first set of pixels,
192      * the second is the variance of the second set of pixels, and the third is the covariance.
193      */
getVariancesnull194     private fun getVariances(
195         pixels0: IntArray,
196         pixels1: IntArray,
197         mean0: Double,
198         mean1: Double,
199         start: Int,
200         stride: Int,
201         windowWidth: Int,
202         windowHeight: Int
203     ): DoubleArray {
204         if (windowHeight == 1 && windowWidth == 1) {
205             // There is only one item. The variance of a single item would be 0.
206             // Since Bessel's correction is used below, it will return NaN instead of 0.
207             return doubleArrayOf(0.0, 0.0, 0.0)
208         }
209 
210         var var0 = 0.0
211         var var1 = 0.0
212         var varBoth = 0.0
213         for (y in 0 until windowHeight) {
214             for (x in 0 until windowWidth) {
215                 val index: Int = indexFromXAndY(x, y, stride, start)
216                 val v0 = getIntensity(pixels0[index]) - mean0
217                 val v1 = getIntensity(pixels1[index]) - mean1
218                 var0 += v0 * v0
219                 var1 += v1 * v1
220                 varBoth += v0 * v1
221             }
222         }
223         // Using Bessel's correction. Hence, subtracting one.
224         val denominatorWithBesselsCorrection = windowWidth * windowHeight - 1.0
225         var0 /= denominatorWithBesselsCorrection
226         var1 /= denominatorWithBesselsCorrection
227         varBoth /= denominatorWithBesselsCorrection
228         return doubleArrayOf(var0, var1, varBoth)
229     }
230 
231     /**
232      * Gets the intensity of a given pixel in RGB using luminosity formula
233      *
234      * l = 0.21R' + 0.72G' + 0.07B'
235      *
236      * The prime symbols dictate a gamma correction of 1.
237      */
getIntensitynull238     private fun getIntensity(pixel: Int): Double {
239         val gamma = 1.0
240         var l = 0.0
241         l += 0.21f * (Color.red(pixel) / 255f.toDouble()).pow(gamma)
242         l += 0.72f * (Color.green(pixel) / 255f.toDouble()).pow(gamma)
243         l += 0.07f * (Color.blue(pixel) / 255f.toDouble()).pow(gamma)
244         return l
245     }
246 }
247