• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.mediav2.common.cts;
18 
19 import static org.junit.Assert.assertEquals;
20 
21 import android.util.Log;
22 
23 import com.android.compatibility.common.util.Preconditions;
24 
25 import java.io.File;
26 import java.io.FileInputStream;
27 import java.io.IOException;
28 import java.io.InputStream;
29 import java.io.RandomAccessFile;
30 import java.nio.ByteBuffer;
31 import java.nio.ByteOrder;
32 import java.util.ArrayList;
33 import java.util.Arrays;
34 
35 /**
36  * Class to compute per-frame PSNR, minimum PSNR and global PSNR between two YUV420P yuv streams.
37  */
38 public class VideoErrorManager {
39     private static final String LOG_TAG = VideoErrorManager.class.getSimpleName();
40     private static final boolean ENABLE_LOGS = false;
41 
42     private final RawResource mRefYuv;
43     private final RawResource mTestYuv;
44     private final boolean mAllowLoopBack;
45 
46     private boolean mGenerateStats;
47     private final double[] mGlobalMSE;
48     private final double[] mMinimumMSE;
49     private final double[] mGlobalPSNR;
50     private final double[] mMinimumPSNR;
51     private final double[] mAvgPSNR;
52     private final ArrayList<double[]> mFramesPSNR;
53 
VideoErrorManager(RawResource refYuv, RawResource testYuv, boolean allowLoopBack)54     public VideoErrorManager(RawResource refYuv, RawResource testYuv, boolean allowLoopBack) {
55         mRefYuv = refYuv;
56         mTestYuv = testYuv;
57         mAllowLoopBack = allowLoopBack;
58         if (mRefYuv.mHeight != mTestYuv.mHeight || mRefYuv.mWidth != mTestYuv.mWidth
59                 || mRefYuv.mBytesPerSample != mTestYuv.mBytesPerSample) {
60             String msg = String.format(
61                     "Reference file attributes and Test file attributes are not same. Reference "
62                             + "width : %d, height : %d, bytesPerSample : %d, Test width : %d, "
63                             + "height : %d, bytesPerSample : %d \n",
64                     mRefYuv.mWidth, mRefYuv.mHeight, mRefYuv.mBytesPerSample, mTestYuv.mWidth,
65                     mTestYuv.mHeight, mTestYuv.mBytesPerSample);
66             throw new IllegalArgumentException(msg);
67         }
68         if (((mRefYuv.mWidth & 1) != 0) || ((mRefYuv.mHeight & 1) != 0) || (
69                 (mRefYuv.mBytesPerSample != 1) && (mRefYuv.mBytesPerSample != 2))) {
70             String msg = String.format(LOG_TAG
71                             + " handles only YUV420p 8bit or 16bit inputs. Current file "
72                             + "attributes are width : %d, height : %d, bytesPerSample : %d",
73                     mRefYuv.mWidth, mRefYuv.mHeight, mRefYuv.mBytesPerSample);
74             throw new IllegalArgumentException(msg);
75         }
76         mMinimumMSE = new double[3];
77         Arrays.fill(mMinimumMSE, Float.MAX_VALUE);
78         mGlobalMSE = new double[3];
79         Arrays.fill(mGlobalMSE, 0.0);
80         mGlobalPSNR = new double[3];
81         mMinimumPSNR = new double[3];
82         mAvgPSNR = new double[3];
83         Arrays.fill(mAvgPSNR, 0.0);
84         mFramesPSNR = new ArrayList<>();
85     }
86 
computeMSE(byte[] data0, byte[] data1, int bytesPerSample)87     static double computeMSE(byte[] data0, byte[] data1, int bytesPerSample) {
88         assertEquals(data0.length, data1.length);
89         int length = data0.length / bytesPerSample;
90         long squareError = 0;
91 
92         if (bytesPerSample == 2) {
93             short[] dataA = new short[length];
94             ByteBuffer.wrap(data0).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(dataA);
95             short[] dataB = new short[length];
96             ByteBuffer.wrap(data1).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(dataB);
97             for (int i = 0; i < length; i++) {
98                 long diff = ((int) dataA[i] & 0xffff) - ((int) dataB[i] & 0xffff);
99                 squareError += diff * diff;
100             }
101         } else {
102             for (int i = 0; i < length; i++) {
103                 int diff = ((int) data0[i] & 0xff) - ((int) data1[i] & 0xff);
104                 squareError += diff * diff;
105             }
106         }
107         return (double) squareError / length;
108     }
109 
computePSNR(double mse, int bytesPerSample)110     static double computePSNR(double mse, int bytesPerSample) {
111         if (mse == 0) return 100.0;
112         final int peakSignal = (1 << (8 * bytesPerSample)) - 1;
113         return 10 * Math.log10((double) peakSignal * peakSignal / mse);
114     }
115 
generateErrorStats()116     private void generateErrorStats() throws IOException {
117         Preconditions.assertTestFileExists(mRefYuv.mFileName);
118         Preconditions.assertTestFileExists(mTestYuv.mFileName);
119 
120         try (RandomAccessFile refStream = new RandomAccessFile(new File(mRefYuv.mFileName), "r");
121              InputStream testStream = new FileInputStream(mTestYuv.mFileName)) {
122             int ySize = mRefYuv.mWidth * mRefYuv.mHeight * mRefYuv.mBytesPerSample;
123             int uvSize = ySize >> 2;
124             byte[] yRef = new byte[ySize];
125             byte[] uvRef = new byte[uvSize];
126             byte[] yTest = new byte[ySize];
127             byte[] uvTest = new byte[uvSize];
128 
129             int frames = 0;
130             while (true) {
131                 int bytesReadRef = refStream.read(yRef);
132                 int bytesReadDec = testStream.read(yTest);
133                 if (bytesReadDec != ySize || (!mAllowLoopBack && bytesReadRef != ySize)) break;
134                 if (bytesReadRef != ySize) {
135                     refStream.seek(0);
136                     refStream.read(yRef);
137                 }
138                 double curYMSE = computeMSE(yRef, yTest, mRefYuv.mBytesPerSample);
139                 mGlobalMSE[0] += curYMSE;
140                 mMinimumMSE[0] = Math.min(mMinimumMSE[0], curYMSE);
141 
142                 assertEquals("failed to read U Plane " + mRefYuv.mFileName
143                                 + " contains insufficient bytes", uvSize,
144                         refStream.read(uvRef));
145                 assertEquals("failed to read U Plane " + mTestYuv.mFileName
146                                 + " contains insufficient bytes", uvSize,
147                         testStream.read(uvTest));
148                 double curUMSE = computeMSE(uvRef, uvTest, mRefYuv.mBytesPerSample);
149                 mGlobalMSE[1] += curUMSE;
150                 mMinimumMSE[1] = Math.min(mMinimumMSE[1], curUMSE);
151 
152                 assertEquals("failed to read V Plane " + mRefYuv.mFileName
153                                 + " contains insufficient bytes", uvSize,
154                         refStream.read(uvRef));
155                 assertEquals("failed to read V Plane " + mTestYuv.mFileName
156                                 + " contains insufficient bytes", uvSize,
157                         testStream.read(uvTest));
158                 double curVMSE = computeMSE(uvRef, uvTest, mRefYuv.mBytesPerSample);
159                 mGlobalMSE[2] += curVMSE;
160                 mMinimumMSE[2] = Math.min(mMinimumMSE[2], curVMSE);
161 
162                 double yFramePSNR = computePSNR(curYMSE, mRefYuv.mBytesPerSample);
163                 double uFramePSNR = computePSNR(curUMSE, mRefYuv.mBytesPerSample);
164                 double vFramePSNR = computePSNR(curVMSE, mRefYuv.mBytesPerSample);
165                 mAvgPSNR[0] += yFramePSNR;
166                 mAvgPSNR[1] += uFramePSNR;
167                 mAvgPSNR[2] += vFramePSNR;
168                 mFramesPSNR.add(new double[]{yFramePSNR, uFramePSNR, vFramePSNR});
169 
170                 if (ENABLE_LOGS) {
171                     String msg = String.format(
172                             "frame: %d mse_y:%,.2f mse_u:%,.2f mse_v:%,.2f psnr_y:%,.2f psnr_u:%,"
173                                     + ".2f psnr_v:%,.2f",
174                             frames, curYMSE, curUMSE, curVMSE, mFramesPSNR.get(frames)[0],
175                             mFramesPSNR.get(frames)[1], mFramesPSNR.get(frames)[2]);
176                     Log.v(LOG_TAG, msg);
177                 }
178                 frames++;
179             }
180             for (int i = 0; i < mGlobalPSNR.length; i++) {
181                 mGlobalMSE[i] /= frames;
182                 mGlobalPSNR[i] = computePSNR(mGlobalMSE[i], mRefYuv.mBytesPerSample);
183                 mMinimumPSNR[i] = computePSNR(mMinimumMSE[i], mRefYuv.mBytesPerSample);
184                 mAvgPSNR[i] /= frames;
185             }
186             if (ENABLE_LOGS) {
187                 String msg = String.format(
188                         "global_psnr_y:%.2f, global_psnr_u:%.2f, global_psnr_v:%.2f, min_psnr_y:%"
189                                 + ".2f, min_psnr_u:%.2f, min_psnr_v:%.2f avg_psnr_y:%.2f, "
190                                 + "avg_psnr_u:%.2f, avg_psnr_v:%.2f",
191                         mGlobalPSNR[0], mGlobalPSNR[1], mGlobalPSNR[2], mMinimumPSNR[0],
192                         mMinimumPSNR[1], mMinimumPSNR[2], mAvgPSNR[0], mAvgPSNR[1], mAvgPSNR[2]);
193                 Log.v(LOG_TAG, msg);
194             }
195         }
196     }
197 
198     /**
199      * Returns Min(Ypsnr of all frames), Min(Upsnr of all frames), Min(Vpsnr of all frames) as an
200      * array at subscripts 0, 1, 2 respectively
201      */
getMinimumPSNR()202     public double[] getMinimumPSNR() throws IOException {
203         if (!mGenerateStats) {
204             generateErrorStats();
205             mGenerateStats = true;
206         }
207         return mMinimumPSNR;
208     }
209 
210     /**
211      * Returns GlobalYpsnr, GlobalUpsnr, GlobalVpsnr as an array at subscripts 0, 1, 2 respectively.
212      * Globalpsnr = 10 * log10 (peakSignal * peakSignal / global mse)
213      * GlobalMSE = Sum of all frames MSE / Total Frames
214      * MSE = Sum of all (error * error) / Total pixels
215      * error = ref[i] - test[i]
216      */
getGlobalPSNR()217     public double[] getGlobalPSNR() throws IOException {
218         if (!mGenerateStats) {
219             generateErrorStats();
220             mGenerateStats = true;
221         }
222         return mGlobalPSNR;
223     }
224 
225     /**
226      * returns list of all frames PSNR. Each entry in the list is an array of 3 elements,
227      * representing Y, U, V Planes PSNR separately
228      */
getFramesPSNR()229     public ArrayList<double[]> getFramesPSNR() throws IOException {
230         if (!mGenerateStats) {
231             generateErrorStats();
232             mGenerateStats = true;
233         }
234         return mFramesPSNR;
235     }
236 
237     /**
238      * Returns Avg(Ypsnr of all frames), Avg(Upsnr of all frames), Avg(Vpsnr of all frames) as an
239      * array at subscripts 0, 1, 2 respectively
240      */
getAvgPSNR()241     public double[] getAvgPSNR() throws IOException {
242         if (!mGenerateStats) {
243             generateErrorStats();
244             mGenerateStats = true;
245         }
246         return mAvgPSNR;
247     }
248 }
249