• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.mediav2.common.cts;
18 
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertTrue;
21 
22 import android.graphics.Rect;
23 import android.util.Log;
24 import android.util.Pair;
25 
26 import com.android.compatibility.common.util.Preconditions;
27 
28 import java.io.File;
29 import java.io.FileInputStream;
30 import java.io.IOException;
31 import java.io.InputStream;
32 import java.io.RandomAccessFile;
33 import java.nio.ByteBuffer;
34 import java.nio.ByteOrder;
35 import java.util.ArrayList;
36 import java.util.Arrays;
37 
38 /**
39  * Class to compute per-frame PSNR, minimum PSNR and global PSNR between two YUV420P yuv streams.
40  */
41 public class VideoErrorManager {
42     private static final String LOG_TAG = VideoErrorManager.class.getSimpleName();
43     private static final boolean ENABLE_LOGS = false;
44 
45     private final RawResource mRefYuv;
46     private final RawResource mTestYuv;
47     private final boolean mAllowLoopBack;
48 
49     private boolean mGenerateStats;
50     private final double[] mGlobalMSE;
51     private final double[] mMinimumMSE;
52     private final double[] mGlobalPSNR;
53     private final double[] mMinimumPSNR;
54     private final double[] mAvgPSNR;
55     private final ArrayList<double[]> mFramesPSNR;
56 
VideoErrorManager(RawResource refYuv, RawResource testYuv, boolean allowLoopBack)57     public VideoErrorManager(RawResource refYuv, RawResource testYuv, boolean allowLoopBack) {
58         mRefYuv = refYuv;
59         mTestYuv = testYuv;
60         mAllowLoopBack = allowLoopBack;
61         if (mRefYuv.mHeight != mTestYuv.mHeight || mRefYuv.mWidth != mTestYuv.mWidth
62                 || mRefYuv.mBytesPerSample != mTestYuv.mBytesPerSample) {
63             String msg = String.format(
64                     "Reference file attributes and Test file attributes are not same. Reference "
65                             + "width : %d, height : %d, bytesPerSample : %d, Test width : %d, "
66                             + "height : %d, bytesPerSample : %d \n",
67                     mRefYuv.mWidth, mRefYuv.mHeight, mRefYuv.mBytesPerSample, mTestYuv.mWidth,
68                     mTestYuv.mHeight, mTestYuv.mBytesPerSample);
69             throw new IllegalArgumentException(msg);
70         }
71         if (((mRefYuv.mWidth & 1) != 0) || ((mRefYuv.mHeight & 1) != 0) || (
72                 (mRefYuv.mBytesPerSample != 1) && (mRefYuv.mBytesPerSample != 2))) {
73             String msg = String.format(LOG_TAG
74                             + " handles only YUV420p 8bit or 16bit inputs. Current file "
75                             + "attributes are width : %d, height : %d, bytesPerSample : %d",
76                     mRefYuv.mWidth, mRefYuv.mHeight, mRefYuv.mBytesPerSample);
77             throw new IllegalArgumentException(msg);
78         }
79         mMinimumMSE = new double[3];
80         Arrays.fill(mMinimumMSE, Float.MAX_VALUE);
81         mGlobalMSE = new double[3];
82         Arrays.fill(mGlobalMSE, 0.0);
83         mGlobalPSNR = new double[3];
84         mMinimumPSNR = new double[3];
85         mAvgPSNR = new double[3];
86         Arrays.fill(mAvgPSNR, 0.0);
87         mFramesPSNR = new ArrayList<>();
88     }
89 
computeFrameVariance(int width, int height, T luma)90     public static <T> Pair<Double, Integer> computeFrameVariance(int width, int height, T luma) {
91         final int bSize = 16;
92         assertTrue("chosen block size is too large with respect to image dimensions",
93                 width > bSize && height > bSize);
94         double varianceSum = 0;
95         int blocks = 0;
96         for (int i = 0; i < height - bSize; i += bSize) {
97             for (int j = 0; j < width - bSize; j += bSize) {
98                 long sse = 0, sum = 0;
99                 int offset = i * width + j;
100                 for (int p = 0; p < bSize; p++) {
101                     for (int q = 0; q < bSize; q++) {
102                         int sample;
103                         if (luma instanceof byte[]) {
104                             sample = Byte.toUnsignedInt(((byte[]) luma)[offset + p * width + q]);
105                         } else if (luma instanceof short[]) {
106                             sample = Short.toUnsignedInt(((short[]) luma)[offset + p * width + q]);
107                             sample >>= 6;
108                         } else {
109                             throw new IllegalArgumentException("Unsupported data type");
110                         }
111                         sum += sample;
112                         sse += sample * sample;
113                     }
114                 }
115                 double meanOfSquares = ((double) sse) / (bSize * bSize);
116                 double mean = ((double) sum) / (bSize * bSize);
117                 double squareOfMean = mean * mean;
118                 double blockVariance = (meanOfSquares - squareOfMean);
119                 assertTrue("variance can't be negative", blockVariance >= 0.0f);
120                 varianceSum += blockVariance;
121                 assertTrue("caution overflow", varianceSum >= 0.0);
122                 blocks++;
123             }
124         }
125         return Pair.create(varianceSum, blocks);
126     }
127 
computeMSE(byte[] data0, byte[] data1, int bytesPerSample, int imgWidth, int imgHeight, Rect cropRect)128     static double computeMSE(byte[] data0, byte[] data1, int bytesPerSample, int imgWidth,
129             int imgHeight, Rect cropRect) {
130         assertEquals(data0.length, data1.length);
131         int length = data0.length / bytesPerSample;
132         long squareError = 0;
133         int cropLeft = 0;
134         int cropTop = 0;
135         int cropWidth = imgWidth;
136         int cropHeight = imgHeight;
137         if (cropRect != null) {
138             cropLeft = cropRect.left;
139             cropTop = cropRect.top;
140             cropWidth = cropRect.width();
141             cropHeight = cropRect.height();
142         }
143 
144         if (bytesPerSample == 2) {
145             short[] dataA = new short[length];
146             ByteBuffer.wrap(data0).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(dataA);
147             short[] dataB = new short[length];
148             ByteBuffer.wrap(data1).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(dataB);
149             for (int h = 0; h < cropHeight; h++) {
150                 int offset = (cropTop + h) * imgWidth + cropLeft;
151                 for (int w = 0; w < cropWidth; w++) {
152                     long diff = (long) ((int) dataA[offset + w] & 0xffff) - ((int) dataB[offset + w]
153                             & 0xffff);
154                     squareError += diff * diff;
155                 }
156             }
157         } else {
158             for (int h = 0; h < cropHeight; h++) {
159                 int offset = (cropTop + h) * imgWidth + cropLeft;
160                 for (int w = 0; w < cropWidth; w++) {
161                     int diff = ((int) data0[offset + w] & 0xff) - ((int) data1[offset + w] & 0xff);
162                     squareError += diff * ((long) diff);
163                 }
164             }
165         }
166         return (double) squareError / (cropWidth * cropHeight);
167     }
168 
computePSNR(double mse, int bytesPerSample)169     static double computePSNR(double mse, int bytesPerSample) {
170         if (mse == 0) return 100.0;
171         final int peakSignal = (1 << (8 * bytesPerSample)) - 1;
172         return 10 * Math.log10((double) peakSignal * peakSignal / mse);
173     }
174 
generateErrorStats()175     private void generateErrorStats() throws IOException {
176         Preconditions.assertTestFileExists(mRefYuv.mFileName);
177         Preconditions.assertTestFileExists(mTestYuv.mFileName);
178 
179         try (RandomAccessFile refStream = new RandomAccessFile(new File(mRefYuv.mFileName), "r");
180              InputStream testStream = new FileInputStream(mTestYuv.mFileName)) {
181             int ySize = mRefYuv.mWidth * mRefYuv.mHeight * mRefYuv.mBytesPerSample;
182             int uvSize = ySize >> 2;
183             byte[] yRef = new byte[ySize];
184             byte[] uvRef = new byte[uvSize];
185             byte[] yTest = new byte[ySize];
186             byte[] uvTest = new byte[uvSize];
187 
188             int frames = 0;
189             while (true) {
190                 int bytesReadRef = refStream.read(yRef);
191                 int bytesReadDec = testStream.read(yTest);
192                 if (bytesReadDec != ySize || (!mAllowLoopBack && bytesReadRef != ySize)) break;
193                 if (bytesReadRef != ySize) {
194                     refStream.seek(0);
195                     refStream.read(yRef);
196                 }
197                 double curYMSE = computeMSE(yRef, yTest, mRefYuv.mBytesPerSample, mRefYuv.mWidth,
198                         mRefYuv.mHeight, null);
199                 mGlobalMSE[0] += curYMSE;
200                 mMinimumMSE[0] = Math.min(mMinimumMSE[0], curYMSE);
201 
202                 assertEquals("failed to read U Plane " + mRefYuv.mFileName
203                                 + " contains insufficient bytes", uvSize,
204                         refStream.read(uvRef));
205                 assertEquals("failed to read U Plane " + mTestYuv.mFileName
206                                 + " contains insufficient bytes", uvSize,
207                         testStream.read(uvTest));
208                 double curUMSE = computeMSE(uvRef, uvTest, mRefYuv.mBytesPerSample,
209                         mRefYuv.mWidth / 2, mRefYuv.mHeight / 2, null);
210                 mGlobalMSE[1] += curUMSE;
211                 mMinimumMSE[1] = Math.min(mMinimumMSE[1], curUMSE);
212 
213                 assertEquals("failed to read V Plane " + mRefYuv.mFileName
214                                 + " contains insufficient bytes", uvSize,
215                         refStream.read(uvRef));
216                 assertEquals("failed to read V Plane " + mTestYuv.mFileName
217                                 + " contains insufficient bytes", uvSize,
218                         testStream.read(uvTest));
219                 double curVMSE = computeMSE(uvRef, uvTest, mRefYuv.mBytesPerSample,
220                         mRefYuv.mWidth / 2, mRefYuv.mHeight / 2, null);
221                 mGlobalMSE[2] += curVMSE;
222                 mMinimumMSE[2] = Math.min(mMinimumMSE[2], curVMSE);
223 
224                 double yFramePSNR = computePSNR(curYMSE, mRefYuv.mBytesPerSample);
225                 double uFramePSNR = computePSNR(curUMSE, mRefYuv.mBytesPerSample);
226                 double vFramePSNR = computePSNR(curVMSE, mRefYuv.mBytesPerSample);
227                 mAvgPSNR[0] += yFramePSNR;
228                 mAvgPSNR[1] += uFramePSNR;
229                 mAvgPSNR[2] += vFramePSNR;
230                 mFramesPSNR.add(new double[]{yFramePSNR, uFramePSNR, vFramePSNR});
231 
232                 if (ENABLE_LOGS) {
233                     String msg = String.format(
234                             "frame: %d mse_y:%,.2f mse_u:%,.2f mse_v:%,.2f psnr_y:%,.2f psnr_u:%,"
235                                     + ".2f psnr_v:%,.2f",
236                             frames, curYMSE, curUMSE, curVMSE, mFramesPSNR.get(frames)[0],
237                             mFramesPSNR.get(frames)[1], mFramesPSNR.get(frames)[2]);
238                     Log.v(LOG_TAG, msg);
239                 }
240                 frames++;
241             }
242             for (int i = 0; i < mGlobalPSNR.length; i++) {
243                 mGlobalMSE[i] /= frames;
244                 mGlobalPSNR[i] = computePSNR(mGlobalMSE[i], mRefYuv.mBytesPerSample);
245                 mMinimumPSNR[i] = computePSNR(mMinimumMSE[i], mRefYuv.mBytesPerSample);
246                 mAvgPSNR[i] /= frames;
247             }
248             if (ENABLE_LOGS) {
249                 String msg = String.format(
250                         "global_psnr_y:%.2f, global_psnr_u:%.2f, global_psnr_v:%.2f, min_psnr_y:%"
251                                 + ".2f, min_psnr_u:%.2f, min_psnr_v:%.2f avg_psnr_y:%.2f, "
252                                 + "avg_psnr_u:%.2f, avg_psnr_v:%.2f",
253                         mGlobalPSNR[0], mGlobalPSNR[1], mGlobalPSNR[2], mMinimumPSNR[0],
254                         mMinimumPSNR[1], mMinimumPSNR[2], mAvgPSNR[0], mAvgPSNR[1], mAvgPSNR[2]);
255                 Log.v(LOG_TAG, msg);
256             }
257         }
258     }
259 
260     /**
261      * Returns Min(Ypsnr of all frames), Min(Upsnr of all frames), Min(Vpsnr of all frames) as an
262      * array at subscripts 0, 1, 2 respectively
263      */
getMinimumPSNR()264     public double[] getMinimumPSNR() throws IOException {
265         if (!mGenerateStats) {
266             generateErrorStats();
267             mGenerateStats = true;
268         }
269         return mMinimumPSNR;
270     }
271 
272     /**
273      * Returns GlobalYpsnr, GlobalUpsnr, GlobalVpsnr as an array at subscripts 0, 1, 2 respectively.
274      * Globalpsnr = 10 * log10 (peakSignal * peakSignal / global mse)
275      * GlobalMSE = Sum of all frames MSE / Total Frames
276      * MSE = Sum of all (error * error) / Total pixels
277      * error = ref[i] - test[i]
278      */
getGlobalPSNR()279     public double[] getGlobalPSNR() throws IOException {
280         if (!mGenerateStats) {
281             generateErrorStats();
282             mGenerateStats = true;
283         }
284         return mGlobalPSNR;
285     }
286 
287     /**
288      * returns list of all frames PSNR. Each entry in the list is an array of 3 elements,
289      * representing Y, U, V Planes PSNR separately
290      */
getFramesPSNR()291     public ArrayList<double[]> getFramesPSNR() throws IOException {
292         if (!mGenerateStats) {
293             generateErrorStats();
294             mGenerateStats = true;
295         }
296         return mFramesPSNR;
297     }
298 
299     /**
300      * Returns Avg(Ypsnr of all frames), Avg(Upsnr of all frames), Avg(Vpsnr of all frames) as an
301      * array at subscripts 0, 1, 2 respectively
302      */
getAvgPSNR()303     public double[] getAvgPSNR() throws IOException {
304         if (!mGenerateStats) {
305             generateErrorStats();
306             mGenerateStats = true;
307         }
308         return mAvgPSNR;
309     }
310 }
311